Skip to content

Commit

Permalink
Use pre-downsampled training data (#8)
Browse files Browse the repository at this point in the history
* handle pre-resampled images and tidy up dataloader

* update data source and test download func now that data is lightweight

* lint
  • Loading branch information
alisterburt authored Dec 6, 2022
1 parent 741d1d0 commit e93ff65
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 45 deletions.
Empty file.
7 changes: 7 additions & 0 deletions src/fidder/data/_tests/test_download_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from fidder.data import download_training_data


def test_download_training_data(tmp_path):
download_training_data(tmp_path)
assert (tmp_path / 'images').exists()
assert (tmp_path / 'masks').exists()
4 changes: 2 additions & 2 deletions src/fidder/data/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ def download_training_data(output_directory: Path):
subprocess.run(
[
"zenodo_get",
"7104305",
"7404985",
"--output-dir",
str(output_directory),
]
)
zipped_archive = output_directory / "tilt_images_with_fiducials.zip"
zipped_archive = output_directory / "fidder_data.zip"
shutil.unpack_archive(
zipped_archive,
extract_dir=output_directory,
Expand Down
79 changes: 36 additions & 43 deletions src/fidder/data/training_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,30 +11,26 @@
from torchvision.transforms import functional as TF

from .augmentation import random_flip, random_square_crop_at_scale
from ..utils import calculate_resampling_factor
from ..constants import TRAINING_IMAGE_DIMENSIONS, TRAINING_PIXEL_SIZE
from ..constants import TRAINING_IMAGE_DIMENSIONS


class FidderDataset(Dataset):
"""Fiducial segmentation dataset.
https://zenodo.org/record/7104305
- Images are in subfolders of root_dir called 'images' and 'masks'
- data are resampled to 8 Å/px for training
- data are cropped to (512, 512) then normalised to mean=0, std=1
before being passed to the network
- data on disk are resampled to 8 Å/px
- data are cropped in memory to (512, 512) then normalised to mean=0,
std=1 before being passed to the network
- train/eval mode activated via methods of the same name
"""

PIXEL_SIZE_MAP = {
"EMPIAR-10164": 1.35,
"EMPIAR-10814": 2.96,
"EMPIAR-10453": 1.33,
"EMPIAR-10364": 2.24,
"EMPIAR-10631": 1.38,
}

def __init__(self, directory: PathLike, train: bool = True, download: bool = True):
def __init__(
self,
directory: PathLike,
train: bool = True,
download: bool = True
):
self.dataset_directory = Path(directory)
self.image_directory = self.dataset_directory / "images"
self.mask_directory = self.dataset_directory / "masks"
Expand Down Expand Up @@ -83,48 +79,29 @@ def __len__(self):
return len(self.image_files)

def __getitem__(self, idx):
name = self.image_files[idx].name

image_file, mask_file = self.image_directory / name, self.mask_directory / name
filename = self.image_files[idx].name
image_file = self.image_directory / filename
mask_file = self.mask_directory / filename
self.check_files(image_file, mask_file)

image = torch.tensor(imageio.imread(image_file), dtype=torch.float32)
mask = torch.tensor(imageio.imread(mask_file), dtype=torch.float32)
image, mask = self.preprocess(image, mask) # (1, h, w), (h, w)
return image, mask

def preprocess(
self, image: torch.Tensor, mask: torch.Tensor
) -> torch.Tensor:
# add channel dim for torchvision transforms
image = einops.rearrange(image, "h w -> 1 h w")
mask = einops.rearrange(mask, "h w -> 1 h w")

# resample images to standard pixel size
image_pixel_size = self.PIXEL_SIZE_MAP[name[:12]]
resampling_factor = calculate_resampling_factor(
source=image_pixel_size,
target=TRAINING_PIXEL_SIZE,
)
_, h, w = TF.get_dimensions(image)
target_size = int(min(h, w) * resampling_factor)
image = TF.resize(
image, target_size, interpolation=TF.InterpolationMode.BICUBIC
)
mask = TF.resize(mask, target_size, interpolation=TF.InterpolationMode.NEAREST)

# augment if training, random crop if validating
if self._is_training:
image, mask = self.augment(image, mask)
else:
if self._validation_crop_parameters is None:
self._validation_crop_parameters = T.RandomCrop.get_params(
image, output_size=TRAINING_IMAGE_DIMENSIONS
)
image = TF.crop(image, *self._validation_crop_parameters)
mask = TF.crop(mask, *self._validation_crop_parameters)

# normalise image
image = (image - torch.mean(image)) / torch.std(image)

image, mask = self.crop_for_validation(image, mask)
image = self.normalise(image)
image = image.float().contiguous()
mask = mask.long().contiguous()

mask = einops.rearrange(mask, "1 h w -> h w")
return image, mask

Expand All @@ -147,3 +124,19 @@ def augment(
)
image, mask = random_flip(image, mask, p=0.5)
return image, mask

def crop_for_validation(
self, image: torch.Tensor, mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
if self._validation_crop_parameters is None:
self._validation_crop_parameters = T.RandomCrop.get_params(
image, output_size=TRAINING_IMAGE_DIMENSIONS
)
image = TF.crop(image, *self._validation_crop_parameters)
mask = TF.crop(mask, *self._validation_crop_parameters)
return image, mask

def normalise(self, image: torch.Tensor) -> torch.Tensor:
mean, std = torch.mean(image), torch.std(image)
torch.nan_to_num(image, nan=mean)
return (image - mean) / std

0 comments on commit e93ff65

Please sign in to comment.