From 82683dba942cb428f6d0a542fa51f34b8bae2f37 Mon Sep 17 00:00:00 2001 From: Kevin Santana Date: Thu, 31 Oct 2024 18:32:12 -0700 Subject: [PATCH] dont load entire dataset and allow dask to lazy load the dataset as needed --- clouddrift/adapters/gdp/gdpsource.py | 168 +++++++++++++-------------- clouddrift/raggedarray.py | 6 +- 2 files changed, 81 insertions(+), 93 deletions(-) diff --git a/clouddrift/adapters/gdp/gdpsource.py b/clouddrift/adapters/gdp/gdpsource.py index df101bd1..4f01e746 100644 --- a/clouddrift/adapters/gdp/gdpsource.py +++ b/clouddrift/adapters/gdp/gdpsource.py @@ -2,11 +2,10 @@ import datetime import gzip -import logging import os import tempfile import warnings -from typing import Callable +from typing import Any import dask.dataframe as dd import numpy as np @@ -104,7 +103,6 @@ _INPUT_COLS_DTYPES = { - "id": np.int64, "posObsMonth": np.float32, "posObsDay": np.float64, "posObsYear": np.float32, @@ -123,6 +121,7 @@ } _INPUT_COLS_PREFILTER_DTYPES: dict[str, type[object]] = { + "id": np.float64, "posObsMonth": np.str_, "posObsYear": np.float64, "senObsMonth": np.str_, @@ -240,8 +239,6 @@ "summary": "Global Drifter Program source (raw) data", } -_logger = logging.getLogger(__name__) - def _get_download_list(tmp_path: str) -> list[tuple[str, str]]: suffix = "rawfiles" @@ -269,7 +266,7 @@ def _rowsize(id_, **kwargs) -> int: def _preprocess(id_, **kwargs) -> xr.Dataset: md_df: pd.DataFrame | None = kwargs.get("md_df") - data_df: pd.DataFrame | None = kwargs.get("data_df") + data_df: pd.DataFrame | dd.DataFrame | None = kwargs.get("data_df") use_fill_values: bool = kwargs.get("use_fill_values", False) if md_df is None or data_df is None: @@ -279,6 +276,8 @@ def _preprocess(id_, **kwargs) -> xr.Dataset: traj_md_df = md_df[md_df["ID"] == id_] traj_data_df = data_df[data_df["id"] == id_] + if isinstance(traj_data_df, dd.DataFrame): + traj_data_df = traj_data_df.compute() rowsize = len(traj_data_df) md_variables = { @@ -331,7 +330,7 @@ def _preprocess(id_, **kwargs) -> xr.Dataset: "id": (["traj"], np.array([id_]).astype(np.int64)), "position_datetime": ( ["obs"], - traj_data_df[["position_datetime"]].values.flatten().astype(np.datetime64), + traj_data_df[["position_datetime"]].values.flatten(), ), } @@ -340,44 +339,24 @@ def _preprocess(id_, **kwargs) -> xr.Dataset: return dataset -def _apply_remove(df: pd.DataFrame, filters: list[Callable]) -> pd.DataFrame: - temp_df = df - for filter_ in filters: - mask = filter_(temp_df) - temp_df = temp_df[~mask] - return temp_df - - -def _apply_transform( - df: pd.DataFrame, - transforms: dict[str, tuple[list[str], Callable]], -) -> pd.DataFrame: - tmp_df = df - for output_col in transforms.keys(): - input_cols, func = transforms[output_col] - args = list() - for col in input_cols: - arg = df[[col]].values.flatten() - args.append(arg) - tmp_df = tmp_df.assign(**{output_col: func(*args)}) - tmp_df = tmp_df.drop(input_cols, axis=1) - return tmp_df - - def _parse_datetime_with_day_ratio( - month_series: np.ndarray, day_series: np.ndarray, year_series: np.ndarray -) -> np.ndarray: - values = list() - for month, day_with_ratio, year in zip(month_series, day_series, year_series): + month: float, day_with_ratio: float, year: float +) -> Any: + try: day = day_with_ratio // 1 dayratio = day_with_ratio - day - seconds = dayratio * _SECONDS_IN_DAY - dt_ns = ( - datetime.datetime(year=int(year), month=int(month), day=int(1)) - + datetime.timedelta(days=int(day), seconds=seconds) - ).timestamp() * 10**9 - values.append(int(dt_ns)) - return np.array(values).astype("datetime64[ns]") + second = dayratio * _SECONDS_IN_DAY + # seconds = ( + # datetime.datetime(year=int(year), month=int(month), day=int(day)) + # + datetime.timedelta(seconds=int(second)) + # ).timestamp() + # return np.datetime64(int(seconds * 10**9)).astype(np.dtype("datetime64[ns]")) + return np.datetime64( + datetime.datetime(year=int(year), month=int(month), day=int(day)) + + datetime.timedelta(seconds=int(second)) + ).astype(np.dtype("datetime64[ns]")) + except Exception as _: + return np.datetime64("NaT", "ns") def _process( @@ -386,30 +365,25 @@ def _process( use_fill_values: bool, ) -> xr.Dataset: """Process each dataframe chunk. Return a dictionary mapping each drifter to a unique xarray Dataset.""" - - # Transform the initial dataframe filtering out rows with really anomolous values - # examples include: years in the future, years way in the past before GDP program, etc... - if isinstance(df, dd.DataFrame): - source_df = df.compute(optimize_graph=True) - else: - source_df = df - - clean_df = _apply_remove( - source_df, - filters=[ - # Filter out year values that are in the future or predating the GDP program - lambda df: (df["posObsYear"] > datetime.datetime.now().year) - | (df["posObsYear"] < 0), - lambda df: (df["senObsYear"] > datetime.datetime.now().year) - | (df["senObsYear"] < 0), - # Filter out month values that contain non-numeric characters - lambda df: df["senObsMonth"].astype(np.str_).str.contains(r"[\D]"), - lambda df: df["posObsMonth"].astype(np.str_).str.contains(r"[\D]"), - # Filter out drogue values that cannot be interpret as floating point values. - # (e.g. - have more than one decimal point) - lambda df: df["drogue"].astype(np.str_).str.match(r"(\d+[\.]+){2,}"), - ], - ) + this_year = datetime.datetime.now().year + source_df = df + + for filter_ in [ + # Filter out year values that are in the future or predating the GDP program + lambda df: df["id"] != np.NaN, + lambda df: df["posObsYear"] <= this_year, + lambda df: df["posObsYear"] > 0, + lambda df: df["senObsYear"] <= this_year, + lambda df: df["senObsYear"] > 0, + # Filter out month values that contain non-numeric characters + lambda df: df["senObsMonth"].astype(np.str_).str.contains(r"[\D]") == False, + lambda df: df["posObsMonth"].astype(np.str_).str.contains(r"[\D]") == False, + # Filter out drogue values that cannot be interpret as floating point values. + # (e.g. - have more than one decimal point) + lambda df: df["drogue"].astype(np.str_).str.match(r"(\d+[\.]+){2,}") == False, + ]: + df = df[filter_(df)] + clean_df = df source_df_len = len(source_df) clean_df_len = len(clean_df) @@ -420,25 +394,32 @@ def _process( raise ValueError("All rows removed from dataframe, please review filters") df = clean_df.astype(_INPUT_COLS_DTYPES) - df = _apply_transform( - df, - { - "position_datetime": ( - ["posObsMonth", "posObsDay", "posObsYear"], - _parse_datetime_with_day_ratio, - ), - "sensor_datetime": ( - ["senObsMonth", "senObsDay", "senObsYear"], - _parse_datetime_with_day_ratio, - ), - }, + + df["position_datetime"] = df.apply( + lambda row: _parse_datetime_with_day_ratio( + row["posObsMonth"], row["posObsDay"], row["posObsYear"] + ), + axis=1, + ) + + df["sensor_datetime"] = df.apply( + lambda row: _parse_datetime_with_day_ratio( + row["senObsMonth"], row["senObsDay"], row["senObsYear"] + ), + axis=1, ) # Find and process drifters found and documented in the drifter metadata. - ids_in_data = np.unique(df[["id"]].values) + if isinstance(df, dd.DataFrame): + ids_in_data: Any = df[["id"]].values + ids_in_data = ids_in_data.compute().flatten() + + # ids_in_data = np.unique(all_ids) ids_with_md = np.intersect1d(ids_in_data, gdp_metadata_df[["ID"]].values) + len_ids_with_md = len(ids_with_md) + len_ids_in_data = len(ids_in_data) - if len(ids_with_md) < len(ids_in_data): + if len_ids_with_md < len_ids_in_data: warnings.warn( "Chunk has drifter ids not found in the metadata table. " + "Using fill values" @@ -452,18 +433,25 @@ def _process( else: selected_ids = ids_with_md - gdp_start_dates = list() - for id_ in selected_ids: - selected_drifter = gdp_metadata_df[gdp_metadata_df["ID"] == id_] - - if len(selected_drifter) == 0: - gdp_start_dates.append(np.datetime64("NaT")) - else: - gdp_start_dates.append(selected_drifter[["Start_date"]].values.flatten()[0]) - - start_date_sortkey = np.argsort(gdp_start_dates) + # Get metadata for selected ids + mask = np.isin(gdp_metadata_df[["ID"]].values.flatten(), selected_ids) + selected_metadata = gdp_metadata_df[mask] + + # initialize with NaN to handle selected ids with no metadata, then populate with selected ids + start_dates = np.full( + selected_ids.shape, np.NaN + ) # Initialize with NaN for selected ids with no metadata + start_dates[: len(selected_metadata)] = selected_metadata[ + ["Start_date"] + ].values.flatten() + start_date_sortkey = np.argsort(start_dates) start_date_sorted_ids = selected_ids[start_date_sortkey] + if isinstance(df, dd.DataFrame): + df = df.set_index("id", drop=False).persist() + else: + df = df.set_index("id") + ra = RaggedArray.from_files( indices=start_date_sorted_ids, preprocess_func=_preprocess, diff --git a/clouddrift/raggedarray.py b/clouddrift/raggedarray.py index a7aadb06..a8a7eb3a 100644 --- a/clouddrift/raggedarray.py +++ b/clouddrift/raggedarray.py @@ -99,7 +99,7 @@ def from_awkward( @classmethod def from_files( cls, - indices: list[int], + indices: list[int] | np.ndarray, preprocess_func: Callable[[int], xr.Dataset], name_coords: list, name_meta: list = list(), @@ -273,7 +273,7 @@ def from_xarray( @staticmethod def number_of_observations( - rowsize_func: Callable[[int], int], indices: list, **kwargs + rowsize_func: Callable[[int], int], indices: list | np.ndarray, **kwargs ) -> np.ndarray: """Iterate through the files and evaluate the number of observations. @@ -343,7 +343,7 @@ def attributes( @staticmethod def allocate( preprocess_func: Callable[[int], xr.Dataset], - indices: list, + indices: list | np.ndarray, rowsize: list | np.ndarray | xr.DataArray, name_coords: list, name_meta: list,