From 749520bfa230d8bc7c13641c40e4ad4e400a5d44 Mon Sep 17 00:00:00 2001 From: Philippe Miron Date: Tue, 12 Mar 2024 16:10:37 -0400 Subject: [PATCH] Switch to row instead of trajectory (#376) * Utilize `rows` for naming legacy `traj` dimension which is mostly relevant in oceanographic datasets while rows is more generalized * remove coord dim map and map coords to dim aliases. * Map dim alias to library required dims --------- Co-authored-by: Philippe Miron Co-authored-by: Shane Elipot Co-authored-by: Kevin Santana --- clouddrift/adapters/andro.py | 8 +- clouddrift/adapters/gdp.py | 8 +- clouddrift/adapters/gdp1h.py | 69 +++++--- clouddrift/adapters/gdp6h.py | 66 +++++--- clouddrift/adapters/glad.py | 1 + clouddrift/adapters/mosaic.py | 1 + clouddrift/adapters/subsurface_floats.py | 4 +- clouddrift/adapters/utils.py | 6 +- clouddrift/adapters/yomaha.py | 11 +- clouddrift/datasets.py | 5 +- clouddrift/pairs.py | 1 + clouddrift/ragged.py | 88 +++++----- clouddrift/raggedarray.py | 198 ++++++++++++++--------- pyproject.toml | 2 +- tests/datasets_tests.py | 7 +- tests/pairs_tests.py | 4 +- tests/ragged_tests.py | 42 ++--- tests/raggedarray_tests.py | 42 ++++- 18 files changed, 348 insertions(+), 215 deletions(-) diff --git a/clouddrift/adapters/andro.py b/clouddrift/adapters/andro.py index e318c1ce..80fb4193 100644 --- a/clouddrift/adapters/andro.py +++ b/clouddrift/adapters/andro.py @@ -1,6 +1,6 @@ """ -This module defines functions used to adapt the ANDRO: An Argo-based -deep displacement dataset as a ragged-arrays dataset. +This module defines functions used to adapt the ANDRO: An Argo-based +deep displacement dataset as a ragged-arrays dataset. The dataset is hosted at https://www.seanoe.org/data/00360/47077/ and the user manual is available at https://archimer.ifremer.fr/doc/00360/47126/. @@ -12,8 +12,8 @@ Reference --------- -Ollitrault Michel, Rannou Philippe, Brion Emilie, Cabanes Cecile, Piron Anne, Reverdin Gilles, -Kolodziejczyk Nicolas (2022). ANDRO: An Argo-based deep displacement dataset. +Ollitrault Michel, Rannou Philippe, Brion Emilie, Cabanes Cecile, Piron Anne, Reverdin Gilles, +Kolodziejczyk Nicolas (2022). ANDRO: An Argo-based deep displacement dataset. SEANOE. https://doi.org/10.17882/47077 """ diff --git a/clouddrift/adapters/gdp.py b/clouddrift/adapters/gdp.py index 40005b9e..1fe8b1f8 100644 --- a/clouddrift/adapters/gdp.py +++ b/clouddrift/adapters/gdp.py @@ -14,9 +14,11 @@ from clouddrift.adapters.utils import download_with_progress from clouddrift.raggedarray import DimNames -GDP_COORDS: list[tuple[str, DimNames]] = [ - ("id", "traj"), - ("time", "obs"), +GDP_DIMS: dict[str, DimNames] = {"traj": "rows", "obs": "obs"} + +GDP_COORDS = [ + "id", + "time", ] GDP_METADATA = [ diff --git a/clouddrift/adapters/gdp1h.py b/clouddrift/adapters/gdp1h.py index 60f4d856..6b0c9526 100644 --- a/clouddrift/adapters/gdp1h.py +++ b/clouddrift/adapters/gdp1h.py @@ -23,9 +23,9 @@ GDP_VERSION = "2.01" -GDP_DATA_URL = "https://www.aoml.noaa.gov/ftp/pub/phod/buoydata/hourly_product/v2.01/" +GDP_DATA_URL = "https://www.aoml.noaa.gov/ftp/pub/phod/buoydata/hourly_product/v2.01" GDP_DATA_URL_EXPERIMENTAL = ( - "https://www.aoml.noaa.gov/ftp/pub/phod/lumpkin/hourly/experimental/" + "https://www.aoml.noaa.gov/ftp/pub/phod/lumpkin/hourly/experimental" ) @@ -113,7 +113,7 @@ def download( gdp_metadata = gdp.get_gdp_metadata() return gdp.order_by_date( - gdp_metadata, [int(f.split("_")[-1][:-3]) for f in filelist] + gdp_metadata, [int(f.split("_")[-1].removesuffix(".nc")) for f in filelist] ) @@ -215,24 +215,33 @@ def preprocess(index: int, **kwargs) -> xr.Dataset: [False if ds.get("location_type") == "Argos" else True], ) # 0 for Argos, 1 for GPS ds["DeployingShip"] = (("traj"), gdp.cut_str(ds.DeployingShip, 20)) - ds["DeploymentStatus"] = (("traj"), gdp.cut_str(ds.DeploymentStatus, 20)) - ds["BuoyTypeManufacturer"] = (("traj"), gdp.cut_str(ds.BuoyTypeManufacturer, 20)) - ds["BuoyTypeSensorArray"] = (("traj"), gdp.cut_str(ds.BuoyTypeSensorArray, 20)) + ds["DeploymentStatus"] = ( + ("traj"), + gdp.cut_str(ds.DeploymentStatus, 20), + ) + ds["BuoyTypeManufacturer"] = ( + ("traj"), + gdp.cut_str(ds.BuoyTypeManufacturer, 20), + ) + ds["BuoyTypeSensorArray"] = ( + ("traj"), + gdp.cut_str(ds.BuoyTypeSensorArray, 20), + ) ds["CurrentProgram"] = ( ("traj"), np.array([gdp.str_to_float(ds.CurrentProgram, -1)], dtype=np.int32), ) - ds["PurchaserFunding"] = (("traj"), gdp.cut_str(ds.PurchaserFunding, 20)) + ds["PurchaserFunding"] = ( + ("traj"), + gdp.cut_str(ds.PurchaserFunding, 20), + ) ds["SensorUpgrade"] = (("traj"), gdp.cut_str(ds.SensorUpgrade, 20)) ds["Transmissions"] = (("traj"), gdp.cut_str(ds.Transmissions, 20)) - ds["DeployingCountry"] = (("traj"), gdp.cut_str(ds.DeployingCountry, 20)) - ds["DeploymentComments"] = ( + ds["DeployingCountry"] = ( ("traj"), - gdp.cut_str( - ds.DeploymentComments.encode("ascii", "ignore").decode("ascii"), 20 - ), - ) # remove non ascii char - ds["ManufactureYear"] = ( + gdp.cut_str(ds.DeployingCountry, 20), + ) + ds["DeploymentComments"] = ( ("traj"), np.array([gdp.str_to_float(ds.ManufactureYear, -1)], dtype=np.int16), ) @@ -240,10 +249,13 @@ def preprocess(index: int, **kwargs) -> xr.Dataset: ("traj"), np.array([gdp.str_to_float(ds.ManufactureMonth, -1)], dtype=np.int16), ) - ds["ManufactureSensorType"] = (("traj"), gdp.cut_str(ds.ManufactureSensorType, 20)) + ds["ManufactureSensorType"] = ( + ("traj"), + gdp.cut_str(ds.ManufactureSensorType, 20), + ) ds["ManufactureVoltage"] = ( ("traj"), - np.array([gdp.str_to_float(ds.ManufactureVoltage[:-6], -1)], dtype=np.int16), + np.array([gdp.str_to_float(ds.ManufactureVoltage[:-2], -1)], dtype=np.int16), ) # e.g. 56 V ds["FloatDiameter"] = ( ("traj"), @@ -270,12 +282,18 @@ def preprocess(index: int, **kwargs) -> xr.Dataset: ("traj"), [gdp.str_to_float(ds.DragAreaOfDrogue[:-4])], ) # e.g. 416.6 m^2 - ds["DragAreaRatio"] = (("traj"), [gdp.str_to_float(ds.DragAreaRatio)]) # e.g. 39.08 + ds["DragAreaRatio"] = ( + ("traj"), + [gdp.str_to_float(ds.DragAreaRatio)], + ) # e.g. 39.08 ds["DrogueCenterDepth"] = ( ("traj"), [gdp.str_to_float(ds.DrogueCenterDepth[:-2])], ) # e.g. 20.0 m - ds["DrogueDetectSensor"] = (("traj"), gdp.cut_str(ds.DrogueDetectSensor, 20)) + ds["DrogueDetectSensor"] = ( + ("traj"), + gdp.cut_str(ds.DrogueDetectSensor, 20), + ) # vars attributes vars_attrs = { @@ -581,9 +599,10 @@ def to_raggedarray( ra = RaggedArray.from_files( indices=ids, preprocess_func=preprocess, - coord_dim_map=gdp.GDP_COORDS, + name_coords=gdp.GDP_COORDS, name_meta=gdp.GDP_METADATA, name_data=GDP_DATA, + name_dims=gdp.GDP_DIMS, rowsize_func=gdp.rowsize, filename_pattern=filename_pattern, tmp_path=tmp_path, @@ -591,11 +610,11 @@ def to_raggedarray( # set dynamic global attributes if ra.attrs_global: - ra.attrs_global[ - "time_coverage_start" - ] = f"{datetime(1970,1,1) + timedelta(seconds=int(np.min(ra.coords['time']))):%Y-%m-%d:%H:%M:%SZ}" - ra.attrs_global[ - "time_coverage_end" - ] = f"{datetime(1970,1,1) + timedelta(seconds=int(np.max(ra.coords['time']))):%Y-%m-%d:%H:%M:%SZ}" + ra.attrs_global["time_coverage_start"] = ( + f"{datetime(1970,1,1) + timedelta(seconds=int(np.min(ra.coords['time']))):%Y-%m-%d:%H:%M:%SZ}" + ) + ra.attrs_global["time_coverage_end"] = ( + f"{datetime(1970,1,1) + timedelta(seconds=int(np.max(ra.coords['time']))):%Y-%m-%d:%H:%M:%SZ}" + ) return ra diff --git a/clouddrift/adapters/gdp6h.py b/clouddrift/adapters/gdp6h.py index e8ece7a6..2ece76ce 100644 --- a/clouddrift/adapters/gdp6h.py +++ b/clouddrift/adapters/gdp6h.py @@ -21,7 +21,7 @@ GDP_VERSION = "September 2023" -GDP_DATA_URL = "https://www.aoml.noaa.gov/ftp/pub/phod/buoydata/6h/" +GDP_DATA_URL = "https://www.aoml.noaa.gov/ftp/pub/phod/buoydata/6h" GDP_TMP_PATH = os.path.join(tempfile.gettempdir(), "clouddrift", "gdp6h") GDP_DATA = [ "lon", @@ -82,7 +82,7 @@ def download( string = urlpath.read().decode("utf-8") filelist = list(set(re.compile(pattern).findall(string))) for f in filelist: - did = int(f[:-3].split("_")[2]) + did = int(f.split("_")[2].removesuffix(".nc")) if (drifter_ids is None or did in drifter_ids) and did not in added: drifter_urls.append(f"{url}/{dir}/{f}") added.add(did) @@ -187,7 +187,10 @@ def preprocess(index: int, **kwargs) -> xr.Dataset: warnings.warn(f"Variable {var} not found in upstream data; skipping.") # new variables - ds["ids"] = (["traj", "obs"], [np.repeat(ds.ID.values, ds.sizes["obs"])]) + ds["ids"] = ( + ["traj", "obs"], + [np.repeat(ds.ID.values, ds.sizes["obs"])], + ) ds["drogue_status"] = ( ["traj", "obs"], [gdp.drogue_presence(ds.drogue_lost_date.data, ds.time.data[0])], @@ -199,17 +202,32 @@ def preprocess(index: int, **kwargs) -> xr.Dataset: [False if ds.get("location_type") == "Argos" else True], ) # 0 for Argos, 1 for GPS ds["DeployingShip"] = (("traj"), gdp.cut_str(ds.DeployingShip, 20)) - ds["DeploymentStatus"] = (("traj"), gdp.cut_str(ds.DeploymentStatus, 20)) - ds["BuoyTypeManufacturer"] = (("traj"), gdp.cut_str(ds.BuoyTypeManufacturer, 20)) - ds["BuoyTypeSensorArray"] = (("traj"), gdp.cut_str(ds.BuoyTypeSensorArray, 20)) + ds["DeploymentStatus"] = ( + ("traj"), + gdp.cut_str(ds.DeploymentStatus, 20), + ) + ds["BuoyTypeManufacturer"] = ( + ("traj"), + gdp.cut_str(ds.BuoyTypeManufacturer, 20), + ) + ds["BuoyTypeSensorArray"] = ( + ("traj"), + gdp.cut_str(ds.BuoyTypeSensorArray, 20), + ) ds["CurrentProgram"] = ( ("traj"), [np.int32(gdp.str_to_float(ds.CurrentProgram, -1))], ) - ds["PurchaserFunding"] = (("traj"), gdp.cut_str(ds.PurchaserFunding, 20)) + ds["PurchaserFunding"] = ( + ("traj"), + gdp.cut_str(ds.PurchaserFunding, 20), + ) ds["SensorUpgrade"] = (("traj"), gdp.cut_str(ds.SensorUpgrade, 20)) ds["Transmissions"] = (("traj"), gdp.cut_str(ds.Transmissions, 20)) - ds["DeployingCountry"] = (("traj"), gdp.cut_str(ds.DeployingCountry, 20)) + ds["DeployingCountry"] = ( + ("traj"), + gdp.cut_str(ds.DeployingCountry, 20), + ) ds["DeploymentComments"] = ( ("traj"), gdp.cut_str( @@ -224,10 +242,13 @@ def preprocess(index: int, **kwargs) -> xr.Dataset: ("traj"), [np.int16(gdp.str_to_float(ds.ManufactureMonth, -1))], ) - ds["ManufactureSensorType"] = (("traj"), gdp.cut_str(ds.ManufactureSensorType, 20)) + ds["ManufactureSensorType"] = ( + ("traj"), + gdp.cut_str(ds.ManufactureSensorType, 20), + ) ds["ManufactureVoltage"] = ( ("traj"), - [np.int16(gdp.str_to_float(ds.ManufactureVoltage[:-6], -1))], + [np.int16(gdp.str_to_float(ds.ManufactureVoltage[:-2], -1))], ) # e.g. 56 V ds["FloatDiameter"] = ( ("traj"), @@ -254,12 +275,18 @@ def preprocess(index: int, **kwargs) -> xr.Dataset: ("traj"), [gdp.str_to_float(ds.DragAreaOfDrogue[:-4])], ) # e.g. 416.6 m^2 - ds["DragAreaRatio"] = (("traj"), [gdp.str_to_float(ds.DragAreaRatio)]) # e.g. 39.08 + ds["DragAreaRatio"] = ( + ("traj"), + [gdp.str_to_float(ds.DragAreaRatio)], + ) # e.g. 39.08 ds["DrogueCenterDepth"] = ( ("traj"), [gdp.str_to_float(ds.DrogueCenterDepth[:-2])], ) # e.g. 20.0 m - ds["DrogueDetectSensor"] = (("traj"), gdp.cut_str(ds.DrogueDetectSensor, 20)) + ds["DrogueDetectSensor"] = ( + ("traj"), + gdp.cut_str(ds.DrogueDetectSensor, 20), + ) # vars attributes vars_attrs = { @@ -481,20 +508,21 @@ def to_raggedarray( ra = RaggedArray.from_files( indices=ids, preprocess_func=preprocess, - coord_dim_map=gdp.GDP_COORDS, + name_coords=gdp.GDP_COORDS, name_meta=gdp.GDP_METADATA, name_data=GDP_DATA, + name_dims=gdp.GDP_DIMS, rowsize_func=gdp.rowsize, filename_pattern="drifter_6h_{id}.nc", tmp_path=tmp_path, ) # update dynamic global attributes - ra.attrs_global[ - "time_coverage_start" - ] = f"{datetime.datetime(1970,1,1) + datetime.timedelta(seconds=int(np.min(ra.coords['time']))):%Y-%m-%d:%H:%M:%SZ}" - ra.attrs_global[ - "time_coverage_end" - ] = f"{datetime.datetime(1970,1,1) + datetime.timedelta(seconds=int(np.max(ra.coords['time']))):%Y-%m-%d:%H:%M:%SZ}" + ra.attrs_global["time_coverage_start"] = ( + f"{datetime.datetime(1970,1,1) + datetime.timedelta(seconds=int(np.min(ra.coords['time']))):%Y-%m-%d:%H:%M:%SZ}" + ) + ra.attrs_global["time_coverage_end"] = ( + f"{datetime.datetime(1970,1,1) + datetime.timedelta(seconds=int(np.max(ra.coords['time']))):%Y-%m-%d:%H:%M:%SZ}" + ) return ra diff --git a/clouddrift/adapters/glad.py b/clouddrift/adapters/glad.py index 2878ead6..b867dff8 100644 --- a/clouddrift/adapters/glad.py +++ b/clouddrift/adapters/glad.py @@ -13,6 +13,7 @@ --------- Özgökmen, Tamay. 2013. GLAD experiment CODE-style drifter trajectories (low-pass filtered, 15 minute interval records), northern Gulf of Mexico near DeSoto Canyon, July-October 2012. Distributed by: Gulf of Mexico Research Initiative Information and Data Cooperative (GRIIDC), Harte Research Institute, Texas A&M University–Corpus Christi. doi:10.7266/N7VD6WC8 """ + from io import BytesIO import numpy as np diff --git a/clouddrift/adapters/mosaic.py b/clouddrift/adapters/mosaic.py index 61f77689..e020d243 100644 --- a/clouddrift/adapters/mosaic.py +++ b/clouddrift/adapters/mosaic.py @@ -18,6 +18,7 @@ >>> from clouddrift.adapters import mosaic >>> ds = mosaic.to_xarray() """ + import xml.etree.ElementTree as ET from datetime import datetime from io import BytesIO diff --git a/clouddrift/adapters/subsurface_floats.py b/clouddrift/adapters/subsurface_floats.py index 51723dfa..a93d6864 100644 --- a/clouddrift/adapters/subsurface_floats.py +++ b/clouddrift/adapters/subsurface_floats.py @@ -1,6 +1,6 @@ """ -This module defines functions to adapt as a ragged-array dataset a collection of data -from 2193 trajectories of SOFAR, APEX, and RAFOS subsurface floats from 52 experiments +This module defines functions to adapt as a ragged-array dataset a collection of data +from 2193 trajectories of SOFAR, APEX, and RAFOS subsurface floats from 52 experiments across the world between 1989 and 2015. The dataset is hosted at https://www.aoml.noaa.gov/phod/float_traj/index.php diff --git a/clouddrift/adapters/utils.py b/clouddrift/adapters/utils.py index 99a3bd82..d18545d9 100644 --- a/clouddrift/adapters/utils.py +++ b/clouddrift/adapters/utils.py @@ -45,9 +45,9 @@ def download_with_progress( retry_protocol = custom_retry_protocol # type: ignore executor = concurrent.futures.ThreadPoolExecutor() - futures: dict[ - concurrent.futures.Future, Tuple[str, Union[BufferedIOBase, str]] - ] = dict() + futures: dict[concurrent.futures.Future, Tuple[str, Union[BufferedIOBase, str]]] = ( + dict() + ) bar = None for src, dst, exp_size in download_map: diff --git a/clouddrift/adapters/yomaha.py b/clouddrift/adapters/yomaha.py index 6a2a3b17..5ca908e4 100644 --- a/clouddrift/adapters/yomaha.py +++ b/clouddrift/adapters/yomaha.py @@ -1,7 +1,7 @@ """ -This module defines functions used to adapt the YoMaHa'07: Velocity data assessed -from trajectories of Argo floats at parking level and at the sea surface as -a ragged-arrays dataset. +This module defines functions used to adapt the YoMaHa'07: Velocity data assessed +from trajectories of Argo floats at parking level and at the sea surface as +a ragged-arrays dataset. The dataset is hosted at http://apdrc.soest.hawaii.edu/projects/yomaha/ and the user manual is available at http://apdrc.soest.hawaii.edu/projects/yomaha/yomaha07/YoMaHa070612.pdf. @@ -52,7 +52,7 @@ def download(tmp_path: str): download_with_progress(download_requests) filename_gz = f"{tmp_path}/{YOMAHA_URLS[-1].split('/')[-1]}" - filename = filename_gz[:-3] + filename = filename_gz.removesuffix(".gz") buffer = BytesIO() download_with_progress([(YOMAHA_URLS[-1], buffer, None)]) @@ -153,7 +153,8 @@ def to_xarray(tmp_path: Union[str, None] = None): ) # open with pandas - filename = f"{tmp_path}/{YOMAHA_URLS[-1].split('/')[-1][:-3]}" + filename_gz = f"{tmp_path}/{YOMAHA_URLS[-1].split('/')[-1]}" + filename = filename_gz.removesuffix(".gz") df = pd.read_csv( filename, names=col_names, sep=r"\s+", header=None, na_values=na_col ) diff --git a/clouddrift/datasets.py b/clouddrift/datasets.py index f729611f..77489aad 100644 --- a/clouddrift/datasets.py +++ b/clouddrift/datasets.py @@ -1,9 +1,10 @@ """ -This module provides functions to easily access ragged array datasets. If the datasets are +This module provides functions to easily access ragged array datasets. If the datasets are not accessed via cloud storage platforms or are not found on the local filesystem, -they will be downloaded from their upstream repositories and stored for later access +they will be downloaded from their upstream repositories and stored for later access (~/.clouddrift for UNIX-based systems). """ + import os import platform from io import BytesIO diff --git a/clouddrift/pairs.py b/clouddrift/pairs.py index 9948c91e..575500bc 100644 --- a/clouddrift/pairs.py +++ b/clouddrift/pairs.py @@ -1,6 +1,7 @@ """ Functions to analyze pairs of contiguous data segments. """ + import itertools from concurrent.futures import ThreadPoolExecutor, as_completed from typing import List, Optional, Tuple, Union diff --git a/clouddrift/ragged.py b/clouddrift/ragged.py index 1f0d40df..723793b8 100644 --- a/clouddrift/ragged.py +++ b/clouddrift/ragged.py @@ -109,7 +109,7 @@ def apply_ragged( if not np.sum(rowsize) == arr.shape[axis]: raise ValueError("The sum of rowsize must equal the length of arr.") - # split the array(s) into trajectories + # split the array(s) into rows arrays = [unpack(np.array(arr), rowsize, rows, axis) for arr in arrays] iter = [[arrays[i][j] for i in range(len(arrays))] for j in range(len(arrays[0]))] @@ -542,9 +542,9 @@ def subset( criteria: dict, id_var_name: str = "id", rowsize_var_name: str = "rowsize", - traj_dim_name: str = "traj", + row_dim_name: str = "rows", obs_dim_name: str = "obs", - full_trajectories=False, + full_rows=False, ) -> xr.Dataset: """Subset a ragged array xarray dataset as a function of one or more criteria. The criteria are passed with a dictionary, where a dictionary key @@ -553,8 +553,8 @@ def subset( masking function applied to any variable of the dataset. This function needs to know the names of the dimensions of the ragged array dataset - (`traj_dim_name` and `obs_dim_name`), and the name of the rowsize variable (`rowsize_var_name`). - Default values corresponds to the clouddrift convention ("traj", "obs", and "rowsize") but should + (`row_dim_name` and `obs_dim_name`), and the name of the rowsize variable (`rowsize_var_name`). + Default values corresponds to the clouddrift convention ("rows", "obs", and "rowsize") but should be changed as needed. Parameters @@ -564,15 +564,16 @@ def subset( criteria : dict Dictionary containing the variables (as keys) and the ranges/values/functions (as values) to subset. id_var_name : str, optional - Name of the variable with dimension `traj_dim_name` containing the ID of the trajectories (default is "id"). + Name of the variable with dimension `row_dim_name` containing the identification number of the + rows (default is "id"). rowsize_var_name : str, optional - Name of the variable containing the number of observations per trajectory (default is "rowsize"). - traj_dim_name : str, optional - Name of the trajectory dimension (default is "traj"). + Name of the variable containing the number of observations per row (default is "rowsize"). + row_dim_name : str, optional + Name of the row dimension (default is "rows"). obs_dim_name : str, optional Name of the observation dimension (default is "obs"). - full_trajectories : bool, optional - If True, the function returns complete rows (trajectories) for which the criteria + full_rows : bool, optional + If True, the function returns complete rows for which the criteria are matched at least once. Default is False which means that only segments matching the criteria are returned when filtering along the observation dimension. @@ -583,46 +584,47 @@ def subset( Examples -------- - Criteria are combined on any data (with dimension "obs") or metadata (with dimension "traj") variables + Criteria are combined on any data (with dimension "obs") or metadata (with dimension "rows") variables part of the Dataset. The following examples are based on NOAA GDP datasets which can be accessed with the - ``clouddrift.datasets`` module. + ``clouddrift.datasets`` module. In these datasets, each row of the ragged arrays corresponds to the data from + a single drifter trajectory and the `row_dim_name` is "traj" and the `obs_dim_name` is "obs". Retrieve a region, like the Gulf of Mexico, using ranges of latitude and longitude: - >>> subset(ds, {"lat": (21, 31), "lon": (-98, -78)}) + >>> subset(ds, {"lat": (21, 31), "lon": (-98, -78)}, row_dim_name="traj") - The parameter `full_trajectories` can be used to retrieve trajectories passing through a region, for example all trajectories passing through the Gulf of Mexico: + The parameter `full_rows` can be used to retrieve trajectories passing through a region, for example all trajectories passing through the Gulf of Mexico: - >>> subset(ds, {"lat": (21, 31), "lon": (-98, -78)}, full_trajectories=True) + >>> subset(ds, {"lat": (21, 31), "lon": (-98, -78)}, full_rows=True, row_dim_name="traj") Retrieve drogued trajectory segments: - >>> subset(ds, {"drogue_status": True}) + >>> subset(ds, {"drogue_status": True}, row_dim_name="traj") Retrieve trajectory segments with temperature higher than 25°C (303.15K): - >>> subset(ds, {"sst": (303.15, np.inf)}) + >>> subset(ds, {"sst": (303.15, np.inf)}, row_dim_name="traj") You can use the same approach to return only the trajectories that are shorter than some number of observations (similar to :func:`prune` but for the entire dataset): - >>> subset(ds, {"rowsize": (0, 1000)}) + >>> subset(ds, {"rowsize": (0, 1000)}, row_dim_name="traj") Retrieve specific drifters using their IDs: - >>> subset(ds, {"id": [2578, 2582, 2583]}) + >>> subset(ds, {"id": [2578, 2582, 2583]}, row_dim_name="traj") Sometimes, you may want to retrieve specific rows of a ragged array. You can do that by filtering along the trajectory dimension directly, since this one corresponds to row numbers: >>> rows = [5, 6, 7] - >>> subset(ds, {"traj": rows}) + >>> subset(ds, {"traj": rows}, row_dim_name="traj") Retrieve a specific time period: - >>> subset(ds, {"time": (np.datetime64("2000-01-01"), np.datetime64("2020-01-31"))}) + >>> subset(ds, {"time": (np.datetime64("2000-01-01"), np.datetime64("2020-01-31"))}, row_dim_name="traj") Note that to subset time variable, the range has to be defined as a function type of the variable. By default, ``xarray`` uses ``np.datetime64`` to @@ -631,13 +633,13 @@ def subset( Those criteria can also be combined: - >>> subset(ds, {"lat": (21, 31), "lon": (-98, -78), "drogue_status": True, "sst": (303.15, np.inf), "time": (np.datetime64("2000-01-01"), np.datetime64("2020-01-31"))}) + >>> subset(ds, {"lat": (21, 31), "lon": (-98, -78), "drogue_status": True, "sst": (303.15, np.inf), "time": (np.datetime64("2000-01-01"), np.datetime64("2020-01-31"))}, row_dim_name="traj") You can also use a function to filter the data. For example, retrieve every other observation - of each trajectory (row): + of each trajectory: >>> func = (lambda arr: ((arr - arr[0]) % 2) == 0) - >>> subset(ds, {"time": func}) + >>> subset(ds, {"time": func}, row_dim_name="traj") The filtering function can accept several input variables passed as a tuple. For example, retrieve drifters released in the Mediterranean Sea, but exclude those released in the Bay of Biscay and the Black Sea: @@ -651,7 +653,7 @@ def subset( >>> # Black Sea >>> in_blacksea = np.logical_and(lon >= 27.4437, lat >= 40.9088) >>> return np.logical_and(in_med, np.logical_not(np.logical_or(in_biscay, in_blacksea))) - >>> subset(ds, {("start_lon", "start_lat"): mediterranean_mask}) + >>> subset(ds, {("start_lon", "start_lat"): mediterranean_mask}, row_dim_name="traj") Raises ------ @@ -666,8 +668,8 @@ def subset( -------- :func:`apply_ragged` """ - mask_traj = xr.DataArray( - data=np.ones(ds.sizes[traj_dim_name], dtype="bool"), dims=[traj_dim_name] + mask_row = xr.DataArray( + data=np.ones(ds.sizes[row_dim_name], dtype="bool"), dims=[row_dim_name] ) mask_obs = xr.DataArray( data=np.ones(ds.sizes[obs_dim_name], dtype="bool"), dims=[obs_dim_name] @@ -686,11 +688,11 @@ def subset( criterion = ds[key] criterion_dims = criterion.dims - if criterion_dims == (traj_dim_name,): - mask_traj = np.logical_and( - mask_traj, + if criterion_dims == (row_dim_name,): + mask_row = np.logical_and( + mask_row, _mask_var( - criterion, criteria[key], ds[rowsize_var_name], traj_dim_name + criterion, criteria[key], ds[rowsize_var_name], row_dim_name ), ) elif criterion_dims == (obs_dim_name,): @@ -703,33 +705,33 @@ def subset( else: raise ValueError(f"Unknown variable '{key}'.") - # remove data when trajectories are filtered + # remove data when rows are filtered traj_idx = rowsize_to_index(ds[rowsize_var_name].values) - for i in np.where(~mask_traj)[0]: + for i in np.where(~mask_row)[0]: mask_obs[slice(traj_idx[i], traj_idx[i + 1])] = False - # remove trajectory completely filtered in mask_obs + # remove rows completely filtered in mask_obs ids_with_mask_obs = np.repeat(ds[id_var_name].values, ds[rowsize_var_name].values)[ mask_obs ] - mask_traj = np.logical_and( - mask_traj, np.in1d(ds[id_var_name], np.unique(ids_with_mask_obs)) + mask_row = np.logical_and( + mask_row, np.in1d(ds[id_var_name], np.unique(ids_with_mask_obs)) ) - # reset mask_obs to True to keep complete trajectories - if full_trajectories: - for i in np.where(mask_traj)[0]: + # reset mask_obs to True if we want to keep complete rows + if full_rows: + for i in np.where(mask_row)[0]: mask_obs[slice(traj_idx[i], traj_idx[i + 1])] = True ids_with_mask_obs = np.repeat( ds[id_var_name].values, ds[rowsize_var_name].values )[mask_obs] - if not any(mask_traj): + if not any(mask_row): warnings.warn("No data matches the criteria; returning an empty dataset.") return xr.Dataset() else: # apply the filtering for both dimensions - ds_sub = ds.isel({traj_dim_name: mask_traj, obs_dim_name: mask_obs}) + ds_sub = ds.isel({row_dim_name: mask_row, obs_dim_name: mask_obs}) _, unique_idx, sorted_rowsize = np.unique( ids_with_mask_obs, return_index=True, return_counts=True ) @@ -820,7 +822,7 @@ def _mask_var( - tuple: (min, max) defining a range - list, np.ndarray, or xr.DataArray: An array-like defining multiples values - scalar: value defining a single value - - function: a function applied against each trajectory using ``apply_ragged`` and returning a mask + - function: a function applied against each row using ``apply_ragged`` and returning a mask rowsize : xr.DataArray, optional List of integers specifying the number of data points in each row dim_name : str, optional diff --git a/clouddrift/raggedarray.py b/clouddrift/raggedarray.py index 0801c4cc..40750c78 100644 --- a/clouddrift/raggedarray.py +++ b/clouddrift/raggedarray.py @@ -3,12 +3,12 @@ structure used by CloudDrift to process custom Lagrangian datasets to Xarray Datasets and Awkward Arrays. """ + from __future__ import annotations import warnings from collections.abc import Callable -from dataclasses import dataclass -from typing import Any, Literal, Optional, Tuple, Union +from typing import Any, Literal, Optional, Union import awkward as ak # type: ignore import numpy as np @@ -17,24 +17,19 @@ from clouddrift.ragged import rowsize_to_index -DimNames = Literal["traj", "obs"] - - -@dataclass -class Dim: - name: DimNames - size: int +DimNames = Literal["rows", "obs"] class RaggedArray: def __init__( self, - coord_dims: list[tuple[str, Dim]], coords: dict, metadata: dict, data: dict, - attrs_global: dict = {}, - attrs_variables: dict = {}, + attrs_global: Optional[dict] = {}, + attrs_variables: Optional[dict] = {}, + name_dims: dict[str, DimNames] = {}, + coord_dims: dict[str, str] = {}, ): self.coords = coords self.coord_dims = coord_dims @@ -42,13 +37,17 @@ def __init__( self.data = data self.attrs_global = attrs_global self.attrs_variables = attrs_variables + self.name_dims = name_dims + self._coord_dims = coord_dims self.validate_attributes() @classmethod def from_awkward( cls, array: ak.Array, - coord_dim_map: list[tuple[str, DimNames]] = [("time", "obs"), ("id", "traj")], + name_coords: list, + name_dims: dict[str, DimNames], + coord_dims: dict[str, str], ): """Load a RaggedArray instance from an Awkward Array. @@ -56,8 +55,12 @@ def from_awkward( ---------- array : ak.Array Awkward Array instance to load the data from - coord_dim_map : list[tuple[str, DimNames]] - List of the coordinate variables names and their dimension names. + name_coords : list, optional + Names of the coordinate variables in the ragged arrays + name_dims: dict + Map a dimension to an alias. + coord_dims: dict + Map a coordinate to a dimension alias. Returns ------- @@ -65,19 +68,18 @@ def from_awkward( A RaggedArray instance """ coords: dict[str, Any] = {} - coord_dims: list[tuple[str, Dim]] = list() metadata = {} data = {} attrs_variables = {} attrs_global = array.layout.parameters["attrs"] - for var, dimName in coord_dim_map: - if dimName == "obs": + for var in name_coords: + alias = coord_dims[var] + if name_dims[alias] == "obs": coords[var] = ak.flatten(array.obs[var]).to_numpy() else: coords[var] = array.obs[var].to_numpy() - coord_dims.append((var, Dim(dimName, len(coords[var])))) attrs_variables[var] = array.obs[var].layout.parameters["attrs"] @@ -90,7 +92,7 @@ def from_awkward( attrs_variables[var] = array.obs[var].layout.parameters["attrs"] return RaggedArray( - coord_dims, coords, metadata, data, attrs_global, attrs_variables + coords, metadata, data, attrs_global, attrs_variables, name_dims, coord_dims ) @classmethod @@ -98,13 +100,14 @@ def from_files( cls, indices: list, preprocess_func: Callable[[int], xr.Dataset], - coord_dim_map: list[tuple[str, DimNames]], + name_coords: list, name_meta: list = list(), name_data: list = list(), + name_dims: dict[str, DimNames] = {}, rowsize_func: Optional[Callable[[int], int]] = None, **kwargs, ): - """Generate a ragged array archive from a list of trajectory files + """Generate a ragged array archive from a list of files Parameters ---------- @@ -112,12 +115,12 @@ def from_files( Identification numbers list to iterate preprocess_func : Callable[[int], xr.Dataset] Returns a processed xarray Dataset from an identification number - coord_dim_map : list[tuple[str, DimNames]] - List of the coordinate variables names and their dimension names. name_meta : list, optional Name of metadata variables to include in the archive (Defaults to []) name_data : list, optional Name of the data variables to include in the archive (Defaults to []) + name_dims: dict + Map an alias to a dimension. rowsize_func : Optional[Callable[[int], int]], optional Returns the number of observations from an identification number (to speed up processing) (Defaults to None) @@ -137,24 +140,25 @@ def from_files( preprocess_func, indices, rowsize, - coord_dim_map, + name_coords, name_meta, name_data, + name_dims, **kwargs, ) attrs_global, attrs_variables = cls.attributes( preprocess_func(indices[0], **kwargs), - coord_dim_map, + name_coords, name_meta, name_data, ) return RaggedArray( - coord_dims, coords, metadata, data, attrs_global, attrs_variables + coords, metadata, data, attrs_global, attrs_variables, name_dims, coord_dims ) @classmethod - def from_netcdf(cls, filename: str): + def from_netcdf(cls, filename: str, rows_dim_name="rows", obs_dim_name="obs"): """Read a ragged arrays archive from a NetCDF file. This is a thin wrapper around ``from_xarray()``. @@ -169,13 +173,15 @@ def from_netcdf(cls, filename: str): RaggedArray A ragged array instance """ - return cls.from_xarray(xr.open_dataset(filename)) + return cls.from_xarray(xr.open_dataset(filename), rows_dim_name, obs_dim_name) @classmethod def from_parquet( cls, filename: str, - coord_dim_map: list[tuple[str, DimNames]] = [("time", "obs"), ("id", "traj")], + name_coords: list, + name_dims: dict[str, DimNames], + coord_dims: dict[str, str], ): """Read a ragged array from a parquet file. @@ -183,24 +189,36 @@ def from_parquet( ---------- filename : str File name of the parquet archive to read. - coord_dim_map : list[tuple[str, DimNames]] - List of the coordinate variables names and their dimension names. + name_coords : list, optional + Names of the coordinate variables in the ragged arrays + name_dims: dict + Map a alias to a dimension. + coord_dims: dict + Map a coordinate to a dimension alias. Returns ------- RaggedArray A ragged array instance """ - return RaggedArray.from_awkward(ak.from_parquet(filename), coord_dim_map) + return RaggedArray.from_awkward( + ak.from_parquet(filename), name_coords, name_dims, coord_dims + ) @classmethod - def from_xarray(cls, ds: xr.Dataset): + def from_xarray( + cls, ds: xr.Dataset, rows_dim_name: str = "rows", obs_dim_name: str = "obs" + ): """Populate a RaggedArray instance from an xarray Dataset instance. Parameters ---------- ds : xr.Dataset Xarray Dataset from which to load the RaggedArray + rows_dim_name : str, optional + Name of the row dimension in the xarray Dataset + obs_dim_name : str, optional + Name of the observations dimension in the xarray Dataset Returns ------- @@ -208,9 +226,10 @@ def from_xarray(cls, ds: xr.Dataset): A RaggedArray instance """ coords = {} - coord_dims: list[tuple[str, Dim]] = list() metadata = {} data = {} + coord_dims = {} + name_dims: dict[str, DimNames] = {rows_dim_name: "rows", obs_dim_name: "obs"} attrs_global = {} attrs_variables = {} @@ -218,32 +237,28 @@ def from_xarray(cls, ds: xr.Dataset): for var in ds.coords.keys(): var = str(var) + dim = ds[var].dims[-1] + coord_dims[var] = str(dim) coords[var] = ds[var].data - dimName = str(ds[var].dims[0]) - dimSize = ds.sizes[dimName] - if dimName == "traj" or dimName == "obs": - coord_dims.append((var, Dim(dimName, dimSize))) # type: ignore - else: - raise RuntimeError(f"coord {var} has an unknown dim {dimName}") attrs_variables[var] = ds[var].attrs for var in ds.data_vars.keys(): - if len(ds[var]) == ds.sizes["traj"]: + if len(ds[var]) == ds.sizes.get(rows_dim_name): metadata[var] = ds[var].data - elif len(ds[var]) == ds.sizes["obs"]: + elif len(ds[var]) == ds.sizes.get(obs_dim_name): data[var] = ds[var].data else: warnings.warn( f""" - Variable '{var}' has unknown dimension size of - {len(ds[var])}, which is not traj={ds.sizes["traj"]} or - obs={ds.sizes["obs"]}; skipping. + Variable '{var}' has unknown dimension size of + {len(ds[var])}, which is not rows={ds.sizes.get(rows_dim_name)} or + obs={ds.sizes.get(obs_dim_name)}; skipping. """ ) attrs_variables[str(var)] = ds[var].attrs return RaggedArray( - coord_dims, coords, metadata, data, attrs_global, attrs_variables + coords, metadata, data, attrs_global, attrs_variables, name_dims, coord_dims ) @staticmethod @@ -255,7 +270,7 @@ def number_of_observations( Parameters ---------- rowsize_func : Callable[[int], int]] - Function that returns the number observations of a trajectory from + Function that returns the number observations of a row from its identification number indices : list Identification numbers list to iterate @@ -263,7 +278,7 @@ def number_of_observations( Returns ------- np.ndarray - Number of observations of each trajectory + Number of observations """ rowsize = np.zeros(len(indices), dtype="int") @@ -279,10 +294,10 @@ def number_of_observations( @staticmethod def attributes( ds: xr.Dataset, - coord_dim_map: list[tuple[str, DimNames]], + name_coords: list, name_meta: list, name_data: list, - ) -> Tuple[dict, dict]: + ) -> tuple[dict, dict]: """Return global attributes and the attributes of all variables (name_coords, name_meta, and name_data) from an Xarray Dataset. @@ -290,8 +305,8 @@ def attributes( ---------- ds : xr.Dataset _description_ - coord_dim_map : list[tuple[str, DimNames]] - List of the coordinate variables names and their dimension names. + name_coords : list, optional + Name of metadata variables to include in the archive (default is []) name_meta : list, optional Name of metadata variables to include in the archive (default is []) name_data : list, optional @@ -306,7 +321,7 @@ def attributes( # coordinates, metadata, and data attrs_variables = {} - for var in name_meta + name_data + [x for x, _ in coord_dim_map]: + for var in name_meta + name_data + name_coords: if var in ds.keys(): attrs_variables[var] = ds[var].attrs else: @@ -319,11 +334,12 @@ def allocate( preprocess_func: Callable[[int], xr.Dataset], indices: list, rowsize: Union[list, np.ndarray, xr.DataArray], - coord_dim_map: list[tuple[str, DimNames]], + name_coords: list, name_meta: list, name_data: list, + name_dims: dict[str, DimNames], **kwargs, - ) -> Tuple[dict, dict, dict, list[tuple[str, Dim]]]: + ) -> tuple[dict, dict, dict, dict]: """ Iterate through the files and fill for the ragged array associated with coordinates, and selected metadata and data variables. @@ -333,42 +349,49 @@ def allocate( preprocess_func : Callable[[int], xr.Dataset] Returns a processed xarray Dataset from an identification number. indices : list - List of indices separating trajectory in the ragged arrays. + List of indices separating row in the ragged arrays. rowsize : list - List of the number of observations per trajectory. - coord_dim_map : list[tuple[str, DimNames]] - List of the coordinate variables names and their dimension names. + List of the number of observations per row. + name_coords : list + Name of the coordinate variables to include in the archive. name_meta : list, optional Name of metadata variables to include in the archive (Defaults to []). name_data : list, optional Name of the data variables to include in the archive (Defaults to []). + name_dims: dict[str, DimNames] + Dimension alias mapped to the name used by clouddrift. Returns ------- - Tuple[dict, dict, dict] + Tuple[dict, dict, dict, dict] Dictionaries containing numerical data and attributes of coordinates, metadata and data variables. """ # open one file to get dtype of variables ds = preprocess_func(indices[0], **kwargs) - nb_traj = len(rowsize) + nb_rows = len(rowsize) nb_obs = np.sum(rowsize).astype("int") index_traj = rowsize_to_index(rowsize) + dim_sizes = {} + + for alias in name_dims.keys(): + if name_dims[alias] == "rows": + dim_sizes[alias] = nb_rows + else: + dim_sizes[alias] = nb_obs # allocate memory - coord_dims: list[tuple[str, Dim]] = list() coords = {} - for var, dimName in coord_dim_map: - if dimName == "traj": - dimSize = nb_traj - else: - dimSize = nb_obs - coords[var] = np.zeros(dimSize, dtype=ds[var].dtype) - coord_dims.append((var, Dim(dimName, dimSize))) + coord_dims: dict[str, str] = {} + for var in name_coords: + dim = ds[var].dims[-1] + dim_size = dim_sizes[dim] + coords[var] = np.zeros(dim_size, dtype=ds[var].dtype) + coord_dims[var] = dim metadata = {} for var in name_meta: try: - metadata[var] = np.zeros(nb_traj, dtype=ds[var].dtype) + metadata[var] = np.zeros(nb_rows, dtype=ds[var].dtype) except KeyError: warnings.warn(f"Variable {var} requested but not found; skipping.") @@ -391,8 +414,9 @@ def allocate( size = rowsize[i] oid = index_traj[i] - for var, dimName in coord_dim_map: - if dimName == "obs": + for var in name_coords: + dim = ds[var].dims[-1] + if name_dims[dim] == "obs": coords[var][oid : oid + size] = ds[var].data else: coords[var][i] = ds[var].data[0] @@ -425,7 +449,7 @@ def validate_attributes(self): if key not in self.attrs_variables: self.attrs_variables[key] = {} - def to_xarray(self, cast_to_float32: bool = True): + def to_xarray(self): """Convert ragged array object to a xarray Dataset. Parameters @@ -439,17 +463,30 @@ def to_xarray(self, cast_to_float32: bool = True): xr.Dataset Xarray Dataset containing the ragged arrays and their attributes """ + dim_name_map = {v: k for k, v in self.name_dims.items()} xr_coords = {} - for var, dim in self.coord_dims: - xr_coords[var] = ([dim.name], self.coords[var], self.attrs_variables[var]) + for var in self.coords.keys(): + xr_coords[var] = ( + [self._coord_dims[var]], + self.coords[var], + self.attrs_variables[var], + ) xr_data = {} for var in self.metadata.keys(): - xr_data[var] = (["traj"], self.metadata[var], self.attrs_variables[var]) + xr_data[var] = ( + [dim_name_map["rows"]], + self.metadata[var], + self.attrs_variables[var], + ) for var in self.data.keys(): - xr_data[var] = (["obs"], self.data[var], self.attrs_variables[var]) + xr_data[var] = ( + [dim_name_map["obs"]], + self.data[var], + self.attrs_variables[var], + ) return xr.Dataset(coords=xr_coords, data_vars=xr_data, attrs=self.attrs_global) @@ -465,8 +502,9 @@ def to_awkward(self): offset = ak.index.Index64(index_traj) data = [] - for var, dim in self.coord_dims: - if dim.name == "obs": + for var in self.coords.keys(): + dim = self._coord_dims[var] + if self.name_dims[dim] == "obs": data.append( ak.contents.ListOffsetArray( offset, diff --git a/pyproject.toml b/pyproject.toml index 33eb58a2..c52feaed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "clouddrift" -version = "0.31.0" +version = "0.32.0" authors = [ { name="Shane Elipot", email="selipot@miami.edu" }, { name="Philippe Miron", email="philippemiron@gmail.com" }, diff --git a/tests/datasets_tests.py b/tests/datasets_tests.py index 78eb1ed7..ce6f16b8 100644 --- a/tests/datasets_tests.py +++ b/tests/datasets_tests.py @@ -33,7 +33,12 @@ def test_glad_dims_coords(self): def test_glad_subset_and_apply_ragged_work(self): with datasets.glad() as ds: - ds_sub = subset(ds, {"id": ["CARTHE_001", "CARTHE_002"]}, id_var_name="id") + ds_sub = subset( + ds, + {"id": ["CARTHE_001", "CARTHE_002"]}, + id_var_name="id", + row_dim_name="traj", + ) self.assertTrue(ds_sub) mean_lon = apply_ragged(np.mean, [ds_sub.longitude], ds_sub.rowsize) self.assertTrue(mean_lon.size == 2) diff --git a/tests/pairs_tests.py b/tests/pairs_tests.py index 1bcc620e..93417d72 100644 --- a/tests/pairs_tests.py +++ b/tests/pairs_tests.py @@ -14,7 +14,9 @@ class pairs_chance_pairs_from_ragged_tests(unittest.TestCase): def setUp(self) -> None: num_trajectories = 10 ids = ["CARTHE_%3.3i" % (i + 1) for i in range(num_trajectories)] - ds = ragged.subset(datasets.glad(), {"id": ids}, id_var_name="id") + ds = ragged.subset( + datasets.glad(), {"id": ids}, id_var_name="id", row_dim_name="traj" + ) self.lon = ds["longitude"] self.lat = ds["latitude"] self.time = ds["time"] diff --git a/tests/ragged_tests.py b/tests/ragged_tests.py index 28057553..6d7225b4 100644 --- a/tests/ragged_tests.py +++ b/tests/ragged_tests.py @@ -35,25 +35,28 @@ def sample_ragged_array() -> RaggedArray: [True, False, False, False], ] rowsize = [len(x) for x in longitude] - ids = [[d] * rowsize[i] for i, d in enumerate(drifter_id)] attrs_global = { "title": "test trajectories", "history": "version xyz", } - coords = {"lon": longitude, "lat": latitude, "ids": ids, "time": t} - metadata = {"id": drifter_id, "rowsize": rowsize} - data = {"test": test} + coords: dict[str, list] = {"id": drifter_id, "time": t} + metadata = {"rowsize": rowsize} + data: dict[str, list] = {"test": test, "lat": latitude, "lon": longitude} # append xr.Dataset to a list list_ds = [] for i in range(0, len(rowsize)): xr_coords = {} - for var in coords.keys(): - xr_coords[var] = ( - ["obs"], - coords[var][i], - {"long_name": f"variable {var}", "units": "-"}, - ) + xr_coords["id"] = ( + ["rows"], + [coords["id"][i]], + {"long_name": "variable id", "units": "-"}, + ) + xr_coords["time"] = ( + ["obs"], + coords["time"][i], + {"long_name": "variable time", "units": "-"}, + ) xr_data: dict[str, Any] = {} for var in metadata.keys(): @@ -77,9 +80,10 @@ def sample_ragged_array() -> RaggedArray: ra = RaggedArray.from_files( [0, 1, 2], lambda i: list_ds[i], - [("ids", "traj"), ("time", "obs"), ("lat", "obs"), ("lon", "obs")], - ["id", "rowsize"], - ["test"], + ["id", "time"], + name_meta=["rowsize"], + name_data=["test", "lat", "lon"], + name_dims={"rows": "rows", "obs": "obs"} ) return ra @@ -667,12 +671,12 @@ def test_arraylike_criterion(self): ds_sub = subset(self.ds, {"id": self.ds["id"][:2].values}) self.assertTrue(ds_sub["id"].size == 2) - def test_full_trajectories(self): + def test_full_rows(self): ds_id_rowsize = { i: j for i, j in zip(self.ds.id.values, self.ds.rowsize.values) } - ds_sub = subset(self.ds, {"lon": (-125, -111)}, full_trajectories=True) + ds_sub = subset(self.ds, {"lon": (-125, -111)}, full_rows=True) self.assertTrue(all(ds_sub.lon == [-121, -111, 51, 61, 71])) ds_sub_id_rowsize = { @@ -681,7 +685,7 @@ def test_full_trajectories(self): for k, v in ds_sub_id_rowsize.items(): self.assertTrue(ds_id_rowsize[k] == v) - ds_sub = subset(self.ds, {"lat": (30, 40)}, full_trajectories=True) + ds_sub = subset(self.ds, {"lat": (30, 40)}, full_rows=True) self.assertTrue(all(ds_sub.lat == [10, 20, 30, 40])) ds_sub_id_rowsize = { @@ -690,12 +694,12 @@ def test_full_trajectories(self): for k, v in ds_sub_id_rowsize.items(): self.assertTrue(ds_id_rowsize[k] == v) - ds_sub = subset(self.ds, {"time": (4, 5)}, full_trajectories=True) + ds_sub = subset(self.ds, {"time": (4, 5)}, full_rows=True) xr.testing.assert_equal(self.ds, ds_sub) def test_subset_by_rows(self): rows = [0, 2] # test extracting first and third rows - ds_sub = subset(self.ds, {"traj": rows}) + ds_sub = subset(self.ds, {"rows": rows}) self.assertTrue(all(ds_sub["id"] == [1, 2])) self.assertTrue(all(ds_sub["rowsize"] == [5, 4])) @@ -733,7 +737,7 @@ def test_subset_callable_wrong_dim(self): def test_subset_callable_wrong_type(self): rows = [0, 2] # test extracting first and third rows with self.assertRaises(TypeError): # passing a tuple when a string is expected - subset(self.ds, {("traj",): rows}) + subset(self.ds, {("rows",): rows}) def test_subset_callable_tuple_unknown_var(self): func = lambda arr1, arr2: np.logical_and( diff --git a/tests/raggedarray_tests.py b/tests/raggedarray_tests.py index 2acd6e17..a545ac5d 100644 --- a/tests/raggedarray_tests.py +++ b/tests/raggedarray_tests.py @@ -31,6 +31,7 @@ def setUpClass(self): } self.variables_coords = [("id", "traj"), ("time", "obs")] + # append xr.Dataset to a list list_ds = [] for i in range(0, len(self.rowsize)): @@ -64,12 +65,19 @@ def setUpClass(self): ) # create test ragged array + self.name_coords=["id", "time"] + self.name_meta=["rowsize"] + self.name_data=["temp"] + self.name_dims={"traj": "rows", "obs": "obs"} + self.coord_dims={"id": "traj", "time": "obs"} self.ra = RaggedArray.from_files( [0, 1, 2], lambda i: list_ds[i], - self.variables_coords, - ["rowsize"], - ["temp"], + self.name_coords, + self.name_meta, + self.name_data, + self.name_dims, + lambda i: self.rowsize[i] ) # output archive @@ -85,11 +93,26 @@ def tearDownClass(self): os.remove(PARQUET_ARCHIVE) def test_from_awkward(self): - ra = RaggedArray.from_awkward(ak.from_parquet(PARQUET_ARCHIVE)) + ra = RaggedArray.from_awkward( + ak.from_parquet(PARQUET_ARCHIVE), + self.name_coords, + self.name_dims, + self.coord_dims + ) self.compare_awkward_array(ra.to_awkward()) def test_from_xarray(self): - ra = RaggedArray.from_xarray(xr.open_dataset(NETCDF_ARCHIVE)) + ra = RaggedArray.from_xarray(xr.open_dataset(NETCDF_ARCHIVE), "traj") + self.compare_awkward_array(ra.to_awkward()) + + def test_from_xarray_dim_names(self): + ds = xr.open_dataset("test_archive.nc") + + ra = RaggedArray.from_xarray( + ds.rename_dims({"traj": "t", "obs": "o"}), + rows_dim_name="t", + obs_dim_name="o", + ) self.compare_awkward_array(ra.to_awkward()) def test_length_ragged_arrays(self): @@ -163,12 +186,17 @@ def test_netcdf_output(self): """ Validate the netCDF output archive """ - ds = RaggedArray.from_netcdf(NETCDF_ARCHIVE) + ds = RaggedArray.from_netcdf(NETCDF_ARCHIVE, "traj") self.compare_awkward_array(ds.to_awkward()) def test_parquet_output(self): """ Validate the netCDF output archive """ - ds = RaggedArray.from_parquet(PARQUET_ARCHIVE) + ds = RaggedArray.from_parquet( + PARQUET_ARCHIVE, + self.name_coords, + self.name_dims, + self.coord_dims + ) self.compare_awkward_array(ds.to_awkward())