From edb635f52bf9a2033debd0bdf123eb6595ad49ab Mon Sep 17 00:00:00 2001 From: b-reyes <53541061+b-reyes@users.noreply.github.com> Date: Thu, 15 Dec 2022 08:59:26 -0800 Subject: [PATCH 1/8] Implement subset portion of `channel_selection` input for `combine_echodata` (#892) * create initial framework for checking for channel consistency accross all EchoData groups * create initial framework for construction of the channel_selection dictionary * document _check_channel_consistency and add a test for it * sort non-None values in dict created in create_channel_selection_dict and start constructing tests for create_channel_selection_dict * document create_channel_selection_dict and create a test for it * correct logic in consistency check for when channel_selection=None, finish documenting _check_echodata_channels, and remove channel check in zarr_combine * remove channel check in zarr_combine as it is no longer necessary * add docstring for channel_selection in combine_echodata and add a routine to check the type of channel_selection * perform channel selection in zarr_combine using the new input ed_group_chan_sel and add a test that will be skipped for channel_selection input to combine_echodata * Improve docstring descriptions Co-authored-by: Wu-Jung Lee * fix failing test_combine_consolidated by adding keyword ed_group_chan_sel to zarr_combine.combine * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add variable name in Returns docstring Co-authored-by: Wu-Jung Lee * add notes to _create_channel_selection_dict directing individuals to the tests for example outputs of the function * check for Sonar/Beam_group using regular expression * simplify the logic provided in _create_channel_selection_dict code block * move list type check so that it is evaluated first in the if statement Co-authored-by: Wu-Jung Lee Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- echopype/echodata/combine.py | 320 +++++++++++++++++- echopype/echodata/zarr_combine.py | 59 +--- .../tests/echodata/test_echodata_combine.py | 164 +++++++++ echopype/tests/echodata/test_zarr_combine.py | 7 +- 4 files changed, 506 insertions(+), 44 deletions(-) diff --git a/echopype/echodata/combine.py b/echopype/echodata/combine.py index dd9e5e4a0c..f460cdf125 100644 --- a/echopype/echodata/combine.py +++ b/echopype/echodata/combine.py @@ -1,3 +1,5 @@ +import itertools +import re from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union from warnings import warn @@ -89,6 +91,59 @@ def check_zarr_path( return validated_path +def _check_channel_selection_form( + channel_selection: Optional[Union[List, Dict[str, list]]] = None +) -> None: + """ + Ensures that the provided user input ``channel_selection`` is in + an acceptable form. + + Parameters + ---------- + channel_selection: list of str or dict, optional + Specifies what channels should be selected for an ``EchoData`` group + with a ``channel`` dimension (before combination). + """ + + # check that channel selection is None, a list, or a dict + if not isinstance(channel_selection, (type(None), list, dict)): + raise TypeError("The input channel_selection does not have an acceptable type!") + + if isinstance(channel_selection, list): + # make sure each element is a string + are_elem_str = [isinstance(elem, str) for elem in channel_selection] + if not all(are_elem_str): + raise TypeError("Each element of channel_selection must be a string!") + + if isinstance(channel_selection, dict): + + # make sure all keys are strings + are_keys_str = [isinstance(elem, str) for elem in channel_selection.keys()] + if not all(are_keys_str): + raise TypeError("Each key of channel_selection must be a string!") + + # make sure all keys are of the form Sonar/Beam_group using regular expression + are_keys_right_form = [ + True if re.match("Sonar/Beam_group(\d{1})", elem) else False # noqa + for elem in channel_selection.keys() + ] + if not all(are_keys_right_form): + raise TypeError( + "Each key of channel_selection can only be a beam group path of " + "the form Sonar/Beam_group!" + ) + + # make sure all values are a list + are_vals_list = [isinstance(elem, list) for elem in channel_selection.values()] + if not all(are_vals_list): + raise TypeError("Each value of channel_selection must be a list!") + + # make sure all values are a list of strings + are_vals_list_str = [set(map(type, elem)) == {str} for elem in channel_selection] + if not all(are_vals_list_str): + raise TypeError("Each value of channel_selection must be a list of strings!") + + def check_echodatas_input(echodatas: List[EchoData]) -> Tuple[str, List[str]]: """ Ensures that the input list of ``EchoData`` objects for ``combine_echodata`` @@ -154,12 +209,249 @@ def check_echodatas_input(echodatas: List[EchoData]) -> Tuple[str, List[str]]: return sonar_model, echodata_filenames +def _check_channel_consistency( + all_chan_list: List, ed_group: str, channel_selection: Optional[List[str]] = None +) -> None: + """ + If ``channel_selection = None``, checks that each element in ``all_chan_list`` are + the same, else makes sure that each element in ``all_chan_list`` contains all channel + names in ``channel_selection``. + + Parameters + ---------- + all_chan_list: list of list + A list whose elements correspond to the Datasets to be combined with + their values set as a list of the channel dimension names in the Dataset + ed_group: str + The EchoData group path that produced ``all_chan_list`` + channel_selection: list of str, optional + A list of channel names, which should be a subset of each + element in ``all_chan_list`` + + Raises + ------ + RuntimeError + If ``channel_selection=None`` and all ``channel`` dimensions are not the + same across all Datasets. + NotImplementedError + If ``channel_selection`` is a list and the listed channels are not contained + in the ``EchoData`` group for all Datasets and need to be created and + padded with NaN. This "expansion" type of combination has not been implemented. + """ + + if channel_selection is None: + + # sort each element in list, so correct comparison can be made + all_chan_list = list(map(sorted, all_chan_list)) + + # determine if the channels are the same across all Datasets + all_chans_equal = [all_chan_list[0]] * len(all_chan_list) == all_chan_list + + if not all_chans_equal: + + # obtain all unique channel names + unique_channels = set(itertools.chain.from_iterable(all_chan_list)) + + # raise an error if we have varying channel lengths + raise RuntimeError( + f"For the EchoData group {ed_group} the channels: {unique_channels} are " + f"not found in all EchoData objects being combined. Select which " + f"channels should be included in the combination using the keyword argument " + f"channel_selection in combine_echodata." + ) + + else: + + # make channel_selection a set, so it is easier to use + channel_selection = set(channel_selection) + + # TODO: if we will allow for expansion, then the below code should be + # replaced with a code section that makes sure the selected channels + # appear at least once in one of the other Datasets + + # determine if channel selection is in each element of all_chan_list + eds_num_chan = [ + channel_selection.intersection(set(ed_chans)) == channel_selection + for ed_chans in all_chan_list + ] + + if not all(eds_num_chan): + # raise a not implemented error if expansion (i.e. padding is necessary) + raise NotImplementedError( + f"For the EchoData group {ed_group}, some EchoData objects do " + f"not contain the selected channels. This type of combine is " + f"not currently implemented." + ) + + +def _create_channel_selection_dict( + sonar_model: str, + has_chan_dim: Dict[str, bool], + user_channel_selection: Optional[Union[List, Dict[str, list]]] = None, +) -> Dict[str, Optional[list]]: + """ + Constructs the dictionary ``channel_selection_dict``, which specifies + the ``channel`` dimension names that should be selected for each + ``EchoData`` group. If a group does not have a ``channel`` dimension + the dictionary value will be set to ``None`` + + Parameters + ---------- + sonar_model: str + The name of the sonar model corresponding to ``has_chan_dim`` + has_chan_dim: dict + A dictionary created using an ``EchoData`` object whose keys are + the ``EchoData`` groups and whose values specify if that + particular group has a ``channel`` dimension + user_channel_selection: list or dict, optional + A user provided input that will be used to construct the values of + ``channel_selection_dict`` (see below for further details) + + Returns + ------- + channel_selection_dict : dict + A dictionary with the same keys as ``has_chan_dim`` and values + determined by ``sonar_model`` and ``user_channel_selection`` as follows: + - If ``user_channel_selection=None``, then the values of the dictionary + will be set to ``None`` + - If ``user_channel_selection`` is a list, then all keys corresponding to + an ``EchoData`` group with a ``channel`` dimension will have their values + set to the provided list and all other groups will be set to ``None`` + - If ``user_channel_selection`` is a dictionary, then all keys corresponding to + an ``EchoData`` group without a ``channel`` dimension will have their values + set as ``None`` and the other group's values will be set as follows: + - If ``sonar_model`` is not EK80-like then all values will be set to + the union of the values of ``user_channel_selection`` + - If ``sonar_model`` is EK80-like then the groups ``Sonar, Platform, Vendor_specific`` + will be set to the union of the values of ``user_channel_selection`` and the rest of + the groups will be set to the same value in ``user_channel_selection`` with the same key + + Notes + ----- + See ``tests/echodata/test_echodata_combine.py::test_create_channel_selection_dict`` for example + outputs from this function. + """ + + # base case where the user did not provide selected channels (will be used downstream) + if user_channel_selection is None: + return {grp: None for grp in has_chan_dim.keys()} + + # obtain the union of all channels for each beam group + if isinstance(user_channel_selection, list): + union_beam_chans = user_channel_selection[:] + else: + union_beam_chans = list(set(itertools.chain.from_iterable(user_channel_selection.values()))) + + # make channel_selection dictionary where the keys are the EchoData groups and the + # values are based on the user provided input user_channel_selection + channel_selection_dict = dict() + for ed_group, has_chan in has_chan_dim.items(): + + # if there are no channel dimensions in the group, set the value to None + if has_chan: + + if ( + (not isinstance(user_channel_selection, list)) + and (sonar_model in ["EK80", "ES80", "EA640"]) + and (ed_group not in ["Sonar", "Platform", "Vendor_specific"]) + ): + # set value to the user provided input with the same key + channel_selection_dict[ed_group] = user_channel_selection[ed_group] + + else: + # set value to the union of the values of user_channel_selection + channel_selection_dict[ed_group] = union_beam_chans + + # sort channel names to produce consistent output (since we may be using sets) + channel_selection_dict[ed_group].sort() + + else: + channel_selection_dict[ed_group] = None + + return channel_selection_dict + + +def _check_echodata_channels( + echodatas: List[EchoData], user_channel_selection: Optional[Union[List, Dict[str, list]]] = None +) -> Dict[str, Optional[List[str]]]: + """ + Coordinates the routines that check to make sure each ``EchoData`` group with a ``channel`` + dimension has consistent channels for all elements in ``echodatas``, taking into account + the input ``user_channel_selection``. + + Parameters + ---------- + echodatas: list of EchoData object + The list of ``EchoData`` objects to be combined + user_channel_selection: list or dict, optional + A user provided input that will be used to specify which channels will be + selected for each ``EchoData`` group + + Returns + ------- + dict + A dictionary with keys corresponding to the ``EchoData`` groups and + values specifying the channels that should be selected within that group. + For more information on this dictionary see the function ``_create_channel_selection_dict``. + + Raises + ------ + RuntimeError + If any ``EchoData`` group has a ``channel`` dimension value + with a duplicate value. + + Notes + ----- + For further information on what is deemed consistent, please see the + function ``_check_channel_consistency``. + """ + + # determine if the EchoData group contains a channel dimension + has_chan_dim = {grp: "channel" in echodatas[0][grp].dims for grp in echodatas[0].group_paths} + + # create dictionary specifying the channels that should be selected for each group + channel_selection = _create_channel_selection_dict( + echodatas[0].sonar_model, has_chan_dim, user_channel_selection + ) + + for ed_group in echodatas[0].group_paths: + + if "channel" in echodatas[0][ed_group].dims: + + # get each EchoData's channels as a list of list + all_chan_list = [list(ed[ed_group].channel.values) for ed in echodatas] + + # make sure each EchoData does not have repeating channels + all_chan_unique = [len(set(ed_chans)) == len(ed_chans) for ed_chans in all_chan_list] + + if not all(all_chan_unique): + # get indices of EchoData objects with repeating channel names + false_ind = [ind for ind, x in enumerate(all_chan_unique) if not x] + + # get files that produced the EchoData objects with repeated channels + files_w_rep_chan = [ + echodatas[ind]["Provenance"].source_filenames.values[0] for ind in false_ind + ] + + raise RuntimeError( + f"The EchoData objects produced by the following files " + f"have a channel dimension with repeating values, " + f"combine cannot be used: {files_w_rep_chan}" + ) + + # perform a consistency check for the channel dims across all Datasets + _check_channel_consistency(all_chan_list, ed_group, channel_selection[ed_group]) + + return channel_selection + + def combine_echodata( echodatas: List[EchoData] = None, zarr_path: Optional[Union[str, Path]] = None, overwrite: bool = False, storage_options: Dict[str, Any] = {}, client: Optional[dask.distributed.Client] = None, + channel_selection: Optional[Union[List, Dict[str, list]]] = None, consolidated: bool = True, ) -> EchoData: """ @@ -182,6 +474,16 @@ def combine_echodata( backend (ignored for local paths) client: dask.distributed.Client, optional An initialized Dask distributed client + channel_selection: list of str or dict, optional + Specifies what channels should be selected for an ``EchoData`` group + with a ``channel`` dimension (before combination). + + - if a list is provided, then each ``EchoData`` group with a ``channel`` dimension + will only contain the channels in the provided list + - if a dictionary is provided, the dictionary should have keys specifying only beam + groups (e.g. "Sonar/Beam_group1") and values as a list of channel names to select + within that beam group. The rest of the ``EchoData`` groups with a ``channel`` dimension + will have their selected channels chosen automatically. consolidated: bool Flag to consolidate zarr metadata. Defaults to ``True`` @@ -224,6 +526,15 @@ def combine_echodata( - the values are not identical - the keys ``date_created`` or ``conversion_time`` do not have the same types + RuntimeError + If any ``EchoData`` group has a ``channel`` dimension value + with a duplicate value. + RuntimeError + If ``channel_selection=None`` and the ``channel`` dimensions are not the + same across the same group under each object in ``echodatas``. + NotImplementedError + If ``channel_selection`` is a list and the listed channels are not contained + in the ``EchoData`` group across all objects in ``echodatas``. Notes ----- @@ -242,7 +553,7 @@ def combine_echodata( -------- Combine lazy loaded ``EchoData`` objects: - >>> ed1 = echopype.open_converted("file1.nc") + >>> ed1 = echopype.open_converted("file1.zarr") >>> ed2 = echopype.open_converted("file2.zarr") >>> combined = echopype.combine_echodata(echodatas=[ed1, ed2], >>> zarr_path="path/to/combined.zarr", @@ -283,6 +594,12 @@ def combine_echodata( # Ensure the list of all EchoData objects to be combined are valid sonar_model, echodata_filenames = check_echodatas_input(echodatas) + # make sure channel_selection is the appropriate type and only contains the beam groups + _check_channel_selection_form(channel_selection) + + # perform channel check and get channel selection for each EchoData group + ed_group_chan_sel = _check_echodata_channels(echodatas, channel_selection) + # initiate ZarrCombine object comb = ZarrCombine() @@ -293,6 +610,7 @@ def combine_echodata( storage_options=storage_options, sonar_model=sonar_model, echodata_filenames=echodata_filenames, + ed_group_chan_sel=ed_group_chan_sel, consolidated=consolidated, ) diff --git a/echopype/echodata/zarr_combine.py b/echopype/echodata/zarr_combine.py index 59fdd184c3..1a9a36d1d0 100644 --- a/echopype/echodata/zarr_combine.py +++ b/echopype/echodata/zarr_combine.py @@ -1,6 +1,6 @@ from collections import defaultdict from itertools import islice -from typing import Any, Dict, Hashable, List, Set, Tuple +from typing import Any, Dict, Hashable, List, Optional, Set, Tuple import dask import dask.array @@ -97,43 +97,6 @@ def _check_ascending_ds_times(self, ds_list: List[xr.Dataset], ed_name: str) -> f"group {ed_name}, combine cannot be used!" ) - @staticmethod - def _check_channels(ds_list: List[xr.Dataset], ed_name: str) -> None: - """ - Makes sure that each Dataset in ``ds_list`` has the - same number of channels and the same name for each - of these channels. - - Parameters - ---------- - ds_list: list of xr.Dataset - List of Datasets to be combined - ed_name: str - The name of the ``EchoData`` group being combined - """ - - if "channel" in ds_list[0].dims: - - # check to make sure we have the same number of channels in each ds - if np.unique([len(ds["channel"].values) for ds in ds_list]).size == 1: - - # make each array an element of a numpy array - channel_arrays = np.array([ds["channel"].values for ds in ds_list]) - - # check for unique rows - if np.unique(channel_arrays, axis=0).shape[0] > 1: - - raise RuntimeError( - f"All {ed_name} groups do not have that same channel coordinate, " - f"combine cannot be used!" - ) - - else: - raise RuntimeError( - f"All {ed_name} groups do not have that same number of channel coordinates, " - f"combine cannot be used!" - ) - @staticmethod def _compare_attrs(attr1: dict, attr2: dict) -> List[str]: """ @@ -242,7 +205,6 @@ def _get_ds_info(self, ds_list: List[xr.Dataset], ed_name: str) -> None: """ self._check_ascending_ds_times(ds_list, ed_name) - self._check_channels(ds_list, ed_name) # Dataframe with column as dim names and rows as the different Datasets self.dims_df = pd.DataFrame([ds.dims for ds in ds_list]) @@ -943,6 +905,7 @@ def combine( storage_options: Dict[str, Any] = {}, sonar_model: str = None, echodata_filenames: List[str] = [], + ed_group_chan_sel: Dict[str, Optional[List[str]]] = {}, consolidated: bool = True, ) -> EchoData: """ @@ -956,6 +919,7 @@ def combine( The full path of the final combined zarr store eds: list of EchoData object The list of ``EchoData`` objects to be combined + The list of ``EchoData`` objects to be combined storage_options: dict Any additional parameters for the storage backend (ignored for local paths) @@ -963,6 +927,11 @@ def combine( The sonar model used for all elements in ``eds`` echodata_filenames : list of str The source files names for all elements in ``eds`` + ed_group_chan_sel: dict + A dictionary with keys corresponding to the ``EchoData`` groups + and values specify what channels should be selected within that + group. If a value is ``None``, then a subset of channels should + not be selected. consolidated: bool Flag to consolidate zarr metadata. Defaults to ``True`` @@ -1010,8 +979,16 @@ def combine( else: ed_group = "Top-level" - # collect the group Dataset from all eds - ds_list = [ed[ed_group] for ed in eds if ed_group in ed.group_paths] + # collect the group Dataset from all eds that have their channels unselected + all_chan_ds_list = [ed[ed_group] for ed in eds if ed_group in ed.group_paths] + + # select only the appropriate channels from each Dataset + ds_list = [ + ds.sel(channel=ed_group_chan_sel[ed_group]) + if ed_group_chan_sel[ed_group] is not None + else ds + for ds in all_chan_ds_list + ] if ds_list: # necessary because a group may not be present diff --git a/echopype/tests/echodata/test_echodata_combine.py b/echopype/tests/echodata/test_echodata_combine.py index aea73e9893..6c3f5fe951 100644 --- a/echopype/tests/echodata/test_echodata_combine.py +++ b/echopype/tests/echodata/test_echodata_combine.py @@ -12,6 +12,9 @@ import tempfile from dask.distributed import Client +from echopype.echodata.combine import _create_channel_selection_dict, _check_echodata_channels, \ + _check_channel_consistency + @pytest.fixture def ek60_diff_range_sample_test_data(test_path): @@ -217,6 +220,19 @@ def test_combine_echodata(raw_datasets): client.close() +def test_combine_echodata_channel_selection(): + """ + This test ensures that the ``channel_selection`` input + of ``combine_echodata`` is producing the correct output + for all sonar models except AD2CP. + """ + + # TODO: Once a mock EchoData structure can be easily formed, + # we should implement this test. + + pytest.skip("This test will not be implemented until after a mock EchoData object can be created.") + + def test_attr_storage(ek60_test_data): # check storage of attributes before combination in provenance group eds = [echopype.open_raw(file, "EK60") for file in ek60_test_data] @@ -345,3 +361,151 @@ def test_combined_echodata_repr(ek60_test_data): # close client client.close() + + +@pytest.mark.parametrize( + ("all_chan_list", "channel_selection"), + [ + ( + [['a', 'b', 'c'], ['a', 'b', 'c']], + None + ), + pytest.param( + [['a', 'b', 'c'], ['a', 'b']], + None, + marks=pytest.mark.xfail(strict=True, + reason="This test should not pass because the channels are not consistent") + ), + ( + [['a', 'b', 'c'], ['a', 'b', 'c']], + ['a', 'b', 'c'] + ), + ( + [['a', 'b', 'c'], ['a', 'b', 'c']], + ['a', 'b'] + ), + ( + [['a', 'b', 'c'], ['a', 'b']], + ['a', 'b'] + ), + pytest.param( + [['a', 'c'], ['a', 'b', 'c']], + ['a', 'b'], + marks=pytest.mark.xfail(strict=True, + reason="This test should not pass because we are selecting " + "channels that do not occur in each Dataset") + ), + ], + ids=["chan_sel_none_pass", "chan_sel_none_fail", + "chan_sel_same_as_given_chans", "chan_sel_subset_of_given_chans", + "chan_sel_subset_of_given_chans_uneven", "chan_sel_diff_from_some_given_chans"] +) +def test_check_channel_consistency(all_chan_list, channel_selection): + """ + Ensures that the channel consistency check for combine works + as expected using mock data. + """ + + _check_channel_consistency(all_chan_list, "test_group", channel_selection) + + +# create duplicated dictionaries used within pytest parameterize +has_chan_dim_1_beam = {'Top-level': False, 'Environment': False, 'Platform': True, + 'Platform/NMEA': False, 'Provenance': False, 'Sonar': True, + 'Sonar/Beam_group1': True, 'Vendor_specific': True} + +has_chan_dim_2_beam = {'Top-level': False, 'Environment': False, 'Platform': True, + 'Platform/NMEA': False, 'Provenance': False, 'Sonar': True, + 'Sonar/Beam_group1': True, 'Sonar/Beam_group2': True, 'Vendor_specific': True} + +expected_1_beam_none = {'Top-level': None, 'Environment': None, 'Platform': None, + 'Platform/NMEA': None, 'Provenance': None, 'Sonar': None, + 'Sonar/Beam_group1': None, 'Vendor_specific': None} + +expected_1_beam_a_b_sel = {'Top-level': None, 'Environment': None, 'Platform': ['a', 'b'], + 'Platform/NMEA': None, 'Provenance': None, 'Sonar': ['a', 'b'], + 'Sonar/Beam_group1': ['a', 'b'], 'Vendor_specific': ['a', 'b']} + + +@pytest.mark.parametrize( + ("sonar_model", "has_chan_dim", "user_channel_selection", "expected_dict"), + [ + ( + ["EK60", "ES70", "AZFP"], + has_chan_dim_1_beam, + [None], + expected_1_beam_none + ), + ( + ["EK80", "ES80", "EA640"], + has_chan_dim_1_beam, + [None], + expected_1_beam_none + ), + ( + ["EK80", "ES80", "EA640"], + has_chan_dim_2_beam, + [None], + {'Top-level': None, 'Environment': None, 'Platform': None, 'Platform/NMEA': None, + 'Provenance': None, 'Sonar': None, 'Sonar/Beam_group1': None, + 'Sonar/Beam_group2': None, 'Vendor_specific': None} + ), + ( + ["EK60", "ES70", "AZFP"], + has_chan_dim_1_beam, + [['a', 'b'], {'Sonar/Beam_group1': ['a', 'b']}], + expected_1_beam_a_b_sel + ), + ( + ["EK80", "ES80", "EA640"], + has_chan_dim_1_beam, + [['a', 'b'], {'Sonar/Beam_group1': ['a', 'b']}], + expected_1_beam_a_b_sel + ), + ( + ["EK80", "ES80", "EA640"], + has_chan_dim_2_beam, + [['a', 'b']], + {'Top-level': None, 'Environment': None, 'Platform': ['a', 'b'], 'Platform/NMEA': None, + 'Provenance': None, 'Sonar': ['a', 'b'], 'Sonar/Beam_group1': ['a', 'b'], + 'Sonar/Beam_group2': ['a', 'b'], 'Vendor_specific': ['a', 'b']} + ), + ( + ["EK80", "ES80", "EA640"], + has_chan_dim_2_beam, + [{'Sonar/Beam_group1': ['a', 'b'], 'Sonar/Beam_group2': ['c', 'd']}], + {'Top-level': None, 'Environment': None, 'Platform': ['a', 'b', 'c', 'd'], 'Platform/NMEA': None, + 'Provenance': None, 'Sonar': ['a', 'b', 'c', 'd'], 'Sonar/Beam_group1': ['a', 'b'], + 'Sonar/Beam_group2': ['c', 'd'], 'Vendor_specific': ['a', 'b', 'c', 'd']} + ), + ( + ["EK80", "ES80", "EA640"], + has_chan_dim_2_beam, + [{'Sonar/Beam_group1': ['a', 'b'], 'Sonar/Beam_group2': ['b', 'c', 'd']}], + {'Top-level': None, 'Environment': None, 'Platform': ['a', 'b', 'c', 'd'], 'Platform/NMEA': None, + 'Provenance': None, 'Sonar': ['a', 'b', 'c', 'd'], 'Sonar/Beam_group1': ['a', 'b'], + 'Sonar/Beam_group2': ['b', 'c', 'd'], 'Vendor_specific': ['a', 'b', 'c', 'd']} + ), + ], + ids=["EK60_no_sel", "EK80_no_sel_1_beam", "EK80_no_sel_2_beam", "EK60_chan_sel", + "EK80_chan_sel_1_beam", "EK80_list_chan_sel_2_beam", "EK80_dict_chan_sel_2_beam_diff_beam_group_chans", + "EK80_dict_chan_sel_2_beam_overlap_beam_group_chans"] +) +def test_create_channel_selection_dict(sonar_model, has_chan_dim, + user_channel_selection, expected_dict): + """ + Ensures that ``create_channel_selction_dict`` is constructing the correct output + for the sonar models ``EK60, EK80, AZFP`` and varying inputs for the input + ``user_channel_selection``. + + Notes + ----- + The input ``has_chan_dim`` is unchanged except for the case where we are considering + an EK80 sonar model with two beam groups. + """ + + for model in sonar_model: + for usr_sel_chan in user_channel_selection: + + channel_selection_dict = _create_channel_selection_dict(model, has_chan_dim, usr_sel_chan) + assert channel_selection_dict == expected_dict diff --git a/echopype/tests/echodata/test_zarr_combine.py b/echopype/tests/echodata/test_zarr_combine.py index 3ae8bb1bf8..b9ce7d391b 100644 --- a/echopype/tests/echodata/test_zarr_combine.py +++ b/echopype/tests/echodata/test_zarr_combine.py @@ -7,7 +7,7 @@ import echopype from echopype.utils.coding import set_time_encodings from pathlib import Path -from echopype.echodata.combine import check_echodatas_input, check_zarr_path +from echopype.echodata.combine import check_echodatas_input, check_zarr_path, _check_echodata_channels from typing import List, Tuple, Dict import tempfile import pytest @@ -400,16 +400,19 @@ def test_combine_consolidated(self, ek60_test_data, consolidated): _, echodata_filenames = check_echodatas_input(eds) + # get channel selection for each EchoData group + ed_group_chan_sel = _check_echodata_channels(eds, user_channel_selection=None) + # create dask client client = Client() - # combined = echopype.combine_echodata(eds, zarr_file_name, client=client) # combine all elements in echodatas by writing to a zarr store combined_echodata = zarr_combine.combine( zarr_path, eds, sonar_model=self.sonar_model, echodata_filenames=echodata_filenames, + ed_group_chan_sel=ed_group_chan_sel, consolidated=consolidated, ) From 584dfff3f9b54889bcae4601b804c965062c6560 Mon Sep 17 00:00:00 2001 From: b-reyes <53541061+b-reyes@users.noreply.github.com> Date: Fri, 16 Dec 2022 08:24:53 -0800 Subject: [PATCH 2/8] Implement frequency-differencing mask (#901) * add mask sub-package and begin creating frequency_difference function * establish checks for the inputs of frequency_difference * split frequency_difference input checks into different functions and perform the frequency differencing assuming Sv is a Dataset * allow source_Sv to be a DataArray * start working on test for mask produced by frequency_difference * change frequency_difference so that it does not accept a DataArray and assign a name to the returned DataArray * move validate_source_Sv to echopype/utils/data_proc_lvls.py and create a simple test for it * start constructing a test for check_source_Sv_freq_diff * improve checks in _check_source_Sv_freq_diff and complete the tests for the function * add pytest parameterize to test_frequency_difference_mask * remove notes docstring from test_check_source_Sv_freq_diff * complete frequency_difference_mask test * finish documentation for frequency_difference function * add mask sub-package to Docs API reference * include >>> for spaces in frequence_difference function example docstrings, so that all code can be copied at once * add three dots for continuation of docstring code block and add echopype.mask to frequency_difference function call in docs * add a missing ellipses in example docstring * improve input file_type docstring in validate_source_Sv Co-authored-by: Wu-Jung Lee * rename validate_source_Sv to validate_source_ds and remove Sv references throughout the function * move validate_source_ds to utils.io and the coresponding test to test_utils_io.py * fix bug associated with checking for repeated channel or frequency_nominal values * change frequency_differencing DataArray output name to mask Co-authored-by: Wu-Jung Lee * change DataArray name check in test_frequency_difference_mask * change frequency_difference function name to frequency_differencing Co-authored-by: Wu-Jung Lee * change all frequency_difference occurances to frequency_differencing * Modify frequency_differencing so that the disclaimer that either freqAB or chanAB should be provided, but not both is in the docstrings for the variables Co-authored-by: Wu-Jung Lee * remove whitespace caused by GH commit * move validation of source_Sv and opening up of Dataset out of _check_source_Sv_freq_diff and into frequency_differencing * replace get_positions method for determining if freqAB or chanAB are in source_Sv with a simpler list comprehension check * allow variable number of channels/frequencies in get_mock_freq_diff_data and modify tests accordingly * improve documentation of test_utils_io.py/test_validate_source_ds Co-authored-by: Wu-Jung Lee Co-authored-by: Wu-Jung Lee --- docs/source/api.rst | 7 + echopype/__init__.py | 3 +- echopype/mask/__init__.py | 3 + echopype/mask/api.py | 279 ++++++++++++++++++++++++++ echopype/tests/mask/test_mask.py | 224 +++++++++++++++++++++ echopype/tests/utils/test_utils_io.py | 42 +++- echopype/utils/io.py | 58 +++++- 7 files changed, 613 insertions(+), 3 deletions(-) create mode 100644 echopype/mask/__init__.py create mode 100644 echopype/mask/api.py create mode 100644 echopype/tests/mask/test_mask.py diff --git a/docs/source/api.rst b/docs/source/api.rst index 8205605e12..650a16081f 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -72,6 +72,13 @@ qc :no-inheritance-diagram: :no-heading: +mask +^^^^ + +.. automodapi:: echopype.mask + :no-inheritance-diagram: + :no-heading: + Utilities --------- diff --git a/echopype/__init__.py b/echopype/__init__.py index 93d9a3be9d..523995343d 100644 --- a/echopype/__init__.py +++ b/echopype/__init__.py @@ -2,7 +2,7 @@ from _echopype_version import version as __version__ # noqa -from . import calibrate, consolidate, preprocess, utils +from . import calibrate, consolidate, mask, preprocess, utils from .convert.api import open_raw from .core import init_ep_dir from .echodata.api import open_converted @@ -21,6 +21,7 @@ "combine_echodata", "calibrate", "consolidate", + "mask", "preprocess", "utils", "verbose", diff --git a/echopype/mask/__init__.py b/echopype/mask/__init__.py new file mode 100644 index 0000000000..6e7e529c8c --- /dev/null +++ b/echopype/mask/__init__.py @@ -0,0 +1,3 @@ +from .api import frequency_differencing + +__all__ = ["frequency_differencing"] diff --git a/echopype/mask/api.py b/echopype/mask/api.py new file mode 100644 index 0000000000..0e060f0e08 --- /dev/null +++ b/echopype/mask/api.py @@ -0,0 +1,279 @@ +import operator as op +import pathlib +from typing import List, Optional, Union + +import numpy as np +import xarray as xr + +from ..utils.io import validate_source_ds + +# lookup table with key string operator and value as corresponding Python operator +str2ops = { + ">": op.gt, + "<": op.lt, + "<=": op.le, + ">=": op.ge, + "==": op.eq, +} + + +def _check_freq_diff_non_data_inputs( + freqAB: Optional[List[float]] = None, + chanAB: Optional[List[str]] = None, + operator: str = ">", + diff: Union[float, int] = None, +) -> None: + """ + Checks that the non-data related inputs of ``frequency_differencing`` (i.e. ``freqAB``, + ``chanAB``, ``operator``, ``diff``) were correctly provided. + + Parameters + ---------- + freqAB: list of float, optional + The pair of nominal frequencies to be used for frequency-differencing, where + the first element corresponds to ``freqA`` and the second element corresponds + to ``freqB`` + chanAB: list of float, optional + The pair of channels that will be used to select the nominal frequencies to be + used for frequency-differencing, where the first element corresponds to ``freqA`` + and the second element corresponds to ``freqB`` + operator: {">", "<", "<=", ">=", "=="} + The operator for the frequency-differencing + diff: float or int + The threshold of Sv difference between frequencies + """ + + # check that either freqAB or chanAB are provided and they are a list of length 2 + if (freqAB is None) and (chanAB is None): + raise RuntimeError("Either freqAB or chanAB must be given!") + elif (freqAB is not None) and (chanAB is not None): + raise RuntimeError("Only freqAB or chanAB must be given, but not both!") + elif freqAB is not None: + if not isinstance(freqAB, list): + raise TypeError("freqAB must be a list!") + elif len(set(freqAB)) != 2: + raise RuntimeError("freqAB must be a list of length 2 with unique elements!") + else: + if not isinstance(chanAB, list): + raise TypeError("chanAB must be a list!") + elif len(set(chanAB)) != 2: + raise RuntimeError("chanAB must be a list of length 2 with unique elements!") + + # check that operator is a string and a valid operator + if not isinstance(operator, str): + raise TypeError("operator must be a string!") + else: + if operator not in [">", "<", "<=", ">=", "=="]: + raise RuntimeError("Invalid operator!") + + # ensure that diff is a float or an int + if not isinstance(diff, (float, int)): + raise TypeError("diff must be a float or int!") + + +def _check_source_Sv_freq_diff( + source_Sv: xr.Dataset, + freqAB: Optional[List[float]] = None, + chanAB: Optional[List[str]] = None, +) -> None: + """ + Ensures that ``source_Sv`` contains ``channel`` as a coordinate and + ``frequency_nominal`` as a variable, the provided list input + (``freqAB`` or ``chanAB``) are contained in the coordinate ``channel`` + or variable ``frequency_nominal``, and ``source_Sv`` does not have + repeated values for ``channel`` and ``frequency_nominal``. + + Parameters + ---------- + source_Sv: xr.Dataset + A Dataset that contains the Sv data to create a mask for + freqAB: list of float, optional + The pair of nominal frequencies to be used for frequency-differencing, where + the first element corresponds to ``freqA`` and the second element corresponds + to ``freqB`` + chanAB: list of float, optional + The pair of channels that will be used to select the nominal frequencies to be + used for frequency-differencing, where the first element corresponds to ``freqA`` + and the second element corresponds to ``freqB`` + """ + + # check that channel and frequency nominal are in source_Sv + if "channel" not in source_Sv.coords: + raise RuntimeError("The Dataset defined by source_Sv must have channel as a coordinate!") + elif "frequency_nominal" not in source_Sv.variables: + raise RuntimeError( + "The Dataset defined by source_Sv must have frequency_nominal as a variable!" + ) + + # make sure that the channel and frequency_nominal values are not repeated in source_Sv + if len(set(source_Sv.channel.values)) < source_Sv.channel.size: + raise RuntimeError( + "The provided source_Sv contains repeated channel values, " "this is not allowed!" + ) + + if len(set(source_Sv.frequency_nominal.values)) < source_Sv.frequency_nominal.size: + raise RuntimeError( + "The provided source_Sv contains repeated frequency_nominal " + "values, this is not allowed!" + ) + + # check that the elements of freqAB are in frequency_nominal + if (freqAB is not None) and (not all([freq in source_Sv.frequency_nominal for freq in freqAB])): + raise RuntimeError( + "The provided list input freqAB contains values that " + "are not in the frequency_nominal variable!" + ) + + # check that the elements of chanAB are in channel + if (chanAB is not None) and (not all([chan in source_Sv.channel for chan in chanAB])): + raise RuntimeError( + "The provided list input chanAB contains values that are " + "not in the channel coordinate!" + ) + + +def frequency_differencing( + source_Sv: Union[xr.Dataset, str, pathlib.Path], + storage_options: Optional[dict] = {}, + freqAB: Optional[List[float]] = None, + chanAB: Optional[List[str]] = None, + operator: str = ">", + diff: Union[float, int] = None, +) -> xr.DataArray: + """ + Create a mask based on the differences of Sv values using a pair of + frequencies. This method is often referred to as the "frequency-differencing" + or "dB-differencing" method. + + Parameters + ---------- + source_Sv: xr.Dataset or str or pathlib.Path + If a Dataset this value contains the Sv data to create a mask for, + else it specifies the path to a zarr or netcdf file containing + a Dataset. This input must correspond to a Dataset that has the + coordinate ``channel`` and variables ``frequency_nominal`` and ``Sv``. + storage_options: dict, optional + Any additional parameters for the storage backend, corresponding to the + path provided for ``source_Sv`` + freqAB: list of float, optional + The pair of nominal frequencies to be used for frequency-differencing, where + the first element corresponds to ``freqA`` and the second element corresponds + to ``freqB``. Only one of ``freqAB`` and ``chanAB`` should be provided, and not both. + chanAB: list of float, optional + The pair of channels that will be used to select the nominal frequencies to be + used for frequency-differencing, where the first element corresponds to ``freqA`` + and the second element corresponds to ``freqB``. Only one of ``freqAB`` and ``chanAB`` + should be provided, and not both. + operator: {">", "<", "<=", ">=", "=="} + The operator for the frequency-differencing + diff: float or int + The threshold of Sv difference between frequencies + + Returns + ------- + xr.DataArray + A DataArray containing the mask for the Sv data. Regions satisfying the thresholding + criteria are filled with ``True``, else the regions are filled with ``False``. + + Raises + ------ + RuntimeError + If neither ``freqAB`` or ``chanAB`` are given + RuntimeError + If both ``freqAB`` and ``chanAB`` are given + TypeError + If any input is not of the correct type + RuntimeError + If either ``freqAB`` or ``chanAB`` are provided and the list + does not contain 2 distinct elements + RuntimeError + If ``freqAB`` contains values that are not contained in ``frequency_nominal`` + RuntimeError + If ``chanAB`` contains values that not contained in ``channel`` + RuntimeError + If ``operator`` is not one of the following: ``">", "<", "<=", ">=", "=="`` + RuntimeError + If the path provided for ``source_Sv`` is not a valid path + RuntimeError + If ``freqAB`` or ``chanAB`` is provided and the Dataset produced by ``source_Sv`` + does not contain the coordinate ``channel`` and variable ``frequency_nominal`` + + Notes + ----- + This function computes the frequency differencing as follows: + ``Sv_freqA - Sv_freqB operator diff``. Thus, if ``operator = "<"`` + and ``diff = "5"`` the following would be calculated: + ``Sv_freqA - Sv_freqB < 5``. + + Examples + -------- + Compute frequency-differencing mask using a mock Dataset and channel selection: + + >>> n = 5 # set the number of ping times and range samples + ... + >>> # create mock Sv data + >>> Sv_da = xr.DataArray(data=np.stack([np.arange(n**2).reshape(n,n), np.identity(n)]), + ... coords={"channel": ['chan1', 'chan2'], + ... "ping_time": np.arange(n), "range_sample":np.arange(n)}) + ... + >>> # obtain mock frequency_nominal data + >>> freq_nom = xr.DataArray(data=np.array([1.0, 2.0]), + ... coords={"channel": ['chan1', 'chan2']}) + ... + >>> # construct mock Sv Dataset + >>> Sv_ds = xr.Dataset(data_vars={"Sv": Sv_da, "frequency_nominal": freq_nom}) + ... + >>> # compute frequency-differencing mask using channel names + >>> echopype.mask.frequency_differencing(source_Sv=mock_Sv_ds, storage_options={}, freqAB=None, + ... chanAB = ['chan1', 'chan2'], + ... operator = ">=", diff=10.0) + + array([[False, False, False, False, False], + [False, False, False, False, False], + [ True, True, True, True, True], + [ True, True, True, True, True], + [ True, True, True, True, True]]) + Coordinates: + * ping_time (ping_time) int64 0 1 2 3 4 + * range_sample (range_sample) int64 0 1 2 3 4 + """ + + # check that non-data related inputs were correctly provided + _check_freq_diff_non_data_inputs(freqAB, chanAB, operator, diff) + + # validate the source_Sv type or path (if it is provided) + source_Sv, file_type = validate_source_ds(source_Sv, storage_options) + + if isinstance(source_Sv, str): + # open up Dataset using source_Sv path + source_Sv = xr.open_dataset(source_Sv, engine=file_type, chunks="auto", **storage_options) + + # check the source_Sv with respect to channel and frequency_nominal + _check_source_Sv_freq_diff(source_Sv, freqAB, chanAB) + + # determine chanA and chanB + if freqAB is not None: + + # obtain position of frequency provided in frequency_nominal + freqA_pos = np.argwhere(source_Sv.frequency_nominal.values == freqAB[0]).flatten()[0] + freqB_pos = np.argwhere(source_Sv.frequency_nominal.values == freqAB[1]).flatten()[0] + + # get channel corresponding to frequency provided + chanA = source_Sv.channel.isel(channel=freqA_pos) + chanB = source_Sv.channel.isel(channel=freqB_pos) + + else: + # get individual channels + chanA = chanAB[0] + chanB = chanAB[1] + + # get the left-hand side of condition + lhs = source_Sv["Sv"].sel(channel=chanA) - source_Sv["Sv"].sel(channel=chanB) + + # create mask using operator lookup table + da = xr.where(str2ops[operator](lhs, diff), True, False) + + # assign a name to DataArray + da.name = "mask" + + return da diff --git a/echopype/tests/mask/test_mask.py b/echopype/tests/mask/test_mask.py new file mode 100644 index 0000000000..543ce2722f --- /dev/null +++ b/echopype/tests/mask/test_mask.py @@ -0,0 +1,224 @@ +import pytest + +import numpy as np +import xarray as xr + +import echopype as ep +from echopype.mask.api import _check_source_Sv_freq_diff + +from typing import List, Union + + +def get_mock_freq_diff_data(n: int, n_chan_freq: int, add_chan: bool, + add_freq_nom: bool) -> xr.Dataset: + """ + Creates an in-memory mock Sv Dataset. + + Parameters + ---------- + n: int + The number of rows (``ping_time``) and columns (``range_sample``) of + each channel matrix + n_chan_freq: int + Determines the size of the ``channel`` coordinate and ``frequency_nominal`` + variable. To create mock data with known outcomes for ``frequency_differencing``, + this value must be greater than or equal to 3. + add_chan: bool + If True the ``channel`` dimension will be named "channel", else it will + be named "data_coord" + add_freq_nom: bool + If True the ``frequency_nominal`` variable will be added to the Dataset + + Returns + ------- + mock_Sv_ds: xr.Dataset + A mock Sv dataset to be used for ``frequency_differencing`` tests. The Sv + data values for the channel coordinate ``chan1`` will be equal to ``mat_A``, + ``chan3`` will be equal to ``mat_B``, and all other channel coordinates + will retain the value of ``np.identity(n)``. + + Notes + ----- + The mock Sv Data is created in such a way where ``mat_A - mat_B`` will be + the identity matrix. + """ + + if n_chan_freq < 3: + raise RuntimeError("The input n_chan_freq must be greater than or equal to 3!") + + # matrix representing freqB + mat_B = np.arange(n ** 2).reshape(n, n) - np.identity(n) + + # matrix representing freqA + mat_A = np.arange(n ** 2).reshape(n, n) + + # construct channel values + chan_vals = ['chan' + str(i) for i in range(1, n_chan_freq+1)] + + # construct mock Sv data + mock_Sv_data = [mat_A, np.identity(n), mat_B] + [np.identity(n) for i in range(3, n_chan_freq)] + + # set channel coordinate name (used for testing purposes) + if not add_chan: + channel_coord_name = "data_coord" + else: + channel_coord_name = "channel" + + # create mock Sv DataArray + mock_Sv_da = xr.DataArray(data=np.stack(mock_Sv_data), + coords={channel_coord_name: chan_vals, "ping_time": np.arange(n), + "range_sample": np.arange(n)}) + + # create data variables for the Dataset + data_vars = {"Sv": mock_Sv_da} + + if add_freq_nom: + # construct frequency_values + freq_vals = [float(i) for i in range(1, n_chan_freq + 1)] + + # create mock frequency_nominal and add it to the Dataset variables + mock_freq_nom = xr.DataArray(data=freq_vals, coords={channel_coord_name: chan_vals}) + data_vars["frequency_nominal"] = mock_freq_nom + + # create mock Dataset with Sv and frequency_nominal + mock_Sv_ds = xr.Dataset(data_vars=data_vars) + + return mock_Sv_ds + + +@pytest.mark.parametrize( + ("n", "n_chan_freq", "add_chan", "add_freq_nom", "freqAB", "chanAB"), + [ + (5, 3, True, True, [1.0, 3.0], None), + (5, 3, True, True, None, ['chan1', 'chan3']), + pytest.param(5, 3, False, True, [1.0, 3.0], None, + marks=pytest.mark.xfail(strict=True, + reason="This should fail because the Dataset " + "will not have the channel coordinate.")), + pytest.param(5, 3, True, False, [1.0, 3.0], None, + marks=pytest.mark.xfail(strict=True, + reason="This should fail because the Dataset " + "will not have the frequency_nominal variable.")), + pytest.param(5, 3, True, True, [1.0, 4.0], None, + marks=pytest.mark.xfail(strict=True, + reason="This should fail because not all selected frequencies" + "are in the frequency_nominal variable.")), + pytest.param(5, 3, True, True, None, ['chan1', 'chan4'], + marks=pytest.mark.xfail(strict=True, + reason="This should fail because not all selected channels" + "are in the channel coordinate.")), + ], + ids=["dataset_input_freqAB_provided", "dataset_input_chanAB_provided", "dataset_no_channel", + "dataset_no_frequency_nominal", "dataset_missing_freqAB_in_freq_nom", + "dataset_missing_chanAB_in_channel"] +) +def test_check_source_Sv_freq_diff(n: int, n_chan_freq: int, add_chan: bool, add_freq_nom: bool, + freqAB: List[float], + chanAB: List[str]): + """ + Test the inputs ``source_Sv, freqAB, chanAB`` for ``_check_source_Sv_freq_diff``. + + Parameters + ---------- + n: int + The number of rows (``ping_time``) and columns (``range_sample``) of + each channel matrix + n_chan_freq: int + Determines the size of the ``channel`` coordinate and ``frequency_nominal`` + variable. To create mock data with known outcomes for ``frequency_differencing``, + this value must be greater than or equal to 3. + add_chan: bool + If True the ``channel`` dimension will be named "channel", else it will + be named "data_coord" + add_freq_nom: bool + If True the ``frequency_nominal`` variable will be added to the Dataset + freqAB: list of float, optional + The pair of nominal frequencies to be used for frequency-differencing, where + the first element corresponds to ``freqA`` and the second element corresponds + to ``freqB`` + chanAB: list of float, optional + The pair of channels that will be used to select the nominal frequencies to be + used for frequency-differencing, where the first element corresponds to ``freqA`` + and the second element corresponds to ``freqB`` + """ + + source_Sv = get_mock_freq_diff_data(n, n_chan_freq, add_chan, add_freq_nom) + + _check_source_Sv_freq_diff(source_Sv, freqAB=freqAB, chanAB=chanAB) + + +@pytest.mark.parametrize( + ("n", "n_chan_freq", "freqAB", "chanAB", "diff", "operator", "mask_truth"), + [ + (5, 4, [1.0, 3.0], None, 1.0, "==", np.identity(5)), + (5, 4, None, ['chan1', 'chan3'], 1.0, "==", np.identity(5)), + (5, 4, [3.0, 1.0], None, 1.0, "==", np.zeros((5, 5))), + (5, 4, None, ['chan3', 'chan1'], 1.0, "==", np.zeros((5, 5))), + (5, 4, [1.0, 3.0], None, 1.0, ">=", np.identity(5)), + (5, 4, None, ['chan1', 'chan3'], 1.0, ">=", np.identity(5)), + (5, 4, [1.0, 3.0], None, 1.0, ">", np.zeros((5, 5))), + (5, 4, None, ['chan1', 'chan3'], 1.0, ">", np.zeros((5, 5))), + (5, 4, [1.0, 3.0], None, 1.0, "<=", np.ones((5, 5))), + (5, 4, None, ['chan1', 'chan3'], 1.0, "<=", np.ones((5, 5))), + (5, 4, [1.0, 3.0], None, 1.0, "<", np.ones((5, 5)) - np.identity(5)), + (5, 4, None, ['chan1', 'chan3'], 1.0, "<", np.ones((5, 5)) - np.identity(5)), + ], + ids=["freqAB_sel_op_equals", "chanAB_sel_op_equals", "reverse_freqAB_sel_op_equals", + "reverse_chanAB_sel_op_equals", "freqAB_sel_op_ge", "chanAB_sel_op_ge", + "freqAB_sel_op_greater", "chanAB_sel_op_greater", "freqAB_sel_op_le", + "chanAB_sel_op_le", "freqAB_sel_op_less", "chanAB_sel_op_less"] +) +def test_frequency_differencing(n: int, n_chan_freq: int, + freqAB: List[float], chanAB: List[str], + diff: Union[float, int], operator: str, + mask_truth: np.ndarray): + """ + Tests that the output values of ``frequency_differencing`` are what we + expect, the output is a DataArray, and that the name of the DataArray is correct. + + Parameters + ---------- + n: int + The number of rows (``ping_time``) and columns (``range_sample``) of + each channel matrix + n_chan_freq: int + Determines the size of the ``channel`` coordinate and ``frequency_nominal`` + variable. To create mock data with known outcomes for ``frequency_differencing``, + this value must be greater than or equal to 3. + freqAB: list of float, optional + The pair of nominal frequencies to be used for frequency-differencing, where + the first element corresponds to ``freqA`` and the second element corresponds + to ``freqB`` + chanAB: list of float, optional + The pair of channels that will be used to select the nominal frequencies to be + used for frequency-differencing, where the first element corresponds to ``freqA`` + and the second element corresponds to ``freqB`` + diff: float or int + The threshold of Sv difference between frequencies + operator: {">", "<", "<=", ">=", "=="} + The operator for the frequency-differencing + mask_truth: np.ndarray + The truth value for the output mask, provided the given inputs + """ + + # obtain mock Sv Dataset + mock_Sv_ds = get_mock_freq_diff_data(n, n_chan_freq, add_chan=True, add_freq_nom=True) + + # obtain the frequency-difference mask for mock_Sv_ds + out = ep.mask.frequency_differencing(source_Sv=mock_Sv_ds, storage_options={}, freqAB=freqAB, + chanAB=chanAB, + operator=operator, diff=diff) + + # ensure that the output values are correct + assert np.all(out == mask_truth) + + # ensure that the output is a DataArray + assert isinstance(out, xr.DataArray) + + # test that the output DataArray is correctly names + assert out.name == "mask" + + + + + diff --git a/echopype/tests/utils/test_utils_io.py b/echopype/tests/utils/test_utils_io.py index 20effbcdca..1173f13ae9 100644 --- a/echopype/tests/utils/test_utils_io.py +++ b/echopype/tests/utils/test_utils_io.py @@ -4,8 +4,9 @@ import pytest from typing import Tuple import platform +import xarray as xr -from echopype.utils.io import sanitize_file_path, validate_output_path, env_indep_joinpath +from echopype.utils.io import sanitize_file_path, validate_output_path, env_indep_joinpath, validate_source_ds @pytest.mark.parametrize( @@ -258,3 +259,42 @@ def test_env_indep_joinpath_os_dependent(save_path: str, is_windows: bool, is_cl pytest.skip("Skipping Unix parameters because we are not on a Unix machine.") +@pytest.mark.parametrize( + ("source_ds_input", "storage_options_input", "true_file_type"), + [ + pytest.param(42, {}, None, + marks=pytest.mark.xfail( + strict=True, + reason='This test should fail because source_ds is not of the correct type.') + ), + pytest.param(xr.DataArray(), {}, None, + marks=pytest.mark.xfail( + strict=True, + reason='This test should fail because source_ds is not of the correct type.') + ), + pytest.param({}, 42, None, + marks=pytest.mark.xfail( + strict=True, + reason='This test should fail because storage_options is not of the correct type.') + ), + (xr.Dataset(attrs={"test": 42}), {}, None), + (os.path.join('folder', 'new_test.nc'), {}, 'netcdf4'), + (os.path.join('folder', 'new_test.zarr'), {}, 'zarr') + ] + +) +def test_validate_source_ds(source_ds_input, storage_options_input, true_file_type): + """ + Tests that ``validate_source_ds`` has the appropriate outputs. + An exhaustive list of combinations of ``source_ds`` and ``storage_options`` + are tested in ``test_validate_output_path`` and are therefore not included here. + """ + + source_ds_output, file_type_output = validate_source_ds(source_ds_input, storage_options_input) + + if isinstance(source_ds_input, xr.Dataset): + assert source_ds_output.identical(source_ds_input) + assert file_type_output is None + else: + assert isinstance(source_ds_output, str) + assert file_type_output == true_file_type diff --git a/echopype/utils/io.py b/echopype/utils/io.py index 7b582bea4b..b9c793e76d 100644 --- a/echopype/utils/io.py +++ b/echopype/utils/io.py @@ -5,9 +5,10 @@ import platform import sys from pathlib import Path, WindowsPath -from typing import TYPE_CHECKING, Dict, Tuple, Union +from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union import fsspec +import xarray as xr from fsspec import FSMap from fsspec.implementations.local import LocalFileSystem @@ -315,3 +316,58 @@ def env_indep_joinpath(*args: Tuple[str, ...]) -> str: joined_path = os.path.join(*args) return joined_path + + +def validate_source_ds( + source_ds: Union[xr.Dataset, str, Path], storage_options: Optional[dict] +) -> Tuple[Union[xr.Dataset, str, xr.DataArray], Optional[str]]: + """ + This function ensures that ``source_ds`` is of the correct + type and validates the path of ``source_ds``, if it is provided. + + Parameters + ---------- + source_ds: xr.Dataset or str or pathlib.Path + A source that points to a Dataset. If the input is a path, it specifies + the path to a zarr or netcdf file. + storage_options: dict, optional + Any additional parameters for the storage backend, corresponding to the + path provided for ``source_ds`` + + Returns + ------- + source_ds: xr.Dataset or str + A Dataset which will be the same as the input ``source_ds`` or a validated + path to a zarr or netcdf file + file_type: {"netcdf4", "zarr"}, optional + The file type of the input path if ``source_ds`` is a path, otherwise ``None`` + """ + + # initialize file_type + file_type = None + + # make sure that storage_options is of the appropriate type + if not isinstance(storage_options, dict): + raise TypeError("storage_options must be a dict!") + + # check that source_ds is of the correct type, if it is a path validate + # the path and open the dataset using xarray + if not isinstance(source_ds, (xr.Dataset, str, Path)): + raise TypeError("source_ds must be a Dataset or str or pathlib.Path!") + elif isinstance(source_ds, (str, Path)): + + # determine if we obtained a zarr or netcdf file + file_type = get_file_format(source_ds) + + # validate source_ds if it is a path + source_ds = validate_output_path( + source_file="blank", # will be unused since source_ds cannot be none + engine=file_type, + output_storage_options=storage_options, + save_path=source_ds, + ) + + # check that the path exists + check_file_existence(file_path=source_ds, storage_options=storage_options) + + return source_ds, file_type From f03a75b9d093eafb9686e7452df0fd0e755d0a64 Mon Sep 17 00:00:00 2001 From: Wu-Jung Lee Date: Thu, 22 Dec 2022 12:50:00 -0800 Subject: [PATCH 3/8] Redo: Add `apply_mask` function to `mask` sub-package (#912) See #905 for all conversations and detailed commits. --- echopype/mask/__init__.py | 4 +- echopype/mask/api.py | 269 +++++++++++++++++++++++--- echopype/tests/mask/test_mask.py | 185 +++++++++++++++++- echopype/tests/utils/test_utils_io.py | 22 +-- echopype/utils/io.py | 46 ++--- 5 files changed, 464 insertions(+), 62 deletions(-) diff --git a/echopype/mask/__init__.py b/echopype/mask/__init__.py index 6e7e529c8c..24b6c172f8 100644 --- a/echopype/mask/__init__.py +++ b/echopype/mask/__init__.py @@ -1,3 +1,3 @@ -from .api import frequency_differencing +from .api import apply_mask, frequency_differencing -__all__ = ["frequency_differencing"] +__all__ = ["frequency_differencing", "apply_mask"] diff --git a/echopype/mask/api.py b/echopype/mask/api.py index 0e060f0e08..e45e730e66 100644 --- a/echopype/mask/api.py +++ b/echopype/mask/api.py @@ -5,7 +5,7 @@ import numpy as np import xarray as xr -from ..utils.io import validate_source_ds +from ..utils.io import validate_source_ds_da # lookup table with key string operator and value as corresponding Python operator str2ops = { @@ -17,6 +17,229 @@ } +def validate_and_collect_mask_input( + mask: Union[ + Union[xr.DataArray, str, pathlib.Path], List[Union[xr.DataArray, str, pathlib.Path]] + ], + storage_options_mask: Union[dict, List[dict]], +) -> Union[xr.DataArray, List[xr.DataArray]]: + """ + Validate that the input ``mask`` and associated ``storage_options_mask`` are correctly + provided to ``apply_mask``. Additionally, form the mask input that should be used + in the core routine of ``apply_mask``. + + Parameters + ---------- + mask: xr.DataArray, str, pathlib.Path, or a list of these datatypes + The mask(s) to be applied. Can be a single input or list that corresponds to a + DataArray or a path. If a path is provided this should point to a zarr or netcdf + file with only one data variable in it. + storage_options_mask: dict or list of dict, default={} + Any additional parameters for the storage backend, corresponding to the + path provided for ``mask``. If ``mask`` is a list, then this input should either + be a list of dictionaries or a single dictionary with storage options that + correspond to all elements in ``mask`` that are paths. + + Returns + ------- + xr.DataArray or list of xr.DataArray + If the ``mask`` input is a single value, then the corresponding DataArray will be + returned, else a list of DataArrays corresponding to the input masks will be returned + + Raises + ------ + ValueError + If ``mask`` is a single-element and ``storage_options_mask`` is not a single dict + TypeError + If ``storage_options_mask`` is not a list of dict or a dict + """ + + if isinstance(mask, list): + + # if storage_options_mask is not a list create a list of + # length len(mask) with elements storage_options_mask + if not isinstance(storage_options_mask, list): + + if not isinstance(storage_options_mask, dict): + raise TypeError("storage_options_mask must be a list of dict or a dict!") + + storage_options_mask = [storage_options_mask] * len(mask) + else: + # ensure all element of storage_options_mask are a dict + if not all([isinstance(elem, dict) for elem in storage_options_mask]): + raise TypeError("storage_options_mask must be a list of dict or a dict!") + + for mask_ind in range(len(mask)): + + # validate the mask type or path (if it is provided) + mask_val, file_type = validate_source_ds_da( + mask[mask_ind], storage_options_mask[mask_ind] + ) + + # replace mask element path with its corresponding DataArray + if isinstance(mask_val, str): + # open up DataArray using mask path + mask[mask_ind] = xr.open_dataarray( + mask_val, engine=file_type, chunks={}, **storage_options_mask[mask_ind] + ) + + else: + + if not isinstance(storage_options_mask, dict): + raise ValueError( + "The provided input storage_options_mask should be a single " + "dict because mask is a single value!" + ) + + # validate the mask type or path (if it is provided) + mask, file_type = validate_source_ds_da(mask, storage_options_mask) + + if isinstance(mask, str): + # open up DataArray using mask path + mask = xr.open_dataarray(mask, engine=file_type, chunks={}, **storage_options_mask) + + return mask + + +def _check_var_name_fill_value( + source_ds: xr.Dataset, var_name: str, fill_value: Union[int, float, np.ndarray, xr.DataArray] +) -> None: + """ + Ensures that the inputs ``var_name`` and ``fill_value`` for the function + ``apply_mask`` were appropriately provided. + + Parameters + ---------- + source_ds: xr.Dataset + A Dataset that contains the variable ``var_name`` + var_name: str + The variable name in ``source_ds`` that the mask should be applied to + fill_value: int or float or np.ndarray or xr.DataArray + Specifies the value(s) at false indices + + Raises + ------ + TypeError + If ``var_name`` or ``fill_value`` are not an accepted type + ValueError + If the Dataset ``source_ds`` does not contain ``var_name`` + ValueError + If ``fill_value`` is an array and not the same shape as ``var_name`` + """ + + # check the type of var_name + if not isinstance(var_name, str): + raise TypeError("The input var_name must be a string!") + + # ensure var_name is in source_ds + if var_name not in source_ds.variables: + raise ValueError("The Dataset source_ds does not contain the variable var_name!") + + # check the type of fill_value + if not isinstance(fill_value, (int, float, np.ndarray, xr.DataArray)): + raise TypeError( + "The input fill_value must be of type int or " "float or np.ndarray or xr.DataArray!" + ) + + # make sure that fill_values is the same shape as var_name, if it is an array + if isinstance(fill_value, (np.ndarray, xr.DataArray)) and ( + fill_value.shape != source_ds[var_name].shape + ): + raise ValueError("If fill_value is an array is must be of the same shape as var_name!") + + +def apply_mask( + source_ds: Union[xr.Dataset, str, pathlib.Path], + mask: Union[ + Union[xr.DataArray, str, pathlib.Path], List[Union[xr.DataArray, str, pathlib.Path]] + ], + var_name: str = "Sv", + fill_value: Union[int, float, np.ndarray, xr.DataArray] = np.nan, + storage_options_ds: dict = {}, + storage_options_mask: Union[dict, List[dict]] = {}, +) -> xr.Dataset: + """ + Applies the provided mask(s) to the variable ``var_name`` + in the provided Dataset ``source_ds``. + + Parameters + ---------- + source_ds: xr.Dataset, str, or pathlib.Path + Points to a Dataset that contains the variable the mask should be applied to + mask: xr.DataArray, str, pathlib.Path, or a list of these datatypes + The mask(s) to be applied. Can be a single input or list that corresponds to + a DataArray or a path. If a path is provided this should point to a zarr or + netcdf file with only one data variable in it. + var_name: str, default="Sv" + The variable name in ``source_ds`` that the mask should be applied to + fill_value: int, float, np.ndarray, or xr.DataArray, default=np.nan + Value(s) at masked indices + storage_options_ds: dict, default={} + Any additional parameters for the storage backend, corresponding to the + path provided for ``source_ds`` + storage_options_mask: dict or list of dict, default={} + Any additional parameters for the storage backend, corresponding to the + path provided for ``mask``. If ``mask`` is a list, then this input should either + be a list of dictionaries or a single dictionary with storage options that + correspond to all elements in ``mask`` that are paths. + + Returns + ------- + xr.Dataset + A Dataset with the same format of ``source_ds`` with the mask(s) applied to ``var_name`` + + Notes + ----- + If the input ``mask`` is a list, then a logical AND will be used to produce the final + mask that will be applied to ``var_name``. + """ + + # validate the source_ds type or path (if it is provided) + source_ds, file_type = validate_source_ds_da(source_ds, storage_options_ds) + + if isinstance(source_ds, str): + # open up Dataset using source_ds path + source_ds = xr.open_dataset(source_ds, engine=file_type, chunks={}, **storage_options_ds) + + # validate and form the mask input to be used downstream + mask = validate_and_collect_mask_input(mask, storage_options_mask) + + # ensure that var_name and fill_value were correctly provided + _check_var_name_fill_value(source_ds, var_name, fill_value) + + # select data only, if fill_value is a DataArray (necessary since + # xr.where(keep_attrs=True) is not functioning correctly) + if isinstance(fill_value, xr.DataArray): + fill_value = fill_value.data + + # obtain final mask to be applied to var_name + if isinstance(mask, list): + # perform a logical AND element-wise operation across the masks + final_mask = np.logical_and.reduce(mask) + + # xr.where has issues with attrs when final_mask is an array, so we make it a DataArray + final_mask = xr.DataArray(final_mask, coords=mask[0].coords) + else: + final_mask = mask + + # sanity check to make sure final_mask is the same shape as source_ds[var_name] + if final_mask.shape != source_ds[var_name].shape: + raise ValueError("Final constructed mask is not the same shape as source_ds[var_name]!") + + # apply the mask to var_name + var_name_masked = xr.where(final_mask, x=source_ds[var_name], y=fill_value, keep_attrs=True) + + # obtain a shallow copy of source_ds + output_ds = source_ds.copy(deep=False) + + # replace var_name with var_name_masked + output_ds[var_name] = var_name_masked + + # TODO: add provenance or attributes specifying that a mask was applied here! + + return output_ds + + def _check_freq_diff_non_data_inputs( freqAB: Optional[List[float]] = None, chanAB: Optional[List[str]] = None, @@ -45,26 +268,26 @@ def _check_freq_diff_non_data_inputs( # check that either freqAB or chanAB are provided and they are a list of length 2 if (freqAB is None) and (chanAB is None): - raise RuntimeError("Either freqAB or chanAB must be given!") + raise ValueError("Either freqAB or chanAB must be given!") elif (freqAB is not None) and (chanAB is not None): - raise RuntimeError("Only freqAB or chanAB must be given, but not both!") + raise ValueError("Only freqAB or chanAB must be given, but not both!") elif freqAB is not None: if not isinstance(freqAB, list): raise TypeError("freqAB must be a list!") elif len(set(freqAB)) != 2: - raise RuntimeError("freqAB must be a list of length 2 with unique elements!") + raise ValueError("freqAB must be a list of length 2 with unique elements!") else: if not isinstance(chanAB, list): raise TypeError("chanAB must be a list!") elif len(set(chanAB)) != 2: - raise RuntimeError("chanAB must be a list of length 2 with unique elements!") + raise ValueError("chanAB must be a list of length 2 with unique elements!") # check that operator is a string and a valid operator if not isinstance(operator, str): raise TypeError("operator must be a string!") else: if operator not in [">", "<", "<=", ">=", "=="]: - raise RuntimeError("Invalid operator!") + raise ValueError("Invalid operator!") # ensure that diff is a float or an int if not isinstance(diff, (float, int)): @@ -99,34 +322,34 @@ def _check_source_Sv_freq_diff( # check that channel and frequency nominal are in source_Sv if "channel" not in source_Sv.coords: - raise RuntimeError("The Dataset defined by source_Sv must have channel as a coordinate!") + raise ValueError("The Dataset defined by source_Sv must have channel as a coordinate!") elif "frequency_nominal" not in source_Sv.variables: - raise RuntimeError( + raise ValueError( "The Dataset defined by source_Sv must have frequency_nominal as a variable!" ) # make sure that the channel and frequency_nominal values are not repeated in source_Sv if len(set(source_Sv.channel.values)) < source_Sv.channel.size: - raise RuntimeError( - "The provided source_Sv contains repeated channel values, " "this is not allowed!" + raise ValueError( + "The provided source_Sv contains repeated channel values, this is not allowed!" ) if len(set(source_Sv.frequency_nominal.values)) < source_Sv.frequency_nominal.size: - raise RuntimeError( + raise ValueError( "The provided source_Sv contains repeated frequency_nominal " "values, this is not allowed!" ) # check that the elements of freqAB are in frequency_nominal if (freqAB is not None) and (not all([freq in source_Sv.frequency_nominal for freq in freqAB])): - raise RuntimeError( + raise ValueError( "The provided list input freqAB contains values that " "are not in the frequency_nominal variable!" ) # check that the elements of chanAB are in channel if (chanAB is not None) and (not all([chan in source_Sv.channel for chan in chanAB])): - raise RuntimeError( + raise ValueError( "The provided list input chanAB contains values that are " "not in the channel coordinate!" ) @@ -177,24 +400,24 @@ def frequency_differencing( Raises ------ - RuntimeError + ValueError If neither ``freqAB`` or ``chanAB`` are given - RuntimeError + ValueError If both ``freqAB`` and ``chanAB`` are given TypeError If any input is not of the correct type - RuntimeError + ValueError If either ``freqAB`` or ``chanAB`` are provided and the list does not contain 2 distinct elements - RuntimeError + ValueError If ``freqAB`` contains values that are not contained in ``frequency_nominal`` - RuntimeError + ValueError If ``chanAB`` contains values that not contained in ``channel`` - RuntimeError + ValueError If ``operator`` is not one of the following: ``">", "<", "<=", ">=", "=="`` - RuntimeError + ValueError If the path provided for ``source_Sv`` is not a valid path - RuntimeError + ValueError If ``freqAB`` or ``chanAB`` is provided and the Dataset produced by ``source_Sv`` does not contain the coordinate ``channel`` and variable ``frequency_nominal`` @@ -242,11 +465,11 @@ def frequency_differencing( _check_freq_diff_non_data_inputs(freqAB, chanAB, operator, diff) # validate the source_Sv type or path (if it is provided) - source_Sv, file_type = validate_source_ds(source_Sv, storage_options) + source_Sv, file_type = validate_source_ds_da(source_Sv, storage_options) if isinstance(source_Sv, str): # open up Dataset using source_Sv path - source_Sv = xr.open_dataset(source_Sv, engine=file_type, chunks="auto", **storage_options) + source_Sv = xr.open_dataset(source_Sv, engine=file_type, chunks={}, **storage_options) # check the source_Sv with respect to channel and frequency_nominal _check_source_Sv_freq_diff(source_Sv, freqAB, chanAB) diff --git a/echopype/tests/mask/test_mask.py b/echopype/tests/mask/test_mask.py index 543ce2722f..de24e6c5fb 100644 --- a/echopype/tests/mask/test_mask.py +++ b/echopype/tests/mask/test_mask.py @@ -1,12 +1,18 @@ +import pathlib + import pytest import numpy as np import xarray as xr +import dask.array +import tempfile +import os import echopype as ep +import echopype.mask from echopype.mask.api import _check_source_Sv_freq_diff -from typing import List, Union +from typing import List, Union, Optional def get_mock_freq_diff_data(n: int, n_chan_freq: int, add_chan: bool, @@ -86,6 +92,56 @@ def get_mock_freq_diff_data(n: int, n_chan_freq: int, add_chan: bool, return mock_Sv_ds +def get_mock_source_ds_apply_mask(n: int, n_chan: int, is_delayed: bool) -> xr.Dataset: + """ + Constructs a mock ``source_ds`` Dataset input for the + ``apply_mask`` function. + + Parameters + ---------- + n: int + The number of rows (``x``) and columns (``y``) of + each channel matrix + n_chan: int + Determines the size of the ``channel`` coordinate + is_delayed: bool + If True, the returned Dataset variables ``var1`` and ``var2`` will be + a Dask arrays, else they will be in-memory arrays + + Returns + ------- + xr.Dataset + A Dataset with coordinates ``channel, x, y`` and + variables ``var1, var2`` (with the created coordinates). The + variables will contain square matrices of ones for each ``channel``. + """ + + # construct channel values + chan_vals = ['chan' + str(i) for i in range(1, n_chan + 1)] + + # construct mock variable data for each channel + if is_delayed: + mock_var_data = [dask.array.ones((n, n)) for i in range(n_chan)] + else: + mock_var_data = [np.ones((n, n)) for i in range(n_chan)] + + # create mock var1 and var2 DataArrays + mock_var1_da = xr.DataArray(data=np.stack(mock_var_data), + coords={"channel": ("channel", chan_vals, {"long_name": "channel name"}), + "x": np.arange(n), "y": np.arange(n)}, + attrs={"long_name": "variable 1"}) + mock_var2_da = xr.DataArray(data=np.stack(mock_var_data), + coords={"channel": ("channel", chan_vals, {"long_name": "channel name"}), + "x": np.arange(n), + "y": np.arange(n)}, + attrs={"long_name": "variable 2"}) + + # create mock Dataset + mock_ds = xr.Dataset(data_vars={"var1": mock_var1_da, "var2": mock_var2_da}) + + return mock_ds + + @pytest.mark.parametrize( ("n", "n_chan_freq", "add_chan", "add_freq_nom", "freqAB", "chanAB"), [ @@ -219,6 +275,133 @@ def test_frequency_differencing(n: int, n_chan_freq: int, assert out.name == "mask" +@pytest.mark.parametrize( + ("n", "n_chan", "var_name", "mask", "mask_file", "fill_value", "is_delayed", "var_masked_truth"), + [ + (2, 1, "var1", np.identity(2), None, np.nan, False, np.array([[1, np.nan], [np.nan, 1]])), + (2, 1, "var1", np.identity(2), None, 2.0, False, np.array([[1, 2.0], [2.0, 1]])), + (2, 1, "var1", np.identity(2), None, np.array([[[np.nan, np.nan], [np.nan, np.nan]]]), + False, np.array([[1, np.nan], [np.nan, 1]])), + (2, 1, "var1", np.identity(2), None, xr.DataArray(data=np.array([[[np.nan, np.nan], [np.nan, np.nan]]]), + coords={"channel": ["chan1"], + "ping_time": [0, 1], + "range_sample": [0, 1]}), + False, np.array([[1, np.nan], [np.nan, 1]])), + (2, 1, "var1", [np.identity(2), np.array([[0, 1], [0, 1]])], [None, None], 2.0, + False, np.array([[2.0, 2.0], [2.0, 1]])), + (2, 1, "var1", np.identity(2), None, 2.0, True, np.array([[1, 2.0], [2.0, 1]])), + (2, 1, "var1", np.identity(2), "test.zarr", 2.0, True, np.array([[1, 2.0], [2.0, 1]])), + (2, 1, "var1", [np.identity(2), np.array([[0, 1], [0, 1]])], ["test0.zarr", "test1.zarr"], 2.0, + False, np.array([[2.0, 2.0], [2.0, 1]])), + (2, 1, "var1", [np.identity(2), np.array([[0, 1], [0, 1]])], ["test0.zarr", None], 2.0, + False, np.array([[2.0, 2.0], [2.0, 1]])), + ], + ids=["single_mask_default_fill", "single_mask_float_fill", "single_mask_np_array_fill", + "single_mask_DataArray_fill", "list_mask_all_np", "single_mask_ds_delayed", + "single_mask_as_path", "list_mask_all_path", "list_mask_some_path"] +) +def test_apply_mask(n: int, n_chan: int, var_name: str, + mask: Union[np.ndarray, List[np.ndarray]], + mask_file: Optional[Union[str, List[str]]], + fill_value: Union[int, float, np.ndarray, xr.DataArray], + is_delayed: bool, var_masked_truth: np.ndarray): + """ + Ensures that ``apply_mask`` functions correctly. + + Parameters + ---------- + n: int + The number of rows (``x``) and columns (``y``) of + each channel matrix + n_chan: int + Determines the size of the ``channel`` coordinate + var_name: {"var1", "var2"} + The variable name in the mock Dataset to apply the mask to + mask: np.ndarray or list of np.ndarray + The mask(s) that should be applied to ``var_name`` + mask_file: str or list of str, optional + If provided, the ``mask`` input will be written to a temporary directory + with file name ``mask_file``. This will then be used in ``apply_mask``. + var_masked_truth: np.ndarray + The true value of ``var_name`` values after the mask has been applied + is_delayed: bool + If True, makes all variables in constructed mock Dataset Dask arrays, + else they will be in-memory arrays + """ + + # obtain mock Dataset containing var_name + mock_ds = get_mock_source_ds_apply_mask(n, n_chan, is_delayed) + + # initialize temp_dir + temp_dir = None + + # make input numpy array masks into DataArrays + if isinstance(mask, list): + + # create temporary directory if mask_file is provided + if any([isinstance(elem, str) for elem in mask_file]): + + # create temporary directory for mask_file + temp_dir = tempfile.TemporaryDirectory() + + for mask_ind in range(len(mask)): + + # form DataArray from given mask data + mask_da = xr.DataArray(data=np.stack([mask[mask_ind] for i in range(n_chan)]), + coords=mock_ds.coords, name='mask_' + str(mask_ind)) + + if mask_file[mask_ind] is None: + + # set mask value to the DataArray given + mask[mask_ind] = mask_da + else: + + # write DataArray to temporary directory + zarr_path = os.path.join(temp_dir.name, mask_file[mask_ind]) + mask_da.to_dataset().to_zarr(zarr_path) + + # set mask value to created path + mask[mask_ind] = zarr_path + + elif isinstance(mask, np.ndarray): + + # form DataArray from given mask data + mask_da = xr.DataArray(data=np.stack([mask for i in range(n_chan)]), + coords=mock_ds.coords, name='mask_0') + + if mask_file is None: + + # set mask to the DataArray formed + mask = mask_da + else: + + # create temporary directory for mask_file + temp_dir = tempfile.TemporaryDirectory() + + # write DataArray to temporary directory + zarr_path = os.path.join(temp_dir.name, mask_file) + mask_da.to_dataset().to_zarr(zarr_path) + + # set mask index to path + mask = zarr_path + + # create DataArray form of the known truth value + var_masked_truth = xr.DataArray(data=np.stack([var_masked_truth for i in range(n_chan)]), + coords=mock_ds[var_name].coords, attrs=mock_ds[var_name].attrs) + var_masked_truth.name = mock_ds[var_name].name + + # apply the mask to var_name + masked_ds = echopype.mask.apply_mask(source_ds=mock_ds, var_name=var_name, mask=mask, + fill_value=fill_value, storage_options_ds={}, + storage_options_mask={}) + # check that masked_ds[var_name] == var_masked_truth + assert masked_ds[var_name].identical(var_masked_truth) + # check that the output Dataset has lazy elements, if the input was lazy + if is_delayed: + assert isinstance(masked_ds[var_name].data, dask.array.Array) + if temp_dir: + # remove the temporary directory, if it was created + temp_dir.cleanup() diff --git a/echopype/tests/utils/test_utils_io.py b/echopype/tests/utils/test_utils_io.py index 1173f13ae9..79734d3282 100644 --- a/echopype/tests/utils/test_utils_io.py +++ b/echopype/tests/utils/test_utils_io.py @@ -6,7 +6,7 @@ import platform import xarray as xr -from echopype.utils.io import sanitize_file_path, validate_output_path, env_indep_joinpath, validate_source_ds +from echopype.utils.io import sanitize_file_path, validate_output_path, env_indep_joinpath, validate_source_ds_da @pytest.mark.parametrize( @@ -260,18 +260,14 @@ def test_env_indep_joinpath_os_dependent(save_path: str, is_windows: bool, is_cl @pytest.mark.parametrize( - ("source_ds_input", "storage_options_input", "true_file_type"), + ("source_ds_da_input", "storage_options_input", "true_file_type"), [ pytest.param(42, {}, None, marks=pytest.mark.xfail( strict=True, reason='This test should fail because source_ds is not of the correct type.') ), - pytest.param(xr.DataArray(), {}, None, - marks=pytest.mark.xfail( - strict=True, - reason='This test should fail because source_ds is not of the correct type.') - ), + pytest.param(xr.DataArray(), {}, None), pytest.param({}, 42, None, marks=pytest.mark.xfail( strict=True, @@ -283,17 +279,17 @@ def test_env_indep_joinpath_os_dependent(save_path: str, is_windows: bool, is_cl ] ) -def test_validate_source_ds(source_ds_input, storage_options_input, true_file_type): +def test_validate_source_ds_da(source_ds_da_input, storage_options_input, true_file_type): """ - Tests that ``validate_source_ds`` has the appropriate outputs. - An exhaustive list of combinations of ``source_ds`` and ``storage_options`` + Tests that ``validate_source_ds_da`` has the appropriate outputs. + An exhaustive list of combinations of ``source_ds_da`` and ``storage_options`` are tested in ``test_validate_output_path`` and are therefore not included here. """ - source_ds_output, file_type_output = validate_source_ds(source_ds_input, storage_options_input) + source_ds_output, file_type_output = validate_source_ds_da(source_ds_da_input, storage_options_input) - if isinstance(source_ds_input, xr.Dataset): - assert source_ds_output.identical(source_ds_input) + if isinstance(source_ds_da_input, (xr.Dataset, xr.DataArray)): + assert source_ds_output.identical(source_ds_da_input) assert file_type_output is None else: assert isinstance(source_ds_output, str) diff --git a/echopype/utils/io.py b/echopype/utils/io.py index b9c793e76d..dd8d77ffff 100644 --- a/echopype/utils/io.py +++ b/echopype/utils/io.py @@ -318,29 +318,29 @@ def env_indep_joinpath(*args: Tuple[str, ...]) -> str: return joined_path -def validate_source_ds( - source_ds: Union[xr.Dataset, str, Path], storage_options: Optional[dict] +def validate_source_ds_da( + source_ds_da: Union[xr.Dataset, xr.DataArray, str, Path], storage_options: Optional[dict] ) -> Tuple[Union[xr.Dataset, str, xr.DataArray], Optional[str]]: """ - This function ensures that ``source_ds`` is of the correct - type and validates the path of ``source_ds``, if it is provided. + This function ensures that ``source_ds_da`` is of the correct + type and validates the path of ``source_ds_da``, if it is provided. Parameters ---------- - source_ds: xr.Dataset or str or pathlib.Path - A source that points to a Dataset. If the input is a path, it specifies - the path to a zarr or netcdf file. + source_ds_da: xr.Dataset, xr.DataArray, str or pathlib.Path + A source that points to a Dataset or DataArray. If the input is a path, + it specifies the path to a zarr or netcdf file. storage_options: dict, optional Any additional parameters for the storage backend, corresponding to the - path provided for ``source_ds`` + path provided for ``source_ds_da`` Returns ------- - source_ds: xr.Dataset or str - A Dataset which will be the same as the input ``source_ds`` or a validated - path to a zarr or netcdf file + source_ds_da: xr.Dataset or xr.DataArray or str + A Dataset or DataArray which will be the same as the input ``source_ds_da`` or + a validated path to a zarr or netcdf file file_type: {"netcdf4", "zarr"}, optional - The file type of the input path if ``source_ds`` is a path, otherwise ``None`` + The file type of the input path if ``source_ds_da`` is a path, otherwise ``None`` """ # initialize file_type @@ -350,24 +350,24 @@ def validate_source_ds( if not isinstance(storage_options, dict): raise TypeError("storage_options must be a dict!") - # check that source_ds is of the correct type, if it is a path validate - # the path and open the dataset using xarray - if not isinstance(source_ds, (xr.Dataset, str, Path)): - raise TypeError("source_ds must be a Dataset or str or pathlib.Path!") - elif isinstance(source_ds, (str, Path)): + # check that source_ds_da is of the correct type, if it is a path validate + # the path and open the Dataset or DataArray using xarray + if not isinstance(source_ds_da, (xr.Dataset, xr.DataArray, str, Path)): + raise TypeError("source_ds_da must be a Dataset or DataArray or str or pathlib.Path!") + elif isinstance(source_ds_da, (str, Path)): # determine if we obtained a zarr or netcdf file - file_type = get_file_format(source_ds) + file_type = get_file_format(source_ds_da) - # validate source_ds if it is a path - source_ds = validate_output_path( + # validate source_ds_da if it is a path + source_ds_da = validate_output_path( source_file="blank", # will be unused since source_ds cannot be none engine=file_type, output_storage_options=storage_options, - save_path=source_ds, + save_path=source_ds_da, ) # check that the path exists - check_file_existence(file_path=source_ds, storage_options=storage_options) + check_file_existence(file_path=source_ds_da, storage_options=storage_options) - return source_ds, file_type + return source_ds_da, file_type From abca664fa774f1b9dfcd045fb9241ef83cbc78a4 Mon Sep 17 00:00:00 2001 From: Wu-Jung Lee Date: Thu, 22 Dec 2022 18:35:07 -0800 Subject: [PATCH 4/8] Redo: Fix meta_source_filenames bug and enable (meta)source_filenames appending of path and list (#913) * Expand prov.source_files_vars to support path sequences that mix a str/path and another sequence; will add error checking * Enable source_files_vars to handle a sequence within a sequence * For EK80, ES70, ES80, EA640, conversion was inserting an unnecessary, empty meta_source_filenames variable * Add more comprehensive and readable prov source-files type hints; rename _source_files * Fix prov type hint bug with Py 3.8 * Add unit test for prov _sanitize_source_files; plus small fixes to np.ndarray type references in prov Co-authored-by: Emilio Mayorga --- echopype/tests/utils/test_source_filenames.py | 33 +++++++ echopype/utils/prov.py | 90 ++++++++++++++----- 2 files changed, 99 insertions(+), 24 deletions(-) create mode 100644 echopype/tests/utils/test_source_filenames.py diff --git a/echopype/tests/utils/test_source_filenames.py b/echopype/tests/utils/test_source_filenames.py new file mode 100644 index 0000000000..d0a935b712 --- /dev/null +++ b/echopype/tests/utils/test_source_filenames.py @@ -0,0 +1,33 @@ +from pathlib import Path + +import numpy as np + +from echopype.utils.prov import _sanitize_source_files + + +def test_scalars(): + """One or more scalar values""" + path1 = "/my/path1" + path2 = Path("/my/path2") + + # Single scalars + assert _sanitize_source_files(path1) == [path1] + assert _sanitize_source_files(path2) == [str(path2)] + # List of scalars + assert _sanitize_source_files([path1, path2]) == [path1, str(path2)] + + +def test_mixed(): + """A scalar value and a list or ndarray""" + path1 = "/my/path1" + path2 = Path("/my/path2") + # Mixed-type list + path_list1 = [path1, path2] + # String-type ndarray + path_list2 = np.array([path1, str(path2)]) + + # A scalar and a list + target_path_list = [path1, path1, str(path2)] + assert _sanitize_source_files([path1, path_list1]) == target_path_list + # A scalar and an ndarray + assert _sanitize_source_files([path1, path_list2]) == target_path_list diff --git a/echopype/utils/prov.py b/echopype/utils/prov.py index d79280029f..dd877172fc 100644 --- a/echopype/utils/prov.py +++ b/echopype/utils/prov.py @@ -1,14 +1,20 @@ from datetime import datetime as dt from pathlib import Path -from typing import Any, Dict, List, Tuple, Union +from typing import Dict, List, Tuple, Union +import numpy as np from _echopype_version import version as ECHOPYPE_VERSION +from numpy.typing import NDArray from typing_extensions import Literal -# TODO: It'd be cleaner to use PathHint, but it leads to a circular import error -# from ..core import PathHint +from .log import _init_logger ProcessType = Literal["conversion", "processing"] +# Note that this PathHint is defined differently from the one in ..core +PathHint = Union[str, Path] +PathSequenceHint = Union[List[PathHint], Tuple[PathHint], NDArray[PathHint]] + +logger = _init_logger(__name__) def echopype_prov_attrs(process_type: ProcessType) -> Dict[str, str]: @@ -30,8 +36,49 @@ def echopype_prov_attrs(process_type: ProcessType) -> Dict[str, str]: return prov_dict +def _sanitize_source_files(paths: Union[PathHint, PathSequenceHint]): + """ + Create sanitized list of string paths from heterogeneous path inputs. + + Parameters + ---------- + paths : Union[PathHint, PathSequenceHint] + File paths as either a single path string or pathlib Path, + a sequence (tuple, list or np.ndarray) of strings or pathlib Paths, + or a mixed sequence that may contain another sequence as an element. + + Returns + ------- + paths_list : List[str] + List of file paths. Empty list if no source path element was parsed successfully. + """ + sequence_types = (list, tuple, np.ndarray) + if isinstance(paths, (str, Path)): + return [str(paths)] + elif isinstance(paths, sequence_types): + paths_list = [] + for p in paths: + if isinstance(p, (str, Path)): + paths_list.append(str(p)) + elif isinstance(p, sequence_types): + paths_list += [str(pp) for pp in p if isinstance(pp, (str, Path))] + else: + logger.warning( + "Unrecognized file path element type, path element will not be" + f" written to (meta)source_file provenance attribute. {p}" + ) + return paths_list + else: + logger.warning( + "Unrecognized file path element type, path element will not be" + f" written to (meta)source_file provenance attribute. {paths}" + ) + return [] + + def source_files_vars( - source_paths: Union[str, List[Any]], meta_source_paths: Union[str, List[Any]] = None + source_paths: Union[PathHint, PathSequenceHint], + meta_source_paths: Union[PathHint, PathSequenceHint] = None, ) -> Dict[str, Dict[str, Tuple]]: """ Create source_filenames and meta_source_filenames provenance @@ -39,11 +86,15 @@ def source_files_vars( Parameters ---------- - source_paths : Union[str, List[Any]] - Source file paths as either a single path string or a list of Path-type paths - meta_source_paths : Union[str, List[Any]] - Source file paths for metadata files (often as XML files), - as either a single path string or a list of Path-type paths + source_paths : Union[PathHint, PathSequenceHint] + Source file paths as either a single path string or pathlib Path, + a sequence (tuple, list or np.ndarray) of strings or pathlib Paths, + or a mixed sequence that may contain another sequence as an element. + meta_source_paths : Union[PathHint, PathSequenceHint] + Source file paths for metadata files (often as XML files), as either a + single path string or pathlib Path, a sequence (tuple, list or np.ndarray) + of strings or pathlib Paths, or a mixed sequence that may contain another + sequence as an element. Returns ------- @@ -57,19 +108,10 @@ def source_files_vars( meta_source_filenames xarray DataArray with filenames dimension source_files_coord : Dict[str, Tuple] Single-element dict containing a tuple for creating the - filenames coordinate variable DataArray + filenames coordinate variable xarray DataArray """ - def _source_files(paths): - """Handle a plain string containing a single path, - a single pathlib Path, or a list of strings or pathlib paths - """ - if isinstance(paths, (str, Path)): - return [str(paths)] - else: - return [str(p) for p in paths if isinstance(p, (str, Path))] - - source_files = _source_files(source_paths) + source_files = _sanitize_source_files(source_paths) files_vars = dict() files_vars["source_files_var"] = { @@ -80,8 +122,10 @@ def _source_files(paths): ), } - if meta_source_paths is not None: - meta_source_files = _source_files(meta_source_paths) + if meta_source_paths is None or meta_source_paths == "": + files_vars["meta_source_files_var"] = None + else: + meta_source_files = _sanitize_source_files(meta_source_paths) files_vars["meta_source_files_var"] = { "meta_source_filenames": ( "filenames", @@ -89,8 +133,6 @@ def _source_files(paths): {"long_name": "Metadata source filenames"}, ), } - else: - files_vars["meta_source_files_var"] = None files_vars["source_files_coord"] = { "filenames": ( From e8e5fabdff3d823fe36b9413c631c1089873eac6 Mon Sep 17 00:00:00 2001 From: Wu-Jung Lee Date: Thu, 22 Dec 2022 20:15:12 -0800 Subject: [PATCH 5/8] Redo: Fix mkdir to not raise error and skip directory creation if exists (#914) See #909 for conversations and detailed commits. --- echopype/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/echopype/core.py b/echopype/core.py index b1453f85a1..bf06bd4fe3 100644 --- a/echopype/core.py +++ b/echopype/core.py @@ -30,7 +30,7 @@ def init_ep_dir(): """Initialize hidden directory for echopype""" if not ECHOPYPE_DIR.exists(): - ECHOPYPE_DIR.mkdir() + ECHOPYPE_DIR.mkdir(exist_ok=True) def validate_azfp_ext(test_ext: str): From 818722600d46a4accec0d2fb6882b394608460ef Mon Sep 17 00:00:00 2001 From: Wu-Jung Lee Date: Thu, 22 Dec 2022 20:16:43 -0800 Subject: [PATCH 6/8] Redo: Ignore docs/source/_build (#915) See #910 for all conversations and detailed commits. --- .pre-commit-config.yaml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 01eb398a96..adf844983f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -43,4 +43,6 @@ repos: rev: v2.1.0 hooks: - id: codespell - args: ["--skip=*.ipynb", "-w", "docs/source", "echopype"] + # Checks spelling in `docs/source` and `echopype` dirs ONLY + # Ignores `.ipynb` files and `_build` folders + args: ["--skip=*.ipynb,docs/source/_build", "-w", "docs/source", "echopype"] From 7cad4d5e6209664f3fb859f3cc7f03e0913651f1 Mon Sep 17 00:00:00 2001 From: Wu-Jung Lee Date: Thu, 22 Dec 2022 20:23:26 -0800 Subject: [PATCH 7/8] Redo: Implement `add_splitbeam_angle` function (#916) See #894 for all conversations and detailed commits. --- echopype/calibrate/calibrate_ek.py | 105 +--- echopype/consolidate/__init__.py | 4 +- echopype/consolidate/api.py | 177 ++++++- echopype/consolidate/split_beam_angle.py | 501 ++++++++++++++++++ echopype/echodata/simrad.py | 229 ++++++++ .../tests/consolidate/test_consolidate.py | 158 +++++- .../tests/echodata/test_echodata_simrad.py | 69 +++ 7 files changed, 1159 insertions(+), 84 deletions(-) create mode 100644 echopype/consolidate/split_beam_angle.py create mode 100644 echopype/echodata/simrad.py create mode 100644 echopype/tests/echodata/test_echodata_simrad.py diff --git a/echopype/calibrate/calibrate_ek.py b/echopype/calibrate/calibrate_ek.py index 06bdb0b2ac..11b03d7e20 100644 --- a/echopype/calibrate/calibrate_ek.py +++ b/echopype/calibrate/calibrate_ek.py @@ -3,6 +3,7 @@ from scipy import signal from ..echodata import EchoData +from ..echodata.simrad import retrieve_correct_beam_group from ..utils import uwa from ..utils.log import _init_logger from .calibrate_base import CAL_PARAMS, CalibrateBase @@ -134,19 +135,16 @@ def get_cal_params(self, cal_params, waveform_mode, encode_mode): else beam["equivalent_beam_angle"] ) - def _cal_power(self, cal_type, use_beam_power=False) -> xr.Dataset: + def _cal_power(self, cal_type: str, power_ed_group: str = None) -> xr.Dataset: """Calibrate power data from EK60 and EK80. Parameters ---------- - cal_type : str + cal_type: str 'Sv' for calculating volume backscattering strength, or 'TS' for calculating target strength - use_beam_power : bool - whether to use beam_power. - If ``True`` use ``echodata["Sonar/Beam_group2"]``; - if ``False`` use ``echodata["Sonar/Beam_group1"]``. - Note ``echodata["Sonar/Beam_group2"]`` could only exist for EK80 data. + power_ed_group: + The ``EchoData`` beam group path containing the power data Returns ------- @@ -154,10 +152,7 @@ def _cal_power(self, cal_type, use_beam_power=False) -> xr.Dataset: The calibrated dataset containing Sv or TS """ # Select source of backscatter data - if use_beam_power: - beam = self.echodata["Sonar/Beam_group2"] - else: - beam = self.echodata["Sonar/Beam_group1"] + beam = self.echodata[power_ed_group] # Harmonize time coordinate between Beam_groupX data and env_params for p in self.env_params.keys(): @@ -284,10 +279,16 @@ def get_env_params(self, **kwargs): ) def compute_Sv(self, **kwargs): - return self._cal_power(cal_type="Sv") + power_ed_group = retrieve_correct_beam_group( + echodata=self.echodata, waveform_mode="CW", encode_mode="power", pulse_compression=False + ) + return self._cal_power(cal_type="Sv", power_ed_group=power_ed_group) def compute_TS(self, **kwargs): - return self._cal_power(cal_type="TS") + power_ed_group = retrieve_correct_beam_group( + echodata=self.echodata, waveform_mode="CW", encode_mode="power", pulse_compression=False + ) + return self._cal_power(cal_type="TS", power_ed_group=power_ed_group) class CalibrateEK80(CalibrateEK): @@ -883,78 +884,22 @@ def _compute_cal(self, cal_type, waveform_mode, encode_mode) -> xr.Dataset: xr.Dataset An xarray Dataset containing either Sv or TS. """ - # Raise error for wrong inputs - if waveform_mode not in ("BB", "CW"): - raise ValueError( - "Input waveform_mode not recognized! " - "waveform_mode must be either 'BB' or 'CW' for EK80 data." - ) - if encode_mode not in ("complex", "power"): - raise ValueError( - "Input encode_mode not recognized! " - "encode_mode must be either 'complex' or 'power' for EK80 data." - ) + + power_ed_group = retrieve_correct_beam_group( + echodata=self.echodata, + waveform_mode=waveform_mode, + encode_mode=encode_mode, + pulse_compression=False, + ) # Set flag_complex # - True: complex cal # - False: power cal - # BB: complex only, CW: complex or power + flag_complex = False if waveform_mode == "BB": - if encode_mode == "power": # BB waveform forces to collect complex samples - raise ValueError("encode_mode='power' not allowed when waveform_mode='BB'!") flag_complex = True - else: # waveform_mode="CW" - if encode_mode == "complex": - flag_complex = True - else: - flag_complex = False - - # Raise error when waveform_mode and actual recording mode do not match - # This simple check is only possible for BB-only data, - # since for data with both BB and CW complex samples, - # frequency_start will exist in echodata["Sonar/Beam_group1"] for the BB channels - if waveform_mode == "BB" and "frequency_start" not in self.echodata["Sonar/Beam_group1"]: - raise ValueError("waveform_mode='BB' but broadband data not found!") - - # Set use_beam_power - # - True: use self.echodata["Sonar/Beam_group2"] for cal - # - False: use self.echodata["Sonar/Beam_group1"] for cal - use_beam_power = False - - # Warn user about additional data in the raw file if another type exists - # When both power and complex samples exist: - # complex samples will be stored in echodata["Sonar/Beam_group1"] - # power samples will be stored in echodata["Sonar/Beam_group2"] - # When only one type of samples exist, - # all samples with be stored in echodata["Sonar/Beam_group1"] - if self.echodata["Sonar/Beam_group2"] is not None: # both power and complex samples exist - # If both beam and beam_power groups exist, - # this means that CW data are encoded as power samples and in beam_power group - if waveform_mode == "CW" and encode_mode == "complex": - raise ValueError("File does not contain CW complex samples") - - if encode_mode == "power": - use_beam_power = True # switch source of backscatter data - logger.info( - "Only power samples are calibrated, but complex samples also exist in the raw data file!" # noqa - ) - else: - logger.info( - "Only complex samples are calibrated, but power samples also exist in the raw data file!" # noqa - ) - else: # only power OR complex samples exist - if ( - "backscatter_i" in self.echodata["Sonar/Beam_group1"].variables - ): # data contain only complex samples - if encode_mode == "power": - raise TypeError( - "File does not contain power samples! Use encode_mode='complex'" - ) # user selects the wrong encode_mode - else: # data contain only power samples - if encode_mode == "complex": - raise TypeError( - "File does not contain complex samples! Use encode_mode='power'" - ) # user selects the wrong encode_mode + elif encode_mode == "complex": + flag_complex = True # Compute Sv if flag_complex: @@ -964,7 +909,7 @@ def _compute_cal(self, cal_type, waveform_mode, encode_mode) -> xr.Dataset: else: # Power samples only make sense for CW mode data self.compute_range_meter(waveform_mode="CW", encode_mode=encode_mode) - ds_cal = self._cal_power(cal_type=cal_type, use_beam_power=use_beam_power) + ds_cal = self._cal_power(cal_type=cal_type, power_ed_group=power_ed_group) return ds_cal diff --git a/echopype/consolidate/__init__.py b/echopype/consolidate/__init__.py index d26813ab89..acca09fd80 100644 --- a/echopype/consolidate/__init__.py +++ b/echopype/consolidate/__init__.py @@ -1,3 +1,3 @@ -from .api import add_depth, add_location, swap_dims_channel_frequency +from .api import add_depth, add_location, add_splitbeam_angle, swap_dims_channel_frequency -__all__ = ["swap_dims_channel_frequency", "add_depth", "add_location"] +__all__ = ["swap_dims_channel_frequency", "add_depth", "add_location", "add_splitbeam_angle"] diff --git a/echopype/consolidate/api.py b/echopype/consolidate/api.py index de81481bc0..d80d0e3f32 100644 --- a/echopype/consolidate/api.py +++ b/echopype/consolidate/api.py @@ -1,10 +1,20 @@ import datetime -from typing import Optional +import pathlib +from typing import Optional, Union import numpy as np import xarray as xr from ..echodata import EchoData +from ..echodata.simrad import retrieve_correct_beam_group +from ..utils.io import validate_source_ds_da +from .split_beam_angle import ( + add_angle_to_ds, + get_angle_complex_BB_nopc, + get_angle_complex_BB_pc, + get_angle_complex_CW, + get_angle_power_CW, +) def swap_dims_channel_frequency(ds: xr.Dataset) -> xr.Dataset: @@ -174,3 +184,168 @@ def sel_interp(var): interp_ds["longitude"] = interp_ds["longitude"].assign_attrs({"history": history}) return interp_ds.drop_vars("time1") + + +def add_splitbeam_angle( + source_Sv: Union[xr.Dataset, str, pathlib.Path], + echodata: EchoData, + waveform_mode: str, + encode_mode: str, + pulse_compression: bool = False, + storage_options: dict = {}, + return_dataset: bool = True, +) -> xr.Dataset: + """ + Add split-beam (alongship/athwartship) angles into the Sv dataset. + This function calculates the alongship/athwartship angle using data stored + in the beam groups of ``echodata``. In cases when angle data does not already exist + or cannot be computed from the data, an error is issued and no angle variables are + added to the dataset. + + Parameters + ---------- + source_Sv: xr.Dataset or str or pathlib.Path + The Sv Dataset or path to a file containing the Sv Dataset, which will have the + split-beam angle data added to it + echodata: EchoData + An ``EchoData`` object holding the raw data + waveform_mode : {"CW", "BB"} + Type of transmit waveform + + - ``"CW"`` for narrowband transmission, + returned echoes recorded either as complex or power/angle samples + - ``"BB"`` for broadband transmission, + returned echoes recorded as complex samples + + encode_mode : {"complex", "power"} + Type of encoded return echo data + + - ``"complex"`` for complex samples + - ``"power"`` for power/angle samples, only allowed when + the echosounder is configured for narrowband transmission + pulse_compression: bool, False + Whether pulse compression should be used (only valid for + ``waveform_mode="BB"`` and ``encode_mode="complex"``) + storage_options: dict, default={} + Any additional parameters for the storage backend, corresponding to the + path provided for ``source_Sv`` + return_dataset: bool, default=True + If True, ``source_Sv`` with the split-beam angle data added to it + will be returned, else it will not be returned. A value of ``False`` + is useful in the situation where ``source_Sv`` is a path and the user + only wants to write the split-beam angle data to the path provided. + + Returns + ------- + xr.Dataset or None + If ``return_dataset=False``, nothing will be returned. If ``return_dataset=True`` + either the input dataset ``source_Sv`` or a lazy-loaded Dataset (obtained from + the path provided by ``source_Sv``) with the split-beam angle data added + will be returned. + + Raises + ------ + ValueError + If ``echodata`` has a sonar model that is not analogous to either EK60 or EK80 + ValueError + If the input ``source_Sv`` does not have a ``channel`` dimension + ValueError + If ``source_Sv`` does not have appropriate dimension lengths in + comparison to ``echodata`` data + ValueError + If the provided ``waveform_mode``, ``encode_mode``, and ``pulse_compression`` are not valid + NotImplementedError + If an unknown ``beam_type`` is encountered during the split-beam calculation + + Notes + ----- + Split-beam angle data potentially exist for the following echosounders depending on + the instrument configuration and recording setting: + + - Simrad EK60 echosounder paired with split-beam transducers and + configured to store angle data + - Simrad EK80 echosounder paired with split-beam transducers and + configured to store angle data + + In most cases where the type of samples collected by the echosounder (power/angle + samples or complex samples) and the transmit waveform (broadband or narrowband) + are identical across all channels, the channels existing in ``source_Sv`` and ` + `echodata`` will be identical. If this is not the case, only angle data corresponding + to channels existing in ``source_Sv`` will be added. + + For EK80 broadband data, the split-beam angles can be estimated from the complex data. + The current implementation generates angles estimated *without* applying pulse compression. + Estimating the angle with pulse compression will be added in the near future. + """ + + # ensure that echodata was produced by EK60 or EK80-like sensors + if echodata.sonar_model not in ["EK60", "ES70", "EK80", "ES80", "EA640"]: + raise ValueError( + "The sonar model that produced echodata does not have split-beam " + "transducers, split-beam angles cannot be added to source_Sv!" + ) + + # validate the source_Sv type or path (if it is provided) + source_Sv, file_type = validate_source_ds_da(source_Sv, storage_options) + + # initialize source_Sv_path + source_Sv_path = None + + if isinstance(source_Sv, str): + + # store source_Sv path so we can use it to write to later + source_Sv_path = source_Sv + + # TODO: In the future we can improve this by obtaining the variable names, channels, + # and dimension lengths directly from source_Sv using zarr or netcdf4. This would + # prevent the unnecessary loading in of the coordinates, which the below statement does. + # open up Dataset using source_Sv path + source_Sv = xr.open_dataset(source_Sv, engine=file_type, chunks={}, **storage_options) + + # raise not implemented error if source_Sv corresponds to MVBS + if source_Sv.attrs["processing_function"] == "preprocess.compute_MVBS": + raise NotImplementedError("Adding split-beam data to MVBS has not been implemented!") + + # check that the appropriate waveform and encode mode have been given + # and obtain the echodata group path corresponding to encode_mode + encode_mode_ed_group = retrieve_correct_beam_group( + echodata, waveform_mode, encode_mode, pulse_compression + ) + + # check that source_Sv at least has a channel dimension + if "channel" not in source_Sv.variables: + raise ValueError("The input source_Sv Dataset must have a channel dimension!") + + # set ds_beam, select the same channels that are in source_Sv + ds_beam = echodata[encode_mode_ed_group].sel(channel=source_Sv.channel.values) + + # fail if source_Sv and ds_beam do not have the same lengths + # for ping_time, range_sample, and channel + same_dim_lens = [ + ds_beam.dims[dim] == source_Sv.dims[dim] for dim in ["channel", "ping_time", "range_sample"] + ] + if not same_dim_lens: + raise ValueError( + "Input source_Sv does not have the same dimension lengths as all dimensions in ds_beam!" + ) + + # obtain split-beam angles from + # CW mode data + if waveform_mode == "CW": + if encode_mode == "power": # power data + theta, phi = get_angle_power_CW(ds_beam=ds_beam) + else: # complex data + theta, phi = get_angle_complex_CW(ds_beam=ds_beam) + # BB mode data + else: + if pulse_compression: # with pulse compression + theta, phi = get_angle_complex_BB_pc(ds_beam=ds_beam) + else: # without pulse compression + theta, phi = get_angle_complex_BB_nopc(ds_beam=ds_beam, ed=echodata) + + # add theta and phi to source_Sv input + source_Sv = add_angle_to_ds( + theta, phi, source_Sv, return_dataset, source_Sv_path, file_type, storage_options + ) + + return source_Sv diff --git a/echopype/consolidate/split_beam_angle.py b/echopype/consolidate/split_beam_angle.py new file mode 100644 index 0000000000..96796c8519 --- /dev/null +++ b/echopype/consolidate/split_beam_angle.py @@ -0,0 +1,501 @@ +""" +Contains functions necessary to compute the split-beam (alongship/athwartship) +angles and add them to a Dataset. +""" +from typing import List, Optional, Tuple + +import numpy as np +import xarray as xr + +from ..echodata import EchoData + + +def get_angle_power_CW(ds_beam: xr.Dataset) -> Tuple[xr.Dataset, xr.Dataset]: + """ + Obtains the split-beam angle data from power encoded data with CW waveform. + + Parameters + ---------- + ds_beam: xr.Dataset + An ``EchoData`` beam group containing angle information needed for + split-beam angle calculation + + Returns + ------- + theta: xr.Dataset + The calculated split-beam alongship angle + phi: xr.Dataset + The calculated split-beam athwartship angle + + Raises + ------ + NotImplementedError + If all ``beam_type`` values are not equal to 1 + + Notes + ----- + Can be used on both EK60 and EK80 data + + Computation done for ``beam_type=1``: + ``physical_angle = ((raw_angle * 180 / 128) / sensitivity) - offset`` + """ + + # raw_angle scaling constant + conversion_const = 180.0 / 128.0 + + def _e2f(angle_type: str) -> xr.Dataset: + """Convert electric angle to physical angle for split-beam data""" + return ( + conversion_const + * ds_beam[f"angle_{angle_type}"] + / ds_beam[f"angle_sensitivity_{angle_type}"] + - ds_beam[f"angle_offset_{angle_type}"] + ) + + # add split-beam angle if at least one channel is split-beam + # in the case when some channels are split-beam and some single-beam + # the single-beam channels will be all NaNs and _e2f would run through and output NaNs + if not np.all(ds_beam["beam_type"].data == 0): + + # obtain split-beam alongship angle + theta = _e2f(angle_type="alongship") + + # obtain split-beam athwartship angle + phi = _e2f(angle_type="athwartship") + + else: + raise ValueError( + "Computing physical split-beam angle is only available for data " + "from split-beam transducers!" + ) + + # drop the beam dimension in theta and phi, if it exists + if "beam" in theta.dims: + theta = theta.drop("beam").squeeze(dim="beam") + phi = phi.drop("beam").squeeze(dim="beam") + + return theta, phi + + +def get_angle_complex_CW(ds_beam: xr.Dataset) -> Tuple[xr.DataArray, xr.DataArray]: + """ + Obtains the split-beam angle data from complex encoded data with CW waveform. + + Parameters + ---------- + ds_beam: xr.Dataset + An ``EchoData`` beam group containing angle information needed for + split-beam angle calculation + + Returns + ------- + theta: xr.Dataset + The calculated split-beam alongship angle + phi: xr.Dataset + The calculated split-beam athwartship angle + """ + + # ensure that the beam_type is appropriate for calculation + if np.all(ds_beam["beam_type"].data == 1): + + # get complex representation of backscatter + backscatter = ds_beam["backscatter_r"] + 1j * ds_beam["backscatter_i"] + + # get angle sensitivity alongship and athwartship + angle_sensitivity_alongship = ds_beam["angle_sensitivity_alongship"].isel( + ping_time=0, beam=0 + ) + angle_sensitivity_athwartship = ds_beam["angle_sensitivity_athwartship"].isel( + ping_time=0, beam=0 + ) + + # get angle offset alongship and athwartship + angle_offset_alongship = ds_beam["angle_offset_alongship"].isel(ping_time=0, beam=0) + angle_offset_athwartship = ds_beam["angle_offset_athwartship"].isel(ping_time=0, beam=0) + + # obtain the split-beam angle data + theta, phi = _compute_angle_from_complex( + bs=backscatter, + beam_type=1, + sens=[angle_sensitivity_alongship, angle_sensitivity_athwartship], + offset=[angle_offset_alongship, angle_offset_athwartship], + ) + + else: + raise NotImplementedError("Computing split-beam angle is only available for beam_type=1!") + + # drop the beam dimension in theta and phi, if it exists + if "beam" in theta.coords: + theta = theta.drop_vars("beam") + phi = phi.drop("beam") + + return theta, phi + + +def _get_interp_offset( + param: str, chan_id: str, freq_center: xr.DataArray, ed: EchoData +) -> np.ndarray: + """ + Obtains an angle offset by first interpolating the + ``angle_offset_alongship`` or ``angle_offset_athwartship`` + data found in the ``Vendor_specific`` group and then + selecting the offset corresponding to the center frequency + value for ``channel=chan_id``. + + Parameters + ---------- + param: {"angle_offset_alongship", "angle_offset_athwartship"} + The angle offset data to select in the ``Vendor_specific`` group + chan_id: str + The channel used to select the center frequency value + freq_center: xr.DataArray + A DataArray filled with center frequency values with coordinate ``channel`` + ed: EchoData + An ``EchoData`` object holding the raw data + + Returns + ------- + np.ndarray + Array filled with the requested angle offset values + """ + + freq_wanted = freq_center.sel(channel=chan_id) + return ( + ed["Vendor_specific"][param].sel(cal_channel_id=chan_id).interp(cal_frequency=freq_wanted) + ).values + + +def _get_offset( + ds_beam: xr.Dataset, fc: xr.DataArray, freq_nominal: xr.DataArray, ed: EchoData +) -> Tuple[xr.DataArray, xr.DataArray]: + """ + Obtains the alongship and athwartship angle offsets. + + Parameters + ---------- + ds_beam: xr.Dataset + The dataset corresponding to a beam group + fc: xr.DataArray + Array corresponding to the center frequency + freq_nominal: xr.DataArray + Array of frequency nominal values + ed: EchoData + An ``EchoData`` object holding the raw data + + Returns + ------- + offset_along: xr.DataArray + Array corresponding to the angle alongship offset + offset_athwart: xr.DataArray + Array corresponding to the angle athwartship offset + """ + + # initialize lists that will hold offsets + offset_along = [] + offset_athwart = [] + + # obtain the offsets for each channel + for ch in fc["channel"].values: + if ch in ed["Vendor_specific"]["cal_channel_id"]: + # calculate offsets using Vendor_specific values + offset_along.append( + _get_interp_offset( + param="angle_offset_alongship", chan_id=ch, freq_center=fc, ed=ed + ) + ) + offset_athwart.append( + _get_interp_offset( + param="angle_offset_athwartship", chan_id=ch, freq_center=fc, ed=ed + ) + ) + else: + # calculate offsets using data in ds_beam + offset_along.append( + ds_beam["angle_offset_alongship"].sel(channel=ch).isel(ping_time=0, beam=0) + * fc.sel(channel=ch) + / freq_nominal.sel(channel=ch) + ) + offset_athwart.append( + ds_beam["angle_offset_athwartship"].sel(channel=ch).isel(ping_time=0, beam=0) + * fc.sel(channel=ch) + / freq_nominal.sel(channel=ch) + ) + + # construct offset DataArrays from lists + offset_along = xr.DataArray( + offset_along, coords={"channel": fc["channel"], "ping_time": fc["ping_time"]} + ) + offset_athwart = xr.DataArray( + offset_athwart, coords={"channel": fc["channel"], "ping_time": fc["ping_time"]} + ) + return offset_along, offset_athwart + + +def _compute_angle_from_complex( + bs: xr.Dataset, beam_type: int, sens: List[xr.DataArray], offset: List[xr.DataArray] +): + """ + Obtains the split-beam angle data alongship and athwartship + using data from a single channel. + + Parameters + ---------- + bs: xr.Dataset + Complex representation of backscatter + beam_type: int + The type of beam being considered + sens: list of xr.DataArray + A list of length two where the first element corresponds to the + angle sensitivity alongship and the second corresponds to the + angle sensitivity athwartship + offset: list of xr.DataArray + A list of length two where the first element corresponds to the + angle offset alongship and the second corresponds to the + angle offset athwartship + + Returns + ------- + theta: xr.Dataset + The calculated split-beam alongship angle for a specific channel + phi: xr.Dataset + The calculated split-beam athwartship angle for a specific channel + + Notes + ----- + This function should only be used for data with complex backscatter. + """ + + # 4-sector transducer + if beam_type == 1: + + bs_fore = (bs.isel(beam=2) + bs.isel(beam=3)) / 2 # forward + bs_aft = (bs.isel(beam=0) + bs.isel(beam=1)) / 2 # aft + bs_star = (bs.isel(beam=0) + bs.isel(beam=3)) / 2 # starboard + bs_port = (bs.isel(beam=1) + bs.isel(beam=2)) / 2 # port + + bs_theta = bs_fore * np.conj(bs_aft) + bs_phi = bs_star * np.conj(bs_port) + theta = np.arctan2(np.imag(bs_theta), np.real(bs_theta)) / np.pi * 180 + phi = np.arctan2(np.imag(bs_phi), np.real(bs_phi)) / np.pi * 180 + + # 3-sector transducer with or without center element + elif beam_type in [17, 49, 65, 81]: + # 3-sector + if beam_type == 17: + bs_star = bs.isel(beam=0) + bs_port = bs.isel(beam=1) + bs_fore = bs.isel(beam=2) + else: + # 3-sector + 1 center element + bs_star = (bs.isel(beam=0) + bs.isel(beam=3)) / 2 + bs_port = (bs.isel(beam=1) + bs.isel(beam=3)) / 2 + bs_fore = (bs.isel(beam=2) + bs.isel(beam=3)) / 2 + + bs_fac1 = bs_fore * np.conj(bs_star) + bs_fac2 = bs_fore * np.conj(bs_port) + fac1 = np.arctan2(np.imag(bs_fac1), np.real(bs_fac1)) / np.pi * 180 + fac2 = np.arctan2(np.imag(bs_fac2), np.real(bs_fac2)) / np.pi * 180 + + theta = (fac1 + fac2) / np.sqrt(3) + phi = fac2 - fac1 + + # EC150–3C + elif beam_type == 97: + raise NotImplementedError + + else: + raise ValueError("beam_type not recognized!") + + theta = theta / sens[0] - offset[0] + phi = phi / sens[1] - offset[1] + + return theta, phi + + +def get_angle_complex_BB_nopc( + ds_beam: xr.Dataset, ed: EchoData +) -> Tuple[xr.DataArray, xr.DataArray]: + """ + Obtains the split-beam angle data from complex samples from broadband transmit signals + without pulse compression. + + Parameters + ---------- + ds_beam: xr.Dataset + An ``EchoData`` beam group containing angle information needed for + split-beam angle calculation + ed: EchoData + An ``EchoData`` object holding the raw data + + Returns + ------- + theta: xr.Dataset + The calculated split-beam alongship angle + phi: xr.Dataset + The calculated split-beam athwartship angle + """ + + # nominal frequency [Hz] + freq_nominal = ds_beam["frequency_nominal"] + + # calculate center frequency + freq_center = (ds_beam["frequency_start"] + ds_beam["frequency_end"]).isel(beam=0) / 2 + + # obtain the angle alongship and athwartship offsets + offset_along, offset_athwart = _get_offset( + ds_beam=ds_beam, fc=freq_center, freq_nominal=freq_nominal, ed=ed + ) + + # obtain the angle sensitivity values alongship and athwartship + sens_along = ds_beam["angle_sensitivity_alongship"].isel(beam=0) * freq_center / freq_nominal + sens_athwart = ( + ds_beam["angle_sensitivity_athwartship"].isel(beam=0) * freq_center / freq_nominal + ) + + # get complex representation of backscatter + backscatter = ds_beam["backscatter_r"] + 1j * ds_beam["backscatter_i"] + + # initialize list that will hold split-beam angle data for each channel + theta_channels = [] + phi_channels = [] + + # obtain the split-beam angle data for each channel + for chan_id in backscatter.channel.values: + theta, phi = _compute_angle_from_complex( + bs=backscatter.sel(channel=chan_id), + beam_type=int(ds_beam["beam_type"].sel(channel=chan_id).isel(ping_time=0)), + sens=[sens_along.sel(channel=chan_id), sens_athwart.sel(channel=chan_id)], + offset=[offset_along.sel(channel=chan_id), offset_athwart.sel(channel=chan_id)], + ) + + theta_channels.append(theta) + phi_channels.append(phi) + + # collect and construct final DataArrays for split-beam angle data + theta = xr.DataArray( + data=theta_channels, + coords={ + "channel": backscatter.channel, + "ping_time": theta_channels[0].ping_time, + "range_sample": theta_channels[0].range_sample, + }, + ) + + phi = xr.DataArray( + data=phi_channels, + coords={ + "channel": backscatter.channel, + "ping_time": phi_channels[0].ping_time, + "range_sample": phi_channels[0].range_sample, + }, + ) + + return theta, phi + + +def get_angle_complex_BB_pc(ds_beam: xr.Dataset) -> Tuple[xr.DataArray, xr.DataArray]: + """ + Obtains the split-beam angle data from complex samples from broadband transmit signals + after pulse compression. + + Parameters + ---------- + ds_beam: xr.Dataset + An ``EchoData`` beam group containing angle information needed for + split-beam angle calculation + + Returns + ------- + theta: xr.Dataset + The calculated split-beam alongship angle + phi: xr.Dataset + The calculated split-beam athwartship angle + """ + + # TODO: make sure to check that the appropriate beam_type is being used + raise NotImplementedError( + "Obtaining the split-beam angle data using pulse compressed " + "backscatter has not been implemented!" + ) + + return xr.DataArray(), xr.DataArray() + + +def add_angle_to_ds( + theta: xr.Dataset, + phi: xr.Dataset, + ds: xr.Dataset, + return_dataset: bool, + source_ds_path: Optional[str] = None, + file_type: Optional[str] = None, + storage_options: dict = {}, +) -> Optional[xr.Dataset]: + """ + Adds the split-beam angle data to the provided input ``ds``. + + Parameters + ---------- + theta: xr.Dataset + The calculated split-beam alongship angle + phi: xr.Dataset + The calculated split-beam athwartship angle + ds: xr.Dataset + The Dataset that ``theta`` and ``phi`` will be added to + return_dataset: bool + Whether a dataset will be returned or not + source_ds_path: str, optional + The path to the file corresponding to ``ds``, if it exists + file_type: {"netcdf4", "zarr"}, optional + The file type corresponding to ``source_ds_path`` + storage_options: dict, default={} + Any additional parameters for the storage backend, corresponding to the + path ``source_ds_path`` + + Returns + ------- + xr.Dataset or None + If ``return_dataset=False``, nothing will be returned. If ``return_dataset=True`` + either the input dataset ``ds`` or a lazy-loaded Dataset (obtained from + the path provided by ``source_ds_path``) with the split-beam angle data added + will be returned. + """ + + # TODO: do we want to add anymore attributes to these variables? + # add appropriate attributes to theta and phi + theta.attrs["long_name"] = "split-beam alongship angle" + phi.attrs["long_name"] = "split-beam athwartship angle" + + if source_ds_path is not None: + + # put the variables into a Dataset, so they can be written at the same time + # add ds attributes to splitb_ds since they will be overwritten by to_netcdf/zarr + splitb_ds = xr.Dataset( + data_vars={"angle_alongship": theta, "angle_athwartship": phi}, + coords=theta.coords, + attrs=ds.attrs, + ) + + # release any resources linked to ds (necessary for to_netcdf) + ds.close() + + # write the split-beam angle data to the provided path + if file_type == "netcdf4": + splitb_ds.to_netcdf(path=source_ds_path, mode="a", **storage_options) + else: + splitb_ds.to_zarr(store=source_ds_path, mode="a", **storage_options) + + else: + + # add the split-beam angles to the provided Dataset + ds["angle_alongship"] = theta + ds["angle_athwartship"] = phi + + if return_dataset and (source_ds_path is not None): + + # open up and return Dataset in source_ds_path + return xr.open_dataset(source_ds_path, engine=file_type, chunks={}, **storage_options) + + elif return_dataset: + + # return input dataset with split-beam angle data + return ds diff --git a/echopype/echodata/simrad.py b/echopype/echodata/simrad.py new file mode 100644 index 0000000000..195f68e1f5 --- /dev/null +++ b/echopype/echodata/simrad.py @@ -0,0 +1,229 @@ +""" +Contains functions that are specific to Simrad echo sounders +""" +from typing import Optional, Tuple + +from .echodata import EchoData + + +def _check_input_args_combination( + waveform_mode: str, encode_mode: str, pulse_compression: bool +) -> None: + """ + Checks that the ``waveform_mode`` and ``encode_mode`` have + the correct values and that the combination of input arguments are valid, without + considering the actual data. + + Parameters + ---------- + waveform_mode: str + Type of transmit waveform + encode_mode: str + Type of encoded return echo data + pulse_compression: bool + States whether pulse compression should be used + """ + + if waveform_mode not in ["CW", "BB"]: + raise ValueError("The input waveform_mode must be either 'CW' or 'BB'!") + + if encode_mode not in ["complex", "power"]: + raise ValueError("The input encode_mode must be either 'complex' or 'power'!") + + # BB has complex data only, but CW can have complex or power data + if (waveform_mode == "BB") and (encode_mode == "power"): + raise ValueError("encode_mode='power' not allowed when waveform_mode='BB'!") + + # make sure that we have BB and complex inputs, if pulse compression is selected + if pulse_compression and ((waveform_mode != "BB") or (encode_mode != "complex")): + raise ValueError( + "Pulse compression can only be used with " + "waveform_mode='BB' and encode_mode='complex'" + ) + + +def _retrieve_correct_beam_group_EK60( + echodata: EchoData, waveform_mode: str, encode_mode: str +) -> Optional[str]: + """ + Ensures that the provided ``waveform_mode`` and ``encode_mode`` are consistent + with the EK60-like data supplied by ``echodata``. Additionally, select the + appropriate beam group corresponding to this input. + + Parameters + ---------- + echodata: EchoData + An ``EchoData`` object holding the data + waveform_mode : {"CW", "BB"} + Type of transmit waveform + encode_mode : {"complex", "power"} + Type of encoded return echo data + + Returns + ------- + power_ed_group: str, optional + The ``EchoData`` beam group path containing the power data + """ + + # initialize power EchoData group value + power_ed_group = None + + # EK60-like sensors must have 'power' and 'CW' modes only + if waveform_mode != "CW": + raise RuntimeError("Incorrect waveform_mode input provided!") + if encode_mode != "power": + raise RuntimeError("Incorrect encode_mode input provided!") + + # ensure that no complex data exists (this should never be triggered) + if "backscatter_i" in echodata["Sonar/Beam_group1"].variables: + raise RuntimeError( + "Provided echodata object does not correspond to an EK60-like " + "sensor, but is labeled as data from an EK60-like sensor!" + ) + else: + power_ed_group = "Sonar/Beam_group1" + + return power_ed_group + + +def _retrieve_correct_beam_group_EK80( + echodata: EchoData, waveform_mode: str, encode_mode: str +) -> Tuple[Optional[str], Optional[str]]: + """ + Ensures that the provided ``waveform_mode`` and ``encode_mode`` are consistent + with the EK80-like data supplied by ``echodata``. Additionally, select the + appropriate beam group corresponding to this input. + + Parameters + ---------- + echodata: EchoData + An ``EchoData`` object holding the data + waveform_mode : {"CW", "BB"} + Type of transmit waveform + encode_mode : {"complex", "power"} + Type of encoded return echo data + + Returns + ------- + power_ed_group: str, optional + The ``EchoData`` beam group path containing the power data + complex_ed_group: str, optional + The ``EchoData`` beam group path containing the complex data + """ + + # initialize power and complex EchoData group values + power_ed_group = None + complex_ed_group = None + + if waveform_mode == "BB": + + # check BB waveform_mode, BB must always have complex data, can have 2 beam groups + # when echodata contains CW power and BB complex samples, and frequency_start + # variable in Beam_group1 + if waveform_mode == "BB" and "frequency_start" not in echodata["Sonar/Beam_group1"]: + raise RuntimeError("waveform_mode='BB', but broadband data not found!") + elif "backscatter_i" not in echodata["Sonar/Beam_group1"].variables: + raise RuntimeError("waveform_mode='BB', but complex data does not exist!") + elif echodata["Sonar/Beam_group2"] is not None: + power_ed_group = "Sonar/Beam_group2" + complex_ed_group = "Sonar/Beam_group1" + else: + complex_ed_group = "Sonar/Beam_group1" + + else: + + # CW can have complex or power data, so we just need to make sure that + # 1) complex samples always exist in Sonar/Beam_group1 + # 2) power samples are in Sonar/Beam_group1 if only one beam group exists + # 3) power samples are in Sonar/Beam_group2 if two beam groups exist + if echodata["Sonar/Beam_group2"] is None: + + if encode_mode == "power": + # power samples must be in Sonar/Beam_group1 (thus no complex samples) + if "backscatter_i" in echodata["Sonar/Beam_group1"].variables: + raise RuntimeError("Data provided does not correspond to encode_mode='power'!") + else: + power_ed_group = "Sonar/Beam_group1" + elif encode_mode == "complex": + # complex samples must be in Sonar/Beam_group1 + if "backscatter_i" not in echodata["Sonar/Beam_group1"].variables: + raise RuntimeError( + "Data provided does not correspond to encode_mode='complex'!" + ) + else: + complex_ed_group = "Sonar/Beam_group1" + + else: + + # complex should be in Sonar/Beam_group1 and power in Sonar/Beam_group2 + # the RuntimeErrors below should never be triggered + if "backscatter_i" not in echodata["Sonar/Beam_group1"].variables: + raise RuntimeError( + "Complex data does not exist in Sonar/Beam_group1, " + "input echodata object must have been incorrectly constructed!" + ) + elif "backscatter_r" not in echodata["Sonar/Beam_group2"].variables: + raise RuntimeError( + "Power data does not exist in Sonar/Beam_group2, " + "input echodata object must have been incorrectly constructed!" + ) + else: + complex_ed_group = "Sonar/Beam_group1" + power_ed_group = "Sonar/Beam_group2" + + return power_ed_group, complex_ed_group + + +def retrieve_correct_beam_group( + echodata: EchoData, waveform_mode: str, encode_mode: str, pulse_compression: bool +) -> str: + """ + A function to make sure that the user has provided the correct + ``waveform_mode`` and ``encode_mode`` inputs based off of the + supplied ``echodata`` object. Additionally, determine the + ``EchoData`` beam group corresponding to ``encode_mode``. + + Parameters + ---------- + echodata: EchoData + An ``EchoData`` object holding the data corresponding to the + waveform and encode modes + waveform_mode : {"CW", "BB"} + Type of transmit waveform + encode_mode : {"complex", "power"} + Type of encoded return echo data + pulse_compression: bool + States whether pulse compression should be used + + Returns + ------- + str + The ``EchoData`` beam group path corresponding to the ``encode_mode`` input + """ + + # checks input and logic of modes without referencing data + _check_input_args_combination(waveform_mode, encode_mode, pulse_compression) + + if echodata.sonar_model in ["EK60", "ES70"]: + + # initialize complex_data_location (needed only for EK60) + complex_ed_group = None + + # check modes against data for EK60 and get power EchoData group + power_ed_group = _retrieve_correct_beam_group_EK60(echodata, waveform_mode, encode_mode) + + elif echodata.sonar_model in ["EK80", "ES80", "EA640"]: + + # check modes against data for EK80 and get power/complex EchoData groups + power_ed_group, complex_ed_group = _retrieve_correct_beam_group_EK80( + echodata, waveform_mode, encode_mode + ) + + else: + # raise error for unknown or unaccounted for sonar model + raise RuntimeError("EchoData was produced by a non-Simrad or unknown Simrad echo sounder!") + + if encode_mode == "complex": + return complex_ed_group + else: + return power_ed_group diff --git a/echopype/tests/consolidate/test_consolidate.py b/echopype/tests/consolidate/test_consolidate.py index a7d9162a7e..41070a5746 100644 --- a/echopype/tests/consolidate/test_consolidate.py +++ b/echopype/tests/consolidate/test_consolidate.py @@ -1,10 +1,25 @@ +import pathlib + import pytest import numpy as np import pandas as pd import xarray as xr - +import scipy.io as io import echopype as ep +from typing import List +import tempfile +import os + +""" +For future reference: + +For ``test_add_splitbeam_angle`` the test data is in the following locations: +- the EK60 raw file is in `test_data/ek60/DY1801_EK60-D20180211-T164025.raw` and the +associated echoview split-beam data is in `test_data/ek60/splitbeam`. +- the EK80 raw file is in `test_data/ek80_bb_with_calibration/2018115-D20181213-T094600.raw` and +the associated echoview split-beam data is in `test_data/ek80_bb_with_calibration/splitbeam` +""" @pytest.fixture( @@ -181,3 +196,144 @@ def _check_var(ds_test): ds_sel = ep.consolidate.add_location(ds=ds, echodata=ed, nmea_sentence="GGA") _check_var(ds_sel) + + +def _create_array_list_from_echoview_mats(paths_to_echoview_mat: List[pathlib.Path]) -> List[np.ndarray]: + """ + Opens each mat file in ``paths_to_echoview_mat``, selects the first ``ping_time``, + and then stores the array in a list. + + Parameters + ---------- + paths_to_echoview_mat: list of pathlib.Path + A list of paths corresponding to mat files, where each mat file contains the + echoview generated angle alongship and athwartship data for a channel + + Returns + ------- + list of np.ndarray + A list of numpy arrays generated by choosing the appropriate data from the mat files. + This list will have the same length as ``paths_to_echoview_mat`` + """ + + list_of_mat_arrays = [] + for mat_file in paths_to_echoview_mat: + + # open mat file and grab appropriate data + list_of_mat_arrays.append(io.loadmat(file_name=mat_file)["P0"]["Data_values"][0][0]) + + return list_of_mat_arrays + + +@pytest.mark.parametrize( + ("sonar_model", "test_path_key", "raw_file_name", "paths_to_echoview_mat", + "waveform_mode", "encode_mode", "pulse_compression", "write_Sv_to_file"), + [ + ( + "EK60", "EK60", "DY1801_EK60-D20180211-T164025.raw", + [ + 'splitbeam/DY1801_EK60-D20180211-T164025_angles_T1.mat', + 'splitbeam/DY1801_EK60-D20180211-T164025_angles_T2.mat', + 'splitbeam/DY1801_EK60-D20180211-T164025_angles_T3.mat', + 'splitbeam/DY1801_EK60-D20180211-T164025_angles_T4.mat', + 'splitbeam/DY1801_EK60-D20180211-T164025_angles_T5.mat' + ], + "CW", "power", False, False + ), + ( + "EK60", "EK60", "DY1801_EK60-D20180211-T164025.raw", + [ + 'splitbeam/DY1801_EK60-D20180211-T164025_angles_T1.mat', + 'splitbeam/DY1801_EK60-D20180211-T164025_angles_T2.mat', + 'splitbeam/DY1801_EK60-D20180211-T164025_angles_T3.mat', + 'splitbeam/DY1801_EK60-D20180211-T164025_angles_T4.mat', + 'splitbeam/DY1801_EK60-D20180211-T164025_angles_T5.mat' + ], + "CW", "power", False, True + ), + ( + "EK80", "EK80_CAL", "2018115-D20181213-T094600.raw", + [ + 'splitbeam/2018115-D20181213-T094600_angles_T1.mat', + 'splitbeam/2018115-D20181213-T094600_angles_T4.mat', + 'splitbeam/2018115-D20181213-T094600_angles_T6.mat', + 'splitbeam/2018115-D20181213-T094600_angles_T5.mat' + ], + "CW", "complex", False, False + ), + ( + "EK80", "EK80_CAL", "2018115-D20181213-T094600.raw", + [ + 'splitbeam/2018115-D20181213-T094600_angles_T3_nopc.mat', + 'splitbeam/2018115-D20181213-T094600_angles_T2_nopc.mat', + ], + "BB", "complex", False, False, + ), + ], + ids=["ek60_CW_power", "ek60_CW_power_Sv_path", "ek80_CW_complex", "ek80_BB_complex_no_pulse"] +) +def test_add_splitbeam_angle(sonar_model, test_path_key, raw_file_name, test_path, + paths_to_echoview_mat, waveform_mode, encode_mode, + pulse_compression, write_Sv_to_file): + + # obtain the EchoData object with the data needed for the calculation + ed = ep.open_raw(test_path[test_path_key] / raw_file_name, sonar_model=sonar_model) + + # compute Sv as it is required for the split-beam angle calculation + ds_Sv = ep.calibrate.compute_Sv(ed, waveform_mode=waveform_mode, encode_mode=encode_mode) + + # initialize temporary directory object + temp_dir = None + + # allows us to test for the case when source_Sv is a path + if write_Sv_to_file: + + # create temporary directory for mask_file + temp_dir = tempfile.TemporaryDirectory() + + # write DataArray to temporary directory + zarr_path = os.path.join(temp_dir.name, "Sv_data.zarr") + ds_Sv.to_zarr(zarr_path) + + # assign input to a path + ds_Sv = zarr_path + + # add the split-beam angles to an empty Dataset + ds_Sv = ep.consolidate.add_splitbeam_angle(source_Sv=ds_Sv, echodata=ed, + waveform_mode=waveform_mode, + encode_mode=encode_mode, + pulse_compression=pulse_compression) + + # obtain corresponding echoview output + full_echoview_path = [test_path[test_path_key] / path for path in paths_to_echoview_mat] + echoview_arr_list = _create_array_list_from_echoview_mats(full_echoview_path) + + # compare echoview output against computed output for all channels + for chan_ind in range(len(echoview_arr_list)): + + # grabs the appropriate ds data to compare against + reduced_angle_alongship = ds_Sv.isel(channel=chan_ind, ping_time=0).angle_alongship.dropna("range_sample") + reduced_angle_athwartship = ds_Sv.isel(channel=chan_ind, ping_time=0).angle_athwartship.dropna("range_sample") + + # TODO: make "start" below a parameter in the input so that this is not ad-hoc but something known + # for some files the echoview data is shifted by one index, here we account for that + if reduced_angle_alongship.shape == (echoview_arr_list[chan_ind].shape[1], ): + start = 0 + else: + start = 1 + + # check the computed angle_alongship values against the echoview output + assert np.allclose(reduced_angle_alongship.values[start:], + echoview_arr_list[chan_ind][0, :], rtol=1e-1, atol=1e-2) + + # check the computed angle_alongship values against the echoview output + assert np.allclose(reduced_angle_athwartship.values[start:], + echoview_arr_list[chan_ind][1, :], rtol=1e-1, atol=1e-2) + + if temp_dir: + # remove the temporary directory, if it was created + temp_dir.cleanup() + + +# TODO: need a test for power/angle data, with mock EchoData object +# containing some channels with single-beam data and some channels with split-beam data \ No newline at end of file diff --git a/echopype/tests/echodata/test_echodata_simrad.py b/echopype/tests/echodata/test_echodata_simrad.py new file mode 100644 index 0000000000..e8b3184727 --- /dev/null +++ b/echopype/tests/echodata/test_echodata_simrad.py @@ -0,0 +1,69 @@ +""" +Tests functions contained within echodata/simrad.py +""" +import pytest +from echopype.echodata.simrad import retrieve_correct_beam_group, _check_input_args_combination + + +@pytest.mark.parametrize( + ("waveform_mode", "encode_mode", "pulse_compression"), + [ + pytest.param("CW", "comp_power", None, + marks=pytest.mark.xfail(strict=True, + reason='This test should fail since comp_power ' + 'is not an acceptable choice for encode_mode.')), + pytest.param("CB", None, None, + marks=pytest.mark.xfail(strict=True, + reason='This test should fail since CB is not an ' + 'acceptable choice for waveform_mode.')), + pytest.param("BB", "power", None, + marks=pytest.mark.xfail(strict=True, + reason='This test should fail since BB and power is ' + 'not an acceptable combination.')), + pytest.param("BB", "power", True, + marks=pytest.mark.xfail(strict=True, + reason='This test should fail since BB and complex ' + 'must be used if pulse_compression is True.')), + pytest.param("CW", "complex", True, + marks=pytest.mark.xfail(strict=True, + reason='This test should fail since BB and complex ' + 'must be used if pulse_compression is True.')), + pytest.param("CW", "power", True, + marks=pytest.mark.xfail(strict=True, + reason='This test should fail since BB and complex ' + 'must be used if pulse_compression is True.')), + ("CW", "complex", False), + ("CW", "power", False), + ("BB", "complex", False), + ("BB", "complex", True), + + ], + ids=["incorrect_encode_mode", "incorrect_waveform_mode", "BB_power_combo", + "BB_power_pc_True", "CW_complex_pc_True", "CW_power_pc_True", "CW_complex_pc_False", + "CW_power_pc_False", "BB_complex_pc_False", "BB_complex_pc_True"] +) +def test_check_input_args_combination(waveform_mode: str, encode_mode: str, + pulse_compression: bool): + """ + Ensures that ``check_input_args_combination`` functions correctly when + provided various combinations of the input parameters. + + Parameters + ---------- + waveform_mode: str + Type of transmit waveform + encode_mode: str + Type of encoded return echo data + pulse_compression: bool + States whether pulse compression should be used + """ + + _check_input_args_combination(waveform_mode, encode_mode, pulse_compression) + + +def test_retrieve_correct_beam_group(): + + # TODO: create this test once we are happy with the form of retrieve_correct_beam_group + + pytest.skip("We need to add tests for retrieve_correct_beam_group!") + From 9e19736f65ca5b3c62349197d725a41b8055ce43 Mon Sep 17 00:00:00 2001 From: Wu-Jung Lee Date: Thu, 22 Dec 2022 21:22:00 -0800 Subject: [PATCH 8/8] Redo: Add unit tests for functions used in `apply_mask` (#917) See #911 for all conversations and detailed commits. --- echopype/mask/api.py | 8 +- echopype/tests/mask/test_mask.py | 284 +++++++++++++++++++++++++------ echopype/utils/io.py | 9 +- 3 files changed, 241 insertions(+), 60 deletions(-) diff --git a/echopype/mask/api.py b/echopype/mask/api.py index e45e730e66..eba7ca858f 100644 --- a/echopype/mask/api.py +++ b/echopype/mask/api.py @@ -17,7 +17,7 @@ } -def validate_and_collect_mask_input( +def _validate_and_collect_mask_input( mask: Union[ Union[xr.DataArray, str, pathlib.Path], List[Union[xr.DataArray, str, pathlib.Path]] ], @@ -77,7 +77,7 @@ def validate_and_collect_mask_input( ) # replace mask element path with its corresponding DataArray - if isinstance(mask_val, str): + if isinstance(mask_val, (str, pathlib.Path)): # open up DataArray using mask path mask[mask_ind] = xr.open_dataarray( mask_val, engine=file_type, chunks={}, **storage_options_mask[mask_ind] @@ -94,7 +94,7 @@ def validate_and_collect_mask_input( # validate the mask type or path (if it is provided) mask, file_type = validate_source_ds_da(mask, storage_options_mask) - if isinstance(mask, str): + if isinstance(mask, (str, pathlib.Path)): # open up DataArray using mask path mask = xr.open_dataarray(mask, engine=file_type, chunks={}, **storage_options_mask) @@ -202,7 +202,7 @@ def apply_mask( source_ds = xr.open_dataset(source_ds, engine=file_type, chunks={}, **storage_options_ds) # validate and form the mask input to be used downstream - mask = validate_and_collect_mask_input(mask, storage_options_mask) + mask = _validate_and_collect_mask_input(mask, storage_options_mask) # ensure that var_name and fill_value were correctly provided _check_var_name_fill_value(source_ds, var_name, fill_value) diff --git a/echopype/tests/mask/test_mask.py b/echopype/tests/mask/test_mask.py index de24e6c5fb..47613f378e 100644 --- a/echopype/tests/mask/test_mask.py +++ b/echopype/tests/mask/test_mask.py @@ -10,7 +10,11 @@ import echopype as ep import echopype.mask -from echopype.mask.api import _check_source_Sv_freq_diff +from echopype.mask.api import ( + _check_source_Sv_freq_diff, + _validate_and_collect_mask_input, + _check_var_name_fill_value +) from typing import List, Union, Optional @@ -59,7 +63,7 @@ def get_mock_freq_diff_data(n: int, n_chan_freq: int, add_chan: bool, mat_A = np.arange(n ** 2).reshape(n, n) # construct channel values - chan_vals = ['chan' + str(i) for i in range(1, n_chan_freq+1)] + chan_vals = ['chan' + str(i) for i in range(1, n_chan_freq + 1)] # construct mock Sv data mock_Sv_data = [mat_A, np.identity(n), mat_B] + [np.identity(n) for i in range(3, n_chan_freq)] @@ -142,6 +146,95 @@ def get_mock_source_ds_apply_mask(n: int, n_chan: int, is_delayed: bool) -> xr.D return mock_ds +def create_input_mask( + mask: Union[np.ndarray, List[np.ndarray]], + mask_file: Optional[Union[str, List[str]]], + mask_coords: Union[xr.core.coordinates.DataArrayCoordinates, dict], + n_chan: int +): + """ + A helper function that correctly constructs the mask input, so it can be + used for ``apply_mask`` related tests. + + Parameters + ---------- + mask: np.ndarray or list of np.ndarray + The mask(s) that should be applied to ``var_name`` + mask_file: str or list of str, optional + If provided, the ``mask`` input will be written to a temporary directory + with file name ``mask_file``. This will then be used in ``apply_mask``. + mask_coords: xr.core.coordinates.DataArrayCoordinates or dict + The DataArray coordinates that should be used for each mask DataArray created + n_chan: int + Determines the size of the ``channel`` coordinate + """ + + # initialize temp_dir + temp_dir = None + + # make input numpy array masks into DataArrays + if isinstance(mask, list): + + # initialize final mask + mask_out = [] + + # create temporary directory if mask_file is provided + if any([isinstance(elem, str) for elem in mask_file]): + # create temporary directory for mask_file + temp_dir = tempfile.TemporaryDirectory() + + for mask_ind in range(len(mask)): + + # form DataArray from given mask data + mask_da = xr.DataArray(data=[mask[mask_ind] for i in range(n_chan)], + coords=mask_coords, name='mask_' + str(mask_ind)) + + if mask_file[mask_ind] is None: + + # set mask value to the DataArray given + mask_out.append(mask_da) + else: + + # write DataArray to temporary directory + zarr_path = os.path.join(temp_dir.name, mask_file[mask_ind]) + mask_da.to_dataset().to_zarr(zarr_path) + + if isinstance(mask_file[mask_ind], pathlib.Path): + # make zarr_path into a Path object + zarr_path = pathlib.Path(zarr_path) + + # set mask value to created path + mask_out.append(zarr_path) + + elif isinstance(mask, np.ndarray): + + # form DataArray from given mask data + mask_da = xr.DataArray(data=[mask for i in range(n_chan)], + coords=mask_coords, name='mask_0') + + if mask_file is None: + + # set mask to the DataArray formed + mask_out = mask_da + else: + + # create temporary directory for mask_file + temp_dir = tempfile.TemporaryDirectory() + + # write DataArray to temporary directory + zarr_path = os.path.join(temp_dir.name, mask_file) + mask_da.to_dataset().to_zarr(zarr_path) + + if isinstance(mask_file, pathlib.Path): + # make zarr_path into a Path object + zarr_path = pathlib.Path(zarr_path) + + # set mask index to path + mask_out = zarr_path + + return mask_out, temp_dir + + @pytest.mark.parametrize( ("n", "n_chan_freq", "add_chan", "add_freq_nom", "freqAB", "chanAB"), [ @@ -275,6 +368,137 @@ def test_frequency_differencing(n: int, n_chan_freq: int, assert out.name == "mask" +@pytest.mark.parametrize( + ("n", "n_chan", "mask_np", "mask_file", "storage_options_mask"), + [ + (5, 1, np.identity(5), None, {}), + (5, 1, [np.identity(5), np.identity(5)], [None, None], {}), + (5, 1, [np.identity(5), np.identity(5)], [None, None], [{}, {}]), + (5, 1, np.identity(5), "path/to/mask.zarr", {}), + (5, 1, [np.identity(5), np.identity(5)], ["path/to/mask0.zarr", "path/to/mask1.zarr"], {}), + (5, 1, np.identity(5), pathlib.Path("path/to/mask.zarr"), {}), + (5, 1, [np.identity(5), np.identity(5), np.identity(5)], + [None, "path/to/mask0.zarr", pathlib.Path("path/to/mask1.zarr")], {}) + ], + ids=["mask_da", "mask_list_da_single_storage", "mask_list_da_list_storage", "mask_str_path", + "mask_list_str_path", "mask_pathlib", "mask_mixed_da_str_pathlib"] +) +def test_validate_and_collect_mask_input( + n: int, + n_chan: int, + mask_np: Union[np.ndarray, List[np.ndarray]], + mask_file: Optional[Union[str, pathlib.Path, List[Union[str, pathlib.Path]]]], + storage_options_mask: Union[dict, List[dict]]): + """ + Tests the allowable types for the mask input and corresponding storage options. + + Parameters + ---------- + n: int + The number of rows (``x``) and columns (``y``) of + each channel matrix + n_chan: int + Determines the size of the ``channel`` coordinate + mask_np: np.ndarray or list of np.ndarray + The mask(s) that should be applied to ``var_name`` + mask_file: str or list of str, optional + If provided, the ``mask`` input will be written to a temporary directory + with file name ``mask_file``. This will then be used in ``apply_mask``. + storage_options_mask: dict or list of dict, default={} + Any additional parameters for the storage backend, corresponding to the + path provided for ``mask`` + + Notes + ----- + The input for ``storage_options_mask`` will only contain the value `{}` or a list of + empty dictionaries as other options are already tested in + ``test_utils_io.py::test_validate_output_path`` and are therefore not included here. + """ + + # construct channel values + chan_vals = ['chan' + str(i) for i in range(1, n_chan + 1)] + + # create coordinates that will be used by all DataArrays created + coords = {"channel": ("channel", chan_vals, {"long_name": "channel name"}), + "x": np.arange(n), "y": np.arange(n)} + + # create input mask and obtain temporary directory, if it was created + mask, _ = create_input_mask(mask_np, mask_file, coords, n_chan) + + mask_out = _validate_and_collect_mask_input(mask=mask, storage_options_mask=storage_options_mask) + + if isinstance(mask_out, list): + for ind, da in enumerate(mask_out): + + # create known solution for mask + mask_da = xr.DataArray(data=[mask_np[ind] for i in range(n_chan)], + coords=coords, name='mask_' + str(ind)) + + assert da.identical(mask_da) + else: + + # create known solution for mask + mask_da = xr.DataArray(data=[mask_np for i in range(n_chan)], + coords=coords, name='mask_0') + assert mask_out.identical(mask_da) + + +@pytest.mark.parametrize( + ("n", "n_chan", "var_name", "fill_value"), + [ + pytest.param(4, 2, 2.0, np.nan, + marks=pytest.mark.xfail(strict=True, + reason="This should fail because the var_name is not a string.")), + pytest.param(4, 2, "var3", np.nan, + marks=pytest.mark.xfail(strict=True, + reason="This should fail because mock_ds will " + "not have var_name=var3 in it.")), + pytest.param(4, 2, "var1", "1.0", + marks=pytest.mark.xfail(strict=True, + reason="This should fail because fill_value is an incorrect type.")), + (4, 2, "var1", 1), + (4, 2, "var1", 1.0), + (2, 1, "var1", np.identity(2)[None, :]), + (2, 1, "var1", xr.DataArray(data=np.array([[[1.0, 0], [0, 1]]]), + coords={"channel": ["chan1"], "ping_time": [0, 1], "range_sample": [0, 1]}) + ), + pytest.param(4, 2, "var1", np.identity(2), + marks=pytest.mark.xfail(strict=True, + reason="This should fail because fill_value is not the right shape.")), + pytest.param(4, 2, "var1", + xr.DataArray(data=np.array([[1.0, 0], [0, 1]]), + coords={"ping_time": [0, 1], "range_sample": [0, 1]}), + marks=pytest.mark.xfail(strict=True, + reason="This should fail because fill_value is not the right shape.")), + ], + ids=["wrong_var_name_type", "no_var_name_ds", "wrong_fill_value_type", "fill_value_int", + "fill_value_float", "fill_value_np_array", "fill_value_DataArray", + "fill_value_np_array_wrong_shape", "fill_value_DataArray_wrong_shape"] +) +def test_check_var_name_fill_value(n: int, n_chan: int, var_name: str, + fill_value: Union[int, float, np.ndarray, xr.DataArray]): + """ + Ensures that the function ``_check_var_name_fill_value`` is behaving as expected. + + Parameters + ---------- + n: int + The number of rows (``x``) and columns (``y``) of + each channel matrix + n_chan: int + Determines the size of the ``channel`` coordinate + var_name: {"var1", "var2"} + The variable name in the mock Dataset to apply the mask to + fill_value: int, float, np.ndarray, or xr.DataArray + Value(s) at masked indices + """ + + # obtain mock Dataset containing var_name + mock_ds = get_mock_source_ds_apply_mask(n, n_chan, is_delayed=False) + + _check_var_name_fill_value(source_ds=mock_ds, var_name=var_name, fill_value=fill_value) + + @pytest.mark.parametrize( ("n", "n_chan", "var_name", "mask", "mask_file", "fill_value", "is_delayed", "var_masked_truth"), [ @@ -322,6 +546,8 @@ def test_apply_mask(n: int, n_chan: int, var_name: str, mask_file: str or list of str, optional If provided, the ``mask`` input will be written to a temporary directory with file name ``mask_file``. This will then be used in ``apply_mask``. + fill_value: int, float, np.ndarray, or xr.DataArray + Value(s) at masked indices var_masked_truth: np.ndarray The true value of ``var_name`` values after the mask has been applied is_delayed: bool @@ -332,58 +558,8 @@ def test_apply_mask(n: int, n_chan: int, var_name: str, # obtain mock Dataset containing var_name mock_ds = get_mock_source_ds_apply_mask(n, n_chan, is_delayed) - # initialize temp_dir - temp_dir = None - - # make input numpy array masks into DataArrays - if isinstance(mask, list): - - # create temporary directory if mask_file is provided - if any([isinstance(elem, str) for elem in mask_file]): - - # create temporary directory for mask_file - temp_dir = tempfile.TemporaryDirectory() - - for mask_ind in range(len(mask)): - - # form DataArray from given mask data - mask_da = xr.DataArray(data=np.stack([mask[mask_ind] for i in range(n_chan)]), - coords=mock_ds.coords, name='mask_' + str(mask_ind)) - - if mask_file[mask_ind] is None: - - # set mask value to the DataArray given - mask[mask_ind] = mask_da - else: - - # write DataArray to temporary directory - zarr_path = os.path.join(temp_dir.name, mask_file[mask_ind]) - mask_da.to_dataset().to_zarr(zarr_path) - - # set mask value to created path - mask[mask_ind] = zarr_path - - elif isinstance(mask, np.ndarray): - - # form DataArray from given mask data - mask_da = xr.DataArray(data=np.stack([mask for i in range(n_chan)]), - coords=mock_ds.coords, name='mask_0') - - if mask_file is None: - - # set mask to the DataArray formed - mask = mask_da - else: - - # create temporary directory for mask_file - temp_dir = tempfile.TemporaryDirectory() - - # write DataArray to temporary directory - zarr_path = os.path.join(temp_dir.name, mask_file) - mask_da.to_dataset().to_zarr(zarr_path) - - # set mask index to path - mask = zarr_path + # create input mask and obtain temporary directory, if it was created + mask, temp_dir = create_input_mask(mask, mask_file, mock_ds.coords, n_chan) # create DataArray form of the known truth value var_masked_truth = xr.DataArray(data=np.stack([var_masked_truth for i in range(n_chan)]), diff --git a/echopype/utils/io.py b/echopype/utils/io.py index dd8d77ffff..c42c0569e7 100644 --- a/echopype/utils/io.py +++ b/echopype/utils/io.py @@ -2,6 +2,7 @@ echopype utilities for file handling """ import os +import pathlib import platform import sys from pathlib import Path, WindowsPath @@ -60,9 +61,13 @@ def get_file_format(file): elif isinstance(file, FSMap): file = file.root - if file.endswith(".nc"): + if isinstance(file, str) and file.endswith(".nc"): return "netcdf4" - elif file.endswith(".zarr"): + elif isinstance(file, str) and file.endswith(".zarr"): + return "zarr" + elif isinstance(file, pathlib.Path) and file.suffix == ".nc": + return "netcdf4" + elif isinstance(file, pathlib.Path) and file.suffix == ".zarr": return "zarr" else: raise ValueError(f"Unsupported file format: {os.path.splitext(file)[1]}")