From 57ceab23dadf568175d424235c21aa8eb4dc4c6a Mon Sep 17 00:00:00 2001 From: Melisande Croft <63270704+melisande-c@users.noreply.github.com> Date: Fri, 5 Jul 2024 10:17:54 +0100 Subject: [PATCH] Feature: Add File IO package (#174) ### Description This is a precursor for the soon to come save prediction feature. Save functions needed to be added; then because they mirror the structure of the read functions, the read functions have been moved from `dataset.dataset_utils` to a new `file_io` package. This package has `read` and `write` subpackages that mirror each other's structures, for ease of understanding. Additionally, the `get_read_func` is slightly refactored to match how the `get_write_func` is implemented. Read functions are stored in a module level dictionary with the keys being `SupportedData`, the `get_read_func` indexes this dictionary based on the `data_type` passed. This removes the eventuality of a long list of if/else statements. This wasn't strictly necessary as we do not plan to support a large number of data types, but the option is always there. An additional extra change that snuck into this PR is renaming `SupportedData.get_extension` to `SupportedData.get_extension_pattern`, and adding a different `SupportedData.get_extension`. The new `SupportedData.get_extension` returns the literal string without the unix wildcard patterns and will be used for saving predictions in a future PR. - **What**: - Added a `file_io` package to contain functions to read and write image files. - `SupportedData.get_extension` modification and addition. - **Why**: Partly an aesthetic choice, removes the responsibility of file reading and writing from the `datasets` package in accordance with trying to follow the single-responsibility principle. - **How**: Added write functions. Moved and slightly refactored read functions. ### Changes Made - **Added**: - `file_io`, `file_io.read`, `file_io.write` packages. - `write_tiff` function - `get_write_func` function - New `SupportedData.get_extension` to return literal extension string - Tests for all the above new functions - **Modified**: - Old `SupportedData.get_extension` renamed to `SupportedData.get_extension_pattern` - Renamed tests to mirror name change in function - **Removed**: - File reading from `datasets.dataset_utils` ### Breaking changes - Code calling read functions from `datasets.dataset_utils` directly. - Code using `SupportedData.get_extension` directly. ### Additional Notes and Examples In a future PR `SupportedData.get_extension` and `SupportedData.get_extension_pattern` could be moved to the `file_io` package. These functions do not need to be bound to `SupportedData` and none of the other "support" `Enum` classes have additional methods. It might make sense to store these functions closer to where they are used. This is again a stylistic choice, feel free to share other opinions. --- **Please ensure your PR meets the following requirements:** - [x] Code builds and passes tests locally, including doctests - [x] New tests have been added (for bug fixes/features) - [x] Pre-commit passes - [ ] PR to the documentation exists (for bug fixes / features) --------- Co-authored-by: jdeschamps <6367888+jdeschamps@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../config/support/supported_data.py | 33 +++++++++-- .../dataset/dataset_utils/__init__.py | 6 -- .../dataset/dataset_utils/file_utils.py | 2 +- .../dataset_utils/iterate_over_files.py | 2 +- .../dataset/dataset_utils/read_utils.py | 27 --------- src/careamics/dataset/in_memory_dataset.py | 2 +- src/careamics/dataset/iterable_dataset.py | 3 +- .../dataset/iterable_pred_dataset.py | 3 +- .../dataset/iterable_tiled_pred_dataset.py | 3 +- src/careamics/file_io/__init__.py | 7 +++ src/careamics/file_io/read/__init__.py | 11 ++++ src/careamics/file_io/read/get_func.py | 56 ++++++++++++++++++ .../read_tiff.py => file_io/read/tiff.py} | 4 +- .../read_zarr.py => file_io/read/zarr.py} | 0 src/careamics/file_io/write/__init__.py | 9 +++ src/careamics/file_io/write/get_func.py | 59 +++++++++++++++++++ src/careamics/file_io/write/tiff.py | 39 ++++++++++++ .../lightning/predict_data_module.py | 6 +- src/careamics/lightning/train_data_module.py | 2 +- tests/config/support/test_supported_data.py | 49 +++++++++++---- tests/dataset/test_iterable_dataset.py | 2 +- tests/file_io/read/test_get_read_func.py | 14 +++++ .../read}/test_read_tiff.py | 2 +- tests/file_io/write/test_get_write_func.py | 14 +++++ tests/file_io/write/test_write_tiff.py | 27 +++++++++ 25 files changed, 320 insertions(+), 62 deletions(-) delete mode 100644 src/careamics/dataset/dataset_utils/read_utils.py create mode 100644 src/careamics/file_io/__init__.py create mode 100644 src/careamics/file_io/read/__init__.py create mode 100644 src/careamics/file_io/read/get_func.py rename src/careamics/{dataset/dataset_utils/read_tiff.py => file_io/read/tiff.py} (92%) rename src/careamics/{dataset/dataset_utils/read_zarr.py => file_io/read/zarr.py} (100%) create mode 100644 src/careamics/file_io/write/__init__.py create mode 100644 src/careamics/file_io/write/get_func.py create mode 100644 src/careamics/file_io/write/tiff.py create mode 100644 tests/file_io/read/test_get_read_func.py rename tests/{dataset/dataset_utils => file_io/read}/test_read_tiff.py (91%) create mode 100644 tests/file_io/write/test_get_write_func.py create mode 100644 tests/file_io/write/test_write_tiff.py diff --git a/src/careamics/config/support/supported_data.py b/src/careamics/config/support/supported_data.py index 73f32975..691a0cb8 100644 --- a/src/careamics/config/support/supported_data.py +++ b/src/careamics/config/support/supported_data.py @@ -60,9 +60,9 @@ def _missing_(cls, value: object) -> str: return super()._missing_(value) @classmethod - def get_extension(cls, data_type: Union[str, SupportedData]) -> str: + def get_extension_pattern(cls, data_type: Union[str, SupportedData]) -> str: """ - Path.rglob and fnmatch compatible extension. + Get Path.rglob and fnmatch compatible extension. Parameters ---------- @@ -72,13 +72,38 @@ def get_extension(cls, data_type: Union[str, SupportedData]) -> str: Returns ------- str - Corresponding extension. + Corresponding extension pattern. """ if data_type == cls.ARRAY: - raise NotImplementedError(f"Data {data_type} are not loaded from file.") + raise NotImplementedError(f"Data '{data_type}' is not loaded from a file.") elif data_type == cls.TIFF: return "*.tif*" elif data_type == cls.CUSTOM: return "*.*" else: raise ValueError(f"Data type {data_type} is not supported.") + + @classmethod + def get_extension(cls, data_type: Union[str, SupportedData]) -> str: + """ + Get file extension of corresponding data type. + + Parameters + ---------- + data_type : str or SupportedData + Data type. + + Returns + ------- + str + Corresponding extension. + """ + if data_type == cls.ARRAY: + raise NotImplementedError(f"Data '{data_type}' is not loaded from a file.") + elif data_type == cls.TIFF: + return ".tiff" + elif data_type == cls.CUSTOM: + # TODO: improve this message + raise NotImplementedError("Custom extensions have to be passed elsewhere.") + else: + raise ValueError(f"Data type {data_type} is not supported.") diff --git a/src/careamics/dataset/dataset_utils/__init__.py b/src/careamics/dataset/dataset_utils/__init__.py index b6a626aa..d6340b73 100644 --- a/src/careamics/dataset/dataset_utils/__init__.py +++ b/src/careamics/dataset/dataset_utils/__init__.py @@ -6,9 +6,6 @@ "get_files_size", "list_files", "validate_source_target_files", - "read_tiff", - "get_read_func", - "read_zarr", "iterate_over_files", "WelfordStatistics", ] @@ -19,7 +16,4 @@ ) from .file_utils import get_files_size, list_files, validate_source_target_files from .iterate_over_files import iterate_over_files -from .read_tiff import read_tiff -from .read_utils import get_read_func -from .read_zarr import read_zarr from .running_stats import WelfordStatistics, compute_normalization_stats diff --git a/src/careamics/dataset/dataset_utils/file_utils.py b/src/careamics/dataset/dataset_utils/file_utils.py index a37905a0..7bdbde57 100644 --- a/src/careamics/dataset/dataset_utils/file_utils.py +++ b/src/careamics/dataset/dataset_utils/file_utils.py @@ -75,7 +75,7 @@ def list_files( raise FileNotFoundError(f"Data path {data_path} does not exist.") # get extension compatible with fnmatch and rglob search - extension = SupportedData.get_extension(data_type) + extension = SupportedData.get_extension_pattern(data_type) if data_type == SupportedData.CUSTOM and extension_filter != "": extension = extension_filter diff --git a/src/careamics/dataset/dataset_utils/iterate_over_files.py b/src/careamics/dataset/dataset_utils/iterate_over_files.py index 330a8839..b1b011a4 100644 --- a/src/careamics/dataset/dataset_utils/iterate_over_files.py +++ b/src/careamics/dataset/dataset_utils/iterate_over_files.py @@ -9,10 +9,10 @@ from torch.utils.data import get_worker_info from careamics.config import DataConfig, InferenceConfig +from careamics.file_io.read import read_tiff from careamics.utils.logging import get_logger from .dataset_utils import reshape_array -from .read_tiff import read_tiff logger = get_logger(__name__) diff --git a/src/careamics/dataset/dataset_utils/read_utils.py b/src/careamics/dataset/dataset_utils/read_utils.py deleted file mode 100644 index 75373253..00000000 --- a/src/careamics/dataset/dataset_utils/read_utils.py +++ /dev/null @@ -1,27 +0,0 @@ -"""Read function utilities.""" - -from typing import Callable, Union - -from careamics.config.support import SupportedData - -from .read_tiff import read_tiff - - -def get_read_func(data_type: Union[SupportedData, str]) -> Callable: - """ - Get the read function for the data type. - - Parameters - ---------- - data_type : SupportedData - Data type. - - Returns - ------- - Callable - Read function. - """ - if data_type == SupportedData.TIFF: - return read_tiff - else: - raise NotImplementedError(f"Data type {data_type} is not supported.") diff --git a/src/careamics/dataset/in_memory_dataset.py b/src/careamics/dataset/in_memory_dataset.py index 0b8c3ab4..0918513a 100644 --- a/src/careamics/dataset/in_memory_dataset.py +++ b/src/careamics/dataset/in_memory_dataset.py @@ -9,12 +9,12 @@ import numpy as np from torch.utils.data import Dataset +from careamics.file_io.read import read_tiff from careamics.transforms import Compose from ..config import DataConfig from ..config.transformations import NormalizeModel from ..utils.logging import get_logger -from .dataset_utils import read_tiff from .patching.patching import ( PatchedOutput, Stats, diff --git a/src/careamics/dataset/iterable_dataset.py b/src/careamics/dataset/iterable_dataset.py index 06cb02af..680355e3 100644 --- a/src/careamics/dataset/iterable_dataset.py +++ b/src/careamics/dataset/iterable_dataset.py @@ -12,10 +12,11 @@ from careamics.config import DataConfig from careamics.config.transformations import NormalizeModel +from careamics.file_io.read import read_tiff from careamics.transforms import Compose from ..utils.logging import get_logger -from .dataset_utils import iterate_over_files, read_tiff +from .dataset_utils import iterate_over_files from .dataset_utils.running_stats import WelfordStatistics from .patching.patching import Stats from .patching.random_patching import extract_patches_random diff --git a/src/careamics/dataset/iterable_pred_dataset.py b/src/careamics/dataset/iterable_pred_dataset.py index 6f5306b6..f9d385da 100644 --- a/src/careamics/dataset/iterable_pred_dataset.py +++ b/src/careamics/dataset/iterable_pred_dataset.py @@ -8,11 +8,12 @@ from numpy.typing import NDArray from torch.utils.data import IterableDataset +from careamics.file_io.read import read_tiff from careamics.transforms import Compose from ..config import InferenceConfig from ..config.transformations import NormalizeModel -from .dataset_utils import iterate_over_files, read_tiff +from .dataset_utils import iterate_over_files class IterablePredDataset(IterableDataset): diff --git a/src/careamics/dataset/iterable_tiled_pred_dataset.py b/src/careamics/dataset/iterable_tiled_pred_dataset.py index 2ea030b5..7242e31b 100644 --- a/src/careamics/dataset/iterable_tiled_pred_dataset.py +++ b/src/careamics/dataset/iterable_tiled_pred_dataset.py @@ -8,12 +8,13 @@ from numpy.typing import NDArray from torch.utils.data import IterableDataset +from careamics.file_io.read import read_tiff from careamics.transforms import Compose from ..config import InferenceConfig from ..config.tile_information import TileInformation from ..config.transformations import NormalizeModel -from .dataset_utils import iterate_over_files, read_tiff +from .dataset_utils import iterate_over_files from .tiling import extract_tiles diff --git a/src/careamics/file_io/__init__.py b/src/careamics/file_io/__init__.py new file mode 100644 index 00000000..26ab6a67 --- /dev/null +++ b/src/careamics/file_io/__init__.py @@ -0,0 +1,7 @@ +"""Functions relating reading and writing image files.""" + +__all__ = ["read", "write", "get_read_func", "get_write_func"] + +from . import read, write +from .read import get_read_func +from .write import get_write_func diff --git a/src/careamics/file_io/read/__init__.py b/src/careamics/file_io/read/__init__.py new file mode 100644 index 00000000..0c52555e --- /dev/null +++ b/src/careamics/file_io/read/__init__.py @@ -0,0 +1,11 @@ +"""Functions relating to reading image files of different formats.""" + +__all__ = [ + "get_read_func", + "read_tiff", + "read_zarr", +] + +from .get_func import get_read_func +from .tiff import read_tiff +from .zarr import read_zarr diff --git a/src/careamics/file_io/read/get_func.py b/src/careamics/file_io/read/get_func.py new file mode 100644 index 00000000..719fda9a --- /dev/null +++ b/src/careamics/file_io/read/get_func.py @@ -0,0 +1,56 @@ +"""Module to get read functions.""" + +from pathlib import Path +from typing import Callable, Dict, Protocol, Union + +from numpy.typing import NDArray + +from careamics.config.support import SupportedData + +from .tiff import read_tiff + + +# This is very strict, function signature has to match including arg names +# See WriteFunc notes +class ReadFunc(Protocol): + """Protocol for type hinting read functions.""" + + def __call__(self, file_path: Path, *args, **kwargs) -> NDArray: + """ + Type hinted callables must match this function signature (not including self). + + Parameters + ---------- + file_path : pathlib.Path + Path to file. + *args + Other positional arguments. + **kwargs + Other keyword arguments. + """ + + +READ_FUNCS: Dict[SupportedData, ReadFunc] = { + SupportedData.TIFF: read_tiff, +} + + +def get_read_func(data_type: Union[str, SupportedData]) -> Callable: + """ + Get the read function for the data type. + + Parameters + ---------- + data_type : SupportedData + Data type. + + Returns + ------- + callable + Read function. + """ + if data_type in READ_FUNCS: + data_type = SupportedData(data_type) # mypy complaining about dict key type + return READ_FUNCS[data_type] + else: + raise NotImplementedError(f"Data type '{data_type}' is not supported.") diff --git a/src/careamics/dataset/dataset_utils/read_tiff.py b/src/careamics/file_io/read/tiff.py similarity index 92% rename from src/careamics/dataset/dataset_utils/read_tiff.py rename to src/careamics/file_io/read/tiff.py index 0cea0f69..1745b3ba 100644 --- a/src/careamics/dataset/dataset_utils/read_tiff.py +++ b/src/careamics/file_io/read/tiff.py @@ -44,7 +44,9 @@ def read_tiff(file_path: Path, *args: list, **kwargs: dict) -> np.ndarray: ValueError If the axes length is incorrect. """ - if fnmatch(file_path.suffix, SupportedData.get_extension(SupportedData.TIFF)): + if fnmatch( + file_path.suffix, SupportedData.get_extension_pattern(SupportedData.TIFF) + ): try: array = tifffile.imread(file_path) except (ValueError, OSError) as e: diff --git a/src/careamics/dataset/dataset_utils/read_zarr.py b/src/careamics/file_io/read/zarr.py similarity index 100% rename from src/careamics/dataset/dataset_utils/read_zarr.py rename to src/careamics/file_io/read/zarr.py diff --git a/src/careamics/file_io/write/__init__.py b/src/careamics/file_io/write/__init__.py new file mode 100644 index 00000000..c99fe074 --- /dev/null +++ b/src/careamics/file_io/write/__init__.py @@ -0,0 +1,9 @@ +"""Functions relating to writing image files of different formats.""" + +__all__ = [ + "get_write_func", + "write_tiff", +] + +from .get_func import get_write_func +from .tiff import write_tiff diff --git a/src/careamics/file_io/write/get_func.py b/src/careamics/file_io/write/get_func.py new file mode 100644 index 00000000..76d3c711 --- /dev/null +++ b/src/careamics/file_io/write/get_func.py @@ -0,0 +1,59 @@ +"""Module to get write functions.""" + +from pathlib import Path +from typing import Protocol, Union + +from numpy.typing import NDArray + +from careamics.config.support import SupportedData + +from .tiff import write_tiff + + +# This is very strict, arguments have to be called file_path & img +# Alternative? - doesn't capture *args & **kwargs +# WriteFunc = Callable[[Path, NDArray], None] +class WriteFunc(Protocol): + """Protocol for type hinting write functions.""" + + def __call__(self, file_path: Path, img: NDArray, *args, **kwargs) -> None: + """ + Type hinted callables must match this function signature (not including self). + + Parameters + ---------- + file_path : pathlib.Path + Path to file. + img : numpy.ndarray + Image data to save. + *args + Other positional arguments. + **kwargs + Other keyword arguments. + """ + + +WRITE_FUNCS: dict[SupportedData, WriteFunc] = { + SupportedData.TIFF: write_tiff, +} + + +def get_write_func(data_type: Union[str, SupportedData]) -> WriteFunc: + """ + Get the write function for the data type. + + Parameters + ---------- + data_type : SupportedData + Data type. + + Returns + ------- + callable + Write function. + """ + if data_type in WRITE_FUNCS: + data_type = SupportedData(data_type) # mypy complaining about dict key type + return WRITE_FUNCS[data_type] + else: + raise NotImplementedError(f"Data type {data_type} is not supported.") diff --git a/src/careamics/file_io/write/tiff.py b/src/careamics/file_io/write/tiff.py new file mode 100644 index 00000000..75a0872a --- /dev/null +++ b/src/careamics/file_io/write/tiff.py @@ -0,0 +1,39 @@ +"""Write tiff function.""" + +from fnmatch import fnmatch +from pathlib import Path + +import tifffile +from numpy.typing import NDArray + +from careamics.config.support import SupportedData + + +def write_tiff(file_path: Path, img: NDArray, *args, **kwargs) -> None: + """ + Write tiff files. + + Parameters + ---------- + file_path : pathlib.Path + Path to file. + img : numpy.ndarray + Image data to save. + *args + Positional arguments passed to `tifffile.imwrite`. + **kwargs + Keyword arguments passed to `tifffile.imwrite`. + + Raises + ------ + ValueError + When the file extension of `file_path` does not match the Unix shell-style + pattern '*.tif*'. + """ + if not fnmatch( + file_path.suffix, SupportedData.get_extension_pattern(SupportedData.TIFF) + ): + raise ValueError( + f"Unexpected extension '{file_path.suffix}' for save file type 'tiff'." + ) + tifffile.imwrite(file_path, img, *args, **kwargs) diff --git a/src/careamics/lightning/predict_data_module.py b/src/careamics/lightning/predict_data_module.py index 71670b42..f0abce80 100644 --- a/src/careamics/lightning/predict_data_module.py +++ b/src/careamics/lightning/predict_data_module.py @@ -16,11 +16,9 @@ IterablePredDataset, IterableTiledPredDataset, ) -from careamics.dataset.dataset_utils import ( - get_read_func, - list_files, -) +from careamics.dataset.dataset_utils import list_files from careamics.dataset.tiling.collate_tiles import collate_tiles +from careamics.file_io.read import get_read_func from careamics.utils import get_logger PredictDatasetType = Union[ diff --git a/src/careamics/lightning/train_data_module.py b/src/careamics/lightning/train_data_module.py index 220daf1f..54b4ee3e 100644 --- a/src/careamics/lightning/train_data_module.py +++ b/src/careamics/lightning/train_data_module.py @@ -13,7 +13,6 @@ from careamics.config.support import SupportedData from careamics.dataset.dataset_utils import ( get_files_size, - get_read_func, list_files, validate_source_target_files, ) @@ -23,6 +22,7 @@ from careamics.dataset.iterable_dataset import ( PathIterableDataset, ) +from careamics.file_io.read import get_read_func from careamics.utils import get_logger, get_ram_size DatasetType = Union[InMemoryDataset, PathIterableDataset] diff --git a/tests/config/support/test_supported_data.py b/tests/config/support/test_supported_data.py index ef8bf22d..bc0fff67 100644 --- a/tests/config/support/test_supported_data.py +++ b/tests/config/support/test_supported_data.py @@ -8,18 +8,18 @@ from careamics.config.support import SupportedData -def test_extension_tiff_fnmatch(tmp_path: Path): +def test_extension_pattern_tiff_fnmatch(tmp_path: Path): """Test that the TIFF extension is compatible with fnmatch.""" path = tmp_path / "test.tif" # test as str - assert fnmatch(str(path), SupportedData.get_extension(SupportedData.TIFF)) + assert fnmatch(str(path), SupportedData.get_extension_pattern(SupportedData.TIFF)) # test as Path - assert fnmatch(path, SupportedData.get_extension(SupportedData.TIFF)) + assert fnmatch(path, SupportedData.get_extension_pattern(SupportedData.TIFF)) -def test_extension_tiff_rglob(tmp_path: Path): +def test_extension_pattern_tiff_rglob(tmp_path: Path): """Test that the TIFF extension is compatible with Path.rglob.""" # create text file text_path = tmp_path / "test.txt" @@ -31,23 +31,25 @@ def test_extension_tiff_rglob(tmp_path: Path): tifffile.imwrite(path, image) # search for files - files = list(tmp_path.rglob(SupportedData.get_extension(SupportedData.TIFF))) + files = list( + tmp_path.rglob(SupportedData.get_extension_pattern(SupportedData.TIFF)) + ) assert len(files) == 1 assert files[0] == path -def test_extension_custom_fnmatch(tmp_path: Path): +def test_extension_pattern_custom_fnmatch(tmp_path: Path): """Test that the custom extension is compatible with fnmatch.""" path = tmp_path / "test.czi" # test as str - assert fnmatch(str(path), SupportedData.get_extension(SupportedData.CUSTOM)) + assert fnmatch(str(path), SupportedData.get_extension_pattern(SupportedData.CUSTOM)) # test as Path - assert fnmatch(path, SupportedData.get_extension(SupportedData.CUSTOM)) + assert fnmatch(path, SupportedData.get_extension_pattern(SupportedData.CUSTOM)) -def test_extension_custom_rglob(tmp_path: Path): +def test_extension_pattern_custom_rglob(tmp_path: Path): """Test that the custom extension is compatible with Path.rglob.""" # create text file text_path = tmp_path / "test.txt" @@ -59,18 +61,43 @@ def test_extension_custom_rglob(tmp_path: Path): np.save(path, image) # search for files - files = list(tmp_path.rglob(SupportedData.get_extension(SupportedData.CUSTOM))) + files = list( + tmp_path.rglob(SupportedData.get_extension_pattern(SupportedData.CUSTOM)) + ) assert len(files) == 2 assert set(files) == {path, text_path} +def test_extension_pattern_array_error(): + """Test that the array extension raises NotImplementedError.""" + with pytest.raises(NotImplementedError): + SupportedData.get_extension_pattern(SupportedData.ARRAY) + + +def test_extension_pattern_any_error(): + """Test that any extension raises ValueError.""" + with pytest.raises(ValueError): + SupportedData.get_extension_pattern("some random") + + def test_extension_array_error(): """Test that the array extension raises NotImplementedError.""" with pytest.raises(NotImplementedError): SupportedData.get_extension(SupportedData.ARRAY) +def test_extension_tiff(): + """Test that the tiff extension is .tiff.""" + assert SupportedData.get_extension(SupportedData.TIFF) == ".tiff" + + +def test_extension_custom_error(): + """Test that the custom extension returns NotImplementedError.""" + with pytest.raises(NotImplementedError): + SupportedData.get_extension(SupportedData.CUSTOM) + + def test_extension_any_error(): - """Test that any extension raises NotImplementedError.""" + """Test that any extension raises ValueError.""" with pytest.raises(ValueError): SupportedData.get_extension("some random") diff --git a/tests/dataset/test_iterable_dataset.py b/tests/dataset/test_iterable_dataset.py index d3e90feb..b61ec9a1 100644 --- a/tests/dataset/test_iterable_dataset.py +++ b/tests/dataset/test_iterable_dataset.py @@ -7,7 +7,7 @@ from careamics.config import DataConfig from careamics.config.support import SupportedData from careamics.dataset import PathIterableDataset -from careamics.dataset.dataset_utils import read_tiff +from careamics.file_io.read import read_tiff @pytest.mark.parametrize( diff --git a/tests/file_io/read/test_get_read_func.py b/tests/file_io/read/test_get_read_func.py new file mode 100644 index 00000000..c3717f31 --- /dev/null +++ b/tests/file_io/read/test_get_read_func.py @@ -0,0 +1,14 @@ +import pytest + +from careamics.config.support import SupportedData +from careamics.file_io import get_read_func +from careamics.file_io.read import read_tiff + + +def test_get_read_tiff(): + assert get_read_func(SupportedData.TIFF) is read_tiff + + +def test_get_read_any_error(): + with pytest.raises(NotImplementedError): + get_read_func("some random") diff --git a/tests/dataset/dataset_utils/test_read_tiff.py b/tests/file_io/read/test_read_tiff.py similarity index 91% rename from tests/dataset/dataset_utils/test_read_tiff.py rename to tests/file_io/read/test_read_tiff.py index 5d534b8f..cd9bd3e0 100644 --- a/tests/dataset/dataset_utils/test_read_tiff.py +++ b/tests/file_io/read/test_read_tiff.py @@ -2,7 +2,7 @@ import pytest import tifffile -from careamics.dataset.dataset_utils.read_tiff import read_tiff +from careamics.file_io.read import read_tiff def test_read_tiff(tmp_path, ordered_array): diff --git a/tests/file_io/write/test_get_write_func.py b/tests/file_io/write/test_get_write_func.py new file mode 100644 index 00000000..ef923a54 --- /dev/null +++ b/tests/file_io/write/test_get_write_func.py @@ -0,0 +1,14 @@ +import pytest + +from careamics.config.support import SupportedData +from careamics.file_io import get_write_func +from careamics.file_io.write import write_tiff + + +def test_get_write_tiff(): + assert get_write_func(SupportedData.TIFF) is write_tiff + + +def test_get_write_any_error(): + with pytest.raises(NotImplementedError): + get_write_func("some random") diff --git a/tests/file_io/write/test_write_tiff.py b/tests/file_io/write/test_write_tiff.py new file mode 100644 index 00000000..6fcdb719 --- /dev/null +++ b/tests/file_io/write/test_write_tiff.py @@ -0,0 +1,27 @@ +import numpy as np +import pytest + +from careamics.file_io.write import write_tiff + + +def test_write_tiff(tmp_path, ordered_array): + """Test writing a tiff file.""" + # create an array + array: np.ndarray = ordered_array((10, 10)) + + # save files + file = tmp_path / "test.tiff" + write_tiff(file, array) + + assert file.is_file() + + +def test_invalid_extension_error(tmp_path, ordered_array): + """Test error is raised when a path with an invalid extension is used.""" + # create an array + array: np.ndarray = ordered_array((10, 10)) + + # save files + file = tmp_path / "test.invalid" + with pytest.raises(ValueError): + write_tiff(file, array)