diff --git a/deployment/docker_trigger.py b/deployment/docker_trigger.py index 1729115..4bcd63f 100644 --- a/deployment/docker_trigger.py +++ b/deployment/docker_trigger.py @@ -4,10 +4,10 @@ from echoflow import echoflow_start from echoflow.stages.echoflow_trigger import echoflow_trigger from prefect import flow -from prefect.task_runners import SequentialTaskRunner +from prefect.task_runners import ThreadPoolTaskRunner from typing import Any, Dict, Optional, Union -@flow(name="Docker-Trigger", task_runner=SequentialTaskRunner()) +@flow(name="Docker-Trigger", task_runner=ThreadPoolTaskRunner(max_workers=1)) def docker_trigger( dataset_config: Union[dict, str, Path], pipeline_config: Union[dict, str, Path], diff --git a/echodataflow/__init__.py b/echodataflow/__init__.py index d02aaaf..cc75207 100644 --- a/echodataflow/__init__.py +++ b/echodataflow/__init__.py @@ -7,8 +7,9 @@ echodataflow_create_prefect_profile, echodataflow_start, get_active_profile, load_profile) -from .utils.config_utils import extract_fs, glob_url, load_block +from .utils.config_utils import extract_fs, glob_url from .utils.file_utils import get_ed_list, get_last_run_output, get_zarr_list +from .utils.filesystem_utils import load_block from .docker_trigger import docker_trigger try: diff --git a/echodataflow/docker_trigger.py b/echodataflow/docker_trigger.py index fe259b0..f1c9d67 100644 --- a/echodataflow/docker_trigger.py +++ b/echodataflow/docker_trigger.py @@ -1,10 +1,10 @@ from pathlib import Path from echodataflow.stages.echodataflow_trigger import echodataflow_trigger from prefect import flow -from prefect.task_runners import SequentialTaskRunner +from prefect.task_runners import ThreadPoolTaskRunner from typing import Any, Dict, Optional, Union -@flow(name="docker-trigger-latest", task_runner=SequentialTaskRunner()) +@flow(name="docker-trigger-latest", task_runner=ThreadPoolTaskRunner(max_workers=1)) def docker_trigger( dataset_config: Union[dict, str, Path], pipeline_config: Union[dict, str, Path], diff --git a/echodataflow/extensions/file_downloader.py b/echodataflow/extensions/file_downloader.py index f9a42db..c6d7cb7 100644 --- a/echodataflow/extensions/file_downloader.py +++ b/echodataflow/extensions/file_downloader.py @@ -15,10 +15,12 @@ from prefect.client.schemas.objects import FlowRun, StateType from prefect.states import Cancelled -from echodataflow.utils.config_utils import glob_url, handle_storage_options +from echodataflow.utils.config_utils import glob_url from echodataflow.utils.file_utils import extract_fs, make_temp_folder import shlex +from echodataflow.utils.filesystem_utils import handle_storage_options + @task def download_temp_file(file_url: str, storage_options: Dict[str, Any], dest_dir: str, delete_on_transfer: bool, replace: bool) -> str: diff --git a/echodataflow/extensions/file_monitor.py b/echodataflow/extensions/file_monitor.py index 946d7ef..5e383a4 100644 --- a/echodataflow/extensions/file_monitor.py +++ b/echodataflow/extensions/file_monitor.py @@ -19,8 +19,8 @@ from echodataflow.models.datastore import StorageType from echodataflow.models.run import EDFRun, FileDetails -from echodataflow.utils.config_utils import glob_url, load_block -from prefect.task_runners import SequentialTaskRunner +from echodataflow.utils.config_utils import glob_url +from echodataflow.utils.filesystem_utils import load_block @task def execute_flow( @@ -112,10 +112,10 @@ def file_monitor( new_run = datetime.now(tz=timezone.utc).isoformat() edfrun: EDFRun = None - try: + try: edfrun = load_block( name=block_name, - type=StorageType.EDFRUN, + stype=StorageType.EDFRUN, ) except Exception as e: print(e) diff --git a/echodataflow/models/echodataflow_config.py b/echodataflow/models/echodataflow_config.py index 1a3790c..9fb302a 100644 --- a/echodataflow/models/echodataflow_config.py +++ b/echodataflow/models/echodataflow_config.py @@ -18,11 +18,9 @@ Email: sbutala@uw.edu Date: August 22, 2023 """ -import json from typing import Any, Dict, List, Optional from prefect.blocks.core import Block -from pydantic import SecretStr from .datastore import StorageType @@ -43,9 +41,9 @@ class EchodataflowPrefectConfig(Block): class Config: arbitrary_types_allowed = True - prefect_account_id: str = None - prefect_api_key: str = None - prefect_workspace_id: str = None + prefect_account_id: Optional[str] = None + prefect_api_key: Optional[str] = None + prefect_workspace_id: Optional[str] = None profile_name: str = None diff --git a/echodataflow/stages/echodataflow.py b/echodataflow/stages/echodataflow.py index b08aca1..372fcd4 100644 --- a/echodataflow/stages/echodataflow.py +++ b/echodataflow/stages/echodataflow.py @@ -47,9 +47,9 @@ ) import echopype as ep -from echodataflow.utils.config_utils import load_block from echodataflow.stages.echodataflow_trigger import echodataflow_trigger +from echodataflow.utils.filesystem_utils import handle_storage_options def check_internet_connection(host="8.8.8.8", port=53, timeout=5): @@ -236,7 +236,7 @@ def echodataflow_start( # Try loading the Prefect config block try: - load_block(name="echodataflow-config", type=StorageType.ECHODATAFLOW) + handle_storage_options({'block_name':"echodataflow-config", 'type':StorageType.ECHODATAFLOW}) except ValueError: print( "\nNo Prefect Cloud Configuration found. Creating Prefect Local named 'echodataflow-local'. Please add your prefect cloud " @@ -546,8 +546,7 @@ def load_credential_configuration(sync: bool = False): current_config = asyncio.run(current_config) if current_config is not None: for base in current_config.blocks: - block = load_block(base.name, base.type) - block_dict = dict(block) + block_dict = handle_storage_options(base) block_dict["name"] = base.name block_dict["active"] = base.active block_dict["options"] = json.dumps(base.options) diff --git a/echodataflow/stages/echodataflow_trigger.py b/echodataflow/stages/echodataflow_trigger.py index 120fec1..5ff371d 100644 --- a/echodataflow/stages/echodataflow_trigger.py +++ b/echodataflow/stages/echodataflow_trigger.py @@ -17,28 +17,24 @@ import json from pathlib import Path from typing import Optional, Union -from fastapi.encoders import jsonable_encoder +from fastapi.encoders import jsonable_encoder from prefect import flow -from prefect.task_runners import SequentialTaskRunner -from prefect.blocks.core import Block -from prefect.variables import Variable +from prefect.task_runners import ThreadPoolTaskRunner from echodataflow.aspects.singleton_echodataflow import Singleton_Echodataflow from echodataflow.models.datastore import Dataset from echodataflow.models.pipeline import Recipe from echodataflow.utils import log_util -from echodataflow.utils.config_utils import ( - check_config, - extract_config, - get_storage_options, - load_block, -) +from echodataflow.utils.config_utils import (check_config, + parse_dynamic_parameters, + parse_yaml_config) +from echodataflow.utils.filesystem_utils import handle_storage_options from .subflows.initialization_flow import init_flow -@flow(name="Echodataflow", task_runner=SequentialTaskRunner()) +@flow(name="Echodataflow", task_runner=ThreadPoolTaskRunner(max_workers=1)) def echodataflow_trigger( dataset_config: Union[dict, str, Path], pipeline_config: Union[dict, str, Path], @@ -82,48 +78,15 @@ def echodataflow_trigger( print("Pipeline output:", pipeline_output) """ - if storage_options: - # Check if storage_options is a Block (fsspec storage) and convert it to a dictionary - if isinstance(storage_options, Block): - storage_options = get_storage_options(storage_options=storage_options) - elif isinstance(storage_options, dict) and storage_options.get("block_name"): - block = load_block( - name=storage_options.get("block_name"), type=storage_options.get("type") - ) - storage_options = get_storage_options(block) - else: - storage_options = {} - - if isinstance(dataset_config, Path): - dataset_config = str(dataset_config) - if isinstance(logging_config, Path): - logging_config = str(logging_config) - if isinstance(pipeline_config, Path): - pipeline_config = str(pipeline_config) + storage_options = handle_storage_options(storage_options=storage_options) + + dataset_config_dict = parse_yaml_config(config=dataset_config, storage_options=storage_options) + logging_config_dict = parse_yaml_config(config=logging_config, storage_options=storage_options) + pipeline_config_dict = parse_yaml_config(config=pipeline_config, storage_options=storage_options) + if isinstance(json_data_path, Path): json_data_path = str(json_data_path) - if isinstance(dataset_config, str): - if not dataset_config.endswith((".yaml", ".yml")): - raise ValueError("Configuration file must be a YAML!") - dataset_config_dict = extract_config(dataset_config, storage_options) - elif isinstance(dataset_config, dict): - dataset_config_dict = dataset_config - - if isinstance(pipeline_config, str): - if not pipeline_config.endswith((".yaml", ".yml")): - raise ValueError("Configuration file must be a YAML!") - pipeline_config_dict = extract_config(pipeline_config, storage_options) - elif isinstance(pipeline_config, dict): - pipeline_config_dict = pipeline_config - - if isinstance(logging_config, str): - if not logging_config.endswith((".yaml", ".yml")): - raise ValueError("Configuration file must be a YAML!") - logging_config_dict = extract_config(logging_config, storage_options) - else: - logging_config_dict = logging_config - log_util.log( msg={ "msg": f"Dataset Configuration Loaded For This Run", @@ -144,7 +107,6 @@ def echodataflow_trigger( }, eflogging=dataset_config_dict.get("logging"), ) - print(dataset_config_dict) log_util.log( msg={ @@ -172,62 +134,18 @@ def echodataflow_trigger( check_config(dataset_config_dict, pipeline_config_dict) pipeline = Recipe(**pipeline_config_dict) dataset = Dataset(**dataset_config_dict) - - if options.get("storage_options_override") and not options["storage_options_override"]: - storage_options = {} - if not storage_options: - if dataset.output.storage_options: - if not dataset.output.storage_options.anon: - block = load_block( - name=dataset.output.storage_options.block_name, - type=dataset.output.storage_options.type, - ) - dataset.output.storage_options_dict = get_storage_options(block) - else: - dataset.output.storage_options_dict = {"anon": dataset.output.storage_options.anon} - - if dataset.args.storage_options: - if not dataset.args.storage_options.anon: - block = load_block( - name=dataset.args.storage_options.block_name, - type=dataset.args.storage_options.type, - ) - dataset.args.storage_options_dict = get_storage_options(block) - else: - dataset.args.storage_options_dict = {"anon": dataset.args.storage_options.anon} - if dataset.args.group: - if dataset.args.group.storage_options: - if not dataset.args.group.storage_options.anon: - block = load_block( - name=dataset.args.group.storage_options.block_name, - type=dataset.args.group.storage_options.type, - ) - dataset.args.group.storage_options_dict = get_storage_options(block) - else: - dataset.args.group.storage_options_dict = { - "anon": dataset.args.group.storage_options.anon - } - else: + + if options.get("storage_options_override", False): dataset.output.storage_options_dict = storage_options dataset.args.storage_options_dict = storage_options dataset.args.group.storage_options_dict = storage_options + else: + dataset.output.storage_options_dict = handle_storage_options(storage_options=dataset.output.storage_options) + dataset.args.storage_options_dict = handle_storage_options(storage_options=dataset.args.storage_options) + dataset.args.group.storage_options_dict = handle_storage_options(storage_options=dataset.args.group.storage_options) + + edf = Singleton_Echodataflow(log_file=logging_config_dict, pipeline=pipeline, dataset=dataset) - print("\nInitiliazing Singleton Object") - Singleton_Echodataflow(log_file=logging_config_dict, pipeline=pipeline, dataset=dataset) - - if dataset.args.parameters and dataset.args.parameters.file_name and dataset.args.parameters.file_name == "VAR_RUN_NAME": - var: Variable = Variable.get("run_name", default=None) - if not var: - raise ValueError("No variable found for name `run_name`") - else: - dataset.args.parameters.file_name = var.value - - # Change made to enable dynamic execution using an extension - if options and options.get("file_name"): - dataset.args.parameters.file_name = options.get("file_name") + dataset = parse_dynamic_parameters(dataset, options=options) - if options and options.get("run_name"): - dataset.name = options.get("run_name") - - print("\nReading Configurations") return init_flow(config=dataset, pipeline=pipeline, json_data_path=json_data_path) \ No newline at end of file diff --git a/echodataflow/stages/subflows/initialization_flow.py b/echodataflow/stages/subflows/initialization_flow.py index 6601308..db3a9f0 100644 --- a/echodataflow/stages/subflows/initialization_flow.py +++ b/echodataflow/stages/subflows/initialization_flow.py @@ -25,7 +25,7 @@ from distributed import Client, LocalCluster from fastapi.encoders import jsonable_encoder from prefect import flow -from prefect.task_runners import SequentialTaskRunner +from prefect.task_runners import ThreadPoolTaskRunner from prefect_dask import DaskTaskRunner from echodataflow.aspects.echodataflow_aspect import echodataflow @@ -45,9 +45,10 @@ process_output_groups, store_json_output) from echodataflow.utils.function_utils import dynamic_function_call +from echodataflow.utils.xr_utils import combine_datasets, fetch_slice_from_store -@flow(name="Initialization", task_runner=SequentialTaskRunner()) +@flow(name="Initialization", task_runner=ThreadPoolTaskRunner(max_workers=1)) @echodataflow(type="FLOW") def init_flow(pipeline: Recipe, config: Dataset, json_data_path: Optional[str] = None): """ @@ -418,27 +419,16 @@ def get_input_from_store_folder(config: Dataset): store_5_output = process_store_folder(config, store_5, end_time) for name, gr in store_18_output.group.items(): - - edf_18 = gr.data[0] - store_18 = xr.open_mfdataset(paths=[ed.out_path for ed in gr.data], engine="zarr", - combine="by_coords", - data_vars="minimal", - coords="minimal", - compat="override").compute() - store_18 = store_18.sel(ping_time=slice(pd.to_datetime(edf_18.start_time, unit="ns"), pd.to_datetime(edf_18.end_time, unit="ns"))) + edf_18 = fetch_slice_from_store(edf_group=gr, config=config) + store_18 = edf_18.data if not store_5_output.group.get(name): raise ValueError(f"No window found in MVBS store (5 channels); window missing -> {name}") - edf_5 = store_5_output.group[name].data[0] - store_5 = xr.open_mfdataset(paths=[ed.out_path for ed in store_5_output.group[name].data], engine="zarr", - combine="by_coords", - data_vars="minimal", - coords="minimal", - compat="override").compute() - store_5 = store_5.sel(ping_time=slice(pd.to_datetime(edf_5.start_time, unit="ns"), pd.to_datetime(edf_5.end_time, unit="ns"))) + edf_5 = fetch_slice_from_store(edf_group=store_5_output.group[name], config=config) + store_5 = edf_5.data - edf_5.data, edf_5.data_ref = combine_datasets(store_18, store_5) + edf_5.data, edf_5.data_ref = combine_datasets(store_18=store_18, store_5=store_5, config=config) combo_output.group[name] = gr.model_copy() combo_output.group[name].data = [edf_5] @@ -449,7 +439,7 @@ def get_input_from_store_folder(config: Dataset): eflogging=config.logging, ) - for dim, size in edf_5.data_ref.dims.items(): + for dim, size in edf_5.data.dims.items(): log_util.log( msg={"msg": f"{ dim } : {size}", "mod_name": __file__, "func_name": "Mask"}, use_dask=False, @@ -458,83 +448,6 @@ def get_input_from_store_folder(config: Dataset): return combo_output -def process_xrd(ds: xr.Dataset, freq_wanted = [120000, 38000, 18000]) -> xr.Dataset: - ds = ds.sel(depth=slice(None, 590)) - - ch_wanted = [int((np.abs(ds["frequency_nominal"]-freq)).argmin()) for freq in freq_wanted] - ds = ds.isel( - channel=ch_wanted - ) - return ds - -def combine_datasets(store_18: xr.Dataset, store_5: xr.Dataset) -> Tuple[torch.Tensor, xr.Dataset]: - ds_32k_120k = None - ds_18k = None - combined_ds = None - try: - partial_channel_name = ["ES18"] - ds_18k = extract_channels(store_18, partial_channel_name) - partial_channel_name = ["ES38", "ES120"] - ds_32k_120k = extract_channels(store_5, partial_channel_name) - except Exception as e: - partial_channel_name = ["ES18"] - ds_18k = extract_channels(store_5, partial_channel_name) - partial_channel_name = ["ES38", "ES120"] - ds_32k_120k = extract_channels(store_18, partial_channel_name) - - if not ds_18k or not ds_32k_120k: - raise ValueError("Could not find the required channels in the datasets") - - ds_18k = process_xrd(ds_18k, freq_wanted=[18000]) - ds_32k_120k = process_xrd(ds_32k_120k, freq_wanted=[120000, 38000]) - - combined_ds = xr.merge([ds_18k["Sv"], ds_32k_120k["Sv"], - ds_18k['latitude'], ds_18k['longitude'], - ds_18k["frequency_nominal"], ds_32k_120k["frequency_nominal"] - ]) - combined_ds.attrs = ds_18k.attrs - - combined_ds = ( - combined_ds - .transpose("channel", "depth", "ping_time") - .sel(depth=slice(None, 590)) - ) - - depth = combined_ds['depth'] - ping_time = combined_ds['ping_time'] - - # Create a tensor with R=120 kHz, G=38 kHz, B=18 kHz mapping - red_channel = extract_channels(combined_ds, ["ES120"]) - green_channel = extract_channels(combined_ds, ["ES38"]) - blue_channel = extract_channels(combined_ds, ["ES18"]) - - tensor = xr.concat([red_channel, green_channel, blue_channel], dim='channel') - tensor['channel'] = ['R', 'G', 'B'] - tensor = tensor.assign_coords({'depth': depth, 'ping_time': ping_time}) - - mvbs_tensor = torch.tensor(tensor['Sv'].values, dtype=torch.float32) - - return (mvbs_tensor, combined_ds) - -def extract_channels(dataset: xr.Dataset, partial_names: List[str]) -> xr.Dataset: - """ - Extracts multiple channels data from the given xarray dataset using partial channel names. - - Args: - dataset (xr.Dataset): The input xarray dataset containing multiple channels. - partial_names (List[str]): The list of partial names of the channels to extract. - - Returns: - xr.Dataset: The dataset containing only the specified channels data. - """ - matching_channels = [] - for partial_name in partial_names: - matching_channels.extend([channel for channel in dataset.channel.values if partial_name in str(channel)]) - - if len(matching_channels) == 0: - raise ValueError(f"No channels found matching any of '{partial_names}'") - - return dataset.sel(channel=matching_channels) def process_store_folder(config: Dataset, store: str, end_time: datetime): output: Output = Output() diff --git a/echodataflow/stages/subflows/mask_prediction.py b/echodataflow/stages/subflows/mask_prediction.py index 8a0ee75..395be54 100644 --- a/echodataflow/stages/subflows/mask_prediction.py +++ b/echodataflow/stages/subflows/mask_prediction.py @@ -16,22 +16,22 @@ Date: August 22, 2023 """ from collections import defaultdict -from pathlib import Path from typing import Dict, Optional -from prefect import flow, task import torch import xarray as xr -import numpy as np +from prefect import flow, task -import pandas as pd from echodataflow.aspects.echodataflow_aspect import echodataflow from echodataflow.models.datastore import Dataset -from echodataflow.models.output_model import EchodataflowObject, ErrorObject, Group +from echodataflow.models.output_model import (EchodataflowObject, ErrorObject, + Group) from echodataflow.models.pipeline import Stage from echodataflow.utils import log_util -from echodataflow.utils.file_utils import get_out_zarr, get_working_dir, get_zarr_list, isFile -from src.model.BinaryHakeModel import BinaryHakeModel +from echodataflow.utils.file_utils import (get_out_zarr, get_working_dir, + get_zarr_list, isFile) +from echodataflow.utils.flow_utils import load_data_in_memory, load_model +from echodataflow.utils.xr_utils import assemble_da, convert_to_tensor @flow @@ -68,66 +68,12 @@ def echodataflow_mask_prediction( futures = defaultdict(list) - try: - log_util.log( - msg={"msg": f"Loading model now ---->", "mod_name": __file__, "func_name": "Mask_Prediction"}, - use_dask=stage.options["use_dask"], - eflogging=config.logging, - ) - model_path = f"/home/exouser/hake_data/model/backup_model_weights/binary_hake_model_1.0m_bottom_offset_1.0m_depth_2017_2019_ver_1.ckpt" - - # Load binary hake models with weights - model = BinaryHakeModel("placeholder_experiment_name", - Path("placeholder_score_tensor_dir"), - "placeholder_tensor_log_dir", 0).eval() - - log_util.log( - msg={"msg": f"Loading model at", "mod_name": __file__, "func_name": "Mask_Prediction"}, - use_dask=stage.options["use_dask"], - eflogging=config.logging, - ) - - log_util.log( - msg={"msg": f"{stage.external_params.get('model_path', model_path)}", "mod_name": __file__, "func_name": "Mask_Prediction"}, - use_dask=stage.options["use_dask"], - eflogging=config.logging, - ) - - model.load_state_dict(torch.load( - stage.external_params.get('model_path', model_path) - )["state_dict"]) - - log_util.log( - msg={"msg": f"Model loaded succefully", "mod_name": __file__, "func_name": "Mask_Prediction"}, - use_dask=stage.options["use_dask"], - eflogging=config.logging, - ) - except Exception as e: - log_util.log( - msg={"msg": "", "mod_name": __file__, "func_name": "Mask_Prediction"}, - use_dask=stage.options["use_dask"], - eflogging=config.logging, - error=e - ) - raise e + model = load_model(stage=stage, config=config) - for name, gr in groups.items(): - if gr.metadata and gr.metadata.is_store_folder and len(gr.data) > 0: - edf = gr.data[0] - store = xr.open_mfdataset(paths=[ed.out_path for ed in gr.data], engine="zarr", - combine="by_coords", - data_vars="minimal", - coords="minimal", - compat="override").compute() - edf.data = store.sel(ping_time=slice(pd.to_datetime(edf.start_time, unit="ns"), pd.to_datetime(edf.end_time, unit="ns"))) - if edf.data.notnull().any(): - gr.data = [edf] - else: - continue - del store - + groups = load_data_in_memory(config=config, groups=groups) + + for name, gr in groups.items(): for ed in gr.data: - gname = ed.out_path.split(".")[0] + ".MaskPrediction" new_process = process_mask_prediction.with_options( task_run_name=gname, name=gname, retries=3 @@ -146,6 +92,28 @@ def echodataflow_mask_prediction( return groups +@task +@echodataflow() +def process_mask_prediction_tensor( + groups: Dict[str, Group], config: Dataset, stage: Stage, prev_stage: Optional[Stage] +): + working_dir = get_working_dir(stage=stage, config=config) + + model = load_model(stage=stage, config=config) + + groups = load_data_in_memory(config=config, groups=groups) + + for name, gr in groups.items(): + results = [] + for ed in gr.data: + pmpu = process_mask_prediction_util.with_options(task_run_name=ed.filename) + results.append(pmpu.fn(ed, config, stage, working_dir, model)) + + groups[name].data = results + + return groups + + @task @echodataflow() def process_mask_prediction( @@ -221,62 +189,23 @@ def process_mask_prediction( eflogging=config.logging, ) - ed_list = get_zarr_list.fn(transect_data=ed, storage_options=config.output.storage_options_dict) + mvbs_slice = get_zarr_list.fn(transect_data=ed, storage_options=config.output.storage_options_dict)[0] - ed_list[0] = ed_list[0].sel(depth=slice(None, 590)) + mvbs_slice = mvbs_slice.sel(depth=slice(None, 590)) log_util.log( msg={"msg": 'Computing mask_prediction', "mod_name": __file__, "func_name": file_name}, use_dask=stage.options["use_dask"], eflogging=config.logging, ) - - bottom_offset = stage.external_params.get('bottom_offset', 1.0) - temperature = stage.external_params.get('temperature', 0.5) - freq_wanted = stage.external_params.get('freq_wanted', [120000, 38000, 18000]) - - ch_wanted = [int((np.abs(ed_list[0]["frequency_nominal"]-freq)).argmin()) for freq in freq_wanted] - - log_util.log( - msg={"msg": f"Channel order {ch_wanted}", "mod_name": __file__, "func_name": file_name}, - use_dask=stage.options["use_dask"], - eflogging=config.logging, - ) - - # Ensure dims sequence is (channel, depth, ping_time) - # and channel sequence is 120, 38, 18 kHz - mvbs_slice = ( - ed_list[0] - .transpose("channel", "depth", "ping_time") - .isel(channel=ch_wanted) - ) - mvbs_tensor = torch.tensor(mvbs_slice['Sv'].values, dtype=torch.float32) + if ed.data_ref is not None: + input_tensor = ed.data_ref + else: + mvbs_slice, input_tensor = convert_to_tensor(combined_ds=mvbs_slice, freq_wanted=stage.external_params.get('freq_wanted', [120000, 38000, 18000]), config=config) - da_MVBS_tensor = torch.clip( - mvbs_tensor.clone().detach().to(torch.float16), - min=-70, - max=-36, - ) - - log_util.log( - msg={"msg": f"converted and clipped tensor", "mod_name": __file__, "func_name": file_name}, - use_dask=stage.options["use_dask"], - eflogging=config.logging, - ) - # Replace NaN values with min Sv - da_MVBS_tensor[torch.isnan(da_MVBS_tensor)] = -70 - - MVBS_tensor_normalized = ( - (da_MVBS_tensor - (-70.0)) / (-36.0 - (-70.0)) * 255.0 - ) - input_tensor = MVBS_tensor_normalized.unsqueeze(0).float() + temperature = stage.external_params.get('temperature', 0.5) - log_util.log( - msg={"msg": f"Normalized tensor", "mod_name": __file__, "func_name": file_name}, - use_dask=stage.options["use_dask"], - eflogging=config.logging, - ) score_tensor = model(input_tensor).detach().squeeze(0) log_util.log( @@ -284,8 +213,6 @@ def process_mask_prediction( use_dask=stage.options["use_dask"], eflogging=config.logging, ) - - # dims = stage.external_params.get('dims', ['ping_time', 'depth']) dims = { 'species': [ "background", "hake"], @@ -327,28 +254,6 @@ def process_mask_prediction( storage_options=config.output.storage_options_dict, ) - else: - log_util.log( - msg={ - "msg": f"Skipped processing {file_name}. File found in the destination folder. To replace or reprocess set `use_offline` flag to False", - "mod_name": __file__, - "func_name": file_name, - }, - use_dask=stage.options["use_dask"], - eflogging=config.logging, - ) - - log_util.log( - msg={"msg": f" ---- Exiting ----", "mod_name": __file__, "func_name": file_name}, - use_dask=stage.options["use_dask"], - eflogging=config.logging, - ) - ed.stages["mask"] = out_zarr - ed.error = ErrorObject(errorFlag=False) - ed.stages[stage.name] = out_zarr - - if mvbs_slice: - slice_zarr = get_out_zarr( group=stage.options.get("group", True), working_dir=working_dir, @@ -366,186 +271,13 @@ def process_mask_prediction( ed.out_path = slice_zarr ed.stages[stage.name] = slice_zarr - - ed.data = None - del da_mask_hake - del da_score_hake - del softmax_score_tensor - del score_tensor - del input_tensor - del mvbs_slice - - return ed - except Exception as e: - log_util.log( - msg={"msg": "", "mod_name": __file__, "func_name": file_name}, - use_dask=stage.options["use_dask"], - eflogging=config.logging, - error=e - ) - ed.error = ErrorObject(errorFlag=True, error_desc=str(e)) - return ed - -def assemble_da(data_array, dims): - da = xr.DataArray( - data_array, dims=dims.keys() - ) - da = da.assign_coords(dims - ) - return da - - -@task -@echodataflow() -def process_mask_prediction_tensor( - groups: Dict[str, Group], config: Dataset, stage: Stage, prev_stage: Optional[Stage] -): - working_dir = get_working_dir(stage=stage, config=config) - - for name, gr in groups.items(): - results = [] - for ed in gr.data: - if ed.data is not None: - log_util.log( - msg={"msg": "ed data is not none", "mod_name": __file__, "func_name": "Mask"}, - use_dask=stage.options["use_dask"], - eflogging=config.logging, - ) - else: - log_util.log( - msg={"msg": "ed data is none", "mod_name": __file__, "func_name": "Mask"}, - use_dask=stage.options["use_dask"], - eflogging=config.logging, - ) - pmpu = process_mask_prediction_util.with_options(task_run_name=ed.filename) - results.append(pmpu.fn(ed, config, stage, working_dir)) - - groups[name].data = results - - return groups - -@task -def process_mask_prediction_util(ed: EchodataflowObject, config: Dataset, stage: Stage, working_dir: str): - file_name = ed.filename + "_mask.zarr" - - try: - log_util.log( - msg={"msg": " ---- Entering ----", "mod_name": __file__, "func_name": file_name}, - use_dask=stage.options["use_dask"], - eflogging=config.logging, - ) - - out_zarr = get_out_zarr( - group=stage.options.get("group", True), - working_dir=working_dir, - transect=ed.group_name, - file_name=file_name, - storage_options=config.output.storage_options_dict, - ) - - log_util.log( - msg={ - "msg": f"Processing file, output will be at {out_zarr}", - "mod_name": __file__, - "func_name": file_name, - }, - use_dask=stage.options["use_dask"], - eflogging=config.logging, - ) - - if ( - stage.options.get("use_offline") == False - or isFile(out_zarr, config.output.storage_options_dict) == False - ): - log_util.log( - msg={ - "msg": f"File not found in the destination folder / use_offline flag is False", - "mod_name": __file__, - "func_name": file_name, - }, - use_dask=stage.options["use_dask"], - eflogging=config.logging, - ) - log_util.log( - msg={"msg": 'Computing mask_prediction', "mod_name": __file__, "func_name": file_name}, - use_dask=stage.options["use_dask"], - eflogging=config.logging, - ) - - bottom_offset = stage.external_params.get('bottom_offset', 1.0) - temperature = stage.external_params.get('temperature', 0.5) - - model_path = f"/home/exouser/hake_data/model/backup_model_weights/binary_hake_model_{bottom_offset}m_bottom_offset_1.0m_depth_2017_2019_ver_1.ckpt" - - # Load binary hake models with weights - model = BinaryHakeModel("placeholder_experiment_name", - Path("placeholder_score_tensor_dir"), - "placeholder_tensor_log_dir", 0).eval() - model.load_state_dict(torch.load( - stage.external_params.get('model_path', model_path) - )["state_dict"]) - - - mvbs_tensor = ed.data # tensor - - da_MVBS_tensor = torch.clip( - mvbs_tensor.clone().detach().to(torch.float16), - min=-70, - max=-36, - ) - - # Replace NaN values with min Sv - da_MVBS_tensor[torch.isnan(da_MVBS_tensor)] = -70 - - MVBS_tensor_normalized = ( - (da_MVBS_tensor - (-70.0)) / (-36.0 - (-70.0)) * 255.0 - ) - input_tensor = MVBS_tensor_normalized.unsqueeze(0).float() - - score_tensor = model(input_tensor).detach().squeeze(0) - - log_util.log( - msg={"msg": f"Converting to Zarr", "mod_name": __file__, "func_name": file_name}, - use_dask=stage.options["use_dask"], - eflogging=config.logging, - ) - - dims = {'species': [ "background", "hake"], 'depth': ed.data_ref["depth"].values, 'ping_time': ed.data_ref["ping_time"].values} - - da_score_hake = assemble_da(score_tensor.numpy(), dims=dims) - - softmax_score_tensor = torch.nn.functional.softmax( - score_tensor / temperature, dim=0 - ) - - dims.pop('species') - da_softmax_hake = assemble_da(softmax_score_tensor.numpy()[1,:,:], dims=dims) - - da_mask_hake = assemble_da(da_softmax_hake.where(da_softmax_hake > stage.options.get('th_softmax', 0.9)), dims=dims) - - score_zarr = get_out_zarr( - group=True, - working_dir=working_dir, - transect="Hake_Score", - file_name=ed.filename + "_score_hake.zarr", - storage_options=config.output.storage_options_dict, - ) - - da_score_hake.to_zarr( - store=score_zarr, - mode="w", - consolidated=True, - storage_options=config.output.storage_options_dict, - ) - - # Get mask from score - da_mask_hake.to_zarr( - store=out_zarr, - mode="w", - consolidated=True, - storage_options=config.output.storage_options_dict, - ) + del mvbs_slice + del da_mask_hake + del da_score_hake + del softmax_score_tensor + del score_tensor + del input_tensor else: log_util.log( @@ -566,32 +298,8 @@ def process_mask_prediction_util(ed: EchodataflowObject, config: Dataset, stage: ed.stages["mask"] = out_zarr ed.error = ErrorObject(errorFlag=False) ed.stages[stage.name] = out_zarr - ed.data = None - slice_zarr = get_out_zarr( - group=stage.options.get("group", True), - working_dir=working_dir, - transect=ed.group_name, - file_name=ed.filename+"_MVBS_Slice.zarr", - storage_options=config.output.storage_options_dict, - ) - - ed.data_ref.to_zarr( - store=slice_zarr, - mode="w", - consolidated=True, - storage_options=config.output.storage_options_dict, - ) - ed.out_path = slice_zarr - ed.data_ref = None - - ed.stages[stage.name] = slice_zarr - - del da_mask_hake - del da_score_hake - del softmax_score_tensor - del score_tensor - del input_tensor + except Exception as e: log_util.log( msg={"msg": "", "mod_name": __file__, "func_name": file_name}, @@ -599,8 +307,8 @@ def process_mask_prediction_util(ed: EchodataflowObject, config: Dataset, stage: eflogging=config.logging, error=e ) - ed.error = ErrorObject(errorFlag=True, error_desc=str(e)) + ed.error = ErrorObject(errorFlag=True, error_desc=str(e)) + finally: ed.data = None ed.data_ref = None - finally: - return ed + return ed \ No newline at end of file diff --git a/echodataflow/stages/subflows/slice_store.py b/echodataflow/stages/subflows/slice_store.py index 4d7f1cc..04f0768 100644 --- a/echodataflow/stages/subflows/slice_store.py +++ b/echodataflow/stages/subflows/slice_store.py @@ -19,7 +19,6 @@ from typing import Dict, Optional from prefect import flow -from prefect.task_runners import SequentialTaskRunner from prefect.variables import Variable import xarray as xr import pandas as pd diff --git a/echodataflow/stages/subflows/write_output.py b/echodataflow/stages/subflows/write_output.py index 3f596fb..2e7f513 100644 --- a/echodataflow/stages/subflows/write_output.py +++ b/echodataflow/stages/subflows/write_output.py @@ -1,7 +1,7 @@ from typing import Dict, Optional from prefect import flow import xarray as xr -from prefect.task_runners import SequentialTaskRunner +from prefect.task_runners import ThreadPoolTaskRunner import zarr.sync from echodataflow.models.datastore import Dataset from echodataflow.models.output_model import ErrorObject, Group @@ -11,7 +11,7 @@ from numcodecs import Zlib import zarr.storage -@flow(task_runner=SequentialTaskRunner()) +@flow(task_runner=ThreadPoolTaskRunner(max_workers=1)) def write_output(groups: Dict[str, Group], config: Dataset, stage: Stage, prev_stage: Optional[Stage]): log_util.log( msg={ diff --git a/echodataflow/tests/flow_tests/MVBS_pipeline.yaml b/echodataflow/tests/flow_tests/MVBS_pipeline.yaml index ecd073e..4fc67f1 100644 --- a/echodataflow/tests/flow_tests/MVBS_pipeline.yaml +++ b/echodataflow/tests/flow_tests/MVBS_pipeline.yaml @@ -1,6 +1,4 @@ active_recipe: MVBS_pipeline -use_local_dask: true -n_workers: 2 pipeline: - recipe_name: MVBS_pipeline stages: diff --git a/echodataflow/tests/unit/__init__.py b/echodataflow/tests/unit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/echodataflow/tests/unit/test_filesystem_utils.py b/echodataflow/tests/unit/test_filesystem_utils.py new file mode 100644 index 0000000..4e9ee1c --- /dev/null +++ b/echodataflow/tests/unit/test_filesystem_utils.py @@ -0,0 +1,60 @@ +import pytest +from unittest.mock import Mock, patch +from echodataflow.utils.filesystem_utils import handle_storage_options +from echodataflow.models.datastore import StorageOptions +from prefect_aws import AwsCredentials + +# Adjust Mocks setup +Block = Mock(return_value = AwsCredentials(aws_access_key_id="test", aws_secret_access_key='password')) +MockStorageOptions = Mock(return_value = StorageOptions()) +load_block = Mock() + +@pytest.fixture +def setup_blocks(): + block = Block() + storage_options = MockStorageOptions() + return block, storage_options + +class TestHandleStorageOptions: + def test_none(self): + """Test handling with no parameters.""" + assert handle_storage_options() == {} + + def test_empty_dict(self): + """Test handling with an empty dictionary.""" + assert handle_storage_options({}) == {} + + def test_anon_dict(self): + """Test handling with anonymous dictionary.""" + assert handle_storage_options({'anon': True}) == {'anon': True} + + def test_block(self, setup_blocks): + block, _ = setup_blocks + expected_dict = {'key': 'test', 'secret': 'password'} + assert handle_storage_options(block) == expected_dict + + def test_anonymous_storage_options(self, setup_blocks): + _, storage_options = setup_blocks + storage_options.anon = True + assert handle_storage_options(storage_options) == {"anon": True} + + @patch('echodataflow.utils.filesystem_utils.load_block') + def test_dict_with_block_name(self, mock_load_block): + storage_dict = {'block_name': 'echoflow-aws-credentials', 'type': 'AWS'} + expected_dict = {'key': 'test', 'secret': 'password'} + block = AwsCredentials(aws_access_key_id="test", aws_secret_access_key='password') + mock_load_block.return_value = block + assert handle_storage_options(storage_dict) == expected_dict + mock_load_block.assert_called_with(name="echoflow-aws-credentials", type="AWS") + + @patch('echodataflow.utils.filesystem_utils.load_block') + def test_storage_options(self, mock_load_block, setup_blocks): + _, storage_options = setup_blocks + storage_options.anon = False + storage_options.block_name = "echoflow-aws-credentials" + storage_options.type = "AWS" + expected_dict = {'key': 'test', 'secret': 'password'} + block = AwsCredentials(aws_access_key_id="test", aws_secret_access_key='password') + mock_load_block.return_value = block + assert handle_storage_options(storage_options) == expected_dict + mock_load_block.assert_called_with(name="echoflow-aws-credentials", type="AWS") \ No newline at end of file diff --git a/echodataflow/utils/config_utils.py b/echodataflow/utils/config_utils.py index 0899cc2..19d378b 100644 --- a/echodataflow/utils/config_utils.py +++ b/echodataflow/utils/config_utils.py @@ -22,8 +22,6 @@ raw_url_file: Optional[str] = None, json_storage_options: StorageOptions = None ) -> List[List[Dict[str, Any]]] - get_storage_options(storage_options: Block = None) -> Dict[str, Any] - load_block(name: str = None, type: StorageType = None) Author: Soham Butala Email: sbutala@uw.edu @@ -33,27 +31,21 @@ import itertools as it import json import os +from pathlib import Path import re -from typing import Any, Coroutine, Dict, List, Literal, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Union from zipfile import ZipFile -import nest_asyncio import yaml from dateutil import parser from prefect import task -from prefect.filesystems import Block -from prefect_aws import AwsCredentials -from prefect_azure import AzureCosmosDbCredentials +from prefect.variables import Variable from echodataflow.aspects.echodataflow_aspect import echodataflow -from echodataflow.models.datastore import Dataset, StorageOptions, StorageType -from echodataflow.models.echodataflow_config import EchodataflowConfig +from echodataflow.models.datastore import Dataset, StorageOptions from echodataflow.models.pipeline import Stage -from echodataflow.models.run import EDFRun from echodataflow.utils.file_utils import extract_fs, isFile -nest_asyncio.apply() - @task def extract_config( @@ -509,81 +501,39 @@ def club_raw_files( all_files.append(files) return all_files - -def get_storage_options(storage_options: Block = None) -> Dict[str, Any]: - """ - Get storage options from a Block. - - Parameters: - storage_options (Block, optional): A block containing storage options. - - Returns: - Dict[str, Any]: Dictionary containing storage options. - - Example: - aws_credentials = AwsCredentials(...) - storage_opts = get_storage_options(aws_credentials) - """ - storage_options_dict: Dict[str, Any] = {} - if storage_options is not None: - if isinstance(storage_options, AwsCredentials): - storage_options_dict["key"] = storage_options.aws_access_key_id - storage_options_dict[ - "secret" - ] = storage_options.aws_secret_access_key.get_secret_value() - if storage_options.aws_session_token: - storage_options_dict["token"] = storage_options.aws_session_token - - return storage_options_dict - - -def handle_storage_options(storage_options: Optional[Dict] = None) -> Dict: - if storage_options: - if isinstance(storage_options, Block): - return get_storage_options(storage_options=storage_options) - elif isinstance(storage_options, dict) and storage_options.get("block_name"): - block = load_block( - name=storage_options.get("block_name"), type=storage_options.get("type") - ) - return get_storage_options(block) - else: - return storage_options if storage_options and len(storage_options.keys()) > 0 else {} - return {} - -def load_block(name: str = None, type: StorageType = None): - """ - Load a block of a specific type by name. - - Parameters: - name (str, optional): The name of the block to load. - type (StorageType, optional): The type of the block to load. - - Returns: - block: The loaded block. - - Raises: - ValueError: If name or type is not provided. - - Example: - loaded_aws_credentials = load_block(name="my-aws-creds", type=StorageType.AWS) - """ - if name is None or type is None: - raise ValueError("Cannot load block without name") - - if type == StorageType.AWS or type == StorageType.AWS.value: - coro = AwsCredentials.load(name=name) - elif type == StorageType.AZCosmos or type == StorageType.AZCosmos.value: - coro = AzureCosmosDbCredentials.load(name=name) - elif type == StorageType.ECHODATAFLOW or type == StorageType.ECHODATAFLOW.value: - coro = EchodataflowConfig.load(name=name) - elif type == StorageType.EDFRUN or type == StorageType.EDFRUN.value: - coro = EDFRun.load(name=name) - - if isinstance(coro, Coroutine): - block = nest_asyncio.asyncio.run(coro) - else: - block = coro - return block +def parse_yaml_config(config: Union[dict, str, Path], storage_options: Dict[str, Any]) -> Dict: + if isinstance(config, Path) or isinstance(config, str): + config = convert_path_to_str(config) + validate_yaml_file(config) + return extract_config(config, storage_options) + return config + +def convert_path_to_str(config: Union[str, Path]) -> str: + return str(config) if isinstance(config, Path) else config + +def validate_yaml_file(config_str) -> None: + if not config_str.endswith((".yaml", ".yml")): + raise ValueError("Configuration file must be a YAML!") + +def parse_dynamic_parameters(dataset: Dataset, options: Dict[str, Any]) -> Dataset: + update_file_name_from_variable(dataset) + apply_options_to_dataset(dataset, options) + return dataset + +def update_file_name_from_variable(dataset: Dataset) -> None: + parameters = dataset.args.parameters + if parameters.file_name == "VAR_RUN_NAME": + run_name_var = Variable.get("run_name", default=None) + if not run_name_var: + raise ValueError("No variable found for name `run_name`") + parameters.file_name = run_name_var.value + +def apply_options_to_dataset(dataset: Dataset, options: Dict[str, Any]) -> None: + if "file_name" in options: + dataset.args.parameters.file_name = options["file_name"] + if "run_name" in options: + dataset.name = options["run_name"] + def sanitize_external_params(config: Dataset, external_params: Dict[str, Any]): """ diff --git a/echodataflow/utils/filesystem_utils.py b/echodataflow/utils/filesystem_utils.py new file mode 100644 index 0000000..0b8f10d --- /dev/null +++ b/echodataflow/utils/filesystem_utils.py @@ -0,0 +1,105 @@ +from typing import Any, Dict, Optional, Union + +import nest_asyncio +from prefect.filesystems import Block +from prefect_aws import AwsCredentials +from prefect_azure import AzureCosmosDbCredentials + +from echodataflow.models.datastore import StorageOptions, StorageType +from echodataflow.models.echodataflow_config import (BaseConfig, + EchodataflowConfig) +from echodataflow.models.run import EDFRun + + +def handle_storage_options(storage_options: Optional[Union[Dict, StorageOptions, Block, BaseConfig]] = None) -> Dict: + if isinstance(storage_options, Block): + return _handle_block(storage_options) + elif isinstance(storage_options, dict): + return _handle_dict_options(storage_options) + elif isinstance(storage_options, StorageOptions): + return _handle_storage_options_class(storage_options) + elif isinstance(storage_options, BaseConfig): + return _handle_baseconfig_options_class(storage_options) + else: + return _handle_default(storage_options) + +def _handle_block(block: Block) -> Dict: + return get_storage_options(storage_options=block) + +def _handle_dict_options(options: Dict[str, Any]) -> Dict: + if "block_name" in options: + block = load_block(name=options["block_name"], stype=options.get("type", None)) + return get_storage_options(block) + return options if options else {} + +def _handle_storage_options_class(options: StorageOptions) -> Dict: + if not options.anon: + block = load_block(name=options.block_name, stype=options.type) + return get_storage_options(block) + return {"anon": options.anon} + +def _handle_baseconfig_options_class(options: BaseConfig) -> Dict: + block = load_block(name=options.name, stype=options.type) + return dict(block) + +def _handle_default(options: Dict[str, Any]): + return options if options else {} + + +def get_storage_options(storage_options: Block = None) -> Dict[str, Any]: + """ + Get storage options from a Block. + + Parameters: + storage_options (Block, optional): A block containing storage options. + + Returns: + Dict[str, Any]: Dictionary containing storage options. + + Example: + aws_credentials = AwsCredentials(...) + storage_opts = get_storage_options(aws_credentials) + """ + storage_options_dict: Dict[str, Any] = {} + if storage_options is not None: + if isinstance(storage_options, AwsCredentials): + storage_options_dict["key"] = storage_options.aws_access_key_id + storage_options_dict[ + "secret" + ] = storage_options.aws_secret_access_key.get_secret_value() + if storage_options.aws_session_token: + storage_options_dict["token"] = storage_options.aws_session_token + + return storage_options_dict + +def load_block(name: str, stype: StorageType): + """ + Load a block of a specific type by name. + + Parameters: + name (str, optional): The name of the block to load. + stype (StorageType, optional): The type of the block to load. + + Returns: + block: The loaded block. + + Raises: + ValueError: If name or type is not provided. + + Example: + loaded_aws_credentials = load_block(name="my-aws-creds", stype=StorageType.AWS) + """ + if name is None or stype is None: + raise ValueError("Cannot load block without name or type") + + loader_map = { + StorageType.AWS: AwsCredentials, + StorageType.AZCosmos: AzureCosmosDbCredentials, + StorageType.ECHODATAFLOW: EchodataflowConfig, + StorageType.EDFRUN: EDFRun + } + + if stype in loader_map: + return nest_asyncio.asyncio.run(loader_map[stype].load(name=name)) + else: + raise ValueError(f"Unsupported storage type: {stype}") \ No newline at end of file diff --git a/echodataflow/utils/flow_utils.py b/echodataflow/utils/flow_utils.py new file mode 100644 index 0000000..2e87d58 --- /dev/null +++ b/echodataflow/utils/flow_utils.py @@ -0,0 +1,57 @@ +from typing import Dict +from echodataflow.models.datastore import Dataset +from echodataflow.models.output_model import Group +from echodataflow.models.pipeline import Stage +from echodataflow.utils import log_util +from pathlib import Path +import torch +from echodataflow.utils.xr_utils import fetch_slice_from_store +from src.model.BinaryHakeModel import BinaryHakeModel + + +def load_model(stage: Stage, config: Dataset): + try: + log_util.log( + msg={"msg": f"Loading model now ---->", "mod_name": __file__, "func_name": "Mask_Prediction"}, + use_dask=stage.options["use_dask"], + eflogging=config.logging, + ) + model_path = f"/home/exouser/hake_data/model/backup_model_weights/binary_hake_model_1.0m_bottom_offset_1.0m_depth_2017_2019_ver_1.ckpt" + + # Load binary hake models with weights + model = BinaryHakeModel("placeholder_experiment_name", + Path("placeholder_score_tensor_dir"), + "placeholder_tensor_log_dir", 0).eval() + + model.load_state_dict(torch.load( + stage.external_params.get('model_path', model_path) + )["state_dict"]) + + log_util.log( + msg={"msg": f"Model loaded succefully", "mod_name": __file__, "func_name": "Mask_Prediction"}, + use_dask=stage.options["use_dask"], + eflogging=config.logging, + ) + except Exception as e: + log_util.log( + msg={"msg": "", "mod_name": __file__, "func_name": "Mask_Prediction"}, + use_dask=stage.options["use_dask"], + eflogging=config.logging, + error=e + ) + raise e + + return model + +def load_data_in_memory(config: Dataset, groups: Dict[str, Group]): + + for _, gr in groups.items(): + # From a store (list of file paths) fetch the slice of data and keep it in memory + if gr.metadata and gr.metadata.is_store_folder and len(gr.data) > 0: + edf = fetch_slice_from_store(edf_group=gr, config=config, start_time=edf.start_time, end_time=edf.end_time) + if edf.data.notnull().any(): + gr.data = [edf] + gr.metadata.is_store_folder = False + else: + continue + return groups \ No newline at end of file diff --git a/echodataflow/utils/xr_utils.py b/echodataflow/utils/xr_utils.py new file mode 100644 index 0000000..84174bd --- /dev/null +++ b/echodataflow/utils/xr_utils.py @@ -0,0 +1,173 @@ + + + +from typing import Any, Dict, List, Tuple + +import torch +from echodataflow.models.datastore import Dataset +from echodataflow.models.output_model import EchodataflowObject, Group +import xarray as xr +import pandas as pd +import numpy as np + +from echodataflow.models.pipeline import Stage +from echodataflow.utils import log_util + + +def fetch_slice_from_store(edf_group: Group, config: Dataset, options: Dict[str, Any] = None, start_time: str = None, end_time: str = None) -> EchodataflowObject: + edf = edf_group.data[0] + default_options = { + "engine":"zarr", + "combine":"by_coords", + "data_vars":"minimal", + "coords":"minimal", + "compat":"override", + "storage_options": config.args.storage_options_dict} if options is None else options + + if options: + default_options.update(options) + + if start_time is None: + start_time = edf.start_time + if end_time is None: + end_time = edf.end_time + + store = xr.open_mfdataset(paths=[ed.out_path for ed in edf_group.data], **default_options).compute() + store_slice = store.sel(ping_time=slice(pd.to_datetime(start_time, unit="ns"), pd.to_datetime(end_time, unit="ns"))) + + if store_slice["ping_time"].size == 0: + del store + del store_slice + raise ValueError(f"No data available between {start_time} and {end_time}") + + del store + edf.data = store_slice + + return edf + +def assemble_da(data_array: xr.DataArray, dims: Dict[str, Any]): + da = xr.DataArray( + data_array, dims=dims.keys() + ) + da = da.assign_coords(dims + ) + return da + +def process_xrd(ds: xr.Dataset, freq_wanted = [120000, 38000, 18000]) -> xr.Dataset: + ds = ds.sel(depth=slice(None, 590)) + + ch_wanted = [int((np.abs(ds["frequency_nominal"]-freq)).argmin()) for freq in freq_wanted] + ds = ds.isel( + channel=ch_wanted + ) + return ds + +def combine_datasets(store_18: xr.Dataset, store_5: xr.Dataset, config: Dataset) -> Tuple[xr.Dataset, torch.Tensor]: + ds_32k_120k = None + ds_18k = None + combined_ds = None + try: + partial_channel_name = ["ES18"] + ds_18k = extract_channels(store_18, partial_channel_name) + partial_channel_name = ["ES38", "ES120"] + ds_32k_120k = extract_channels(store_5, partial_channel_name) + except Exception as e: + partial_channel_name = ["ES18"] + ds_18k = extract_channels(store_5, partial_channel_name) + partial_channel_name = ["ES38", "ES120"] + ds_32k_120k = extract_channels(store_18, partial_channel_name) + + if not ds_18k or not ds_32k_120k: + raise ValueError("Could not find the required channels in the datasets") + + ds_18k = process_xrd(ds_18k, freq_wanted=[18000]) + ds_32k_120k = process_xrd(ds_32k_120k, freq_wanted=[120000, 38000]) + + combined_ds = xr.merge([ds_18k["Sv"], ds_32k_120k["Sv"], + ds_18k['latitude'], ds_18k['longitude'], + ds_18k["frequency_nominal"], ds_32k_120k["frequency_nominal"] + ]) + combined_ds.attrs = ds_18k.attrs + + + return convert_to_tensor(combined_ds=combined_ds, config=config) + +def convert_to_tensor(combined_ds: xr.Dataset, config: Dataset, freq_wanted: List[int] = [120000, 38000, 18000]) -> Tuple[xr.Dataset, torch.Tensor]: + """ + Convert dataset to a tensor and return the tensor and the dataset. + """ + + ch_wanted = [int((np.abs(combined_ds["frequency_nominal"]-freq)).argmin()) for freq in freq_wanted] + + log_util.log( + msg={"msg": f"Channel order {ch_wanted}", "mod_name": __file__, "func_name": "xr_utils.convert_to_tensor"}, + use_dask=False, + eflogging=config.logging, + ) + + depth = combined_ds['depth'] + ping_time = combined_ds['ping_time'] + + # Create a tensor with R=120 kHz, G=38 kHz, B=18 kHz mapping + red_channel = extract_channels(combined_ds, ["ES120"]) + green_channel = extract_channels(combined_ds, ["ES38"]) + blue_channel = extract_channels(combined_ds, ["ES18"]) + + ds = xr.concat([red_channel, green_channel, blue_channel], dim='channel') + ds['channel'] = ['R', 'G', 'B'] + ds = ds.assign_coords({'depth': depth, 'ping_time': ping_time}) + + ds = ( + ds + .transpose("channel", "depth", "ping_time") + .isel(channel=ch_wanted) + ) + + mvbs_tensor = torch.tensor(ds['Sv'].values, dtype=torch.float32) + + da_MVBS_tensor = torch.clip( + mvbs_tensor.clone().detach().to(torch.float16), + min=-70, + max=-36, + ) + log_util.log( + msg={"msg": f"converted and clipped tensor", "mod_name": __file__, "func_name": "xr_utils.convert_to_tensor"}, + use_dask=False, + eflogging=config.logging, + ) + + # Replace NaN values with min Sv + da_MVBS_tensor[torch.isnan(da_MVBS_tensor)] = -70 + + MVBS_tensor_normalized = ( + (da_MVBS_tensor - (-70.0)) / (-36.0 - (-70.0)) * 255.0 + ) + input_tensor = MVBS_tensor_normalized.unsqueeze(0).float() + log_util.log( + msg={"msg": f"Normalized tensor", "mod_name": __file__, "func_name": "xr_utils.convert_to_tensor"}, + use_dask=False, + eflogging=config.logging, + ) + + return (ds, input_tensor) + + +def extract_channels(dataset: xr.Dataset, partial_names: List[str]) -> xr.Dataset: + """ + Extracts multiple channels data from the given xarray dataset using partial channel names. + + Args: + dataset (xr.Dataset): The input xarray dataset containing multiple channels. + partial_names (List[str]): The list of partial names of the channels to extract. + + Returns: + xr.Dataset: The dataset containing only the specified channels data. + """ + matching_channels = [] + for partial_name in partial_names: + matching_channels.extend([channel for channel in dataset.channel.values if partial_name in str(channel)]) + + if len(matching_channels) == 0: + raise ValueError(f"No channels found matching any of '{partial_names}'") + + return dataset.sel(channel=matching_channels) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index e805c72..f9468ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ write_to = "echodataflow/version.py" line-length = 100 [tool.pytest.ini_options] -testpaths = ["echodataflow/tests/flow_tests"] +testpaths = ["echodataflow/tests/flow_tests", "echodataflow/tests/unit"] addopts = "--cov=./ --cov-report=term --cov-report=xml" filterwarnings = [ "ignore::DeprecationWarning" @@ -30,5 +30,7 @@ include_namespace_packages = true omit = [ "*/tests/*", "*/__init__.py", - "*/docs/*" + "*/docs/*", + "*/deployment/*", + "*/setup.py", ] \ No newline at end of file