diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 1871c11b85..809f2c5bba 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -1,12 +1,14 @@ from __future__ import annotations -import shutil -import pickle import warnings -import tempfile +import platform from pathlib import Path from tqdm.auto import tqdm +from concurrent.futures import ProcessPoolExecutor +import multiprocessing as mp +from threadpoolctl import threadpool_limits + import numpy as np from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension @@ -314,11 +316,13 @@ def _run(self, verbose=False, **job_kwargs): job_kwargs = fix_job_kwargs(job_kwargs) n_jobs = job_kwargs["n_jobs"] progress_bar = job_kwargs["progress_bar"] + max_threads_per_process = job_kwargs["max_threads_per_process"] + mp_context = job_kwargs["mp_context"] # fit model/models # TODO : make parralel for by_channel_global and concatenated if mode == "by_channel_local": - pca_models = self._fit_by_channel_local(n_jobs, progress_bar) + pca_models = self._fit_by_channel_local(n_jobs, progress_bar, max_threads_per_process, mp_context) for chan_ind, chan_id in enumerate(self.sorting_analyzer.channel_ids): self.data[f"pca_model_{mode}_{chan_id}"] = pca_models[chan_ind] pca_model = pca_models @@ -411,12 +415,16 @@ def run_for_all_spikes(self, file_path=None, verbose=False, **job_kwargs): ) processor.run() - def _fit_by_channel_local(self, n_jobs, progress_bar): + def _fit_by_channel_local(self, n_jobs, progress_bar, max_threads_per_process, mp_context): from sklearn.decomposition import IncrementalPCA - from concurrent.futures import ProcessPoolExecutor p = self.params + if mp_context is not None and platform.system() == "Windows": + assert mp_context != "fork", "'fork' mp_context not supported on Windows!" + elif mp_context == "fork" and platform.system() == "Darwin": + warnings.warn('As of Python 3.8 "fork" is no longer considered safe on macOS') + unit_ids = self.sorting_analyzer.unit_ids channel_ids = self.sorting_analyzer.channel_ids # there is one PCA per channel for independent fit per channel @@ -436,13 +444,18 @@ def _fit_by_channel_local(self, n_jobs, progress_bar): pca = pca_models[chan_ind] pca.partial_fit(wfs[:, :, wf_ind]) else: - # parallel + # create list of args to parallelize. For convenience, the max_threads_per_process is passed + # as last argument items = [ - (chan_ind, pca_models[chan_ind], wfs[:, :, wf_ind]) for wf_ind, chan_ind in enumerate(channel_inds) + (chan_ind, pca_models[chan_ind], wfs[:, :, wf_ind], max_threads_per_process) + for wf_ind, chan_ind in enumerate(channel_inds) ] n_jobs = min(n_jobs, len(items)) - with ProcessPoolExecutor(max_workers=n_jobs) as executor: + with ProcessPoolExecutor( + max_workers=n_jobs, + mp_context=mp.get_context(mp_context), + ) as executor: results = executor.map(_partial_fit_one_channel, items) for chan_ind, pca_model_updated in results: pca_models[chan_ind] = pca_model_updated @@ -674,6 +687,12 @@ def _init_work_all_pc_extractor(recording, sorting, all_pcs_args, nbefore, nafte def _partial_fit_one_channel(args): - chan_ind, pca_model, wf_chan = args - pca_model.partial_fit(wf_chan) - return chan_ind, pca_model + chan_ind, pca_model, wf_chan, max_threads_per_process = args + + if max_threads_per_process is None: + pca_model.partial_fit(wf_chan) + return chan_ind, pca_model + else: + with threadpool_limits(limits=int(max_threads_per_process)): + pca_model.partial_fit(wf_chan) + return chan_ind, pca_model diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index 4de86be32b..7a509c410f 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -18,6 +18,18 @@ class TestPrincipalComponentsExtension(AnalyzerExtensionCommonTestSuite): def test_extension(self, params): self.run_extension_tests(ComputePrincipalComponents, params=params) + def test_multi_processing(self): + """ + Test the extension works with multiple processes. + """ + sorting_analyzer = self._prepare_sorting_analyzer( + format="memory", sparse=False, extension_class=ComputePrincipalComponents + ) + sorting_analyzer.compute("principal_components", mode="by_channel_local", n_jobs=2) + sorting_analyzer.compute( + "principal_components", mode="by_channel_local", n_jobs=2, max_threads_per_process=4, mp_context="spawn" + ) + def test_mode_concatenated(self): """ Replicate the "extension_function_params_list" test outside of diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index 7c099a2f74..4c68dfea59 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -2,15 +2,16 @@ from __future__ import annotations - +import warnings from copy import deepcopy - -import numpy as np +import platform from tqdm.auto import tqdm -from concurrent.futures import ProcessPoolExecutor +import numpy as np -import warnings +import multiprocessing as mp +from concurrent.futures import ProcessPoolExecutor +from threadpoolctl import threadpool_limits from .misc_metrics import compute_num_spikes, compute_firing_rates @@ -56,6 +57,8 @@ def compute_pc_metrics( seed=None, n_jobs=1, progress_bar=False, + mp_context=None, + max_threads_per_process=None, ) -> dict: """ Calculate principal component derived metrics. @@ -144,17 +147,7 @@ def compute_pc_metrics( pcs = dense_projections[np.isin(all_labels, neighbor_unit_ids)][:, :, neighbor_channel_indices] pcs_flat = pcs.reshape(pcs.shape[0], -1) - func_args = ( - pcs_flat, - labels, - non_nn_metrics, - unit_id, - unit_ids, - qm_params, - seed, - n_spikes_all_units, - fr_all_units, - ) + func_args = (pcs_flat, labels, non_nn_metrics, unit_id, unit_ids, qm_params, max_threads_per_process) items.append(func_args) if not run_in_parallel and non_nn_metrics: @@ -167,7 +160,15 @@ def compute_pc_metrics( for metric_name, metric in pca_metrics_unit.items(): pc_metrics[metric_name][unit_id] = metric elif run_in_parallel and non_nn_metrics: - with ProcessPoolExecutor(n_jobs) as executor: + if mp_context is not None and platform.system() == "Windows": + assert mp_context != "fork", "'fork' mp_context not supported on Windows!" + elif mp_context == "fork" and platform.system() == "Darwin": + warnings.warn('As of Python 3.8 "fork" is no longer considered safe on macOS') + + with ProcessPoolExecutor( + max_workers=n_jobs, + mp_context=mp.get_context(mp_context), + ) as executor: results = executor.map(pca_metrics_one_unit, items) if progress_bar: results = tqdm(results, total=len(unit_ids), desc="calculate_pc_metrics") @@ -976,26 +977,19 @@ def _compute_isolation(pcs_target_unit, pcs_other_unit, n_neighbors: int): def pca_metrics_one_unit(args): - ( - pcs_flat, - labels, - metric_names, - unit_id, - unit_ids, - qm_params, - seed, - # we_folder, - n_spikes_all_units, - fr_all_units, - ) = args - - # if "nn_isolation" in metric_names or "nn_noise_overlap" in metric_names: - # we = load_waveforms(we_folder) + (pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params, max_threads_per_process) = args + + if max_threads_per_process is None: + return _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params) + else: + with threadpool_limits(limits=int(max_threads_per_process)): + return _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params) + +def _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params): pc_metrics = {} # metrics if "isolation_distance" in metric_names or "l_ratio" in metric_names: - try: isolation_distance, l_ratio = mahalanobis_metrics(pcs_flat, labels, unit_id) except: diff --git a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py index 6ddeb02689..f2e912c6b4 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py @@ -1,9 +1,7 @@ import pytest import numpy as np -from spikeinterface.qualitymetrics import ( - compute_pc_metrics, -) +from spikeinterface.qualitymetrics import compute_pc_metrics, get_quality_pca_metric_list def test_calculate_pc_metrics(small_sorting_analyzer): @@ -22,3 +20,24 @@ def test_calculate_pc_metrics(small_sorting_analyzer): assert not np.all(np.isnan(res2[metric_name].values)) assert np.array_equal(res1[metric_name].values, res2[metric_name].values) + + +def test_pca_metrics_multi_processing(small_sorting_analyzer): + sorting_analyzer = small_sorting_analyzer + + metric_names = get_quality_pca_metric_list() + metric_names.remove("nn_isolation") + metric_names.remove("nn_noise_overlap") + + print(f"Computing PCA metrics with 1 thread per process") + res1 = compute_pc_metrics( + sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_process=1, progress_bar=True + ) + print(f"Computing PCA metrics with 2 thread per process") + res2 = compute_pc_metrics( + sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_process=2, progress_bar=True + ) + print("Computing PCA metrics with spawn context") + res2 = compute_pc_metrics( + sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_process=2, progress_bar=True + )