Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix decoding issue with PYAV due to new support for multiple training… #541

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 82 additions & 33 deletions slowfast/datasets/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def get_multiple_start_end_idx(
num_clips_uniform,
min_delta=0,
max_delta=math.inf,
use_offset=False
):
"""
Sample a clip of size clip_size from a video of size video_size and
Expand Down Expand Up @@ -114,20 +115,28 @@ def sample_clips(
min_delta=0,
max_delta=math.inf,
num_retries=100,
use_offset=False
):
se_inds = np.empty((0, 2))
dt = np.empty((0))
for clip_size in clip_sizes:
for i_try in range(num_retries):
clip_size = int(clip_size)
# clip_size = int(clip_size)
max_start = max(video_size - clip_size, 0)
if clip_idx == -1:
# Random temporal sampling.
start_idx = random.uniform(0, max_start)
else:
# Uniformly sample the clip with the given index.
start_idx = max_start * clip_idx / num_clips_uniform
end_idx = start_idx + clip_size # - 1
else: # Uniformly sample the clip with the given index.
if use_offset:
if num_clips_uniform == 1:
# Take the center clip if num_clips is 1.
start_idx = math.floor(max_start / 2)
else:
start_idx = clip_idx * math.floor(max_start / (num_clips_uniform - 1))
else:
start_idx = max_start * clip_idx / num_clips_uniform

end_idx = start_idx + clip_size - 1

se_inds_new = np.append(se_inds, [[start_idx, end_idx]], axis=0)
if se_inds.shape[0] < 1:
Expand Down Expand Up @@ -156,6 +165,7 @@ def sample_clips(
min_delta,
max_delta,
100,
use_offset,
)
success = not (any(dt < min_delta) or any(dt > max_delta))
if success or clip_idx != -1:
Expand Down Expand Up @@ -295,9 +305,7 @@ def torchvision_decode(
clip_sizes = [
np.maximum(
1.0,
np.ceil(
sampling_rate[i] * (num_frames[i] - 1) / target_fps * fps
),
sampling_rate[i] * num_frames[i] / target_fps * fps
)
for i in range(len(sampling_rate))
]
Expand All @@ -308,6 +316,7 @@ def torchvision_decode(
num_clips_uniform,
min_delta=min_delta,
max_delta=max_delta,
use_offset=use_offset,
)
frames_out = [None] * len(num_frames)
for k in range(len(num_frames)):
Expand Down Expand Up @@ -374,6 +383,10 @@ def pyav_decode(
num_clips_uniform=10,
target_fps=30,
use_offset=False,
modalities=("visual",),
max_spatial_scale=0,
min_delta=-math.inf,
max_delta=math.inf,
):
"""
Convert the video from its original fps to the target_fps. If the video
Expand Down Expand Up @@ -411,38 +424,69 @@ def pyav_decode(
# If failed to fetch the decoding information, decode the entire video.
decode_all_video = True
video_start_pts, video_end_pts = 0, math.inf
start_end_delta_time = None

frames = None
if container.streams.video:
video_frames, max_pts = pyav_decode_stream(
container,
video_start_pts,
video_end_pts,
container.streams.video[0],
{"video": 0},
)
container.close()

frames = [frame.to_rgb().to_ndarray() for frame in video_frames]
frames = torch.as_tensor(np.stack(frames))
frames_out = [frames]

else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For line 414-417, we should also decode the whole videos and return the frames. Example code could be:

        decode_all_video = True
        video_start_pts, video_end_pts = 0, math.inf
        start_end_delta_time = None

        frames = None
        if container.streams.video:
            video_frames, max_pts = pyav_decode_stream(
                container,
                video_start_pts,
                video_end_pts,
                container.streams.video[0],
                {"video": 0},
            )

            frames = [frame.to_rgb().to_ndarray() for frame in video_frames]
            frames = torch.as_tensor(np.stack(frames))
        frames_out = [frames]

# Perform selective decoding.
decode_all_video = False
clip_size = np.maximum(
1.0, np.ceil(sampling_rate * (num_frames - 1) / target_fps * fps)
)
start_idx, end_idx, fraction = get_start_end_idx(
clip_sizes = [
np.maximum(
1.0,
np.ceil(
sampling_rate[i] * (num_frames[i] - 1) / target_fps * fps
),
)
for i in range(len(sampling_rate))
]
start_end_delta_time = get_multiple_start_end_idx(
frames_length,
clip_size,
clip_sizes,
clip_idx,
num_clips_uniform,
use_offset=use_offset,
)
timebase = duration / frames_length
video_start_pts = int(start_idx * timebase)
video_end_pts = int(end_idx * timebase)

frames = None
# If video stream was found, fetch video frames from the video.
if container.streams.video:
video_frames, max_pts = pyav_decode_stream(
container,
video_start_pts,
video_end_pts,
container.streams.video[0],
{"video": 0},
min_delta=min_delta,
max_delta=max_delta,
)
frames_out = [None] * len(num_frames)
for k in range(len(num_frames)):
start_idx = start_end_delta_time[k, 0]
end_idx = start_end_delta_time[k, 1]
timebase = duration / frames_length
video_start_pts = int(start_idx * timebase)
video_end_pts = int(end_idx * timebase)

frames = None
# If video stream was found, fetch video frames from the video.
if container.streams.video:
video_frames, max_pts = pyav_decode_stream(
container,
video_start_pts,
video_end_pts,
container.streams.video[0],
{"video": 0},
)

frames = [frame.to_rgb().to_ndarray() for frame in video_frames]
frames = torch.as_tensor(np.stack(frames))

frames_out[k] = frames
container.close()

frames = [frame.to_rgb().to_ndarray() for frame in video_frames]
frames = torch.as_tensor(np.stack(frames))
return frames, fps, decode_all_video
return frames_out, fps, decode_all_video, start_end_delta_time


def decode(
Expand Down Expand Up @@ -504,14 +548,18 @@ def decode(
if backend == "pyav":
assert min_delta == -math.inf and max_delta == math.inf, \
"delta sampling not supported in pyav"
frames_decoded, fps, decode_all_video = pyav_decode(
frames_decoded, fps, decode_all_video, start_end_delta_time = pyav_decode(
container,
sampling_rate,
num_frames,
clip_idx,
num_clips_uniform,
target_fps,
use_offset=use_offset,
modalities=("visual",),
max_spatial_scale=max_spatial_scale,
min_delta=min_delta,
max_delta=max_delta,
)
elif backend == "torchvision":
(
Expand Down Expand Up @@ -551,7 +599,7 @@ def decode(
clip_sizes = [
np.maximum(
1.0,
np.ceil(sampling_rate[i] * (num_frames[i] - 1) / target_fps * fps),
sampling_rate[i] * num_frames[i] / target_fps * fps
)
for i in range(len(sampling_rate))
]
Expand All @@ -565,6 +613,7 @@ def decode(
num_clips_uniform if decode_all_video else 1,
min_delta=min_delta,
max_delta=max_delta,
use_offset=use_offset,
)

frames_out, start_inds, time_diff_aug = (
Expand Down