diff --git a/HISTORY.rst b/HISTORY.rst index 53c2d12..0f5d265 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -4,7 +4,9 @@ History 2.2.10 (unreleased) ------------------- -- Nothing changed yet. +- Move _field_model_dict and its methods back to GridH5ResultAdmin. +- Add `substances` and overwrite `get_model_instance_by_field_name` method + in GridH5WaterQualityResultAdmin. 2.2.9 (2024-05-27) diff --git a/threedigrid/admin/gridadmin.py b/threedigrid/admin/gridadmin.py index 3679ff2..41d031a 100644 --- a/threedigrid/admin/gridadmin.py +++ b/threedigrid/admin/gridadmin.py @@ -4,7 +4,6 @@ """ import logging -from collections import defaultdict import h5py import numpy as np @@ -17,7 +16,6 @@ from threedigrid.admin.nodes.models import Cells, EmbeddedNodes, Grid, Nodes from threedigrid.admin.pumps.models import Pumps from threedigrid.geo_utils import raise_import_exception, transform_bbox -from threedigrid.orm.models import Model try: import pyproj @@ -55,7 +53,6 @@ def __init__(self, h5_file_path, file_modus="r", set_props=False): :param file_modus: mode with which to open the file (defaults to r=READ) """ - self._field_model_dict = defaultdict(list) self.grid_file = h5_file_path self.datasource_class = H5pyGroup self.is_rpc = False @@ -191,50 +188,6 @@ def has_levees(self): return False return bool(self.levees.id.size) - @property - def _field_model_map(self): - """ - :return: a dict of {: [model name, ...]} - """ - if self._field_model_dict: - return self._field_model_dict - - model_names = set() - for attr_name in dir(self): - # skip private attrs - if any([attr_name.startswith("__"), attr_name.startswith("_")]): - continue - try: - attr = getattr(self, attr_name) - except AttributeError: - logger.warning( - "Attribute: '{}' does not " "exist in h5py_file.".format(attr_name) - ) - continue - if not issubclass(type(attr), Model): - continue - model_names.add(attr_name) - - for model_name in model_names: - for x in getattr(self, model_name)._field_names: - self._field_model_dict[x].append(model_name) - return self._field_model_dict - - def get_model_instance_by_field_name(self, field_name): - """ - :param field_name: name of a models field - :return: instance of the model the field belongs to - :raises IndexError if the field name is not unique across models - """ - model_name = self._field_model_map.get(field_name) - if not model_name or len(model_name) != 1: - raise IndexError( - "Ambiguous result. Field name {} yields {} model(s)".format( - field_name, len(model_name) if model_name else 0 - ) - ) - return getattr(self, model_name[0]) - def get_extent_subset(self, subset_name, target_epsg_code=""): """ diff --git a/threedigrid/admin/gridresultadmin.py b/threedigrid/admin/gridresultadmin.py index dea3581..040db58 100644 --- a/threedigrid/admin/gridresultadmin.py +++ b/threedigrid/admin/gridresultadmin.py @@ -3,6 +3,7 @@ import logging import re +from collections import defaultdict from typing import List, Optional, Union import h5py @@ -37,6 +38,7 @@ StructureControl, StructureControlTypes, ) +from threedigrid.orm.models import Model logger = logging.getLogger(__name__) @@ -61,6 +63,7 @@ def __init__(self, h5_file_path, netcdf_file_path, file_modus="r", swmr=False): called subgrid_map.nc) :param file_modus: modus in which to open the files """ + self._field_model_dict = defaultdict(list) self._netcdf_file_path = netcdf_file_path super().__init__(h5_file_path, file_modus) @@ -154,6 +157,50 @@ def version_check(self): self.threedicore_version, ) + @property + def _field_model_map(self): + """ + :return: a dict of {: [model name, ...]} + """ + if self._field_model_dict: + return self._field_model_dict + + model_names = set() + for attr_name in dir(self): + # skip private attrs + if any([attr_name.startswith("__"), attr_name.startswith("_")]): + continue + try: + attr = getattr(self, attr_name) + except AttributeError: + logger.warning( + "Attribute: '{}' does not " "exist in h5py_file.".format(attr_name) + ) + continue + if not issubclass(type(attr), Model): + continue + model_names.add(attr_name) + + for model_name in model_names: + for x in getattr(self, model_name)._field_names: + self._field_model_dict[x].append(model_name) + return self._field_model_dict + + def get_model_instance_by_field_name(self, field_name): + """ + :param field_name: name of a models field + :return: instance of the model the field belongs to + :raises IndexError if the field name is not unique across models + """ + model_name = self._field_model_map.get(field_name) + if not model_name or len(model_name) != 1: + raise IndexError( + "Ambiguous result. Field name {} yields {} model(s)".format( + field_name, len(model_name) if model_name else 0 + ) + ) + return getattr(self, model_name[0]) + @property def threedicore_result_version(self): """ @@ -486,6 +533,7 @@ def __init__( netcdf_file_path: str, file_modus: str = "r", swmr: bool = False, + substances: List[str] = [], ) -> None: """ :param h5_file_path: path to the hdf5 gridadmin file @@ -529,6 +577,21 @@ def __init__( if isinstance(value, bytes): value = value.decode("utf-8") self.__getattribute__(substance).__setattr__(attr, value) + self.substances = list(substances) + + def get_model_instance_by_field_name(self, field_name): + """ + :param field_name: name of a models field + :return: instance of the model the field belongs to + :raises AttributeError if the model instance cannot be found + """ + try: + model_instance = self.__getattribute__(field_name) + return model_instance + except AttributeError: + raise AttributeError( + f"Model instance with field name {field_name} not found" + ) def set_timeseries_chunk_size(self, new_chunk_size: int) -> None: """