diff --git a/configs/localization/drn/README.md b/configs/localization/drn/README.md new file mode 100644 index 0000000000..7eb5b3edda --- /dev/null +++ b/configs/localization/drn/README.md @@ -0,0 +1,84 @@ +# DRN + +[Dense Regression Network for Video Grounding](https://openaccess.thecvf.com/content_CVPR_2020/papers/Zeng_Dense_Regression_Network_for_Video_Grounding_CVPR_2020_paper.pdf) + + + +## Abstract + + + +We address the problem of video grounding from natural language queries. The key challenge in this task is that one training video might only contain a few annotated starting/ending frames that can be used as positive examples for model training. Most conventional approaches directly train a binary classifier using such imbalance data, thus achieving inferior results. The key idea of this paper is to use the distances between the frame within the ground truth and the starting (ending) frame as dense supervisions to improve the video grounding accuracy. Specifically, we design a novel dense regression network (DRN) to regress the distances from each frame to the starting (ending) frame of the video segment described by the query. We also propose a simple but effective IoU regression head module to explicitly consider the localization quality of the grounding results (i.e., the IoU between the predicted location and the ground truth). Experimental results show that our approach significantly outperforms state-of-the-arts on three datasets (i.e., Charades-STA, ActivityNet-Captions, and TACoS). + + + +
+ +
+ +## Results and Models + +### Charades STA C3D feature + +| feature | gpus | pretrain | Recall@Top1(IoU=0.5) | Recall@Top5(IoU=0.5) | config | ckpt | log | +| :-----: | :--: | :------: | :------------------: | :------------------: | :----------------------------------------------: | :---------------------------------------------: | :--------------------------------------------: | +| C3D | 2 | None | 47.04 | 84.57 | [config](configs/localization/drn/drn_2xb16-4096-10e_c3d-feature_third.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/localization/drn/drn_2xb16-4096-10e_c3d-feature_20230809-ec0429a6.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/drn_2xb16-4096-10e_c3d-feature.log) | + +For more details on data preparation, you can refer to [Charades STA Data Preparation](/tools/data/charades-sta/README.md). + +## Train + +The training of DRN has three stages. Following the official paper, the second and the third stage loads the best checkpoint from previous stage. + +The first stage training: + +```shell +bash tools/dist_train.sh configs/localization/drn/drn_2xb16-4096-10e_c3d-feature_first.py 2 +``` + +The second stage training: + +```shell +BEST_CKPT=work_dirs/drn_2xb16-4096-10e_c3d-feature_first/SOME.PTH +bash tools/dist_train.sh configs/localization/drn/drn_2xb16-4096-10e_c3d-feature_second.py 2 --cfg-options load_from=${BEST_CKPT} +``` + +The third stage training: + +```shell +BEST_CKPT=work_dirs/drn_2xb16-4096-10e_c3d-feature_second/SOME.PTH +bash tools/dist_train.sh configs/localization/drn/drn_2xb16-4096-10e_c3d-feature_third.py 2 --cfg-options load_from=${BEST_CKPT} +``` + +## Test + +Test DRN on Charades STA C3D feature: + +```shell +python3 tools/test.py configs/localization/drn/drn_2xb16-4096-10e_c3d-feature_third.py CHECKPOINT.PTH +``` + +For more details, you can refer to the **Testing** part in the [Training and Test Tutorial](/docs/en/user_guides/train_test.md). + +## Citation + +```BibTeX +@inproceedings{DRN2020CVPR, + author = {Runhao, Zeng and Haoming, Xu and Wenbing, Huang and Peihao, Chen and Mingkui, Tan and Chuang Gan}, + title = {Dense Regression Network for Video Grounding}, + booktitle = {CVPR}, + year = {2020}, +} +``` + + + +```BibTeX +@inproceedings{gao2017tall, + title={Tall: Temporal activity localization via language query}, + author={Gao, Jiyang and Sun, Chen and Yang, Zhenheng and Nevatia, Ram}, + booktitle={Proceedings of the IEEE international conference on computer vision}, + pages={5267--5275}, + year={2017} +} +``` diff --git a/configs/localization/drn/drn_2xb16-4096-10e_c3d-feature_first.py b/configs/localization/drn/drn_2xb16-4096-10e_c3d-feature_first.py new file mode 100644 index 0000000000..e66076e962 --- /dev/null +++ b/configs/localization/drn/drn_2xb16-4096-10e_c3d-feature_first.py @@ -0,0 +1,115 @@ +_base_ = ['../../_base_/default_runtime.py'] + +# model settings +model = dict( + type='DRN', + vocab_size=1301, + feature_dim=4096, + embed_dim=300, + hidden_dim=512, + bidirection=True, + first_output_dim=256, + fpn_feature_dim=512, + lstm_layers=1, + graph_node_features=1024, + fcos_pre_nms_top_n=32, + fcos_inference_thr=0.05, + fcos_prior_prob=0.01, + focal_alpha=0.25, + focal_gamma=2.0, + fpn_stride=[1, 2, 4], + fcos_nms_thr=0.6, + fcos_conv_layers=1, + fcos_num_class=2, + is_first_stage=True, + is_second_stage=False) + +# dataset settings +dataset_type = 'CharadesSTADataset' +root = 'data/CharadesSTA' +data_root = f'{root}/C3D_unit16_overlap0.5_merged/' +data_root_val = f'{root}/C3D_unit16_overlap0.5_merged/' +ann_file_train = f'{root}/Charades_sta_train.txt' +ann_file_val = f'{root}/Charades_sta_test.txt' +ann_file_test = f'{root}/Charades_sta_test.txt' + +word2id_file = f'{root}/Charades_word2id.json' +fps_file = f'{root}/Charades_fps_dict.json' +duration_file = f'{root}/Charades_duration.json' +num_frames_file = f'{root}/Charades_frames_info.json' +window_size = 16 +ft_overlap = 0.5 + +train_pipeline = [ + dict( + type='PackLocalizationInputs', + keys=('gt_bbox', 'proposals'), + meta_keys=('vid_name', 'query_tokens', 'query_length', 'num_proposals', + 'num_frames')) +] + +val_pipeline = train_pipeline +test_pipeline = val_pipeline + +train_dataloader = dict( + batch_size=16, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + drop_last=True, + dataset=dict( + type=dataset_type, + ann_file=ann_file_train, + data_prefix=dict(video=data_root), + pipeline=train_pipeline, + word2id_file=word2id_file, + fps_file=fps_file, + duration_file=duration_file, + num_frames_file=num_frames_file, + window_size=window_size, + ft_overlap=ft_overlap), +) + +val_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + drop_last=True, + dataset=dict( + type=dataset_type, + ann_file=ann_file_val, + data_prefix=dict(video=data_root), + pipeline=val_pipeline, + word2id_file=word2id_file, + fps_file=fps_file, + duration_file=duration_file, + num_frames_file=num_frames_file, + window_size=window_size, + ft_overlap=ft_overlap), +) +test_dataloader = val_dataloader + +max_epochs = 10 +train_cfg = dict( + type='EpochBasedTrainLoop', + max_epochs=max_epochs, + val_begin=1, + val_interval=1) + +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +val_evaluator = dict(type='RecallatTopK', topK_list=(1, 5), threshold=0.5) +test_evaluator = val_evaluator + +optim_wrapper = dict( + optimizer=dict(type='Adam', lr=1e-3), + clip_grad=dict(max_norm=5, norm_type=2), +) + +param_scheduler = [ + dict(type='LinearLR', start_factor=0.1, by_epoch=True, begin=0, end=5), +] + +find_unused_parameters = True diff --git a/configs/localization/drn/drn_2xb16-4096-10e_c3d-feature_second.py b/configs/localization/drn/drn_2xb16-4096-10e_c3d-feature_second.py new file mode 100644 index 0000000000..46a671db4c --- /dev/null +++ b/configs/localization/drn/drn_2xb16-4096-10e_c3d-feature_second.py @@ -0,0 +1,110 @@ +_base_ = ['../../_base_/default_runtime.py'] + +# model settings +model = dict( + type='DRN', + vocab_size=1301, + feature_dim=4096, + embed_dim=300, + hidden_dim=512, + bidirection=True, + first_output_dim=256, + fpn_feature_dim=512, + lstm_layers=1, + graph_node_features=1024, + fcos_pre_nms_top_n=32, + fcos_inference_thr=0.05, + fcos_prior_prob=0.01, + focal_alpha=0.25, + focal_gamma=2.0, + fpn_stride=[1, 2, 4], + fcos_nms_thr=0.6, + fcos_conv_layers=1, + fcos_num_class=2, + is_first_stage=False, + is_second_stage=True) + +# dataset settings +dataset_type = 'CharadesSTADataset' +root = 'data/CharadesSTA' +data_root = f'{root}/C3D_unit16_overlap0.5_merged/' +data_root_val = f'{root}/C3D_unit16_overlap0.5_merged/' +ann_file_train = f'{root}/Charades_sta_train.txt' +ann_file_val = f'{root}/Charades_sta_test.txt' +ann_file_test = f'{root}/Charades_sta_test.txt' + +word2id_file = f'{root}/Charades_word2id.json' +fps_file = f'{root}/Charades_fps_dict.json' +duration_file = f'{root}/Charades_duration.json' +num_frames_file = f'{root}/Charades_frames_info.json' +window_size = 16 +ft_overlap = 0.5 + +train_pipeline = [ + dict( + type='PackLocalizationInputs', + keys=('gt_bbox', 'proposals'), + meta_keys=('vid_name', 'query_tokens', 'query_length', 'num_proposals', + 'num_frames')) +] + +val_pipeline = train_pipeline +test_pipeline = val_pipeline + +train_dataloader = dict( + batch_size=16, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + drop_last=True, + dataset=dict( + type=dataset_type, + ann_file=ann_file_train, + data_prefix=dict(video=data_root), + pipeline=train_pipeline, + word2id_file=word2id_file, + fps_file=fps_file, + duration_file=duration_file, + num_frames_file=num_frames_file, + window_size=window_size, + ft_overlap=ft_overlap), +) + +val_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + drop_last=True, + dataset=dict( + type=dataset_type, + ann_file=ann_file_val, + data_prefix=dict(video=data_root), + pipeline=val_pipeline, + word2id_file=word2id_file, + fps_file=fps_file, + duration_file=duration_file, + num_frames_file=num_frames_file, + window_size=window_size, + ft_overlap=ft_overlap), +) +test_dataloader = val_dataloader + +max_epochs = 10 +train_cfg = dict( + type='EpochBasedTrainLoop', + max_epochs=max_epochs, + val_begin=1, + val_interval=1) + +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +val_evaluator = dict(type='RecallatTopK', topK_list=(1, 5), threshold=0.5) +test_evaluator = val_evaluator + +optim_wrapper = dict( + optimizer=dict(type='Adam', lr=1e-5), + clip_grad=dict(max_norm=5, norm_type=2)) + +find_unused_parameters = True diff --git a/configs/localization/drn/drn_2xb16-4096-10e_c3d-feature_third.py b/configs/localization/drn/drn_2xb16-4096-10e_c3d-feature_third.py new file mode 100644 index 0000000000..2a286415bc --- /dev/null +++ b/configs/localization/drn/drn_2xb16-4096-10e_c3d-feature_third.py @@ -0,0 +1,110 @@ +_base_ = ['../../_base_/default_runtime.py'] + +# model settings +model = dict( + type='DRN', + vocab_size=1301, + feature_dim=4096, + embed_dim=300, + hidden_dim=512, + bidirection=True, + first_output_dim=256, + fpn_feature_dim=512, + lstm_layers=1, + graph_node_features=1024, + fcos_pre_nms_top_n=32, + fcos_inference_thr=0.05, + fcos_prior_prob=0.01, + focal_alpha=0.25, + focal_gamma=2.0, + fpn_stride=[1, 2, 4], + fcos_nms_thr=0.6, + fcos_conv_layers=1, + fcos_num_class=2, + is_first_stage=False, + is_second_stage=False) + +# dataset settings +dataset_type = 'CharadesSTADataset' +root = 'data/CharadesSTA' +data_root = f'{root}/C3D_unit16_overlap0.5_merged/' +data_root_val = f'{root}/C3D_unit16_overlap0.5_merged/' +ann_file_train = f'{root}/Charades_sta_train.txt' +ann_file_val = f'{root}/Charades_sta_test.txt' +ann_file_test = f'{root}/Charades_sta_test.txt' + +word2id_file = f'{root}/Charades_word2id.json' +fps_file = f'{root}/Charades_fps_dict.json' +duration_file = f'{root}/Charades_duration.json' +num_frames_file = f'{root}/Charades_frames_info.json' +window_size = 16 +ft_overlap = 0.5 + +train_pipeline = [ + dict( + type='PackLocalizationInputs', + keys=('gt_bbox', 'proposals'), + meta_keys=('vid_name', 'query_tokens', 'query_length', 'num_proposals', + 'num_frames')) +] + +val_pipeline = train_pipeline +test_pipeline = val_pipeline + +train_dataloader = dict( + batch_size=16, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + drop_last=True, + dataset=dict( + type=dataset_type, + ann_file=ann_file_train, + data_prefix=dict(video=data_root), + pipeline=train_pipeline, + word2id_file=word2id_file, + fps_file=fps_file, + duration_file=duration_file, + num_frames_file=num_frames_file, + window_size=window_size, + ft_overlap=ft_overlap), +) + +val_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + drop_last=True, + dataset=dict( + type=dataset_type, + ann_file=ann_file_val, + data_prefix=dict(video=data_root), + pipeline=val_pipeline, + word2id_file=word2id_file, + fps_file=fps_file, + duration_file=duration_file, + num_frames_file=num_frames_file, + window_size=window_size, + ft_overlap=ft_overlap), +) +test_dataloader = val_dataloader + +max_epochs = 10 +train_cfg = dict( + type='EpochBasedTrainLoop', + max_epochs=max_epochs, + val_begin=1, + val_interval=1) + +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +val_evaluator = dict(type='RecallatTopK', topK_list=(1, 5), threshold=0.5) +test_evaluator = val_evaluator + +optim_wrapper = dict( + optimizer=dict(type='Adam', lr=1e-6), + clip_grad=dict(max_norm=5, norm_type=2)) + +find_unused_parameters = True diff --git a/configs/localization/drn/metafile.yml b/configs/localization/drn/metafile.yml new file mode 100644 index 0000000000..d092668b1e --- /dev/null +++ b/configs/localization/drn/metafile.yml @@ -0,0 +1,26 @@ +Collections: +- Name: DRN + README: configs/localization/drn/README.md + Paper: + URL: https://openaccess.thecvf.com/content_CVPR_2020/papers/Zeng_Dense_Regression_Network_for_Video_Grounding_CVPR_2020_paper.pdf + Title: "Dense Regression Network for Video Grounding" + +Models: + - Name: drn_2xb16-4096-10e_c3d-feature_third + Config: configs/localization/drn/drn_2xb16-4096-10e_c3d-feature_third.py + In Collection: DRN + Metadata: + Batch Size: 16 + Epochs: 10 + Training Data: Charades STA + Training Resources: 2 GPUs + feature: C3D + Modality: RGB + Results: + - Dataset: Charades STA + Task: Video Grounding + Metrics: + Recall@Top1(IoU=0.5): 47.04 + Recall@Top5(IoU=0.5): 84.57 + Training Log: https://download.openmmlab.com/mmaction/v1.0/drn_2xb16-4096-10e_c3d-feature.log + Weights: https://download.openmmlab.com/mmaction/v1.0/localization/drn/drn_2xb16-4096-10e_c3d-feature_20230809-ec0429a6.pth diff --git a/mmaction/datasets/__init__.py b/mmaction/datasets/__init__.py index cc838f8f31..eef565309d 100644 --- a/mmaction/datasets/__init__.py +++ b/mmaction/datasets/__init__.py @@ -3,6 +3,7 @@ from .audio_dataset import AudioDataset from .ava_dataset import AVADataset, AVAKineticsDataset from .base import BaseActionDataset +from .charades_sta_dataset import CharadesSTADataset from .msrvtt_datasets import MSRVTTVQA, MSRVTTVQAMC, MSRVTTRetrieval from .pose_dataset import PoseDataset from .rawframe_dataset import RawframeDataset @@ -15,5 +16,5 @@ 'AVADataset', 'AVAKineticsDataset', 'ActivityNetDataset', 'AudioDataset', 'BaseActionDataset', 'PoseDataset', 'RawframeDataset', 'RepeatAugDataset', 'VideoDataset', 'repeat_pseudo_collate', 'VideoTextDataset', - 'MSRVTTRetrieval', 'MSRVTTVQA', 'MSRVTTVQAMC' + 'MSRVTTRetrieval', 'MSRVTTVQA', 'MSRVTTVQAMC', 'CharadesSTADataset' ] diff --git a/mmaction/datasets/charades_sta_dataset.py b/mmaction/datasets/charades_sta_dataset.py new file mode 100644 index 0000000000..aca9c9a6bb --- /dev/null +++ b/mmaction/datasets/charades_sta_dataset.py @@ -0,0 +1,124 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +from typing import Callable, List, Optional, Union + +import mmengine +import numpy as np +import torch +from mmengine.fileio import exists + +from mmaction.registry import DATASETS +from mmaction.utils import ConfigType +from .base import BaseActionDataset + +try: + import nltk + nltk_imported = True +except ImportError: + nltk_imported = False + + +@DATASETS.register_module() +class CharadesSTADataset(BaseActionDataset): + + def __init__(self, + ann_file: str, + pipeline: List[Union[dict, Callable]], + word2id_file: str, + fps_file: str, + duration_file: str, + num_frames_file: str, + window_size: int, + ft_overlap: float, + data_prefix: Optional[ConfigType] = dict(video=''), + test_mode: bool = False, + **kwargs): + if not nltk_imported: + raise ImportError('nltk is required for CharadesSTADataset') + + self.fps_info = mmengine.load(fps_file) + self.duration_info = mmengine.load(duration_file) + self.num_frames = mmengine.load(num_frames_file) + self.word2id = mmengine.load(word2id_file) + self.ft_interval = int(window_size * (1 - ft_overlap)) + + super().__init__( + ann_file, + pipeline=pipeline, + data_prefix=data_prefix, + test_mode=test_mode, + **kwargs) + + def load_data_list(self) -> List[dict]: + """Load annotation file to get video information.""" + exists(self.ann_file) + data_list = [] + with open(self.ann_file) as f: + anno_database = f.readlines() + + for item in anno_database: + first_part, query_sentence = item.strip().split('##') + query_sentence = query_sentence.replace('.', '') + query_words = nltk.word_tokenize(query_sentence) + query_tokens = [self.word2id[word] for word in query_words] + query_length = len(query_tokens) + query_tokens = torch.from_numpy(np.array(query_tokens)) + + vid_name, start_time, end_time = first_part.split() + duration = float(self.duration_info[vid_name]) + fps = float(self.fps_info[vid_name]) + + gt_start_time = float(start_time) + gt_end_time = float(end_time) + + gt_bbox = (gt_start_time / duration, min(gt_end_time / duration, + 1)) + + num_frames = int(self.num_frames[vid_name]) + proposal_frames = self.get_proposals(num_frames) + + proposals = proposal_frames / num_frames + proposals = torch.from_numpy(proposals) + proposal_indexes = proposal_frames / self.ft_interval + proposal_indexes = proposal_indexes.astype(np.int32) + + info = dict( + vid_name=vid_name, + fps=fps, + num_frames=num_frames, + duration=duration, + query_tokens=query_tokens, + query_length=query_length, + gt_start_time=gt_start_time, + gt_end_time=gt_end_time, + gt_bbox=gt_bbox, + proposals=proposals, + num_proposals=proposals.shape[0], + proposal_indexes=proposal_indexes) + data_list.append(info) + return data_list + + def get_proposals(self, num_frames): + proposals = (num_frames - 1) / 32 * np.arange(33) + proposals = proposals.astype(np.int32) + proposals = np.stack([proposals[:-1], proposals[1:]]).T + return proposals + + def get_data_info(self, idx: int) -> dict: + """Get annotation by index.""" + data_info = super().get_data_info(idx) + vid_name = data_info['vid_name'] + feature_path = os.path.join(self.data_prefix['video'], + f'{vid_name}.pt') + vid_feature = torch.load(feature_path) + proposal_feats = [] + proposal_indexes = data_info['proposal_indexes'].clip( + max=vid_feature.shape[0] - 1) + for s, e in proposal_indexes: + prop_feature, _ = vid_feature[s:e + 1].max(dim=0) + proposal_feats.append(prop_feature) + + proposal_feats = torch.stack(proposal_feats) + + data_info['raw_feature'] = proposal_feats + return data_info diff --git a/mmaction/evaluation/metrics/__init__.py b/mmaction/evaluation/metrics/__init__.py index 341ec577ce..fd50aded2e 100644 --- a/mmaction/evaluation/metrics/__init__.py +++ b/mmaction/evaluation/metrics/__init__.py @@ -5,9 +5,10 @@ from .multimodal_metric import VQAMCACC, ReportVQA, RetrievalRecall, VQAAcc from .multisports_metric import MultiSportsMetric from .retrieval_metric import RetrievalMetric +from .video_grounding_metric import RecallatTopK __all__ = [ 'AccMetric', 'AVAMetric', 'ANetMetric', 'ConfusionMatrix', 'MultiSportsMetric', 'RetrievalMetric', 'VQAAcc', 'ReportVQA', 'VQAMCACC', - 'RetrievalRecall' + 'RetrievalRecall', 'RecallatTopK' ] diff --git a/mmaction/evaluation/metrics/video_grounding_metric.py b/mmaction/evaluation/metrics/video_grounding_metric.py new file mode 100644 index 0000000000..310db64452 --- /dev/null +++ b/mmaction/evaluation/metrics/video_grounding_metric.py @@ -0,0 +1,66 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Any, Optional, Sequence, Tuple + +from mmengine.evaluator import BaseMetric + +from mmaction.registry import METRICS + + +@METRICS.register_module() +class RecallatTopK(BaseMetric): + """ActivityNet dataset evaluation metric.""" + + def __init__(self, + topK_list: Tuple[int] = (1, 5), + threshold: float = 0.5, + collect_device: str = 'cpu', + prefix: Optional[str] = None): + super().__init__(collect_device=collect_device, prefix=prefix) + self.topK_list = topK_list + self.threshold = threshold + + def process(self, data_batch: Sequence[Tuple[Any, dict]], + predictions: Sequence[dict]) -> None: + """Process one batch of data samples and predictions. The processed + results should be stored in ``self.results``, which will be used to + compute the metrics when all batches have been processed. + + Args: + data_batch (Sequence[Tuple[Any, dict]]): A batch of data + from the dataloader. + predictions (Sequence[dict]): A batch of outputs from + the model. + """ + for pred in predictions: + self.results.append(pred) + + def compute_metrics(self, results: list) -> dict: + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + Returns: + dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + eval_results = dict() + for topK in self.topK_list: + total = len(results) + correct = 0.0 + for result in results: + gt = result['gt'] + predictions = result['predictions'][:topK] + for prediction in predictions: + IoU = self.calculate_IoU(gt, prediction) + if IoU > self.threshold: + correct += 1 + break + acc = correct / total + eval_results[f'Recall@Top{topK}_IoU={self.threshold}'] = acc + return eval_results + + def calculate_IoU(self, i0, i1): + union = (min(i0[0], i1[0]), max(i0[1], i1[1])) + inter = (max(i0[0], i1[0]), min(i0[1], i1[1])) + iou = (inter[1] - inter[0]) / (union[1] - union[0]) + return iou diff --git a/mmaction/models/localizers/__init__.py b/mmaction/models/localizers/__init__.py index 26e016410b..debd9a16f4 100644 --- a/mmaction/models/localizers/__init__.py +++ b/mmaction/models/localizers/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .bmn import BMN from .bsn import PEM, TEM +from .drn.drn import DRN from .tcanet import TCANet -__all__ = ['TEM', 'PEM', 'BMN', 'TCANet'] +__all__ = ['TEM', 'PEM', 'BMN', 'TCANet', 'DRN'] diff --git a/mmaction/models/localizers/drn/drn.py b/mmaction/models/localizers/drn/drn.py new file mode 100644 index 0000000000..869791e6bb --- /dev/null +++ b/mmaction/models/localizers/drn/drn.py @@ -0,0 +1,260 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence + +import numpy as np +import torch +import torch.nn as nn +from mmengine.model import BaseModel + +from mmaction.registry import MODELS +from mmaction.utils import OptConfigType +from ..utils import soft_nms +from .drn_utils import FPN, Backbone, FCOSModule, QueryEncoder + + +@MODELS.register_module() +class DRN(BaseModel): + """Dense Regression Network for Video Grounding. + + Please refer `Dense Regression Network for Video Grounding + `_. + Code Reference: https://github.com/Alvin-Zeng/DRN + + Args: + vocab_size (int): number of all possible words in the query. + Defaults to 1301. + hidden_dim (int): the hidden dimension of the LSTM in the + language model. Defaults to 512. + embed_dim (int): the embedding dimension of the query. Defaults + to 300. + bidirection (bool): if True, use bi-direction LSTM in the + language model. Defaults to True. + first_output_dim (int): the output dimension of the first layer + in the backbone. Defaults to 256. + fpn_feature_dim (int): the output dimension of the FPN. Defaults + to 512. + feature_dim (int): the dimension of the video clip feature. + lstm_layers (int): the number of LSTM layers in the language model. + Defaults to 1. + fcos_pre_nms_top_n (int): value of Top-N in the FCOS module before + nms. Defaults to 32. + fcos_inference_thr (float): threshold in the FOCS inference. BBoxes + with scores higher than this threshold are regarded as positive. + Defaults to 0.05. + fcos_prior_prob (float): A prior probability of the positive bboexes. + Used to initialized the bias of the classification head. + Defaults to 0.01. + focal_alpha (float):Focal loss hyper-parameter alpha. + Defaults to 0.25. + focal_gamma (float): Focal loss hyper-parameter gamma. + Defaults to 2.0. + fpn_stride (Sequence[int]): the strides in the FPN. Defaults to + [1, 2, 4]. + fcos_nms_thr (float): NMS threshold in the FOCS module. + Defaults to 0.6. + fcos_conv_layers (int): number of convolution layers in FCOS. + Defaults to 1. + fcos_num_class (int): number of classes in FCOS. + Defaults to 2. + is_first_stage (bool): if true, the model is in the first stage + training. + is_second_stage (bool): if true, the model is in the second stage + training. + """ + + def __init__(self, + vocab_size: int = 1301, + hidden_dim: int = 512, + embed_dim: int = 300, + bidirection: bool = True, + first_output_dim: int = 256, + fpn_feature_dim: int = 512, + feature_dim: int = 4096, + lstm_layers: int = 1, + fcos_pre_nms_top_n: int = 32, + fcos_inference_thr: float = 0.05, + fcos_prior_prob: float = 0.01, + focal_alpha: float = 0.25, + focal_gamma: float = 2.0, + fpn_stride: Sequence[int] = [1, 2, 4], + fcos_nms_thr: float = 0.6, + fcos_conv_layers: int = 1, + fcos_num_class: int = 2, + is_first_stage: bool = False, + is_second_stage: bool = False, + init_cfg: OptConfigType = None, + **kwargs) -> None: + super(DRN, self).__init__(init_cfg) + + self.query_encoder = QueryEncoder( + vocab_size=vocab_size, + hidden_dim=hidden_dim, + embed_dim=embed_dim, + num_layers=lstm_layers, + bidirection=bidirection) + + channels_list = [ + (feature_dim + 256, first_output_dim, 3, 1), + (first_output_dim, first_output_dim * 2, 3, 2), + (first_output_dim * 2, first_output_dim * 4, 3, 2), + ] + self.backbone_net = Backbone(channels_list) + + self.fpn = FPN( + in_channels_list=[256, 512, 1024], out_channels=fpn_feature_dim) + + self.fcos = FCOSModule( + in_channels=fpn_feature_dim, + fcos_num_class=fcos_num_class, + fcos_conv_layers=fcos_conv_layers, + fcos_prior_prob=fcos_prior_prob, + fcos_inference_thr=fcos_inference_thr, + fcos_pre_nms_top_n=fcos_pre_nms_top_n, + fcos_nms_thr=fcos_nms_thr, + test_detections_per_img=32, + fpn_stride=fpn_stride, + focal_alpha=focal_alpha, + focal_gamma=focal_gamma, + is_first_stage=is_first_stage, + is_second_stage=is_second_stage) + + self.prop_fc = nn.Linear(feature_dim, feature_dim) + self.position_transform = nn.Linear(3, 256) + + qInput = [] + for t in range(len(channels_list)): + if t > 0: + qInput += [nn.Linear(1024, channels_list[t - 1][1])] + else: + qInput += [nn.Linear(1024, feature_dim)] + self.qInput = nn.ModuleList(qInput) + + self.is_second_stage = is_second_stage + + def forward(self, inputs, data_samples, mode, **kwargs): + props_features = torch.stack(inputs) + batch_size = props_features.shape[0] + device = props_features.device + proposals = torch.stack([ + sample.proposals['proposals'] for sample in data_samples + ]).to(device) + gt_bbox = torch.stack([ + sample.gt_instances['gt_bbox'] for sample in data_samples + ]).to(device) + + video_info = [i.metainfo for i in data_samples] + query_tokens_ = [i['query_tokens'] for i in video_info] + query_length = [i['query_length'] for i in video_info] + query_length = torch.from_numpy(np.array(query_length)) + + max_query_len = max([i.shape[0] for i in query_tokens_]) + query_tokens = torch.zeros(batch_size, max_query_len) + for idx, query_token in enumerate(query_tokens_): + query_len = query_token.shape[0] + query_tokens[idx, :query_len] = query_token + + query_tokens = query_tokens.to(device).long() + query_length = query_length.to(device).long() # should be on CPU! + + sort_index = query_length.argsort(descending=True) + box_lists, loss_dict = self._forward(query_tokens[sort_index], + query_length[sort_index], + props_features[sort_index], + proposals[sort_index], + gt_bbox[sort_index]) + if mode == 'loss': + return loss_dict + elif mode == 'predict': + # only support batch size = 1 + bbox = box_lists[0] + + per_vid_detections = bbox['detections'] + per_vid_scores = bbox['scores'] + + props_pred = torch.cat( + (per_vid_detections, per_vid_scores.unsqueeze(-1)), dim=-1) + + props_pred = props_pred.cpu().numpy() + props_pred = sorted(props_pred, key=lambda x: x[-1], reverse=True) + props_pred = np.array(props_pred) + + props_pred = soft_nms( + props_pred, + alpha=0.4, + low_threshold=0.5, + high_threshold=0.9, + top_k=5) + result = { + 'vid_name': data_samples[0].metainfo['vid_name'], + 'gt': gt_bbox[0].cpu().numpy(), + 'predictions': props_pred, + } + return [result] + + raise ValueError(f'Unsupported mode {mode}!') + + def nms_temporal(self, start, end, score, overlap=0.45): + pick = [] + assert len(start) == len(score) + assert len(end) == len(score) + if len(start) == 0: + return pick + + union = end - start + # sort and get index + intervals = [ + i[0] for i in sorted(enumerate(score), key=lambda x: x[1]) + ] + + while len(intervals) > 0: + i = intervals[-1] + pick.append(i) + + xx1 = [max(start[i], start[j]) for j in intervals[:-1]] + xx2 = [min(end[i], end[j]) for j in intervals[:-1]] + inter = [max(0., k2 - k1) for k1, k2 in zip(xx1, xx2)] + o = [ + inter[u] / (union[i] + union[intervals[u]] - inter[u]) + for u in range(len(intervals) - 1) + ] + I_new = [] + for j in range(len(o)): + if o[j] <= overlap: + I_new.append(intervals[j]) + intervals = I_new + return np.array(pick) + + def _forward(self, query_tokens, query_length, props_features, + props_start_end, gt_bbox): + + position_info = [props_start_end, props_start_end] + position_feats = [] + query_features = self.query_encoder(query_tokens, query_length) + for i in range(len(query_features)): + query_features[i] = self.qInput[i](query_features[i]) + if i > 1: + position_info.append( + torch.cat([ + props_start_end[:, ::2 * (i - 1), [0]], + props_start_end[:, 1::2 * (i - 1), [1]] + ], + dim=-1)) + props_duration = position_info[i][:, :, 1] - position_info[i][:, :, + 0] + props_duration = props_duration.unsqueeze(-1) + position_feat = torch.cat((position_info[i], props_duration), + dim=-1).float() + position_feats.append( + self.position_transform(position_feat).permute(0, 2, 1)) + + props_features = self.prop_fc(props_features) + + inputs = props_features.permute(0, 2, 1) + outputs = self.backbone_net(inputs, query_features, position_feats) + outputs = self.fpn(outputs) + + if self.is_second_stage: + outputs = [_.detach() for _ in outputs] + box_lists, loss_dict = self.fcos(outputs, gt_bbox.float()) + + return box_lists, loss_dict diff --git a/mmaction/models/localizers/drn/drn_utils/FPN.py b/mmaction/models/localizers/drn/drn_utils/FPN.py new file mode 100644 index 0000000000..1170ac5cf3 --- /dev/null +++ b/mmaction/models/localizers/drn/drn_utils/FPN.py @@ -0,0 +1,44 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +import torch.nn.functional as F +from torch import Tensor, nn + +from .backbone import conv_block + + +class FPN(nn.Module): + + def __init__(self, in_channels_list: List, out_channels: int) -> None: + super(FPN, self).__init__() + + inner_blocks = [] + layer_blocks = [] + for idx, in_channels in enumerate(in_channels_list, 1): + inner_block = conv_block(in_channels, out_channels, 1, 1) + layer_block = conv_block(out_channels, out_channels, 3, 1) + + inner_blocks.append(inner_block) + layer_blocks.append(layer_block) + + self.inner_blocks = nn.ModuleList(inner_blocks) + self.layer_blocks = nn.ModuleList(layer_blocks) + + def forward(self, x: Tensor) -> Tuple[Tensor]: + # process the last lowest resolution feat and + # first feed it into 1 x 1 conv + last_inner = self.inner_blocks[-1](x[-1]) + results = [self.layer_blocks[-1](last_inner)] + + for feature, inner_block, layer_block in zip( + x[:-1][::-1], self.inner_blocks[:-1][::-1], + self.layer_blocks[:-1][::-1]): + if not inner_block: + continue + inner_top_down = F.interpolate( + last_inner, scale_factor=2, mode='nearest') + inner_lateral = inner_block(feature) + last_inner = inner_lateral + inner_top_down + results.insert(0, layer_block(last_inner)) + + return tuple(results) diff --git a/mmaction/models/localizers/drn/drn_utils/__init__.py b/mmaction/models/localizers/drn/drn_utils/__init__.py new file mode 100644 index 0000000000..4d371a5055 --- /dev/null +++ b/mmaction/models/localizers/drn/drn_utils/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .backbone import Backbone +from .fcos import FCOSModule +from .FPN import FPN +from .language_module import QueryEncoder + +__all__ = ['Backbone', 'FPN', 'QueryEncoder', 'FCOSModule'] diff --git a/mmaction/models/localizers/drn/drn_utils/backbone.py b/mmaction/models/localizers/drn/drn_utils/backbone.py new file mode 100644 index 0000000000..ac2c6338d0 --- /dev/null +++ b/mmaction/models/localizers/drn/drn_utils/backbone.py @@ -0,0 +1,48 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +import torch +from torch import Tensor, nn + + +def conv_block(in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 1) -> nn.Module: + module = nn.Sequential( + nn.Conv1d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=(kernel_size - 1) // 2, + bias=False), nn.BatchNorm1d(out_channels), nn.ReLU()) + return module + + +class Backbone(nn.Module): + + def __init__(self, channels_list: List[tuple]) -> None: + super(Backbone, self).__init__() + + self.num_layers = len(channels_list) + layers = [] + for idx, channels_config in enumerate(channels_list): + layer = conv_block(*channels_config) + layers.append(layer) + self.layers = nn.ModuleList(layers) + + def forward(self, x: Tensor, query_fts: Tensor, + position_fts: Tensor) -> Tuple[Tensor]: + results = [] + + for idx in range(self.num_layers): + query_ft = query_fts[idx].unsqueeze(1).permute(0, 2, 1) + position_ft = position_fts[idx] + x = query_ft * x + if idx == 0: + x = torch.cat([x, position_ft], dim=1) + x = self.layers[idx](x) + results.append(x) + + return tuple(results) diff --git a/mmaction/models/localizers/drn/drn_utils/fcos.py b/mmaction/models/localizers/drn/drn_utils/fcos.py new file mode 100644 index 0000000000..33b30c4cb1 --- /dev/null +++ b/mmaction/models/localizers/drn/drn_utils/fcos.py @@ -0,0 +1,192 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +from torch import nn + +from .inference import make_fcos_postprocessor +from .loss import make_fcos_loss_evaluator + + +class Scale(nn.Module): + + def __init__(self, init_value=1.0): + super(Scale, self).__init__() + self.scale = nn.Parameter(torch.FloatTensor([init_value])) + + def forward(self, x): + return x * self.scale + + +class FCOSHead(torch.nn.Module): + + def __init__(self, in_channels: int, fcos_num_class: int, + fcos_conv_layers: int, fcos_prior_prob: float, + is_second_stage: bool) -> None: + super(FCOSHead, self).__init__() + num_classes = fcos_num_class - 1 + + cls_tower = [] + bbox_tower = [] + for i in range(fcos_conv_layers): + cls_tower.append( + nn.Conv1d( + in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1)) + cls_tower.append(nn.BatchNorm1d(in_channels)) + cls_tower.append(nn.ReLU()) + bbox_tower.append( + nn.Conv1d( + in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1)) + bbox_tower.append(nn.BatchNorm1d(in_channels)) + bbox_tower.append(nn.ReLU()) + + self.cls_tower = nn.Sequential(*cls_tower) + self.bbox_tower = nn.Sequential(*bbox_tower) + self.cls_logits = nn.Conv1d( + in_channels, num_classes, kernel_size=3, stride=1, padding=1) + + self.bbox_pred = nn.Conv1d( + in_channels, 2, kernel_size=3, stride=1, padding=1) + + self.mix_fc = nn.Sequential( + nn.Conv1d(2 * in_channels, in_channels, kernel_size=1, stride=1), + nn.BatchNorm1d(in_channels), nn.ReLU()) + + self.iou_scores = nn.Sequential( + nn.Conv1d( + in_channels, + in_channels // 2, + kernel_size=3, + stride=1, + padding=1), + nn.BatchNorm1d(in_channels // 2), + nn.ReLU(), + nn.Conv1d(in_channels // 2, 1, kernel_size=1, stride=1), + ) + + # initialization + for module in self.modules(): + if isinstance(module, nn.Conv1d): + torch.nn.init.normal_(module.weight, std=0.01) + torch.nn.init.constant_(module.bias, 0) + + # initialize the bias for focal loss + bias_value = -math.log((1 - fcos_prior_prob) / fcos_prior_prob) + torch.nn.init.constant_(self.cls_logits.bias, bias_value) + + self.scales = nn.ModuleList([Scale(init_value=1.0) for _ in range(3)]) + self.is_second_stage = is_second_stage + + def forward(self, x): + logits = [] + bbox_reg = [] + iou_scores = [] + for idx, feature in enumerate(x): + cls_tower = self.cls_tower(feature) + box_tower = self.bbox_tower(feature) + logits.append(self.cls_logits(cls_tower)) + + bbox_reg_ = torch.exp(self.scales[idx](self.bbox_pred(box_tower))) + if self.is_second_stage: + bbox_reg_ = bbox_reg_.detach() + bbox_reg.append(bbox_reg_) + + mix_feature = torch.cat([cls_tower, box_tower], dim=1) + if self.is_second_stage: + mix_feature = mix_feature.detach() + mix_feature = self.mix_fc(mix_feature) + iou_scores.append(self.iou_scores(mix_feature)) + return logits, bbox_reg, iou_scores + + +class FCOSModule(torch.nn.Module): + + def __init__(self, in_channels: int, fcos_num_class: int, + fcos_conv_layers: int, fcos_prior_prob: float, + fcos_inference_thr: float, fcos_pre_nms_top_n: int, + fcos_nms_thr: float, test_detections_per_img: int, + fpn_stride: int, focal_alpha: float, focal_gamma: float, + is_first_stage: bool, is_second_stage: bool) -> None: + super(FCOSModule, self).__init__() + + head = FCOSHead( + in_channels=in_channels, + fcos_num_class=fcos_num_class, + fcos_conv_layers=fcos_conv_layers, + fcos_prior_prob=fcos_prior_prob, + is_second_stage=is_second_stage) + + self.is_first_stage = is_first_stage + self.is_second_stage = is_second_stage + box_selector_test = make_fcos_postprocessor(fcos_num_class, + fcos_inference_thr, + fcos_pre_nms_top_n, + fcos_nms_thr, + test_detections_per_img, + is_first_stage) + loss_evaluator = make_fcos_loss_evaluator(focal_alpha, focal_gamma) + self.head = head + self.box_selector_test = box_selector_test + self.loss_evaluator = loss_evaluator + self.fpn_strides = fpn_stride + + def forward(self, features, targets=None): + box_cls, box_regression, iou_scores = self.head(features) + locations = self.compute_locations(features) + + if self.training: + return self._forward_train(locations, box_cls, box_regression, + targets, iou_scores) + else: + return self._forward_test(locations, box_cls, box_regression, + targets, iou_scores) + + def _forward_train(self, locations, box_cls, box_regression, targets, + iou_scores): + loss_box_cls, loss_box_reg, loss_iou = self.loss_evaluator( + locations, box_cls, box_regression, targets, iou_scores, + self.is_first_stage) + + if self.is_second_stage: + loss_box_cls = loss_box_cls.detach() + loss_box_reg = loss_box_reg.detach() + if self.is_first_stage: + loss_iou = loss_iou.detach() + + losses = { + 'loss_cls': loss_box_cls, + 'loss_reg': loss_box_reg, + 'loss_iou': loss_iou + } + return None, losses + + def _forward_test(self, locations, box_cls, box_regression, targets, + iou_scores): + boxes = self.box_selector_test(locations, box_cls, box_regression, + iou_scores) + losses = None + return boxes, losses + + def compute_locations(self, features): + locations = [] + for level, feature in enumerate(features): + t = feature.size(-1) + locations_per_level = self.compute_locations_per_level( + t, self.fpn_strides[level], feature.device) + locations.append(locations_per_level) + return locations + + def compute_locations_per_level(self, t, stride, device): + shifts_t = torch.arange( + 0, t * stride, step=stride, dtype=torch.float32, device=device) + shifts_t = shifts_t.reshape(-1) + locations = shifts_t + stride / 2 + return locations diff --git a/mmaction/models/localizers/drn/drn_utils/inference.py b/mmaction/models/localizers/drn/drn_utils/inference.py new file mode 100644 index 0000000000..09cc7ef989 --- /dev/null +++ b/mmaction/models/localizers/drn/drn_utils/inference.py @@ -0,0 +1,212 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Copied from https://github.com/Alvin-Zeng/DRN/""" + +import torch + + +class FCOSPostProcessor(torch.nn.Module): + """Performs post-processing on the outputs of the RetinaNet boxes. + + This is only used in the testing. + """ + + def __init__(self, pre_nms_thresh, pre_nms_top_n, nms_thresh, + fpn_post_nms_top_n, min_size, num_classes, is_first_stage): + """ + Arguments: + pre_nms_thresh (float) + pre_nms_top_n (int) + nms_thresh (float) + fpn_post_nms_top_n (int) + min_size (int) + num_classes (int) + box_coder (BoxCoder) + """ + super(FCOSPostProcessor, self).__init__() + self.pre_nms_thresh = pre_nms_thresh + self.pre_nms_top_n = pre_nms_top_n + self.nms_thresh = nms_thresh + self.fpn_post_nms_top_n = fpn_post_nms_top_n + self.min_size = min_size + self.num_classes = num_classes + self.innerness_threshold = 0.15 + self.downsample_scale = 32 + self.is_first_stage = is_first_stage + + def forward_for_single_feature_map(self, locations, box_cls, + box_regression, level, iou_scores): + """ + Arguments: + anchors: list[BoxList] + box_cls: tensor of size N, A * C, H, W + box_regression: tensor of size N, A * 4, H, W + """ + N, C, T = box_cls.shape + + # put in the same format as locations + box_cls = box_cls.permute(0, 2, 1).contiguous().sigmoid() + iou_scores = iou_scores.permute(0, 2, 1).contiguous().sigmoid() + box_regression = box_regression.permute(0, 2, 1) + + # centerness = centerness.permute(0, 2, 1) + # centerness = centerness.reshape(N, -1).sigmoid() + # inner = inner.squeeze().sigmoid() + + candidate_inds = (box_cls > self.pre_nms_thresh) + pre_nms_top_n = candidate_inds.view(N, -1).sum(1) + pre_nms_top_n = pre_nms_top_n.clamp(max=self.pre_nms_top_n) + + # multiply the classification scores with centerness scores + # box_cls = box_cls * centerness[:, :, None] + # box_cls = box_cls + centerness[:, :, None] + if not self.is_first_stage: + box_cls = box_cls * iou_scores + + results = [] + for i in range(N): + + # per_centerness = centerness[i] + + per_box_cls = box_cls[i] + per_candidate_inds = candidate_inds[i] + per_box_cls = per_box_cls[per_candidate_inds] + + per_candidate_nonzeros = per_candidate_inds.nonzero() + per_box_loc = per_candidate_nonzeros[:, 0] + per_class = per_candidate_nonzeros[:, 1] + 1 + + per_box_regression = box_regression[i] + per_box_regression = per_box_regression[per_box_loc] + per_locations = locations[per_box_loc] + + # per_centerness = per_centerness[per_box_loc] + + per_pre_nms_top_n = pre_nms_top_n[i] + + if per_candidate_inds.sum().item() > per_pre_nms_top_n.item(): + per_box_cls, top_k_indices = \ + per_box_cls.topk(per_pre_nms_top_n, sorted=False) + per_class = per_class[top_k_indices] + per_box_regression = per_box_regression[top_k_indices] + per_locations = per_locations[top_k_indices] + + # per_centerness = per_centerness[top_k_indices] + + detections = torch.stack([ + per_locations - per_box_regression[:, 0], + per_locations + per_box_regression[:, 1], + ], + dim=1) / self.downsample_scale + + detections[:, 0].clamp_(min=0, max=1) + detections[:, 1].clamp_(min=0, max=1) + + # remove small boxes + p_start, p_end = detections.unbind(dim=1) + duration = p_end - p_start + keep = (duration >= self.min_size).nonzero().squeeze(1) + detections = detections[keep] + + temp_dict = {} + temp_dict['detections'] = detections + temp_dict['labels'] = per_class + temp_dict['scores'] = torch.sqrt(per_box_cls) + temp_dict['level'] = [level] + # temp_dict['centerness'] = per_centerness + temp_dict['locations'] = per_locations / 32 + + results.append(temp_dict) + + return results + + def forward(self, locations, box_cls, box_regression, iou_scores): + """ + Arguments: + anchors: list[list[BoxList]] + box_cls: list[tensor] + box_regression: list[tensor] + image_sizes: list[(h, w)] + Returns: + boxlists (list[BoxList]): the post-processed anchors, after + applying box decoding and NMS + """ + sampled_boxes = [] + for i, (l, o, b, iou_s) in enumerate( + zip(locations, box_cls, box_regression, iou_scores)): + sampled_boxes.append( + self.forward_for_single_feature_map(l, o, b, i, iou_s)) + + boxlists = list(zip(*sampled_boxes)) + # boxlists = [cat_boxlist(boxlist) for boxlist in boxlists] + boxlists = self.select_over_all_levels(boxlists) + + return boxlists + + # TODO very similar to filter_results from PostProcessor + # but filter_results is per image + # TODO Yang: solve this issue in the future. No good solution + # right now. + def select_over_all_levels(self, boxlists): + num_images = len(boxlists) + results = [] + for i in range(num_images): + dicts = boxlists[i] + per_vid_scores = [] + per_vid_detections = [] + per_vid_labels = [] + # add level number + per_vid_level = [] + per_vid_locations = [] + # per_vid_centerness = [] + for per_scale_dict in dicts: + if len(per_scale_dict['detections']) != 0: + per_vid_detections.append(per_scale_dict['detections']) + if len(per_scale_dict['scores']) != 0: + per_vid_scores.append(per_scale_dict['scores']) + if len(per_scale_dict['level']) != 0: + per_vid_level.append(per_scale_dict['level'] * + len(per_scale_dict['detections'])) + + if len(per_scale_dict['locations']) != 0: + per_vid_locations.append(per_scale_dict['locations']) + + # if len(per_scale_dict['centerness']) != 0: + # per_vid_centerness.append(per_scale_dict['centerness']) + if len(per_vid_detections) == 0: + per_vid_detections = torch.Tensor([0, 1]).unsqueeze(0) + per_vid_scores = torch.Tensor([1]) + per_vid_level = [[-1]] + per_vid_locations = torch.Tensor([0.5]) + # per_vid_centerness = torch.Tensor([0.5]).cuda() + else: + per_vid_detections = torch.cat(per_vid_detections, dim=0) + per_vid_scores = torch.cat(per_vid_scores, dim=0) + per_vid_level = per_vid_level + per_vid_locations = torch.cat(per_vid_locations, dim=0) + # per_vid_centerness = torch.cat(per_vid_centerness, dim=0) + + temp_dict = {} + temp_dict['detections'] = per_vid_detections + temp_dict['labels'] = per_vid_labels + temp_dict['scores'] = per_vid_scores + temp_dict['level'] = per_vid_level + # temp_dict['centerness'] = per_vid_centerness + temp_dict['locations'] = per_vid_locations + results.append(temp_dict) + + return results + + +def make_fcos_postprocessor(fcos_num_class, fcos_inference_thr, + fcos_pre_nms_top_n, fcos_nms_thr, + test_detections_per_img, is_first_stage): + box_selector = FCOSPostProcessor( + pre_nms_thresh=fcos_inference_thr, + pre_nms_top_n=fcos_pre_nms_top_n, + nms_thresh=fcos_nms_thr, + fpn_post_nms_top_n=test_detections_per_img, + min_size=0, + num_classes=fcos_num_class, + is_first_stage=is_first_stage) + + return box_selector diff --git a/mmaction/models/localizers/drn/drn_utils/language_module.py b/mmaction/models/localizers/drn/drn_utils/language_module.py new file mode 100644 index 0000000000..135652a5eb --- /dev/null +++ b/mmaction/models/localizers/drn/drn_utils/language_module.py @@ -0,0 +1,92 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import torch +from torch import Tensor, nn +from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence + + +class QueryEncoder(nn.Module): + + def __init__(self, + vocab_size: int, + hidden_dim: int = 512, + embed_dim: int = 300, + num_layers: int = 1, + bidirection: bool = True) -> None: + super(QueryEncoder, self).__init__() + self.hidden_dim = hidden_dim + self.embed_dim = embed_dim + self.embedding = nn.Embedding( + num_embeddings=vocab_size + 1, + embedding_dim=embed_dim, + padding_idx=0) + # self.embedding.weight.data.copy_(torch.load('glove_weights')) + self.biLSTM = nn.LSTM( + input_size=embed_dim, + hidden_size=self.hidden_dim, + num_layers=num_layers, + dropout=0.0, + batch_first=True, + bidirectional=bidirection) + + self.W3 = nn.Linear(hidden_dim * 4, hidden_dim) + self.W2 = nn.ModuleList( + [nn.Linear(hidden_dim, hidden_dim * 2) for _ in range(3)]) + self.W1 = nn.Linear(hidden_dim * 2, 1) + + def extract_textual(self, q_encoding: Tensor, lstm_outputs: Tensor, + q_length: Tensor, t: int): + q_cmd = self.W3(q_encoding).relu() + q_cmd = self.W2[t](q_cmd) + q_cmd = q_cmd[:, None, :] * lstm_outputs + raw_att = self.W1(q_cmd).squeeze(-1) + + raw_att = apply_mask1d(raw_att, q_length) + att = raw_att.softmax(dim=-1) + cmd = torch.bmm(att[:, None, :], lstm_outputs).squeeze(1) + return cmd + + def forward(self, query_tokens: Tensor, + query_length: Tensor) -> List[Tensor]: + self.biLSTM.flatten_parameters() + + query_embedding = self.embedding(query_tokens) + + # output denotes the forward and backward hidden states in Eq 2. + query_embedding = pack_padded_sequence( + query_embedding, query_length.cpu(), batch_first=True) + output, _ = self.biLSTM(query_embedding) + output, _ = pad_packed_sequence(output, batch_first=True) + + # q_vector denotes the global representation `g` in Eq 2. + q_vector_list = [] + + for i, length in enumerate(query_length): + h1 = output[i][0] + hs = output[i][length - 1] + q_vector = torch.cat((h1, hs), dim=-1) + q_vector_list.append(q_vector) + q_vector = torch.stack(q_vector_list) + # outputs denotes the query feature in Eq3 in 3 levels. + outputs = [] + for cmd_t in range(3): + query_feat = self.extract_textual(q_vector, output, query_length, + cmd_t) + outputs.append(query_feat) + + # Note: the output here is zero-padded + # we need slice the non-zero items for the following operations. + return outputs + + +def apply_mask1d(attention: Tensor, image_locs: Tensor) -> Tensor: + batch_size, num_loc = attention.size() + tmp1 = torch.arange( + num_loc, dtype=attention.dtype, device=attention.device) + tmp1 = tmp1.expand(batch_size, num_loc) + + tmp2 = image_locs.unsqueeze(dim=1).expand(batch_size, num_loc) + mask = tmp1 >= tmp2.to(tmp1.dtype) + attention = attention.masked_fill(mask, -1e30) + return attention diff --git a/mmaction/models/localizers/drn/drn_utils/loss.py b/mmaction/models/localizers/drn/drn_utils/loss.py new file mode 100644 index 0000000000..920ebac0b3 --- /dev/null +++ b/mmaction/models/localizers/drn/drn_utils/loss.py @@ -0,0 +1,240 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Adapted from https://github.com/Alvin-Zeng/DRN/""" + +import torch +import torchvision +from torch import nn + +INF = 100000000 + + +def SigmoidFocalLoss(alpha, gamma): + + def loss_fn(inputs, targets): + loss = torchvision.ops.sigmoid_focal_loss( + inputs=inputs, + targets=targets, + alpha=alpha, + gamma=gamma, + reduction='sum') + return loss + + return loss_fn + + +def IOULoss(): + + def loss_fn(pred, target): + pred_left = pred[:, 0] + pred_right = pred[:, 1] + + target_left = target[:, 0] + target_right = target[:, 1] + + intersect = torch.min(pred_right, target_right) + torch.min( + pred_left, target_left) + target_area = target_left + target_right + pred_area = pred_left + pred_right + union = target_area + pred_area - intersect + + losses = -torch.log((intersect + 1e-8) / (union + 1e-8)) + return losses.mean() + + return loss_fn + + +class FCOSLossComputation(object): + """This class computes the FCOS losses.""" + + def __init__(self, focal_alpha, focal_gamma): + self.cls_loss_fn = SigmoidFocalLoss(focal_alpha, focal_gamma) + self.box_reg_loss_fn = IOULoss() + self.centerness_loss_fn = nn.BCEWithLogitsLoss() + self.iou_loss_fn = nn.SmoothL1Loss() + + def prepare_targets(self, points, targets): + object_sizes_of_interest = [ + [-1, 6], + [5.6, 11], + [11, INF], + ] + expanded_object_sizes_of_interest = [] + for idx, points_per_level in enumerate(points): + object_sizes_of_interest_per_level = \ + points_per_level.new_tensor(object_sizes_of_interest[idx]) + expanded_object_sizes_of_interest.append( + object_sizes_of_interest_per_level[None].expand( + len(points_per_level), -1)) + + expanded_object_sizes_of_interest = torch.cat( + expanded_object_sizes_of_interest, dim=0) + num_points_per_level = [ + len(points_per_level) for points_per_level in points + ] + points_all_level = torch.cat(points, dim=0) + labels, reg_targets = self.compute_targets_for_locations( + points_all_level, targets, expanded_object_sizes_of_interest) + + for i in range(len(labels)): + labels[i] = torch.split(labels[i], num_points_per_level, dim=0) + reg_targets[i] = torch.split( + reg_targets[i], num_points_per_level, dim=0) + + labels_level_first = [] + reg_targets_level_first = [] + for level in range(len(points)): + labels_level_first.append( + torch.cat([labels_per_im[level] for labels_per_im in labels], + dim=0)) + reg_targets_level_first.append( + torch.cat([ + reg_targets_per_im[level] + for reg_targets_per_im in reg_targets + ], + dim=0)) + + return labels_level_first, reg_targets_level_first + + def compute_targets_for_locations(self, locations, targets, + object_sizes_of_interest): + labels = [] + reg_targets = [] + ts = locations + + for im_i in range(len(targets)): + targets_per_im = targets[im_i] + bboxes = targets_per_im * 32 + + left = ts[:, None] - bboxes[None, 0] + right = bboxes[None, 1] - ts[:, None] + reg_targets_per_im = torch.cat([left, right], dim=1) + + is_in_boxes = reg_targets_per_im.min(dim=1)[0] > 0 + max_reg_targets_per_im = reg_targets_per_im.max(dim=1)[0] + is_cared_in_the_level = \ + (max_reg_targets_per_im >= object_sizes_of_interest[:, 0]) & \ + (max_reg_targets_per_im <= object_sizes_of_interest[:, 1]) + + locations_to_gt_area = bboxes[1] - bboxes[0] + locations_to_gt_area = locations_to_gt_area.repeat( + len(locations), 1) + locations_to_gt_area[is_in_boxes == 0] = INF + locations_to_gt_area[is_cared_in_the_level == 0] = INF + + _ = locations_to_gt_area.min(dim=1) + locations_to_min_area, locations_to_gt_inds = _ + + labels_per_im = reg_targets_per_im.new_ones( + len(reg_targets_per_im)) + labels_per_im[locations_to_min_area == INF] = 0 + + labels.append(labels_per_im) + reg_targets.append(reg_targets_per_im) + + return labels, reg_targets + + def __call__(self, + locations, + box_cls, + box_regression, + targets, + iou_scores, + is_first_stage=True): + N = box_cls[0].size(0) + num_classes = box_cls[0].size(1) + labels, reg_targets = self.prepare_targets(locations, targets) + + box_cls_flatten = [] + box_regression_flatten = [] + # centerness_flatten = [] + labels_flatten = [] + reg_targets_flatten = [] + + for idx in range(len(labels)): + box_cls_flatten.append(box_cls[idx].permute(0, 2, 1).reshape( + -1, num_classes)) + box_regression_flatten.append(box_regression[idx].permute( + 0, 2, 1).reshape(-1, 2)) + labels_flatten.append(labels[idx].reshape(-1)) + reg_targets_flatten.append(reg_targets[idx].reshape(-1, 2)) + + if not is_first_stage: + # [batch, 56, 2] + merged_box_regression = torch.cat( + box_regression, dim=-1).transpose(2, 1) + # [56] + merged_locations = torch.cat(locations, dim=0) + # [batch, 56] + full_locations = merged_locations[None, :].expand( + merged_box_regression.size(0), -1).contiguous() + pred_start = full_locations - merged_box_regression[:, :, 0] + pred_end = full_locations + merged_box_regression[:, :, 1] + # [batch, 56, 2] + predictions = torch.cat( + [pred_start.unsqueeze(-1), + pred_end.unsqueeze(-1)], dim=-1) / 32 + # TODO: make sure the predictions are legal. (e.g. start < end) + predictions.clamp_(min=0, max=1) + # gt: [batch, 2] + gt_box = targets[:, None, :] + + iou_target = segment_tiou(predictions, gt_box) + iou_pred = torch.cat(iou_scores, dim=-1).squeeze().sigmoid() + iou_pos_ind = iou_target > 0.9 + pos_iou_target = iou_target[iou_pos_ind] + + pos_iou_pred = iou_pred[iou_pos_ind] + + if iou_pos_ind.sum().item() == 0: + iou_loss = torch.tensor([0.]).to(iou_pos_ind.device) + else: + iou_loss = self.iou_loss_fn(pos_iou_pred, pos_iou_target) + + box_cls_flatten = torch.cat(box_cls_flatten, dim=0) + box_regression_flatten = torch.cat(box_regression_flatten, dim=0) + labels_flatten = torch.cat(labels_flatten, dim=0) + reg_targets_flatten = torch.cat(reg_targets_flatten, dim=0) + + pos_inds = torch.nonzero(labels_flatten > 0).squeeze(1) + cls_loss = self.cls_loss_fn( + box_cls_flatten, labels_flatten.unsqueeze(1)) / ( + pos_inds.numel() + N) # add N to avoid dividing by a zero + + box_regression_flatten = box_regression_flatten[pos_inds] + reg_targets_flatten = reg_targets_flatten[pos_inds] + + if pos_inds.numel() > 0: + reg_loss = self.box_reg_loss_fn( + box_regression_flatten, + reg_targets_flatten, + ) + else: + reg_loss = box_regression_flatten.sum() + + if not is_first_stage: + return cls_loss, reg_loss, iou_loss + + return cls_loss, reg_loss, torch.tensor([0.]).to(cls_loss.device) + + +def segment_tiou(box_a, box_b): + + # gt: [batch, 1, 2], detections: [batch, 56, 2] + # calculate interaction + inter_max_xy = torch.min(box_a[:, :, -1], box_b[:, :, -1]) + inter_min_xy = torch.max(box_a[:, :, 0], box_b[:, :, 0]) + inter = torch.clamp((inter_max_xy - inter_min_xy), min=0) + + # calculate union + union_max_xy = torch.max(box_a[:, :, -1], box_b[:, :, -1]) + union_min_xy = torch.min(box_a[:, :, 0], box_b[:, :, 0]) + union = torch.clamp((union_max_xy - union_min_xy), min=0) + + iou = inter / (union + 1e-6) + + return iou + + +def make_fcos_loss_evaluator(focal_alpha, focal_gamma): + loss_evaluator = FCOSLossComputation(focal_alpha, focal_gamma) + return loss_evaluator diff --git a/tools/data/charades-sta/README.md b/tools/data/charades-sta/README.md new file mode 100644 index 0000000000..b2bea83d2b --- /dev/null +++ b/tools/data/charades-sta/README.md @@ -0,0 +1,59 @@ +# Preparing AVA + +## Introduction + + + +```BibTeX +@inproceedings{gao2017tall, + title={Tall: Temporal activity localization via language query}, + author={Gao, Jiyang and Sun, Chen and Yang, Zhenheng and Nevatia, Ram}, + booktitle={Proceedings of the IEEE international conference on computer vision}, + pages={5267--5275}, + year={2017} +} + +@inproceedings{DRN2020CVPR, + author = {Runhao, Zeng and Haoming, Xu and Wenbing, Huang and Peihao, Chen and Mingkui, Tan and Chuang Gan}, + title = {Dense Regression Network for Video Grounding}, + booktitle = {CVPR}, + year = {2020}, +} +``` + +Charades-STA is a new dataset built on top of Charades by adding sentence temporal annotations. It is introduced by Gao et al. in `TALL: Temporal Activity Localization via Language Query`. Currently, we only support C3D features from `Dense Regression Network for Video Grounding`. + +## Step 1. Prepare Annotations + +First of all, you can run the following script to prepare annotations from the official repository of DRN: + +```shell +bash download_annotations.sh +``` + +## Step 2. Prepare C3D features + +After the first step, you should be at `${MMACTION2}/data/CharadesSTA/`. Download the C3D features following the [official command](https://github.com/Alvin-Zeng/DRN/tree/master#download-features) to the current directory `${MMACTION2}/data/CharadesSTA/`. + +After finishing the two steps, the folder structure will look like: + +``` +mmaction2 +├── mmaction +├── tools +├── configs +├── data +│ ├── CharadesSTA +│ │ ├── C3D_unit16_overlap0.5_merged +│ │ | ├── 001YG.pt +│ │ | ├── 003WS.pt +│ │ | ├── 004QE.pt +│ │ | ├── 00607.pt +│ │ | ├── ... +│ │ ├── Charades_duration.json +│ │ ├── Charades_fps_dict.json +│ │ ├── Charades_frames_info.json +│ │ ├── Charades_sta_test.txt +│ │ ├── Charades_sta_train.txt +│ │ ├── Charades_word2id.json +``` diff --git a/tools/data/charades-sta/download_annotations.sh b/tools/data/charades-sta/download_annotations.sh new file mode 100644 index 0000000000..85bdb7d1a8 --- /dev/null +++ b/tools/data/charades-sta/download_annotations.sh @@ -0,0 +1,18 @@ +#!/usr/bin/env bash + +DATA_DIR="../../../data/CharadesSTA/" + +if [[ ! -d "${DATA_DIR}" ]]; then + echo "${DATA_DIR} does not exist. Creating"; + mkdir -p ${DATA_DIR} +fi + +cd ${DATA_DIR} + +URL="https://raw.githubusercontent.com/Alvin-Zeng/DRN/master/data/dataset/Charades" +wget ${URL}/Charades_frames_info.json +wget ${URL}/Charades_duration.json +wget ${URL}/Charades_fps_dict.json +wget ${URL}/Charades_sta_test.txt +wget ${URL}/Charades_sta_train.txt +wget ${URL}/Charades_word2id.json