Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add base elements to support distributed comms. Add supports_distributed plugin flag. #1370

Merged
Merged
6 changes: 5 additions & 1 deletion .github/workflows/environment-update.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,11 @@ jobs:
run: |
python -m unittest discover tests &&
echo "Running checkpointing tests..." &&
bash ./tests/checkpointing/test_checkpointing.sh
bash ./tests/checkpointing/test_checkpointing.sh &&
echo "Running distributed training tests..." &&
cd tests &&
PYTHONPATH=.. python run_dist_tests.py &&
cd ..
- name: checkout avalanche-docker repo
if: always()
uses: actions/checkout@v3
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/unit-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,9 @@ jobs:
PYTHONPATH=. python examples/eval_plugin.py &&
echo "Running checkpointing tests..." &&
bash ./tests/checkpointing/test_checkpointing.sh &&
echo "Running distributed training tests..." &&
cd tests &&
PYTHONPATH=.. python run_dist_tests.py &&
cd .. &&
echo "While running unit tests, the following datasets were downloaded:" &&
ls ~/.avalanche/data
7 changes: 5 additions & 2 deletions avalanche/benchmarks/classic/cmnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,11 @@
)
from avalanche.benchmarks.datasets.external_datasets.mnist import \
get_mnist_dataset
from ..utils import make_classification_dataset, DefaultTransformGroups
from ..utils.data import make_avalanche_dataset
from avalanche.benchmarks.utils import (
make_classification_dataset,
DefaultTransformGroups,
)
from avalanche.benchmarks.utils.data import make_avalanche_dataset

_default_mnist_train_transform = Compose(
[Normalize((0.1307,), (0.3081,))]
Expand Down
52 changes: 42 additions & 10 deletions avalanche/benchmarks/datasets/lvis_dataset/lvis_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,20 +159,20 @@ def __getitem__(self, index):
"""
img_id = self.img_ids[index]
img_dict: LVISImgEntry = self.lvis_api.load_imgs(ids=[img_id])[0]
annotation_dicts = self.targets[index]
annotation_dicts: LVISImgTargets = self.targets[index]

# Transform from LVIS dictionary to torchvision-style target
num_objs = len(annotation_dicts)
num_objs = annotation_dicts["bbox"].shape[0]

boxes = []
labels = []
for i in range(num_objs):
xmin = annotation_dicts[i]["bbox"][0]
ymin = annotation_dicts[i]["bbox"][1]
xmax = xmin + annotation_dicts[i]["bbox"][2]
ymax = ymin + annotation_dicts[i]["bbox"][3]
xmin = annotation_dicts["bbox"][i][0]
ymin = annotation_dicts["bbox"][i][1]
xmax = xmin + annotation_dicts["bbox"][i][2]
ymax = ymin + annotation_dicts["bbox"][i][3]
boxes.append([xmin, ymin, xmax, ymax])
labels.append(annotation_dicts[i]["category_id"])
labels.append(annotation_dicts["category_id"][i])

if len(boxes) > 0:
boxes = torch.as_tensor(boxes, dtype=torch.float32)
Expand All @@ -183,7 +183,7 @@ def __getitem__(self, index):
image_id = torch.tensor([img_id])
areas = []
for i in range(num_objs):
areas.append(annotation_dicts[i]["area"])
areas.append(annotation_dicts["area"][i])
areas = torch.as_tensor(areas, dtype=torch.float32)
iscrowd = torch.zeros((num_objs,), dtype=torch.int64)

Expand Down Expand Up @@ -233,7 +233,17 @@ class LVISAnnotationEntry(TypedDict):
category_id: int


class LVISDetectionTargets(Sequence[List[LVISAnnotationEntry]]):
class LVISImgTargets(TypedDict):
id: torch.Tensor
area: torch.Tensor
segmentation: List[List[List[float]]]
image_id: torch.Tensor
bbox: torch.Tensor
category_id: torch.Tensor
labels: torch.Tensor


class LVISDetectionTargets(Sequence[List[LVISImgTargets]]):
def __init__(
self,
lvis_api: LVIS,
Expand All @@ -254,7 +264,28 @@ def __getitem__(self, index):
annotation_dicts: List[LVISAnnotationEntry] = self.lvis_api.load_anns(
annotation_ids
)
return annotation_dicts

n_annotations = len(annotation_dicts)

category_tensor = torch.empty((n_annotations,), dtype=torch.long)
target_dict: LVISImgTargets = {
'bbox': torch.empty((n_annotations, 4), dtype=torch.float32),
'category_id': category_tensor,
'id': torch.empty((n_annotations,), dtype=torch.long),
'area': torch.empty((n_annotations,), dtype=torch.float32),
'image_id': torch.full((1,), img_id, dtype=torch.long),
'segmentation': [],
'labels': category_tensor # Alias of category_id
}

for ann_idx, annotation in enumerate(annotation_dicts):
target_dict['bbox'][ann_idx] = torch.as_tensor(annotation['bbox'])
target_dict['category_id'][ann_idx] = annotation['category_id']
target_dict['id'][ann_idx] = annotation['id']
target_dict['area'][ann_idx] = annotation['area']
target_dict['segmentation'].append(annotation['segmentation'])

return target_dict


def _test_to_tensor(a, b):
Expand Down Expand Up @@ -316,5 +347,6 @@ def _plot_detection_sample(img: Image.Image, target):
"LvisDataset",
"LVISImgEntry",
"LVISAnnotationEntry",
"LVISImgTargets",
"LVISDetectionTargets",
]
44 changes: 22 additions & 22 deletions avalanche/benchmarks/utils/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union

import torch
from torch.utils.data import RandomSampler, DistributedSampler
from torch.utils.data import RandomSampler, DistributedSampler, Dataset
from torch.utils.data.dataloader import DataLoader

from avalanche.benchmarks.utils.collate_functions import (
Expand All @@ -31,6 +31,7 @@
)
from avalanche.benchmarks.utils.data import AvalancheDataset
from avalanche.benchmarks.utils.data_attribute import DataAttribute
from avalanche.distributed.distributed_helper import DistributedHelper

_default_collate_mbatches_fn = classification_collate_mbatches_fn

Expand Down Expand Up @@ -284,14 +285,14 @@ def __init__(
self.collate_mbatches = collate_mbatches

for data in self.datasets:
if _DistributedHelper.is_distributed and distributed_sampling:
if DistributedHelper.is_distributed and distributed_sampling:
seed = torch.randint(
0,
2 ** 32 - 1 - _DistributedHelper.world_size,
2 ** 32 - 1 - DistributedHelper.world_size,
(1,),
dtype=torch.int64,
)
seed += _DistributedHelper.rank
seed += DistributedHelper.rank
generator = torch.Generator()
generator.manual_seed(int(seed))
else:
Expand Down Expand Up @@ -584,11 +585,11 @@ def _get_batch_sizes(


def _make_data_loader(
dataset,
distributed_sampling,
data_loader_args,
batch_size,
force_no_workers=False,
dataset: Dataset,
distributed_sampling: bool,
data_loader_args: Dict[str, Any],
batch_size: int,
force_no_workers: bool = False,
):
data_loader_args = data_loader_args.copy()

Expand All @@ -601,14 +602,22 @@ def _make_data_loader(
if 'prefetch_factor' in data_loader_args:
data_loader_args['prefetch_factor'] = 2

if _DistributedHelper.is_distributed and distributed_sampling:
if DistributedHelper.is_distributed and distributed_sampling:
# Note: shuffle only goes in the sampler, while
# drop_last must be passed to both the sampler
# and the DataLoader
drop_last = data_loader_args.pop("drop_last", False)
sampler = DistributedSampler(
dataset,
shuffle=data_loader_args.pop("shuffle", False),
drop_last=data_loader_args.pop("drop_last", False),
shuffle=data_loader_args.pop("shuffle", True),
drop_last=drop_last,
)
data_loader = DataLoader(
dataset, sampler=sampler, batch_size=batch_size, **data_loader_args
dataset,
sampler=sampler,
batch_size=batch_size,
drop_last=drop_last,
**data_loader_args
)
else:
sampler = None
Expand All @@ -619,15 +628,6 @@ def _make_data_loader(
return data_loader, sampler


class __DistributedHelperPlaceholder:
is_distributed = False
world_size = 1
rank = 0


_DistributedHelper = __DistributedHelperPlaceholder()


__all__ = [
"detection_collate_fn",
"detection_collate_mbatches_fn",
Expand Down
31 changes: 28 additions & 3 deletions avalanche/core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC
from typing import TypeVar, Generic
from typing import Optional, Type, TypeVar, Generic
from typing import TYPE_CHECKING

if TYPE_CHECKING:
Expand Down Expand Up @@ -27,8 +27,16 @@ class BasePlugin(Generic[Template], ABC):
and loggers.
"""

supports_distributed: bool = False
"""
A flag describing whether this plugin supports distributed training.
"""

def __init__(self):
pass
"""
Inizializes an instance of a supervised plugin.
"""
super().__init__()

def before_training(self, strategy: Template, *args, **kwargs):
"""Called before `train` by the `BaseTemplate`."""
Expand Down Expand Up @@ -68,13 +76,26 @@ def after_eval(self, strategy: Template, *args, **kwargs) -> CallbackResult:
"""Called after `eval` by the `BaseTemplate`."""
pass

def __init_subclass__(
cls,
supports_distributed: bool = False,
**kwargs) -> None:
cls.supports_distributed = supports_distributed
return super().__init_subclass__(**kwargs)


class BaseSGDPlugin(BasePlugin[Template], ABC):
"""ABC for BaseSGDTemplate plugins.

See `BaseSGDTemplate` for complete description of the train/eval loop.
"""

def __init__(self):
"""
Inizializes an instance of a base SGD plugin.
"""
super().__init__()

def before_training_epoch(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
Expand Down Expand Up @@ -193,7 +214,11 @@ class SupervisedPlugin(BaseSGDPlugin[Template], ABC):

See `BaseTemplate` for complete description of the train/eval loop.
"""
pass
def __init__(self):
"""
Inizializes an instance of a supervised plugin.
"""
super().__init__()


class SupervisedMetaLearningPlugin(SupervisedPlugin[Template], ABC):
Expand Down
1 change: 1 addition & 0 deletions avalanche/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .distributed_helper import *
Loading