diff --git a/doc/get_started/quickstart.rst b/doc/get_started/quickstart.rst index 1349802ce5..bfa4335ac1 100644 --- a/doc/get_started/quickstart.rst +++ b/doc/get_started/quickstart.rst @@ -287,7 +287,7 @@ available parameters are dictionaries and can be accessed with: 'detect_threshold': 5, 'freq_max': 5000.0, 'freq_min': 400.0, - 'max_threads_per_process': 1, + 'max_threads_per_worker': 1, 'mp_context': None, 'n_jobs': 20, 'nested_params': None, diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index ead7007920..bea77decfc 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -90,7 +90,7 @@ write_python, normal_pdf, ) -from .job_tools import ensure_n_jobs, ensure_chunk_size, ChunkRecordingExecutor, split_job_kwargs, fix_job_kwargs +from .job_tools import get_best_job_kwargs, ensure_n_jobs, ensure_chunk_size, ChunkRecordingExecutor, split_job_kwargs, fix_job_kwargs from .recording_tools import ( write_binary_recording, write_to_h5_dataset_format, diff --git a/src/spikeinterface/core/globals.py b/src/spikeinterface/core/globals.py index 38f39c5481..195440c061 100644 --- a/src/spikeinterface/core/globals.py +++ b/src/spikeinterface/core/globals.py @@ -97,7 +97,7 @@ def is_set_global_dataset_folder() -> bool: ######################################## -_default_job_kwargs = dict(n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1) +_default_job_kwargs = dict(pool_engine="thread", n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_worker=1) global global_job_kwargs global_job_kwargs = _default_job_kwargs.copy() diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index 7a6172369b..ce7eb05dbc 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -12,8 +12,9 @@ import sys from tqdm.auto import tqdm -from concurrent.futures import ProcessPoolExecutor -import multiprocessing as mp +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor +import multiprocessing +import threading from threadpoolctl import threadpool_limits @@ -39,6 +40,7 @@ job_keys = ( + "pool_engine", "n_jobs", "total_memory", "chunk_size", @@ -46,7 +48,7 @@ "chunk_duration", "progress_bar", "mp_context", - "max_threads_per_process", + "max_threads_per_worker", ) # theses key are the same and should not be in th final dict @@ -57,12 +59,63 @@ "chunk_duration", ) +def get_best_job_kwargs(): + """ + Gives best possible job_kwargs for the platform. + Currently this function is from developer experience, but may be adapted in the future. + """ + + n_cpu = os.cpu_count() + + if platform.system() == "Linux": + # maybe we should test this more but with linux the fork is still faster than threading + pool_engine = "process" + mp_context = "fork" + + # this is totally empirical but this is a good start + if n_cpu <= 16: + # for small n_cpu let's make many process + n_jobs = n_cpu + max_threads_per_worker = 1 + else: + # let's have fewer processes with more threads each + n_cpu = int(n_cpu / 4) + max_threads_per_worker = 8 + + else: # windows and mac + # on windows and macos the fork is forbidden and process+spwan is super slow at startup + # so let's go to threads + pool_engine = "thread" + mp_context = None + n_jobs = n_cpu + max_threads_per_worker = 1 + + return dict( + pool_engine=pool_engine, + mp_context=mp_context, + n_jobs=n_jobs, + max_threads_per_worker=max_threads_per_worker, + ) + + + def fix_job_kwargs(runtime_job_kwargs): from .globals import get_global_job_kwargs, is_set_global_job_kwargs_set job_kwargs = get_global_job_kwargs() + # deprecation with backward compatibility + # this can be removed in 0.104.0 + if "max_threads_per_process" in runtime_job_kwargs: + runtime_job_kwargs = runtime_job_kwargs.copy() + runtime_job_kwargs["max_threads_per_worker"] = runtime_job_kwargs.pop("max_threads_per_process") + warnings.warn( + "job_kwargs: max_threads_per_process was changed to max_threads_per_worker, max_threads_per_process will be removed in 0.104", + DeprecationWarning, + stacklevel=2, + ) + for k in runtime_job_kwargs: assert k in job_keys, ( f"{k} is not a valid job keyword argument. " f"Available keyword arguments are: {list(job_keys)}" @@ -287,11 +340,15 @@ class ChunkRecordingExecutor: If True, output is verbose job_name : str, default: "" Job name + progress_bar : bool, default: False + If True, a progress bar is printed to monitor the progress of the process handle_returns : bool, default: False If True, the function can return values gather_func : None or callable, default: None Optional function that is called in the main thread and retrieves the results of each worker. This function can be used instead of `handle_returns` to implement custom storage on-the-fly. + pool_engine : "process" | "thread", default: "thread" + If n_jobs>1 then use ProcessPoolExecutor or ThreadPoolExecutor n_jobs : int, default: 1 Number of jobs to be used. Use -1 to use as many jobs as number of cores total_memory : str, default: None @@ -305,13 +362,12 @@ class ChunkRecordingExecutor: mp_context : "fork" | "spawn" | None, default: None "fork" or "spawn". If None, the context is taken by the recording.get_preferred_mp_context(). "fork" is only safely available on LINUX systems. - max_threads_per_process : int or None, default: None + max_threads_per_worker : int or None, default: None Limit the number of thread per process using threadpoolctl modules. This used only when n_jobs>1 If None, no limits. - progress_bar : bool, default: False - If True, a progress bar is printed to monitor the progress of the process - + need_worker_index : bool, default False + If True then each worker will also have a "worker_index" injected in the local worker dict. Returns ------- @@ -329,6 +385,7 @@ def __init__( progress_bar=False, handle_returns=False, gather_func=None, + pool_engine="thread", n_jobs=1, total_memory=None, chunk_size=None, @@ -336,19 +393,21 @@ def __init__( chunk_duration=None, mp_context=None, job_name="", - max_threads_per_process=1, + max_threads_per_worker=1, + need_worker_index=False, ): self.recording = recording self.func = func self.init_func = init_func self.init_args = init_args - if mp_context is None: - mp_context = recording.get_preferred_mp_context() - if mp_context is not None and platform.system() == "Windows": - assert mp_context != "fork", "'fork' mp_context not supported on Windows!" - elif mp_context == "fork" and platform.system() == "Darwin": - warnings.warn('As of Python 3.8 "fork" is no longer considered safe on macOS') + if pool_engine == "process": + if mp_context is None: + mp_context = recording.get_preferred_mp_context() + if mp_context is not None and platform.system() == "Windows": + assert mp_context != "fork", "'fork' mp_context not supported on Windows!" + elif mp_context == "fork" and platform.system() == "Darwin": + warnings.warn('As of Python 3.8 "fork" is no longer considered safe on macOS') self.mp_context = mp_context @@ -368,7 +427,11 @@ def __init__( n_jobs=self.n_jobs, ) self.job_name = job_name - self.max_threads_per_process = max_threads_per_process + self.max_threads_per_worker = max_threads_per_worker + + self.pool_engine = pool_engine + + self.need_worker_index = need_worker_index if verbose: chunk_memory = self.chunk_size * recording.get_num_channels() * np.dtype(recording.get_dtype()).itemsize @@ -380,6 +443,7 @@ def __init__( print( self.job_name, "\n" + f"engine={self.pool_engine} - " f"n_jobs={self.n_jobs} - " f"samples_per_chunk={self.chunk_size:,} - " f"chunk_memory={chunk_memory_str} - " @@ -402,69 +466,172 @@ def run(self, recording_slices=None): if self.n_jobs == 1: if self.progress_bar: - recording_slices = tqdm(recording_slices, ascii=True, desc=self.job_name) + recording_slices = tqdm(recording_slices, desc=self.job_name, total=len(recording_slices)) + + worker_dict = self.init_func(*self.init_args) + if self.need_worker_index: + worker_dict["worker_index"] = 0 - worker_ctx = self.init_func(*self.init_args) for segment_index, frame_start, frame_stop in recording_slices: - res = self.func(segment_index, frame_start, frame_stop, worker_ctx) + res = self.func(segment_index, frame_start, frame_stop, worker_dict) if self.handle_returns: returns.append(res) if self.gather_func is not None: self.gather_func(res) + else: n_jobs = min(self.n_jobs, len(recording_slices)) - # parallel - with ProcessPoolExecutor( - max_workers=n_jobs, - initializer=worker_initializer, - mp_context=mp.get_context(self.mp_context), - initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_process), - ) as executor: - results = executor.map(function_wrapper, recording_slices) + if self.pool_engine == "process": + + if self.need_worker_index: + lock = multiprocessing.Lock() + array_pid = multiprocessing.Array("i", n_jobs) + for i in range(n_jobs): + array_pid[i] = -1 + else: + lock = None + array_pid = None + + # parallel + with ProcessPoolExecutor( + max_workers=n_jobs, + initializer=process_worker_initializer, + mp_context=multiprocessing.get_context(self.mp_context), + initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_worker, self.need_worker_index, lock, array_pid), + ) as executor: + results = executor.map(process_function_wrapper, recording_slices) + + if self.progress_bar: + results = tqdm(results, desc=self.job_name, total=len(recording_slices)) + + for res in results: + if self.handle_returns: + returns.append(res) + if self.gather_func is not None: + self.gather_func(res) + + elif self.pool_engine == "thread": + # this is need to create a per worker local dict where the initializer will push the func wrapper + thread_local_data = threading.local() + + global _thread_started + _thread_started = 0 if self.progress_bar: - results = tqdm(results, desc=self.job_name, total=len(recording_slices)) - - for res in results: - if self.handle_returns: - returns.append(res) - if self.gather_func is not None: - self.gather_func(res) + # here the tqdm threading do not work (maybe collision) so we need to create a pbar + # before thread spawning + pbar = tqdm(desc=self.job_name, total=len(recording_slices)) + + if self.need_worker_index: + lock = threading.Lock() + else: + lock = None + + with ThreadPoolExecutor( + max_workers=n_jobs, + initializer=thread_worker_initializer, + initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_worker, thread_local_data, self.need_worker_index, lock), + ) as executor: + + + recording_slices2 = [(thread_local_data, ) + args for args in recording_slices] + results = executor.map(thread_function_wrapper, recording_slices2) + + for res in results: + if self.progress_bar: + pbar.update(1) + if self.handle_returns: + returns.append(res) + if self.gather_func is not None: + self.gather_func(res) + if self.progress_bar: + pbar.close() + del pbar + else: + raise ValueError("If n_jobs>1 pool_engine must be 'process' or 'thread'") + return returns + +class WorkerFuncWrapper: + """ + small wrapper that handles: + * local worker_dict + * max_threads_per_worker + """ + def __init__(self, func, worker_dict, max_threads_per_worker): + self.func = func + self.worker_dict = worker_dict + self.max_threads_per_worker = max_threads_per_worker + + def __call__(self, args): + segment_index, start_frame, end_frame = args + if self.max_threads_per_worker is None: + return self.func(segment_index, start_frame, end_frame, self.worker_dict) + else: + with threadpool_limits(limits=self.max_threads_per_worker): + return self.func(segment_index, start_frame, end_frame, self.worker_dict) # see # https://stackoverflow.com/questions/10117073/how-to-use-initializer-to-set-up-my-multiprocess-pool -# the tricks is : theses 2 variables are global per worker -# so they are not share in the same process -global _worker_ctx -global _func +# the trick is : this variable is global per worker (so not shared in the same process) +global _process_func_wrapper -def worker_initializer(func, init_func, init_args, max_threads_per_process): - global _worker_ctx - if max_threads_per_process is None: - _worker_ctx = init_func(*init_args) +def process_worker_initializer(func, init_func, init_args, max_threads_per_worker, need_worker_index, lock, array_pid): + global _process_func_wrapper + if max_threads_per_worker is None: + worker_dict = init_func(*init_args) else: - with threadpool_limits(limits=max_threads_per_process): - _worker_ctx = init_func(*init_args) - _worker_ctx["max_threads_per_process"] = max_threads_per_process - global _func - _func = func - - -def function_wrapper(args): - segment_index, start_frame, end_frame = args - global _func - global _worker_ctx - max_threads_per_process = _worker_ctx["max_threads_per_process"] - if max_threads_per_process is None: - return _func(segment_index, start_frame, end_frame, _worker_ctx) + with threadpool_limits(limits=max_threads_per_worker): + worker_dict = init_func(*init_args) + + if need_worker_index: + child_process = multiprocessing.current_process() + lock.acquire() + worker_index = None + for i in range(len(array_pid)): + if array_pid[i] == -1: + worker_index = i + array_pid[i] = child_process.ident + break + worker_dict["worker_index"] = worker_index + lock.release() + + _process_func_wrapper = WorkerFuncWrapper(func, worker_dict, max_threads_per_worker) + +def process_function_wrapper(args): + global _process_func_wrapper + return _process_func_wrapper(args) + + +# use by thread at init +global _thread_started + +def thread_worker_initializer(func, init_func, init_args, max_threads_per_worker, thread_local_data, need_worker_index, lock): + if max_threads_per_worker is None: + worker_dict = init_func(*init_args) else: - with threadpool_limits(limits=max_threads_per_process): - return _func(segment_index, start_frame, end_frame, _worker_ctx) + with threadpool_limits(limits=max_threads_per_worker): + worker_dict = init_func(*init_args) + + if need_worker_index: + lock.acquire() + global _thread_started + worker_index = _thread_started + _thread_started += 1 + worker_dict["worker_index"] = worker_index + lock.release() + + thread_local_data.func_wrapper = WorkerFuncWrapper(func, worker_dict, max_threads_per_worker) + +def thread_function_wrapper(args): + thread_local_data = args[0] + args = args[1:] + return thread_local_data.func_wrapper(args) + # Here some utils copy/paste from DART (Charlie Windolf) diff --git a/src/spikeinterface/core/tests/test_globals.py b/src/spikeinterface/core/tests/test_globals.py index 9677378fc5..580287eb21 100644 --- a/src/spikeinterface/core/tests/test_globals.py +++ b/src/spikeinterface/core/tests/test_globals.py @@ -36,7 +36,7 @@ def test_global_tmp_folder(create_cache_folder): def test_global_job_kwargs(): - job_kwargs = dict(n_jobs=4, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1) + job_kwargs = dict(pool_engine="thread", n_jobs=4, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_worker=1) global_job_kwargs = get_global_job_kwargs() # test warning when not setting n_jobs and calling fix_job_kwargs @@ -44,7 +44,7 @@ def test_global_job_kwargs(): job_kwargs_split = fix_job_kwargs({}) assert global_job_kwargs == dict( - n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1 + pool_engine="thread", n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_worker=1 ) set_global_job_kwargs(**job_kwargs) assert get_global_job_kwargs() == job_kwargs @@ -59,7 +59,7 @@ def test_global_job_kwargs(): set_global_job_kwargs(**partial_job_kwargs) global_job_kwargs = get_global_job_kwargs() assert global_job_kwargs == dict( - n_jobs=2, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1 + pool_engine="thread", n_jobs=2, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_worker=1 ) # test that fix_job_kwargs grabs global kwargs new_job_kwargs = dict(n_jobs=cpu_count()) @@ -80,6 +80,6 @@ def test_global_job_kwargs(): if __name__ == "__main__": - test_global_dataset_folder() - test_global_tmp_folder() + # test_global_dataset_folder() + # test_global_tmp_folder() test_global_job_kwargs() diff --git a/src/spikeinterface/core/tests/test_job_tools.py b/src/spikeinterface/core/tests/test_job_tools.py index 2f3aff0023..824532a11e 100644 --- a/src/spikeinterface/core/tests/test_job_tools.py +++ b/src/spikeinterface/core/tests/test_job_tools.py @@ -1,7 +1,9 @@ import pytest import os -from spikeinterface.core import generate_recording, set_global_job_kwargs, get_global_job_kwargs +import time + +from spikeinterface.core import generate_recording, set_global_job_kwargs, get_global_job_kwargs, get_best_job_kwargs from spikeinterface.core.job_tools import ( divide_segment_into_chunks, @@ -77,28 +79,25 @@ def test_ensure_chunk_size(): assert end_frame == recording.get_num_frames(segment_index=segment_index) -def func(segment_index, start_frame, end_frame, worker_ctx): +def func(segment_index, start_frame, end_frame, worker_dict): import os - import time - #  print('func', segment_index, start_frame, end_frame, worker_ctx, os.getpid()) + #  print('func', segment_index, start_frame, end_frame, worker_dict, os.getpid()) time.sleep(0.010) # time.sleep(1.0) return os.getpid() def init_func(arg1, arg2, arg3): - worker_ctx = {} - worker_ctx["arg1"] = arg1 - worker_ctx["arg2"] = arg2 - worker_ctx["arg3"] = arg3 - return worker_ctx + worker_dict = {} + worker_dict["arg1"] = arg1 + worker_dict["arg2"] = arg2 + worker_dict["arg3"] = arg3 + return worker_dict def test_ChunkRecordingExecutor(): recording = generate_recording(num_channels=2) - # make serializable - recording = recording.save() init_args = "a", 120, "yep" @@ -139,7 +138,7 @@ def __call__(self, res): gathering_func2 = GatherClass() - # chunk + parallel + gather_func + # process + gather_func processor = ChunkRecordingExecutor( recording, func, @@ -148,6 +147,7 @@ def __call__(self, res): verbose=True, progress_bar=True, gather_func=gathering_func2, + pool_engine="process", n_jobs=2, chunk_duration="200ms", job_name="job_name", @@ -157,7 +157,7 @@ def __call__(self, res): assert gathering_func2.pos == num_chunks - # chunk + parallel + spawn + # process spawn processor = ChunkRecordingExecutor( recording, func, @@ -165,6 +165,7 @@ def __call__(self, res): init_args, verbose=True, progress_bar=True, + pool_engine="process", mp_context="spawn", n_jobs=2, chunk_duration="200ms", @@ -172,6 +173,21 @@ def __call__(self, res): ) processor.run() + # thread + processor = ChunkRecordingExecutor( + recording, + func, + init_func, + init_args, + verbose=True, + progress_bar=True, + pool_engine="thread", + n_jobs=2, + chunk_duration="200ms", + job_name="job_name", + ) + processor.run() + def test_fix_job_kwargs(): # test negative n_jobs @@ -220,10 +236,56 @@ def test_split_job_kwargs(): assert "other_param" not in job_kwargs and "n_jobs" in job_kwargs and "progress_bar" in job_kwargs + + +def func2(segment_index, start_frame, end_frame, worker_dict): + time.sleep(0.010) + # print(os.getpid(), worker_dict["worker_index"]) + return worker_dict["worker_index"] + + +def init_func2(): + # this leave time for other thread/process to start + time.sleep(0.010) + worker_dict = {} + return worker_dict + + +def test_worker_index(): + recording = generate_recording(num_channels=2) + init_args = tuple() + + for i in range(2): + # making this 2 times ensure to test that global variables are correctly reset + for pool_engine in ("process", "thread"): + processor = ChunkRecordingExecutor( + recording, + func2, + init_func2, + init_args, + progress_bar=False, + gather_func=None, + pool_engine=pool_engine, + n_jobs=2, + handle_returns=True, + chunk_duration="200ms", + need_worker_index=True + ) + res = processor.run() + # we should have a mix of 0 and 1 + assert 0 in res + assert 1 in res + +def test_get_best_job_kwargs(): + job_kwargs = get_best_job_kwargs() + print(job_kwargs) + if __name__ == "__main__": # test_divide_segment_into_chunks() # test_ensure_n_jobs() # test_ensure_chunk_size() # test_ChunkRecordingExecutor() - test_fix_job_kwargs() + # test_fix_job_kwargs() # test_split_job_kwargs() + # test_worker_index() + test_get_best_job_kwargs() diff --git a/src/spikeinterface/core/tests/test_waveform_tools.py b/src/spikeinterface/core/tests/test_waveform_tools.py index 845eaf1310..ed27815758 100644 --- a/src/spikeinterface/core/tests/test_waveform_tools.py +++ b/src/spikeinterface/core/tests/test_waveform_tools.py @@ -173,21 +173,45 @@ def test_estimate_templates_with_accumulator(): job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") - templates = estimate_templates_with_accumulator( - recording, spikes, sorting.unit_ids, nbefore, nafter, return_scaled=True, **job_kwargs - ) - print(templates.shape) - assert templates.shape[0] == sorting.unit_ids.size - assert templates.shape[1] == nbefore + nafter - assert templates.shape[2] == recording.get_num_channels() + # here we compare the result with the same mechanism with with several worker pool size + # this means that that acumulator are splitted and then agglomerated back + # this should lead to very small diff + # n_jobs=1 is done in loop + templates_by_worker = [] + + if platform.system() == "Linux": + engine_loop = ["thread", "process"] + else: + engine_loop = ["thread"] + + for pool_engine in engine_loop: + for n_jobs in (1, 2, 8): + job_kwargs = dict(pool_engine=pool_engine, n_jobs=n_jobs, progress_bar=True, chunk_duration="1s") + templates = estimate_templates_with_accumulator( + recording, spikes, sorting.unit_ids, nbefore, nafter, return_scaled=True, **job_kwargs + ) + assert templates.shape[0] == sorting.unit_ids.size + assert templates.shape[1] == nbefore + nafter + assert templates.shape[2] == recording.get_num_channels() + assert np.any(templates != 0) + + templates_by_worker.append(templates) + if len(templates_by_worker) > 1: + templates_loop = templates_by_worker[0] + np.testing.assert_almost_equal(templates, templates_loop, decimal=4) + + # import matplotlib.pyplot as plt + # fig, axs = plt.subplots(nrows=2, sharex=True) + # for unit_index, unit_id in enumerate(sorting.unit_ids): + # ax = axs[0] + # ax.set_title(f"{pool_engine} {n_jobs}") + # ax.plot(templates[unit_index, :, :].T.flatten()) + # ax.plot(templates_loop[unit_index, :, :].T.flatten(), color="k", ls="--") + # ax = axs[1] + # ax.plot((templates - templates_loop)[unit_index, :, :].T.flatten(), color="k", ls="--") + # plt.show() - assert np.any(templates != 0) - # import matplotlib.pyplot as plt - # fig, ax = plt.subplots() - # for unit_index, unit_id in enumerate(sorting.unit_ids): - # ax.plot(templates[unit_index, :, :].T.flatten()) - # plt.show() def test_estimate_templates(): @@ -225,6 +249,6 @@ def test_estimate_templates(): if __name__ == "__main__": - test_waveform_tools() + # test_waveform_tools() test_estimate_templates_with_accumulator() - test_estimate_templates() + # test_estimate_templates() diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index 3affd7f0ec..8a7b15f886 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -296,17 +296,17 @@ def _init_worker_distribute_buffers( recording, unit_ids, spikes, arrays_info, nbefore, nafter, return_scaled, inds_by_unit, mode, sparsity_mask ): # create a local dict per worker - worker_ctx = {} + worker_dict = {} if isinstance(recording, dict): from spikeinterface.core import load_extractor recording = load_extractor(recording) - worker_ctx["recording"] = recording + worker_dict["recording"] = recording if mode == "memmap": # in memmap mode we have the "too many open file" problem with linux # memmap file will be open on demand and not globally per worker - worker_ctx["arrays_info"] = arrays_info + worker_dict["arrays_info"] = arrays_info elif mode == "shared_memory": from multiprocessing.shared_memory import SharedMemory @@ -321,33 +321,33 @@ def _init_worker_distribute_buffers( waveforms_by_units[unit_id] = arr # we need a reference to all sham otherwise we get segment fault!!! shms[unit_id] = shm - worker_ctx["shms"] = shms - worker_ctx["waveforms_by_units"] = waveforms_by_units + worker_dict["shms"] = shms + worker_dict["waveforms_by_units"] = waveforms_by_units - worker_ctx["unit_ids"] = unit_ids - worker_ctx["spikes"] = spikes + worker_dict["unit_ids"] = unit_ids + worker_dict["spikes"] = spikes - worker_ctx["nbefore"] = nbefore - worker_ctx["nafter"] = nafter - worker_ctx["return_scaled"] = return_scaled - worker_ctx["inds_by_unit"] = inds_by_unit - worker_ctx["sparsity_mask"] = sparsity_mask - worker_ctx["mode"] = mode + worker_dict["nbefore"] = nbefore + worker_dict["nafter"] = nafter + worker_dict["return_scaled"] = return_scaled + worker_dict["inds_by_unit"] = inds_by_unit + worker_dict["sparsity_mask"] = sparsity_mask + worker_dict["mode"] = mode - return worker_ctx + return worker_dict # used by ChunkRecordingExecutor -def _worker_distribute_buffers(segment_index, start_frame, end_frame, worker_ctx): +def _worker_distribute_buffers(segment_index, start_frame, end_frame, worker_dict): # recover variables of the worker - recording = worker_ctx["recording"] - unit_ids = worker_ctx["unit_ids"] - spikes = worker_ctx["spikes"] - nbefore = worker_ctx["nbefore"] - nafter = worker_ctx["nafter"] - return_scaled = worker_ctx["return_scaled"] - inds_by_unit = worker_ctx["inds_by_unit"] - sparsity_mask = worker_ctx["sparsity_mask"] + recording = worker_dict["recording"] + unit_ids = worker_dict["unit_ids"] + spikes = worker_dict["spikes"] + nbefore = worker_dict["nbefore"] + nafter = worker_dict["nafter"] + return_scaled = worker_dict["return_scaled"] + inds_by_unit = worker_dict["inds_by_unit"] + sparsity_mask = worker_dict["sparsity_mask"] seg_size = recording.get_num_samples(segment_index=segment_index) @@ -383,12 +383,12 @@ def _worker_distribute_buffers(segment_index, start_frame, end_frame, worker_ctx if in_chunk_pos.size == 0: continue - if worker_ctx["mode"] == "memmap": + if worker_dict["mode"] == "memmap": # open file in demand (and also autoclose it after) - filename = worker_ctx["arrays_info"][unit_id] + filename = worker_dict["arrays_info"][unit_id] wfs = np.load(str(filename), mmap_mode="r+") - elif worker_ctx["mode"] == "shared_memory": - wfs = worker_ctx["waveforms_by_units"][unit_id] + elif worker_dict["mode"] == "shared_memory": + wfs = worker_dict["waveforms_by_units"][unit_id] for pos in in_chunk_pos: sample_index = spikes[inds[pos]]["sample_index"] @@ -548,50 +548,50 @@ def extract_waveforms_to_single_buffer( def _init_worker_distribute_single_buffer( recording, spikes, wf_array_info, nbefore, nafter, return_scaled, mode, sparsity_mask ): - worker_ctx = {} - worker_ctx["recording"] = recording - worker_ctx["wf_array_info"] = wf_array_info - worker_ctx["spikes"] = spikes - worker_ctx["nbefore"] = nbefore - worker_ctx["nafter"] = nafter - worker_ctx["return_scaled"] = return_scaled - worker_ctx["sparsity_mask"] = sparsity_mask - worker_ctx["mode"] = mode + worker_dict = {} + worker_dict["recording"] = recording + worker_dict["wf_array_info"] = wf_array_info + worker_dict["spikes"] = spikes + worker_dict["nbefore"] = nbefore + worker_dict["nafter"] = nafter + worker_dict["return_scaled"] = return_scaled + worker_dict["sparsity_mask"] = sparsity_mask + worker_dict["mode"] = mode if mode == "memmap": filename = wf_array_info["filename"] all_waveforms = np.load(str(filename), mmap_mode="r+") - worker_ctx["all_waveforms"] = all_waveforms + worker_dict["all_waveforms"] = all_waveforms elif mode == "shared_memory": from multiprocessing.shared_memory import SharedMemory shm_name, dtype, shape = wf_array_info["shm_name"], wf_array_info["dtype"], wf_array_info["shape"] shm = SharedMemory(shm_name) all_waveforms = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf) - worker_ctx["shm"] = shm - worker_ctx["all_waveforms"] = all_waveforms + worker_dict["shm"] = shm + worker_dict["all_waveforms"] = all_waveforms # prepare segment slices segment_slices = [] for segment_index in range(recording.get_num_segments()): s0, s1 = np.searchsorted(spikes["segment_index"], [segment_index, segment_index + 1]) segment_slices.append((s0, s1)) - worker_ctx["segment_slices"] = segment_slices + worker_dict["segment_slices"] = segment_slices - return worker_ctx + return worker_dict # used by ChunkRecordingExecutor -def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, worker_ctx): +def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, worker_dict): # recover variables of the worker - recording = worker_ctx["recording"] - segment_slices = worker_ctx["segment_slices"] - spikes = worker_ctx["spikes"] - nbefore = worker_ctx["nbefore"] - nafter = worker_ctx["nafter"] - return_scaled = worker_ctx["return_scaled"] - sparsity_mask = worker_ctx["sparsity_mask"] - all_waveforms = worker_ctx["all_waveforms"] + recording = worker_dict["recording"] + segment_slices = worker_dict["segment_slices"] + spikes = worker_dict["spikes"] + nbefore = worker_dict["nbefore"] + nafter = worker_dict["nafter"] + return_scaled = worker_dict["return_scaled"] + sparsity_mask = worker_dict["sparsity_mask"] + all_waveforms = worker_dict["all_waveforms"] seg_size = recording.get_num_samples(segment_index=segment_index) @@ -630,7 +630,7 @@ def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, work wf = wf[:, mask] all_waveforms[spike_index, :, : wf.shape[1]] = wf - if worker_ctx["mode"] == "memmap": + if worker_dict["mode"] == "memmap": all_waveforms.flush() @@ -843,12 +843,6 @@ def estimate_templates_with_accumulator( waveform_squared_accumulator_per_worker = None shm_squared_name = None - # trick to get the work_index given pid arrays - lock = multiprocessing.Lock() - array_pid = multiprocessing.Array("i", num_worker) - for i in range(num_worker): - array_pid[i] = -1 - func = _worker_estimate_templates init_func = _init_worker_estimate_templates @@ -862,14 +856,12 @@ def estimate_templates_with_accumulator( nbefore, nafter, return_scaled, - lock, - array_pid, ) if job_name is None: job_name = "estimate_templates_with_accumulator" processor = ChunkRecordingExecutor( - recording, func, init_func, init_args, job_name=job_name, verbose=verbose, **job_kwargs + recording, func, init_func, init_args, job_name=job_name, verbose=verbose, need_worker_index=True, **job_kwargs ) processor.run() @@ -920,15 +912,13 @@ def _init_worker_estimate_templates( nbefore, nafter, return_scaled, - lock, - array_pid, ): - worker_ctx = {} - worker_ctx["recording"] = recording - worker_ctx["spikes"] = spikes - worker_ctx["nbefore"] = nbefore - worker_ctx["nafter"] = nafter - worker_ctx["return_scaled"] = return_scaled + worker_dict = {} + worker_dict["recording"] = recording + worker_dict["spikes"] = spikes + worker_dict["nbefore"] = nbefore + worker_dict["nafter"] = nafter + worker_dict["return_scaled"] = return_scaled from multiprocessing.shared_memory import SharedMemory import multiprocessing @@ -936,48 +926,36 @@ def _init_worker_estimate_templates( shm = SharedMemory(shm_name) waveform_accumulator_per_worker = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf) - worker_ctx["shm"] = shm - worker_ctx["waveform_accumulator_per_worker"] = waveform_accumulator_per_worker + worker_dict["shm"] = shm + worker_dict["waveform_accumulator_per_worker"] = waveform_accumulator_per_worker if shm_squared_name is not None: shm_squared = SharedMemory(shm_squared_name) waveform_squared_accumulator_per_worker = np.ndarray(shape=shape, dtype=dtype, buffer=shm_squared.buf) - worker_ctx["shm_squared"] = shm_squared - worker_ctx["waveform_squared_accumulator_per_worker"] = waveform_squared_accumulator_per_worker + worker_dict["shm_squared"] = shm_squared + worker_dict["waveform_squared_accumulator_per_worker"] = waveform_squared_accumulator_per_worker # prepare segment slices segment_slices = [] for segment_index in range(recording.get_num_segments()): s0, s1 = np.searchsorted(spikes["segment_index"], [segment_index, segment_index + 1]) segment_slices.append((s0, s1)) - worker_ctx["segment_slices"] = segment_slices - - child_process = multiprocessing.current_process() - - lock.acquire() - num_worker = None - for i in range(len(array_pid)): - if array_pid[i] == -1: - num_worker = i - array_pid[i] = child_process.ident - break - worker_ctx["worker_index"] = num_worker - lock.release() + worker_dict["segment_slices"] = segment_slices - return worker_ctx + return worker_dict # used by ChunkRecordingExecutor -def _worker_estimate_templates(segment_index, start_frame, end_frame, worker_ctx): +def _worker_estimate_templates(segment_index, start_frame, end_frame, worker_dict): # recover variables of the worker - recording = worker_ctx["recording"] - segment_slices = worker_ctx["segment_slices"] - spikes = worker_ctx["spikes"] - nbefore = worker_ctx["nbefore"] - nafter = worker_ctx["nafter"] - waveform_accumulator_per_worker = worker_ctx["waveform_accumulator_per_worker"] - waveform_squared_accumulator_per_worker = worker_ctx.get("waveform_squared_accumulator_per_worker", None) - worker_index = worker_ctx["worker_index"] - return_scaled = worker_ctx["return_scaled"] + recording = worker_dict["recording"] + segment_slices = worker_dict["segment_slices"] + spikes = worker_dict["spikes"] + nbefore = worker_dict["nbefore"] + nafter = worker_dict["nafter"] + waveform_accumulator_per_worker = worker_dict["waveform_accumulator_per_worker"] + waveform_squared_accumulator_per_worker = worker_dict.get("waveform_squared_accumulator_per_worker", None) + worker_index = worker_dict["worker_index"] + return_scaled = worker_dict["return_scaled"] seg_size = recording.get_num_samples(segment_index=segment_index) diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 809f2c5bba..84fbfc5965 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -316,13 +316,13 @@ def _run(self, verbose=False, **job_kwargs): job_kwargs = fix_job_kwargs(job_kwargs) n_jobs = job_kwargs["n_jobs"] progress_bar = job_kwargs["progress_bar"] - max_threads_per_process = job_kwargs["max_threads_per_process"] + max_threads_per_worker = job_kwargs["max_threads_per_worker"] mp_context = job_kwargs["mp_context"] # fit model/models # TODO : make parralel for by_channel_global and concatenated if mode == "by_channel_local": - pca_models = self._fit_by_channel_local(n_jobs, progress_bar, max_threads_per_process, mp_context) + pca_models = self._fit_by_channel_local(n_jobs, progress_bar, max_threads_per_worker, mp_context) for chan_ind, chan_id in enumerate(self.sorting_analyzer.channel_ids): self.data[f"pca_model_{mode}_{chan_id}"] = pca_models[chan_ind] pca_model = pca_models @@ -415,7 +415,7 @@ def run_for_all_spikes(self, file_path=None, verbose=False, **job_kwargs): ) processor.run() - def _fit_by_channel_local(self, n_jobs, progress_bar, max_threads_per_process, mp_context): + def _fit_by_channel_local(self, n_jobs, progress_bar, max_threads_per_worker, mp_context): from sklearn.decomposition import IncrementalPCA p = self.params @@ -444,10 +444,10 @@ def _fit_by_channel_local(self, n_jobs, progress_bar, max_threads_per_process, m pca = pca_models[chan_ind] pca.partial_fit(wfs[:, :, wf_ind]) else: - # create list of args to parallelize. For convenience, the max_threads_per_process is passed + # create list of args to parallelize. For convenience, the max_threads_per_worker is passed # as last argument items = [ - (chan_ind, pca_models[chan_ind], wfs[:, :, wf_ind], max_threads_per_process) + (chan_ind, pca_models[chan_ind], wfs[:, :, wf_ind], max_threads_per_worker) for wf_ind, chan_ind in enumerate(channel_inds) ] n_jobs = min(n_jobs, len(items)) @@ -687,12 +687,12 @@ def _init_work_all_pc_extractor(recording, sorting, all_pcs_args, nbefore, nafte def _partial_fit_one_channel(args): - chan_ind, pca_model, wf_chan, max_threads_per_process = args + chan_ind, pca_model, wf_chan, max_threads_per_worker = args - if max_threads_per_process is None: + if max_threads_per_worker is None: pca_model.partial_fit(wf_chan) return chan_ind, pca_model else: - with threadpool_limits(limits=int(max_threads_per_process)): + with threadpool_limits(limits=int(max_threads_per_worker)): pca_model.partial_fit(wf_chan) return chan_ind, pca_model diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index 7a509c410f..ecfc39f2c6 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -27,7 +27,7 @@ def test_multi_processing(self): ) sorting_analyzer.compute("principal_components", mode="by_channel_local", n_jobs=2) sorting_analyzer.compute( - "principal_components", mode="by_channel_local", n_jobs=2, max_threads_per_process=4, mp_context="spawn" + "principal_components", mode="by_channel_local", n_jobs=2, max_threads_per_worker=4, mp_context="spawn" ) def test_mode_concatenated(self): diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index ca21f1e45f..96fa58f4ef 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -60,7 +60,7 @@ def compute_pc_metrics( n_jobs=1, progress_bar=False, mp_context=None, - max_threads_per_process=None, + max_threads_per_worker=None, ) -> dict: """ Calculate principal component derived metrics. @@ -157,7 +157,8 @@ def compute_pc_metrics( pcs = dense_projections[np.isin(all_labels, neighbor_unit_ids)][:, :, neighbor_channel_indices] pcs_flat = pcs.reshape(pcs.shape[0], -1) - func_args = (pcs_flat, labels, non_nn_metrics, unit_id, unit_ids, metric_params, max_threads_per_process) + func_args = (pcs_flat, labels, non_nn_metrics, unit_id, unit_ids, metric_params, max_threads_per_worker) + items.append(func_args) if not run_in_parallel and non_nn_metrics: @@ -987,12 +988,13 @@ def _compute_isolation(pcs_target_unit, pcs_other_unit, n_neighbors: int): def pca_metrics_one_unit(args): - (pcs_flat, labels, metric_names, unit_id, unit_ids, metric_params, max_threads_per_process) = args - if max_threads_per_process is None: + (pcs_flat, labels, metric_names, unit_id, unit_ids, metric_params, max_threads_per_worker) = args + + if max_threads_per_worker is None: return _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, metric_params) else: - with threadpool_limits(limits=int(max_threads_per_process)): + with threadpool_limits(limits=int(max_threads_per_worker)): return _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, metric_params) diff --git a/src/spikeinterface/qualitymetrics/tests/conftest.py b/src/spikeinterface/qualitymetrics/tests/conftest.py index 01fa16c8d7..9878adf142 100644 --- a/src/spikeinterface/qualitymetrics/tests/conftest.py +++ b/src/spikeinterface/qualitymetrics/tests/conftest.py @@ -8,8 +8,8 @@ job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") -@pytest.fixture(scope="module") -def small_sorting_analyzer(): + +def make_small_analyzer(): recording, sorting = generate_ground_truth_recording( durations=[2.0], num_units=10, @@ -34,6 +34,9 @@ def small_sorting_analyzer(): return sorting_analyzer +@pytest.fixture(scope="module") +def small_sorting_analyzer(): + return make_small_analyzer() @pytest.fixture(scope="module") def sorting_analyzer_simple(): diff --git a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py index f2e912c6b4..312c3949b3 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py @@ -15,11 +15,25 @@ def test_calculate_pc_metrics(small_sorting_analyzer): res2 = pd.DataFrame(res2) for metric_name in res1.columns: - if metric_name != "nn_unit_id": - assert not np.all(np.isnan(res1[metric_name].values)) - assert not np.all(np.isnan(res2[metric_name].values)) + values1 = res1[metric_name].values + values2 = res1[metric_name].values - assert np.array_equal(res1[metric_name].values, res2[metric_name].values) + if metric_name != "nn_unit_id": + assert not np.all(np.isnan(values1)) + assert not np.all(np.isnan(values2)) + + if values1.dtype.kind == "f": + np.testing.assert_almost_equal(values1, values2, decimal=4) + # import matplotlib.pyplot as plt + # fig, axs = plt.subplots(nrows=2, share=True) + # ax =a xs[0] + # ax.plot(res1[metric_name].values) + # ax.plot(res2[metric_name].values) + # ax =a xs[1] + # ax.plot(res2[metric_name].values - res1[metric_name].values) + # plt.show() + else: + assert np.array_equal(values1, values2) def test_pca_metrics_multi_processing(small_sorting_analyzer): @@ -31,13 +45,18 @@ def test_pca_metrics_multi_processing(small_sorting_analyzer): print(f"Computing PCA metrics with 1 thread per process") res1 = compute_pc_metrics( - sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_process=1, progress_bar=True + sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_worker=1, progress_bar=True ) print(f"Computing PCA metrics with 2 thread per process") res2 = compute_pc_metrics( - sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_process=2, progress_bar=True + sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_worker=2, progress_bar=True ) print("Computing PCA metrics with spawn context") res2 = compute_pc_metrics( - sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_process=2, progress_bar=True + sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_worker=2, progress_bar=True ) + +if __name__ == "__main__": + from spikeinterface.qualitymetrics.tests.conftest import make_small_analyzer + small_sorting_analyzer = make_small_analyzer() + test_calculate_pc_metrics(small_sorting_analyzer) \ No newline at end of file diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py index 4a7b722aea..e618cfbfb6 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merge.py +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -261,7 +261,7 @@ def find_merge_pairs( **job_kwargs, # n_jobs=1, # mp_context="fork", - # max_threads_per_process=1, + # max_threads_per_worker=1, # progress_bar=True, ): """ @@ -299,7 +299,7 @@ def find_merge_pairs( n_jobs = job_kwargs["n_jobs"] mp_context = job_kwargs.get("mp_context", None) - max_threads_per_process = job_kwargs.get("max_threads_per_process", 1) + max_threads_per_worker = job_kwargs.get("max_threads_per_worker", 1) progress_bar = job_kwargs["progress_bar"] Executor = get_poolexecutor(n_jobs) @@ -316,7 +316,7 @@ def find_merge_pairs( templates, method, method_kwargs, - max_threads_per_process, + max_threads_per_worker, ), ) as pool: jobs = [] @@ -354,7 +354,7 @@ def find_pair_worker_init( templates, method, method_kwargs, - max_threads_per_process, + max_threads_per_worker, ): global _ctx _ctx = {} @@ -366,7 +366,7 @@ def find_pair_worker_init( _ctx["method"] = method _ctx["method_kwargs"] = method_kwargs _ctx["method_class"] = find_pair_method_dict[method] - _ctx["max_threads_per_process"] = max_threads_per_process + _ctx["max_threads_per_worker"] = max_threads_per_worker # if isinstance(features_dict_or_folder, dict): # _ctx["features"] = features_dict_or_folder @@ -380,7 +380,7 @@ def find_pair_worker_init( def find_pair_function_wrapper(label0, label1): global _ctx - with threadpool_limits(limits=_ctx["max_threads_per_process"]): + with threadpool_limits(limits=_ctx["max_threads_per_worker"]): is_merge, label0, label1, shift, merge_value = _ctx["method_class"].merge( label0, label1, diff --git a/src/spikeinterface/sortingcomponents/clustering/split.py b/src/spikeinterface/sortingcomponents/clustering/split.py index 15917934a8..3c2e878c39 100644 --- a/src/spikeinterface/sortingcomponents/clustering/split.py +++ b/src/spikeinterface/sortingcomponents/clustering/split.py @@ -65,7 +65,7 @@ def split_clusters( n_jobs = job_kwargs["n_jobs"] mp_context = job_kwargs.get("mp_context", None) progress_bar = job_kwargs["progress_bar"] - max_threads_per_process = job_kwargs.get("max_threads_per_process", 1) + max_threads_per_worker = job_kwargs.get("max_threads_per_worker", 1) original_labels = peak_labels peak_labels = peak_labels.copy() @@ -77,7 +77,7 @@ def split_clusters( max_workers=n_jobs, initializer=split_worker_init, mp_context=get_context(method=mp_context), - initargs=(recording, features_dict_or_folder, original_labels, method, method_kwargs, max_threads_per_process), + initargs=(recording, features_dict_or_folder, original_labels, method, method_kwargs, max_threads_per_worker), ) as pool: labels_set = np.setdiff1d(peak_labels, [-1]) current_max_label = np.max(labels_set) + 1 @@ -133,7 +133,7 @@ def split_clusters( def split_worker_init( - recording, features_dict_or_folder, original_labels, method, method_kwargs, max_threads_per_process + recording, features_dict_or_folder, original_labels, method, method_kwargs, max_threads_per_worker ): global _ctx _ctx = {} @@ -144,14 +144,14 @@ def split_worker_init( _ctx["method"] = method _ctx["method_kwargs"] = method_kwargs _ctx["method_class"] = split_methods_dict[method] - _ctx["max_threads_per_process"] = max_threads_per_process + _ctx["max_threads_per_worker"] = max_threads_per_worker _ctx["features"] = FeaturesLoader.from_dict_or_folder(features_dict_or_folder) _ctx["peaks"] = _ctx["features"]["peaks"] def split_function_wrapper(peak_indices, recursion_level): global _ctx - with threadpool_limits(limits=_ctx["max_threads_per_process"]): + with threadpool_limits(limits=_ctx["max_threads_per_worker"]): is_split, local_labels = _ctx["method_class"].split( peak_indices, _ctx["peaks"], _ctx["features"], recursion_level, **_ctx["method_kwargs"] )