From 3dd703c1428653438d2a10ed04235e67db469237 Mon Sep 17 00:00:00 2001 From: Andy Aschwanden Date: Sun, 22 Dec 2024 19:03:19 +0100 Subject: [PATCH] refactoring Separeted sampling and plotting. --- analysis/analyze_scalar.py | 368 ++++++++------------- data/03_prepare_mass_balance.py | 25 +- pism_ragis/data/ragis_config.toml | 30 +- pism_ragis/filtering.py | 117 ++++++- pism_ragis/plotting.py | 528 ++++++++++++++++-------------- pism_ragis/processing.py | 85 +++++ pyproject.toml | 2 +- 7 files changed, 661 insertions(+), 494 deletions(-) diff --git a/analysis/analyze_scalar.py b/analysis/analyze_scalar.py index 6776296..2bb0b67 100644 --- a/analysis/analyze_scalar.py +++ b/analysis/analyze_scalar.py @@ -30,11 +30,12 @@ from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed from functools import wraps from importlib.resources import files +from itertools import chain from pathlib import Path from typing import Any, Callable, Dict, Hashable, List, Mapping, Union import dask -import matplotlib +import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np import pandas as pd @@ -49,8 +50,7 @@ import pism_ragis.processing as prp from pism_ragis.analyze import delta_analysis, sobol_analysis from pism_ragis.decorators import profileit, timeit -from pism_ragis.filtering import filter_outliers, importance_sampling -from pism_ragis.likelihood import log_normal +from pism_ragis.filtering import filter_outliers, run_importance_sampling from pism_ragis.logger import get_logger from pism_ragis.plotting import ( plot_basins, @@ -58,10 +58,11 @@ plot_prior_posteriors, plot_sensitivity_indices, ) +from pism_ragis.processing import config_to_dataframe, filter_config logger = get_logger("pism_ragis") -matplotlib.use("Agg") +# mpl.use("Agg") xr.set_options(keep_attrs=True) plt.style.use("tableau-colorblind10") # Ignore specific RuntimeWarnings @@ -73,6 +74,36 @@ ) +def sort_columns(df: pd.DataFrame, sorted_columns: list) -> pd.DataFrame: + """ + Sort columns of a DataFrame. + + This function sorts the columns of a DataFrame such that the columns specified in + `sorted_columns` appear in the specified order, while all other columns appear before + the sorted columns in their original order. + + Parameters + ---------- + df : pd.DataFrame + The input DataFrame to be sorted. + sorted_columns : list + A list of column names to be sorted. + + Returns + ------- + pd.DataFrame + The DataFrame with columns sorted as specified. + """ + # Identify columns that are not in the list + other_columns = [col for col in df.columns if col not in sorted_columns] + + # Concatenate other columns with the sorted columns + new_column_order = other_columns + sorted_columns + + # Reindex the DataFrame + return df.reindex(columns=new_column_order) + + def add_prefix_coord( sensitivity_indices: xr.Dataset, parameter_groups: Dict ) -> xr.Dataset: @@ -166,156 +197,6 @@ def prepare_input( return df -@timeit -def run_sampling( - observed: xr.Dataset, - simulated: xr.Dataset, - obs_mean_vars: List[str] = ["grounding_line_flux", "mass_balance"], - obs_std_vars: List[str] = [ - "grounding_line_flux_uncertainty", - "mass_balance_uncertainty", - ], - sim_vars: List[str] = ["grounding_line_flux", "mass_balance"], - filter_range: List[int] = [1990, 2019], - fudge_factor: float = 3.0, - fig_dir: Union[str, Path] = "figures", - params: List[str] = [], - config: Dict = {}, -) -> pd.DataFrame: - """ - Run sampling to process observed and simulated datasets. - - This function performs importance sampling using the specified observed and simulated datasets, - processes the results, and returns a DataFrame with the prior and posterior configurations. - - Parameters - ---------- - observed : xr.Dataset - The observed dataset. - simulated : xr.Dataset - The simulated dataset. - obs_mean_vars : List[str], optional - A list of variable names for the observed mean values, by default ["grounding_line_flux", "mass_balance"]. - obs_std_vars : List[str], optional - A list of variable names for the observed standard deviation values, by default ["grounding_line_flux_uncertainty", "mass_balance_uncertainty"]. - sim_vars : List[str], optional - A list of variable names for the simulated values, by default ["grounding_line_flux", "mass_balance"]. - filter_range : List[int], optional - A list containing the start and end years for filtering, by default [1990, 2019]. - fudge_factor : float, optional - A fudge factor for the importance sampling, by default 3.0. - fig_dir : Union[str, Path], optional - The directory where figures will be saved, by default "figures". - params : List[str], optional - A list of parameter names to be used for filtering configurations, by default []. - config : Dict, optional - A dictionary containing configuration settings for the RAGIS model, by default {}. - - Returns - ------- - pd.DataFrame - A DataFrame containing the prior and posterior configurations. - - Examples - -------- - >>> observed = xr.Dataset(...) - >>> simulated = xr.Dataset(...) - >>> result = run_sampling(observed, simulated) - """ - filter_start_year, filter_end_year = filter_range - - simulated_prior = simulated - simulated_prior["ensemble"] = "Prior" - - prior_config = filter_config(simulated.isel({"time": 0}), params) - prior_df = config_to_dataframe(prior_config, ensemble="Prior") - prior_posterior_list = [] - - for obs_mean_var, obs_std_var, sim_var in zip( - obs_mean_vars, obs_std_vars, sim_vars - ): - print(f"Importance sampling using {obs_mean_var}") - f = importance_sampling( - simulated=simulated.sel( - time=slice(str(filter_start_year), str(filter_end_year)) - ), - observed=observed.sel( - time=slice(str(filter_start_year), str(filter_end_year)) - ), - log_likelihood=log_normal, - fudge_factor=fudge_factor, - n_samples=len(simulated.exp_id), - obs_mean_var=obs_mean_var, - obs_std_var=obs_std_var, - sim_var=sim_var, - ) - - with ProgressBar() as pbar: - result = f.compute() - logger.info( - "Importance Sampling: Finished in %2.2f seconds", pbar.last_duration - ) - - importance_sampled_ids = result["exp_id_sampled"] - importance_sampled_ids["basin"] = importance_sampled_ids["basin"].astype(str) - - simulated_posterior = simulated.sel(exp_id=importance_sampled_ids) - simulated_posterior["ensemble"] = "Posterior" - - posterior_config = filter_config(simulated_posterior.isel({"time": 0}), params) - posterior_df = config_to_dataframe(posterior_config, ensemble="Posterior") - - prior_posterior_f = pd.concat([prior_df, posterior_df]).reset_index(drop=True) - prior_posterior_f["filtered_by"] = obs_mean_var - prior_posterior_list.append(prior_posterior_f) - - plot_basins( - observed, - simulated_prior, - simulated_posterior, - obs_mean_var, - filter_range=filter_range, - fig_dir=fig_dir, - config=config, - ) - prior_posterior = pd.concat(prior_posterior_list).reset_index(drop=True) - prior_posterior = prior_posterior.apply(prp.convert_column_to_numeric) - return prior_posterior - - -def filter_config(ds: xr.Dataset, params: list[str]) -> xr.DataArray: - """ - Filter the configuration parameters from the dataset. - - This function selects the specified configuration parameters from the dataset - and returns them as a DataArray. - - Parameters - ---------- - ds : xr.Dataset - The input dataset containing the configuration parameters. - params : List[str] - A list of configuration parameter names to be selected. - - Returns - ------- - xr.DataArray - The selected configuration parameters as a DataArray. - - Examples - -------- - >>> ds = xr.Dataset({'pism_config': (('pism_config_axis',), [1, 2, 3])}, - coords={'pism_config_axis': ['param1', 'param2', 'param3']}) - >>> filter_config(ds, ['param1', 'param3']) - - array([1, 3]) - Coordinates: - * pism_config_axis (pism_config_axis) pd.DataFrame: - """ - Convert an xarray DataArray configuration to a pandas DataFrame. - - This function converts the input DataArray containing configuration data into a - pandas DataFrame. The dimensions of the DataArray (excluding 'pism_config_axis') - are used as the index, and the 'pism_config_axis' values are used as columns. - - Parameters - ---------- - config : xr.DataArray - The input DataArray containing the configuration data. - ensemble : Union[str, None], optional - An optional string to add as a column named 'ensemble' in the DataFrame, by default None. - - Returns - ------- - pd.DataFrame - A DataFrame where the dimensions of the DataArray (excluding 'pism_config_axis') - are used as the index, and the 'pism_config_axis' values are used as columns. - - Examples - -------- - >>> config = xr.DataArray( - ... data=[[1, 2, 3], [4, 5, 6]], - ... dims=["time", "pism_config_axis"], - ... coords={"time": [0, 1], "pism_config_axis": ["param1", "param2", "param3"]} - ... ) - >>> df = config_to_dataframe(config) - >>> print(df) - pism_config_axis time param1 param2 param3 - 0 0 1 2 3 - 1 1 4 5 6 - - >>> df = config_to_dataframe(config, ensemble="ensemble1") - >>> print(df) - pism_config_axis time param1 param2 param3 ensemble - 0 0 1 2 3 ensemble1 - 1 1 4 5 6 ensemble1 - """ - dims = [dim for dim in config.dims if dim != "pism_config_axis"] - df = config.to_dataframe().reset_index() - df = df.pivot(index=dims, columns="pism_config_axis", values="pism_config") - df.reset_index(inplace=True) - if ensemble: - df["ensemble"] = ensemble - return df - - def convert_bstrings_to_str(element: Any) -> Any: """ Convert byte strings to regular strings. @@ -775,13 +604,19 @@ def run_sensitivity_analysis( png_dir = plot_dir / Path("pngs") png_dir.mkdir(parents=True, exist_ok=True) - plt.rcParams["font.size"] = 6 - - flux_vars = config["Flux Variables"] - flux_uncertainty_vars = { - k + "_uncertainty": v + "_uncertainty" for k, v in flux_vars.items() + rcparams = { + "axes.linewidth": 0.25, + "xtick.direction": "in", + "xtick.major.size": 2.5, + "xtick.major.width": 0.25, + "ytick.direction": "in", + "ytick.major.size": 2.5, + "ytick.major.width": 0.25, + "hatch.linewidth": 0.25, } + plt.rcParams.update(rcparams) + simulated_ds = prepare_simulations( basin_files, config, reference_date, parallel=parallel, engine=engine ) @@ -865,29 +700,73 @@ def run_sensitivity_analysis( {"time": resampling_frequency} ).mean() - prior_posterior_mankoff = run_sampling( - observed=observed_mankoff_basins_resampled_ds, - simulated=simulated_mankoff_basins_resampled_ds, - filter_range=filter_range, - fudge_factor=fudge_factor, - params=params, - config=config, - fig_dir=fig_dir, + obs_mean_vars_mankoff: List[str] = ["grounding_line_flux", "mass_balance"] + obs_std_vars_mankoff: List[str] = [ + "grounding_line_flux_uncertainty", + "mass_balance_uncertainty", + ] + sim_vars_mankoff: List[str] = ["grounding_line_flux", "mass_balance"] + + sim_plot_vars = ( + [ragis_config["Cumulative Variables"]["cumulative_mass_balance"]] + + list(ragis_config["Flux Variables"].values()) + + ["ensemble"] ) - prior_posterior_grace = run_sampling( - observed=observed_grace_basins_resampled_ds, - simulated=simulated_grace_basins_resampled_ds, - obs_mean_vars=["mass_balance"], - obs_std_vars=["mass_balance_uncertainty"], - sim_vars=["mass_balance"], - fudge_factor=10, - filter_range=filter_range, - params=params, - config=config, - fig_dir=fig_dir, + prior_posterior_mankoff, simulated_prior_mankoff, simulated_posterior_mankoff = ( + run_importance_sampling( + observed=observed_mankoff_basins_resampled_ds, + simulated=simulated_mankoff_basins_resampled_ds, + obs_mean_vars=obs_mean_vars_mankoff, + obs_std_vars=obs_std_vars_mankoff, + sim_vars=sim_vars_mankoff, + filter_range=filter_range, + fudge_factor=fudge_factor, + params=params, + ) + ) + + for filter_var in obs_mean_vars_mankoff: + plot_basins( + observed_mankoff_basins_resampled_ds, + simulated_prior_mankoff[sim_plot_vars], + simulated_posterior_mankoff.sel({"filtered_by": filter_var})[sim_plot_vars], + filter_var=filter_var, + filter_range=filter_range, + fig_dir=fig_dir, + config=config, + ) + + obs_mean_vars_grace: List[str] = ["mass_balance"] + obs_std_vars_grace: List[str] = [ + "mass_balance_uncertainty", + ] + sim_vars_grace: List[str] = ["mass_balance"] + + prior_posterior_grace, simulated_prior_grace, simulated_posterior_grace = ( + run_importance_sampling( + observed=observed_grace_basins_resampled_ds, + simulated=simulated_grace_basins_resampled_ds, + obs_mean_vars=obs_mean_vars_grace, + obs_std_vars=obs_std_vars_grace, + sim_vars=sim_vars_grace, + fudge_factor=fudge_factor, + filter_range=filter_range, + params=params, + ) ) + for filter_var in obs_mean_vars_grace: + plot_basins( + observed_grace_basins_resampled_ds, + simulated_prior_grace[sim_plot_vars], + simulated_posterior_grace.sel({"filtered_by": filter_var})[sim_plot_vars], + filter_var=filter_var, + filter_range=filter_range, + fig_dir=fig_dir, + config=config, + ) + prior_posterior = pd.concat( [prior_posterior_mankoff, prior_posterior_grace] ).reset_index(drop=True) @@ -910,11 +789,31 @@ def run_sensitivity_analysis( ) bins_dict = config["Posterior Bins"] + parameter_catetories = config["Parameter Categories"] + + params_sorted_by_category: dict = { + group: [] for group in sorted(parameter_catetories.values()) + } + for param in params: + prefix = param.split(".")[0] + if prefix in parameter_catetories: + group = parameter_catetories[prefix] + if param not in params_sorted_by_category[group]: + params_sorted_by_category[group].append(param) + + params_sorted_list = list(chain(*params_sorted_by_category.values())) + if "frontal_melt.routing.parameter_a" in prior_posterior.columns: + prior_posterior["frontal_melt.routing.parameter_a"] *= 10**4 + if "ocean.th.gamma_T" in prior_posterior.columns: + prior_posterior["ocean.th.gamma_T"] *= 10**5 + if "calving.vonmises_calving.sigma_max" in prior_posterior.columns: + prior_posterior["calving.vonmises_calving.sigma_max"] *= 10**-3 + prior_posterior_sorted = sort_columns(prior_posterior, params_sorted_list) + + params_sorted_dict = {k: params_short_dict[k] for k in params_sorted_list} plot_prior_posteriors( - prior_posterior.rename(columns=params_short_dict), - bins_dict, + prior_posterior_sorted.rename(columns=params_sorted_dict), fig_dir=fig_dir, - config=config, ) prior_config = filter_config(simulated.isel({"time": 0}), params) @@ -944,8 +843,7 @@ def run_sensitivity_analysis( si_dir.mkdir(parents=True, exist_ok=True) sensitivity_indices.to_netcdf(si_dir / Path("sensitivity_indices.nc")) - parameter_groups = config["Parameter Groups"] - sensitivity_indices = add_prefix_coord(sensitivity_indices, parameter_groups) + sensitivity_indices = add_prefix_coord(sensitivity_indices, parameter_catetories) # Group by the new coordinate and compute the sum for each group indices_vars = [v for v in sensitivity_indices.data_vars if "_conf" not in v] diff --git a/data/03_prepare_mass_balance.py b/data/03_prepare_mass_balance.py index 95b4533..65e1da5 100644 --- a/data/03_prepare_mass_balance.py +++ b/data/03_prepare_mass_balance.py @@ -114,10 +114,10 @@ short_name = "GREENLAND_MASS_TELLUS_MASCON_CRI_TIME_SERIES_RL06.1_V3" results = download_earthaccess(result_dir=p, short_name=short_name) - + grace_file = results[0] # Read the data into a pandas DataFrame df = pd.read_csv( - results[0], + grace_file, header=32, # Skip the header lines sep="\s+", names=[ @@ -135,16 +135,25 @@ df["time"] = date ds = xr.Dataset.from_dataframe(df.set_index(df["time"])) - ds["mass_balance"] = ds["cumulative_mass_balance"].diff(dim="time") + ds["cumulative_mass_balance"].attrs.update({"units": "Gt"}) ds["cumulative_mass_balance_uncertainty"] = np.sqrt( (ds["mass_balance_uncertainty"] ** 2).cumsum(dim="time") ) - ds["cumulative_mass_balance"].attrs.update({"units": "Gt"}) ds["cumulative_mass_balance_uncertainty"].attrs.update({"units": "Gt"}) - ds["mass_balance"].attrs.update({"units": "Gt year^-1"}) - ds["mass_balance_uncertainty"].attrs.update({"units": "Gt year^-1"}) - ds = ds.expand_dims({"basin": ["GRACE"]}, axis=-1) + + ds = ds.pint.quantify() + + days_in_interval = ( + (ds.time.diff(dim="time") / np.timedelta64(1, "s")) + .pint.quantify("s") + .pint.to("year") + ) + + ds["mass_balance"] = ( + ds["cumulative_mass_balance"].diff(dim="time") / days_in_interval + ) + ds = ds.expand_dims({"basin": ["GRACE"]}) fn = "grace_greenland_mass_balance.nc" p_fn = p / fn - grace_ds = ds + grace_ds = ds.pint.dequantify() save_netcdf(grace_ds, p_fn) diff --git a/pism_ragis/data/ragis_config.toml b/pism_ragis/data/ragis_config.toml index 649e2d0..18589de 100644 --- a/pism_ragis/data/ragis_config.toml +++ b/pism_ragis/data/ragis_config.toml @@ -1,23 +1,23 @@ [Parameters] -'calving.vonmises_calving.sigma_max' = '$\sigma_{\mathrm{max}}$ (Pa)' +'calving.vonmises_calving.sigma_max' = '$\sigma_{\mathrm{max}}$ (kPa)' 'calving.rate_scaling.file' = 'Calving' -'geometry.front_retreat.prescribed.file' = 'Retreat' -'ocean.th.gamma_T' = '$\gamma$' -'surface.given.file' = 'Climate' -'ocean.th.file' = 'Ocean' -'frontal_melt.routing.parameter_a' = '$a$' -'frontal_melt.routing.parameter_b' = '$b$' +'geometry.front_retreat.prescribed.file' = 'Retreat Method' +'ocean.th.gamma_T' = '$\gamma$ (10$^{-5}$ 1)' +'surface.given.file' = 'Climate Forcing' +'ocean.th.file' = 'Ocean Forcing' +'frontal_melt.routing.parameter_a' = '$a$ (10$^{-4}$ m$^{-\alpha}$ day$^{\alpha-1}$ Celsius$^{-\beta}$)' +'frontal_melt.routing.parameter_b' = '$b$ (day$^{\alpha-1)}$ Celsius$^{-\beta}$)' 'frontal_melt.routing.power_alpha' = '$\alpha$ (1)' 'frontal_melt.routing.power_beta' = '$\beta$ (1)' 'stress_balance.sia.enhancement_factor' = '$E_{\mathrm{SIA}}$ (1)' 'stress_balance.ssa.Glen_exponent' = '$n_{\mathrm{SSA}}$ (1)' 'basal_resistance.pseudo_plastic.q' = '$q$ (1)' 'basal_yield_stress.mohr_coulomb.till_effective_fraction_overburden' = '$\delta$ (1)' -'basal_yield_stress.mohr_coulomb.topg_to_phi.phi_min' = '$\phi_{\mathrm{min}}^{\circ{}}$' -'basal_yield_stress.mohr_coulomb.topg_to_phi.phi_max' = '$\phi_{\mathrm{max}}^{\circ{}}$' -'basal_yield_stress.mohr_coulomb.topg_to_phi.topg_min' = '$z_{\mathrm{min}}^{\circ{}}$' -'basal_yield_stress.mohr_coulomb.topg_to_phi.topg_max' = '$z_{\mathrm{min}}^{\circ{}}$' +'basal_yield_stress.mohr_coulomb.topg_to_phi.phi_min' = '$\phi_{\mathrm{min}} (^{\circ{}})$' +'basal_yield_stress.mohr_coulomb.topg_to_phi.phi_max' = '$\phi_{\mathrm{max}} (^{\circ{}})$' +'basal_yield_stress.mohr_coulomb.topg_to_phi.topg_min' = '$z_{\mathrm{min}}$ (m)' +'basal_yield_stress.mohr_coulomb.topg_to_phi.topg_max' = '$z_{\mathrm{max}}$ (m)' ['Posterior Bins'] @@ -40,7 +40,7 @@ 'basal_yield_stress.mohr_coulomb.topg_to_phi.topg_min' = 10 'basal_yield_stress.mohr_coulomb.topg_to_phi.topg_max' = 10 -['Parameter Groups'] +['Parameter Categories'] "surface" = "Climate" "atmosphere" = "Climate" @@ -59,6 +59,12 @@ mass_flux = "mass_balance" grounding_line_flux = "grounding_line_flux" smb_flux = "surface_mass_balance" +['Flux Uncertainty Variables'] + +mass_flux_uncertainty = "mass_balance_uncertainty" +grounding_line_flux_uncertainty = "grounding_line_flux_uncertainty" +smb_flux_uncertainty = "surface_mass_balance_uncertainty" + ['Cumulative Variables'] cumulative_mass_balance = "cumulative_mass_balance" diff --git a/pism_ragis/filtering.py b/pism_ragis/filtering.py index 553ea4d..9dd970e 100644 --- a/pism_ragis/filtering.py +++ b/pism_ragis/filtering.py @@ -24,13 +24,18 @@ import logging import warnings -from typing import Callable, Dict, List +from typing import Callable, Dict, List, Tuple import numpy as np +import pandas as pd import xarray as xr +from dask.diagnostics import ProgressBar +import pism_ragis.processing as prp from pism_ragis.decorators import timeit +from pism_ragis.likelihood import log_normal from pism_ragis.logger import get_logger +from pism_ragis.processing import config_to_dataframe, filter_config logger: logging.Logger = get_logger("pism_ragis") @@ -262,3 +267,113 @@ def filter_outliers( outliers_ds = ds.sel(exp_id=outlier_exp_ids) print(f"Ensemble size: {n_members}, outlier-filtered size: {n_members_filtered}\n") return filtered_ds, outliers_ds + + +@timeit +def run_importance_sampling( + observed: xr.Dataset, + simulated: xr.Dataset, + obs_mean_vars: List[str] = ["grounding_line_flux", "mass_balance"], + obs_std_vars: List[str] = [ + "grounding_line_flux_uncertainty", + "mass_balance_uncertainty", + ], + sim_vars: List[str] = ["grounding_line_flux", "mass_balance"], + filter_range: List[int] = [1990, 2019], + fudge_factor: float = 3.0, + params: List[str] = [], +) -> Tuple[pd.DataFrame, xr.Dataset, xr.Dataset]: + """ + Run sampling to process observed and simulated datasets. + + This function performs importance sampling using the specified observed and simulated datasets, + processes the results, and returns a DataFrame with the prior and posterior configurations. + + Parameters + ---------- + observed : xr.Dataset + The observed dataset. + simulated : xr.Dataset + The simulated dataset. + obs_mean_vars : List[str], optional + A list of variable names for the observed mean values, by default ["grounding_line_flux", "mass_balance"]. + obs_std_vars : List[str], optional + A list of variable names for the observed standard deviation values, by default ["grounding_line_flux_uncertainty", "mass_balance_uncertainty"]. + sim_vars : List[str], optional + A list of variable names for the simulated values, by default ["grounding_line_flux", "mass_balance"]. + filter_range : List[int], optional + A list containing the start and end years for filtering, by default [1990, 2019]. + fudge_factor : float, optional + A fudge factor for the importance sampling, by default 3.0. + params : List[str], optional + A list of parameter names to be used for filtering configurations, by default []. + + Returns + ------- + pd.DataFrame + A DataFrame containing the prior and posterior configurations. + + Examples + -------- + >>> observed = xr.Dataset(...) + >>> simulated = xr.Dataset(...) + >>> result = run_sampling(observed, simulated) + """ + filter_start_year, filter_end_year = filter_range + + simulated_prior = simulated + simulated_prior["ensemble"] = "Prior" + + prior_config = filter_config(simulated.isel({"time": 0}), params) + prior_df = config_to_dataframe(prior_config, ensemble="Prior") + + prior_posterior_list = [] + posterior_list = [] + for obs_mean_var, obs_std_var, sim_var in zip( + obs_mean_vars, obs_std_vars, sim_vars + ): + print(f"Importance sampling using {obs_mean_var}") + f = importance_sampling( + simulated=simulated.sel( + time=slice(str(filter_start_year), str(filter_end_year)) + ), + observed=observed.sel( + time=slice(str(filter_start_year), str(filter_end_year)) + ), + log_likelihood=log_normal, + fudge_factor=fudge_factor, + n_samples=len(simulated.exp_id), + obs_mean_var=obs_mean_var, + obs_std_var=obs_std_var, + sim_var=sim_var, + ) + + with ProgressBar() as pbar: + result = f.compute() + logger.info( + "Importance Sampling: Finished in %2.2f seconds", pbar.last_duration + ) + + importance_sampled_ids = result["exp_id_sampled"] + importance_sampled_ids["basin"] = importance_sampled_ids["basin"].astype(str) + + simulated_posterior = simulated.sel(exp_id=importance_sampled_ids) + simulated_posterior["ensemble"] = "Posterior" + simulated_posterior = simulated_posterior.expand_dims( + {"filtered_by": [obs_mean_var]} + ) + + posterior_config = filter_config(simulated_posterior.isel({"time": 0}), params) + posterior_df = config_to_dataframe(posterior_config, ensemble="Posterior") + + prior_posterior_f = pd.concat([prior_df, posterior_df]).reset_index(drop=True) + prior_posterior_f["filtered_by"] = obs_mean_var + prior_posterior_list.append(prior_posterior_f) + + posterior_list.append(simulated_posterior) + + prior_posterior = pd.concat(prior_posterior_list).reset_index(drop=True) + prior_posterior = prior_posterior.apply(prp.convert_column_to_numeric) + prior = simulated_prior + posterior = xr.concat(posterior_list, dim="filtered_by") + return prior_posterior, prior, posterior diff --git a/pism_ragis/plotting.py b/pism_ragis/plotting.py index 64875c1..da77f62 100644 --- a/pism_ragis/plotting.py +++ b/pism_ragis/plotting.py @@ -26,6 +26,7 @@ from pathlib import Path from typing import Dict, List, Union +import matplotlib as mpl import matplotlib.pylab as plt import numpy as np import pandas as pd @@ -49,9 +50,8 @@ @timeit def plot_prior_posteriors( df: pd.DataFrame, - bins_dict: Dict, fig_dir: Union[str, Path] = "figures", - config: Dict = {}, + fontsize: float = 4, ): """ Plot histograms of prior and posterior distributions. @@ -60,16 +60,12 @@ def plot_prior_posteriors( ---------- df : pd.DataFrame DataFrame containing the data to plot. - bins_dict : Dict - Dictionary containing the number of bins for each variable. fig_dir : Union[str, Path], optional Directory to save the figures, by default "figures". - config : Dict, optional - Configuration dictionary, by default {}. + fontsize : float, optional + Font size for the plot, by default 4. """ - params_short_dict = config["Parameters"] - plot_dir = fig_dir / Path("basin_histograms") plot_dir.mkdir(parents=True, exist_ok=True) pdf_dir = plot_dir / Path("pdfs") @@ -77,52 +73,59 @@ def plot_prior_posteriors( png_dir = plot_dir / Path("pngs") png_dir.mkdir(parents=True, exist_ok=True) - for (basin, filtering_var), m_df in df.groupby(by=["basin", "filtered_by"]): - plt.rcParams["font.size"] = 4 - fig, axs = plt.subplots( - 4, - 4, - sharey=True, - figsize=[6.2, 4.2], - ) - fig.subplots_adjust(hspace=0.75, wspace=0.1) - for k, (v, v_s) in enumerate(params_short_dict.items()): - legend = bool(k == 0) - try: - _ = sns.histplot( - data=m_df, - x=v_s, - hue="ensemble", - hue_order=["Prior", "Posterior"], - bins=bins_dict[v], - palette=sim_cmap, - common_norm=False, - stat="probability", - multiple="dodge", - alpha=0.8, - linewidth=0.2, - ax=axs.ravel()[k], - legend=legend, - ) - except: - pass - if legend: - axs.ravel()[k].get_legend().set_title(None) - axs.ravel()[k].get_legend().get_frame().set_linewidth(0.0) - axs.ravel()[k].get_legend().get_frame().set_alpha(0.0) - - for ax in axs.flatten(): - ax.set_ylabel("") - ax.set_ylim(0, 1) - ticklabels = ax.get_xticklabels() - for tick in ticklabels: - tick.set_rotation(15) - fn = pdf_dir / Path(f"{basin}_prior_posterior_filtered_by_{filtering_var}.pdf") - fig.savefig(fn) - fn = png_dir / Path(f"{basin}_prior_posterior_filtered_by_{filtering_var}.png") - fig.savefig(fn, dpi=300) - plt.close() - del fig + group_columns = ["basin", "filtered_by"] + + rc_params = { + "font.size": fontsize, + # Add other rcParams settings if needed + } + + with mpl.rc_context(rc=rc_params): + for (basin, filter_var), m_df in df.groupby(by=group_columns): + m_df = m_df.drop(columns=group_columns + ["exp_id"]) + fig, axs = plt.subplots( + 3, + 6, + sharey=False, + figsize=[6.2, 3.2], + ) + fig.subplots_adjust(hspace=0.5, wspace=0.22) + for k, v in enumerate(m_df.drop(columns=["ensemble"]).columns): + legend = bool(k == 1) + try: + _ = sns.histplot( + data=m_df, + x=v, + hue="ensemble", + hue_order=["Prior", "Posterior"], + palette=sim_cmap, + common_norm=False, + stat="count", + multiple="dodge", + alpha=0.8, + linewidth=0.2, + ax=axs.ravel()[k], + legend=legend, + ) + except: + pass + if legend: + axs.ravel()[k].get_legend().set_title(None) + axs.ravel()[k].get_legend().get_frame().set_linewidth(0.0) + axs.ravel()[k].get_legend().get_frame().set_alpha(0.0) + + for ax in axs.flatten(): + ax.set_ylabel("") + # ax.set_ylim(0, 1) + ticklabels = ax.get_xticklabels() + for tick in ticklabels: + tick.set_rotation(15) + fn = pdf_dir / Path(f"{basin}_prior_posterior_filtered_by_{filter_var}.pdf") + fig.savefig(fn) + fn = png_dir / Path(f"{basin}_prior_posterior_filtered_by_{filter_var}.png") + fig.savefig(fn, dpi=300) + plt.close() + del fig @timeit @@ -130,7 +133,7 @@ def plot_basins( observed: xr.Dataset, prior: xr.Dataset, posterior: xr.Dataset, - filtering_var: str, + filter_var: str, filter_range: List[int] = [1990, 2019], fig_dir: Union[str, Path] = "figures", plot_range: List[int] = [1980, 2020], @@ -151,7 +154,7 @@ def plot_basins( The prior dataset. posterior : xr.Dataset The posterior dataset. - filtering_var : str + filter_var : str The variable used for filtering. filter_range : List[int], optional A list containing the start and end years for filtering, by default [1990, 2019]. @@ -186,7 +189,7 @@ def plot_basins( {"time": slice(str(plot_range[0]), str(plot_range[1]))} ), config=config, - filtering_var=filtering_var, + filter_var=filter_var, filter_range=filter_range, fig_dir=fig_dir, obs_alpha=obs_alpha, @@ -232,26 +235,28 @@ def plot_sensitivity_indices( png_dir = plot_dir / "pngs" png_dir.mkdir(parents=True, exist_ok=True) - plt.rcParams["font.size"] = fontsize - - fig, ax = plt.subplots(1, 1, figsize=(6.2, 3.6)) - for g in ds.sensitivity_indices_group: - indices_da = ds[indices_var].sel(sensitivity_indices_group=g) - conf_da = ds[indices_conf_var].sel(sensitivity_indices_group=g) - ax.fill_between( - indices_da.time, - (indices_da - conf_da), - (indices_da + conf_da), - alpha=0.25, - ) - indices_da.plot(hue="sensitivity_indices_group", ax=ax, lw=0.75, label=g.values) - legend = ax.legend(loc="upper left") - legend.get_frame().set_linewidth(0.0) - legend.get_frame().set_alpha(0.0) - ax.set_title(f"{indices_var} for basin {basin} for {filter_var}") - fn = pdf_dir / f"basin_{basin}_{indices_var}_for_{filter_var}.pdf" - fig.savefig(fn) - plt.close() + with mpl.rc_context({"font.size": fontsize}): + + fig, ax = plt.subplots(1, 1, figsize=(6.2, 3.6)) + for g in ds.sensitivity_indices_group: + indices_da = ds[indices_var].sel(sensitivity_indices_group=g) + conf_da = ds[indices_conf_var].sel(sensitivity_indices_group=g) + ax.fill_between( + indices_da.time, + (indices_da - conf_da), + (indices_da + conf_da), + alpha=0.25, + ) + indices_da.plot( + hue="sensitivity_indices_group", ax=ax, lw=0.75, label=g.values + ) + legend = ax.legend(loc="upper left") + legend.get_frame().set_linewidth(0.0) + legend.get_frame().set_alpha(0.0) + ax.set_title(f"{indices_var} for basin {basin} for {filter_var}") + fn = pdf_dir / f"basin_{basin}_{indices_var}_for_{filter_var}.pdf" + fig.savefig(fn) + plt.close() @timeit @@ -260,7 +265,7 @@ def plot_obs_sims( sim_prior: xr.Dataset, sim_posterior: xr.Dataset, config: dict, - filtering_var: str, + filter_var: str, filter_range: List[int] = [1990, 2019], fig_dir: Union[str, Path] = "figures", reference_year: float = 1986.0, @@ -283,7 +288,7 @@ def plot_obs_sims( Posterior simulation dataset. config : dict Configuration dictionary containing variable names. - filtering_var : str + filter_var : str Variable used for filtering. filter_range : List[int], optional Range of years for filtering, by default [1990, 2019]. @@ -317,174 +322,220 @@ def plot_obs_sims( basin = obs.basin.values mass_cumulative_varname = config["Cumulative Variables"]["cumulative_mass_balance"] - mass_cumulative_uncertainty_varname = mass_cumulative_varname + "_uncertainty" + mass_cumulative_uncertainty_varname = config["Cumulative Uncertainty Variables"][ + "cumulative_mass_balance_uncertainty" + ] grounding_line_flux_varname = config["Flux Variables"]["grounding_line_flux"] - grounding_line_flux_uncertainty_varname = ( - grounding_line_flux_varname + "_uncertainty" - ) - - plt.rcParams["font.size"] = fontsize - - fig, axs = plt.subplots( - 2, - 1, - sharex=True, - figsize=(6.2, 2.8), - height_ratios=[2, 1], - ) - fig.subplots_adjust(hspace=0.05, wspace=0.05) - - obs_ci = axs[0].fill_between( - obs["time"], - obs[mass_cumulative_varname] - sigma * obs[mass_cumulative_uncertainty_varname], - obs[mass_cumulative_varname] + sigma * obs[mass_cumulative_uncertainty_varname], - color=obs_cmap[0], - alpha=obs_alpha, - lw=0, - label=f"Observed ({sigma}-$\sigma$ uncertainty)", - ) - - if grounding_line_flux_varname in obs.data_vars: - axs[1].fill_between( + grounding_line_flux_uncertainty_varname = config["Flux Uncertainty Variables"][ + "grounding_line_flux_uncertainty" + ] + mass_flux_varname = config["Flux Variables"]["mass_flux"] + mass_flux_uncertainty_varname = config["Flux Uncertainty Variables"][ + "mass_flux_uncertainty" + ] + + with mpl.rc_context({"font.size": fontsize}): + + fig, axs = plt.subplots( + 3, + 1, + sharex=True, + figsize=(6.2, 3.6), + height_ratios=[2, 1, 1], + ) + fig.subplots_adjust(hspace=0.05, wspace=0.05) + + obs_ci = axs[0].fill_between( obs["time"], - obs[grounding_line_flux_varname] - - sigma * obs[grounding_line_flux_uncertainty_varname], - obs[grounding_line_flux_varname] - + sigma * obs[grounding_line_flux_uncertainty_varname], + obs[mass_cumulative_varname] + - sigma * obs[mass_cumulative_uncertainty_varname], + obs[mass_cumulative_varname] + + sigma * obs[mass_cumulative_uncertainty_varname], color=obs_cmap[0], alpha=obs_alpha, lw=0, + label=f"Observed ({sigma}-$\sigma$ uncertainty)", ) - sim_cis = [] - if sim_prior is not None: - sim_prior = sim_prior[ - [mass_cumulative_varname, grounding_line_flux_varname, "ensemble"] - ].load() - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered") - quantiles = {} - for q in [percentiles[0], 0.5, percentiles[1]]: - quantiles[q] = sim_prior.utils.drop_nonnumeric_vars().quantile( - q, dim="exp_id", skipna=True - ) + if mass_flux_varname in obs.data_vars: + axs[1].fill_between( + obs["time"], + obs[mass_flux_varname] - sigma * obs[mass_flux_uncertainty_varname], + obs[mass_flux_varname] + sigma * obs[mass_flux_uncertainty_varname], + color=obs_cmap[0], + alpha=obs_alpha, + lw=0, + ) - for k, m_var in enumerate( - [mass_cumulative_varname, grounding_line_flux_varname] - ): - sim_ci = axs[k].fill_between( - quantiles[0.5].time, - quantiles[percentiles[0]][m_var], - quantiles[percentiles[1]][m_var], - alpha=sim_alpha, - color=sim_cmap[0], - label=f"""{sim_prior["ensemble"].values} ({percentile_range:.0f}% credibility interval)""", + if grounding_line_flux_varname in obs.data_vars: + axs[-1].fill_between( + obs["time"], + obs[grounding_line_flux_varname] + - sigma * obs[grounding_line_flux_uncertainty_varname], + obs[grounding_line_flux_varname] + + sigma * obs[grounding_line_flux_uncertainty_varname], + color=obs_cmap[0], + alpha=obs_alpha, lw=0, ) - if k == 0: - sim_cis.append(sim_ci) - if sim_posterior is not None: - sim_posterior = sim_posterior[ - [mass_cumulative_varname, grounding_line_flux_varname, "ensemble"] - ].load() - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered") - quantiles = {} - for q in [percentiles[0], 0.5, percentiles[1]]: - quantiles[q] = sim_posterior.utils.drop_nonnumeric_vars().quantile( - q, dim="exp_id", skipna=True + + sim_cis = [] + if sim_prior is not None: + sim_prior = sim_prior[ + [ + mass_cumulative_varname, + mass_flux_varname, + grounding_line_flux_varname, + "ensemble", + ] + ].load() + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered") + quantiles = {} + for q in [percentiles[0], 0.5, percentiles[1]]: + quantiles[q] = sim_prior.utils.drop_nonnumeric_vars().quantile( + q, dim="exp_id", skipna=True + ) + + for k, m_var in enumerate( + [ + mass_cumulative_varname, + mass_flux_varname, + grounding_line_flux_varname, + ] + ): + sim_ci = axs[k].fill_between( + quantiles[0.5].time, + quantiles[percentiles[0]][m_var], + quantiles[percentiles[1]][m_var], + alpha=sim_alpha, + color=sim_cmap[0], + label=f"""{sim_prior["ensemble"].values} ({percentile_range:.0f}% credibility interval)""", + lw=0, + ) + if k == 0: + sim_cis.append(sim_ci) + if sim_posterior is not None: + sim_posterior = sim_posterior[ + [ + mass_cumulative_varname, + mass_flux_varname, + grounding_line_flux_varname, + "ensemble", + ] + ].load() + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered") + quantiles = {} + for q in [percentiles[0], 0.5, percentiles[1]]: + quantiles[q] = sim_posterior.utils.drop_nonnumeric_vars().quantile( + q, dim="exp_id", skipna=True + ) + + for k, m_var in enumerate( + [ + mass_cumulative_varname, + mass_flux_varname, + grounding_line_flux_varname, + ] + ): + sim_ci = axs[k].fill_between( + quantiles[0.5].time, + quantiles[percentiles[0]][m_var], + quantiles[percentiles[1]][m_var], + alpha=sim_alpha, + color=sim_cmap[1], + label=f"""{sim_posterior["ensemble"].values} ({percentile_range:.0f}% credibility interval)""", + lw=0, + ) + if k == 0: + sim_cis.append(sim_ci) + axs[k].plot( + quantiles[0.5].time, + quantiles[0.5][m_var], + lw=0.75, + color=sim_cmap[1], ) - for k, m_var in enumerate( - [mass_cumulative_varname, grounding_line_flux_varname] - ): - sim_ci = axs[k].fill_between( - quantiles[0.5].time, - quantiles[percentiles[0]][m_var], - quantiles[percentiles[1]][m_var], - alpha=sim_alpha, - color=sim_cmap[1], - label=f"""{sim_posterior["ensemble"].values} ({percentile_range:.0f}% credibility interval)""", - lw=0, + if sim_posterior is not None: + y_min, y_max = axs[-1].get_ylim() + scaler = y_min + (y_max - y_min) * 0.05 + obs_filtered = obs.sel( + time=slice(f"{filter_range[0]}", f"{filter_range[-1]}") ) - if k == 0: - sim_cis.append(sim_ci) - axs[k].plot( - quantiles[0.5].time, quantiles[0.5][m_var], lw=0.75, color=sim_cmap[1] + filter_range_ds = obs_filtered[mass_cumulative_varname] + filter_range_ds *= 0 + filter_range_ds += scaler + _ = filter_range_ds.plot( + ax=axs[-1], lw=1, ls="solid", color="k", label="Filtering Range" ) - - if sim_posterior is not None: - y_min, y_max = axs[1].get_ylim() - scaler = y_min + (y_max - y_min) * 0.05 - obs_filtered = obs.sel(time=slice(f"{filter_range[0]}", f"{filter_range[-1]}")) - filter_range_ds = obs_filtered[mass_cumulative_varname] - filter_range_ds *= 0 - filter_range_ds += scaler - _ = filter_range_ds.plot( - ax=axs[1], lw=1, ls="solid", color="k", label="Filtering Range" - ) - x_s = ( - filter_range_ds.time.values[0] - + (filter_range_ds.time.values[-1] - filter_range_ds.time.values[0]) / 2 - ) - y_s = scaler - axs[1].text( - x_s, - y_s, - "Filtering Range", - horizontalalignment="center", - fontweight="medium", + x_s = ( + filter_range_ds.time.values[0] + + (filter_range_ds.time.values[-1] - filter_range_ds.time.values[0]) / 2 + ) + y_s = scaler + axs[-1].text( + x_s, + y_s, + "Filtering Range", + horizontalalignment="center", + fontweight="medium", + ) + legend = axs[0].legend( + handles=[obs_ci, *sim_cis], ) - legend = axs[0].legend( - handles=[obs_ci, *sim_cis], - ) - legend.get_frame().set_linewidth(0.0) - legend.get_frame().set_alpha(0.0) - - axs[0].add_artist(legend) - - axs[0].xaxis.set_tick_params(labelbottom=False) - - axs[0].set_ylabel(f"Cumulative mass\nloss since {reference_year} (Gt)") - axs[0].set_xlabel("") - if sim_posterior is not None: - axs[0].set_title(f"{basin} filtered by {filtering_var}") - else: - axs[0].set_title(f"{basin}") - axs[1].set_title("") - axs[1].set_ylabel("Grounding Line\nFlux (Gt/yr)") - axs[-1].set_xlim(np.datetime64("1980-01-01"), np.datetime64("2020-01-01")) - fig.tight_layout() - - if sim_prior is not None: - prior_str = "prior" - else: - prior_str = "" - if sim_posterior is not None: - posterior_str = "_posterior" - else: - posterior_str = "" - prior_posterior_str = prior_str + posterior_str - - fig.savefig( - pdf_dir - / Path( - f"{basin}_mass_accounting_{prior_posterior_str}_filtered_by_{filtering_var}.pdf" + legend.get_frame().set_linewidth(0.0) + legend.get_frame().set_alpha(0.0) + + axs[0].add_artist(legend) + + axs[0].xaxis.set_tick_params(labelbottom=False) + + axs[0].set_ylabel(f"Cumulative mass\nloss since {reference_year} (Gt)") + axs[0].set_xlabel("") + if sim_posterior is not None: + axs[0].set_title(f"{basin} filtered by {filter_var}") + else: + axs[0].set_title(f"{basin}") + axs[1].set_title("") + axs[1].set_ylabel("Mass balance\n (Gt/yr)") + axs[-1].set_title("") + axs[-1].set_ylabel("Grounding Line\nFlux (Gt/yr)") + axs[-1].set_xlim(np.datetime64("1980-01-01"), np.datetime64("2020-01-01")) + fig.tight_layout() + + if sim_prior is not None: + prior_str = "prior" + else: + prior_str = "" + if sim_posterior is not None: + posterior_str = "_posterior" + else: + posterior_str = "" + prior_posterior_str = prior_str + posterior_str + + fig.savefig( + pdf_dir + / Path( + f"{basin}_mass_accounting_{prior_posterior_str}_filtered_by_{filter_var}.pdf" + ) ) - ) - fig.savefig( - png_dir - / Path( - f"{basin}_mass_accounting_{prior_posterior_str}_filtered_by_{filtering_var}.png", - dpi=300, + fig.savefig( + png_dir + / Path( + f"{basin}_mass_accounting_{prior_posterior_str}_filtered_by_{filter_var}.png", + dpi=300, + ) ) - ) - plt.close() - del fig + plt.close() + del fig def plot_outliers( - filtered_da: xr.DataArray, outliers_da: xr.DataArray, filename: Union[Path, str] + filtered_da: xr.DataArray, + outliers_da: xr.DataArray, + filename: Union[Path, str], + fontsize: int = 6, ): """ Plot outliers in the given DataArrays and save the plot to a file. @@ -501,6 +552,8 @@ def plot_outliers( The DataArray containing the outliers. filename : Union[Path, str] The path or filename where the plot will be saved. + fontsize : int, optional + The font size for the plot, by default 6. Examples -------- @@ -516,13 +569,14 @@ def plot_outliers( ... ) >>> plot_outliers(filtered_da, outliers_da, "outliers_plot.png") """ - fig, ax = plt.subplots(1, 1) - if outliers_da.size > 0: - outliers_da.plot( - hue="exp_id", color=sim_cmap[0], add_legend=False, ax=ax, lw=0.25 - ) - if filtered_da.size > 0: - filtered_da.plot( - hue="exp_id", color=sim_cmap[1], add_legend=False, ax=ax, lw=0.25 - ) - fig.savefig(filename) + with mpl.rc_context({"font.size": fontsize}): + fig, ax = plt.subplots(1, 1) + if outliers_da.size > 0: + outliers_da.plot( + hue="exp_id", color=sim_cmap[0], add_legend=False, ax=ax, lw=0.25 + ) + if filtered_da.size > 0: + filtered_da.plot( + hue="exp_id", color=sim_cmap[1], add_legend=False, ax=ax, lw=0.25 + ) + fig.savefig(filename) diff --git a/pism_ragis/processing.py b/pism_ragis/processing.py index 96e9a1c..4698215 100644 --- a/pism_ragis/processing.py +++ b/pism_ragis/processing.py @@ -951,3 +951,88 @@ def transpose_dataframe(df: pd.DataFrame, exp_id: str) -> pd.DataFrame: df.columns = param_names df["exp_id"] = exp_id return df + + +def filter_config(ds: xr.Dataset, params: list[str]) -> xr.DataArray: + """ + Filter the configuration parameters from the dataset. + + This function selects the specified configuration parameters from the dataset + and returns them as a DataArray. + + Parameters + ---------- + ds : xr.Dataset + The input dataset containing the configuration parameters. + params : List[str] + A list of configuration parameter names to be selected. + + Returns + ------- + xr.DataArray + The selected configuration parameters as a DataArray. + + Examples + -------- + >>> ds = xr.Dataset({'pism_config': (('pism_config_axis',), [1, 2, 3])}, + coords={'pism_config_axis': ['param1', 'param2', 'param3']}) + >>> filter_config(ds, ['param1', 'param3']) + + array([1, 3]) + Coordinates: + * pism_config_axis (pism_config_axis) pd.DataFrame: + """ + Convert an xarray DataArray configuration to a pandas DataFrame. + + This function converts the input DataArray containing configuration data into a + pandas DataFrame. The dimensions of the DataArray (excluding 'pism_config_axis') + are used as the index, and the 'pism_config_axis' values are used as columns. + + Parameters + ---------- + config : xr.DataArray + The input DataArray containing the configuration data. + ensemble : Union[str, None], optional + An optional string to add as a column named 'ensemble' in the DataFrame, by default None. + + Returns + ------- + pd.DataFrame + A DataFrame where the dimensions of the DataArray (excluding 'pism_config_axis') + are used as the index, and the 'pism_config_axis' values are used as columns. + + Examples + -------- + >>> config = xr.DataArray( + ... data=[[1, 2, 3], [4, 5, 6]], + ... dims=["time", "pism_config_axis"], + ... coords={"time": [0, 1], "pism_config_axis": ["param1", "param2", "param3"]} + ... ) + >>> df = config_to_dataframe(config) + >>> print(df) + pism_config_axis time param1 param2 param3 + 0 0 1 2 3 + 1 1 4 5 6 + + >>> df = config_to_dataframe(config, ensemble="ensemble1") + >>> print(df) + pism_config_axis time param1 param2 param3 ensemble + 0 0 1 2 3 ensemble1 + 1 1 4 5 6 ensemble1 + """ + dims = [dim for dim in config.dims if dim != "pism_config_axis"] + df = config.to_dataframe().reset_index() + df = df.pivot(index=dims, columns="pism_config_axis", values="pism_config") + df.reset_index(inplace=True) + if ensemble: + df["ensemble"] = ensemble + return df diff --git a/pyproject.toml b/pyproject.toml index 42e41c5..d41c6d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "pism-ragis" -version = "0.3.0" +version = "0.3.1" maintainers = [{name = "Andy Aschwanden", email = "andy.aschwanden@gmail.com"}] description = """Home of NASA ROSES project "A Reanalysis of the Greenland Ice Sheet""" readme = "README.md"