-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Enhance] Enhance ArcFaceClsHead. (#1181)
* update arcface * fix unit tests * add adv-margins add adv-margins update arcface * rebase * update doc and fix ut * rebase * update code * rebase * use label data * update set-margins * Modify Arcface related method names. Co-authored-by: mzr1996 <[email protected]>
- Loading branch information
Showing
10 changed files
with
535 additions
and
185 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,13 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .class_num_check_hook import ClassNumCheckHook | ||
from .margin_head_hooks import SetAdaptiveMarginsHook | ||
from .precise_bn_hook import PreciseBNHook | ||
from .retriever_hooks import PrepareProtoBeforeValLoopHook | ||
from .switch_recipe_hook import SwitchRecipeHook | ||
from .visualization_hook import VisualizationHook | ||
|
||
__all__ = [ | ||
'ClassNumCheckHook', 'PreciseBNHook', 'VisualizationHook', | ||
'SwitchRecipeHook', 'PrepareProtoBeforeValLoopHook' | ||
'SwitchRecipeHook', 'PrepareProtoBeforeValLoopHook', | ||
'SetAdaptiveMarginsHook' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved | ||
import numpy as np | ||
from mmengine.hooks import Hook | ||
from mmengine.model import is_model_wrapper | ||
|
||
from mmcls.models.heads import ArcFaceClsHead | ||
from mmcls.registry import HOOKS | ||
|
||
|
||
@HOOKS.register_module() | ||
class SetAdaptiveMarginsHook(Hook): | ||
r"""Set adaptive-margins in ArcFaceClsHead based on the power of | ||
category-wise count. | ||
A PyTorch implementation of paper `Google Landmark Recognition 2020 | ||
Competition Third Place Solution <https://arxiv.org/abs/2010.05350>`_. | ||
The margins will be | ||
:math:`\text{f}(n) = (marginMax - marginMin) · norm(n^p) + marginMin`. | ||
The `n` indicates the number of occurrences of a category. | ||
Args: | ||
margin_min (float): Lower bound of margins. Defaults to 0.05. | ||
margin_max (float): Upper bound of margins. Defaults to 0.5. | ||
power (float): The power of category freqercy. Defaults to -0.25. | ||
""" | ||
|
||
def __init__(self, margin_min=0.05, margin_max=0.5, power=-0.25) -> None: | ||
self.margin_min = margin_min | ||
self.margin_max = margin_max | ||
self.margin_range = margin_max - margin_min | ||
self.p = power | ||
|
||
def before_train(self, runner): | ||
"""change the margins in ArcFaceClsHead. | ||
Args: | ||
runner (obj: `Runner`): Runner. | ||
""" | ||
model = runner.model | ||
if is_model_wrapper(model): | ||
model = model.module | ||
|
||
if (hasattr(model, 'head') | ||
and not isinstance(model.head, ArcFaceClsHead)): | ||
raise ValueError( | ||
'Hook ``SetFreqPowAdvMarginsHook`` could only be used ' | ||
f'for ``ArcFaceClsHead``, but get {type(model.head)}') | ||
|
||
# generate margins base on the dataset. | ||
gt_labels = runner.train_dataloader.dataset.get_gt_labels() | ||
label_count = np.bincount(gt_labels) | ||
label_count[label_count == 0] = 1 # At least one occurrence | ||
pow_freq = np.power(label_count, self.p) | ||
|
||
min_f, max_f = pow_freq.min(), pow_freq.max() | ||
normized_pow_freq = (pow_freq - min_f) / (max_f - min_f) | ||
margins = normized_pow_freq * self.margin_range + self.margin_min | ||
|
||
assert len(margins) == runner.model.head.num_classes | ||
|
||
model.head.set_margins(margins) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.