Skip to content

Commit

Permalink
Modify GridH5WaterQualityResultAdmin (#231)
Browse files Browse the repository at this point in the history
  • Loading branch information
hoanphungt authored May 29, 2024
1 parent c260254 commit 4b43c8d
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 48 deletions.
4 changes: 3 additions & 1 deletion HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ History
2.2.10 (unreleased)
-------------------

- Nothing changed yet.
- Move _field_model_dict and its methods back to GridH5ResultAdmin.
- Add `substances` and overwrite `get_model_instance_by_field_name` method
in GridH5WaterQualityResultAdmin.


2.2.9 (2024-05-27)
Expand Down
47 changes: 0 additions & 47 deletions threedigrid/admin/gridadmin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
"""

import logging
from collections import defaultdict

import h5py
import numpy as np
Expand All @@ -17,7 +16,6 @@
from threedigrid.admin.nodes.models import Cells, EmbeddedNodes, Grid, Nodes
from threedigrid.admin.pumps.models import Pumps
from threedigrid.geo_utils import raise_import_exception, transform_bbox
from threedigrid.orm.models import Model

try:
import pyproj
Expand Down Expand Up @@ -55,7 +53,6 @@ def __init__(self, h5_file_path, file_modus="r", set_props=False):
:param file_modus: mode with which to open the file
(defaults to r=READ)
"""
self._field_model_dict = defaultdict(list)
self.grid_file = h5_file_path
self.datasource_class = H5pyGroup
self.is_rpc = False
Expand Down Expand Up @@ -191,50 +188,6 @@ def has_levees(self):
return False
return bool(self.levees.id.size)

@property
def _field_model_map(self):
"""
:return: a dict of {<field name>: [model name, ...]}
"""
if self._field_model_dict:
return self._field_model_dict

model_names = set()
for attr_name in dir(self):
# skip private attrs
if any([attr_name.startswith("__"), attr_name.startswith("_")]):
continue
try:
attr = getattr(self, attr_name)
except AttributeError:
logger.warning(
"Attribute: '{}' does not " "exist in h5py_file.".format(attr_name)
)
continue
if not issubclass(type(attr), Model):
continue
model_names.add(attr_name)

for model_name in model_names:
for x in getattr(self, model_name)._field_names:
self._field_model_dict[x].append(model_name)
return self._field_model_dict

def get_model_instance_by_field_name(self, field_name):
"""
:param field_name: name of a models field
:return: instance of the model the field belongs to
:raises IndexError if the field name is not unique across models
"""
model_name = self._field_model_map.get(field_name)
if not model_name or len(model_name) != 1:
raise IndexError(
"Ambiguous result. Field name {} yields {} model(s)".format(
field_name, len(model_name) if model_name else 0
)
)
return getattr(self, model_name[0])

def get_extent_subset(self, subset_name, target_epsg_code=""):
"""
Expand Down
63 changes: 63 additions & 0 deletions threedigrid/admin/gridresultadmin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import logging
import re
from collections import defaultdict
from typing import List, Optional, Union

import h5py
Expand Down Expand Up @@ -37,6 +38,7 @@
StructureControl,
StructureControlTypes,
)
from threedigrid.orm.models import Model

logger = logging.getLogger(__name__)

Expand All @@ -61,6 +63,7 @@ def __init__(self, h5_file_path, netcdf_file_path, file_modus="r", swmr=False):
called subgrid_map.nc)
:param file_modus: modus in which to open the files
"""
self._field_model_dict = defaultdict(list)
self._netcdf_file_path = netcdf_file_path
super().__init__(h5_file_path, file_modus)

Expand Down Expand Up @@ -154,6 +157,50 @@ def version_check(self):
self.threedicore_version,
)

@property
def _field_model_map(self):
"""
:return: a dict of {<field name>: [model name, ...]}
"""
if self._field_model_dict:
return self._field_model_dict

model_names = set()
for attr_name in dir(self):
# skip private attrs
if any([attr_name.startswith("__"), attr_name.startswith("_")]):
continue
try:
attr = getattr(self, attr_name)
except AttributeError:
logger.warning(
"Attribute: '{}' does not " "exist in h5py_file.".format(attr_name)
)
continue
if not issubclass(type(attr), Model):
continue
model_names.add(attr_name)

for model_name in model_names:
for x in getattr(self, model_name)._field_names:
self._field_model_dict[x].append(model_name)
return self._field_model_dict

def get_model_instance_by_field_name(self, field_name):
"""
:param field_name: name of a models field
:return: instance of the model the field belongs to
:raises IndexError if the field name is not unique across models
"""
model_name = self._field_model_map.get(field_name)
if not model_name or len(model_name) != 1:
raise IndexError(
"Ambiguous result. Field name {} yields {} model(s)".format(
field_name, len(model_name) if model_name else 0
)
)
return getattr(self, model_name[0])

@property
def threedicore_result_version(self):
"""
Expand Down Expand Up @@ -486,6 +533,7 @@ def __init__(
netcdf_file_path: str,
file_modus: str = "r",
swmr: bool = False,
substances: List[str] = [],
) -> None:
"""
:param h5_file_path: path to the hdf5 gridadmin file
Expand Down Expand Up @@ -529,6 +577,21 @@ def __init__(
if isinstance(value, bytes):
value = value.decode("utf-8")
self.__getattribute__(substance).__setattr__(attr, value)
self.substances = list(substances)

def get_model_instance_by_field_name(self, field_name):
"""
:param field_name: name of a models field
:return: instance of the model the field belongs to
:raises AttributeError if the model instance cannot be found
"""
try:
model_instance = self.__getattribute__(field_name)
return model_instance
except AttributeError:
raise AttributeError(
f"Model instance with field name {field_name} not found"
)

def set_timeseries_chunk_size(self, new_chunk_size: int) -> None:
"""
Expand Down

0 comments on commit 4b43c8d

Please sign in to comment.