Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add max_threads_per_process and mp_context to pca by channel computation and PCA metrics #3434

Merged
merged 6 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 31 additions & 12 deletions src/spikeinterface/postprocessing/principal_component.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations

import shutil
import pickle
import warnings
import tempfile
import platform
from pathlib import Path
from tqdm.auto import tqdm

from concurrent.futures import ProcessPoolExecutor
import multiprocessing as mp
from threadpoolctl import threadpool_limits

import numpy as np

from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension
Expand Down Expand Up @@ -314,11 +316,13 @@ def _run(self, verbose=False, **job_kwargs):
job_kwargs = fix_job_kwargs(job_kwargs)
n_jobs = job_kwargs["n_jobs"]
progress_bar = job_kwargs["progress_bar"]
max_threads_per_process = job_kwargs["max_threads_per_process"]
mp_context = job_kwargs["mp_context"]

# fit model/models
# TODO : make parralel for by_channel_global and concatenated
if mode == "by_channel_local":
pca_models = self._fit_by_channel_local(n_jobs, progress_bar)
pca_models = self._fit_by_channel_local(n_jobs, progress_bar, max_threads_per_process, mp_context)
for chan_ind, chan_id in enumerate(self.sorting_analyzer.channel_ids):
self.data[f"pca_model_{mode}_{chan_id}"] = pca_models[chan_ind]
pca_model = pca_models
Expand Down Expand Up @@ -411,12 +415,16 @@ def run_for_all_spikes(self, file_path=None, verbose=False, **job_kwargs):
)
processor.run()

def _fit_by_channel_local(self, n_jobs, progress_bar):
def _fit_by_channel_local(self, n_jobs, progress_bar, max_threads_per_process, mp_context):
from sklearn.decomposition import IncrementalPCA
from concurrent.futures import ProcessPoolExecutor

p = self.params

if mp_context is not None and platform.system() == "Windows":
assert mp_context != "fork", "'fork' mp_context not supported on Windows!"
elif mp_context == "fork" and platform.system() == "Darwin":
warnings.warn('As of Python 3.8 "fork" is no longer considered safe on macOS')

unit_ids = self.sorting_analyzer.unit_ids
channel_ids = self.sorting_analyzer.channel_ids
# there is one PCA per channel for independent fit per channel
Expand All @@ -436,13 +444,18 @@ def _fit_by_channel_local(self, n_jobs, progress_bar):
pca = pca_models[chan_ind]
pca.partial_fit(wfs[:, :, wf_ind])
else:
# parallel
# create list of args to parallelize. For convenience, the max_threads_per_process is passed
# as last argument
items = [
(chan_ind, pca_models[chan_ind], wfs[:, :, wf_ind]) for wf_ind, chan_ind in enumerate(channel_inds)
(chan_ind, pca_models[chan_ind], wfs[:, :, wf_ind], max_threads_per_process)
for wf_ind, chan_ind in enumerate(channel_inds)
]
n_jobs = min(n_jobs, len(items))

with ProcessPoolExecutor(max_workers=n_jobs) as executor:
with ProcessPoolExecutor(
max_workers=n_jobs,
mp_context=mp.get_context(mp_context),
) as executor:
results = executor.map(_partial_fit_one_channel, items)
for chan_ind, pca_model_updated in results:
pca_models[chan_ind] = pca_model_updated
Expand Down Expand Up @@ -674,6 +687,12 @@ def _init_work_all_pc_extractor(recording, sorting, all_pcs_args, nbefore, nafte


def _partial_fit_one_channel(args):
chan_ind, pca_model, wf_chan = args
pca_model.partial_fit(wf_chan)
return chan_ind, pca_model
chan_ind, pca_model, wf_chan, max_threads_per_process = args

if max_threads_per_process is None:
pca_model.partial_fit(wf_chan)
return chan_ind, pca_model
else:
with threadpool_limits(limits=int(max_threads_per_process)):
pca_model.partial_fit(wf_chan)
return chan_ind, pca_model
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,18 @@ class TestPrincipalComponentsExtension(AnalyzerExtensionCommonTestSuite):
def test_extension(self, params):
self.run_extension_tests(ComputePrincipalComponents, params=params)

def test_multi_processing(self):
"""
Test the extension works with multiple processes.
"""
sorting_analyzer = self._prepare_sorting_analyzer(
format="memory", sparse=False, extension_class=ComputePrincipalComponents
)
sorting_analyzer.compute("principal_components", mode="by_channel_local", n_jobs=2)
sorting_analyzer.compute(
"principal_components", mode="by_channel_local", n_jobs=2, max_threads_per_process=4, mp_context="spawn"
)

def test_mode_concatenated(self):
"""
Replicate the "extension_function_params_list" test outside of
Expand Down
60 changes: 27 additions & 33 deletions src/spikeinterface/qualitymetrics/pca_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@

from __future__ import annotations


import warnings
from copy import deepcopy

import numpy as np
import platform
from tqdm.auto import tqdm
from concurrent.futures import ProcessPoolExecutor

import numpy as np

import warnings
import multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor
from threadpoolctl import threadpool_limits

from .misc_metrics import compute_num_spikes, compute_firing_rates

Expand Down Expand Up @@ -56,6 +57,8 @@ def compute_pc_metrics(
seed=None,
n_jobs=1,
progress_bar=False,
mp_context=None,
max_threads_per_process=None,
) -> dict:
"""
Calculate principal component derived metrics.
Expand Down Expand Up @@ -144,17 +147,7 @@ def compute_pc_metrics(
pcs = dense_projections[np.isin(all_labels, neighbor_unit_ids)][:, :, neighbor_channel_indices]
pcs_flat = pcs.reshape(pcs.shape[0], -1)

func_args = (
pcs_flat,
labels,
non_nn_metrics,
unit_id,
unit_ids,
qm_params,
seed,
n_spikes_all_units,
fr_all_units,
)
func_args = (pcs_flat, labels, non_nn_metrics, unit_id, unit_ids, qm_params, max_threads_per_process)
items.append(func_args)

if not run_in_parallel and non_nn_metrics:
Expand All @@ -167,7 +160,15 @@ def compute_pc_metrics(
for metric_name, metric in pca_metrics_unit.items():
pc_metrics[metric_name][unit_id] = metric
elif run_in_parallel and non_nn_metrics:
with ProcessPoolExecutor(n_jobs) as executor:
if mp_context is not None and platform.system() == "Windows":
assert mp_context != "fork", "'fork' mp_context not supported on Windows!"
elif mp_context == "fork" and platform.system() == "Darwin":
warnings.warn('As of Python 3.8 "fork" is no longer considered safe on macOS')

with ProcessPoolExecutor(
max_workers=n_jobs,
mp_context=mp.get_context(mp_context),
) as executor:
results = executor.map(pca_metrics_one_unit, items)
if progress_bar:
results = tqdm(results, total=len(unit_ids), desc="calculate_pc_metrics")
Expand Down Expand Up @@ -976,26 +977,19 @@ def _compute_isolation(pcs_target_unit, pcs_other_unit, n_neighbors: int):


def pca_metrics_one_unit(args):
(
pcs_flat,
labels,
metric_names,
unit_id,
unit_ids,
qm_params,
seed,
# we_folder,
n_spikes_all_units,
fr_all_units,
) = args

# if "nn_isolation" in metric_names or "nn_noise_overlap" in metric_names:
# we = load_waveforms(we_folder)
(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params, max_threads_per_process) = args

if max_threads_per_process is None:
return _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params)
else:
with threadpool_limits(limits=int(max_threads_per_process)):
return _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params)


def _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params):
pc_metrics = {}
# metrics
if "isolation_distance" in metric_names or "l_ratio" in metric_names:

try:
isolation_distance, l_ratio = mahalanobis_metrics(pcs_flat, labels, unit_id)
except:
Expand Down
25 changes: 22 additions & 3 deletions src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import pytest
import numpy as np

from spikeinterface.qualitymetrics import (
compute_pc_metrics,
)
from spikeinterface.qualitymetrics import compute_pc_metrics, get_quality_pca_metric_list


def test_calculate_pc_metrics(small_sorting_analyzer):
Expand All @@ -22,3 +20,24 @@ def test_calculate_pc_metrics(small_sorting_analyzer):
assert not np.all(np.isnan(res2[metric_name].values))

assert np.array_equal(res1[metric_name].values, res2[metric_name].values)


def test_pca_metrics_multi_processing(small_sorting_analyzer):
sorting_analyzer = small_sorting_analyzer

metric_names = get_quality_pca_metric_list()
metric_names.remove("nn_isolation")
metric_names.remove("nn_noise_overlap")

print(f"Computing PCA metrics with 1 thread per process")
res1 = compute_pc_metrics(
sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_process=1, progress_bar=True
)
print(f"Computing PCA metrics with 2 thread per process")
res2 = compute_pc_metrics(
sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_process=2, progress_bar=True
)
print("Computing PCA metrics with spawn context")
res2 = compute_pc_metrics(
sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_process=2, progress_bar=True
)