From b0007812d653d09f703e167aae0154ef13a81748 Mon Sep 17 00:00:00 2001 From: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com> Date: Mon, 21 Nov 2022 18:10:39 +0800 Subject: [PATCH] [Enhance] Enhance ArcFaceClsHead. (#1181) * update arcface * fix unit tests * add adv-margins add adv-margins update arcface * rebase * update doc and fix ut * rebase * update code * rebase * use label data * update set-margins * Modify Arcface related method names. Co-authored-by: mzr1996 --- docs/en/api/engine.rst | 3 +- docs/en/api/models.rst | 1 + mmcls/engine/hooks/__init__.py | 4 +- mmcls/engine/hooks/margin_head_hooks.py | 61 ++++ mmcls/models/backbones/hornet.py | 14 +- mmcls/models/heads/__init__.py | 2 +- mmcls/models/heads/arcface_head.py | 176 ----------- mmcls/models/heads/margin_head.py | 299 ++++++++++++++++++ .../test_hooks/test_arcface_hooks.py | 102 ++++++ tests/test_models/test_heads.py | 58 +++- 10 files changed, 535 insertions(+), 185 deletions(-) create mode 100644 mmcls/engine/hooks/margin_head_hooks.py delete mode 100644 mmcls/models/heads/arcface_head.py create mode 100644 mmcls/models/heads/margin_head.py create mode 100644 tests/test_engine/test_hooks/test_arcface_hooks.py diff --git a/docs/en/api/engine.rst b/docs/en/api/engine.rst index a85760aa57f..d1fa82bde95 100644 --- a/docs/en/api/engine.rst +++ b/docs/en/api/engine.rst @@ -31,7 +31,8 @@ Hooks ClassNumCheckHook PreciseBNHook VisualizationHook - SwitchRecipeHook + PrepareProtoBeforeValLoopHook + SetAdaptiveMarginsHook .. module:: mmcls.engine.optimizers diff --git a/docs/en/api/models.rst b/docs/en/api/models.rst index 2894f630926..8b0bfab2f3f 100644 --- a/docs/en/api/models.rst +++ b/docs/en/api/models.rst @@ -140,6 +140,7 @@ Heads EfficientFormerClsHead DeiTClsHead ConformerHead + ArcFaceClsHead MultiLabelClsHead MultiLabelLinearClsHead CSRAClsHead diff --git a/mmcls/engine/hooks/__init__.py b/mmcls/engine/hooks/__init__.py index 29d73fb462a..54343b7af19 100644 --- a/mmcls/engine/hooks/__init__.py +++ b/mmcls/engine/hooks/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .class_num_check_hook import ClassNumCheckHook +from .margin_head_hooks import SetAdaptiveMarginsHook from .precise_bn_hook import PreciseBNHook from .retriever_hooks import PrepareProtoBeforeValLoopHook from .switch_recipe_hook import SwitchRecipeHook @@ -7,5 +8,6 @@ __all__ = [ 'ClassNumCheckHook', 'PreciseBNHook', 'VisualizationHook', - 'SwitchRecipeHook', 'PrepareProtoBeforeValLoopHook' + 'SwitchRecipeHook', 'PrepareProtoBeforeValLoopHook', + 'SetAdaptiveMarginsHook' ] diff --git a/mmcls/engine/hooks/margin_head_hooks.py b/mmcls/engine/hooks/margin_head_hooks.py new file mode 100644 index 00000000000..7ca878433d2 --- /dev/null +++ b/mmcls/engine/hooks/margin_head_hooks.py @@ -0,0 +1,61 @@ +# Copyright (c) OpenMMLab. All rights reserved +import numpy as np +from mmengine.hooks import Hook +from mmengine.model import is_model_wrapper + +from mmcls.models.heads import ArcFaceClsHead +from mmcls.registry import HOOKS + + +@HOOKS.register_module() +class SetAdaptiveMarginsHook(Hook): + r"""Set adaptive-margins in ArcFaceClsHead based on the power of + category-wise count. + + A PyTorch implementation of paper `Google Landmark Recognition 2020 + Competition Third Place Solution `_. + The margins will be + :math:`\text{f}(n) = (marginMax - marginMin) ยท norm(n^p) + marginMin`. + The `n` indicates the number of occurrences of a category. + + Args: + margin_min (float): Lower bound of margins. Defaults to 0.05. + margin_max (float): Upper bound of margins. Defaults to 0.5. + power (float): The power of category freqercy. Defaults to -0.25. + """ + + def __init__(self, margin_min=0.05, margin_max=0.5, power=-0.25) -> None: + self.margin_min = margin_min + self.margin_max = margin_max + self.margin_range = margin_max - margin_min + self.p = power + + def before_train(self, runner): + """change the margins in ArcFaceClsHead. + + Args: + runner (obj: `Runner`): Runner. + """ + model = runner.model + if is_model_wrapper(model): + model = model.module + + if (hasattr(model, 'head') + and not isinstance(model.head, ArcFaceClsHead)): + raise ValueError( + 'Hook ``SetFreqPowAdvMarginsHook`` could only be used ' + f'for ``ArcFaceClsHead``, but get {type(model.head)}') + + # generate margins base on the dataset. + gt_labels = runner.train_dataloader.dataset.get_gt_labels() + label_count = np.bincount(gt_labels) + label_count[label_count == 0] = 1 # At least one occurrence + pow_freq = np.power(label_count, self.p) + + min_f, max_f = pow_freq.min(), pow_freq.max() + normized_pow_freq = (pow_freq - min_f) / (max_f - min_f) + margins = normized_pow_freq * self.margin_range + self.margin_min + + assert len(margins) == runner.model.head.num_classes + + model.head.set_margins(margins) diff --git a/mmcls/models/backbones/hornet.py b/mmcls/models/backbones/hornet.py index aa98aa0a79b..e6d107045f5 100644 --- a/mmcls/models/backbones/hornet.py +++ b/mmcls/models/backbones/hornet.py @@ -250,13 +250,16 @@ def forward(self, x): @MODELS.register_module() class HorNet(BaseBackbone): - """HorNet - A PyTorch impl of : `HorNet: Efficient High-Order Spatial Interactions - with Recursive Gated Convolutions` - Inspiration from - https://github.com/raoyongming/HorNet + """HorNet. + + A PyTorch implementation of paper `HorNet: Efficient High-Order Spatial + Interactions with Recursive Gated Convolutions + `_ . + Inspiration from https://github.com/raoyongming/HorNet + Args: arch (str | dict): HorNet architecture. + If use string, choose from 'tiny', 'small', 'base' and 'large'. If use dict, it should have below keys: - **base_dim** (int): The base dimensions of embedding. @@ -264,6 +267,7 @@ class HorNet(BaseBackbone): - **orders** (List[int]): The number of order of gnConv in each stage. - **dw_cfg** (List[dict]): The Config for dw conv. + Defaults to 'tiny'. in_channels (int): Number of input image channels. Defaults to 3. drop_path_rate (float): Stochastic depth rate. Defaults to 0. diff --git a/mmcls/models/heads/__init__.py b/mmcls/models/heads/__init__.py index 104f1c53588..3e359d37227 100644 --- a/mmcls/models/heads/__init__.py +++ b/mmcls/models/heads/__init__.py @@ -1,10 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .arcface_head import ArcFaceClsHead from .cls_head import ClsHead from .conformer_head import ConformerHead from .deit_head import DeiTClsHead from .efficientformer_head import EfficientFormerClsHead from .linear_head import LinearClsHead +from .margin_head import ArcFaceClsHead from .multi_label_cls_head import MultiLabelClsHead from .multi_label_csra_head import CSRAClsHead from .multi_label_linear_head import MultiLabelLinearClsHead diff --git a/mmcls/models/heads/arcface_head.py b/mmcls/models/heads/arcface_head.py deleted file mode 100644 index 23cb3d23024..00000000000 --- a/mmcls/models/heads/arcface_head.py +++ /dev/null @@ -1,176 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import math -from typing import List, Optional, Tuple - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from mmcls.registry import MODELS -from mmcls.structures import ClsDataSample -from .cls_head import ClsHead - - -class NormLinear(nn.Linear): - """An enhanced linear layer, which could normalize the input and the linear - weight. - - Args: - in_features (int): size of each input sample. - out_features (int): size of each output sample - bias (bool): Whether there is bias. If set to ``False``, the - layer will not learn an additive bias. Defaults to ``True``. - feature_norm (bool): Whether to normalize the input feature. - Defaults to ``True``. - weight_norm (bool):Whether to normalize the weight. - Defaults to ``True``. - """ - - def __init__(self, - in_features: int, - out_features: int, - bias: bool = False, - feature_norm: bool = True, - weight_norm: bool = True): - - super().__init__(in_features, out_features, bias=bias) - self.weight_norm = weight_norm - self.feature_norm = feature_norm - - def forward(self, input: torch.Tensor) -> torch.Tensor: - if self.feature_norm: - input = F.normalize(input) - if self.weight_norm: - weight = F.normalize(self.weight) - else: - weight = self.weight - return F.linear(input, weight, self.bias) - - -@MODELS.register_module() -class ArcFaceClsHead(ClsHead): - """ArcFace classifier head. - - Args: - num_classes (int): Number of categories excluding the background - category. - in_channels (int): Number of channels in the input feature map. - s (float): Norm of input feature. Defaults to 30.0. - m (float): Margin. Defaults to 0.5. - easy_margin (bool): Avoid theta + m >= PI. Defaults to False. - ls_eps (float): Label smoothing. Defaults to 0. - bias (bool): Whether to use bias in norm layer. Defaults to False. - loss (dict): Config of classification loss. Defaults to - ``dict(type='CrossEntropyLoss', loss_weight=1.0)``. - init_cfg (dict, optional): the config to control the initialization. - Defaults to None. - """ - - def __init__(self, - num_classes: int, - in_channels: int, - s: float = 30.0, - m: float = 0.50, - easy_margin: bool = False, - ls_eps: float = 0.0, - bias: bool = False, - loss: dict = dict(type='CrossEntropyLoss', loss_weight=1.0), - init_cfg: Optional[dict] = None): - - super(ArcFaceClsHead, self).__init__(init_cfg=init_cfg) - self.loss_module = MODELS.build(loss) - - self.in_channels = in_channels - self.num_classes = num_classes - - if self.num_classes <= 0: - raise ValueError( - f'num_classes={num_classes} must be a positive integer') - - self.s = s - self.m = m - self.ls_eps = ls_eps - - self.norm_linear = NormLinear(in_channels, num_classes, bias=bias) - - self.easy_margin = easy_margin - self.th = math.cos(math.pi - m) - self.mm = math.sin(math.pi - m) * m - - def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: - """The process before the final classification head. - - The input ``feats`` is a tuple of tensor, and each tensor is the - feature of a backbone stage. In ``ArcFaceHead``, we just obtain the - feature of the last stage. - """ - # The ArcFaceHead doesn't have other module, just return after - # unpacking. - return feats[-1] - - def forward(self, - feats: Tuple[torch.Tensor], - target: Optional[torch.Tensor] = None) -> torch.Tensor: - """The forward process.""" - - pre_logits = self.pre_logits(feats) - - # cos=(a*b)/(||a||*||b||) - cosine = self.norm_linear(pre_logits) - - if target is None: - return self.s * cosine - - phi = torch.cos(torch.acos(cosine) + self.m) - - if self.easy_margin: - # when cosine>0, choose phi - # when cosine<=0, choose cosine - phi = torch.where(cosine > 0, phi, cosine) - else: - # when cos>th, choose phi - # when cos<=th, choose cosine-mm - phi = torch.where(cosine > self.th, phi, cosine - self.mm) - - one_hot = torch.zeros(cosine.size(), device=pre_logits.device) - one_hot.scatter_(1, target.view(-1, 1).long(), 1) - if self.ls_eps > 0: - one_hot = (1 - - self.ls_eps) * one_hot + self.ls_eps / self.num_classes - - output = (one_hot * phi) + ((1.0 - one_hot) * cosine) - return output * self.s - - def loss(self, feats: Tuple[torch.Tensor], - data_samples: List[ClsDataSample], **kwargs) -> dict: - """Calculate losses from the classification score. - - Args: - feats (tuple[Tensor]): The features extracted from the backbone. - Multiple stage inputs are acceptable but only the last stage - will be used to classify. The shape of every item should be - ``(num_samples, num_classes)``. - data_samples (List[ClsDataSample]): The annotation data of - every samples. - **kwargs: Other keyword arguments to forward the loss module. - - Returns: - dict[str, Tensor]: a dictionary of loss components - """ - - if 'score' in data_samples[0].gt_label: - # Batch augmentation may convert labels to one-hot format scores. - target = torch.stack([i.gt_label.score for i in data_samples]) - else: - target = torch.cat([i.gt_label.label for i in data_samples]) - - # The part can be traced by torch.fx - cls_score = self(feats, target) - - # compute loss - losses = dict() - loss = self.loss_module( - cls_score, target, avg_factor=cls_score.size(0), **kwargs) - losses['loss'] = loss - - return losses diff --git a/mmcls/models/heads/margin_head.py b/mmcls/models/heads/margin_head.py new file mode 100644 index 00000000000..ffd8ee8ae36 --- /dev/null +++ b/mmcls/models/heads/margin_head.py @@ -0,0 +1,299 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import List, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.fileio import list_from_file +from mmengine.runner import autocast +from mmengine.utils import is_seq_of + +from mmcls.models.losses import convert_to_one_hot +from mmcls.registry import MODELS +from mmcls.structures import ClsDataSample +from .cls_head import ClsHead + + +class NormProduct(nn.Linear): + """An enhanced linear layer with k clustering centers to calculate product + between normalized input and linear weight. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample + k (int): The number of clustering centers. Defaults to 1. + bias (bool): Whether there is bias. If set to ``False``, the + layer will not learn an additive bias. Defaults to ``True``. + feature_norm (bool): Whether to normalize the input feature. + Defaults to ``True``. + weight_norm (bool):Whether to normalize the weight. + Defaults to ``True``. + """ + + def __init__(self, + in_features: int, + out_features: int, + k=1, + bias: bool = False, + feature_norm: bool = True, + weight_norm: bool = True): + + super().__init__(in_features, out_features * k, bias=bias) + self.weight_norm = weight_norm + self.feature_norm = feature_norm + self.out_features = out_features + self.k = k + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if self.feature_norm: + input = F.normalize(input) + if self.weight_norm: + weight = F.normalize(self.weight) + else: + weight = self.weight + cosine_all = F.linear(input, weight, self.bias) + + if self.k == 1: + return cosine_all + else: + cosine_all = cosine_all.view(-1, self.out_features, self.k) + cosine, _ = torch.max(cosine_all, dim=2) + return cosine + + +@MODELS.register_module() +class ArcFaceClsHead(ClsHead): + """ArcFace classifier head. + + A PyTorch implementation of paper `ArcFace: Additive Angular Margin Loss + for Deep Face Recognition `_ and + `Sub-center ArcFace: Boosting Face Recognition by Large-Scale Noisy Web + Faces `_ + + Example: + To use ArcFace in config files. + + 1. use vanilla ArcFace + + .. code:: python + + mode = dict( + backbone = xxx, + neck = xxxx, + head=dict( + type='ArcFaceClsHead', + num_classes=5000, + in_channels=1024, + loss = dict(type='CrossEntropyLoss', loss_weight=1.0), + init_cfg=None), + ) + + 2. use SubCenterArcFace with 3 sub-centers + + .. code:: python + + mode = dict( + backbone = xxx, + neck = xxxx, + head=dict( + type='ArcFaceClsHead', + num_classes=5000, + in_channels=1024, + num_subcenters=3, + loss = dict(type='CrossEntropyLoss', loss_weight=1.0), + init_cfg=None), + ) + + 3. use SubCenterArcFace With CountPowerAdaptiveMargins + + .. code:: python + + mode = dict( + backbone = xxx, + neck = xxxx, + head=dict( + type='ArcFaceClsHead', + num_classes=5000, + in_channels=1024, + num_subcenters=3, + loss = dict(type='CrossEntropyLoss', loss_weight=1.0), + init_cfg=None), + ) + + custom_hooks = [dict(type='SetAdaptiveMarginsHook')] + + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + num_subcenters (int): Number of subcenters. Defaults to 1. + scale (float): Scale factor of output logit. Defaults to 64.0. + margins (float): The penalty margin. Could be the fllowing formats: + + - float: The margin, would be same for all the categories. + - Sequence[float]: The category-based margins list. + - str: A '.txt' file path which contains a list. Each line + represents the margin of a category, and the number in the + i-th row indicates the margin of the i-th class. + + Defaults to 0.5. + easy_margin (bool): Avoid theta + m >= PI. Defaults to False. + loss (dict): Config of classification loss. Defaults to + ``dict(type='CrossEntropyLoss', loss_weight=1.0)``. + init_cfg (dict, optional): the config to control the initialization. + Defaults to None. + """ + + def __init__(self, + num_classes: int, + in_channels: int, + num_subcenters: int = 1, + scale: float = 64., + margins: Optional[Union[float, Sequence[float], str]] = 0.50, + easy_margin: bool = False, + loss: dict = dict(type='CrossEntropyLoss', loss_weight=1.0), + init_cfg: Optional[dict] = None): + + super(ArcFaceClsHead, self).__init__(init_cfg=init_cfg) + self.loss_module = MODELS.build(loss) + + assert num_subcenters >= 1 and num_classes >= 0 + self.in_channels = in_channels + self.num_classes = num_classes + self.num_subcenters = num_subcenters + self.scale = scale + self.easy_margin = easy_margin + + self.norm_product = NormProduct(in_channels, num_classes, + num_subcenters) + + if isinstance(margins, float): + margins = [margins] * num_classes + elif isinstance(margins, str) and margins.endswith('.txt'): + margins = [float(item) for item in list_from_file(margins)] + else: + assert is_seq_of(list(margins), (float, int)), ( + 'the attribute `margins` in ``ArcFaceClsHead`` should be a ' + ' float, a Sequence of float, or a ".txt" file path.') + + assert len(margins) == num_classes, \ + 'The length of margins must be equal with num_classes.' + + self.register_buffer( + 'margins', torch.tensor(margins).float(), persistent=False) + # To make `phi` monotonic decreasing, refers to + # https://github.com/deepinsight/insightface/issues/108 + sinm_m = torch.sin(math.pi - self.margins) * self.margins + threshold = torch.cos(math.pi - self.margins) + self.register_buffer('sinm_m', sinm_m, persistent=False) + self.register_buffer('threshold', threshold, persistent=False) + + def set_margins(self, margins: Union[Sequence[float], float]) -> None: + """set margins of arcface head. + + Args: + margins (Union[Sequence[float], float]): The marigins. + """ + if isinstance(margins, float): + margins = [margins] * self.num_classes + assert is_seq_of( + list(margins), float) and (len(margins) == self.num_classes), ( + f'margins must be Sequence[Union(float, int)], get {margins}') + + self.margins = torch.tensor( + margins, device=self.margins.device, dtype=torch.float32) + self.sinm_m = torch.sin(self.margins) * self.margins + self.threshold = -torch.cos(self.margins) + + def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The process before the final classification head. + + The input ``feats`` is a tuple of tensor, and each tensor is the + feature of a backbone stage. In ``ArcFaceHead``, we just obtain the + feature of the last stage. + """ + # The ArcFaceHead doesn't have other module, just return after + # unpacking. + return feats[-1] + + def _get_logit_with_margin(self, pre_logits, target): + """add arc margin to the cosine in target index. + + The target must be in index format. + """ + assert target.dim() == 1 or ( + target.dim() == 2 and target.shape[1] == 1), \ + 'The target must be in index format.' + cosine = self.norm_product(pre_logits) + phi = torch.cos(torch.acos(cosine) + self.margins) + + if self.easy_margin: + # when cosine>0, choose phi + # when cosine<=0, choose cosine + phi = torch.where(cosine > 0, phi, cosine) + else: + # when cos>th, choose phi + # when cos<=th, choose cosine-mm + phi = torch.where(cosine > self.threshold, phi, + cosine - self.sinm_m) + + target = convert_to_one_hot(target, self.num_classes) + output = target * phi + (1 - target) * cosine + return output + + def forward(self, + feats: Tuple[torch.Tensor], + target: Optional[torch.Tensor] = None) -> torch.Tensor: + """The forward process.""" + # Disable AMP + with autocast(enabled=False): + pre_logits = self.pre_logits(feats) + + if target is None: + # when eval, logit is the cosine between W and pre_logits; + # cos(theta_yj) = (x/||x||) * (W/||W||) + logit = self.norm_product(pre_logits) + else: + # when training, add a margin to the pre_logits where target is + # True, then logit is the cosine between W and new pre_logits + logit = self._get_logit_with_margin(pre_logits, target) + + return self.scale * logit + + def loss(self, feats: Tuple[torch.Tensor], + data_samples: List[ClsDataSample], **kwargs) -> dict: + """Calculate losses from the classification score. + + Args: + feats (tuple[Tensor]): The features extracted from the backbone. + Multiple stage inputs are acceptable but only the last stage + will be used to classify. The shape of every item should be + ``(num_samples, num_classes)``. + data_samples (List[ClsDataSample]): The annotation data of + every samples. + **kwargs: Other keyword arguments to forward the loss module. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + # Unpack data samples and pack targets + label_target = torch.cat([i.gt_label.label for i in data_samples]) + if 'score' in data_samples[0].gt_label: + # Batch augmentation may convert labels to one-hot format scores. + target = torch.stack([i.gt_label.score for i in data_samples]) + else: + # change the labels to to one-hot format scores. + target = label_target + + # the index format target would be used + cls_score = self(feats, label_target) + + # compute loss + losses = dict() + loss = self.loss_module( + cls_score, target, avg_factor=cls_score.size(0), **kwargs) + losses['loss'] = loss + + return losses diff --git a/tests/test_engine/test_hooks/test_arcface_hooks.py b/tests/test_engine/test_hooks/test_arcface_hooks.py new file mode 100644 index 00000000000..041d36a466f --- /dev/null +++ b/tests/test_engine/test_hooks/test_arcface_hooks.py @@ -0,0 +1,102 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import tempfile +from unittest import TestCase + +import numpy as np +import torch +from mmengine.runner import Runner +from torch.utils.data import DataLoader, Dataset + + +class ExampleDataset(Dataset): + + def __init__(self): + self.index = 0 + self.metainfo = None + + def __getitem__(self, idx): + results = dict(imgs=torch.rand((224, 224, 3)).float(), ) + return results + + def get_gt_labels(self): + gt_labels = np.array([0, 1, 2, 4, 0, 4, 1, 2, 2, 1]) + return gt_labels + + def __len__(self): + return 10 + + +class TestSetAdaptiveMarginsHook(TestCase): + DEFAULT_HOOK_CFG = dict(type='SetAdaptiveMarginsHook') + DEFAULT_MODEL = dict( + type='ImageClassifier', + backbone=dict( + type='ResNet', + depth=34, + num_stages=4, + out_indices=(3, ), + style='pytorch'), + neck=dict(type='GlobalAveragePooling'), + head=dict(type='ArcFaceClsHead', in_channels=512, num_classes=5)) + + def test_before_train(self): + default_hooks = dict( + timer=dict(type='IterTimerHook'), + logger=None, + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict(type='CheckpointHook', interval=1), + sampler_seed=dict(type='DistSamplerSeedHook'), + visualization=dict(type='VisualizationHook', enable=False), + ) + tmpdir = tempfile.TemporaryDirectory() + loader = DataLoader(ExampleDataset(), batch_size=2) + self.runner = Runner( + model=self.DEFAULT_MODEL, + work_dir=tmpdir.name, + train_dataloader=loader, + train_cfg=dict(by_epoch=True, max_epochs=1), + log_level='WARNING', + optim_wrapper=dict( + optimizer=dict(type='SGD', lr=0.1, momentum=0.9)), + param_scheduler=dict( + type='MultiStepLR', milestones=[1, 2], gamma=0.1), + default_scope='mmcls', + default_hooks=default_hooks, + experiment_name='test_construct_with_arcface', + custom_hooks=[self.DEFAULT_HOOK_CFG]) + + default_margins = torch.tensor([0.5] * 5) + torch.allclose(self.runner.model.head.margins.cpu(), default_margins) + self.runner.call_hook('before_train') + # counts = [2 ,3 , 3, 0, 2] -> [2 ,3 , 3, 1, 2] at least occur once + # feqercy**-0.25 = [0.84089642, 0.75983569, 0.75983569, 1., 0.84089642] + # normized = [0.33752196, 0. , 0. , 1. , 0.33752196] + # margins = [0.20188488, 0.05, 0.05, 0.5, 0.20188488] + expert_margins = torch.tensor( + [0.20188488, 0.05, 0.05, 0.5, 0.20188488]) + torch.allclose(self.runner.model.head.margins.cpu(), expert_margins) + + model_cfg = {**self.DEFAULT_MODEL} + model_cfg['head'] = dict( + type='LinearClsHead', + num_classes=1000, + in_channels=512, + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + topk=(1, 5), + ) + self.runner = Runner( + model=model_cfg, + work_dir=tmpdir.name, + train_dataloader=loader, + train_cfg=dict(by_epoch=True, max_epochs=1), + log_level='WARNING', + optim_wrapper=dict( + optimizer=dict(type='SGD', lr=0.1, momentum=0.9)), + param_scheduler=dict( + type='MultiStepLR', milestones=[1, 2], gamma=0.1), + default_scope='mmcls', + default_hooks=default_hooks, + experiment_name='test_construct_wo_arcface', + custom_hooks=[self.DEFAULT_HOOK_CFG]) + with self.assertRaises(ValueError): + self.runner.call_hook('before_train') diff --git a/tests/test_models/test_heads.py b/tests/test_models/test_heads.py index 6f63eec4519..0b1f72f1db8 100644 --- a/tests/test_models/test_heads.py +++ b/tests/test_models/test_heads.py @@ -1,6 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy +import os import random +import tempfile from unittest import TestCase import numpy as np @@ -486,9 +488,37 @@ class TestArcFaceClsHead(TestCase): DEFAULT_ARGS = dict(type='ArcFaceClsHead', in_channels=10, num_classes=5) def test_initialize(self): - with self.assertRaisesRegex(ValueError, 'num_classes=-5 must be'): + with self.assertRaises(AssertionError): MODELS.build({**self.DEFAULT_ARGS, 'num_classes': -5}) + with self.assertRaises(AssertionError): + MODELS.build({**self.DEFAULT_ARGS, 'num_subcenters': 0}) + + # Test margins + with self.assertRaises(AssertionError): + MODELS.build({**self.DEFAULT_ARGS, 'margins': dict()}) + + with self.assertRaises(AssertionError): + MODELS.build({**self.DEFAULT_ARGS, 'margins': [0.1] * 4}) + + with self.assertRaises(AssertionError): + MODELS.build({**self.DEFAULT_ARGS, 'margins': [0.1] * 4 + ['0.1']}) + + arcface = MODELS.build(self.DEFAULT_ARGS) + torch.allclose(arcface.margins, torch.tensor([0.5] * 5)) + + arcface = MODELS.build({**self.DEFAULT_ARGS, 'margins': [0.1] * 5}) + torch.allclose(arcface.margins, torch.tensor([0.1] * 5)) + + margins = [0.1, 0.2, 0.3, 0.4, 5] + with tempfile.TemporaryDirectory() as tmpdirname: + tmp_path = os.path.join(tmpdirname, 'margins.txt') + with open(tmp_path, 'w') as tmp_file: + for m in margins: + tmp_file.write(f'{m}\n') + arcface = MODELS.build({**self.DEFAULT_ARGS, 'margins': tmp_path}) + torch.allclose(arcface.margins, torch.tensor(margins)) + def test_pre_logits(self): head = MODELS.build(self.DEFAULT_ARGS) @@ -497,10 +527,29 @@ def test_pre_logits(self): pre_logits = head.pre_logits(feats) self.assertIs(pre_logits, feats[-1]) + # Test with SubCenterArcFace + head = MODELS.build({**self.DEFAULT_ARGS, 'num_subcenters': 3}) + feats = (torch.rand(4, 10), torch.rand(4, 10)) + pre_logits = head.pre_logits(feats) + self.assertIs(pre_logits, feats[-1]) + def test_forward(self): head = MODELS.build(self.DEFAULT_ARGS) # target is not None feats = (torch.rand(4, 10), torch.rand(4, 10)) + target = torch.zeros(4).long() + outs = head(feats, target) + self.assertEqual(outs.shape, (4, 5)) + + # target is None + feats = (torch.rand(4, 10), torch.rand(4, 10)) + outs = head(feats) + self.assertEqual(outs.shape, (4, 5)) + + # Test with SubCenterArcFace + head = MODELS.build({**self.DEFAULT_ARGS, 'num_subcenters': 3}) + # target is not None + feats = (torch.rand(4, 10), torch.rand(4, 10)) target = torch.zeros(4) outs = head(feats, target) self.assertEqual(outs.shape, (4, 5)) @@ -519,3 +568,10 @@ def test_loss(self): losses = head.loss(feats, data_samples) self.assertEqual(losses.keys(), {'loss'}) self.assertGreater(losses['loss'].item(), 0) + + # Test with SubCenterArcFace + head = MODELS.build({**self.DEFAULT_ARGS, 'num_subcenters': 3}) + # test loss with used='before' + losses = head.loss(feats, data_samples) + self.assertEqual(losses.keys(), {'loss'}) + self.assertGreater(losses['loss'].item(), 0)