Skip to content

Commit

Permalink
Don't compute FID clean and FID clip if one channel image, fixes #6
Browse files Browse the repository at this point in the history
  • Loading branch information
giulio98 committed May 5, 2024
1 parent cb13fb0 commit c1aa1f6
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 17 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ Before you begin with any experiments, ensure to create a `.env` file with the f
export WANDB_API_KEY=<your wandb api key>
export HOME=<your_home_directory> # e.g., /home/username
export CUDA_HOME=/usr/local/cuda
export PROJECT_ROOT=<your_project_directory> # /home/username/functional_diffusion_processes
export PROJECT_ROOT=<your_project_directory> # /home/username/functional-diffusion-processes
export DATA_ROOT=${PROJECT_ROOT}/data
export LOGS_ROOT=${PROJECT_ROOT}/logs
export TFDS_DATA_DIR=${DATA_ROOT}/tensorflow_datasets
Expand Down
9 changes: 7 additions & 2 deletions conf/experiments_maml/eval_mnist.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@



# @package _global_
defaults:
- override /trainers: trainer_maml
Expand All @@ -24,7 +27,7 @@ trainers:
evaluation_config:
seed: 43 # random seed for reproducibility
eval_dir: ${oc.env:LOGS_ROOT}/inr_mnist # directory where evaluation results are saved
num_samples: 10000 # number of samples to be generated for evaluation
num_samples: 16000 # number of samples to be generated for evaluation

sdes:
sde_config:
Expand Down Expand Up @@ -65,4 +68,6 @@ models:
datasets:
test:
data_config:
batch_size: 1
image_height_size: 32
image_width_size: 32
batch_size: 512
2 changes: 1 addition & 1 deletion conf/trainers/trainer_maml.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,6 @@ trainer_logging:
use_wandb: True # if True, wandb is used for logging
wandb_init:
name: ${trainers.training_config.save_dir}
project: "your_project"
project: "fdp"
entity: "your_entity"
save_code: False
28 changes: 15 additions & 13 deletions src/functional_diffusion_processes/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,22 +687,24 @@ class EvalMeta:
fid_score, inception_score = fid_metric.compute_fid(self.eval_dir, num_sampling_rounds)
pylogger.info("FID: %.6e" % fid_score)
pylogger.info("Inception score %.6e" % inception_score)
if not ds_test.data_config.output_size == 1:
# Compute FID clean
clean_dataset = os.path.join(
fid_metric.metric_config.real_features_path,
f"{fid_metric.metric_config.dataset_name.lower()}_clean",
)
clean_fake_dataset = os.path.join(self.eval_dir, "clean")
fid_clean = fid.compute_fid(clean_fake_dataset, clean_dataset, mode="clean")
pylogger.info("FID clean: %.6e" % fid_clean)

# Compute FID clean
clean_dataset = os.path.join(
fid_metric.metric_config.real_features_path, f"{fid_metric.metric_config.dataset_name.lower()}_clean"
)
clean_fake_dataset = os.path.join(self.eval_dir, "clean")
fid_clean = fid.compute_fid(clean_fake_dataset, clean_dataset, mode="clean")
pylogger.info("FID clean: %.6e" % fid_clean)

# Compute FID-CLIP
fid_clip = fid.compute_fid(clean_fake_dataset, clean_dataset, mode="clean", model_name="clip_vit_b_32")
pylogger.info("FID-CLIP: %.6e" % fid_clip)
# Compute FID-CLIP
fid_clip = fid.compute_fid(clean_fake_dataset, clean_dataset, mode="clean", model_name="clip_vit_b_32")
pylogger.info("FID-CLIP: %.6e" % fid_clip)

if self.logging.use_wandb:
wandb.log({"FID": float(fid_score)})
wandb.log({"inception score": float(inception_score)})
wandb.log({"FID clean": float(fid_clean)})
wandb.log({"FID-CLIP": float(fid_clip)})
if not ds_test.data_config.output_size == 1:
wandb.log({"FID clean": float(fid_clean)})
wandb.log({"FID-CLIP": float(fid_clip)})
wandb.finish()

0 comments on commit c1aa1f6

Please sign in to comment.