Skip to content

Commit

Permalink
[Enhance] Enhance ArcFaceClsHead. (#1181)
Browse files Browse the repository at this point in the history
* update arcface

* fix unit tests

* add adv-margins

add adv-margins

update arcface

* rebase

* update doc and fix ut

* rebase

* update code

* rebase

* use label data

* update set-margins

* Modify Arcface related method names.

Co-authored-by: mzr1996 <[email protected]>
  • Loading branch information
Ezra-Yu and mzr1996 authored Nov 21, 2022
1 parent 4fb44f8 commit b000781
Show file tree
Hide file tree
Showing 10 changed files with 535 additions and 185 deletions.
3 changes: 2 additions & 1 deletion docs/en/api/engine.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ Hooks
ClassNumCheckHook
PreciseBNHook
VisualizationHook
SwitchRecipeHook
PrepareProtoBeforeValLoopHook
SetAdaptiveMarginsHook

.. module:: mmcls.engine.optimizers

Expand Down
1 change: 1 addition & 0 deletions docs/en/api/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ Heads
EfficientFormerClsHead
DeiTClsHead
ConformerHead
ArcFaceClsHead
MultiLabelClsHead
MultiLabelLinearClsHead
CSRAClsHead
Expand Down
4 changes: 3 additions & 1 deletion mmcls/engine/hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .class_num_check_hook import ClassNumCheckHook
from .margin_head_hooks import SetAdaptiveMarginsHook
from .precise_bn_hook import PreciseBNHook
from .retriever_hooks import PrepareProtoBeforeValLoopHook
from .switch_recipe_hook import SwitchRecipeHook
from .visualization_hook import VisualizationHook

__all__ = [
'ClassNumCheckHook', 'PreciseBNHook', 'VisualizationHook',
'SwitchRecipeHook', 'PrepareProtoBeforeValLoopHook'
'SwitchRecipeHook', 'PrepareProtoBeforeValLoopHook',
'SetAdaptiveMarginsHook'
]
61 changes: 61 additions & 0 deletions mmcls/engine/hooks/margin_head_hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) OpenMMLab. All rights reserved
import numpy as np
from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper

from mmcls.models.heads import ArcFaceClsHead
from mmcls.registry import HOOKS


@HOOKS.register_module()
class SetAdaptiveMarginsHook(Hook):
r"""Set adaptive-margins in ArcFaceClsHead based on the power of
category-wise count.
A PyTorch implementation of paper `Google Landmark Recognition 2020
Competition Third Place Solution <https://arxiv.org/abs/2010.05350>`_.
The margins will be
:math:`\text{f}(n) = (marginMax - marginMin) · norm(n^p) + marginMin`.
The `n` indicates the number of occurrences of a category.
Args:
margin_min (float): Lower bound of margins. Defaults to 0.05.
margin_max (float): Upper bound of margins. Defaults to 0.5.
power (float): The power of category freqercy. Defaults to -0.25.
"""

def __init__(self, margin_min=0.05, margin_max=0.5, power=-0.25) -> None:
self.margin_min = margin_min
self.margin_max = margin_max
self.margin_range = margin_max - margin_min
self.p = power

def before_train(self, runner):
"""change the margins in ArcFaceClsHead.
Args:
runner (obj: `Runner`): Runner.
"""
model = runner.model
if is_model_wrapper(model):
model = model.module

if (hasattr(model, 'head')
and not isinstance(model.head, ArcFaceClsHead)):
raise ValueError(
'Hook ``SetFreqPowAdvMarginsHook`` could only be used '
f'for ``ArcFaceClsHead``, but get {type(model.head)}')

# generate margins base on the dataset.
gt_labels = runner.train_dataloader.dataset.get_gt_labels()
label_count = np.bincount(gt_labels)
label_count[label_count == 0] = 1 # At least one occurrence
pow_freq = np.power(label_count, self.p)

min_f, max_f = pow_freq.min(), pow_freq.max()
normized_pow_freq = (pow_freq - min_f) / (max_f - min_f)
margins = normized_pow_freq * self.margin_range + self.margin_min

assert len(margins) == runner.model.head.num_classes

model.head.set_margins(margins)
14 changes: 9 additions & 5 deletions mmcls/models/backbones/hornet.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,20 +250,24 @@ def forward(self, x):

@MODELS.register_module()
class HorNet(BaseBackbone):
"""HorNet
A PyTorch impl of : `HorNet: Efficient High-Order Spatial Interactions
with Recursive Gated Convolutions`
Inspiration from
https://github.com/raoyongming/HorNet
"""HorNet.
A PyTorch implementation of paper `HorNet: Efficient High-Order Spatial
Interactions with Recursive Gated Convolutions
<https://arxiv.org/abs/2207.14284>`_ .
Inspiration from https://github.com/raoyongming/HorNet
Args:
arch (str | dict): HorNet architecture.
If use string, choose from 'tiny', 'small', 'base' and 'large'.
If use dict, it should have below keys:
- **base_dim** (int): The base dimensions of embedding.
- **depths** (List[int]): The number of blocks in each stage.
- **orders** (List[int]): The number of order of gnConv in each
stage.
- **dw_cfg** (List[dict]): The Config for dw conv.
Defaults to 'tiny'.
in_channels (int): Number of input image channels. Defaults to 3.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
Expand Down
2 changes: 1 addition & 1 deletion mmcls/models/heads/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .arcface_head import ArcFaceClsHead
from .cls_head import ClsHead
from .conformer_head import ConformerHead
from .deit_head import DeiTClsHead
from .efficientformer_head import EfficientFormerClsHead
from .linear_head import LinearClsHead
from .margin_head import ArcFaceClsHead
from .multi_label_cls_head import MultiLabelClsHead
from .multi_label_csra_head import CSRAClsHead
from .multi_label_linear_head import MultiLabelLinearClsHead
Expand Down
176 changes: 0 additions & 176 deletions mmcls/models/heads/arcface_head.py

This file was deleted.

Loading

0 comments on commit b000781

Please sign in to comment.