diff --git a/.github/workflows/environment-update.yml b/.github/workflows/environment-update.yml index 0f8b86dee..eabda6fa3 100644 --- a/.github/workflows/environment-update.yml +++ b/.github/workflows/environment-update.yml @@ -61,7 +61,11 @@ jobs: run: | python -m unittest discover tests && echo "Running checkpointing tests..." && - bash ./tests/checkpointing/test_checkpointing.sh + bash ./tests/checkpointing/test_checkpointing.sh && + echo "Running distributed training tests..." && + cd tests && + PYTHONPATH=.. python run_dist_tests.py && + cd .. - name: checkout avalanche-docker repo if: always() uses: actions/checkout@v3 diff --git a/.github/workflows/unit-test.yml b/.github/workflows/unit-test.yml index 09c2d9a22..df1ce8522 100644 --- a/.github/workflows/unit-test.yml +++ b/.github/workflows/unit-test.yml @@ -58,5 +58,9 @@ jobs: PYTHONPATH=. python examples/eval_plugin.py && echo "Running checkpointing tests..." && bash ./tests/checkpointing/test_checkpointing.sh && + echo "Running distributed training tests..." && + cd tests && + PYTHONPATH=.. python run_dist_tests.py && + cd .. && echo "While running unit tests, the following datasets were downloaded:" && ls ~/.avalanche/data diff --git a/avalanche/benchmarks/classic/cmnist.py b/avalanche/benchmarks/classic/cmnist.py index 93971ef8b..d53365dba 100644 --- a/avalanche/benchmarks/classic/cmnist.py +++ b/avalanche/benchmarks/classic/cmnist.py @@ -28,8 +28,11 @@ ) from avalanche.benchmarks.datasets.external_datasets.mnist import \ get_mnist_dataset -from ..utils import make_classification_dataset, DefaultTransformGroups -from ..utils.data import make_avalanche_dataset +from avalanche.benchmarks.utils import ( + make_classification_dataset, + DefaultTransformGroups, +) +from avalanche.benchmarks.utils.data import make_avalanche_dataset _default_mnist_train_transform = Compose( [Normalize((0.1307,), (0.3081,))] diff --git a/avalanche/benchmarks/datasets/lvis_dataset/lvis_dataset.py b/avalanche/benchmarks/datasets/lvis_dataset/lvis_dataset.py index d250d941b..a63696d7f 100644 --- a/avalanche/benchmarks/datasets/lvis_dataset/lvis_dataset.py +++ b/avalanche/benchmarks/datasets/lvis_dataset/lvis_dataset.py @@ -159,20 +159,20 @@ def __getitem__(self, index): """ img_id = self.img_ids[index] img_dict: LVISImgEntry = self.lvis_api.load_imgs(ids=[img_id])[0] - annotation_dicts = self.targets[index] + annotation_dicts: LVISImgTargets = self.targets[index] # Transform from LVIS dictionary to torchvision-style target - num_objs = len(annotation_dicts) + num_objs = annotation_dicts["bbox"].shape[0] boxes = [] labels = [] for i in range(num_objs): - xmin = annotation_dicts[i]["bbox"][0] - ymin = annotation_dicts[i]["bbox"][1] - xmax = xmin + annotation_dicts[i]["bbox"][2] - ymax = ymin + annotation_dicts[i]["bbox"][3] + xmin = annotation_dicts["bbox"][i][0] + ymin = annotation_dicts["bbox"][i][1] + xmax = xmin + annotation_dicts["bbox"][i][2] + ymax = ymin + annotation_dicts["bbox"][i][3] boxes.append([xmin, ymin, xmax, ymax]) - labels.append(annotation_dicts[i]["category_id"]) + labels.append(annotation_dicts["category_id"][i]) if len(boxes) > 0: boxes = torch.as_tensor(boxes, dtype=torch.float32) @@ -183,7 +183,7 @@ def __getitem__(self, index): image_id = torch.tensor([img_id]) areas = [] for i in range(num_objs): - areas.append(annotation_dicts[i]["area"]) + areas.append(annotation_dicts["area"][i]) areas = torch.as_tensor(areas, dtype=torch.float32) iscrowd = torch.zeros((num_objs,), dtype=torch.int64) @@ -233,7 +233,17 @@ class LVISAnnotationEntry(TypedDict): category_id: int -class LVISDetectionTargets(Sequence[List[LVISAnnotationEntry]]): +class LVISImgTargets(TypedDict): + id: torch.Tensor + area: torch.Tensor + segmentation: List[List[List[float]]] + image_id: torch.Tensor + bbox: torch.Tensor + category_id: torch.Tensor + labels: torch.Tensor + + +class LVISDetectionTargets(Sequence[List[LVISImgTargets]]): def __init__( self, lvis_api: LVIS, @@ -254,7 +264,28 @@ def __getitem__(self, index): annotation_dicts: List[LVISAnnotationEntry] = self.lvis_api.load_anns( annotation_ids ) - return annotation_dicts + + n_annotations = len(annotation_dicts) + + category_tensor = torch.empty((n_annotations,), dtype=torch.long) + target_dict: LVISImgTargets = { + 'bbox': torch.empty((n_annotations, 4), dtype=torch.float32), + 'category_id': category_tensor, + 'id': torch.empty((n_annotations,), dtype=torch.long), + 'area': torch.empty((n_annotations,), dtype=torch.float32), + 'image_id': torch.full((1,), img_id, dtype=torch.long), + 'segmentation': [], + 'labels': category_tensor # Alias of category_id + } + + for ann_idx, annotation in enumerate(annotation_dicts): + target_dict['bbox'][ann_idx] = torch.as_tensor(annotation['bbox']) + target_dict['category_id'][ann_idx] = annotation['category_id'] + target_dict['id'][ann_idx] = annotation['id'] + target_dict['area'][ann_idx] = annotation['area'] + target_dict['segmentation'].append(annotation['segmentation']) + + return target_dict def _test_to_tensor(a, b): @@ -316,5 +347,6 @@ def _plot_detection_sample(img: Image.Image, target): "LvisDataset", "LVISImgEntry", "LVISAnnotationEntry", + "LVISImgTargets", "LVISDetectionTargets", ] diff --git a/avalanche/benchmarks/utils/data_loader.py b/avalanche/benchmarks/utils/data_loader.py index d0e6a7e91..bdfd8452a 100644 --- a/avalanche/benchmarks/utils/data_loader.py +++ b/avalanche/benchmarks/utils/data_loader.py @@ -17,7 +17,7 @@ from typing import Any, Dict, List, Mapping, Optional, Sequence, Union import torch -from torch.utils.data import RandomSampler, DistributedSampler +from torch.utils.data import RandomSampler, DistributedSampler, Dataset from torch.utils.data.dataloader import DataLoader from avalanche.benchmarks.utils.collate_functions import ( @@ -31,6 +31,7 @@ ) from avalanche.benchmarks.utils.data import AvalancheDataset from avalanche.benchmarks.utils.data_attribute import DataAttribute +from avalanche.distributed.distributed_helper import DistributedHelper _default_collate_mbatches_fn = classification_collate_mbatches_fn @@ -284,14 +285,14 @@ def __init__( self.collate_mbatches = collate_mbatches for data in self.datasets: - if _DistributedHelper.is_distributed and distributed_sampling: + if DistributedHelper.is_distributed and distributed_sampling: seed = torch.randint( 0, - 2 ** 32 - 1 - _DistributedHelper.world_size, + 2 ** 32 - 1 - DistributedHelper.world_size, (1,), dtype=torch.int64, ) - seed += _DistributedHelper.rank + seed += DistributedHelper.rank generator = torch.Generator() generator.manual_seed(int(seed)) else: @@ -584,11 +585,11 @@ def _get_batch_sizes( def _make_data_loader( - dataset, - distributed_sampling, - data_loader_args, - batch_size, - force_no_workers=False, + dataset: Dataset, + distributed_sampling: bool, + data_loader_args: Dict[str, Any], + batch_size: int, + force_no_workers: bool = False, ): data_loader_args = data_loader_args.copy() @@ -601,14 +602,22 @@ def _make_data_loader( if 'prefetch_factor' in data_loader_args: data_loader_args['prefetch_factor'] = 2 - if _DistributedHelper.is_distributed and distributed_sampling: + if DistributedHelper.is_distributed and distributed_sampling: + # Note: shuffle only goes in the sampler, while + # drop_last must be passed to both the sampler + # and the DataLoader + drop_last = data_loader_args.pop("drop_last", False) sampler = DistributedSampler( dataset, - shuffle=data_loader_args.pop("shuffle", False), - drop_last=data_loader_args.pop("drop_last", False), + shuffle=data_loader_args.pop("shuffle", True), + drop_last=drop_last, ) data_loader = DataLoader( - dataset, sampler=sampler, batch_size=batch_size, **data_loader_args + dataset, + sampler=sampler, + batch_size=batch_size, + drop_last=drop_last, + **data_loader_args ) else: sampler = None @@ -619,15 +628,6 @@ def _make_data_loader( return data_loader, sampler -class __DistributedHelperPlaceholder: - is_distributed = False - world_size = 1 - rank = 0 - - -_DistributedHelper = __DistributedHelperPlaceholder() - - __all__ = [ "detection_collate_fn", "detection_collate_mbatches_fn", diff --git a/avalanche/core.py b/avalanche/core.py index 99105b392..f266b7f0d 100644 --- a/avalanche/core.py +++ b/avalanche/core.py @@ -1,5 +1,5 @@ from abc import ABC -from typing import TypeVar, Generic +from typing import Optional, Type, TypeVar, Generic from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -27,8 +27,16 @@ class BasePlugin(Generic[Template], ABC): and loggers. """ + supports_distributed: bool = False + """ + A flag describing whether this plugin supports distributed training. + """ + def __init__(self): - pass + """ + Inizializes an instance of a supervised plugin. + """ + super().__init__() def before_training(self, strategy: Template, *args, **kwargs): """Called before `train` by the `BaseTemplate`.""" @@ -68,6 +76,13 @@ def after_eval(self, strategy: Template, *args, **kwargs) -> CallbackResult: """Called after `eval` by the `BaseTemplate`.""" pass + def __init_subclass__( + cls, + supports_distributed: bool = False, + **kwargs) -> None: + cls.supports_distributed = supports_distributed + return super().__init_subclass__(**kwargs) + class BaseSGDPlugin(BasePlugin[Template], ABC): """ABC for BaseSGDTemplate plugins. @@ -75,6 +90,12 @@ class BaseSGDPlugin(BasePlugin[Template], ABC): See `BaseSGDTemplate` for complete description of the train/eval loop. """ + def __init__(self): + """ + Inizializes an instance of a base SGD plugin. + """ + super().__init__() + def before_training_epoch( self, strategy: Template, *args, **kwargs ) -> CallbackResult: @@ -193,7 +214,11 @@ class SupervisedPlugin(BaseSGDPlugin[Template], ABC): See `BaseTemplate` for complete description of the train/eval loop. """ - pass + def __init__(self): + """ + Inizializes an instance of a supervised plugin. + """ + super().__init__() class SupervisedMetaLearningPlugin(SupervisedPlugin[Template], ABC): diff --git a/avalanche/distributed/__init__.py b/avalanche/distributed/__init__.py new file mode 100644 index 000000000..88357b2f1 --- /dev/null +++ b/avalanche/distributed/__init__.py @@ -0,0 +1 @@ +from .distributed_helper import * diff --git a/avalanche/distributed/distributed_consistency_verification.py b/avalanche/distributed/distributed_consistency_verification.py new file mode 100644 index 000000000..689e85c28 --- /dev/null +++ b/avalanche/distributed/distributed_consistency_verification.py @@ -0,0 +1,108 @@ +import hashlib +import io + +from typing import Tuple, TYPE_CHECKING + +import torch + +from torch.utils.data import DataLoader + +if TYPE_CHECKING: + from torch import Tensor + from torch.nn import Module + from avalanche.benchmarks import DatasetScenario + from torch.utils.data import Dataset + + +def hash_benchmark(benchmark: 'DatasetScenario', *, + hash_engine=None, num_workers=0) -> str: + if hash_engine is None: + hash_engine = hashlib.sha256() + + for stream_name in sorted(benchmark.streams.keys()): + stream = benchmark.streams[stream_name] + hash_engine.update(stream_name.encode()) + for experience in stream: + exp_dataset = experience.dataset + hash_dataset(exp_dataset, + hash_engine=hash_engine, + num_workers=num_workers) + return hash_engine.hexdigest() + + +def hash_dataset(dataset: 'Dataset', *, hash_engine=None, num_workers=0) -> str: + if hash_engine is None: + hash_engine = hashlib.sha256() + + data_loader = DataLoader( + dataset, + collate_fn=lambda batch: tuple(zip(*batch)), + num_workers=num_workers + ) + for loaded_elem in data_loader: + example = tuple(tuple_element[0] for tuple_element in loaded_elem) + + # https://stackoverflow.com/a/63880190 + buff = io.BytesIO() + torch.save(example, buff) + buff.seek(0) + hash_engine.update(buff.read()) + return hash_engine.hexdigest() + + +def hash_minibatch(minibatch: 'Tuple[Tensor]', *, hash_engine=None) -> str: + if hash_engine is None: + hash_engine = hashlib.sha256() + + for tuple_elem in minibatch: + buff = io.BytesIO() + torch.save(tuple_elem, buff) + buff.seek(0) + hash_engine.update(buff.read()) + return hash_engine.hexdigest() + + +def hash_tensor(tensor: 'Tensor', *, hash_engine=None) -> str: + if hash_engine is None: + hash_engine = hashlib.sha256() + + buff = io.BytesIO() + torch.save(tensor, buff) + buff.seek(0) + hash_engine.update(buff.read()) + return hash_engine.hexdigest() + + +def hash_model( + model: 'Module', + include_buffers=True, + *, + hash_engine=None) -> str: + if hash_engine is None: + hash_engine = hashlib.sha256() + + for name, param in model.named_parameters(): + hash_engine.update(name.encode()) + buff = io.BytesIO() + torch.save(param.detach().cpu(), buff) + buff.seek(0) + hash_engine.update(buff.read()) + + if include_buffers: + for name, model_buffer in model.named_buffers(): + hash_engine.update(name.encode()) + buff = io.BytesIO() + torch.save(model_buffer.detach().cpu(), buff) + buff.seek(0) + hash_engine.update(buff.read()) + + return hash_engine.hexdigest() + + +__all__ = [ + 'hash_benchmark', + 'hash_dataset', + 'hash_minibatch', + 'hash_tensor', + 'hash_model' +] diff --git a/avalanche/distributed/distributed_helper.py b/avalanche/distributed/distributed_helper.py new file mode 100644 index 000000000..5d7054da6 --- /dev/null +++ b/avalanche/distributed/distributed_helper.py @@ -0,0 +1,819 @@ +import os +import pickle +import warnings +from io import BytesIO +from typing import ContextManager, Optional, List, Any, Iterable, Dict, TypeVar + +import torch +from torch import Tensor +from torch.nn.modules import Module +from torch.nn.parallel import DistributedDataParallel +from typing_extensions import Literal +from torch.distributed import ( + init_process_group, + broadcast_object_list +) + + +BroadcastT = TypeVar('BroadcastT') + + +from avalanche.distributed.distributed_consistency_verification import \ + hash_tensor + + +class _Singleton(type): + _instances: Dict[Any, Any] = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(_Singleton, cls).__call__( + *args, **kwargs) + return cls._instances[cls] + + +class RollingSeedContext(object): + """ + Implement seed alignment by storing the state of random number generators. + + Doesn't require a distributed communication (even broadcast), which makes + this the best choices when wrapping sections that (may) both: + - behave differently depending on the rank + - change the global state of random number generators + """ + def __init__(self): + self.rng_manager_state = None + + def save_generators_state(self): + from avalanche.training.determinism.rng_manager import RNGManager + self.rng_manager_state = RNGManager.__getstate__() + + def load_generators_state(self): + from avalanche.training.determinism.rng_manager import RNGManager + self.rng_manager_state = RNGManager.__setstate__(self.rng_manager_state) + + def step_random_generators(self): + from avalanche.training.determinism.rng_manager import RNGManager + RNGManager.step_generators() + + def __enter__(self): + self.save_generators_state() + + def __exit__(self, *_): + self.load_generators_state() + self.step_random_generators() + + +class BroadcastSeedContext(object): + """ + Implement seed alignment by broadcasting a new seed from the main process. + + This is usually slower than using :class:`RollingSeedContext`. + """ + def __init__(self): + pass + + def __enter__(self): + pass + + def __exit__(self, *_): + DistributedHelper.align_seeds() + + +class _MainProcessFirstContext(object): + """ + A context in which the main process must enter and exit the section before + other processes. + + For instance, can be used to wrap the dataset download procedure. + """ + + def __init__( + self, + seed_alignment: Literal["rolling", "broadcast"] = 'rolling', + final_barrier: bool = False): + + self._seed_aligner: ContextManager + if seed_alignment == 'rolling': + self._seed_aligner = RollingSeedContext() + else: + self._seed_aligner = BroadcastSeedContext() + + self._final_barrier = final_barrier + + def __enter__(self): + self._seed_aligner.__enter__() + + if not DistributedHelper.is_main_process: + # Wait for the main process + DistributedHelper.barrier() + + def __exit__(self, exc_type, exc_val, exc_tb): + if DistributedHelper.is_main_process: + # Let other process enter the section + DistributedHelper.barrier() + + self._seed_aligner.__exit__(exc_type, exc_val, exc_tb) + if self._final_barrier: + DistributedHelper.barrier() + + +class _DistributedHelperCls(object): + """ + Implementation of the distributed helper. + + Methods of this class can be used as an high-level wrappers + aroung the torch.ddistributed API to allow for a simpler + distributed communication. + + Only a single object of this class is instantiated + as the "DistributedHelper" singleton. + + + Note: differently from the original Pytorch API, which requires + that input tensor(s) to be moved to the default device (forced to + CUDA if using NCCL), these functions usually also manage input tensors + residing on a different devices. The returned elements will + be moved to the same device of the input tensor. Consider looking at + the documentation of each method for more details. + """ + __metaclass__ = _Singleton + + def __init__(self): + self.use_cuda = False + self._dev_map = _DistributedHelperCls._make_map('cpu') + + def init_distributed(self, random_seed, backend=None, use_cuda=True): + if self.is_distributed: + raise RuntimeError('Distributed API already initialized') + + use_cuda = use_cuda and torch.cuda.is_available() + + if backend is None: + if use_cuda: + backend = 'nccl' + else: + backend = 'gloo' + + if backend == 'nccl' and not use_cuda: + warnings.warn( + 'Bad configuration: using NCCL, but you set use_cuda=False!') + + could_initialize_distributed = False + if os.environ.get('LOCAL_RANK', None) is None: + warnings.warn( + 'Torch distributed could not be initialized ' + '(missing environment configuration)') + else: + init_process_group(backend=backend) + could_initialize_distributed = True + + self.set_random_seeds(random_seed) + self.use_cuda = use_cuda + + # Used only for debugging purposes + # if use_cuda or backend == 'nccl': + # # https://github.com/pytorch/pytorch/issues/6351 + # torch.backends.cudnn.deterministic = True + # torch.backends.cudnn.benchmark = False + + # Force-init the default CUDA device (if any) + reference_device = self.make_device(set_cuda_device=True) + + # Create map for device placement of unpickled tensors + self._dev_map = _DistributedHelperCls._make_map(reference_device) + + return could_initialize_distributed + + def get_device_id(self): + """ + Obtain the id of the device to use. + + :return: an int, describing the device to use. -1 for CPU. + """ + if self.is_distributed: + device_id = self.rank + else: + device_id = 0 + + if self.use_cuda: + return device_id + + return -1 + + def make_device(self, set_cuda_device: bool = False) -> torch.device: + """ + Returns (an optionally sets) the default `torch.device` to use. + + Automatically called from :meth:`init_distributed`. + + :param set_cuda_device: If True, sets the default device + by calling :func:`torch.cuda.set_device`. + :return: The default device to be used for `torch.distributed` + communications. + """ + if self.is_distributed: + device_id = self.rank + else: + device_id = 0 + + if self.use_cuda and device_id >= 0: + ref_device = torch.device(f'cuda:{device_id}') + if set_cuda_device: + torch.cuda.set_device(ref_device) + else: + ref_device = torch.device('cpu') + return ref_device + + def wrap_model(self, model: Module) -> Module: + """ + Wraps a given model to enable distributed training. + + The given model will be wrapped using :class:`DistributedDataParallel`. + + :return: The model wrapped in :class:`DistributedDataParallel` if + running a distributed training, or the model itself if running a + single-process training. + """ + # Note: find_unused_parameters is needed for multi task models. + if self.is_distributed: + if self.forced_cuda_comm or self.use_cuda: + # forced_cuda_comm is True if using NCCL; use_cuda may be true + # even when not using NCCL. + # User already warned if using NCCL with use_cuda==False. + # device_ids must be a single device id + # (an int, a device object or a str) + # If not set, output_device defaults to device_ids[0] + return DistributedDataParallel( + model, + device_ids=[self.make_device()], + find_unused_parameters=True) + else: + return DistributedDataParallel( + model, + find_unused_parameters=True) + else: + return model + + def unwrap_model(self, model: Module) -> Module: + """ + Unwrap a model. + + :param model: A model to be unwrapped. + :return: The unwrapped model. + """ + if isinstance(model, DistributedDataParallel): + return model.module + + return model + + def set_random_seeds(self, random_seed): + """ + Set the random seed for all number generators. + + :param random_seed: The random seed to set. + """ + from avalanche.training.determinism.rng_manager import RNGManager + RNGManager.set_random_seeds(random_seed) + + def align_seeds(self): + """ + Aligns the random seed for all number generators across all processes. + """ + + if not self.is_distributed: + return + + if self.is_main_process: + reference_seed = torch.randint(0, 2**32-1, (1,), dtype=torch.int64) + else: + reference_seed = torch.empty((1,), dtype=torch.int64) + + self.broadcast(reference_seed) + seed = int(reference_seed) + self.set_random_seeds(seed) + + def main_process_first(self): + """ + Returns an execution context allowing the main process + to complete the section before allowing other processes + to enter it. + + A common use case is to allow the main process to download + the dataset without the risk of interference from the + other processes. + """ + return _MainProcessFirstContext() + + def barrier(self): + """ + Awaits for all processes. + + No-op if not running a distributed training. + """ + + if self.is_distributed: + torch.distributed.barrier() + + def broadcast(self, tensor: Tensor, src: int = 0): + """ + Broadcasts the given tensor from a source process to all processes. + + Differences with torch.distributed: + - The input tensor can reside in any device. + - The input tensor will be transmitted using the current backend. + However, the resulting tensor will be moved to the save device + as the `tensor` parameter before retutrning it, + no matter the backend in use. + - No-op if not running a distributed training. + + :param tensor: The tensor to be broadcasted. + :param src: The rank of the source process. Defaults to 0, + which is the main process. + :return: The tensor obtained from the source process, in the same + device as the tensor parameter. + """ + + if not self.is_distributed: + return tensor + + tensor_distrib, orig_data = self._prepare_for_distributed_comm(tensor) + torch.distributed.broadcast(tensor_distrib, src=src) + tensor = self._revert_to_original_device(tensor_distrib, orig_data) + + return tensor + + def broadcast_object(self, obj: BroadcastT, src=0) -> BroadcastT: + """ + Broadcasts the given object from a source process to all processes. + + Note: if broadcasting a Tensor, consider using :meth:`broadcast` + instead. + + Differences with torch.distributed: + - No-op if not running a distributed training. + + :param obj: The object to be broadcasted. + :param src: The rank of the source process. Defaults to 0, + which is the main process. + :return: The object obtained from the source process. + """ + + if not self.is_distributed: + return obj + + io_list = [obj] + + broadcast_object_list(io_list, src=src) + return io_list[0] + + def cat_all(self, tensor: Tensor): + """ + Concatenates tensors from all processes. + + The resulting tensor will be concatenated in the order given by the + rank of each source process. + + Differences with torch.distributed: + - The input tensor can reside in any device. + - The input tensor will be transmitted using the current backend. + However, the resulting tensor will be moved to the save device + as the `tensor` parameter before returning it, + no matter the backend in use. + - No-op if not running a distributed training. + + :param tensor: The tensor from the current process. Tensors across + processes must have the same `tensor.shape[1:]`. + :return: A single tensor, as a concatenation of the tensors from + all processes. + """ + # TODO: use all_gather_into_tensor (if available and + # if NCCL and tensor.device == 'default device') + + if not self.is_distributed: + return tensor + + gathered_tensors = self.gather_all(tensor) + for i, t in enumerate(gathered_tensors): + if len(t.shape) == 0: + # Tensor with 0-length shape + gathered_tensors[i] = torch.reshape(t, (1,)) + + return torch.cat(gathered_tensors) + + def gather_tensor_shapes(self, tensor: Tensor, max_shape_len=10) \ + -> List[List[int]]: + """ + Gathers the shapes of the tensors from all processes. + + :param tensor: The tensor from the current process. + :param max_shape_len: If an int, defines maximum expected length + of the shapes. In that case, an efficient communication + primitive will be used. If None, shapes will be obtained + via :meth:`gather_all_objects`. Defaults to 10. + :return: A list of shapes (one from each process, in rank order). + Each shape is returned as a list of `int`s. + """ + # Tensor differ by whole shape + tensor_size = torch.zeros(max_shape_len, dtype=torch.int64) + for i in range(len(tensor.shape)): + tensor_size[i] = tensor.shape[i] + all_tensors_shape = [ + self._prepare_for_distributed_comm( + torch.zeros_like(tensor_size))[0] + for _ in range(self.world_size)] + tensor_size, _ = self._prepare_for_distributed_comm(tensor_size) + + torch.distributed.all_gather(all_tensors_shape, tensor_size) + + all_tensors_shape = [t.cpu() for t in all_tensors_shape] + + # Trim shape + for i, t in enumerate(all_tensors_shape): + for x in range(len(t)): + if t[x] == 0: + if x == 0: + # Tensor with 0-length shape + all_tensors_shape[i] = t[:x+1] + else: + all_tensors_shape[i] = t[:x] + + break + + return [t_shape.tolist() for t_shape in all_tensors_shape] + + def gather_all( + self, + tensor: Tensor, + same_shape: bool = False, + shapes: Optional[List[List[int]]] = None) -> List[Tensor]: + """ + Gather all for tensors only. + + Differences with torch.distributed: + - The input tensor can reside in any device. + - The input tensor will be transmitted using the current backend. + However, the resulting tensors will be moved to the save device + as the `tensor` parameter before returning them, + no matter the backend in use. + - No-op if not running a distributed training. + + This will also manage tensors of different shapes. If you + are sure that the tensors will be of the same shape, consider + passing same_shape to speed up the communication. + + Beware that, if you are in need of concatenating multiple tensors, + method :meth:`cat_all` may be more suitable. + + :param tensor: The tensor to be sent from the current process. + :return: A list of tensors, one from each process (in rank order). + """ + if not self.is_distributed: + return [tensor] + + # Based on: + # https://discuss.pytorch.org/t/how-to-concatenate-different-size-tensors-from-distributed-processes/44819/4 + + if same_shape: + # Same size for all tensors + if len(tensor.shape) > 0: + tensor_size = list(tensor.shape) + else: + tensor_size = [0] + all_tensors_shape = \ + [tensor_size for _ in range(self.world_size)] + elif shapes is not None: + # Shapes given by the user + # make sure it is a list of lists + all_tensors_shape = [list(s) for s in shapes] + else: + # Tensor differ by whole shape + all_tensors_shape = self.gather_tensor_shapes(tensor) + + same_shape = all(all_tensors_shape[0] == x for x in all_tensors_shape) + orig_device = tensor.device + + if same_shape: + # Same shape: create identical tensors and proceed with all_gather + out_tensors = [torch.empty_like(tensor) for _ in all_tensors_shape] + else: + # Different shapes: create a tensors of the size of the bigger one + all_tensors_numel = [] + dtype = tensor.dtype + for t_shape in all_tensors_shape: + if t_shape[0] == 0 and len(t_shape) == 1: + # Tensor with 0-length shape + curr_size = 1 + else: + curr_size = 1 + for t_s in t_shape: + curr_size *= t_s + all_tensors_numel.append(curr_size) + + max_numel = max(all_tensors_numel) + out_tensors = [torch.empty((max_numel,), dtype=dtype) + for _ in all_tensors_shape] + + tensor = tensor.flatten() + n_padding = max_numel - tensor.numel() + if n_padding > 0: + padding = torch.zeros((n_padding,), + dtype=tensor.dtype, + device=orig_device) + tensor = torch.cat((tensor, padding), dim=0) + + tensor, _ = self._prepare_for_distributed_comm(tensor) + out_tensors = [self._prepare_for_distributed_comm(t)[0] + for t in out_tensors] + + torch.distributed.all_gather(out_tensors, tensor) + + if not same_shape: + # The tensors are flat and of the wrong dimension: re-shape them + for tensor_idx, (tensor_sz, tensor_numel, out_t) in \ + enumerate(zip(all_tensors_shape, + all_tensors_numel, + out_tensors)): + if tensor_sz[0] == 0: + # Tensor with 0-length shape + out_tensors[tensor_idx] = \ + out_t[:tensor_numel].reshape(tuple()) + else: + out_tensors[tensor_idx] = \ + out_t[:tensor_numel].reshape(tensor_sz) + + out_tensors = [t.to(orig_device) for t in out_tensors] + return out_tensors + + def gather_all_objects(self, obj: BroadcastT) -> List[BroadcastT]: + """ + Gather all for objects. This will also take care of moving cuda tensors + (even the ones nested inside objects) to the correct default device. + + Same as torch.distributed: + - Tensors nested inside the input object must reside in the + default device. Future versions of Avalanche may adopt + solutions to circumvent the limitations of + orch.distributed. + + Differences with torch.distributed: + - The input object will be transmitted using the current backend. + However, the resulting tensors nested inside of it + will be moved to the default device before returning them, + no matter the backend in use. + - No-op if not running a distributed training. + + :param obj: The object to be sent from the current process. + :return: A list of objects, one from each process (in rank order). + """ + out_list = [None for _ in range(self.world_size)] + torch.distributed.all_gather_object(out_list, obj) + return out_list # type: ignore + + def check_equal_tensors(self, tensor: Tensor): + """ + Checks if the given tensor is the same across all processes. + + This method will raise an error the tensors are not equal. + + :param tensor: The tensor to be compared. + """ + if not DistributedHelper.is_distributed: + return + + all_tensors = self.gather_all(tensor) + + tensors_hashes = [hash_tensor(t) for t in all_tensors] + + if len(set(tensors_hashes)) != 1: + # Equal tensors + raise ValueError('Different tensors. Got hashes: {}'.format( + tensors_hashes)) + + def check_equal_objects(self, obj: Any): + """ + Checks if the given object is the same across all processes. + + This method will raise an error the objects are not equal. + + :param tensor: The obj to be compared. + """ + if not DistributedHelper.is_distributed: + return + + output: List[Any] = [None for _ in range(self.world_size)] + torch.distributed.all_gather_object(output, obj) + + obj_bt = _base_typed(obj) + + for i, o in enumerate(output): + o_bt = _base_typed(o) + if obj_bt != o_bt: + raise ValueError( + 'Different objects (ranks this={}, remote={}). ' + 'Got this={}, remote={}'.format( + self.rank, i, obj, o)) + + def _prepare_for_distributed_comm(self, tensor: Tensor): + """ + Internal utility used to move the tensor to the backend device. + + :param: The tensor to be send using torch.distributed API. + :return: A tuple of 2 elements: + 1. The first element is the tensor moved to the correct device. + 2. The descriptor, as a tuple of 3 elements: + 1. The original device of the input tensor + 2. A boolean, describing if the tensor should be moved back + to the original device. + 3. The original tensor. + """ + original_device = tensor.device + copy_back = self.forced_cuda_comm and not tensor.is_cuda + if self.forced_cuda_comm: + tensor_distributed = tensor.cuda() + else: + tensor_distributed = tensor + + return tensor_distributed, (original_device, copy_back, tensor) + + def _revert_to_original_device(self, tensor_distributed, orig_data): + """ + Internal utility used to move the tensor back to the original device + (if needed). + + :param: The tensor obtained from a torch.distributed API call. + :param: The descriptor in the format of + :meth:`_prepare_for_distributed_comm`. + :return: The tensor moved to the appropriate device. + """ + + original_device, copy_back, tensor = orig_data + if copy_back: + if tensor is None: + tensor = tensor_distributed.to(original_device) + else: + tensor[:] = tensor_distributed + + return tensor + + @property + def rank(self) -> int: + """ + The current tank. + + :return: The rank of the current process. + Returns 0 if not running a distributed training. + """ + if torch.distributed.is_initialized(): + return torch.distributed.get_rank() + return 0 + + @property + def world_size(self) -> int: + """ + The world size. + + :return: The world size of the default group. + Returns 1 if not running a distributed training. + """ + + if torch.distributed.is_initialized(): + return torch.distributed.get_world_size() + return 1 + + @property + def is_distributed(self) -> bool: + """ + Check if running a distributed training. + + :return: True if running a distributed training, False + otherwise. + """ + return torch.distributed.is_initialized() + + @property + def is_main_process(self) -> bool: + """ + Check if running is the main process. + + :return: True if running in the main process (or not running a + distributed training), False otherwise. + """ + return self.rank == 0 + + @property + def backend(self) -> str: + """ + Obtain the name of the backend. + + :return: The name of the backend. + """ + return torch.distributed.get_backend() + + @property + def forced_cuda_comm(self) -> bool: + """ + Check if input tensors must be moved to the default cuda device + before passing them to torch.distributed API calls. + + :return: True if tensors must be moved to the default cuda device, + False otherwise. + """ + return self.backend == 'nccl' + + @property + def device_map(self) -> Dict[str, str]: + """ + Obtain the default device map, commonly used when unpickling elements + coming from other processes (or via `torch.load`). + + :return: A device map mapping devices to the current one. + """ + return self._dev_map + + @staticmethod + def _make_map(device_or_map) -> Dict[str, str]: + # TODO: borrowed from checkpointing plugins + # it would be better to have a single function in a shared utils + if not isinstance(device_or_map, (torch.device, str)): + return device_or_map + + device = torch.device(device_or_map) + map_location = dict() + + map_location['cpu'] = 'cpu' + for cuda_idx in range(100): + map_location[f'cuda:{cuda_idx}'] = str(device) + return map_location + + +BASE_TYPES = [str, int, float, bool, type(None)] + + +def _base_typed(obj): + """ + Improved version of https://stackoverflow.com/a/62420097 + """ + T = type(obj) + from_numpy = T.__module__ == 'numpy' + from_pytorch = T.__module__ == 'torch' + + if from_numpy or from_pytorch: + return obj.tolist() + + if T in BASE_TYPES or callable(obj) or ((from_numpy or from_pytorch) + and not isinstance(T, Iterable)): + return obj + + if isinstance(obj, Dict): + return {_base_typed(k): _base_typed(v) for k, v in obj.items()} + elif isinstance(obj, Iterable): + base_items = [_base_typed(item) for item in obj] + return base_items if (from_numpy or from_pytorch) else T(base_items) + + d = obj if T is dict else obj.__dict__ + + return {k: _base_typed(v) for k, v in d.items()} + + +def fix(): + return lambda b: torch.load(BytesIO(b), + map_location=DistributedHelper.device_map) + + +class MappedUnpickler(pickle.Unpickler): + """ + An unpickler that maps incoming tensors to the default device + of this process, thus preventing issues when moving objects containing + nested `Tensor`s. + + This unpickler will we used to replace the + `torch.distributed.distributed_c10d._unpickler`. + """ + # Based on: + # https://github.com/pytorch/pytorch/issues/16797#issuecomment-777059657 + + # In turn based on: + # https://github.com/pytorch/pytorch/issues/16797#issuecomment-633423219 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def find_class(self, module, name): + if module == 'torch.storage' and name == '_load_from_bytes': + return fix() + else: + return super().find_class(module, name) + + +torch.distributed.distributed_c10d._unpickler = MappedUnpickler + +DistributedHelper = _DistributedHelperCls() + + +__all__ = [ + 'RollingSeedContext', + 'BroadcastSeedContext', + '_DistributedHelperCls', + 'DistributedHelper' +] diff --git a/avalanche/evaluation/metrics/detection.py b/avalanche/evaluation/metrics/detection.py index bcfd38c41..c10f3a11e 100644 --- a/avalanche/evaluation/metrics/detection.py +++ b/avalanche/evaluation/metrics/detection.py @@ -390,7 +390,7 @@ def _check_evaluator(self): ) def __str__(self): - return "LvisMetrics" + return "DetectionMetrics" def lvis_evaluator_factory(lvis_gt: LVIS, iou_types: List[str]): diff --git a/avalanche/logging/base_logger.py b/avalanche/logging/base_logger.py index 77b86864e..4ce79ad77 100644 --- a/avalanche/logging/base_logger.py +++ b/avalanche/logging/base_logger.py @@ -2,6 +2,8 @@ from typing import TYPE_CHECKING, List +from avalanche.distributed.distributed_helper import DistributedHelper + if TYPE_CHECKING: from avalanche.evaluation.metric_results import MetricValue from avalanche.training.templates import SupervisedTemplate @@ -28,6 +30,31 @@ class BaseLogger(ABC): def __init__(self): super().__init__() + if not DistributedHelper.is_main_process: + + raise RuntimeError( + 'You are creating a logger in a non-main process during a ' + 'distributed training session. ' + 'Jump to this error for an example on how to fix this.') + + # You have to create the loggers in the main process only. Otherwise, + # metrics will end up duplicated in your log files and consistency + # errors may arise. When creating the EvaluationPlugin in a + # non-main process, just pass loggers=None. + # + # Recommended way: + # if not DistributedHelper.is_main_process + # # Define the loggers + # loggers = [...] + # else: + # loggers = None + # + # # Instantiate the evaluation plugin + # eval_plugin = EvaluationPlugin(metricA, metricB, ..., loggers=loggers) + # + # # Instantiate the strategy + # strategy = MyStrategy(..., evaluator=eval_plugin) + def log_single_metric(self, name, value, x_plot): """Log a metric value. diff --git a/avalanche/models/dynamic_modules.py b/avalanche/models/dynamic_modules.py index 551fefbed..835f9058b 100644 --- a/avalanche/models/dynamic_modules.py +++ b/avalanche/models/dynamic_modules.py @@ -75,6 +75,15 @@ def eval_adaptation(self, experience: CLExperience): """ pass + @property + def _adaptation_device(self): + """ + The device to use when expanding (or otherwise adapting) + the model. Defaults to the current device of the fist + parameter listed using :meth:`parameters`. + """ + return next(self.parameters()).device + class MultiTaskModule(DynamicModule): """Base pytorch Module with support for task labels. @@ -217,7 +226,7 @@ def __init__( self.mask_value = mask_value self.classifier = torch.nn.Linear(in_features, initial_out_features) - au_init = torch.zeros(initial_out_features, dtype=torch.bool) + au_init = torch.zeros(initial_out_features, dtype=torch.int8) self.register_buffer("active_units", au_init) self.active_units: torch.Tensor = au_init # Needed for type checks @@ -228,6 +237,7 @@ def adaptation(self, experience: CLExperience): :param experience: data from the current experience. :return: """ + device = self._adaptation_device in_features = self.classifier.in_features old_nclasses = self.classifier.out_features curr_classes = experience.classes_in_this_experience @@ -237,7 +247,10 @@ def adaptation(self, experience: CLExperience): if self.masking: if old_nclasses != new_nclasses: # expand active_units mask old_act_units = self.active_units - self.active_units = torch.zeros(new_nclasses, dtype=torch.bool) + self.active_units = torch.zeros( + new_nclasses, + dtype=torch.int8, + device=device) self.active_units[: old_act_units.shape[0]] = old_act_units # update with new active classes if self.training: @@ -247,7 +260,7 @@ def adaptation(self, experience: CLExperience): if old_nclasses == new_nclasses: return old_w, old_b = self.classifier.weight, self.classifier.bias - self.classifier = torch.nn.Linear(in_features, new_nclasses) + self.classifier = torch.nn.Linear(in_features, new_nclasses).to(device) self.classifier.weight[:old_nclasses] = old_w self.classifier.bias[:old_nclasses] = old_b @@ -320,14 +333,14 @@ def __init__( self.classifiers["0"] = first_head self.max_class_label = max(self.max_class_label, initial_out_features) - au_init = torch.zeros(initial_out_features, dtype=torch.bool) + au_init = torch.zeros(initial_out_features, dtype=torch.int8) self.register_buffer("active_units_T0", au_init) @property def active_units(self): res = {} for tid in self.known_train_tasks_labels: - mask = getattr(self, f"active_units_T{tid}") + mask = getattr(self, f"active_units_T{tid}").to(torch.bool) au = torch.arange(0, mask.shape[0])[mask].tolist() res[tid] = au return res @@ -336,7 +349,7 @@ def active_units(self): def task_masks(self): res = {} for tid in self.known_train_tasks_labels: - res[tid] = getattr(self, f"active_units_T{tid}") + res[tid] = getattr(self, f"active_units_T{tid}").to(torch.bool) return res def adaptation(self, experience: CLExperience): @@ -346,6 +359,7 @@ def adaptation(self, experience: CLExperience): :return: """ super().adaptation(experience) + device = self._adaptation_device curr_classes = experience.classes_in_this_experience task_labels = experience.task_labels if isinstance(task_labels, ConstantSequence): @@ -357,12 +371,16 @@ def adaptation(self, experience: CLExperience): # head adaptation if tid not in self.classifiers: # create new head new_head = IncrementalClassifier( - self.in_features, self.starting_out_features - ) + self.in_features, + self.starting_out_features, + masking=False + ).to(device) self.classifiers[tid] = new_head au_init = torch.zeros( - self.starting_out_features, dtype=torch.bool + self.starting_out_features, + dtype=torch.int8, + device=device ) self.register_buffer(f"active_units_T{tid}", au_init) @@ -390,7 +408,9 @@ def adaptation(self, experience: CLExperience): if old_nunits != new_nclasses: # expand active_units mask old_act_units = self._buffers[au_name] self._buffers[au_name] = torch.zeros( - new_nclasses, dtype=torch.bool + new_nclasses, + dtype=torch.int8, + device=device ) self._buffers[au_name][ : old_act_units.shape[0] @@ -407,6 +427,7 @@ def forward_single_task(self, x, task_label): :param task_label: :return: """ + device = self._adaptation_device task_label = str(task_label) out = self.classifiers[task_label](x) if self.masking: @@ -415,7 +436,10 @@ def forward_single_task(self, x, task_label): nunits, oldsize = out.shape[-1], curr_au.shape[0] if oldsize < nunits: # we have to update the mask old_mask = self._buffers[au_name] - self._buffers[au_name] = torch.zeros(nunits, dtype=torch.bool) + self._buffers[au_name] = torch.zeros( + nunits, + dtype=torch.int8, + device=device) self._buffers[au_name][:oldsize] = old_mask curr_au = self._buffers[au_name] out[..., torch.logical_not(curr_au)] = self.mask_value diff --git a/avalanche/models/helper_method.py b/avalanche/models/helper_method.py index a1f5587f5..a7d3e07f4 100644 --- a/avalanche/models/helper_method.py +++ b/avalanche/models/helper_method.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 import torch import torch.nn as nn diff --git a/avalanche/models/utils.py b/avalanche/models/utils.py index 5a1ef3153..f847bafaa 100644 --- a/avalanche/models/utils.py +++ b/avalanche/models/utils.py @@ -1,19 +1,29 @@ from avalanche.benchmarks.utils import make_classification_dataset from avalanche.models.dynamic_modules import MultiTaskModule, DynamicModule import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel from collections import OrderedDict from avalanche.benchmarks.scenarios import CLExperience +def is_multi_task_module(model: nn.Module) -> bool: + return isinstance(model, MultiTaskModule) or \ + (isinstance(model, DistributedDataParallel) and + isinstance(model.module, MultiTaskModule)) + + def avalanche_forward(model, x, task_labels): - if isinstance(model, MultiTaskModule): + if is_multi_task_module(model): return model(x, task_labels) else: # no task labels return model(x) def avalanche_model_adaptation(model: nn.Module, experience: CLExperience): + if isinstance(model, DistributedDataParallel): + raise RuntimeError('The model is wrapped in DistributedDataParallel. ' + 'Please unwrap it before calling this method.') for module in model.modules(): if isinstance(module, DynamicModule): module.adaptation(experience) diff --git a/avalanche/training/determinism/rng_manager.py b/avalanche/training/determinism/rng_manager.py index fd30a639c..09c418d22 100644 --- a/avalanche/training/determinism/rng_manager.py +++ b/avalanche/training/determinism/rng_manager.py @@ -1,4 +1,3 @@ -import hashlib import random from collections import OrderedDict from typing import Any, Dict, Type diff --git a/avalanche/training/plugins/clock.py b/avalanche/training/plugins/clock.py index 535ef3f72..c04bf05ac 100644 --- a/avalanche/training/plugins/clock.py +++ b/avalanche/training/plugins/clock.py @@ -11,7 +11,7 @@ from avalanche.training.plugins import SupervisedPlugin -class Clock(SupervisedPlugin): +class Clock(SupervisedPlugin, supports_distributed=True): """Counter for strategy events. WARNING: Clock needs to be the last plugin, otherwise counters will be @@ -61,3 +61,8 @@ def after_training_exp(self, strategy, **kwargs): def after_eval_iteration(self, strategy, **kwargs): self.total_iterations += 1 + + +__all__ = [ + 'Clock' +] diff --git a/avalanche/training/plugins/evaluation.py b/avalanche/training/plugins/evaluation.py index b5b277199..baa51ad76 100644 --- a/avalanche/training/plugins/evaluation.py +++ b/avalanche/training/plugins/evaluation.py @@ -1,7 +1,18 @@ import warnings from copy import copy from collections import defaultdict -from typing import Any, Dict, List, Tuple, Union, Sequence, TYPE_CHECKING +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Tuple, + Union, + Sequence, + TYPE_CHECKING, +) +from avalanche.distributed.distributed_helper import DistributedHelper from avalanche.evaluation.metric_results import MetricValue from avalanche.evaluation.metrics import accuracy_metrics, loss_metrics @@ -37,7 +48,10 @@ class EvaluationPlugin: def __init__( self, *metrics: Union["PluginMetric", Sequence["PluginMetric"]], - loggers: Union["BaseLogger", Sequence["BaseLogger"]] = None, + loggers: Optional[Union[ + "BaseLogger", + Sequence["BaseLogger"], + Callable[[], Sequence["BaseLogger"]]]] = None, collect_all=True, strict_checks=False ): @@ -52,6 +66,7 @@ def __init__( is used when calling `eval`. An error will be raised otherwise. """ super().__init__() + self.supports_distributed = True self.collect_all = collect_all self.strict_checks = strict_checks @@ -65,12 +80,14 @@ def __init__( if loggers is None: loggers = [] + elif callable(loggers): + loggers = loggers() elif not isinstance(loggers, Sequence): loggers = [loggers] self.loggers: Sequence["BaseLogger"] = loggers - if len(self.loggers) == 0: + if len(self.loggers) == 0 and DistributedHelper.is_main_process: warnings.warn("No loggers specified, metrics will not be logged") self.all_metric_results: Dict[str, Tuple[List[int], List[Any]]] @@ -219,14 +236,24 @@ def before_eval(self, strategy: "SupervisedTemplate", **kwargs): raise ValueError(msge) -def default_evaluator(): +def default_loggers() -> Sequence["BaseLogger"]: + if DistributedHelper.is_main_process: + return [InteractiveLogger()] + else: + return [] + + +def default_evaluator() -> EvaluationPlugin: return EvaluationPlugin( accuracy_metrics( minibatch=False, epoch=True, experience=True, stream=True ), loss_metrics(minibatch=False, epoch=True, experience=True, stream=True), - loggers=[InteractiveLogger()], + loggers=default_loggers, ) -__all__ = ["EvaluationPlugin", "default_evaluator"] +__all__ = [ + "EvaluationPlugin", + "default_evaluator" +] diff --git a/avalanche/training/plugins/gdumb.py b/avalanche/training/plugins/gdumb.py index 45bb3189e..11f7e6fd5 100644 --- a/avalanche/training/plugins/gdumb.py +++ b/avalanche/training/plugins/gdumb.py @@ -1,5 +1,5 @@ import copy -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from avalanche.training.plugins.strategy_plugin import SupervisedPlugin from avalanche.training.storage_policy import ClassBalancedBuffer @@ -8,7 +8,7 @@ from avalanche.training.templates import SupervisedTemplate -class GDumbPlugin(SupervisedPlugin): +class GDumbPlugin(SupervisedPlugin, supports_distributed=True): """GDumb plugin. At each experience the model is trained from scratch using a buffer of @@ -21,7 +21,9 @@ class GDumbPlugin(SupervisedPlugin): https://www.robots.ox.ac.uk/~tvg/publications/2020/gdumb.pdf """ - def __init__(self, mem_size: int = 200): + def __init__( + self, + mem_size: int = 200): super().__init__() self.mem_size = mem_size @@ -52,3 +54,8 @@ def after_train_dataset_adaptation( ): self.storage_policy.update(strategy, **kwargs) strategy.adapted_dataset = self.storage_policy.buffer + + +__all__ = [ + 'GDumbPlugin' +] diff --git a/avalanche/training/plugins/mir.py b/avalanche/training/plugins/mir.py index 38abd39d9..2c35de2b9 100644 --- a/avalanche/training/plugins/mir.py +++ b/avalanche/training/plugins/mir.py @@ -1,6 +1,5 @@ -#!/usr/bin/env python3 import copy -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING import torch from avalanche.benchmarks.utils import concat_datasets from avalanche.models.utils import avalanche_forward diff --git a/avalanche/training/plugins/replay.py b/avalanche/training/plugins/replay.py index 6637ddc5b..ecefd746f 100644 --- a/avalanche/training/plugins/replay.py +++ b/avalanche/training/plugins/replay.py @@ -12,7 +12,7 @@ from avalanche.training.templates import SupervisedTemplate -class ReplayPlugin(SupervisedPlugin): +class ReplayPlugin(SupervisedPlugin, supports_distributed=True): """ Experience replay plugin. @@ -50,7 +50,7 @@ def __init__( batch_size: Optional[int] = None, batch_size_mem: Optional[int] = None, task_balanced_dataloader: bool = False, - storage_policy: Optional["ExemplarsBuffer"] = None, + storage_policy: Optional["ExemplarsBuffer"] = None ): super().__init__() self.mem_size = mem_size diff --git a/avalanche/training/supervised/ar1.py b/avalanche/training/supervised/ar1.py index 65f93852b..8c706ad52 100644 --- a/avalanche/training/supervised/ar1.py +++ b/avalanche/training/supervised/ar1.py @@ -1,5 +1,5 @@ import warnings -from typing import Optional, Sequence, List, Tuple +from typing import Callable, Optional, List, Tuple, Union import torch from torch import Tensor @@ -59,9 +59,12 @@ def __init__( ewc_lambda: float = 0, train_mb_size: int = 128, eval_mb_size: int = 128, - device=None, + device: Union[str, torch.device] = "cpu", plugins: Optional[List[SupervisedPlugin]] = None, - evaluator: EvaluationPlugin = default_evaluator(), + evaluator: Union[ + EvaluationPlugin, + Callable[[], EvaluationPlugin] + ] = default_evaluator, eval_every=-1, ): """ @@ -271,13 +274,19 @@ def make_train_dataloader(self, num_workers=0, shuffle=True, **kwargs): if hasattr(self.adapted_dataset, "collate_fn") else None ) + + other_dataloader_args = self._obtain_common_dataloader_parameters( + batch_size=current_batch_mb_size, + num_workers=num_workers, + shuffle=shuffle, + **kwargs + ) + # AR1 only supports SIT scenarios (no task labels). self.dataloader = DataLoader( self.adapted_dataset, - num_workers=num_workers, - batch_size=current_batch_mb_size, - shuffle=shuffle, collate_fn=collate_fn, + **other_dataloader_args ) def training_epoch(self, **kwargs): diff --git a/avalanche/training/supervised/cumulative.py b/avalanche/training/supervised/cumulative.py index 9acd8fd9c..875526a37 100644 --- a/avalanche/training/supervised/cumulative.py +++ b/avalanche/training/supervised/cumulative.py @@ -1,4 +1,5 @@ -from typing import Optional, List +from typing import Callable, Optional, List, Union +import torch from torch.nn import Module from torch.optim import Optimizer @@ -26,9 +27,12 @@ def __init__( train_mb_size: int = 1, train_epochs: int = 1, eval_mb_size: Optional[int] = None, - device=None, + device: Union[str, torch.device] = "cpu", plugins: Optional[List[SupervisedPlugin]] = None, - evaluator: EvaluationPlugin = default_evaluator(), + evaluator: Union[ + EvaluationPlugin, + Callable[[], EvaluationPlugin] + ] = default_evaluator, eval_every=-1, ): """Init. diff --git a/avalanche/training/supervised/deep_slda.py b/avalanche/training/supervised/deep_slda.py index 22d741df9..cb20c659b 100644 --- a/avalanche/training/supervised/deep_slda.py +++ b/avalanche/training/supervised/deep_slda.py @@ -1,12 +1,15 @@ import warnings -from typing import Optional, Sequence, Union +from typing import Callable, Optional, Sequence, Union import os import torch from avalanche.training.plugins import SupervisedPlugin from avalanche.training.templates import SupervisedTemplate -from avalanche.training.plugins.evaluation import default_evaluator +from avalanche.training.plugins.evaluation import ( + EvaluationPlugin, + default_evaluator, +) from avalanche.models.dynamic_modules import MultiTaskModule from avalanche.models import FeatureExtractorBackbone @@ -37,7 +40,10 @@ def __init__( eval_mb_size: int = 1, device: Union[str, torch.device] = "cpu", plugins: Optional[Sequence["SupervisedPlugin"]] = None, - evaluator=default_evaluator(), + evaluator: Union[ + EvaluationPlugin, + Callable[[], EvaluationPlugin] + ] = default_evaluator, eval_every=-1, ): """Init function for the SLDA model. diff --git a/avalanche/training/supervised/der.py b/avalanche/training/supervised/der.py index d166795f1..78f0981b2 100644 --- a/avalanche/training/supervised/der.py +++ b/avalanche/training/supervised/der.py @@ -1,5 +1,14 @@ from collections import defaultdict -from typing import Dict, List, Optional, Sequence, Set, SupportsInt, Union +from typing import ( + Callable, + Dict, + List, + Optional, + Sequence, + Set, + SupportsInt, + Union, +) import torch import torch.nn.functional as F @@ -10,7 +19,10 @@ from avalanche.benchmarks.utils.data import AvalancheDataset from avalanche.benchmarks.utils.data_attribute import TensorDataAttribute from avalanche.core import SupervisedPlugin -from avalanche.training.plugins.evaluation import default_evaluator +from avalanche.training.plugins.evaluation import ( + EvaluationPlugin, + default_evaluator, +) from avalanche.training.storage_policy import ( BalancedExemplarsBuffer, ReservoirSamplingBuffer, @@ -149,7 +161,10 @@ def __init__( eval_mb_size: Optional[int] = 1, device: Union[str, torch.device] = "cpu", plugins: Optional[List[SupervisedPlugin]] = None, - evaluator=default_evaluator(), + evaluator: Union[ + EvaluationPlugin, + Callable[[], EvaluationPlugin] + ] = default_evaluator, eval_every=-1, peval_mode="epoch", ): diff --git a/avalanche/training/supervised/er_ace.py b/avalanche/training/supervised/er_ace.py index 44c7384b9..cd13df9ca 100644 --- a/avalanche/training/supervised/er_ace.py +++ b/avalanche/training/supervised/er_ace.py @@ -1,18 +1,16 @@ -#!/usr/bin/env python3 -import copy -from typing import List, Optional, Sequence, Union +from typing import Callable, List, Optional, Union -import numpy as np import torch -import torch.nn.functional as F from torch.nn import CrossEntropyLoss, Module from torch.optim import Optimizer -from avalanche.benchmarks.utils import concat_datasets from avalanche.core import SupervisedPlugin from avalanche.models.utils import avalanche_forward from avalanche.training import ACECriterion -from avalanche.training.plugins.evaluation import default_evaluator +from avalanche.training.plugins.evaluation import ( + EvaluationPlugin, + default_evaluator, +) from avalanche.training.storage_policy import ClassBalancedBuffer from avalanche.training.templates import (OnlineSupervisedTemplate, SupervisedTemplate) @@ -45,7 +43,10 @@ def __init__( eval_mb_size: Optional[int] = 1, device: Union[str, torch.device] = "cpu", plugins: Optional[List[SupervisedPlugin]] = None, - evaluator=default_evaluator(), + evaluator: Union[ + EvaluationPlugin, + Callable[[], EvaluationPlugin] + ] = default_evaluator, eval_every=-1, peval_mode="experience", ): @@ -201,7 +202,10 @@ def __init__( eval_mb_size: Optional[int] = 1, device: Union[str, torch.device] = "cpu", plugins: Optional[List[SupervisedPlugin]] = None, - evaluator=default_evaluator(), + evaluator: Union[ + EvaluationPlugin, + Callable[[], EvaluationPlugin] + ] = default_evaluator, eval_every=-1, peval_mode="epoch", ): diff --git a/avalanche/training/supervised/icarl.py b/avalanche/training/supervised/icarl.py index f1601ff73..df58d83b3 100644 --- a/avalanche/training/supervised/icarl.py +++ b/avalanche/training/supervised/icarl.py @@ -1,11 +1,9 @@ -import copy import itertools -from typing import TYPE_CHECKING, Optional, List +from typing import Callable, Optional, List, Union import torch from torch.optim import Optimizer from avalanche.benchmarks.utils import ( - concat_classification_datasets, make_tensor_classification_dataset, classification_subset, ) @@ -40,9 +38,12 @@ def __init__( train_mb_size: int = 1, train_epochs: int = 1, eval_mb_size: Optional[int] = None, - device=None, + device: Union[str, torch.device] = "cpu", plugins: Optional[List[SupervisedPlugin]] = None, - evaluator: EvaluationPlugin = default_evaluator(), + evaluator: Union[ + EvaluationPlugin, + Callable[[], EvaluationPlugin] + ] = default_evaluator, eval_every=-1, ): """Init. diff --git a/avalanche/training/supervised/joint_training.py b/avalanche/training/supervised/joint_training.py index 3e78e8c73..112bf91fe 100644 --- a/avalanche/training/supervised/joint_training.py +++ b/avalanche/training/supervised/joint_training.py @@ -9,7 +9,7 @@ # Website: avalanche.continualai.org # ################################################################################ -from typing import Iterable, List, Optional, Sequence, TypeVar, Union +from typing import Callable, Iterable, List, Optional, Sequence, TypeVar, Union import torch from torch.nn import Module @@ -18,7 +18,10 @@ from avalanche.benchmarks.scenarios.generic_scenario import DatasetExperience from avalanche.benchmarks.utils.utils import concat_datasets from avalanche.core import BasePlugin -from avalanche.training.plugins.evaluation import default_evaluator +from avalanche.training.plugins.evaluation import ( + EvaluationPlugin, + default_evaluator, +) from avalanche.training.templates import SupervisedTemplate from avalanche.models import DynamicModule from avalanche.training.templates.base import ( @@ -66,7 +69,10 @@ def __init__( eval_mb_size: int = 1, device: Union[str, torch.device] = "cpu", plugins: Optional[Sequence[TPluginType]] = None, - evaluator=default_evaluator(), + evaluator: Union[ + EvaluationPlugin, + Callable[[], EvaluationPlugin] + ] = default_evaluator, eval_every=-1, ): """Init. diff --git a/avalanche/training/supervised/l2p.py b/avalanche/training/supervised/l2p.py index b7afd9a6d..1062e7550 100644 --- a/avalanche/training/supervised/l2p.py +++ b/avalanche/training/supervised/l2p.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Callable, List, Optional, Union import numpy as np import torch @@ -36,9 +36,12 @@ def __init__( train_mb_size: int = 1, train_epochs: int = 1, eval_mb_size: Optional[int] = 1, - device: str = "cpu", + device: Union[str, torch.device] = "cpu", plugins: Optional[List["SupervisedPlugin"]] = None, - evaluator: EvaluationPlugin = default_evaluator(), + evaluator: Union[ + EvaluationPlugin, + Callable[[], EvaluationPlugin] + ] = default_evaluator, eval_every: int = -1, peval_mode: str = "epoch", prompt_pool: bool = True, @@ -88,6 +91,10 @@ def __init__( :param use_vit: Boolean to confirm the usage of a visual Transformer.\ Default True """ + + if device is None: + device = torch.device("cpu") + self.num_classes = num_classes self.lr = lr self.sim_coefficient = sim_coefficient @@ -113,7 +120,7 @@ def __init__( if n.startswith(tuple(["blocks", "patch_embed", "cls_token", "norm", "pos_embed"])): p.requires_grad = False - + model.head = torch.nn.Linear(768, num_classes).to(device) optimizer = torch.optim.Adam( diff --git a/avalanche/training/supervised/lamaml.py b/avalanche/training/supervised/lamaml.py index 5877885f8..9fa949bb4 100644 --- a/avalanche/training/supervised/lamaml.py +++ b/avalanche/training/supervised/lamaml.py @@ -1,4 +1,4 @@ -from typing import Sequence, Optional, Union +from typing import Callable, Sequence, Optional, Union import torch import torch.nn as nn @@ -39,7 +39,10 @@ def __init__( eval_mb_size: int = 1, device: Union[str, torch.device] = "cpu", plugins: Optional[Sequence["SupervisedPlugin"]] = None, - evaluator: EvaluationPlugin = default_evaluator(), + evaluator: Union[ + EvaluationPlugin, + Callable[[], EvaluationPlugin] + ] = default_evaluator, eval_every=-1, peval_mode="epoch", ): diff --git a/avalanche/training/supervised/lamaml_v2.py b/avalanche/training/supervised/lamaml_v2.py index 4c88e241d..b008539dc 100644 --- a/avalanche/training/supervised/lamaml_v2.py +++ b/avalanche/training/supervised/lamaml_v2.py @@ -1,4 +1,4 @@ -from typing import List, Sequence, Optional, Union +from typing import Callable, List, Sequence, Optional, Union import pkg_resources from pkg_resources import DistributionNotFound, VersionConflict @@ -43,7 +43,10 @@ def __init__( eval_mb_size: int = 1, device: Union[str, torch.device] = "cpu", plugins: Optional[Sequence["SupervisedPlugin"]] = None, - evaluator: EvaluationPlugin = default_evaluator(), + evaluator: Union[ + EvaluationPlugin, + Callable[[], EvaluationPlugin] + ] = default_evaluator, eval_every=-1, peval_mode="epoch", ): diff --git a/avalanche/training/supervised/mer.py b/avalanche/training/supervised/mer.py index a3b8e3539..37f6234c3 100644 --- a/avalanche/training/supervised/mer.py +++ b/avalanche/training/supervised/mer.py @@ -1,4 +1,4 @@ -from typing import Sequence, Optional, Union +from typing import Callable, Sequence, Optional, Union import torch import torch.nn.functional as F @@ -62,7 +62,10 @@ def __init__( eval_mb_size: int = 1, device: Union[str, torch.device] = "cpu", plugins: Optional[Sequence["SupervisedPlugin"]] = None, - evaluator: EvaluationPlugin = default_evaluator(), + evaluator: Union[ + EvaluationPlugin, + Callable[[], EvaluationPlugin] + ] = default_evaluator, eval_every=-1, peval_mode="epoch", ): diff --git a/avalanche/training/supervised/naive_object_detection.py b/avalanche/training/supervised/naive_object_detection.py index 7fa92a82d..d670d3a1d 100644 --- a/avalanche/training/supervised/naive_object_detection.py +++ b/avalanche/training/supervised/naive_object_detection.py @@ -8,7 +8,7 @@ # E-mail: contact@continualai.org # # Website: avalanche.continualai.org # ################################################################################ -from typing import Optional, Sequence, Union +from typing import Callable, Optional, Sequence, Union import torch from pkg_resources import parse_version @@ -17,6 +17,7 @@ from torch.utils.data import DataLoader from avalanche.benchmarks.utils.data_loader import ( + collate_from_data_or_kwargs, detection_collate_fn, TaskBalancedDataLoader, detection_collate_mbatches_fn, @@ -56,7 +57,10 @@ def __init__( eval_mb_size: int = 1, device: Union[str, torch.device] = "cpu", plugins: Optional[Sequence["SupervisedPlugin"]] = None, - evaluator: EvaluationPlugin = default_evaluator(), + evaluator: Union[ + EvaluationPlugin, + Callable[[], EvaluationPlugin] + ] = default_evaluator, eval_every=-1, peval_mode="epoch", scaler=None, @@ -127,7 +131,7 @@ def make_train_dataloader( self, num_workers=0, shuffle=True, - pin_memory=True, + pin_memory=None, persistent_workers=False, **kwargs ): @@ -139,31 +143,39 @@ def make_train_dataloader( :param num_workers: number of thread workers for the data loading. :param shuffle: True if the data should be shuffled, False otherwise. :param pin_memory: If True, the data loader will copy Tensors into CUDA - pinned memory before returning them. Defaults to True. + pinned memory before returning them. Defaults to None, which means + that the value will be determined by looking at the strategy + `device` field. :param persistent_workers: If True, the data loader will not shutdown the worker processes after a dataset has been consumed once. Used only if `PyTorch >= 1.7.0`. """ - other_dataloader_args = {} - - if parse_version(torch.__version__) >= parse_version("1.7.0"): - other_dataloader_args["persistent_workers"] = persistent_workers - assert self.adapted_dataset is not None - self.dataloader = TaskBalancedDataLoader( - self.adapted_dataset, - oversample_small_groups=True, - num_workers=num_workers, + other_dataloader_args = self._obtain_common_dataloader_parameters( batch_size=self.train_mb_size, + num_workers=num_workers, shuffle=shuffle, pin_memory=pin_memory, - collate_fn=detection_collate_fn, + persistent_workers=persistent_workers, + **kwargs + ) + + self.dataloader = TaskBalancedDataLoader( + self.adapted_dataset, + oversample_small_groups=True, **other_dataloader_args ) - def make_eval_dataloader(self, num_workers=0, pin_memory=True, **kwargs): + def make_eval_dataloader( + self, + num_workers=0, + shuffle=False, + pin_memory=None, + persistent_workers=False, + drop_last=False, + **kwargs): """ Initializes the eval data loader. :param num_workers: How many subprocesses to use for data loading. @@ -177,12 +189,23 @@ def make_eval_dataloader(self, num_workers=0, pin_memory=True, **kwargs): assert self.adapted_dataset is not None - self.dataloader = DataLoader( - self.adapted_dataset, - num_workers=num_workers, + other_dataloader_args = self._obtain_common_dataloader_parameters( batch_size=self.eval_mb_size, + num_workers=num_workers, + shuffle=shuffle, pin_memory=pin_memory, - collate_fn=detection_collate_fn, + persistent_workers=persistent_workers, + drop_last=drop_last, + **kwargs + ) + + collate_from_data_or_kwargs( + self.adapted_dataset, + other_dataloader_args) + + self.dataloader = DataLoader( + self.adapted_dataset, + **other_dataloader_args ) def criterion(self): diff --git a/avalanche/training/supervised/strategy_wrappers.py b/avalanche/training/supervised/strategy_wrappers.py index 815073c1f..e6ad4f417 100644 --- a/avalanche/training/supervised/strategy_wrappers.py +++ b/avalanche/training/supervised/strategy_wrappers.py @@ -8,15 +8,18 @@ # E-mail: contact@continualai.org # # Website: avalanche.continualai.org # ################################################################################ -from typing import Optional, Sequence, List, Union +from typing import Callable, Optional, Sequence, List, Union import torch from torch.nn.parameter import Parameter from torch.nn import Module, CrossEntropyLoss -from torch.optim import Optimizer, SGD +from torch.optim import Optimizer from avalanche.models.pnn import PNN -from avalanche.training.plugins.evaluation import default_evaluator +from avalanche.training.plugins.evaluation import ( + default_evaluator, + default_loggers, +) from avalanche.training.plugins import ( SupervisedPlugin, CWRStarPlugin, @@ -64,9 +67,12 @@ def __init__( train_mb_size: int = 1, train_epochs: int = 1, eval_mb_size: Optional[int] = None, - device=None, + device: Union[str, torch.device] = "cpu", plugins: Optional[List[SupervisedPlugin]] = None, - evaluator: EvaluationPlugin = default_evaluator(), + evaluator: Union[ + EvaluationPlugin, + Callable[[], EvaluationPlugin] + ] = default_evaluator, eval_every=-1, **base_kwargs ): @@ -122,7 +128,10 @@ def __init__( eval_mb_size: int = 1, device: Union[str, torch.device] = "cpu", plugins: Optional[Sequence["SupervisedPlugin"]] = None, - evaluator=default_evaluator(), + evaluator: Union[ + EvaluationPlugin, + Callable[[], EvaluationPlugin] + ] = default_evaluator, eval_every=-1, **base_kwargs ): @@ -175,9 +184,12 @@ def __init__( train_mb_size: int = 1, train_epochs: int = 1, eval_mb_size: Optional[int] = None, - device=None, + device: Union[str, torch.device] = "cpu", plugins: Optional[List[SupervisedPlugin]] = None, - evaluator: EvaluationPlugin = default_evaluator(), + evaluator: Union[ + EvaluationPlugin, + Callable[[], EvaluationPlugin] + ] = default_evaluator, eval_every=-1, **base_kwargs ): @@ -239,9 +251,12 @@ def __init__( train_mb_size: int = 1, train_epochs: int = 1, eval_mb_size: Optional[int] = None, - device=None, + device: Union[str, torch.device] = "cpu", plugins: Optional[List[SupervisedPlugin]] = None, - evaluator: EvaluationPlugin = default_evaluator(), + evaluator: Union[ + EvaluationPlugin, + Callable[[], EvaluationPlugin] + ] = default_evaluator, eval_every=-1, **base_kwargs ): @@ -314,9 +329,12 @@ def __init__( train_mb_size: int = 1, train_epochs: int = 1, eval_mb_size: Optional[int] = None, - device=None, + device: Union[str, torch.device] = "cpu", plugins: Optional[List[SupervisedPlugin]] = None, - evaluator: EvaluationPlugin = default_evaluator(), + evaluator: Union[ + EvaluationPlugin, + Callable[[], EvaluationPlugin] + ] = default_evaluator, eval_every=-1, generator_strategy: Optional[BaseTemplate] = None, replay_size: Optional[int] = None, @@ -417,7 +435,7 @@ def __init__( def get_default_vae_logger(): - return EvaluationPlugin(loggers=[InteractiveLogger()]) + return EvaluationPlugin(loggers=default_loggers) class VAETraining(SupervisedTemplate): @@ -441,9 +459,12 @@ def __init__( train_mb_size: int = 1, train_epochs: int = 1, eval_mb_size: Optional[int] = None, - device=None, + device: Union[str, torch.device] = "cpu", plugins: Optional[List[SupervisedPlugin]] = None, - evaluator: EvaluationPlugin = get_default_vae_logger(), + evaluator: Union[ + EvaluationPlugin, + Callable[[], EvaluationPlugin] + ] = get_default_vae_logger, eval_every=-1, **base_kwargs ): @@ -507,9 +528,12 @@ def __init__( train_mb_size: int = 1, train_epochs: int = 1, eval_mb_size: Optional[int] = None, - device=None, + device: Union[str, torch.device] = "cpu", plugins: Optional[List[SupervisedPlugin]] = None, - evaluator: EvaluationPlugin = default_evaluator(), + evaluator: Union[ + EvaluationPlugin, + Callable[[], EvaluationPlugin] + ] = default_evaluator, eval_every=-1, **base_kwargs ): @@ -573,9 +597,12 @@ def __init__( train_mb_size: int = 1, train_epochs: int = 1, eval_mb_size: Optional[int] = None, - device=None, + device: Union[str, torch.device] = "cpu", plugins: Optional[List[SupervisedPlugin]] = None, - evaluator: EvaluationPlugin = default_evaluator(), + evaluator: Union[ + EvaluationPlugin, + Callable[[], EvaluationPlugin] + ] = default_evaluator, eval_every=-1, **base_kwargs ): @@ -638,9 +665,12 @@ def __init__( train_mb_size: int = 1, train_epochs: int = 1, eval_mb_size: Optional[int] = None, - device=None, + device: Union[str, torch.device] = "cpu", plugins: Optional[List[SupervisedPlugin]] = None, - evaluator: EvaluationPlugin = default_evaluator(), + evaluator: Union[ + EvaluationPlugin, + Callable[[], EvaluationPlugin] + ] = default_evaluator, eval_every=-1, **base_kwargs ): @@ -706,9 +736,12 @@ def __init__( train_mb_size: int = 1, train_epochs: int = 1, eval_mb_size: Optional[int] = None, - device=None, + device: Union[str, torch.device] = "cpu", plugins: Optional[List[SupervisedPlugin]] = None, - evaluator: EvaluationPlugin = default_evaluator(), + evaluator: Union[ + EvaluationPlugin, + Callable[[], EvaluationPlugin] + ] = default_evaluator, eval_every=-1, **base_kwargs ): @@ -774,9 +807,12 @@ def __init__( train_mb_size: int = 1, train_epochs: int = 1, eval_mb_size: Optional[int] = None, - device=None, + device: Union[str, torch.device] = "cpu", plugins: Optional[List[SupervisedPlugin]] = None, - evaluator: EvaluationPlugin = default_evaluator(), + evaluator: Union[ + EvaluationPlugin, + Callable[[], EvaluationPlugin] + ] = default_evaluator, eval_every=-1, **base_kwargs ): @@ -844,9 +880,12 @@ def __init__( train_mb_size: int = 1, train_epochs: int = 1, eval_mb_size: Optional[int] = None, - device=None, + device: Union[str, torch.device] = "cpu", plugins: Optional[List[SupervisedPlugin]] = None, - evaluator: EvaluationPlugin = default_evaluator(), + evaluator: Union[ + EvaluationPlugin, + Callable[[], EvaluationPlugin] + ] = default_evaluator, eval_every=-1, **base_kwargs ): @@ -932,7 +971,10 @@ def __init__( eval_mb_size: int = 1, device: Union[str, torch.device] = "cpu", plugins: Optional[Sequence["SupervisedPlugin"]] = None, - evaluator=default_evaluator(), + evaluator: Union[ + EvaluationPlugin, + Callable[[], EvaluationPlugin] + ] = default_evaluator, eval_every=-1, **base_kwargs ): @@ -1007,9 +1049,12 @@ def __init__( train_mb_size: int = 1, train_epochs: int = 1, eval_mb_size: Optional[int] = None, - device=None, + device: Union[str, torch.device] = "cpu", plugins: Optional[List[SupervisedPlugin]] = None, - evaluator: EvaluationPlugin = default_evaluator(), + evaluator: Union[ + EvaluationPlugin, + Callable[[], EvaluationPlugin] + ] = default_evaluator, eval_every=-1, **base_kwargs ): @@ -1080,9 +1125,12 @@ def __init__( train_mb_size: int = 1, train_epochs: int = 1, eval_mb_size: Optional[int] = None, - device=None, + device: Union[str, torch.device] = "cpu", plugins: Optional[List[SupervisedPlugin]] = None, - evaluator: EvaluationPlugin = default_evaluator(), + evaluator: Union[ + EvaluationPlugin, + Callable[[], EvaluationPlugin] + ] = default_evaluator, eval_every=-1, **base_kwargs ): @@ -1147,9 +1195,12 @@ def __init__( train_mb_size: int = 1, train_epochs: int = 1, eval_mb_size: int = 1, - device=None, + device: Union[str, torch.device] = "cpu", plugins: Optional[List[SupervisedPlugin]] = None, - evaluator: EvaluationPlugin = default_evaluator(), + evaluator: Union[ + EvaluationPlugin, + Callable[[], EvaluationPlugin] + ] = default_evaluator, eval_every=-1, **base_kwargs ): @@ -1225,9 +1276,12 @@ def __init__( train_mb_size: int = 1, train_epochs: int = 1, eval_mb_size: Optional[int] = None, - device=None, + device: Union[str, torch.device] = "cpu", plugins: Optional[List[SupervisedPlugin]] = None, - evaluator: EvaluationPlugin = default_evaluator(), + evaluator: Union[ + EvaluationPlugin, + Callable[[], EvaluationPlugin] + ] = default_evaluator, eval_every=-1, **base_kwargs ): @@ -1309,9 +1363,12 @@ def __init__( train_mb_size: int = 1, train_epochs: int = 1, eval_mb_size: int = 1, - device=None, + device: Union[str, torch.device] = "cpu", plugins: Optional[List[SupervisedPlugin]] = None, - evaluator: EvaluationPlugin = default_evaluator(), + evaluator: Union[ + EvaluationPlugin, + Callable[[], EvaluationPlugin] + ] = default_evaluator, eval_every=-1, **base_kwargs ): @@ -1386,9 +1443,12 @@ def __init__( train_mb_size: int = 1, train_epochs: int = 1, eval_mb_size: Optional[int] = None, - device=None, + device: Union[str, torch.device] = "cpu", plugins: Optional[List[SupervisedPlugin]] = None, - evaluator: EvaluationPlugin = default_evaluator(), + evaluator: Union[ + EvaluationPlugin, + Callable[[], EvaluationPlugin] + ] = default_evaluator, eval_every=-1, **base_kwargs ): diff --git a/avalanche/training/templates/base.py b/avalanche/training/templates/base.py index a14df8cb4..1da839ac9 100644 --- a/avalanche/training/templates/base.py +++ b/avalanche/training/templates/base.py @@ -1,3 +1,4 @@ +import sys import warnings from collections import defaultdict from typing import Generic, Iterable, Sequence, Optional, TypeVar, Union, List @@ -7,6 +8,7 @@ from avalanche.benchmarks import CLExperience, CLStream from avalanche.core import BasePlugin +from avalanche.distributed.distributed_helper import DistributedHelper from avalanche.training.templates.strategy_mixin_protocol import \ BaseStrategyProtocol from avalanche.training.utils import trigger_plugins @@ -49,6 +51,9 @@ def __init__( """ PyTorch model. """ if device is None: + warnings.warn( + 'When instantiating a strategy, please pass a non-None device.' + ) device = 'cpu' self.device = torch.device(device) @@ -73,6 +78,12 @@ def __init__( self.current_eval_stream: Iterable[TExperienceType] = [] """ Current evaluation stream. """ + self._distributed_check: bool = False + """ + Internal flag used to verify the support for distributed + training only once. + """ + ################################################################### # Other variables # ################################################################### @@ -106,6 +117,12 @@ def train( when calling `eval`. If you use multiple streams, they must have different names. """ + if not self._distributed_check: + # Checks if the strategy elements are compatible with + # distributed training + self._check_distributed_training_compatibility() + self._distributed_check = True + self.is_training = True self._stop_training = False @@ -154,6 +171,12 @@ def eval( :return: dictionary containing last recorded value for each metric name """ + if not self._distributed_check: + # Checks if the strategy elements are compatible with + # distributed training + self._check_distributed_training_compatibility() + self._distributed_check = True + # eval can be called inside the train method. # Save the shared state here to restore before returning. prev_train_state = self._save_train_state() @@ -245,6 +268,28 @@ def is_callback(x): f"callbacks: {cb_p - cb_supported}", ) return + + def _check_distributed_training_compatibility(self): + """ + Check if strategy elements (plugins, ...) are compatible with + distributed training. + This check does nothing if not training in distributed mode. + """ + if not DistributedHelper.is_distributed: + return True + + unsupported_plugins = [] + for plugin in self.plugins: + if not getattr(plugin, "supports_distributed", False): + unsupported_plugins.append(plugin) + + if len(unsupported_plugins) > 0: + warnings.warn('You are using plugins that are not compatible' + 'with distributed training:') + for plugin in unsupported_plugins: + print(type(plugin), file=sys.stderr) + + return len(unsupported_plugins) == 0 ######################################################### # Plugin Triggers # diff --git a/avalanche/training/templates/base_sgd.py b/avalanche/training/templates/base_sgd.py index 8b04c9e76..35e21a49a 100644 --- a/avalanche/training/templates/base_sgd.py +++ b/avalanche/training/templates/base_sgd.py @@ -345,11 +345,50 @@ def _before_eval_exp(self, **kwargs): super()._before_eval_exp(**kwargs) + def _obtain_common_dataloader_parameters(self, **kwargs): + """ + Utility function that returns the dictionary of parameters to be passed + to the train and eval dataloaders. + + This function can be useful when in need to customize the data loading + parameters but no radical changes are needed. When overriding to + add/customize parameters, it is recommended to first call this + implementation (super) to obtain a base dictionary of parameters. + + However, if a more deep change is needed in the data loading procedure, + it is better to overrride :meth:`make_train_dataloader` and/or + :meth:`make_eval_dataloader` directly. + + Note: the resulting dictionary does not include the collate function + unless explicitly passed. + + :param kwargs: The dataloader arguments as passed to the `train` + or `eval` method. + :return: A dictionary of parameters to be passed to the DataLoader class + or to one of the Avalanche dataloaders. + """ + other_dataloader_args = {} + + if 'persistent_workers' in kwargs: + if parse_version(torch.__version__) >= parse_version("1.7.0"): + other_dataloader_args["persistent_workers"] = \ + kwargs['persistent_workers'] + else: + del kwargs['persistent_workers'] + + for k, v in kwargs.items(): + other_dataloader_args[k] = v + + if other_dataloader_args.get('pin_memory', None) is None: + other_dataloader_args['pin_memory'] = self.device.type == 'cuda' + + return other_dataloader_args + def make_train_dataloader( self, num_workers=0, shuffle=True, - pin_memory=True, + pin_memory=None, persistent_workers=False, **kwargs ): @@ -364,27 +403,31 @@ def make_train_dataloader( pinned memory before returning them. Defaults to True. """ - other_dataloader_args = {} - - if parse_version(torch.__version__) >= parse_version("1.7.0"): - other_dataloader_args["persistent_workers"] = persistent_workers - for k, v in kwargs.items(): - other_dataloader_args[k] = v - assert self.adapted_dataset is not None - self.dataloader = TaskBalancedDataLoader( - self.adapted_dataset, - oversample_small_groups=True, - num_workers=num_workers, + other_dataloader_args = self._obtain_common_dataloader_parameters( batch_size=self.train_mb_size, + num_workers=num_workers, shuffle=shuffle, pin_memory=pin_memory, + persistent_workers=persistent_workers, + **kwargs + ) + + self.dataloader = TaskBalancedDataLoader( + self.adapted_dataset, + oversample_small_groups=True, **other_dataloader_args ) def make_eval_dataloader( - self, num_workers=0, pin_memory=True, persistent_workers=False, **kwargs + self, + num_workers=0, + shuffle=False, + pin_memory=None, + persistent_workers=False, + drop_last=False, + **kwargs ): """ Initializes the eval data loader. @@ -396,20 +439,25 @@ def make_eval_dataloader( :param kwargs: :return: """ - other_dataloader_args = {} - if parse_version(torch.__version__) >= parse_version("1.7.0"): - other_dataloader_args["persistent_workers"] = persistent_workers - for k, v in kwargs.items(): - other_dataloader_args[k] = v + assert self.adapted_dataset is not None - collate_from_data_or_kwargs(self.adapted_dataset, - other_dataloader_args) - self.dataloader = DataLoader( - self.adapted_dataset, - num_workers=num_workers, + other_dataloader_args = self._obtain_common_dataloader_parameters( batch_size=self.eval_mb_size, + num_workers=num_workers, + shuffle=shuffle, pin_memory=pin_memory, + persistent_workers=persistent_workers, + drop_last=drop_last, + **kwargs + ) + + collate_from_data_or_kwargs( + self.adapted_dataset, + other_dataloader_args) + + self.dataloader = DataLoader( + self.adapted_dataset, **other_dataloader_args ) @@ -487,13 +535,17 @@ def _after_eval_dataset_adaptation(self, **kwargs): trigger_plugins(self, "after_eval_dataset_adaptation", **kwargs) -class PeriodicEval(BaseSGDPlugin): +class PeriodicEval(BaseSGDPlugin, supports_distributed=True): """Schedules periodic evaluation during training. This plugin is automatically configured and added by the BaseTemplate. """ - def __init__(self, eval_every=-1, peval_mode="epoch", do_initial=True): + def __init__( + self, + eval_every=-1, + peval_mode="epoch", + do_initial=True): """Init. :param eval_every: the frequency of the calls to `eval` inside the @@ -513,7 +565,7 @@ def __init__(self, eval_every=-1, peval_mode="epoch", do_initial=True): self.eval_every = eval_every self.peval_mode = peval_mode self.do_initial = do_initial and eval_every > -1 - self.do_final = None + self.do_final: Optional[bool] = None self._is_eval_updated = False def before_training(self, strategy, **kwargs): diff --git a/avalanche/training/templates/common_templates.py b/avalanche/training/templates/common_templates.py index 66b724c7e..d9b686338 100644 --- a/avalanche/training/templates/common_templates.py +++ b/avalanche/training/templates/common_templates.py @@ -36,7 +36,7 @@ class SupervisedTemplate( TDatasetExperience, TMBInput, TMBOutput - ],): + ]): """Base class for continual learning strategies. @@ -94,7 +94,10 @@ def __init__( eval_mb_size: Optional[int] = 1, device: Union[str, torch.device] = "cpu", plugins: Optional[Sequence[BasePlugin]] = None, - evaluator=default_evaluator, + evaluator: Union[ + EvaluationPlugin, + Callable[[], EvaluationPlugin] + ] = default_evaluator, eval_every=-1, peval_mode="epoch", ): @@ -208,18 +211,21 @@ class SupervisedMetaLearningTemplate( PLUGIN_CLASS = SupervisedPlugin def __init__( - self, - model: Module, - optimizer: Optimizer, - criterion=CrossEntropyLoss(), - train_mb_size: int = 1, - train_epochs: int = 1, - eval_mb_size: Optional[int] = 1, - device: Union[str, torch.device] = "cpu", - plugins: Optional[Sequence[BasePlugin]] = None, - evaluator=default_evaluator, - eval_every=-1, - peval_mode="epoch", + self, + model: Module, + optimizer: Optimizer, + criterion=CrossEntropyLoss(), + train_mb_size: int = 1, + train_epochs: int = 1, + eval_mb_size: Optional[int] = 1, + device: Union[str, torch.device] = "cpu", + plugins: Optional[Sequence[BasePlugin]] = None, + evaluator: Union[ + EvaluationPlugin, + Callable[[], EvaluationPlugin] + ] = default_evaluator, + eval_every=-1, + peval_mode="epoch", ): """Init. @@ -448,18 +454,21 @@ class OnlineSupervisedMetaLearningTemplate( PLUGIN_CLASS = SupervisedPlugin def __init__( - self, - model: Module, - optimizer: Optimizer, - criterion=CrossEntropyLoss(), - train_mb_size: int = 1, - train_passes: int = 1, - eval_mb_size: Optional[int] = 1, - device: Union[str, torch.device] = "cpu", - plugins: Optional[Sequence[BasePlugin]] = None, - evaluator=default_evaluator, - eval_every=-1, - peval_mode="epoch", + self, + model: Module, + optimizer: Optimizer, + criterion=CrossEntropyLoss(), + train_mb_size: int = 1, + train_passes: int = 1, + eval_mb_size: Optional[int] = 1, + device: Union[str, torch.device] = "cpu", + plugins: Optional[Sequence[BasePlugin]] = None, + evaluator: Union[ + EvaluationPlugin, + Callable[[], EvaluationPlugin] + ] = default_evaluator, + eval_every=-1, + peval_mode="epoch", ): """Init. diff --git a/avalanche/training/templates/observation_type/batch_observation.py b/avalanche/training/templates/observation_type/batch_observation.py index 53a28b54a..e548ad885 100644 --- a/avalanche/training/templates/observation_type/batch_observation.py +++ b/avalanche/training/templates/observation_type/batch_observation.py @@ -20,7 +20,7 @@ def model_adaptation(self, model=None): def make_optimizer(self): """Optimizer initialization. - Called before each training experiene to configure the optimizer. + Called before each training experience to configure the optimizer. """ # we reset the optimizer's state after each experience. # This allows to add new parameters (new heads) and diff --git a/avalanche/training/utils.py b/avalanche/training/utils.py index 86c54914e..f6010d1d7 100644 --- a/avalanche/training/utils.py +++ b/avalanche/training/utils.py @@ -429,6 +429,7 @@ def __str__(self): __all__ = [ + "trigger_plugins", "load_all_dataset", "zerolike_params_dict", "copy_params_dict", diff --git a/examples/detection.py b/examples/detection.py index 7e71b5dda..0f21e21b4 100644 --- a/examples/detection.py +++ b/examples/detection.py @@ -16,6 +16,7 @@ """ import argparse +from pkg_resources import parse_version import torch import torchvision import logging @@ -78,9 +79,7 @@ def main(args): if args.detection_only: # Ingore the segmentation task # load a model pre-trained on COCO - model = torchvision.models.detection.fasterrcnn_resnet50_fpn( - pretrained=True - ) + model = obtain_base_model(segmentation=False) # Replace the classifier with a new one, that has "num_classes" outputs # 1) Get number of input features for the classifier @@ -91,9 +90,7 @@ def main(args): ) else: # Detection + Segmentation - model = torchvision.models.detection.maskrcnn_resnet50_fpn( - pretrained=True - ) + model = obtain_base_model(segmentation=True) # Replace the classifier with a new one, that has "num_classes" outputs # 1) Get number of input features for the classifier @@ -165,6 +162,35 @@ def main(args): print("Evaluation completed") +def obtain_base_model(segmentation: bool): + torchvision_is_old_version = \ + parse_version(torch.__version__) < parse_version("0.13") + + pretrain_argument = dict() + + if torchvision_is_old_version: + pretrain_argument['pretrained'] = True + else: + if segmentation: + pretrain_argument['weights'] = \ + torchvision.models.detection.mask_rcnn.\ + MaskRCNN_ResNet50_FPN_Weights.DEFAULT + else: + pretrain_argument['weights'] = \ + torchvision.models.detection.faster_rcnn.\ + FasterRCNN_ResNet50_FPN_Weights.DEFAULT + + if segmentation: + model = torchvision.models.detection.maskrcnn_resnet50_fpn( + **pretrain_argument + ) + else: + model = torchvision.models.detection.fasterrcnn_resnet50_fpn( + **pretrain_argument + ) + return model + + def split_penn_fudan( n_experiences: int, train_transform=None, diff --git a/examples/detection_lvis.py b/examples/detection_lvis.py index 64d3ba9da..ec011e5fa 100644 --- a/examples/detection_lvis.py +++ b/examples/detection_lvis.py @@ -19,6 +19,8 @@ from pathlib import Path from typing import Union +from pkg_resources import parse_version + from avalanche.benchmarks.datasets.lvis_dataset import LvisDataset from avalanche.evaluation.metrics.detection import make_lvis_metrics from avalanche.training.supervised.naive_object_detection import ( @@ -71,9 +73,7 @@ def main(args): # MODEL CREATION # load a model pre-trained on COCO - model = torchvision.models.detection.fasterrcnn_resnet50_fpn( - pretrained=True - ) + model = obtain_base_model(segmentation=False) # Just tune the box predictor for p in model.parameters(): @@ -140,6 +140,35 @@ def main(args): print("Evaluation completed") +def obtain_base_model(segmentation: bool): + torchvision_is_old_version = \ + parse_version(torch.__version__) < parse_version("0.13") + + pretrain_argument = dict() + + if torchvision_is_old_version: + pretrain_argument['pretrained'] = True + else: + if segmentation: + pretrain_argument['weights'] = \ + torchvision.models.detection.mask_rcnn.\ + MaskRCNN_ResNet50_FPN_Weights.DEFAULT + else: + pretrain_argument['weights'] = \ + torchvision.models.detection.faster_rcnn.\ + FasterRCNN_ResNet50_FPN_Weights.DEFAULT + + if segmentation: + model = torchvision.models.detection.maskrcnn_resnet50_fpn( + **pretrain_argument + ) + else: + model = torchvision.models.detection.fasterrcnn_resnet50_fpn( + **pretrain_argument + ) + return model + + def split_lvis( n_experiences: int, train_transform=None, diff --git a/tests/distributed/__init__.py b/tests/distributed/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/distributed/distributed_test_utils.py b/tests/distributed/distributed_test_utils.py new file mode 100644 index 000000000..4e17e8f4b --- /dev/null +++ b/tests/distributed/distributed_test_utils.py @@ -0,0 +1,42 @@ +import contextlib +import os + +import torch + +from avalanche.distributed import DistributedHelper + + +def common_dst_tests_setup(): + use_gpu_in_tests = os.environ.get('USE_GPU', 'false').lower() in [ + '1', 'true'] + use_gpu_in_tests = use_gpu_in_tests and torch.cuda.is_available() + DistributedHelper.init_distributed(1234, use_cuda=use_gpu_in_tests) + return use_gpu_in_tests + + +def check_skip_distributed_test() -> bool: + return os.environ.get('DISTRIBUTED_TESTS', 'false').lower() \ + not in ['1', 'true'] + + +def check_skip_distributed_slow_test() -> bool: + return check_skip_distributed_test() or \ + os.environ.get('FAST_TEST', 'false').lower() in ['1', 'true'] + + +@contextlib.contextmanager +def suppress_dst_tests_output(): + if os.environ['LOCAL_RANK'] != 0: + with contextlib.redirect_stderr(None): + with contextlib.redirect_stdout(None): + yield + else: + yield + + +__all__ = [ + 'common_dst_tests_setup', + 'check_skip_distributed_test', + 'check_skip_distributed_slow_test', + 'suppress_dst_tests_output' +] diff --git a/tests/distributed/test_distributed_helper.py b/tests/distributed/test_distributed_helper.py new file mode 100644 index 000000000..123c281b5 --- /dev/null +++ b/tests/distributed/test_distributed_helper.py @@ -0,0 +1,510 @@ +import os +import random +import shutil +import tempfile +import time +import unittest +import numpy as np + +import torch +import torch.distributed as dst +from torch.nn import Module +from torch.nn.parallel import DistributedDataParallel +from avalanche.benchmarks.generators.benchmark_generators import \ + dataset_benchmark +from avalanche.benchmarks.utils.classification_dataset import \ + make_tensor_classification_dataset + +from avalanche.distributed import DistributedHelper +from avalanche.distributed.distributed_helper import \ + RollingSeedContext, BroadcastSeedContext +from avalanche.models import SimpleMLP, as_multitask +from avalanche.models.utils import avalanche_model_adaptation + +from avalanche.training.determinism.rng_manager import RNGManager +from tests.distributed.distributed_test_utils import \ + check_skip_distributed_slow_test, check_skip_distributed_test, \ + suppress_dst_tests_output, common_dst_tests_setup + + +class DistributedHelperTests(unittest.TestCase): + + def setUp(self) -> None: + self.use_gpu_in_tests = common_dst_tests_setup() + + @unittest.skipIf(check_skip_distributed_test(), + 'Distributed tests ignored') + def test_device_id(self): + if self.use_gpu_in_tests: + self.assertEqual(dst.get_rank(), DistributedHelper.get_device_id()) + self.assertEqual(torch.device(f'cuda:{dst.get_rank()}'), + DistributedHelper.make_device()) + else: + self.assertEqual(-1, DistributedHelper.get_device_id()) + self.assertEqual(torch.device('cpu'), + DistributedHelper.make_device()) + + @unittest.skipIf(check_skip_distributed_test(), + 'Distributed tests ignored') + def test_wrap_model(self): + mb_size = 1*2*2*3*5 + num_classes = 11 + torch.manual_seed(1234 + DistributedHelper.rank) + mb_x = torch.randn((mb_size, 32)) + mb_y = torch.randint(0, num_classes, (mb_size,)) + mb_t = torch.full((mb_size,), 1) + model = SimpleMLP(num_classes=num_classes, input_size=32) + model = as_multitask(model, 'classifier') + self.assertIsInstance(model, Module) + + device = DistributedHelper.make_device() + + if device.type == 'cuda': + # Additional test: must raise an error if the model + # is not already in the correct device + with self.assertRaises(Exception): + model_wrapped = DistributedHelper.wrap_model(model) + + model = model.to(device) + + model_wrapped = DistributedHelper.wrap_model(model) + self.assertIsInstance(model_wrapped, DistributedDataParallel) + self.assertNotIsInstance(model, DistributedDataParallel) + + device = DistributedHelper.make_device() + mb_x = mb_x.to(device) + mb_y = mb_y.to(device) + mb_t = mb_t.to(device) + model = model.to(device) + + model.eval() + model_wrapped.eval() + + benchmark = dataset_benchmark( + [make_tensor_classification_dataset( + mb_x, mb_y, mb_t, task_labels=mb_t.tolist() + )], + [make_tensor_classification_dataset( + mb_x, mb_y, mb_t, task_labels=mb_t.tolist() + )] + ) + + avalanche_model_adaptation(model, benchmark.train_stream[0]) + + with torch.no_grad(): + mb_out1 = model(mb_x, mb_t).detach() + self.assertEqual(mb_out1.device, device) + self.assertSequenceEqual([mb_size, num_classes], mb_out1.shape) + + mb_out2 = model_wrapped(mb_x, mb_t).detach() + self.assertEqual(mb_out2.device, device) + self.assertSequenceEqual([mb_size, num_classes], mb_out2.shape) + + self.assertTrue(torch.equal(mb_out1, mb_out2)) + + mb_out_all = DistributedHelper.cat_all(mb_out2) + + start_idx = mb_size * DistributedHelper.rank + end_idx = start_idx + mb_size + + self.assertTrue(torch.equal(mb_out1, + mb_out_all[start_idx: end_idx])) + + self.assertTrue(model is DistributedHelper.unwrap_model(model_wrapped)) + + @unittest.skipIf(check_skip_distributed_test(), + 'Distributed tests ignored') + def test_broadcast_tensor_or_objects(self): + ts = torch.full((10,), DistributedHelper.rank, dtype=torch.long) + DistributedHelper.broadcast(ts) + self.assertTrue(torch.equal(ts, torch.zeros((10,), dtype=torch.long))) + + device = DistributedHelper.make_device() + ts = ts.to(device) + + my_object = {'a': DistributedHelper.rank, 'b': ts} + my_object_from_main = DistributedHelper.broadcast_object(my_object) + + expect = { + 'a': 0, + 'b': torch.full((10,), 0, dtype=torch.long).tolist()} + + self.assertEqual(device, my_object_from_main['b'].device) + my_object_from_main['b'] = my_object_from_main['b'].tolist() + self.assertEqual(expect, my_object_from_main) + + @unittest.skipIf(check_skip_distributed_test(), + 'Distributed tests ignored') + def test_gather_all_objects(self): + ts = torch.full((10,), DistributedHelper.rank, dtype=torch.long) + + device = DistributedHelper.make_device() + ts = ts.to(device) + + my_object = {'a': DistributedHelper.rank, 'b': ts} + all_objects = DistributedHelper.gather_all_objects(my_object) + self.assertIsInstance(all_objects, list) + self.assertEqual(DistributedHelper.world_size, len(all_objects)) + + for rank in range(DistributedHelper.world_size): + expect = { + 'a': rank, + 'b': torch.full((10,), rank, dtype=torch.long).tolist()} + + self.assertEqual(device, all_objects[rank]['b'].device) + all_objects[rank]['b'] = all_objects[rank]['b'].tolist() + self.assertEqual(expect, all_objects[rank]) + + @unittest.skipIf(check_skip_distributed_test(), + 'Distributed tests ignored') + def test_cat_all(self): + if DistributedHelper.rank == 0: + ts = torch.full((10+1, 5), DistributedHelper.rank, dtype=torch.long) + else: + ts = torch.full((10, 5), DistributedHelper.rank, dtype=torch.long) + device = DistributedHelper.make_device() + + if device.type == 'cuda': + # Additional test: tensors do not need to be on the default device + DistributedHelper.cat_all(ts) + + ts = ts.to(device) + + concatenated_tensor = DistributedHelper.cat_all(ts) + + self.assertEqual(device, concatenated_tensor.device) + + expect = torch.empty((DistributedHelper.world_size * 10 + 1, 5), + dtype=torch.long).to(device) + for rank in range(DistributedHelper.world_size): + if rank == 0: + expect[rank * 10: (rank + 1) * 10 + 1] = rank + else: + expect[1 + rank * 10: 1 + (rank + 1) * 10] = rank + + self.assertTrue(torch.equal(concatenated_tensor, expect)) + + @unittest.skipIf(check_skip_distributed_test(), + 'Distributed tests ignored') + def test_gather_all_same_size(self): + ts = torch.full((10, 5), DistributedHelper.rank, dtype=torch.long) + device = DistributedHelper.make_device() + + if device.type == 'cuda': + # Additional test: tensors do not need to be on the default device + DistributedHelper.gather_all(ts) + + # On the other hand, PyTorch all_gather requires tensors to be on + # the default device + with self.assertRaises(Exception): + + out_t = [torch.empty_like(ts) + for _ in range(DistributedHelper.world_size)] + torch.distributed.all_gather(out_t, ts) + + # ... while this should work + out_t = [torch.empty_like(ts).to(device) + for _ in range(DistributedHelper.world_size)] + torch.distributed.all_gather(out_t, ts.to(device)) + + ts = ts.to(device) + + for same_shape in [False, True]: + print(f'same_shape={same_shape}') + # with self.subTest(same_shape=same_shape): + tensor_list = DistributedHelper.gather_all( + ts, same_shape=same_shape) + + self.assertEqual(DistributedHelper.world_size, len(tensor_list)) + + for t in tensor_list: + self.assertEqual(device, t.device) + + for rank in range(DistributedHelper.world_size): + expect = torch.full((10, 5), rank, dtype=torch.long).to(device) + self.assertTrue(torch.equal(tensor_list[rank], expect)) + + @unittest.skipIf(check_skip_distributed_slow_test(), + 'Distributed tests ignored') + def test_gather_all_performance_known_same_shape(self): + ts = torch.full((128, 224, 224, 3), + DistributedHelper.rank, + dtype=torch.float32) + device = DistributedHelper.make_device() + ts = ts.to(device) + + resulting_tensors = [torch.empty_like(ts).to(device) + for _ in range(DistributedHelper.world_size)] + + from tqdm import tqdm + n_times = 30 + torch.distributed.all_gather(resulting_tensors, ts) + start_time = time.time() + for _ in tqdm(range(n_times)): + torch.distributed.all_gather(resulting_tensors, ts) + end_time = time.time() + print('Time taken by PyTorch all_gather', end_time-start_time, + 'avg', (end_time-start_time) / n_times) + + start_time = time.time() + out_list = [None for _ in range(DistributedHelper.world_size)] + torch.distributed.all_gather_object(out_list, ts) + + for _ in tqdm(range(n_times)): + torch.distributed.all_gather_object(out_list, ts) + end_time = time.time() + print('Time taken by PyTorch all_gather_object', end_time-start_time, + 'avg', (end_time-start_time) / n_times) + + @unittest.skipIf(check_skip_distributed_slow_test(), + 'Distributed tests ignored') + def test_gather_all_performance_sync_shape(self): + max_shape_size = 10 + shape = [128, 6, DistributedHelper.rank+1] + \ + ([3] * DistributedHelper.rank) + + device = DistributedHelper.make_device() + + def shape_all_gather(): + ts = torch.zeros((max_shape_size,), dtype=torch.int64) + for i in range(len(shape)): + ts[i] = shape[i] + + ts = ts.to(device) + all_tensors_shape = [torch.empty_like(ts) + for _ in range(DistributedHelper.world_size)] + torch.distributed.all_gather(all_tensors_shape, ts) + all_tensors_shape = [t.cpu() for t in all_tensors_shape] + + for i, t in enumerate(all_tensors_shape): + for x in range(len(t)): + if t[x] == 0: + if x == 0: + # Tensor with 0-length shape + all_tensors_shape[i] = t[:x+1] + else: + all_tensors_shape[i] = t[:x] + break + + def shape_all_gather_objects(): + out_list = [None for _ in range(DistributedHelper.world_size)] + torch.distributed.all_gather_object(out_list, shape) + + from tqdm import tqdm + n_times = 1000 + shape_all_gather() + start_time = time.time() + for _ in tqdm(range(n_times)): + shape_all_gather() + end_time = time.time() + print('Time taken by PyTorch all_gather', end_time-start_time, + 'avg', (end_time-start_time) / n_times) + + start_time = time.time() + shape_all_gather_objects() + + for _ in tqdm(range(n_times)): + shape_all_gather_objects() + end_time = time.time() + print('Time taken by PyTorch all_gather_object', end_time-start_time, + 'avg', (end_time-start_time) / n_times) + + @unittest.skipIf(check_skip_distributed_test(), + 'Distributed tests ignored') + def test_gather_all_same_dim0(self): + ts = torch.full((10, DistributedHelper.rank+1), + DistributedHelper.rank, + dtype=torch.long) + device = DistributedHelper.make_device() + + ts = ts.to(device) + + tensor_list = DistributedHelper.gather_all(ts) + self.assertEqual(DistributedHelper.world_size, len(tensor_list)) + + for t in tensor_list: + self.assertEqual(device, t.device) + + for rank in range(DistributedHelper.world_size): + expect = torch.full((10, rank+1), + rank, + dtype=torch.long).to(device) + self.assertTrue(torch.equal(tensor_list[rank], expect)) + + @unittest.skipIf(check_skip_distributed_test(), + 'Distributed tests ignored') + def test_gather_all_same_dim1_n(self): + ts = torch.full((10+DistributedHelper.rank, 5), + DistributedHelper.rank, + dtype=torch.long) + device = DistributedHelper.make_device() + + ts = ts.to(device) + + tensor_list = DistributedHelper.gather_all(ts) + self.assertEqual(DistributedHelper.world_size, len(tensor_list)) + + for t in tensor_list: + self.assertEqual(device, t.device) + + for rank in range(DistributedHelper.world_size): + expect = torch.full((10+rank, 5), + rank, + dtype=torch.long).to(device) + self.assertTrue(torch.equal(tensor_list[rank], expect)) + + @unittest.skipIf(check_skip_distributed_test(), + 'Distributed tests ignored') + def test_gather_all_zero_shaped(self): + ts = torch.full(tuple(), DistributedHelper.rank, dtype=torch.long) + device = DistributedHelper.make_device() + + ts = ts.to(device) + + for same_shape in [False, True]: + print(f'same_shape={same_shape}') + # with self.subTest(same_shape=same_shape): + tensor_list = DistributedHelper.gather_all( + ts, + same_shape=same_shape) + self.assertEqual(DistributedHelper.world_size, len(tensor_list)) + + for t in tensor_list: + self.assertEqual(device, t.device) + + for rank in range(DistributedHelper.world_size): + expect = torch.full(tuple(), rank, dtype=torch.long).to(device) + self.assertTrue(torch.equal(tensor_list[rank], expect)) + + @unittest.skipIf(check_skip_distributed_test(), + 'Distributed tests ignored') + def test_check_equal_tensors(self): + if DistributedHelper.world_size == 1 and \ + DistributedHelper.get_device_id() >= 0: + self.skipTest('When using CUDA, there must be at ' + 'least two processes to run this test') + torch.manual_seed(1234) + ts = torch.randn((100,)) + DistributedHelper.check_equal_tensors(ts) + + torch.manual_seed(1234 + DistributedHelper.rank) + ts = torch.randn((100,)) + with self.assertRaises(Exception): + DistributedHelper.check_equal_tensors(ts) + + @unittest.skipIf(check_skip_distributed_test(), + 'Distributed tests ignored') + def test_fields(self): + self.assertEqual(dst.get_rank(), DistributedHelper.rank) + self.assertEqual(dst.get_world_size(), DistributedHelper.world_size) + self.assertEqual(True, DistributedHelper.is_distributed) + self.assertEqual(dst.get_rank() == 0, DistributedHelper.is_main_process) + + if self.use_gpu_in_tests: + self.assertEqual('nccl', DistributedHelper.backend) + self.assertTrue(DistributedHelper.forced_cuda_comm) + else: + self.assertEqual('gloo', DistributedHelper.backend) + self.assertFalse(DistributedHelper.forced_cuda_comm) + + @unittest.skipIf(check_skip_distributed_test(), + 'Distributed tests ignored') + def test_set_random_seeds_and_align(self): + DistributedHelper.set_random_seeds(5678) + + self.assertEqual(297076, np.random.randint(0, 1000000)) + self.assertEqual(643380, torch.randint(0, 1000000, (1,)).item()) + self.assertEqual(683410, random.randint(0, 1000000)) + + if DistributedHelper.is_main_process: + np.random.randint(0, 1000000) + torch.randint(0, 1000000, (1,)) + random.randint(0, 1000000) + + DistributedHelper.align_seeds() + + ref_values = ( + int(np.random.randint(0, 1000000)), + int(torch.randint(0, 1000000, (1,))), + int(random.randint(0, 1000000)) + ) + + DistributedHelper.check_equal_objects(ref_values) + + @unittest.skipIf(check_skip_distributed_test(), + 'Distributed tests ignored') + def test_rolling_seed_aligner(self): + RNGManager.set_random_seeds(4321) + + with RollingSeedContext(): + RNGManager.set_random_seeds(1234 + DistributedHelper.rank) + random.randint(0, 2 ** 64 - 1) + + final_value = random.randint(0, 2 ** 64 - 1) + self.assertEqual(14732185405572191734, final_value) + + @unittest.skipIf(check_skip_distributed_test(), + 'Distributed tests ignored') + def test_broadcast_seed_aligner(self): + RNGManager.set_random_seeds(4321) + + with BroadcastSeedContext(): + RNGManager.set_random_seeds(1234 + DistributedHelper.rank) + random.randint(0, 2 ** 64 - 1) + + final_value = random.randint(0, 2 ** 64 - 1) + self.assertEqual(15306775005444441373, final_value) + + @unittest.skipIf(check_skip_distributed_test(), + 'Distributed tests ignored') + def test_main_process_first(self): + tmpdirname = '' + try: + my_rank = DistributedHelper.rank + if DistributedHelper.is_main_process: + tmpdirname = tempfile.mkdtemp() + + tmpdirname = DistributedHelper.broadcast_object(tmpdirname) + + with DistributedHelper.main_process_first(): + + for _ in range(2): + time.sleep(0.1 + my_rank * 0.05) + files = list(os.listdir(tmpdirname)) + if DistributedHelper.is_main_process: + self.assertEqual(0, len(files)) + else: + self.assertIn(f'rank0', files) + self.assertNotIn(f'rank{my_rank}', files) + + with open(os.path.join(tmpdirname, f'rank{my_rank}'), 'w') \ + as f: + f.write('ok') + + for _ in range(2): + time.sleep(0.1 + my_rank * 0.05) + files = list(os.listdir(tmpdirname)) + if DistributedHelper.is_main_process: + self.assertEqual(1, len(files)) + self.assertIn(f'rank0', files) + else: + self.assertIn(f'rank0', files) + self.assertIn(f'rank{my_rank}', files) + + DistributedHelper.barrier() + files = set(os.listdir(tmpdirname)) + expect = set([f'rank{rnk}' + for rnk in range(DistributedHelper.world_size)]) + self.assertSetEqual(expect, files) + DistributedHelper.barrier() + finally: + if tmpdirname is not None and DistributedHelper.is_main_process: + shutil.rmtree(tmpdirname) + + +if __name__ == "__main__": + with suppress_dst_tests_output(): + verbosity = 1 + if DistributedHelper.rank > 0: + verbosity = 0 + unittest.main(verbosity=verbosity) diff --git a/tests/run_dist_tests.py b/tests/run_dist_tests.py new file mode 100644 index 000000000..1ef96d7a4 --- /dev/null +++ b/tests/run_dist_tests.py @@ -0,0 +1,105 @@ +import os +import signal +import sys +import unittest +from subprocess import Popen +from typing import Union, Set +from unittest import TestSuite, TestCase + +import click + + +def get_distributed_test_cases(suite: Union[TestCase, TestSuite]) -> Set[str]: + found_cases = set() + if isinstance(suite, TestSuite): + for x in suite: + found_cases.update(get_distributed_test_cases(x)) + + if isinstance(suite, TestCase): + case_id = suite.id() + + if case_id.startswith('distributed.') or \ + case_id.startswith('tests.distributed.'): + found_cases.add(case_id) + + if '_FailedTest' in case_id: + raise RuntimeError( + f'Errors encountered while listing test cases: {case_id}') + + return found_cases + + +@click.command() +@click.argument('test_cases', nargs=-1) +def run_distributed_suites(test_cases): + cases_names = get_distributed_test_cases( + unittest.defaultTestLoader.discover('.')) # Don't change the path! + cases_names = list(sorted(cases_names)) + print(cases_names) + if len(test_cases) > 0: + test_cases = set(test_cases) + cases_names = [x for x in cases_names if x in test_cases] + + if set(cases_names) != test_cases: + print('Some cases have not been found!', + test_cases - set(cases_names)) + sys.exit(1) + + print('Running', len(cases_names), 'tests') + p = None + success = True + exited = False + failed_test_cases = set() + + use_gpu_in_tests = os.environ.get('USE_GPU', 'false').lower() in [ + '1', 'true'] + if use_gpu_in_tests: + print('Running tests using GPUs') + import torch + nproc_per_node = torch.cuda.device_count() + else: + print('Running tests using CPU only') + nproc_per_node = 2 + + for case_name in cases_names: + if exited: + print('Exiting due to keyboard interrupt') + break + print('Running test:', case_name, flush=True) + try: + my_env = os.environ.copy() + my_env['DISTRIBUTED_TESTS'] = '1' + p = Popen( + ['python', '-m', 'torch.distributed.run', '--nnodes=1', + f'--nproc_per_node={nproc_per_node}', + '-m', 'unittest', case_name], + stdout=sys.stdout, + stderr=sys.stderr, + env=my_env) + p.communicate() + except KeyboardInterrupt: + success = False + exited = True + p.send_signal(signal.SIGINT) + finally: + exit_code = p.wait() + print('Test completed with code', exit_code) + success = success and exit_code == 0 + p = None + + if exit_code != 0: + failed_test_cases.add(case_name) + + if success: + print('Tests completed successfully') + sys.exit(0) + else: + print('The following tests terminated with errors:') + for failed_case in sorted(failed_test_cases): + print(failed_case) + + sys.exit(1) + + +if __name__ == '__main__': + run_distributed_suites() diff --git a/tests/training/test_online_strategies.py b/tests/training/test_online_strategies.py index b3270347f..14021764c 100644 --- a/tests/training/test_online_strategies.py +++ b/tests/training/test_online_strategies.py @@ -53,7 +53,7 @@ def test_naive(self): train_mb_size=1, device=self.device, eval_mb_size=50, - evaluator=default_evaluator(), + evaluator=default_evaluator, ) ocl_benchmark = OnlineCLScenario(benchmark_streams, access_task_boundaries=True) @@ -68,7 +68,7 @@ def test_naive(self): train_mb_size=1, device=self.device, eval_mb_size=50, - evaluator=default_evaluator(), + evaluator=default_evaluator, ) ocl_benchmark = OnlineCLScenario(benchmark_streams, access_task_boundaries=False) diff --git a/tests/unit_tests_utils.py b/tests/unit_tests_utils.py index 922bdd1e5..e672c4c2c 100644 --- a/tests/unit_tests_utils.py +++ b/tests/unit_tests_utils.py @@ -29,7 +29,7 @@ if "UPDATE_METRICS" in os.environ: UPDATE_METRICS = os.environ["UPDATE_METRICS"].lower() == "true" -print(f"UPDATE_METRICS: {UPDATE_METRICS}") +# print(f"UPDATE_METRICS: {UPDATE_METRICS}") def is_github_action():