Skip to content

Commit

Permalink
move backwards compat to _handle_backward_compatibility_on_load
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishalcrow committed Dec 2, 2024
1 parent 9db0b83 commit 039b408
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 19 deletions.
25 changes: 13 additions & 12 deletions src/spikeinterface/postprocessing/template_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,22 @@ class ComputeTemplateMetrics(AnalyzerExtension):
need_recording = False
use_nodepipeline = False
need_job_kwargs = False
need_backward_compatibility_on_load = True

min_channels_for_multi_channel_warning = 10

def _handle_backward_compatibility_on_load(self):

# For backwards compatibility - this reformats metrics_kwargs as metric_params
if (metrics_kwargs := self.params.get("metrics_kwargs")) is not None:

metric_params = {}
for metric_name in self.params["metric_names"]:
metric_params[metric_name] = deepcopy(metrics_kwargs)
self.params["metric_params"] = metric_params

del self.params["metrics_kwargs"]

def _set_params(
self,
metric_names=None,
Expand Down Expand Up @@ -344,18 +357,6 @@ def _run(self, verbose=False):
def _get_data(self):
return self.data["metrics"]

def load_params(self):
AnalyzerExtension.load_params(self)
# For backwards compatibility - this reformats metrics_kwargs as metric_params
if (metrics_kwargs := self.params.get("metrics_kwargs")) is not None:

metric_params = {}
for metric_name in self.params["metric_names"]:
metric_params[metric_name] = deepcopy(metrics_kwargs)
self.params["metric_params"] = metric_params

del self.params["metrics_kwargs"]


register_result_extension(ComputeTemplateMetrics)
compute_template_metrics = ComputeTemplateMetrics.function_factory()
Expand Down
14 changes: 7 additions & 7 deletions src/spikeinterface/qualitymetrics/quality_metric_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ class ComputeQualityMetrics(AnalyzerExtension):
need_recording = False
use_nodepipeline = False
need_job_kwargs = True
need_backward_compatibility_on_load = True

def _handle_backward_compatibility_on_load(self):
# For backwards compatibility - this renames qm_params as metric_params
if (qm_params := self.params.get("qm_params")) is not None:
self.params["metric_params"] = qm_params
del self.params["qm_params"]

def _set_params(
self,
Expand Down Expand Up @@ -262,13 +269,6 @@ def _run(self, verbose=False, **job_kwargs):
def _get_data(self):
return self.data["metrics"]

def load_params(self):
AnalyzerExtension.load_params(self)
# For backwards compatibility - this renames qm_params as metric_params
if (qm_params := self.params.get("qm_params")) is not None:
self.params["metric_params"] = qm_params
del self.params["qm_params"]


register_result_extension(ComputeQualityMetrics)
compute_quality_metrics = ComputeQualityMetrics.function_factory()
Expand Down

0 comments on commit 039b408

Please sign in to comment.