From 039b408a59ce965b91908de82d0bc55114f8655e Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Mon, 2 Dec 2024 13:43:39 +0000 Subject: [PATCH] move backwards compat to `_handle_backward_compatibility_on_load` --- .../postprocessing/template_metrics.py | 25 ++++++++++--------- .../quality_metric_calculator.py | 14 +++++------ 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 477ad04440..7de6e8766a 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -88,9 +88,22 @@ class ComputeTemplateMetrics(AnalyzerExtension): need_recording = False use_nodepipeline = False need_job_kwargs = False + need_backward_compatibility_on_load = True min_channels_for_multi_channel_warning = 10 + def _handle_backward_compatibility_on_load(self): + + # For backwards compatibility - this reformats metrics_kwargs as metric_params + if (metrics_kwargs := self.params.get("metrics_kwargs")) is not None: + + metric_params = {} + for metric_name in self.params["metric_names"]: + metric_params[metric_name] = deepcopy(metrics_kwargs) + self.params["metric_params"] = metric_params + + del self.params["metrics_kwargs"] + def _set_params( self, metric_names=None, @@ -344,18 +357,6 @@ def _run(self, verbose=False): def _get_data(self): return self.data["metrics"] - def load_params(self): - AnalyzerExtension.load_params(self) - # For backwards compatibility - this reformats metrics_kwargs as metric_params - if (metrics_kwargs := self.params.get("metrics_kwargs")) is not None: - - metric_params = {} - for metric_name in self.params["metric_names"]: - metric_params[metric_name] = deepcopy(metrics_kwargs) - self.params["metric_params"] = metric_params - - del self.params["metrics_kwargs"] - register_result_extension(ComputeTemplateMetrics) compute_template_metrics = ComputeTemplateMetrics.function_factory() diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index e7e7c244ea..d71450853f 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -55,6 +55,13 @@ class ComputeQualityMetrics(AnalyzerExtension): need_recording = False use_nodepipeline = False need_job_kwargs = True + need_backward_compatibility_on_load = True + + def _handle_backward_compatibility_on_load(self): + # For backwards compatibility - this renames qm_params as metric_params + if (qm_params := self.params.get("qm_params")) is not None: + self.params["metric_params"] = qm_params + del self.params["qm_params"] def _set_params( self, @@ -262,13 +269,6 @@ def _run(self, verbose=False, **job_kwargs): def _get_data(self): return self.data["metrics"] - def load_params(self): - AnalyzerExtension.load_params(self) - # For backwards compatibility - this renames qm_params as metric_params - if (qm_params := self.params.get("qm_params")) is not None: - self.params["metric_params"] = qm_params - del self.params["qm_params"] - register_result_extension(ComputeQualityMetrics) compute_quality_metrics = ComputeQualityMetrics.function_factory()