Skip to content

Commit

Permalink
Merge pull request #1516 from lrzpellegrini/ffcv_improvements
Browse files Browse the repository at this point in the history
Ffcv improvements
  • Loading branch information
AntonioCarta authored Oct 11, 2023
2 parents 0515a47 + 8906a28 commit 1c313f7
Show file tree
Hide file tree
Showing 7 changed files with 187 additions and 15 deletions.
138 changes: 138 additions & 0 deletions avalanche/benchmarks/utils/ffcv_support/center_crop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""
Implementation of the CenterCrop transformation for FFCV
"""

from typing import Callable, Tuple
from ffcv.fields.decoders import SimpleRGBImageDecoder
from ffcv.pipeline.state import State
from ffcv.pipeline.allocation_query import AllocationQuery
import numpy as np
from dataclasses import replace
from ffcv.fields.rgb_image import IMAGE_MODES
from ffcv.pipeline.compiler import Compiler
from ffcv.libffcv import imdecode


def get_center_crop_torchvision_alike(
image_height, image_width, output_size, img, out_buffer
):
crop_height = output_size[0]
crop_width = output_size[1]

padding_h = (crop_height - image_height) // 2 if crop_height > image_height else 0
padding_w = (crop_width - image_width) // 2 if crop_width > image_width else 0

crop_t = (
int(round((image_height - crop_height) / 2.0))
if image_height > crop_height
else 0
)
crop_l = (
int(round((image_width - crop_width) / 2.0)) if image_width > crop_width else 0
)
crop_height_effective = min(crop_height, image_height)
crop_width_effective = min(crop_width, image_width)

# print(image_height, image_width, crop_height, crop_width, padding_h, padding_w, crop_t, crop_l, crop_height_effective, crop_width_effective)
# print(f'From ({crop_t} : {crop_t+crop_height_effective}, {crop_l} : {crop_l+crop_width_effective}) to '
# f'{padding_h} : {padding_h+crop_height_effective}, {padding_w} : {padding_w+crop_width_effective}')

if crop_height_effective != crop_height or crop_width_effective != crop_width:
out_buffer[:] = 0 # Set padding color
out_buffer[
padding_h : padding_h + crop_height_effective,
padding_w : padding_w + crop_width_effective,
] = img[
crop_t : crop_t + crop_height_effective, crop_l : crop_l + crop_width_effective
]

return out_buffer


class CenterCropRGBImageDecoderTVAlike(SimpleRGBImageDecoder):
"""Decoder for :class:`~ffcv.fields.RGBImageField` that performs a center crop operation.
It supports both variable and constant resolution datasets.
Differently from the original CenterCropRGBImageDecoder from FFCV,
this operates like torchvision CenterCrop.
"""

def __init__(self, output_size):
super().__init__()
self.output_size = output_size

def declare_state_and_memory(
self, previous_state: State
) -> Tuple[State, AllocationQuery]:
widths = self.metadata["width"]
heights = self.metadata["height"]
# We convert to uint64 to avoid overflows
self.max_width = np.uint64(widths.max())
self.max_height = np.uint64(heights.max())
output_shape = (self.output_size[0], self.output_size[1], 3)
my_dtype = np.dtype("<u1")

return (
replace(previous_state, jit_mode=True, shape=output_shape, dtype=my_dtype),
(
AllocationQuery(output_shape, my_dtype),
AllocationQuery(
(self.max_height * self.max_width * np.uint64(3),), my_dtype
),
),
)

def generate_code(self) -> Callable:
jpg = IMAGE_MODES["jpg"]

mem_read = self.memory_read
my_range = Compiler.get_iterator()
imdecode_c = Compiler.compile(imdecode)
c_crop = Compiler.compile(self.get_crop_generator)
output_size = self.output_size

def decode(batch_indices, my_storage, metadata, storage_state):
destination, temp_storage = my_storage
for dst_ix in my_range(len(batch_indices)):
source_ix = batch_indices[dst_ix]
field = metadata[source_ix]
image_data = mem_read(field["data_ptr"], storage_state)
height = np.uint32(field["height"])
width = np.uint32(field["width"])

if field["mode"] == jpg:
temp_buffer = temp_storage[dst_ix]
imdecode_c(
image_data,
temp_buffer,
height,
width,
height,
width,
0,
0,
1,
1,
False,
False,
)
selected_size = 3 * height * width
temp_buffer = temp_buffer.reshape(-1)[:selected_size]
temp_buffer = temp_buffer.reshape(height, width, 3)
else:
temp_buffer = image_data.reshape(height, width, 3)

c_crop(height, width, output_size, temp_buffer, destination[dst_ix])

return destination[: len(batch_indices)]

decode.is_parallel = True
return decode

@property
def get_crop_generator(self):
return get_center_crop_torchvision_alike


__all__ = ["CenterCropRGBImageDecoderTVAlike"]
14 changes: 13 additions & 1 deletion avalanche/benchmarks/utils/ffcv_support/ffcv_transform_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,18 @@
import torch

from avalanche.benchmarks.utils.transforms import flat_transforms_recursive
from avalanche.benchmarks.utils.ffcv_support.center_crop import (
CenterCropRGBImageDecoderTVAlike,
)

from torchvision.transforms import ToTensor as ToTensorTV
from torchvision.transforms import PILToTensor as PILToTensorTV
from torchvision.transforms import Normalize as NormalizeTV
from torchvision.transforms import ConvertImageDtype as ConvertTV
from torchvision.transforms import RandomResizedCrop as RandomResizedCropTV
from torchvision.transforms import CenterCrop as CenterCropTV
from torchvision.transforms import RandomHorizontalFlip as RandomHorizontalFlipTV
from torchvision.transforms import RandomCrop as RandomCropTV
from torchvision.transforms import Lambda

from ffcv.transforms import ToTensor as ToTensorFFCV
from ffcv.transforms import ToDevice as ToDeviceFFCV
Expand Down Expand Up @@ -282,6 +285,15 @@ def _apply_transforms_pre_optimization(
elif len(size) == 1:
size = [size[0], size[0]]
result[-1] = RandomResizedCropRGBImageDecoder(size, t.scale, t.ratio)
elif isinstance(t, CenterCropTV) and isinstance(
result[-1], SimpleRGBImageDecoder
):
size = t.size
if isinstance(size, int):
size = [size, size]
elif len(size) == 1:
size = [size[0], size[0]]
result[-1] = CenterCropRGBImageDecoderTVAlike(size)
else:
result.append(t)

Expand Down
17 changes: 13 additions & 4 deletions avalanche/training/supervised/naive_object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,10 +247,19 @@ def forward(self):
def _unpack_minibatch(self):
# Unpack minibatch mainly takes care of moving tensors to devices.
# In addition, it will prepare the targets in the proper dict format.
images = list(image.to(self.device) for image in self.mbatch[0])
targets = [{k: v.to(self.device) for k, v in t.items()} for t in self.mbatch[1]]

mbatch = [images, targets, torch.as_tensor(self.mbatch[2]).to(self.device)]
images = list(
image.to(self.device, non_blocking=True) for image in self.mbatch[0]
)
targets = [
{k: v.to(self.device, non_blocking=True) for k, v in t.items()}
for t in self.mbatch[1]
]

mbatch = [
images,
targets,
torch.as_tensor(self.mbatch[2]).to(self.device, non_blocking=True),
]
self.mbatch = tuple(mbatch)

def backward(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _unpack_minibatch(self):
self.mbatch = mbatch

for i in range(len(mbatch)):
mbatch[i] = mbatch[i].to(self.device) # type: ignore
mbatch[i] = mbatch[i].to(self.device, non_blocking=True) # type: ignore


__all__ = ["SupervisedProblem"]
4 changes: 3 additions & 1 deletion examples/ffcv/ffcv_enable.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@

def main(cuda: int):
# --- CONFIG
device = torch.device(f"cuda:{cuda}" if torch.cuda.is_available() else "cpu")
device = torch.device(
f"cuda:{cuda}" if cuda >= 0 and torch.cuda.is_available() else "cpu"
)
RNGManager.set_random_seeds(1234)

benchmark_type = "cifar100"
Expand Down
4 changes: 3 additions & 1 deletion examples/ffcv/ffcv_enable_rgb_compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@


def main(cuda: int):
device = torch.device(f"cuda:{cuda}" if torch.cuda.is_available() else "cpu")
device = torch.device(
f"cuda:{cuda}" if cuda >= 0 and torch.cuda.is_available() else "cpu"
)
RNGManager.set_random_seeds(1234)

benchmark_type = "tinyimagenet"
Expand Down
23 changes: 16 additions & 7 deletions examples/ffcv/ffcv_try_speed.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,18 @@
from torchvision.transforms import Compose, ToTensor, Normalize

from torch.utils.data import DataLoader
from torch.utils.data.sampler import (
BatchSampler,
SequentialSampler,
)
from tqdm import tqdm


def main(cuda: int):
# --- CONFIG
device = torch.device(f"cuda:{cuda}" if torch.cuda.is_available() else "cpu")
device = torch.device(
f"cuda:{cuda}" if cuda >= 0 and torch.cuda.is_available() else "cpu"
)
RNGManager.set_random_seeds(1234)

benchmark_type = "cifar100"
Expand Down Expand Up @@ -114,16 +120,19 @@ def benchmark_ffcv_speed(

start_time = time.time()
ffcv_loader = HybridFfcvLoader(
avl_set,
None,
batch_size,
dict(num_workers=num_workers, drop_last=True),
dataset=avl_set,
batch_sampler=BatchSampler(
SequentialSampler(avl_set),
batch_size=batch_size,
drop_last=True,
),
ffcv_loader_parameters=dict(num_workers=num_workers),
device=device,
print_ffcv_summary=False,
)

for _ in tqdm(range(epochs)):
for batch in ffcv_loader:
for batch in tqdm(ffcv_loader):
# "Touch" tensors to make sure they already moved to GPU
batch[0][0]
batch[-1][0]
Expand Down Expand Up @@ -152,7 +161,7 @@ def benchmark_pytorch_speed(benchmark, device, batch_size=128, num_workers=1, ep

batch: Tuple[torch.Tensor]
for _ in tqdm(range(epochs)):
for batch in torch_loader:
for batch in tqdm(torch_loader):
batch = tuple(x.to(device, non_blocking=True) for x in batch)

# "Touch" tensors to make sure they already moved to GPU
Expand Down

0 comments on commit 1c313f7

Please sign in to comment.