Skip to content

Commit

Permalink
Resolve label deletion issue (#2315)
Browse files Browse the repository at this point in the history
* fix label detection issue
---------
Co-authored-by: Eunwoo Shin <[email protected]>
  • Loading branch information
sungmanc authored Jul 10, 2023
1 parent 9fd5871 commit 645cf7b
Show file tree
Hide file tree
Showing 16 changed files with 546 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,10 @@ def evaluate(
)

eval_results["MHAcc"] = total_acc
eval_results["avgClsAcc"] = total_acc_sl / self.hierarchical_info["num_multiclass_heads"]
if self.hierarchical_info["num_multiclass_heads"] > 0:
eval_results["avgClsAcc"] = total_acc_sl / self.hierarchical_info["num_multiclass_heads"]
else:
eval_results["avgClsAcc"] = total_acc_sl
eval_results["mAP"] = mAP_value
eval_results["accuracy"] = total_acc

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@
logger = get_logger()


def is_hierarchical_chkpt(chkpt: dict):
"""Detect whether previous checkpoint is hierarchical or not."""
for k, v in chkpt.items():
if "fc" in k:
return True
return False


@CLASSIFIERS.register_module()
class SAMImageClassifier(SAMClassifierMixin, ClsLossDynamicsTrackingMixin, ImageClassifier):
"""SAM-enabled ImageClassifier."""
Expand Down Expand Up @@ -193,11 +201,19 @@ def load_state_dict_pre_hook(module, state_dict, prefix, *args, **kwargs): # no
def load_state_dict_mixing_hook(
model, model_classes, chkpt_classes, chkpt_dict, prefix, *args, **kwargs
): # pylint: disable=unused-argument, too-many-branches, too-many-locals
"""Modify input state_dict according to class name matching before weight loading."""
"""Modify input state_dict according to class name matching before weight loading.
If previous training is hierarchical training,
then the current training should be hierarchical training. vice versa.
"""
backbone_type = type(model.backbone).__name__
if backbone_type not in ["OTXMobileNetV3", "OTXEfficientNet", "OTXEfficientNetV2"]:
return

if model.hierarchical != is_hierarchical_chkpt(chkpt_dict):
return

# Dst to src mapping index
model_classes = list(model_classes)
chkpt_classes = list(chkpt_classes)
Expand Down Expand Up @@ -249,13 +265,15 @@ def load_state_dict_mixing_hook(
continue

# Mix weights
chkpt_param = chkpt_dict[chkpt_name]
for module, c in enumerate(model2chkpt):
if c >= 0:
model_param[module].copy_(chkpt_param[c])
# NOTE: Label mix is not supported for H-label classification.
if not model.hierarchical:
chkpt_param = chkpt_dict[chkpt_name]
for module, c in enumerate(model2chkpt):
if c >= 0:
model_param[module].copy_(chkpt_param[c])

# Replace checkpoint weight by mixed weights
chkpt_dict[chkpt_name] = model_param
# Replace checkpoint weight by mixed weights
chkpt_dict[chkpt_name] = model_param

def extract_feat(self, img):
"""Directly extract features from the backbone + neck.
Expand Down
22 changes: 15 additions & 7 deletions otx/algorithms/classification/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
from otx.api.entities.inference_parameters import (
default_progress_callback as default_infer_progress_callback,
)
from otx.api.entities.label import LabelEntity
from otx.api.entities.label_schema import LabelGroup
from otx.api.entities.metadata import FloatMetadata, FloatType
from otx.api.entities.metrics import (
CurveMetric,
Expand Down Expand Up @@ -127,16 +129,22 @@ def __init__(self, task_environment: TaskEnvironment, output_path: Optional[str]
if self._task_environment.model is not None:
self._load_model()

def _is_multi_label(self, label_groups: List[LabelGroup], all_labels: List[LabelEntity]):
"""Check whether the current training mode is multi-label or not."""
# NOTE: In the current Geti, multi-label should have `___` symbol for all group names.
find_multilabel_symbol = ["___" in getattr(i, "name", "") for i in label_groups]
return (
(len(label_groups) > 1) and (len(label_groups) == len(all_labels)) and (False not in find_multilabel_symbol)
)

def _set_train_mode(self):
self._multilabel = len(self._task_environment.label_schema.get_groups(False)) > 1 and len(
self._task_environment.label_schema.get_groups(False)
) == len(
self._task_environment.get_labels(include_empty=False)
) # noqa:E127
label_groups = self._task_environment.label_schema.get_groups(include_empty=False)
all_labels = self._task_environment.label_schema.get_labels(include_empty=False)

self._multilabel = self._is_multi_label(label_groups, all_labels)
if self._multilabel:
logger.info("Classification mode: multilabel")

if not self._multilabel and len(self._task_environment.label_schema.get_groups(False)) > 1:
elif len(label_groups) > 1:
logger.info("Classification mode: hierarchical")
self._hierarchical = True
self._hierarchical_info = get_hierarchical_info(self._task_environment.label_schema)
Expand Down
181 changes: 181 additions & 0 deletions tests/assets/datumaro_h-label_class_decremental/annotations/train.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
{
"info": {},
"categories": {
"label": {
"labels": [
{
"name": "right",
"parent": "triangle",
"attributes": []
},
{
"name": "multi a",
"parent": "triangle",
"attributes": []
},
{
"name": "equilateral",
"parent": "triangle",
"attributes": []
},
{
"name": "square",
"parent": "rectangle",
"attributes": []
},
{
"name": "triangle",
"parent": "",
"attributes": []
},
{
"name": "non_square",
"parent": "rectangle",
"attributes": []
},
{
"name": "rectangle",
"parent": "",
"attributes": []
}
],
"label_groups": [
{
"name": "shape",
"group_type": "exclusive",
"labels": ["rectangle", "triangle"]
},
{
"name": "rectangle default",
"group_type": "exclusive",
"labels": ["non_square", "square"]
},
{
"name": "triangle default",
"group_type": "exclusive",
"labels": ["equilateral", "right"]
},
{
"name": "shape___multiple example___multi a",
"group_type": "exclusive",
"labels": ["multi a"]
}
],
"attributes": []
},
"mask": {
"colormap": [
{
"label_id": 0,
"r": 129,
"g": 64,
"b": 123
},
{
"label_id": 1,
"r": 91,
"g": 105,
"b": 255
},
{
"label_id": 2,
"r": 91,
"g": 105,
"b": 255
},
{
"label_id": 3,
"r": 255,
"g": 86,
"b": 98
},
{
"label_id": 4,
"r": 204,
"g": 148,
"b": 218
},
{
"label_id": 5,
"r": 0,
"g": 251,
"b": 87
},
{
"label_id": 6,
"r": 84,
"g": 143,
"b": 173
}
]
}
},
"items": [
{
"id": "a",
"annotations": [
{
"id": 0,
"type": "label",
"attributes": {},
"group": 0,
"label_id": 4
},
{
"id": 0,
"type": "label",
"attributes": {},
"group": 0,
"label_id": 5
},
{
"id": 0,
"type": "label",
"attributes": {},
"group": 0,
"label_id": 1
}
],
"image": {
"path": "a.jpg",
"size": [10, 5]
},
"media": {
"path": ""
}
},
{
"id": "b",
"annotations": [
{
"id": 0,
"type": "label",
"attributes": {},
"group": 0,
"label_id": 6
},
{
"id": 0,
"type": "label",
"attributes": {},
"group": 0,
"label_id": 5
},
{
"id": 0,
"type": "label",
"attributes": {},
"group": 0,
"label_id": 2
}
],
"image": {
"path": "b.jpg",
"size": [10, 5]
},
"media": {
"path": ""
}
}
]
}
Loading

0 comments on commit 645cf7b

Please sign in to comment.