Skip to content

Commit

Permalink
feat: Use ssl features for fed_benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
xavier-owkin committed Jan 8, 2024
1 parent 37330d0 commit 6209218
Showing 1 changed file with 66 additions and 31 deletions.
97 changes: 66 additions & 31 deletions flamby/benchmarks/fed_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import copy
import traceback

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -113,6 +114,8 @@ def main(args_cli):

# We can now instantiate the dataset specific model on CPU
global_init = Baseline()
if args_cli.use_ssl_features and dataset_name == "fed_camelyon16":
global_init = Baseline(768)

# We parse the hyperparams from the config or from the CLI if strategy is given
strategy_specific_hp_dicts = get_strategies(
Expand Down Expand Up @@ -211,20 +214,29 @@ def main(args_cli):
df.drop(index_of_interest, inplace=True)
m = copy.deepcopy(global_init)
set_seed(args_cli.seed)
m = train_single_centric(
m,
train_pooled,
use_gpu,
"Pooled",
pooled_hyperparameters["optimizer_class"],
pooled_hyperparameters["learning_rate"],
BaselineLoss,
NUM_EPOCHS_POOLED,
dp_target_epsilon=pooled_hyperparameters["dp_target_epsilon"],
dp_target_delta=pooled_hyperparameters["dp_target_delta"],
dp_max_grad_norm=pooled_hyperparameters["dp_max_grad_norm"],
seed=args_cli.seed,
)
try:
m = train_single_centric(
m,
train_pooled,
use_gpu,
"Pooled",
pooled_hyperparameters["optimizer_class"],
pooled_hyperparameters["learning_rate"],
BaselineLoss,
NUM_EPOCHS_POOLED,
dp_target_epsilon=pooled_hyperparameters["dp_target_epsilon"],
dp_target_delta=pooled_hyperparameters["dp_target_delta"],
dp_max_grad_norm=pooled_hyperparameters["dp_max_grad_norm"],
seed=args_cli.seed,
)
except RuntimeError:
traceback.print_exc()
raise RuntimeError(
"The previous error might be linked to the use of phikon features. If"
" you want to use them make sure to have the parameter"
" --use-ssl-features enabled as well as the correct path in the"
" dataset_location.yaml file."
)
(
perf_dict,
pooled_perf_dict,
Expand Down Expand Up @@ -461,7 +473,6 @@ def main(args_cli):


if __name__ == "__main__":

parser = argparse.ArgumentParser()
parser.add_argument(
"--GPU", type=int, default=0, help="GPU to run the training on (if available)"
Expand Down Expand Up @@ -513,33 +524,41 @@ def main(args_cli):
"-nft",
type=int,
default=None,
help="The number of SGD fine-tuning updates to be"
"performed on the model at the personalization step,"
"if strategy is given and that it is FedAvgFineTuning",
help=(
"The number of SGD fine-tuning updates to be"
"performed on the model at the personalization step,"
"if strategy is given and that it is FedAvgFineTuning"
),
)
parser.add_argument(
"--tau",
"-tau",
type=float,
default=None,
help="FedOpt tau parameter used only if strategy is "
"given and that it is a fedopt strategy",
help=(
"FedOpt tau parameter used only if strategy is "
"given and that it is a fedopt strategy"
),
)
parser.add_argument(
"--beta1",
"-b1",
type=float,
default=None,
help="FedOpt beta1 parameter used only if strategy is "
"given and that it is a fedopt strategy",
help=(
"FedOpt beta1 parameter used only if strategy is "
"given and that it is a fedopt strategy"
),
)
parser.add_argument(
"--beta2",
"-b2",
type=float,
default=None,
help="FedOpt beta2 parameter used only if strategy is"
" given and that it is a fedopt strategy",
help=(
"FedOpt beta2 parameter used only if strategy is"
" given and that it is a fedopt strategy"
),
)
parser.add_argument(
"--strategy",
Expand Down Expand Up @@ -578,22 +597,24 @@ def main(args_cli):
"-dpe",
type=float,
default=None,
help="the target epsilon for (epsilon, delta)-differential" "private guarantee",
help="the target epsilon for (epsilon, delta)-differential private guarantee",
)
parser.add_argument(
"--dp_target_delta",
"-dpd",
type=float,
default=None,
help="the target delta for (epsilon, delta)-differential" "private guarantee",
help="the target delta for (epsilon, delta)-differential private guarantee",
)
parser.add_argument(
"--dp_max_grad_norm",
"-mgn",
type=float,
default=None,
help="the maximum L2 norm of per-sample gradients; "
"used to enforce differential privacy",
help=(
"the maximum L2 norm of per-sample gradients; "
"used to enforce differential privacy"
),
)
parser.add_argument(
"--log",
Expand Down Expand Up @@ -621,15 +642,29 @@ def main(args_cli):
"-scb",
default=None,
type=str,
help="Whether or not to compute only one single-centric baseline and which one.",
help=(
"Whether or not to compute only one single-centric baseline and which one."
),
choices=["Pooled", "Local"],
)
parser.add_argument(
"--nlocal",
default=0,
type=int,
help="Will only be used if --single-centric-baseline Local, will test"
"only training on Local {nlocal}.",
help=(
"Will only be used if --single-centric-baseline Local, will test"
"only training on Local {nlocal}."
),
)
parser.add_argument(
"--use-ssl-features",
action="store_true",
help=(
"Whether to use the much more performant phikon feature extractor on"
" Camelyon16, trained with self-supervised learning on histology datasets"
" from https://www.medrxiv.org/content/10.1101/2023.07.21.23292757v2"
" instead of imagenet-trained resnet."
),
)
parser.add_argument("--seed", default=0, type=int, help="Seed")

Expand Down

0 comments on commit 6209218

Please sign in to comment.