Skip to content

Commit

Permalink
add inference demo
Browse files Browse the repository at this point in the history
  • Loading branch information
youqingxiaozhua committed Mar 31, 2023
1 parent 7448d37 commit f3b392b
Show file tree
Hide file tree
Showing 6 changed files with 213 additions and 12 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ APViT: Vision Transformer With Attentive Pooling for Robust Facial Expression Re

APViT is a simple and efficient Transformer-based method for facial expression recognition (FER). It builds on the [TransFER](https://openaccess.thecvf.com/content/ICCV2021/html/Xue_TransFER_Learning_Relation-Aware_Facial_Expression_Representations_With_Transformers_ICCV_2021_paper.html), but introduces two attentive pooling (AP) modules that do not require any learnable parameters. These modules help the model focus on the most expressive features and ignore the less relevant ones. You can read more about our method in our [paper](https://arxiv.org/abs/2212.05463).

## Update

- 2023-03-31: Added an [notebook demo](demo.ipynb) for inference.


## Installation

Expand Down
4 changes: 2 additions & 2 deletions configs/_base_/datasets/RAF.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
dict(type='Resize', size=img_size),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label', ]),
dict(type='Collect', keys=['img', 'gt_label',])
# dict(type='ToTensor', keys=['gt_label', ]),
dict(type='Collect', keys=['img', ])
]

base_path = 'data/RAF-DB/basic/'
Expand Down
195 changes: 195 additions & 0 deletions demo.ipynb

Large diffs are not rendered by default.

18 changes: 10 additions & 8 deletions mmcls/apis/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,20 @@ def init_model(config, checkpoint=None, device='cuda:0', options=None):
if options is not None:
config.merge_from_dict(options)
config.model.pretrained = None
config.model.extractor.pretrained = None
config.model.vit.pretrained = None
model = build_classifier(config.model)
if checkpoint is not None:
map_loc = 'cpu' if device == 'cpu' else None
checkpoint = load_checkpoint(model, checkpoint, map_location=map_loc)
if 'CLASSES' in checkpoint['meta']:
model.CLASSES = checkpoint['meta']['CLASSES']
else:
from mmcls.datasets import ImageNet
warnings.simplefilter('once')
warnings.warn('Class names are not saved in the checkpoint\'s '
'meta data, use imagenet by default.')
model.CLASSES = ImageNet.CLASSES
class_loaded = False
if 'meta' in checkpoint:
if 'CLASSES' in checkpoint['meta']:
model.CLASSES = checkpoint['meta']['CLASSES']
class_loaded = True
if not class_loaded:
from mmcls.datasets.raf import FER_CLASSES
model.CLASSES = FER_CLASSES
model.cfg = config # save the config in the model for convenience
model.to(device)
model.eval()
Expand Down
4 changes: 2 additions & 2 deletions mmcls/models/vit/vit_siam_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,8 +415,8 @@ def init_weights(self, pretrained, patch_num=0):
if patch_num != pos_embed.shape[1] - 1:
logger.warning(f'interpolate pos_embed from {patch_pos_embed.shape[1]} to {patch_num}')
pos_embed_new = resize_pos_embed_v2(patch_pos_embed, patch_num, 0)
else: # 去掉 cls_token
print('does not need to resize')
else: # remove cls_token
print('does not need to resize!')
pos_embed_new = patch_pos_embed
del state_dict['pos_embed']
state_dict['patch_pos_embed'] = pos_embed_new
Expand Down
Binary file added resources/demo.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit f3b392b

Please sign in to comment.