diff --git a/analysis/compute_basins_stats.py b/analysis/compute_basins_stats.py index 33477cb..9744e57 100755 --- a/analysis/compute_basins_stats.py +++ b/analysis/compute_basins_stats.py @@ -131,9 +131,7 @@ m_id = str(m_id_re.group(1)) n_ensemble = len(ensemble) - ds = ds.expand_dims( - {"ensemble_id": [ensemble], "exp_id": [m_id]}, axis=[-1, -2] - ) + ds = ds.expand_dims({"exp_id": [m_id]}) if "ice_mass" in ds: ds["ice_mass"] /= 1e12 @@ -151,39 +149,15 @@ bmb_floating_da.name = "tendency_of_ice_mass_due_to_basal_mass_flux_floating" ds = xr.merge([ds, bmb_grounded_da, bmb_floating_da]) - ds.rio.set_spatial_dims(x_dim="x", y_dim="y", inplace=True) + del ds["time"].attrs["bounds"] + ds = ds.drop_vars( + ["time_bounds", "timestamp"], errors="ignore" + ).rio.set_spatial_dims(x_dim="x", y_dim="y") ds.rio.write_crs(crs, inplace=True) p_config = ds["pism_config"] p_run_stats = ds["run_stats"] - ds = ds[mb_vars] - - print(f"Size in memory: {(ds.nbytes / 1024**3):.1f} GB") - - basins_file = result_dir / f"basins_sums_ensemble_{ensemble}_id_{m_id}.nc" - - client = Client() - print(f"Open client in browser: {client.dashboard_link}") - - start = time.time() - - basins_ds_scattered = client.scatter( - [ds] + [ds.rio.clip([basin.geometry]) for _, basin in basins.iterrows()] - ) - basin_names = ["GRACE"] + [basin["SUBREGION1"] for _, basin in basins.iterrows()] - n_basins = len(basin_names) - futures = client.map(compute_basin, basins_ds_scattered, basin_names) - progress(futures) - basin_sums = xr.concat(client.gather(futures), dim="basin").drop_vars( - ["mapping", "spatial_ref"] - ) - del basin_sums["time"].attrs["bounds"] - if cf: - basin_sums["basin"] = basin_sums["basin"].astype(f"S{n_basins}") - basin_sums["ensemble_id"] = basin_sums["ensemble_id"].astype(f"S{n_ensemble}") - basin_sums.attrs["Conventions"] = "CF-1.8" - # List of suffixes to exclude suffixes_to_exclude = ["_doc", "_type", "_units", "_option", "_choices"] @@ -212,16 +186,46 @@ pism_config = xr.DataArray( pc_vals, dims=["pism_config_axis"], - coords={"pism_config_axis": pc_keys}, + coords={"pism_config_axis": pc_keys, "exp_id": m_id, "ensemble_id": ensemble}, name="pism_config", ) run_stats = xr.DataArray( rs_vals, dims=["run_stats_axis"], - coords={"run_stats_axis": rs_keys}, + coords={"run_stats_axis": rs_keys, "exp_id": m_id, "ensemble_id": ensemble}, name="run_stats", ) - basin_sums = xr.merge([basin_sums, pism_config, run_stats]) + ds = xr.merge( + [ + ds[mb_vars].drop_vars(["pism_config", "run_stats"], errors="ignore"), + pism_config, + run_stats, + ] + ) + print(f"Size in memory: {(ds.nbytes / 1024**3):.1f} GB") + + basins_file = result_dir / f"basins_sums_ensemble_{ensemble}_id_{m_id}.nc" + + client = Client() + print(f"Open client in browser: {client.dashboard_link}") + + start = time.time() + + basins_ds_scattered = client.scatter( + [ds] + [ds.rio.clip([basin.geometry]) for _, basin in basins.iterrows()] + ) + basin_names = ["GRACE"] + [basin["SUBREGION1"] for _, basin in basins.iterrows()] + n_basins = len(basin_names) + futures = client.map(compute_basin, basins_ds_scattered, basin_names) + progress(futures) + basin_sums = ( + xr.concat(client.gather(futures), dim="basin") + .drop_vars(["mapping", "spatial_ref"]) + .sortby(["basin", "time"]) + ) + if cf: + basin_sums["basin"] = basin_sums["basin"].astype(f"S{n_basins}") + basin_sums.attrs["Conventions"] = "CF-1.8" basin_sums.to_netcdf(basins_file, engine=engine) diff --git a/pism_ragis/processing.py b/pism_ragis/processing.py index 15bfa2e..2a098f8 100644 --- a/pism_ragis/processing.py +++ b/pism_ragis/processing.py @@ -600,6 +600,7 @@ def load_ensemble( ds = xr.open_mfdataset( filenames, parallel=parallel, + chunks={"exp_id": -1, "pism_config_axis": -1}, engine=engine, ).drop_vars(["spatial_ref", "mapping"], errors="ignore") print("Done.")