Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Fix gt_labels #2690

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions docs/en/get_started/guide_to_framework.md
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,6 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import BaseModel, BaseModule, Sequential
from mmengine.structures import LabelData
from mmaction.registry import MODELS


Expand Down Expand Up @@ -498,8 +497,7 @@ class ClsHeadZelda(BaseModule):
cls_scores = self.average_clip(cls_scores, num_views)

for ds, sc in zip(data_samples, cls_scores):
pred = LabelData(item=sc)
ds.pred_scores = pred
ds.set_pred_score(sc)
return data_samples

def average_clip(self, cls_scores, num_views):
Expand Down
4 changes: 1 addition & 3 deletions docs/zh_cn/get_started/guide_to_framework.md
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,6 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import BaseModel, BaseModule, Sequential
from mmengine.structures import LabelData
from mmaction.registry import MODELS


Expand Down Expand Up @@ -498,8 +497,7 @@ class ClsHeadZelda(BaseModule):
cls_scores = self.average_clip(cls_scores, num_views)

for ds, sc in zip(data_samples, cls_scores):
pred = LabelData(item=sc)
ds.pred_scores = pred
ds.set_pred_score(sc)
return data_samples

def average_clip(self, cls_scores, num_views):
Expand Down
5 changes: 1 addition & 4 deletions mmaction/utils/typing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch
from mmengine.config import ConfigDict
from mmengine.structures import InstanceData, LabelData
from mmengine.structures import InstanceData

from mmaction.structures import ActionDataSample

Expand All @@ -18,9 +18,6 @@
InstanceList = List[InstanceData]
OptInstanceList = Optional[InstanceList]

LabelList = List[LabelData]
OptLabelList = Optional[LabelList]

SampleList = List[ActionDataSample]
OptSampleList = Optional[SampleList]

Expand Down
3 changes: 1 addition & 2 deletions mmaction/visualization/action_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,12 @@ class ActionVisualizer(Visualizer):
>>> import decord
>>> from pathlib import Path
>>> from mmaction.structures import ActionDataSample, ActionVisualizer
>>> from mmengine.structures import LabelData
>>> # Example frame
>>> video = decord.VideoReader('./demo/demo.mp4')
>>> video = video.get_batch(range(32)).asnumpy()
>>> # Example annotation
>>> data_sample = ActionDataSample()
>>> data_sample.gt_label = LabelData(item=torch.tensor([2]))
>>> data_sample.set_pred_label(torch.tensor([2]))
>>> # Setup the visualizer
>>> vis = ActionVisualizer(
... save_dir="./outputs",
Expand Down
3 changes: 1 addition & 2 deletions projects/actionclip/models/actionclip.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch
import torch.nn.functional as F
from mmengine.model import BaseModel
from mmengine.structures import LabelData

from mmaction.registry import MODELS
from .adapter import TransformerAdapter
Expand Down Expand Up @@ -108,7 +107,7 @@ def forward(self,
cls_scores = F.softmax(similarity, dim=2).mean(dim=1)

for data_sample, score in zip(data_samples, cls_scores):
data_sample.pred_scores = LabelData(item=score)
data_sample.set_pred_score(score)

return data_samples

Expand Down
3 changes: 0 additions & 3 deletions tools/deployment/export_onnx_gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from mmengine import Config
from mmengine.registry import init_default_scope
from mmengine.runner import load_checkpoint
from mmengine.structures import LabelData

from mmaction.registry import MODELS
from mmaction.structures import ActionDataSample
Expand Down Expand Up @@ -117,8 +116,6 @@ def main():
base_model.eval()

data_sample = ActionDataSample()
data_sample.pred_scores = LabelData()
data_sample.pred_labels = LabelData()
base_output = base_model(
input_tensor.unsqueeze(0), data_samples=[data_sample],
mode='predict')[0]
Expand Down
3 changes: 0 additions & 3 deletions tools/deployment/export_onnx_posec3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from mmengine import Config
from mmengine.registry import init_default_scope
from mmengine.runner import load_checkpoint
from mmengine.structures import LabelData

from mmaction.registry import MODELS
from mmaction.structures import ActionDataSample
Expand Down Expand Up @@ -113,8 +112,6 @@ def main():
base_model.eval()

data_sample = ActionDataSample()
data_sample.pred_scores = LabelData()
data_sample.pred_labels = LabelData()
base_output = base_model(
input_tensor.unsqueeze(0), data_samples=[data_sample],
mode='predict')[0]
Expand Down
Loading