diff --git a/phylib/io/model.py b/phylib/io/model.py index 4e5a63a..3fdfdcf 100644 --- a/phylib/io/model.py +++ b/phylib/io/model.py @@ -683,10 +683,13 @@ def _load_templates(self): try: path = self._find_path( 'templates.npy', 'templates.waveforms.npy', 'templates.waveforms.*.npy') - data = self._read_array(path, mmap_mode='r') + data = self._read_array(path, mmap_mode='r+') data = np.atleast_3d(data) assert data.ndim == 3 assert data.dtype in (np.float32, np.float64) + # WARNING: this will load the full array in memory, might cause memory problems + empty_templates = np.all(np.all(np.isnan(data), axis=1), axis=1) + data[empty_templates, ...] = 0 n_templates, n_samples, n_channels_loc = data.shape except IOError: return @@ -818,6 +821,7 @@ def _find_best_channels(self, template, amplitude_threshold=None): # Compute the template amplitude on each channel. assert template.ndim == 2 # shape: (n_samples, n_channels) amplitude = template.max(axis=0) - template.min(axis=0) + assert not np.all(np.isnan(amplitude)), "Template is all NaN!" assert amplitude.ndim == 1 # shape: (n_channels,) # Find the peak channel. best_channel = np.argmax(amplitude)