Skip to content

Commit

Permalink
feat: Allow empty datasets in init_data_loaders (#300)
Browse files Browse the repository at this point in the history
  • Loading branch information
xavier-owkin authored Jan 25, 2024
1 parent 3cab86e commit 0651176
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 30 deletions.
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ channels:
- conda
- conda-forge
dependencies:
- python=3.8
- python=3.10
- pip
- pytest
- ipython
38 changes: 25 additions & 13 deletions flamby/benchmarks/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,22 +198,32 @@ def init_data_loaders(
batch_size_test = batch_size if batch_size_test is None else batch_size_test
if not pooled:
training_dls = [
dl(
dataset(center=i, train=True, pooled=False),
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
collate_fn=collate_fn,
(
dl(
center_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
collate_fn=collate_fn,
)
if len(center_dataset := dataset(center=i, train=True, pooled=False))
> 0
else None
)
for i in range(num_clients)
]
test_dls = [
dl(
dataset(center=i, train=False, pooled=False),
batch_size=batch_size_test,
shuffle=False,
num_workers=num_workers,
collate_fn=collate_fn,
(
dl(
center_dataset,
batch_size=batch_size_test,
shuffle=False,
num_workers=num_workers,
collate_fn=collate_fn,
)
if len(center_dataset := dataset(center=i, train=False, pooled=False))
> 0
else None
)
for i in range(num_clients)
]
Expand Down Expand Up @@ -569,7 +579,9 @@ def ensemble_perf_from_predictions(
return ensemble_perf


def set_dataset_specific_config(dataset_name, compute_ensemble_perf=False, use_gpu=True):
def set_dataset_specific_config(
dataset_name, compute_ensemble_perf=False, use_gpu=True
):
"""_summary_
Parameters
Expand Down
12 changes: 5 additions & 7 deletions flamby/strategies/fed_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,11 @@ def perform_round(self):
None for _ in range(len(local_updates[0]["updates"]))
]
for idx_weight in range(len(local_updates[0]["updates"])):
aggregated_delta_weights[idx_weight] = sum(
[
local_updates[idx_client]["updates"][idx_weight]
* local_updates[idx_client]["n_samples"]
for idx_client in range(self.num_clients)
]
)
aggregated_delta_weights[idx_weight] = sum([
local_updates[idx_client]["updates"][idx_weight]
* local_updates[idx_client]["n_samples"]
for idx_client in range(self.num_clients)
])
aggregated_delta_weights[idx_weight] /= float(self.total_number_of_samples)

# Update models
Expand Down
13 changes: 6 additions & 7 deletions flamby/strategies/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,9 @@ def _prox_local_train(self, dataloader_with_memory, num_updates, mu):
_loss = _prox_loss.detach()

if mu > 0.0:
squared_norm = compute_model_diff_squared_norm(model_initial, self.model)
squared_norm = compute_model_diff_squared_norm(
model_initial, self.model
)
_prox_loss += mu / 2 * squared_norm

# Backpropagation
Expand Down Expand Up @@ -451,12 +453,9 @@ def check_exchange_compliance(tensors_list, max_bytes, units="bytes"):
"""
assert units in ["bytes", "bits", "megabytes", "gigabytes"]
assert isinstance(tensors_list, list), "You should provide a list of tensors."
assert all(
[
(isinstance(t, np.ndarray) or isinstance(t, torch.Tensor))
for t in tensors_list
]
)
assert all([
(isinstance(t, np.ndarray) or isinstance(t, torch.Tensor)) for t in tensors_list
])
bytes_count = 0
for t in tensors_list:
if isinstance(t, np.ndarray):
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def run(self):
"pydicom",
"requests",
"scipy",
"sphinx==4.5.0",
"sphinx-rtd-theme==1.0.0",
"sphinx",
"sphinx-rtd-theme",
]
tests = ["albumentations", "pytest"]
all_extra = camelyon16 + heart + isic2019 + ixi + kits19 + lidc + tcga + docs + tests
Expand Down

0 comments on commit 0651176

Please sign in to comment.