From 0651176b7fc981522c298ec7ee4cda2b781005af Mon Sep 17 00:00:00 2001 From: xavier-owkin <153643450+xavier-owkin@users.noreply.github.com> Date: Thu, 25 Jan 2024 17:50:28 +0100 Subject: [PATCH] feat: Allow empty datasets in init_data_loaders (#300) --- environment.yml | 2 +- flamby/benchmarks/benchmark_utils.py | 38 ++++++++++++++++++---------- flamby/strategies/fed_avg.py | 12 ++++----- flamby/strategies/utils.py | 13 +++++----- setup.py | 4 +-- 5 files changed, 39 insertions(+), 30 deletions(-) diff --git a/environment.yml b/environment.yml index bf5e3ce4b..0006c1adc 100644 --- a/environment.yml +++ b/environment.yml @@ -3,7 +3,7 @@ channels: - conda - conda-forge dependencies: - - python=3.8 + - python=3.10 - pip - pytest - ipython \ No newline at end of file diff --git a/flamby/benchmarks/benchmark_utils.py b/flamby/benchmarks/benchmark_utils.py index a7498a576..a771ef1db 100644 --- a/flamby/benchmarks/benchmark_utils.py +++ b/flamby/benchmarks/benchmark_utils.py @@ -198,22 +198,32 @@ def init_data_loaders( batch_size_test = batch_size if batch_size_test is None else batch_size_test if not pooled: training_dls = [ - dl( - dataset(center=i, train=True, pooled=False), - batch_size=batch_size, - shuffle=True, - num_workers=num_workers, - collate_fn=collate_fn, + ( + dl( + center_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + collate_fn=collate_fn, + ) + if len(center_dataset := dataset(center=i, train=True, pooled=False)) + > 0 + else None ) for i in range(num_clients) ] test_dls = [ - dl( - dataset(center=i, train=False, pooled=False), - batch_size=batch_size_test, - shuffle=False, - num_workers=num_workers, - collate_fn=collate_fn, + ( + dl( + center_dataset, + batch_size=batch_size_test, + shuffle=False, + num_workers=num_workers, + collate_fn=collate_fn, + ) + if len(center_dataset := dataset(center=i, train=False, pooled=False)) + > 0 + else None ) for i in range(num_clients) ] @@ -569,7 +579,9 @@ def ensemble_perf_from_predictions( return ensemble_perf -def set_dataset_specific_config(dataset_name, compute_ensemble_perf=False, use_gpu=True): +def set_dataset_specific_config( + dataset_name, compute_ensemble_perf=False, use_gpu=True +): """_summary_ Parameters diff --git a/flamby/strategies/fed_avg.py b/flamby/strategies/fed_avg.py index e72645913..02e80270a 100644 --- a/flamby/strategies/fed_avg.py +++ b/flamby/strategies/fed_avg.py @@ -178,13 +178,11 @@ def perform_round(self): None for _ in range(len(local_updates[0]["updates"])) ] for idx_weight in range(len(local_updates[0]["updates"])): - aggregated_delta_weights[idx_weight] = sum( - [ - local_updates[idx_client]["updates"][idx_weight] - * local_updates[idx_client]["n_samples"] - for idx_client in range(self.num_clients) - ] - ) + aggregated_delta_weights[idx_weight] = sum([ + local_updates[idx_client]["updates"][idx_weight] + * local_updates[idx_client]["n_samples"] + for idx_client in range(self.num_clients) + ]) aggregated_delta_weights[idx_weight] /= float(self.total_number_of_samples) # Update models diff --git a/flamby/strategies/utils.py b/flamby/strategies/utils.py index 4f552cd86..e4d185ecd 100644 --- a/flamby/strategies/utils.py +++ b/flamby/strategies/utils.py @@ -270,7 +270,9 @@ def _prox_local_train(self, dataloader_with_memory, num_updates, mu): _loss = _prox_loss.detach() if mu > 0.0: - squared_norm = compute_model_diff_squared_norm(model_initial, self.model) + squared_norm = compute_model_diff_squared_norm( + model_initial, self.model + ) _prox_loss += mu / 2 * squared_norm # Backpropagation @@ -451,12 +453,9 @@ def check_exchange_compliance(tensors_list, max_bytes, units="bytes"): """ assert units in ["bytes", "bits", "megabytes", "gigabytes"] assert isinstance(tensors_list, list), "You should provide a list of tensors." - assert all( - [ - (isinstance(t, np.ndarray) or isinstance(t, torch.Tensor)) - for t in tensors_list - ] - ) + assert all([ + (isinstance(t, np.ndarray) or isinstance(t, torch.Tensor)) for t in tensors_list + ]) bytes_count = 0 for t in tensors_list: if isinstance(t, np.ndarray): diff --git a/setup.py b/setup.py index 6e543e227..c3e2dcc78 100644 --- a/setup.py +++ b/setup.py @@ -77,8 +77,8 @@ def run(self): "pydicom", "requests", "scipy", - "sphinx==4.5.0", - "sphinx-rtd-theme==1.0.0", + "sphinx", + "sphinx-rtd-theme", ] tests = ["albumentations", "pytest"] all_extra = camelyon16 + heart + isic2019 + ixi + kits19 + lidc + tcga + docs + tests