Skip to content

Commit

Permalink
Softmax Score Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Sohambutala committed Sep 30, 2024
1 parent 5a2a477 commit f7d5ad8
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 7 deletions.
2 changes: 1 addition & 1 deletion echodataflow/stages/subflows/echoshader_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def eshader_preprocess(ed: EchodataflowObject, working_dir, config: Dataset, sta
del ds_MVBS_combined_resampled

ds_MVBS_combined_resampled = xr.open_zarr(working_dir + "/" + "eshader.zarr", storage_options=config.output.storage_options_dict)
print(ds_MVBS_combined_resampled["softmax"])
# print(ds_MVBS_combined_resampled["softmax"])


# hv_ds = hv.Dataset(ds_MVBS_combined_resampled["softmax"])
Expand Down
44 changes: 38 additions & 6 deletions echodataflow/stages/subflows/initialization_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,15 +543,47 @@ def apply_scores_if_needed(config: Dataset, store_output: Output, end_time: date

edf.data = edf.data.sel(ping_time=slice(min_time, max_time))
score_ds = score_ds.sel(ping_time=slice(min_time, max_time))

# Apply softmax or other score-based logic
softmax = xr.apply_ufunc(scipy.special.softmax, score_ds, kwargs={'axis': 0}, dask="allowed")

# resampling
edf.data = edf.data.resample(ping_time="30s").mean()
softmax = softmax.resample(ping_time="30s").mean()
score_ds = score_ds.resample(ping_time="30s").mean()

score_ds = score_ds.transpose("ping_time", "depth", "species")
tensor_scores = torch.tensor(score_ds['__xarray_dataarray_variable__'].compute().values)
temperature = 0.5
softmax_scores = torch.nn.functional.softmax(tensor_scores / temperature, dim=2)
score_ds = score_ds.assign(softmax=(('ping_time', 'depth', 'species'), softmax_scores.numpy()))

try:
score_ds.to_zarr(
os.path.join(os.path.expanduser("~"), ".echodataflow", "score.zarr"),
mode="w",
consolidated=True
)
log_util.log(
msg={
"msg": f"min {min_time} Max {max_time}" ,
"mod_name": __file__,
"func_name": "Init Flow",
},
eflogging=config.logging,
)
except:
pass


log_util.log(
msg={
"msg": f"min {score_ds['softmax'].min().values} Max {score_ds['softmax'].max().values}" ,
"mod_name": __file__,
"func_name": "Init Flow",
},
eflogging=config.logging,
)

# # Apply softmax or other score-based logic
# softmax = xr.apply_ufunc(scipy.special.softmax, score_ds, kwargs={'axis': 0}, dask="allowed")

edf.data = edf.data.assign(softmax=softmax.sel(species="hake")["__xarray_dataarray_variable__"])
edf.data = edf.data.assign(softmax=score_ds.sel(species="hake")["softmax"])

edf.data.to_zarr(
os.path.join(os.path.expanduser("~"), ".echodataflow", "eshader.zarr"),
Expand Down

0 comments on commit f7d5ad8

Please sign in to comment.