Skip to content

Commit

Permalink
Add method to_xarray to EOProduct
Browse files Browse the repository at this point in the history
  • Loading branch information
amarandon committed Dec 4, 2024
1 parent cb2c811 commit f50d67b
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 2 deletions.
2 changes: 1 addition & 1 deletion eodag_cube/api/product/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
# limitations under the License.
"""EODAG product package"""
from ._assets import Asset, AssetsDict # noqa
from ._product import EOProduct # noqa
from ._product import EOProduct, XarrayDict # noqa
31 changes: 31 additions & 0 deletions eodag_cube/api/product/_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@
from __future__ import annotations

import logging
import os
from collections import UserDict
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from urllib.parse import urlparse

import numpy as np
import rasterio
Expand All @@ -40,6 +43,12 @@
logger = logging.getLogger("eodag-cube.api.product")


class XarrayDict(UserDict[str, xr.Dataset]):
"""
Dictionnary which keys are file paths and values are xarray Datasets.
"""


class EOProduct(EOProduct_core):
"""A wrapper around an Earth Observation Product originating from a search.
Expand Down Expand Up @@ -245,3 +254,25 @@ def _get_rio_env(self, dataset_address: str) -> Dict[str, Any]:
return rio_env_dict
else:
return {}

def _build_xarray_dict(self, **kwargs):
result = XarrayDict()
product_path = urlparse(self.location).path
for root, _dirs, filenames in os.walk(product_path):
for filename in filenames:
filepath = os.path.join(root, filename)
try:
ds = xr.open_dataset(filepath, **kwargs)
result[filepath] = ds
except ValueError as exc:
logger.debug("Cannot open %s with xarray: %s", filepath, exc)
return result

def to_xarray(self, **kwargs) -> XarrayDict:
"""
Return a dictionnary which keys are file paths and values are xarray Datasets.
Any keyword arguments passed will be forwarded to xarray.open_dataset.
"""
self.download()
return self._build_xarray_dict(**kwargs)
48 changes: 47 additions & 1 deletion tests/units/test_eoproduct.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,24 @@
# limitations under the License.

import itertools
import os
import random
import shutil
from pathlib import Path
from tempfile import TemporaryDirectory

import numpy as np
import xarray as xr
from rasterio.session import AWSSession

from tests import TEST_GRIB_FILE_PATH, TEST_GRIB_FILENAME, EODagTestCase
from eodag_cube.api.product import XarrayDict
from tests import (
TEST_GRIB_FILE_PATH,
TEST_GRIB_FILENAME,
TEST_GRIB_PRODUCT_PATH,
TEST_RESOURCES_PATH,
EODagTestCase,
)
from tests.context import (
DEFAULT_PROJ,
Authentication,
Expand Down Expand Up @@ -238,3 +249,38 @@ def test_get_rio_env(self):
self.assertEqual(rio_env["AWS_HTTPS"], "YES")
self.assertEqual(rio_env["AWS_S3_ENDPOINT"], "some.where")
self.assertEqual(rio_env["AWS_VIRTUAL_HOSTING"], "FALSE")

def populate_directory_with_heterogeneous_files(self, destination):
"""
Put various files in the destination directory:
- a grib file
- an .idx file that often comes with grib files
- a JPEG2000 file
- an XML file
"""
# Copy all files from a grib product
shutil.copytree(TEST_GRIB_PRODUCT_PATH, destination, dirs_exist_ok=True)

# Copy files from an S2A product
s2a_path = os.path.join(
TEST_RESOURCES_PATH,
"products",
"S2A_MSIL1C_20180101T105441_N0206_R051_T31TDH_20180101T124911.SAFE",
)
shutil.copytree(s2a_path, destination, dirs_exist_ok=True)

def test_build_xarray_dict(self):
with TemporaryDirectory(prefix="eodag-cube-tests") as tmp_dir:
product = EOProduct(
self.provider, self.eoproduct_props, productType=self.product_type
)
product.location = f"file://{tmp_dir}"
self.populate_directory_with_heterogeneous_files(tmp_dir)

xarray_dict = product._build_xarray_dict()

self.assertIsInstance(xarray_dict, XarrayDict)
self.assertEqual(len(xarray_dict), 2)
for key, value in xarray_dict.items():
self.assertIn(Path(key).suffix, {".grib", ".jp2"})
self.assertIsInstance(value, xr.Dataset)

0 comments on commit f50d67b

Please sign in to comment.