Skip to content

Commit

Permalink
move import, fix naming and add back skip_download
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinsantana11 committed Oct 23, 2024
1 parent 0d764b2 commit abab73d
Showing 1 changed file with 25 additions and 22 deletions.
47 changes: 25 additions & 22 deletions clouddrift/adapters/gdp/gdpsource.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import datetime
import gzip
import logging
import os
import tempfile
Expand Down Expand Up @@ -380,17 +381,21 @@ def _parse_datetime_with_day_ratio(


def _process(
df: dd.DataFrame,
df: dd.DataFrame | pd.DataFrame,
gdp_metadata_df: pd.DataFrame,
use_fill_values: bool,
) -> xr.Dataset:
"""Process each dataframe chunk. Return a dictionary mapping each drifter to a unique xarray Dataset."""

# Transform the initial dataframe filtering out rows with really anomolous values
# examples include: years in the future, years way in the past before GDP program, etc...
preremove_df = df.compute()
df_chunk = _apply_remove(
preremove_df,
if isinstance(df, dd.DataFrame):
source_df = df.compute(optimize_graph=True)
else:
source_df = df

Check warning on line 395 in clouddrift/adapters/gdp/gdpsource.py

View check run for this annotation

Codecov / codecov/patch

clouddrift/adapters/gdp/gdpsource.py#L395

Added line #L395 was not covered by tests

clean_df = _apply_remove(
source_df,
filters=[
# Filter out year values that are in the future or predating the GDP program
lambda df: (df["posObsYear"] > datetime.datetime.now().year)
Expand All @@ -406,20 +411,17 @@ def _process(
],
)

preremove_len = len(preremove_df)
postremove_len = len(df_chunk)

if preremove_len != postremove_len:
warnings.warn(
f"Filters removed {preremove_len - postremove_len} rows from chunk"
)
source_df_len = len(source_df)
clean_df_len = len(clean_df)
if source_df_len != clean_df_len:
warnings.warn(f"Filters removed {source_df_len - clean_df_len} rows from chunk")

if postremove_len == 0:
if clean_df_len == 0:
raise ValueError("All rows removed from dataframe, please review filters")

Check warning on line 420 in clouddrift/adapters/gdp/gdpsource.py

View check run for this annotation

Codecov / codecov/patch

clouddrift/adapters/gdp/gdpsource.py#L420

Added line #L420 was not covered by tests

df_chunk = df_chunk.astype(_INPUT_COLS_DTYPES)
df_chunk = _apply_transform(
df_chunk,
df = clean_df.astype(_INPUT_COLS_DTYPES)
df = _apply_transform(
df,
{
"position_datetime": (
["posObsMonth", "posObsDay", "posObsYear"],
Expand All @@ -433,7 +435,7 @@ def _process(
)

# Find and process drifters found and documented in the drifter metadata.
ids_in_data = np.unique(df_chunk[["id"]].values)
ids_in_data = np.unique(df[["id"]].values)
ids_with_md = np.intersect1d(ids_in_data, gdp_metadata_df[["ID"]].values)

if len(ids_with_md) < len(ids_in_data):
Expand All @@ -442,7 +444,7 @@ def _process(
+ "Using fill values"
if use_fill_values
else "Ignoring data observations"
+ f" for missing metadata ids: {np.setdiff1d(ids_in_data, ids_with_md)}."
+ f" for missing metadata ids: {len(np.setdiff1d(ids_in_data, ids_with_md))}."
)

if use_fill_values:
Expand Down Expand Up @@ -471,7 +473,7 @@ def _process(
name_data=_DATA_VARS,
name_dims={"traj": "rows", "obs": "obs"},
md_df=gdp_metadata_df,
data_df=df_chunk,
data_df=df,
use_fill_values=use_fill_values,
tqdm={"disable": True},
)
Expand All @@ -482,6 +484,7 @@ def to_raggedarray(
tmp_path: str = _TMP_PATH,
max: int | None = None,
use_fill_values: bool = True,
skip_download: bool = False,
) -> xr.Dataset:
"""Transforms the GDP source dataset into a ragged array xarray Dataset.
Expand All @@ -499,6 +502,8 @@ def to_raggedarray(
for testing purposes.
use_fill_values: bool, True (default)
When True, missing metadata fields are replaced with fill values. dataset.
skip_download: bool, False (default)
When True, skip downloading the files. This can be used when wanting
Returns
-------
Expand All @@ -515,12 +520,11 @@ def to_raggedarray(
requests = requests[:max]

# Download necessary data and metadata files.
download_with_progress(requests)
if not skip_download:
download_with_progress(requests)

gdp_metadata_df = get_gdp_metadata(tmp_path)

import gzip

data_files = list()
for compressed_data_file in tqdm(
[dst for (_, dst) in requests], desc="Decompressing files", unit="file"
Expand All @@ -544,7 +548,6 @@ def to_raggedarray(
header=None,
names=_INPUT_COLS,
dtype=wanted_dtypes,
engine="c",
blocksize="1GB",
assume_missing=True,
)
Expand Down

0 comments on commit abab73d

Please sign in to comment.