From edb4f43c557c88d27775e796c53fe383bf7b6ffd Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Wed, 1 Apr 2026 12:48:12 -0400 Subject: [PATCH 1/6] generate: Add temporal correlations to noise simulator --- src/spikeinterface/core/generate.py | 605 ++++++++++++++---- .../core/tests/test_generate.py | 245 +++++-- 2 files changed, 689 insertions(+), 161 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 35116a9e4c..420af71b44 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -159,9 +159,13 @@ def generate_sorting( spikes.append(spikes_in_seg) if add_spikes_on_borders: - spikes_on_borders = np.zeros(2 * num_spikes_per_border, dtype=minimum_spike_dtype) + spikes_on_borders = np.zeros( + 2 * num_spikes_per_border, dtype=minimum_spike_dtype + ) spikes_on_borders["segment_index"] = segment_index - spikes_on_borders["unit_index"] = rng.choice(num_units, size=2 * num_spikes_per_border, replace=True) + spikes_on_borders["unit_index"] = rng.choice( + num_units, size=2 * num_spikes_per_border, replace=True + ) # at start spikes_on_borders["sample_index"][:num_spikes_per_border] = rng.integers( 0, border_size_samples, num_spikes_per_border @@ -173,7 +177,11 @@ def generate_sorting( spikes.append(spikes_on_borders) spikes = np.concatenate(spikes) - spikes = spikes[np.lexsort((spikes["unit_index"], spikes["sample_index"], spikes["segment_index"]))] + spikes = spikes[ + np.lexsort( + (spikes["unit_index"], spikes["sample_index"], spikes["segment_index"]) + ) + ] sorting = NumpySorting(spikes, sampling_frequency, unit_ids) @@ -217,18 +225,26 @@ def add_synchrony_to_sorting(sorting, sync_event_ratio=0.3, seed=None): sample_index = spike["sample_index"] if sample_index not in units_used_for_spike: units_used_for_spike[sample_index] = np.array([spike["unit_index"]]) - units_not_used = unit_ids[~np.isin(unit_ids, units_used_for_spike[sample_index])] + units_not_used = unit_ids[ + ~np.isin(unit_ids, units_used_for_spike[sample_index]) + ] if len(units_not_used) == 0: continue new_unit_indices[i] = rng.choice(units_not_used) - units_used_for_spike[sample_index] = np.append(units_used_for_spike[sample_index], new_unit_indices[i]) + units_used_for_spike[sample_index] = np.append( + units_used_for_spike[sample_index], new_unit_indices[i] + ) spikes_duplicated["unit_index"] = new_unit_indices - sort_idxs = np.lexsort([spikes_duplicated["sample_index"], spikes_duplicated["segment_index"]]) + sort_idxs = np.lexsort( + [spikes_duplicated["sample_index"], spikes_duplicated["segment_index"]] + ) spikes_duplicated = spikes_duplicated[sort_idxs] - synchronous_spikes = NumpySorting(spikes_duplicated, sorting.get_sampling_frequency(), unit_ids) + synchronous_spikes = NumpySorting( + spikes_duplicated, sorting.get_sampling_frequency(), unit_ids + ) sorting = TransformSorting.add_from_sorting(sorting, synchronous_spikes) return sorting @@ -276,12 +292,18 @@ def generate_sorting_to_inject( for segment_index in range(sorting.get_num_segments()): for unit_id in sorting.unit_ids: - spike_train = sorting.get_unit_spike_train(unit_id, segment_index=segment_index) - n_injection = min(max_injected_per_unit, int(round(injected_rate * len(spike_train)))) + spike_train = sorting.get_unit_spike_train( + unit_id, segment_index=segment_index + ) + n_injection = min( + max_injected_per_unit, int(round(injected_rate * len(spike_train))) + ) # Inject more, then take out all that violate the refractory period. n = int(n_injection + 10 * np.sqrt(n_injection)) injected_spike_train = np.sort( - np.random.uniform(low=0, high=num_samples[segment_index], size=n).astype(np.int64) + np.random.uniform( + low=0, high=num_samples[segment_index], size=n + ).astype(np.int64) ) # Remove spikes that are in the refractory period. @@ -290,16 +312,22 @@ def generate_sorting_to_inject( # Remove spikes that violate the refractory period of the real spikes. # TODO: Need a better & faster way than this. - min_diff = np.min(np.abs(injected_spike_train[:, None] - spike_train[None, :]), axis=1) + min_diff = np.min( + np.abs(injected_spike_train[:, None] - spike_train[None, :]), axis=1 + ) violations = min_diff < t_r injected_spike_train = injected_spike_train[~violations] if len(injected_spike_train) > n_injection: - injected_spike_train = np.sort(rng.choice(injected_spike_train, n_injection, replace=False)) + injected_spike_train = np.sort( + rng.choice(injected_spike_train, n_injection, replace=False) + ) injected_spike_trains[segment_index][unit_id] = injected_spike_train - return NumpySorting.from_unit_dict(injected_spike_trains, sorting.get_sampling_frequency()) + return NumpySorting.from_unit_dict( + injected_spike_trains, sorting.get_sampling_frequency() + ) class TransformSorting(BaseSorting): @@ -343,11 +371,13 @@ def __init__( if new_unit_ids is not None: new_unit_ids = list(new_unit_ids) - assert ~np.any( - np.isin(new_unit_ids, sorting.unit_ids) - ), "some units ids are already present. Consider using added_spikes_existing_units" + assert ~np.any(np.isin(new_unit_ids, sorting.unit_ids)), ( + "some units ids are already present. Consider using added_spikes_existing_units" + ) if len(new_unit_ids) > 0: - assert type(unit_ids[0]) == type(new_unit_ids[0]), "unit_ids should have the same type" + assert type(unit_ids[0]) == type(new_unit_ids[0]), ( + "unit_ids should have the same type" + ) unit_ids = unit_ids + list(new_unit_ids) BaseSorting.__init__(self, sampling_frequency, unit_ids) @@ -356,42 +386,74 @@ def __init__( self._cached_spike_vector = sorting.to_spike_vector().copy() self.refractory_period_ms = refractory_period_ms - self.added_spikes_from_existing_mask = np.zeros(len(self._cached_spike_vector), dtype=bool) - self.added_spikes_from_new_mask = np.zeros(len(self._cached_spike_vector), dtype=bool) + self.added_spikes_from_existing_mask = np.zeros( + len(self._cached_spike_vector), dtype=bool + ) + self.added_spikes_from_new_mask = np.zeros( + len(self._cached_spike_vector), dtype=bool + ) - if added_spikes_existing_units is not None and len(added_spikes_existing_units) > 0: - assert ( - added_spikes_existing_units.dtype == minimum_spike_dtype - ), "added_spikes_existing_units should be a spike vector" + if ( + added_spikes_existing_units is not None + and len(added_spikes_existing_units) > 0 + ): + assert added_spikes_existing_units.dtype == minimum_spike_dtype, ( + "added_spikes_existing_units should be a spike vector" + ) added_unit_indices = np.arange(len(self.parent_unit_ids)) - self._cached_spike_vector = np.concatenate((self._cached_spike_vector, added_spikes_existing_units)) + self._cached_spike_vector = np.concatenate( + (self._cached_spike_vector, added_spikes_existing_units) + ) self.added_spikes_from_existing_mask = np.concatenate( - (self.added_spikes_from_existing_mask, np.ones(len(added_spikes_existing_units), dtype=bool)) + ( + self.added_spikes_from_existing_mask, + np.ones(len(added_spikes_existing_units), dtype=bool), + ) ) self.added_spikes_from_new_mask = np.concatenate( - (self.added_spikes_from_new_mask, np.zeros(len(added_spikes_existing_units), dtype=bool)) + ( + self.added_spikes_from_new_mask, + np.zeros(len(added_spikes_existing_units), dtype=bool), + ) ) if added_spikes_new_units is not None and len(added_spikes_new_units) > 0: - assert ( - added_spikes_new_units.dtype == minimum_spike_dtype - ), "added_spikes_new_units should be a spike vector" - self._cached_spike_vector = np.concatenate((self._cached_spike_vector, added_spikes_new_units)) + assert added_spikes_new_units.dtype == minimum_spike_dtype, ( + "added_spikes_new_units should be a spike vector" + ) + self._cached_spike_vector = np.concatenate( + (self._cached_spike_vector, added_spikes_new_units) + ) self.added_spikes_from_existing_mask = np.concatenate( - (self.added_spikes_from_existing_mask, np.zeros(len(added_spikes_new_units), dtype=bool)) + ( + self.added_spikes_from_existing_mask, + np.zeros(len(added_spikes_new_units), dtype=bool), + ) ) self.added_spikes_from_new_mask = np.concatenate( - (self.added_spikes_from_new_mask, np.ones(len(added_spikes_new_units), dtype=bool)) + ( + self.added_spikes_from_new_mask, + np.ones(len(added_spikes_new_units), dtype=bool), + ) ) - sort_idxs = np.lexsort([self._cached_spike_vector["sample_index"], self._cached_spike_vector["segment_index"]]) + sort_idxs = np.lexsort( + [ + self._cached_spike_vector["sample_index"], + self._cached_spike_vector["segment_index"], + ] + ) self._cached_spike_vector = self._cached_spike_vector[sort_idxs] - self.added_spikes_from_existing_mask = self.added_spikes_from_existing_mask[sort_idxs] + self.added_spikes_from_existing_mask = self.added_spikes_from_existing_mask[ + sort_idxs + ] self.added_spikes_from_new_mask = self.added_spikes_from_new_mask[sort_idxs] # We need to add the sorting segments for segment_index in range(sorting.get_num_segments()): - segment = SpikeVectorSortingSegment(self._cached_spike_vector, segment_index, unit_ids=self.unit_ids) + segment = SpikeVectorSortingSegment( + self._cached_spike_vector, segment_index, unit_ids=self.unit_ids + ) self.add_sorting_segment(segment) if self.refractory_period_ms is not None: @@ -407,7 +469,9 @@ def __init__( @property def added_spikes_mask(self): - return np.logical_or(self.added_spikes_from_existing_mask, self.added_spikes_from_new_mask) + return np.logical_or( + self.added_spikes_from_existing_mask, self.added_spikes_from_new_mask + ) def get_added_spikes_indices(self): return np.nonzero(self.added_spikes_mask)[0] @@ -422,7 +486,9 @@ def get_added_units_inds(self): return self.unit_ids[len(self.parent_unit_ids) :] @staticmethod - def add_from_sorting(sorting1: BaseSorting, sorting2: BaseSorting, refractory_period_ms=None) -> "TransformSorting": + def add_from_sorting( + sorting1: BaseSorting, sorting2: BaseSorting, refractory_period_ms=None + ) -> "TransformSorting": """ Construct TransformSorting by adding one sorting to one other. @@ -437,10 +503,12 @@ def add_from_sorting(sorting1: BaseSorting, sorting2: BaseSorting, refractory_pe of spikes. Any spike times in added_spikes violating the refractory period will be discarded. """ - assert ( - sorting1.get_sampling_frequency() == sorting2.get_sampling_frequency() - ), "sampling_frequency should be the same" - assert type(sorting1.unit_ids[0]) == type(sorting2.unit_ids[0]), "unit_ids should have the same type" + assert sorting1.get_sampling_frequency() == sorting2.get_sampling_frequency(), ( + "sampling_frequency should be the same" + ) + assert type(sorting1.unit_ids[0]) == type(sorting2.unit_ids[0]), ( + "unit_ids should have the same type" + ) # We detect the indices that are shared by the two sortings mask1 = np.isin(sorting2.unit_ids, sorting1.unit_ids) common_ids = sorting2.unit_ids[mask1] @@ -484,7 +552,9 @@ def add_from_sorting(sorting1: BaseSorting, sorting2: BaseSorting, refractory_pe @staticmethod def add_from_unit_dict( - sorting1: BaseSorting, units_dict_list: list[dict] | dict, refractory_period_ms=None + sorting1: BaseSorting, + units_dict_list: list[dict] | dict, + refractory_period_ms=None, ) -> "TransformSorting": """ Construct TransformSorting by adding one sorting with a @@ -503,13 +573,22 @@ def add_from_unit_dict( of spikes. Any spike times in added_spikes violating the refractory period will be discarded. """ - sorting2 = NumpySorting.from_unit_dict(units_dict_list, sorting1.get_sampling_frequency()) - sorting = TransformSorting.add_from_sorting(sorting1, sorting2, refractory_period_ms) + sorting2 = NumpySorting.from_unit_dict( + units_dict_list, sorting1.get_sampling_frequency() + ) + sorting = TransformSorting.add_from_sorting( + sorting1, sorting2, refractory_period_ms + ) return sorting @staticmethod def from_samples_and_labels( - sorting1, times_list, labels_list, sampling_frequency, unit_ids=None, refractory_period_ms=None + sorting1, + times_list, + labels_list, + sampling_frequency, + unit_ids=None, + refractory_period_ms=None, ) -> "NumpySorting": """ Construct TransformSorting from: @@ -536,8 +615,12 @@ def from_samples_and_labels( discarded. """ - sorting2 = NumpySorting.from_samples_and_labels(times_list, labels_list, sampling_frequency, unit_ids) - sorting = TransformSorting.add_from_sorting(sorting1, sorting2, refractory_period_ms) + sorting2 = NumpySorting.from_samples_and_labels( + times_list, labels_list, sampling_frequency, unit_ids + ) + sorting = TransformSorting.add_from_sorting( + sorting1, sorting2, refractory_period_ms + ) return sorting def clean_refractory_period(self): @@ -554,11 +637,14 @@ def clean_refractory_period(self): * (self._cached_spike_vector["segment_index"] == segment_index) ) to_keep[indices[1:]] = np.logical_or( - to_keep[indices[1:]], np.diff(self._cached_spike_vector[indices]["sample_index"]) > rpv + to_keep[indices[1:]], + np.diff(self._cached_spike_vector[indices]["sample_index"]) > rpv, ) self._cached_spike_vector = self._cached_spike_vector[to_keep] - self.added_spikes_from_existing_mask = self.added_spikes_from_existing_mask[to_keep] + self.added_spikes_from_existing_mask = self.added_spikes_from_existing_mask[ + to_keep + ] self.added_spikes_from_new_mask = self.added_spikes_from_new_mask[to_keep] @@ -647,11 +733,19 @@ def generate_snippets( ) sorting = generate_sorting( - num_units=num_units, sampling_frequency=sampling_frequency, durations=durations, empty_units=empty_units + num_units=num_units, + sampling_frequency=sampling_frequency, + durations=durations, + empty_units=empty_units, ) snippets = snippets_from_sorting( - recording=recording, sorting=sorting, nbefore=nbefore, nafter=nafter, wf_folder=wf_folder, **job_kwargs + recording=recording, + sorting=sorting, + nbefore=nbefore, + nafter=nafter, + wf_folder=wf_folder, + **job_kwargs, ) if set_probe: @@ -743,7 +837,9 @@ def synthesize_poisson_spike_vector( refractory_period_seconds = refractory_period_ms / 1000.0 refractory_period_frames = int(refractory_period_seconds * sampling_frequency) - is_refractory_period_too_long = np.any(refractory_period_seconds >= 1.0 / firing_rates) + is_refractory_period_too_long = np.any( + refractory_period_seconds >= 1.0 / firing_rates + ) if is_refractory_period_too_long: raise ValueError( f"The given refractory period {refractory_period_ms} is too long for the firing rates {firing_rates}" @@ -764,7 +860,9 @@ def synthesize_poisson_spike_vector( binomial_p_modified = np.minimum(binomial_p_modified, 1.0) # Generate inter spike frames, add the refractory samples and accumulate for sorted spike frames - inter_spike_frames = rng.geometric(p=binomial_p_modified[:, np.newaxis], size=(num_units, num_spikes_max)) + inter_spike_frames = rng.geometric( + p=binomial_p_modified[:, np.newaxis], size=(num_units, num_spikes_max) + ) inter_spike_frames[:, 1:] += refractory_period_frames spike_frames = np.cumsum(inter_spike_frames, axis=1, out=inter_spike_frames) spike_frames = spike_frames.ravel() @@ -780,7 +878,9 @@ def synthesize_poisson_spike_vector( # Sort globaly spike_frames = spike_frames[:num_correct_frames] - sort_indices = np.argsort(spike_frames, kind="stable") # I profiled the different kinds, this is the fastest. + sort_indices = np.argsort( + spike_frames, kind="stable" + ) # I profiled the different kinds, this is the fastest. unit_indices = unit_indices[sort_indices] spike_frames = spike_frames[sort_indices] @@ -926,7 +1026,9 @@ def inject_some_duplicate_units(sorting, num=4, max_shift=5, ratio=None, seed=No """ rng = np.random.default_rng(seed) - other_ids = np.arange(np.max(sorting.unit_ids) + 1, np.max(sorting.unit_ids) + num + 1) + other_ids = np.arange( + np.max(sorting.unit_ids) + 1, np.max(sorting.unit_ids) + num + 1 + ) shifts = rng.integers(low=-max_shift, high=max_shift, size=num) shifts[shifts == 0] += max_shift @@ -936,7 +1038,8 @@ def inject_some_duplicate_units(sorting, num=4, max_shift=5, ratio=None, seed=No for segment_index in range(sorting.get_num_segments()): # sorting to dict d = { - unit_id: sorting.get_unit_spike_train(unit_id, segment_index=segment_index) for unit_id in sorting.unit_ids + unit_id: sorting.get_unit_spike_train(unit_id, segment_index=segment_index) + for unit_id in sorting.unit_ids } r = {} @@ -957,13 +1060,17 @@ def inject_some_duplicate_units(sorting, num=4, max_shift=5, ratio=None, seed=No r[unit_id] = times spiketrains.append(r) - sorting_new_units = NumpySorting.from_unit_dict(spiketrains, sampling_frequency=sorting.get_sampling_frequency()) + sorting_new_units = NumpySorting.from_unit_dict( + spiketrains, sampling_frequency=sorting.get_sampling_frequency() + ) sorting_with_dup = TransformSorting.add_from_sorting(sorting, sorting_new_units) return sorting_with_dup -def inject_some_split_units(sorting, split_ids: list, num_split=2, output_ids=False, seed=None): +def inject_some_split_units( + sorting, split_ids: list, num_split=2, output_ids=False, seed=None +): """ Inject some split units in a sorting. @@ -1001,7 +1108,8 @@ def inject_some_split_units(sorting, split_ids: list, num_split=2, output_ids=Fa for segment_index in range(sorting.get_num_segments()): # sorting to dict d = { - unit_id: sorting.get_unit_spike_train(unit_id, segment_index=segment_index) for unit_id in sorting.unit_ids + unit_id: sorting.get_unit_spike_train(unit_id, segment_index=segment_index) + for unit_id in sorting.unit_ids } new_units = {} @@ -1017,14 +1125,18 @@ def inject_some_split_units(sorting, split_ids: list, num_split=2, output_ids=Fa new_units[unit_id] = original_times spiketrains.append(new_units) - sorting_with_split = NumpySorting.from_unit_dict(spiketrains, sampling_frequency=sorting.get_sampling_frequency()) + sorting_with_split = NumpySorting.from_unit_dict( + spiketrains, sampling_frequency=sorting.get_sampling_frequency() + ) if output_ids: return sorting_with_split, other_ids else: return sorting_with_split -def synthetize_spike_train_bad_isi(duration, baseline_rate, num_violations, violation_delta=1e-5): +def synthetize_spike_train_bad_isi( + duration, baseline_rate, num_violations, violation_delta=1e-5 +): """Create a spike train. Has uniform inter-spike intervals, except where isis violations occur. Parameters @@ -1119,7 +1231,9 @@ def __init__( self.durations = durations self.refractory_period_seconds = refractory_period_ms / 1000.0 - is_refractory_period_too_long = np.any(self.refractory_period_seconds >= 1.0 / firing_rates) + is_refractory_period_too_long = np.any( + self.refractory_period_seconds >= 1.0 / firing_rates + ) if is_refractory_period_too_long: raise ValueError( f"The given refractory period {refractory_period_ms} is too long for the firing rates {firing_rates}" @@ -1175,15 +1289,21 @@ def __init__( self.firing_rates = firing_rates if np.isscalar(self.refractory_period_seconds): - self.refractory_period_seconds = np.full(num_units, self.refractory_period_seconds, dtype="float64") + self.refractory_period_seconds = np.full( + num_units, self.refractory_period_seconds, dtype="float64" + ) self.segment_seed = seed - self.units_seed = {unit_id: abs(self.segment_seed + hash(unit_id)) for unit_id in unit_ids} + self.units_seed = { + unit_id: abs(self.segment_seed + hash(unit_id)) for unit_id in unit_ids + } self.num_samples = math.ceil(sampling_frequency * duration) super().__init__(t_start) - def get_unit_spike_train(self, unit_id, start_frame: int | None = None, end_frame: int | None = None) -> np.ndarray: + def get_unit_spike_train( + self, unit_id, start_frame: int | None = None, end_frame: int | None = None + ) -> np.ndarray: unit_seed = self.units_seed[unit_id] unit_index = self.parent_extractor.id_to_index(unit_id) @@ -1218,7 +1338,9 @@ def get_unit_spike_train(self, unit_id, start_frame: int | None = None, end_fram start_index = 0 if end_frame is not None: - end_index = np.searchsorted(spike_frames[start_index:], end_frame, side="left") + end_index = np.searchsorted( + spike_frames[start_index:], end_frame, side="left" + ) else: end_index = int(self.duration * self.sampling_frequency) @@ -1250,6 +1372,14 @@ class NoiseGeneratorRecording(BaseRecording): Std of the white noise (if an array, defined by per channels) cov_matrix : np.ndarray | None, default: None The covariance matrix of the noise + spectral_density : np.ndarray | None, default: None + The spectral density of the noise, such as would be estimated from an array of snippets of shape + `(n_snippets, spectral_snippet_length)` by the following method (Welch's method): + + ```python + periodogram = rfft(snippets, n=next_fast_len(snippets.shape[1]), norm="ortho") + spectral_density = np.sqrt((periodogram * periodogram.conj()).mean(axis=0)) + ``` dtype : np.dtype | str | None, default: "float32" The dtype of the recording. Note that only np.float32 and np.float64 are supported. seed : int | None, default: None @@ -1260,6 +1390,10 @@ class NoiseGeneratorRecording(BaseRecording): very fast and cusume only one noise block. * "on_the_fly": generate on the fly a new noise block by combining seed + noise block index no memory preallocation but a bit more computaion (random) + temporal_use_overlap_add: bool, default: False + If applying a spectral density, it's faster to use overlap-add convolutions, but these can + introduce numerical differences when the args to get_traces() change. But, it is okay to set + this flag if you will only ever call get_traces with non-overlapping chunks. noise_block_size : int, default: 30000 Size in sample of noise block. @@ -1276,9 +1410,11 @@ def __init__( durations: list[float], noise_levels: float | np.ndarray = 1.0, cov_matrix: np.ndarray | None = None, + spectral_density: np.ndarray | None = None, dtype: np.dtype | str | None = "float32", seed: int | None = None, strategy: Literal["tile_pregenerated", "on_the_fly"] = "tile_pregenerated", + temporal_use_overlap_add: bool = False, noise_block_size: int = 30000, ): @@ -1286,7 +1422,9 @@ def __init__( dtype = np.dtype(dtype).name # Cast to string for serialization if dtype not in ("float32", "float64"): raise ValueError(f"'dtype' must be 'float32' or 'float64' but is {dtype}") - assert strategy in ("tile_pregenerated", "on_the_fly"), "'strategy' must be 'tile_pregenerated' or 'on_the_fly'" + assert strategy in ("tile_pregenerated", "on_the_fly"), ( + "'strategy' must be 'tile_pregenerated' or 'on_the_fly'" + ) if np.isscalar(noise_levels): noise_levels = np.ones((1, num_channels)) * noise_levels @@ -1295,16 +1433,23 @@ def __init__( if len(noise_levels.shape) < 2: noise_levels = noise_levels[np.newaxis, :] - assert len(noise_levels[0]) == num_channels, "Noise levels should have a size of num_channels" + assert len(noise_levels[0]) == num_channels, ( + "Noise levels should have a size of num_channels" + ) - BaseRecording.__init__(self, sampling_frequency=sampling_frequency, channel_ids=channel_ids, dtype=dtype) + BaseRecording.__init__( + self, + sampling_frequency=sampling_frequency, + channel_ids=channel_ids, + dtype=dtype, + ) num_segments = len(durations) if cov_matrix is not None: - assert ( - cov_matrix.shape[0] == cov_matrix.shape[1] == num_channels - ), "cov_matrix should have a size (num_channels, num_channels)" + assert cov_matrix.shape[0] == cov_matrix.shape[1] == num_channels, ( + "cov_matrix should have a size (num_channels, num_channels)" + ) # very important here when multiprocessing and dump/load seed = _ensure_seed(seed) @@ -1322,9 +1467,11 @@ def __init__( noise_block_size, noise_levels, cov_matrix, + spectral_density, dtype, segments_seeds[i], strategy, + temporal_use_overlap_add, ) self.add_recording_segment(rec_segment) @@ -1350,9 +1497,11 @@ def __init__( noise_block_size, noise_levels, cov_matrix, + spectral_density, dtype, seed, strategy, + use_overlap_add, ): assert seed is not None @@ -1363,21 +1512,33 @@ def __init__( self.noise_block_size = noise_block_size self.noise_levels = noise_levels self.cov_matrix = cov_matrix + self.spectral_density = spectral_density self.dtype = dtype self.seed = seed self.strategy = strategy + self.use_overlap_add = use_overlap_add if self.strategy == "tile_pregenerated": rng = np.random.default_rng(seed=self.seed) if self.cov_matrix is None: self.noise_block = ( - rng.standard_normal(size=(self.noise_block_size, self.num_channels), dtype=self.dtype) + rng.standard_normal( + size=(self.noise_block_size, self.num_channels), + dtype=self.dtype, + ) * noise_levels ) else: self.noise_block = rng.multivariate_normal( - np.zeros(self.num_channels), self.cov_matrix, size=self.noise_block_size + np.zeros(self.num_channels), + self.cov_matrix, + size=self.noise_block_size, + ) + + if spectral_density is not None: + self.noise_block = _apply_temporal_psd( + self.noise_block, spectral_density, use_overlap_add=use_overlap_add ) elif self.strategy == "on_the_fly": @@ -1386,7 +1547,7 @@ def __init__( def get_num_samples(self) -> int: return self.num_samples - def get_traces( + def get_traces_spatial_only( self, start_frame: int | None = None, end_frame: int | None = None, @@ -1414,13 +1575,20 @@ def get_traces( elif self.strategy == "on_the_fly": rng = np.random.default_rng(seed=(self.seed, block_index)) if self.cov_matrix is None: - noise_block = rng.standard_normal(size=(self.noise_block_size, self.num_channels), dtype=self.dtype) + noise_block = rng.standard_normal( + size=(self.noise_block_size, self.num_channels), + dtype=self.dtype, + ) else: noise_block = rng.multivariate_normal( - np.zeros(self.num_channels), self.cov_matrix, size=self.noise_block_size + np.zeros(self.num_channels), + self.cov_matrix, + size=self.noise_block_size, ) noise_block *= self.noise_levels + else: + assert False if block_index == first_block_index: if first_block_index != last_block_index: @@ -1429,7 +1597,10 @@ def get_traces( pos += end_first_block else: # special case when unique block - traces[:] = noise_block[start_frame_within_block : start_frame_within_block + num_samples] + traces[:] = noise_block[ + start_frame_within_block : start_frame_within_block + + num_samples + ] elif block_index == last_block_index: if end_frame_within_block > 0: traces[pos:] = noise_block[:end_frame_within_block] @@ -1442,6 +1613,69 @@ def get_traces( return traces + def get_traces( + self, + start_frame: int | None = None, + end_frame: int | None = None, + channel_indices: list | None = None, + ): + if self.spectral_density is None or self.strategy == "tile_pregenerated": + return self.get_traces_spatial_only(start_frame, end_frame, channel_indices) + + n_samples = self.get_num_samples() + if start_frame is None: + start_frame = 0 + if end_frame is None: + end_frame = n_samples + + # margin logic for temporal correlations + # temporal correlations are done "circularly" so that the noise wraps continuously + # having the spatial only helper makes the logic way easier! + + # circular margin logic + # left_circ_len is how much margin to grab from the right side to fill out the left + # margin circularly. similarly right_circ_len. + margin = self.spectral_density.shape[0] - 1 + left_pad_len = min(start_frame, margin) + left_circ_len = margin - left_pad_len + right_pad_len = min(n_samples - end_frame, margin) + right_circ_len = margin - right_pad_len + + # grab padded chunk + pad_start_frame = start_frame - left_pad_len + pad_end_frame = end_frame + right_pad_len + print(f"{pad_start_frame=} {pad_end_frame=}") + assert 0 <= pad_start_frame < pad_end_frame <= n_samples + chunk = self.get_traces_spatial_only( + pad_start_frame, pad_end_frame, channel_indices + ) + empty = np.zeros_like(chunk[:0]) + if left_circ_len: + circ_left = self.get_traces_spatial_only( + n_samples - left_circ_len, n_samples, channel_indices + ) + else: + circ_left = empty + if right_circ_len: + circ_right = self.get_traces_spatial_only( + 0, right_circ_len, channel_indices + ) + else: + circ_right = empty + + chunk_main = chunk[left_pad_len : chunk.shape[0] - right_pad_len] + pad_left = np.concatenate([circ_left, chunk[:left_pad_len]]) + pad_right = np.concatenate( + [chunk[chunk.shape[0] - right_pad_len : chunk.shape[0]], circ_right] + ) + return _apply_temporal_psd( + chunk_main, + self.spectral_density, + pad_left, + pad_right, + use_overlap_add=self.use_overlap_add, + ) + noise_generator_recording = define_function_from_class( source_class=NoiseGeneratorRecording, name="noise_generator_recording" @@ -1502,6 +1736,50 @@ def generate_recording_by_size( return recording +def _apply_temporal_psd( + chunk: np.ndarray, + psd: np.ndarray, + pad_left: np.ndarray | None = None, + pad_right: np.ndarray | None = None, + use_overlap_add: bool = False, +): + """Apply a convolution so that chunk's first dimension has psd `psd`, if it started out as white noise + + If padding arrays are not supplied, circular padding is extracted. + + It is very sad because it's relatively slow, but we are absolutely forced to use direct convolution here. + If not, then the output will have numerical disagreement depending on the args to get_traces(), since the + ffts will vary slightly. If you can tolerate errors of 5e-7 or if you'll only ever call get_traces() with + the same args in non-overlapping chunks, you can use_overlap_add=True. + """ + from scipy.signal import oaconvolve, convolve + + klen = psd.shape[0] + block_len = 2 * klen - 1 + pad_len = klen - 1 + + # padding + T = chunk.shape[0] + if pad_left is None: + pad_left = chunk[T - pad_len : T] + if pad_right is None: + pad_right = chunk[:pad_len] + assert pad_left.shape == pad_right.shape == (pad_len, chunk.shape[1]) + + # stack and convolve + chunk_conv = np.concatenate([pad_left.T, chunk.T, pad_right.T], axis=1) + kernel = np.fft.fftshift(np.fft.irfft(psd, n=block_len)) + if use_overlap_add: + conv = oaconvolve(chunk_conv, kernel[None], mode="valid", axes=1) + else: + conv = convolve(chunk_conv, kernel[None], mode="valid", method="direct") + del chunk_conv + conv = np.ascontiguousarray(conv.T) + assert conv.shape == chunk.shape + + return conv + + ## Waveforms zone ## @@ -1518,7 +1796,9 @@ def exp_growth(start_amp, end_amp, duration_ms, tau_ms, sampling_frequency, flip return y[:-1] -def get_ellipse(positions, center, x_factor=1, y_factor=1, x_angle=0, y_angle=0, z_angle=0): +def get_ellipse( + positions, center, x_factor=1, y_factor=1, x_angle=0, y_angle=0, z_angle=0 +): """ Compute the distances to a particular ellipsoid in order to take into account spatial inhomogeneities while generating the template. In a carthesian, centered @@ -1575,7 +1855,9 @@ def get_ellipse(positions, center, x_factor=1, y_factor=1, x_angle=0, y_angle=0, rot_matrix = Rx @ Ry @ Rz P = rot_matrix @ p - distances = np.sqrt((P[0] / x_factor) ** 2 + (P[1] / y_factor) ** 2 + (P[2] / 1) ** 2) + distances = np.sqrt( + (P[0] / x_factor) ** 2 + (P[1] / y_factor) ** 2 + (P[2] / 1) ** 2 + ) return distances @@ -1615,7 +1897,12 @@ def generate_single_fake_waveform( nrepol = int(repolarization_ms * sampling_frequency / 1000.0) tau_ms = repolarization_ms * 0.5 wf[nbefore : nbefore + nrepol] = exp_growth( - negative_amplitude, positive_amplitude, repolarization_ms, tau_ms, sampling_frequency, flip=True + negative_amplitude, + positive_amplitude, + repolarization_ms, + tau_ms, + sampling_frequency, + flip=True, ) # recovery @@ -1680,7 +1967,9 @@ def _ensure_unit_params(unit_params, num_units, seed): elif isinstance(v, (list, np.ndarray)): # already vector values = np.asarray(v) - assert values.shape == (num_units,), f"generate_templates: wrong shape for {k} in unit_params" + assert values.shape == (num_units,), ( + f"generate_templates: wrong shape for {k} in unit_params" + ) elif v is None: values = [None] * num_units else: @@ -1772,7 +2061,9 @@ def generate_templates( # channel_locations to 3D if channel_locations.shape[1] == 2: - channel_locations = np.hstack([channel_locations, np.zeros((channel_locations.shape[0], 1))]) + channel_locations = np.hstack( + [channel_locations, np.zeros((channel_locations.shape[0], 1))] + ) num_units = units_locations.shape[0] num_channels = channel_locations.shape[0] @@ -1783,7 +2074,9 @@ def generate_templates( if upsample_factor is not None: upsample_factor = int(upsample_factor) assert upsample_factor >= 1 - templates = np.zeros((num_units, width, num_channels, upsample_factor), dtype=dtype) + templates = np.zeros( + (num_units, width, num_channels, upsample_factor), dtype=dtype + ) fs = sampling_frequency * upsample_factor else: templates = np.zeros((num_units, width, num_channels), dtype=dtype) @@ -1928,9 +2221,17 @@ def __init__( check_borders = False self.templates = templates - channel_ids = parent_recording.channel_ids if parent_recording is not None else list(range(templates.shape[2])) - dtype = parent_recording.dtype if parent_recording is not None else templates.dtype - BaseRecording.__init__(self, sorting.get_sampling_frequency(), channel_ids, dtype) + channel_ids = ( + parent_recording.channel_ids + if parent_recording is not None + else list(range(templates.shape[2])) + ) + dtype = ( + parent_recording.dtype if parent_recording is not None else templates.dtype + ) + BaseRecording.__init__( + self, sorting.get_sampling_frequency(), channel_ids, dtype + ) # Important : self._serializability is not change here because it will depend on the sorting parents itself. @@ -1962,7 +2263,9 @@ def __init__( if amplitude_factor is None: amplitude_vector = None elif np.isscalar(amplitude_factor): - amplitude_vector = np.full(self.spike_vector.size, amplitude_factor, dtype="float32") + amplitude_vector = np.full( + self.spike_vector.size, amplitude_factor, dtype="float32" + ) else: amplitude_factor = np.asarray(amplitude_factor) assert amplitude_factor.shape == self.spike_vector.shape @@ -1970,13 +2273,18 @@ def __init__( if parent_recording is not None: assert parent_recording.get_num_segments() == sorting.get_num_segments() - assert parent_recording.get_sampling_frequency() == sorting.get_sampling_frequency() + assert ( + parent_recording.get_sampling_frequency() + == sorting.get_sampling_frequency() + ) assert parent_recording.get_num_channels() == templates.shape[2] parent_recording.copy_metadata(self) if num_samples is None: if parent_recording is None: - num_samples = [self.spike_vector["sample_index"][-1] + templates.shape[1]] + num_samples = [ + self.spike_vector["sample_index"][-1] + templates.shape[1] + ] else: num_samples = [ parent_recording.get_num_frames(segment_index) @@ -1987,13 +2295,25 @@ def __init__( num_samples = [num_samples] for segment_index in range(sorting.get_num_segments()): - start = np.searchsorted(self.spike_vector["segment_index"], segment_index, side="left") - end = np.searchsorted(self.spike_vector["segment_index"], segment_index, side="right") + start = np.searchsorted( + self.spike_vector["segment_index"], segment_index, side="left" + ) + end = np.searchsorted( + self.spike_vector["segment_index"], segment_index, side="right" + ) spikes = self.spike_vector[start:end] - amplitude_vec = amplitude_vector[start:end] if amplitude_vector is not None else None - upsample_vec = upsample_vector[start:end] if upsample_vector is not None else None + amplitude_vec = ( + amplitude_vector[start:end] if amplitude_vector is not None else None + ) + upsample_vec = ( + upsample_vector[start:end] if upsample_vector is not None else None + ) - parent_recording_segment = None if parent_recording is None else parent_recording.segments[segment_index] + parent_recording_segment = ( + None + if parent_recording is None + else parent_recording.segments[segment_index] + ) recording_segment = InjectTemplatesRecordingSegment( self.sampling_frequency, self.dtype, @@ -2033,7 +2353,10 @@ def _check_templates(templates: np.ndarray): max_value = np.max(np.abs(templates)) threshold = 0.01 * max_value - if max(np.max(np.abs(templates[:, 0])), np.max(np.abs(templates[:, -1]))) > threshold: + if ( + max(np.max(np.abs(templates[:, 0])), np.max(np.abs(templates[:, -1]))) + > threshold + ): warnings.warn( "Warning! Your templates do not go to 0 on the edges in InjectTemplatesRecording. Please make your window bigger." ) @@ -2055,7 +2378,9 @@ def __init__( BaseRecordingSegment.__init__( self, sampling_frequency, - t_start=0 if parent_recording_segment is None else parent_recording_segment.t_start, + t_start=0 + if parent_recording_segment is None + else parent_recording_segment.t_start, ) assert not (parent_recording_segment is None and num_samples is None) @@ -2066,7 +2391,11 @@ def __init__( self.amplitude_vector = amplitude_vector self.upsample_vector = upsample_vector self.parent_recording = parent_recording_segment - self.num_samples = parent_recording_segment.get_num_frames() if num_samples is None else num_samples + self.num_samples = ( + parent_recording_segment.get_num_frames() + if num_samples is None + else num_samples + ) def get_traces( self, @@ -2077,7 +2406,11 @@ def get_traces( if channel_indices is None: n_channels = self.templates.shape[2] elif isinstance(channel_indices, slice): - stop = channel_indices.stop if channel_indices.stop is not None else self.templates.shape[2] + stop = ( + channel_indices.stop + if channel_indices.stop is not None + else self.templates.shape[2] + ) start = channel_indices.start if channel_indices.start is not None else 0 step = channel_indices.step if channel_indices.step is not None else 1 n_channels = math.ceil((stop - start) / step) @@ -2085,12 +2418,22 @@ def get_traces( n_channels = len(channel_indices) if self.parent_recording is not None: - traces = self.parent_recording.get_traces(start_frame, end_frame, channel_indices).copy() + traces = self.parent_recording.get_traces( + start_frame, end_frame, channel_indices + ).copy() else: traces = np.zeros([end_frame - start_frame, n_channels], dtype=self.dtype) - start = np.searchsorted(self.spike_vector["sample_index"], start_frame - self.templates.shape[1], side="left") - end = np.searchsorted(self.spike_vector["sample_index"], end_frame + self.templates.shape[1], side="right") + start = np.searchsorted( + self.spike_vector["sample_index"], + start_frame - self.templates.shape[1], + side="left", + ) + end = np.searchsorted( + self.spike_vector["sample_index"], + end_frame + self.templates.shape[1], + side="right", + ) for i in range(start, end): spike = self.spike_vector[i] @@ -2133,7 +2476,9 @@ def get_num_samples(self) -> int: return self.num_samples -inject_templates = define_function_from_class(source_class=InjectTemplatesRecording, name="inject_templates") +inject_templates = define_function_from_class( + source_class=InjectTemplatesRecording, name="inject_templates" +) ## toy example zone ## @@ -2147,7 +2492,9 @@ def generate_channel_locations(num_channels, num_columns, contact_spacing_um): num_contact_per_column = num_channels // num_columns j = 0 for i in range(num_columns): - channel_locations[j : j + num_contact_per_column, 0] = i * contact_spacing_um + channel_locations[j : j + num_contact_per_column, 0] = ( + i * contact_spacing_um + ) channel_locations[j : j + num_contact_per_column, 1] = ( np.arange(num_contact_per_column) * contact_spacing_um ) @@ -2166,7 +2513,9 @@ def _generate_multimodal(rng, size, num_modes, lim0, lim1): prob += np.exp(-((bins - center) ** 2) / (2 * sigma**2)) prob /= np.sum(prob) choices = rng.choice(np.arange(bins.size), size, p=prob) - values = bins[choices] + rng.uniform(low=-bin_step / 2, high=bin_step / 2, size=size) + values = bins[choices] + rng.uniform( + low=-bin_step / 2, high=bin_step / 2, size=size + ) return values @@ -2241,23 +2590,35 @@ def generate_unit_locations( rng = np.random.default_rng(seed=seed) units_locations = np.zeros((num_units, 3), dtype="float32") - minimum_x, maximum_x = np.min(channel_locations[:, 0]) - margin_um, np.max(channel_locations[:, 0]) + margin_um - minimum_y, maximum_y = np.min(channel_locations[:, 1]) - margin_um, np.max(channel_locations[:, 1]) + margin_um + minimum_x, maximum_x = ( + np.min(channel_locations[:, 0]) - margin_um, + np.max(channel_locations[:, 0]) + margin_um, + ) + minimum_y, maximum_y = ( + np.min(channel_locations[:, 1]) - margin_um, + np.max(channel_locations[:, 1]) + margin_um, + ) units_locations[:, 0] = rng.uniform(minimum_x, maximum_x, size=num_units) if distribution == "uniform": units_locations[:, 1] = rng.uniform(minimum_y, maximum_y, size=num_units) elif distribution == "multimodal": - units_locations[:, 1] = _generate_multimodal(rng, num_units, num_modes, minimum_y, maximum_y) + units_locations[:, 1] = _generate_multimodal( + rng, num_units, num_modes, minimum_y, maximum_y + ) else: - raise ValueError("generate_unit_locations has wrong distribution must be 'uniform' or ") + raise ValueError( + "generate_unit_locations has wrong distribution must be 'uniform' or " + ) units_locations[:, 2] = rng.uniform(minimum_z, maximum_z, size=num_units) if minimum_distance is not None: solution_found = False renew_inds = None for i in range(max_iteration): - distances = np.linalg.norm(units_locations[:, np.newaxis] - units_locations[np.newaxis, :], axis=2) + distances = np.linalg.norm( + units_locations[:, np.newaxis] - units_locations[np.newaxis, :], axis=2 + ) inds0, inds1 = np.nonzero(distances < minimum_distance) mask = inds0 != inds1 inds0 = inds0[mask] @@ -2270,15 +2631,21 @@ def generate_unit_locations( # random only bad ones in the previous set renew_inds = renew_inds[np.isin(renew_inds, np.unique(inds0))] - units_locations[:, 0][renew_inds] = rng.uniform(minimum_x, maximum_x, size=renew_inds.size) + units_locations[:, 0][renew_inds] = rng.uniform( + minimum_x, maximum_x, size=renew_inds.size + ) if distribution == "uniform": - units_locations[:, 1][renew_inds] = rng.uniform(minimum_y, maximum_y, size=renew_inds.size) + units_locations[:, 1][renew_inds] = rng.uniform( + minimum_y, maximum_y, size=renew_inds.size + ) elif distribution == "multimodal": units_locations[:, 1][renew_inds] = _generate_multimodal( rng, renew_inds.size, num_modes, minimum_y, maximum_y ) - units_locations[:, 2][renew_inds] = rng.uniform(minimum_z, maximum_z, size=renew_inds.size) + units_locations[:, 2][renew_inds] = rng.uniform( + minimum_z, maximum_z, size=renew_inds.size + ) else: solution_found = True @@ -2291,7 +2658,9 @@ def generate_unit_locations( "You can use distance_strict=False or reduce minimum distance" ) else: - warnings.warn(f"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=}") + warnings.warn( + f"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=}" + ) return units_locations @@ -2317,7 +2686,9 @@ def generate_ground_truth_recording( upsample_vector=None, generate_sorting_kwargs=dict(firing_rates=15, refractory_period_ms=4.0), noise_kwargs=dict(noise_levels=5.0, strategy="on_the_fly"), - generate_unit_locations_kwargs=dict(margin_um=10.0, minimum_z=5.0, maximum_z=50.0, minimum_distance=20), + generate_unit_locations_kwargs=dict( + margin_um=10.0, minimum_z=5.0, maximum_z=50.0, minimum_distance=20 + ), generate_templates_kwargs=None, dtype="float32", seed=None, @@ -2416,7 +2787,9 @@ def generate_ground_truth_recording( num_contact_per_column[mid] += num_channels % prb_kwargs["num_columns"] prb_kwargs["num_contact_per_column"] = num_contact_per_column else: - raise ValueError("num_columns should be provided in dict generate_probe_kwargs") + raise ValueError( + "num_columns should be provided in dict generate_probe_kwargs" + ) probe = generate_multi_columns_probe(**prb_kwargs) probe.set_device_channel_indices(np.arange(num_channels)) diff --git a/src/spikeinterface/core/tests/test_generate.py b/src/spikeinterface/core/tests/test_generate.py index e0e28d09cd..3bbbf66bab 100644 --- a/src/spikeinterface/core/tests/test_generate.py +++ b/src/spikeinterface/core/tests/test_generate.py @@ -57,17 +57,29 @@ def test_generate_sorting_with_spikes_on_borders(): ) # check that segments are correctly sorted all_spikes = sorting.to_spike_vector() - np.testing.assert_array_equal(all_spikes["segment_index"], np.sort(all_spikes["segment_index"])) + np.testing.assert_array_equal( + all_spikes["segment_index"], np.sort(all_spikes["segment_index"]) + ) spikes = sorting.to_spike_vector(concatenated=False) # at least num_border spikes at borders for all segments for spikes_in_segment in spikes: # check that sample indices are correctly sorted within segments - np.testing.assert_array_equal(spikes_in_segment["sample_index"], np.sort(spikes_in_segment["sample_index"])) + np.testing.assert_array_equal( + spikes_in_segment["sample_index"], + np.sort(spikes_in_segment["sample_index"]), + ) num_samples = int(segment_duration * 30000) - assert np.sum(spikes_in_segment["sample_index"] < border_size_samples) >= num_spikes_on_borders assert ( - np.sum(spikes_in_segment["sample_index"] >= num_samples - border_size_samples) >= num_spikes_on_borders + np.sum(spikes_in_segment["sample_index"] < border_size_samples) + >= num_spikes_on_borders + ) + assert ( + np.sum( + spikes_in_segment["sample_index"] + >= num_samples - border_size_samples + ) + >= num_spikes_on_borders ) @@ -117,9 +129,9 @@ def test_memory_sorting_generator(): memory_usage_MiB = after_instanciation_MiB - before_instanciation_MiB ratio = memory_usage_MiB / before_instanciation_MiB expected_allocation_MiB = 0 - assert ( - ratio <= 1.0 + relative_tolerance - ), f"SortingGenerator wrong memory {memory_usage_MiB} instead of {expected_allocation_MiB}" + assert ratio <= 1.0 + relative_tolerance, ( + f"SortingGenerator wrong memory {memory_usage_MiB} instead of {expected_allocation_MiB}" + ) def test_sorting_generator_consisency_across_calls(): @@ -156,8 +168,12 @@ def test_sorting_generator_consisency_within_trains(): ) for unit_id in sorting.get_unit_ids(): - spike_train = sorting.get_unit_spike_train(unit_id=unit_id, start_frame=0, end_frame=1000) - spike_train_again = sorting.get_unit_spike_train(unit_id=unit_id, start_frame=0, end_frame=1000) + spike_train = sorting.get_unit_spike_train( + unit_id=unit_id, start_frame=0, end_frame=1000 + ) + spike_train_again = sorting.get_unit_spike_train( + unit_id=unit_id, start_frame=0, end_frame=1000 + ) assert np.allclose(spike_train, spike_train_again) @@ -189,11 +205,13 @@ def test_noise_generator_memory(): ) after_instanciation_MiB = measure_memory_allocation() / bytes_to_MiB_factor memory_usage_MiB = after_instanciation_MiB - before_instanciation_MiB - expected_allocation_MiB = dtype.itemsize * num_channels * noise_block_size / bytes_to_MiB_factor + expected_allocation_MiB = ( + dtype.itemsize * num_channels * noise_block_size / bytes_to_MiB_factor + ) ratio = expected_allocation_MiB / expected_allocation_MiB - assert ( - ratio <= 1.0 + relative_tolerance - ), f"NoiseGeneratorRecording with 'tile_pregenerated' wrong memory {memory_usage_MiB} instead of {expected_allocation_MiB}" + assert ratio <= 1.0 + relative_tolerance, ( + f"NoiseGeneratorRecording with 'tile_pregenerated' wrong memory {memory_usage_MiB} instead of {expected_allocation_MiB}" + ) # case 2: no preallocation very few memory (under 2 MiB) before_instanciation_MiB = measure_memory_allocation() / bytes_to_MiB_factor @@ -208,7 +226,9 @@ def test_noise_generator_memory(): ) after_instanciation_MiB = measure_memory_allocation() / bytes_to_MiB_factor memory_usage_MiB = after_instanciation_MiB - before_instanciation_MiB - assert memory_usage_MiB < 2, f"NoiseGeneratorRecording with 'on_the_fly wrong memory {memory_usage_MiB}MiB" + assert memory_usage_MiB < 2, ( + f"NoiseGeneratorRecording with 'on_the_fly wrong memory {memory_usage_MiB}MiB" + ) def test_noise_generator_several_noise_levels(): @@ -279,6 +299,66 @@ def test_noise_generator_correct_shape(strategy): assert traces.shape == (num_frames, num_channels) +@pytest.mark.parametrize("duration", [1.0, 2.0, 2.2]) +@pytest.mark.parametrize("strategy", strategy_list) +def test_noise_generator_temporal(strategy, duration): + psdlen = 25 + kdomain = np.linspace(0.0, 10.0, psdlen) + fake_psd = (kdomain + 0.1) * np.exp(-kdomain) + # this ensures std dev of output ~= 1 + fake_psd /= np.sqrt((fake_psd**2).mean()) + + # Test that the recording has the correct size in shape + sampling_frequency = 30000 # Hz + durations = [duration] + dtype = np.dtype("float32") + num_channels = 2 + seed = 0 + + rec = NoiseGeneratorRecording( + num_channels=num_channels, + sampling_frequency=sampling_frequency, + durations=durations, + dtype=dtype, + seed=seed, + spectral_density=fake_psd, + strategy=strategy, + ) + + # check output matches at different chunks + full_traces = rec.get_traces() + end_frame = rec.get_num_frames() + for t0 in [0, 100]: + for t1 in [end_frame, end_frame - 100]: + print(f"{t0=} {t1=}") + chk = rec.get_traces(0, t0, t1) + chk0 = full_traces[t0:t1] + print(f"{np.flatnonzero((chk!=chk0).any(1))=}") + np.testing.assert_array_equal(chk, chk0) + + np.testing.assert_allclose(full_traces.std(), 1.0, rtol=0.02) + + # re-estimate the psd from the result + # it will not be perfect, okay! + n = 2 * psdlen - 1 + snips = full_traces[: n * (full_traces.shape[0] // n)] + snips = snips.reshape(-1, n, snips.shape[-1]) + psd = np.fft.rfft(snips, n=n, axis=1, norm="ortho") + psd = np.sqrt(np.square(np.abs(psd)).mean(axis=(0, 2))) + + sample_size = snips.shape[0] * snips.shape[2] + standard_error = 1.0 / np.sqrt(sample_size) + + # accuracy is good at low freqs + np.testing.assert_allclose( + psd[1 : psdlen // 3], + fake_psd[1 : psdlen // 3], + atol=3 * standard_error, + rtol=0.1, + ) + np.testing.assert_allclose(psd, fake_psd, atol=0.5) + + @pytest.mark.parametrize("strategy", strategy_list) @pytest.mark.parametrize( "start_frame, end_frame", @@ -308,7 +388,9 @@ def test_noise_generator_consistency_across_calls(strategy, start_frame, end_fra ) traces = lazy_recording.get_traces(start_frame=start_frame, end_frame=end_frame) - same_traces = lazy_recording.get_traces(start_frame=start_frame, end_frame=end_frame) + same_traces = lazy_recording.get_traces( + start_frame=start_frame, end_frame=end_frame + ) assert np.allclose(traces, same_traces) @@ -324,14 +406,18 @@ def test_noise_generator_consistency_across_calls(strategy, start_frame, end_fra (0, 60_000, 10_000), ], ) -def test_noise_generator_consistency_across_traces(strategy, start_frame, end_frame, extra_samples): +def test_noise_generator_consistency_across_traces( + strategy, start_frame, end_frame, extra_samples +): # Test that the generated traces behave like true arrays. Calling a larger array and then slicing it should # give the same result as calling the slice directly sampling_frequency = 30000 # Hz durations = [10.0] dtype = np.dtype("float32") num_channels = 2 - seed = start_frame + end_frame + extra_samples # To make sure that the seed is different for each test + seed = ( + start_frame + end_frame + extra_samples + ) # To make sure that the seed is different for each test lazy_recording = NoiseGeneratorRecording( num_channels=num_channels, @@ -344,8 +430,12 @@ def test_noise_generator_consistency_across_traces(strategy, start_frame, end_fr traces = lazy_recording.get_traces(start_frame=start_frame, end_frame=end_frame) end_frame_larger_array = end_frame + extra_samples - larger_traces = lazy_recording.get_traces(start_frame=start_frame, end_frame=end_frame_larger_array) - equivalent_trace_from_larger_traces = larger_traces[:-extra_samples, :] # Remove the extra samples + larger_traces = lazy_recording.get_traces( + start_frame=start_frame, end_frame=end_frame_larger_array + ) + equivalent_trace_from_larger_traces = larger_traces[ + :-extra_samples, : + ] # Remove the extra samples assert np.allclose(traces, equivalent_trace_from_larger_traces) @@ -378,7 +468,9 @@ def test_generate_single_fake_waveform(): sampling_frequency = 30000.0 ms_before = 1.0 ms_after = 3.0 - wf = generate_single_fake_waveform(ms_before=ms_before, ms_after=ms_after, sampling_frequency=sampling_frequency) + wf = generate_single_fake_waveform( + ms_before=ms_before, ms_after=ms_after, sampling_frequency=sampling_frequency + ) # import matplotlib.pyplot as plt # times = np.arange(wf.size) / sampling_frequency * 1000 - ms_before @@ -391,7 +483,9 @@ def test_generate_single_fake_waveform(): def test_generate_unit_locations(): seed = 0 - probe = generate_multi_columns_probe(num_columns=2, num_contact_per_column=20, xpitch=20, ypitch=20) + probe = generate_multi_columns_probe( + num_columns=2, num_contact_per_column=20, xpitch=20, ypitch=20 + ) channel_locations = probe.contact_positions num_units = 100 @@ -407,7 +501,9 @@ def test_generate_unit_locations(): distance_strict=False, seed=seed, ) - distances = np.linalg.norm(unit_locations[:, np.newaxis] - unit_locations[np.newaxis, :], axis=2) + distances = np.linalg.norm( + unit_locations[:, np.newaxis] - unit_locations[np.newaxis, :], axis=2 + ) dist_flat = np.triu(distances, k=1).flatten() dist_flat = dist_flat[dist_flat > 0] assert np.all(dist_flat > minimum_distance) @@ -430,7 +526,9 @@ def test_generate_templates(): num_units = 10 margin_um = 15.0 channel_locations = generate_channel_locations(num_chans, num_columns, 20.0) - unit_locations = generate_unit_locations(num_units, channel_locations, margin_um=margin_um, seed=seed) + unit_locations = generate_unit_locations( + num_units, channel_locations, margin_um=margin_um, seed=seed + ) sampling_frequency = 30000.0 ms_before = 1.0 @@ -501,15 +599,30 @@ def test_inject_templates(): # generate some sutff rec_noise = generate_recording( - num_channels=num_channels, durations=durations, sampling_frequency=sampling_frequency, seed=42 + num_channels=num_channels, + durations=durations, + sampling_frequency=sampling_frequency, + seed=42, ) channel_locations = rec_noise.get_channel_locations() sorting = generate_sorting( - num_units=num_units, durations=durations, sampling_frequency=sampling_frequency, firing_rates=1.0, seed=42 + num_units=num_units, + durations=durations, + sampling_frequency=sampling_frequency, + firing_rates=1.0, + seed=42, + ) + units_locations = generate_unit_locations( + num_units, channel_locations, margin_um=10.0, seed=42 ) - units_locations = generate_unit_locations(num_units, channel_locations, margin_um=10.0, seed=42) templates_3d = generate_templates( - channel_locations, units_locations, sampling_frequency, ms_before, ms_after, seed=42, upsample_factor=None + channel_locations, + units_locations, + sampling_frequency, + ms_before, + ms_after, + seed=42, + upsample_factor=None, ) templates_4d = generate_templates( channel_locations, @@ -526,23 +639,38 @@ def test_inject_templates(): sorting, templates_3d, nbefore=nbefore, - num_samples=[rec_noise.get_num_frames(seg_ind) for seg_ind in range(rec_noise.get_num_segments())], + num_samples=[ + rec_noise.get_num_frames(seg_ind) + for seg_ind in range(rec_noise.get_num_segments()) + ], ) # Case 2: with parent_recording - rec2 = InjectTemplatesRecording(sorting, templates_3d, nbefore=nbefore, parent_recording=rec_noise) + rec2 = InjectTemplatesRecording( + sorting, templates_3d, nbefore=nbefore, parent_recording=rec_noise + ) # Case 3: with parent_recording + upsample_factor rng = np.random.default_rng(seed=42) - upsample_vector = rng.integers(0, upsample_factor, size=sorting.to_spike_vector().size) + upsample_vector = rng.integers( + 0, upsample_factor, size=sorting.to_spike_vector().size + ) rec3 = InjectTemplatesRecording( - sorting, templates_4d, nbefore=nbefore, parent_recording=rec_noise, upsample_vector=upsample_vector + sorting, + templates_4d, + nbefore=nbefore, + parent_recording=rec_noise, + upsample_vector=upsample_vector, ) for rec in (rec1, rec2, rec3): assert rec.get_traces(end_frame=600, segment_index=0).shape == (600, 4) - assert rec.get_traces(start_frame=100, end_frame=600, segment_index=1).shape == (500, 4) - assert rec.get_traces(start_frame=rec_noise.get_num_frames(0) - 200, segment_index=0).shape == (200, 4) + assert rec.get_traces( + start_frame=100, end_frame=600, segment_index=1 + ).shape == (500, 4) + assert rec.get_traces( + start_frame=rec_noise.get_num_frames(0) - 200, segment_index=0 + ).shape == (200, 4) # Check dumpability saved_loaded = load(rec.to_dict()) @@ -563,41 +691,60 @@ def test_transformsorting(): transformed = TransformSorting.add_from_sorting(sorting_1, sorting_3) assert len(transformed.unit_ids) == 50 - sorting_1 = NumpySorting.from_unit_dict({46: np.array([0, 150], dtype=int)}, sampling_frequency=20000.0) + sorting_1 = NumpySorting.from_unit_dict( + {46: np.array([0, 150], dtype=int)}, sampling_frequency=20000.0 + ) sorting_2 = NumpySorting.from_unit_dict( - {0: np.array([100, 2000], dtype=int), 3: np.array([200, 4000], dtype=int)}, sampling_frequency=20000.0 + {0: np.array([100, 2000], dtype=int), 3: np.array([200, 4000], dtype=int)}, + sampling_frequency=20000.0, ) transformed = TransformSorting.add_from_sorting(sorting_1, sorting_2) assert len(transformed.unit_ids) == 3 - assert np.all(np.array([k for k in transformed.count_num_spikes_per_unit(outputs="array")]) == 2) + assert np.all( + np.array([k for k in transformed.count_num_spikes_per_unit(outputs="array")]) + == 2 + ) - sorting_1 = NumpySorting.from_unit_dict({0: np.array([12], dtype=int)}, sampling_frequency=20000.0) + sorting_1 = NumpySorting.from_unit_dict( + {0: np.array([12], dtype=int)}, sampling_frequency=20000.0 + ) sorting_2 = NumpySorting.from_unit_dict( - {0: np.array([150], dtype=int), 3: np.array([12, 150], dtype=int)}, sampling_frequency=20000.0 + {0: np.array([150], dtype=int), 3: np.array([12, 150], dtype=int)}, + sampling_frequency=20000.0, ) transformed = TransformSorting.add_from_sorting(sorting_1, sorting_2) assert len(transformed.unit_ids) == 2 target_array = np.array([2, 2]) - source_array = np.array([k for k in transformed.count_num_spikes_per_unit(outputs="array")]) + source_array = np.array( + [k for k in transformed.count_num_spikes_per_unit(outputs="array")] + ) assert np.array_equal(source_array, target_array) assert transformed.get_added_spikes_from_existing_indices().size == 1 assert transformed.get_added_spikes_from_new_indices().size == 2 assert transformed.get_added_units_inds() == [3] - transformed = TransformSorting.add_from_unit_dict(sorting_1, {46: np.array([12, 150], dtype=int)}) + transformed = TransformSorting.add_from_unit_dict( + sorting_1, {46: np.array([12, 150], dtype=int)} + ) sorting_1 = generate_sorting(seed=0) - transformed = TransformSorting(sorting_1, sorting_1.to_spike_vector(), refractory_period_ms=0) + transformed = TransformSorting( + sorting_1, sorting_1.to_spike_vector(), refractory_period_ms=0 + ) assert len(sorting_1.to_spike_vector()) == len(transformed.to_spike_vector()) - transformed = TransformSorting(sorting_1, sorting_1.to_spike_vector(), refractory_period_ms=5) + transformed = TransformSorting( + sorting_1, sorting_1.to_spike_vector(), refractory_period_ms=5 + ) assert 2 * len(sorting_1.to_spike_vector()) > len(transformed.to_spike_vector()) transformed_2 = TransformSorting(sorting_1, transformed.to_spike_vector()) assert len(transformed_2.to_spike_vector()) > len(transformed.to_spike_vector()) - assert np.sum(transformed_2.get_added_spikes_indices()) >= np.sum(transformed_2.get_added_spikes_from_new_indices()) + assert np.sum(transformed_2.get_added_spikes_indices()) >= np.sum( + transformed_2.get_added_spikes_from_new_indices() + ) assert np.sum(transformed_2.get_added_spikes_indices()) >= np.sum( transformed_2.get_added_spikes_from_existing_indices() ) @@ -613,9 +760,17 @@ def test_generate_ground_truth_recording(): def test_generate_sorting_to_inject(): durations = [10.0, 20.0] - sorting = generate_sorting(num_units=10, durations=durations, sampling_frequency=30000, firing_rates=1.0, seed=2205) + sorting = generate_sorting( + num_units=10, + durations=durations, + sampling_frequency=30000, + firing_rates=1.0, + seed=2205, + ) injected_sorting = generate_sorting_to_inject( - sorting, [int(duration * sorting.sampling_frequency) for duration in durations], seed=2308 + sorting, + [int(duration * sorting.sampling_frequency) for duration in durations], + seed=2308, ) num_spikes = sorting.count_num_spikes_per_unit() num_injected_spikes = injected_sorting.count_num_spikes_per_unit() From 592307418554d145a51e6f885d6c629c968e7bd3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 1 Apr 2026 16:54:24 +0000 Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/generate.py | 378 +++++------------- .../core/tests/test_generate.py | 137 ++----- 2 files changed, 135 insertions(+), 380 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 420af71b44..9eb87c97de 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -159,13 +159,9 @@ def generate_sorting( spikes.append(spikes_in_seg) if add_spikes_on_borders: - spikes_on_borders = np.zeros( - 2 * num_spikes_per_border, dtype=minimum_spike_dtype - ) + spikes_on_borders = np.zeros(2 * num_spikes_per_border, dtype=minimum_spike_dtype) spikes_on_borders["segment_index"] = segment_index - spikes_on_borders["unit_index"] = rng.choice( - num_units, size=2 * num_spikes_per_border, replace=True - ) + spikes_on_borders["unit_index"] = rng.choice(num_units, size=2 * num_spikes_per_border, replace=True) # at start spikes_on_borders["sample_index"][:num_spikes_per_border] = rng.integers( 0, border_size_samples, num_spikes_per_border @@ -177,11 +173,7 @@ def generate_sorting( spikes.append(spikes_on_borders) spikes = np.concatenate(spikes) - spikes = spikes[ - np.lexsort( - (spikes["unit_index"], spikes["sample_index"], spikes["segment_index"]) - ) - ] + spikes = spikes[np.lexsort((spikes["unit_index"], spikes["sample_index"], spikes["segment_index"]))] sorting = NumpySorting(spikes, sampling_frequency, unit_ids) @@ -225,26 +217,18 @@ def add_synchrony_to_sorting(sorting, sync_event_ratio=0.3, seed=None): sample_index = spike["sample_index"] if sample_index not in units_used_for_spike: units_used_for_spike[sample_index] = np.array([spike["unit_index"]]) - units_not_used = unit_ids[ - ~np.isin(unit_ids, units_used_for_spike[sample_index]) - ] + units_not_used = unit_ids[~np.isin(unit_ids, units_used_for_spike[sample_index])] if len(units_not_used) == 0: continue new_unit_indices[i] = rng.choice(units_not_used) - units_used_for_spike[sample_index] = np.append( - units_used_for_spike[sample_index], new_unit_indices[i] - ) + units_used_for_spike[sample_index] = np.append(units_used_for_spike[sample_index], new_unit_indices[i]) spikes_duplicated["unit_index"] = new_unit_indices - sort_idxs = np.lexsort( - [spikes_duplicated["sample_index"], spikes_duplicated["segment_index"]] - ) + sort_idxs = np.lexsort([spikes_duplicated["sample_index"], spikes_duplicated["segment_index"]]) spikes_duplicated = spikes_duplicated[sort_idxs] - synchronous_spikes = NumpySorting( - spikes_duplicated, sorting.get_sampling_frequency(), unit_ids - ) + synchronous_spikes = NumpySorting(spikes_duplicated, sorting.get_sampling_frequency(), unit_ids) sorting = TransformSorting.add_from_sorting(sorting, synchronous_spikes) return sorting @@ -292,18 +276,12 @@ def generate_sorting_to_inject( for segment_index in range(sorting.get_num_segments()): for unit_id in sorting.unit_ids: - spike_train = sorting.get_unit_spike_train( - unit_id, segment_index=segment_index - ) - n_injection = min( - max_injected_per_unit, int(round(injected_rate * len(spike_train))) - ) + spike_train = sorting.get_unit_spike_train(unit_id, segment_index=segment_index) + n_injection = min(max_injected_per_unit, int(round(injected_rate * len(spike_train)))) # Inject more, then take out all that violate the refractory period. n = int(n_injection + 10 * np.sqrt(n_injection)) injected_spike_train = np.sort( - np.random.uniform( - low=0, high=num_samples[segment_index], size=n - ).astype(np.int64) + np.random.uniform(low=0, high=num_samples[segment_index], size=n).astype(np.int64) ) # Remove spikes that are in the refractory period. @@ -312,22 +290,16 @@ def generate_sorting_to_inject( # Remove spikes that violate the refractory period of the real spikes. # TODO: Need a better & faster way than this. - min_diff = np.min( - np.abs(injected_spike_train[:, None] - spike_train[None, :]), axis=1 - ) + min_diff = np.min(np.abs(injected_spike_train[:, None] - spike_train[None, :]), axis=1) violations = min_diff < t_r injected_spike_train = injected_spike_train[~violations] if len(injected_spike_train) > n_injection: - injected_spike_train = np.sort( - rng.choice(injected_spike_train, n_injection, replace=False) - ) + injected_spike_train = np.sort(rng.choice(injected_spike_train, n_injection, replace=False)) injected_spike_trains[segment_index][unit_id] = injected_spike_train - return NumpySorting.from_unit_dict( - injected_spike_trains, sorting.get_sampling_frequency() - ) + return NumpySorting.from_unit_dict(injected_spike_trains, sorting.get_sampling_frequency()) class TransformSorting(BaseSorting): @@ -371,13 +343,11 @@ def __init__( if new_unit_ids is not None: new_unit_ids = list(new_unit_ids) - assert ~np.any(np.isin(new_unit_ids, sorting.unit_ids)), ( - "some units ids are already present. Consider using added_spikes_existing_units" - ) + assert ~np.any( + np.isin(new_unit_ids, sorting.unit_ids) + ), "some units ids are already present. Consider using added_spikes_existing_units" if len(new_unit_ids) > 0: - assert type(unit_ids[0]) == type(new_unit_ids[0]), ( - "unit_ids should have the same type" - ) + assert type(unit_ids[0]) == type(new_unit_ids[0]), "unit_ids should have the same type" unit_ids = unit_ids + list(new_unit_ids) BaseSorting.__init__(self, sampling_frequency, unit_ids) @@ -386,24 +356,15 @@ def __init__( self._cached_spike_vector = sorting.to_spike_vector().copy() self.refractory_period_ms = refractory_period_ms - self.added_spikes_from_existing_mask = np.zeros( - len(self._cached_spike_vector), dtype=bool - ) - self.added_spikes_from_new_mask = np.zeros( - len(self._cached_spike_vector), dtype=bool - ) + self.added_spikes_from_existing_mask = np.zeros(len(self._cached_spike_vector), dtype=bool) + self.added_spikes_from_new_mask = np.zeros(len(self._cached_spike_vector), dtype=bool) - if ( - added_spikes_existing_units is not None - and len(added_spikes_existing_units) > 0 - ): - assert added_spikes_existing_units.dtype == minimum_spike_dtype, ( - "added_spikes_existing_units should be a spike vector" - ) + if added_spikes_existing_units is not None and len(added_spikes_existing_units) > 0: + assert ( + added_spikes_existing_units.dtype == minimum_spike_dtype + ), "added_spikes_existing_units should be a spike vector" added_unit_indices = np.arange(len(self.parent_unit_ids)) - self._cached_spike_vector = np.concatenate( - (self._cached_spike_vector, added_spikes_existing_units) - ) + self._cached_spike_vector = np.concatenate((self._cached_spike_vector, added_spikes_existing_units)) self.added_spikes_from_existing_mask = np.concatenate( ( self.added_spikes_from_existing_mask, @@ -418,12 +379,10 @@ def __init__( ) if added_spikes_new_units is not None and len(added_spikes_new_units) > 0: - assert added_spikes_new_units.dtype == minimum_spike_dtype, ( - "added_spikes_new_units should be a spike vector" - ) - self._cached_spike_vector = np.concatenate( - (self._cached_spike_vector, added_spikes_new_units) - ) + assert ( + added_spikes_new_units.dtype == minimum_spike_dtype + ), "added_spikes_new_units should be a spike vector" + self._cached_spike_vector = np.concatenate((self._cached_spike_vector, added_spikes_new_units)) self.added_spikes_from_existing_mask = np.concatenate( ( self.added_spikes_from_existing_mask, @@ -444,16 +403,12 @@ def __init__( ] ) self._cached_spike_vector = self._cached_spike_vector[sort_idxs] - self.added_spikes_from_existing_mask = self.added_spikes_from_existing_mask[ - sort_idxs - ] + self.added_spikes_from_existing_mask = self.added_spikes_from_existing_mask[sort_idxs] self.added_spikes_from_new_mask = self.added_spikes_from_new_mask[sort_idxs] # We need to add the sorting segments for segment_index in range(sorting.get_num_segments()): - segment = SpikeVectorSortingSegment( - self._cached_spike_vector, segment_index, unit_ids=self.unit_ids - ) + segment = SpikeVectorSortingSegment(self._cached_spike_vector, segment_index, unit_ids=self.unit_ids) self.add_sorting_segment(segment) if self.refractory_period_ms is not None: @@ -469,9 +424,7 @@ def __init__( @property def added_spikes_mask(self): - return np.logical_or( - self.added_spikes_from_existing_mask, self.added_spikes_from_new_mask - ) + return np.logical_or(self.added_spikes_from_existing_mask, self.added_spikes_from_new_mask) def get_added_spikes_indices(self): return np.nonzero(self.added_spikes_mask)[0] @@ -486,9 +439,7 @@ def get_added_units_inds(self): return self.unit_ids[len(self.parent_unit_ids) :] @staticmethod - def add_from_sorting( - sorting1: BaseSorting, sorting2: BaseSorting, refractory_period_ms=None - ) -> "TransformSorting": + def add_from_sorting(sorting1: BaseSorting, sorting2: BaseSorting, refractory_period_ms=None) -> "TransformSorting": """ Construct TransformSorting by adding one sorting to one other. @@ -503,12 +454,10 @@ def add_from_sorting( of spikes. Any spike times in added_spikes violating the refractory period will be discarded. """ - assert sorting1.get_sampling_frequency() == sorting2.get_sampling_frequency(), ( - "sampling_frequency should be the same" - ) - assert type(sorting1.unit_ids[0]) == type(sorting2.unit_ids[0]), ( - "unit_ids should have the same type" - ) + assert ( + sorting1.get_sampling_frequency() == sorting2.get_sampling_frequency() + ), "sampling_frequency should be the same" + assert type(sorting1.unit_ids[0]) == type(sorting2.unit_ids[0]), "unit_ids should have the same type" # We detect the indices that are shared by the two sortings mask1 = np.isin(sorting2.unit_ids, sorting1.unit_ids) common_ids = sorting2.unit_ids[mask1] @@ -573,12 +522,8 @@ def add_from_unit_dict( of spikes. Any spike times in added_spikes violating the refractory period will be discarded. """ - sorting2 = NumpySorting.from_unit_dict( - units_dict_list, sorting1.get_sampling_frequency() - ) - sorting = TransformSorting.add_from_sorting( - sorting1, sorting2, refractory_period_ms - ) + sorting2 = NumpySorting.from_unit_dict(units_dict_list, sorting1.get_sampling_frequency()) + sorting = TransformSorting.add_from_sorting(sorting1, sorting2, refractory_period_ms) return sorting @staticmethod @@ -615,12 +560,8 @@ def from_samples_and_labels( discarded. """ - sorting2 = NumpySorting.from_samples_and_labels( - times_list, labels_list, sampling_frequency, unit_ids - ) - sorting = TransformSorting.add_from_sorting( - sorting1, sorting2, refractory_period_ms - ) + sorting2 = NumpySorting.from_samples_and_labels(times_list, labels_list, sampling_frequency, unit_ids) + sorting = TransformSorting.add_from_sorting(sorting1, sorting2, refractory_period_ms) return sorting def clean_refractory_period(self): @@ -642,9 +583,7 @@ def clean_refractory_period(self): ) self._cached_spike_vector = self._cached_spike_vector[to_keep] - self.added_spikes_from_existing_mask = self.added_spikes_from_existing_mask[ - to_keep - ] + self.added_spikes_from_existing_mask = self.added_spikes_from_existing_mask[to_keep] self.added_spikes_from_new_mask = self.added_spikes_from_new_mask[to_keep] @@ -837,9 +776,7 @@ def synthesize_poisson_spike_vector( refractory_period_seconds = refractory_period_ms / 1000.0 refractory_period_frames = int(refractory_period_seconds * sampling_frequency) - is_refractory_period_too_long = np.any( - refractory_period_seconds >= 1.0 / firing_rates - ) + is_refractory_period_too_long = np.any(refractory_period_seconds >= 1.0 / firing_rates) if is_refractory_period_too_long: raise ValueError( f"The given refractory period {refractory_period_ms} is too long for the firing rates {firing_rates}" @@ -860,9 +797,7 @@ def synthesize_poisson_spike_vector( binomial_p_modified = np.minimum(binomial_p_modified, 1.0) # Generate inter spike frames, add the refractory samples and accumulate for sorted spike frames - inter_spike_frames = rng.geometric( - p=binomial_p_modified[:, np.newaxis], size=(num_units, num_spikes_max) - ) + inter_spike_frames = rng.geometric(p=binomial_p_modified[:, np.newaxis], size=(num_units, num_spikes_max)) inter_spike_frames[:, 1:] += refractory_period_frames spike_frames = np.cumsum(inter_spike_frames, axis=1, out=inter_spike_frames) spike_frames = spike_frames.ravel() @@ -878,9 +813,7 @@ def synthesize_poisson_spike_vector( # Sort globaly spike_frames = spike_frames[:num_correct_frames] - sort_indices = np.argsort( - spike_frames, kind="stable" - ) # I profiled the different kinds, this is the fastest. + sort_indices = np.argsort(spike_frames, kind="stable") # I profiled the different kinds, this is the fastest. unit_indices = unit_indices[sort_indices] spike_frames = spike_frames[sort_indices] @@ -1026,9 +959,7 @@ def inject_some_duplicate_units(sorting, num=4, max_shift=5, ratio=None, seed=No """ rng = np.random.default_rng(seed) - other_ids = np.arange( - np.max(sorting.unit_ids) + 1, np.max(sorting.unit_ids) + num + 1 - ) + other_ids = np.arange(np.max(sorting.unit_ids) + 1, np.max(sorting.unit_ids) + num + 1) shifts = rng.integers(low=-max_shift, high=max_shift, size=num) shifts[shifts == 0] += max_shift @@ -1038,8 +969,7 @@ def inject_some_duplicate_units(sorting, num=4, max_shift=5, ratio=None, seed=No for segment_index in range(sorting.get_num_segments()): # sorting to dict d = { - unit_id: sorting.get_unit_spike_train(unit_id, segment_index=segment_index) - for unit_id in sorting.unit_ids + unit_id: sorting.get_unit_spike_train(unit_id, segment_index=segment_index) for unit_id in sorting.unit_ids } r = {} @@ -1060,17 +990,13 @@ def inject_some_duplicate_units(sorting, num=4, max_shift=5, ratio=None, seed=No r[unit_id] = times spiketrains.append(r) - sorting_new_units = NumpySorting.from_unit_dict( - spiketrains, sampling_frequency=sorting.get_sampling_frequency() - ) + sorting_new_units = NumpySorting.from_unit_dict(spiketrains, sampling_frequency=sorting.get_sampling_frequency()) sorting_with_dup = TransformSorting.add_from_sorting(sorting, sorting_new_units) return sorting_with_dup -def inject_some_split_units( - sorting, split_ids: list, num_split=2, output_ids=False, seed=None -): +def inject_some_split_units(sorting, split_ids: list, num_split=2, output_ids=False, seed=None): """ Inject some split units in a sorting. @@ -1108,8 +1034,7 @@ def inject_some_split_units( for segment_index in range(sorting.get_num_segments()): # sorting to dict d = { - unit_id: sorting.get_unit_spike_train(unit_id, segment_index=segment_index) - for unit_id in sorting.unit_ids + unit_id: sorting.get_unit_spike_train(unit_id, segment_index=segment_index) for unit_id in sorting.unit_ids } new_units = {} @@ -1125,18 +1050,14 @@ def inject_some_split_units( new_units[unit_id] = original_times spiketrains.append(new_units) - sorting_with_split = NumpySorting.from_unit_dict( - spiketrains, sampling_frequency=sorting.get_sampling_frequency() - ) + sorting_with_split = NumpySorting.from_unit_dict(spiketrains, sampling_frequency=sorting.get_sampling_frequency()) if output_ids: return sorting_with_split, other_ids else: return sorting_with_split -def synthetize_spike_train_bad_isi( - duration, baseline_rate, num_violations, violation_delta=1e-5 -): +def synthetize_spike_train_bad_isi(duration, baseline_rate, num_violations, violation_delta=1e-5): """Create a spike train. Has uniform inter-spike intervals, except where isis violations occur. Parameters @@ -1231,9 +1152,7 @@ def __init__( self.durations = durations self.refractory_period_seconds = refractory_period_ms / 1000.0 - is_refractory_period_too_long = np.any( - self.refractory_period_seconds >= 1.0 / firing_rates - ) + is_refractory_period_too_long = np.any(self.refractory_period_seconds >= 1.0 / firing_rates) if is_refractory_period_too_long: raise ValueError( f"The given refractory period {refractory_period_ms} is too long for the firing rates {firing_rates}" @@ -1289,21 +1208,15 @@ def __init__( self.firing_rates = firing_rates if np.isscalar(self.refractory_period_seconds): - self.refractory_period_seconds = np.full( - num_units, self.refractory_period_seconds, dtype="float64" - ) + self.refractory_period_seconds = np.full(num_units, self.refractory_period_seconds, dtype="float64") self.segment_seed = seed - self.units_seed = { - unit_id: abs(self.segment_seed + hash(unit_id)) for unit_id in unit_ids - } + self.units_seed = {unit_id: abs(self.segment_seed + hash(unit_id)) for unit_id in unit_ids} self.num_samples = math.ceil(sampling_frequency * duration) super().__init__(t_start) - def get_unit_spike_train( - self, unit_id, start_frame: int | None = None, end_frame: int | None = None - ) -> np.ndarray: + def get_unit_spike_train(self, unit_id, start_frame: int | None = None, end_frame: int | None = None) -> np.ndarray: unit_seed = self.units_seed[unit_id] unit_index = self.parent_extractor.id_to_index(unit_id) @@ -1338,9 +1251,7 @@ def get_unit_spike_train( start_index = 0 if end_frame is not None: - end_index = np.searchsorted( - spike_frames[start_index:], end_frame, side="left" - ) + end_index = np.searchsorted(spike_frames[start_index:], end_frame, side="left") else: end_index = int(self.duration * self.sampling_frequency) @@ -1422,9 +1333,7 @@ def __init__( dtype = np.dtype(dtype).name # Cast to string for serialization if dtype not in ("float32", "float64"): raise ValueError(f"'dtype' must be 'float32' or 'float64' but is {dtype}") - assert strategy in ("tile_pregenerated", "on_the_fly"), ( - "'strategy' must be 'tile_pregenerated' or 'on_the_fly'" - ) + assert strategy in ("tile_pregenerated", "on_the_fly"), "'strategy' must be 'tile_pregenerated' or 'on_the_fly'" if np.isscalar(noise_levels): noise_levels = np.ones((1, num_channels)) * noise_levels @@ -1433,9 +1342,7 @@ def __init__( if len(noise_levels.shape) < 2: noise_levels = noise_levels[np.newaxis, :] - assert len(noise_levels[0]) == num_channels, ( - "Noise levels should have a size of num_channels" - ) + assert len(noise_levels[0]) == num_channels, "Noise levels should have a size of num_channels" BaseRecording.__init__( self, @@ -1447,9 +1354,9 @@ def __init__( num_segments = len(durations) if cov_matrix is not None: - assert cov_matrix.shape[0] == cov_matrix.shape[1] == num_channels, ( - "cov_matrix should have a size (num_channels, num_channels)" - ) + assert ( + cov_matrix.shape[0] == cov_matrix.shape[1] == num_channels + ), "cov_matrix should have a size (num_channels, num_channels)" # very important here when multiprocessing and dump/load seed = _ensure_seed(seed) @@ -1597,10 +1504,7 @@ def get_traces_spatial_only( pos += end_first_block else: # special case when unique block - traces[:] = noise_block[ - start_frame_within_block : start_frame_within_block - + num_samples - ] + traces[:] = noise_block[start_frame_within_block : start_frame_within_block + num_samples] elif block_index == last_block_index: if end_frame_within_block > 0: traces[pos:] = noise_block[:end_frame_within_block] @@ -1646,28 +1550,20 @@ def get_traces( pad_end_frame = end_frame + right_pad_len print(f"{pad_start_frame=} {pad_end_frame=}") assert 0 <= pad_start_frame < pad_end_frame <= n_samples - chunk = self.get_traces_spatial_only( - pad_start_frame, pad_end_frame, channel_indices - ) + chunk = self.get_traces_spatial_only(pad_start_frame, pad_end_frame, channel_indices) empty = np.zeros_like(chunk[:0]) if left_circ_len: - circ_left = self.get_traces_spatial_only( - n_samples - left_circ_len, n_samples, channel_indices - ) + circ_left = self.get_traces_spatial_only(n_samples - left_circ_len, n_samples, channel_indices) else: circ_left = empty if right_circ_len: - circ_right = self.get_traces_spatial_only( - 0, right_circ_len, channel_indices - ) + circ_right = self.get_traces_spatial_only(0, right_circ_len, channel_indices) else: circ_right = empty chunk_main = chunk[left_pad_len : chunk.shape[0] - right_pad_len] pad_left = np.concatenate([circ_left, chunk[:left_pad_len]]) - pad_right = np.concatenate( - [chunk[chunk.shape[0] - right_pad_len : chunk.shape[0]], circ_right] - ) + pad_right = np.concatenate([chunk[chunk.shape[0] - right_pad_len : chunk.shape[0]], circ_right]) return _apply_temporal_psd( chunk_main, self.spectral_density, @@ -1796,9 +1692,7 @@ def exp_growth(start_amp, end_amp, duration_ms, tau_ms, sampling_frequency, flip return y[:-1] -def get_ellipse( - positions, center, x_factor=1, y_factor=1, x_angle=0, y_angle=0, z_angle=0 -): +def get_ellipse(positions, center, x_factor=1, y_factor=1, x_angle=0, y_angle=0, z_angle=0): """ Compute the distances to a particular ellipsoid in order to take into account spatial inhomogeneities while generating the template. In a carthesian, centered @@ -1855,9 +1749,7 @@ def get_ellipse( rot_matrix = Rx @ Ry @ Rz P = rot_matrix @ p - distances = np.sqrt( - (P[0] / x_factor) ** 2 + (P[1] / y_factor) ** 2 + (P[2] / 1) ** 2 - ) + distances = np.sqrt((P[0] / x_factor) ** 2 + (P[1] / y_factor) ** 2 + (P[2] / 1) ** 2) return distances @@ -1967,9 +1859,7 @@ def _ensure_unit_params(unit_params, num_units, seed): elif isinstance(v, (list, np.ndarray)): # already vector values = np.asarray(v) - assert values.shape == (num_units,), ( - f"generate_templates: wrong shape for {k} in unit_params" - ) + assert values.shape == (num_units,), f"generate_templates: wrong shape for {k} in unit_params" elif v is None: values = [None] * num_units else: @@ -2061,9 +1951,7 @@ def generate_templates( # channel_locations to 3D if channel_locations.shape[1] == 2: - channel_locations = np.hstack( - [channel_locations, np.zeros((channel_locations.shape[0], 1))] - ) + channel_locations = np.hstack([channel_locations, np.zeros((channel_locations.shape[0], 1))]) num_units = units_locations.shape[0] num_channels = channel_locations.shape[0] @@ -2074,9 +1962,7 @@ def generate_templates( if upsample_factor is not None: upsample_factor = int(upsample_factor) assert upsample_factor >= 1 - templates = np.zeros( - (num_units, width, num_channels, upsample_factor), dtype=dtype - ) + templates = np.zeros((num_units, width, num_channels, upsample_factor), dtype=dtype) fs = sampling_frequency * upsample_factor else: templates = np.zeros((num_units, width, num_channels), dtype=dtype) @@ -2221,17 +2107,9 @@ def __init__( check_borders = False self.templates = templates - channel_ids = ( - parent_recording.channel_ids - if parent_recording is not None - else list(range(templates.shape[2])) - ) - dtype = ( - parent_recording.dtype if parent_recording is not None else templates.dtype - ) - BaseRecording.__init__( - self, sorting.get_sampling_frequency(), channel_ids, dtype - ) + channel_ids = parent_recording.channel_ids if parent_recording is not None else list(range(templates.shape[2])) + dtype = parent_recording.dtype if parent_recording is not None else templates.dtype + BaseRecording.__init__(self, sorting.get_sampling_frequency(), channel_ids, dtype) # Important : self._serializability is not change here because it will depend on the sorting parents itself. @@ -2263,9 +2141,7 @@ def __init__( if amplitude_factor is None: amplitude_vector = None elif np.isscalar(amplitude_factor): - amplitude_vector = np.full( - self.spike_vector.size, amplitude_factor, dtype="float32" - ) + amplitude_vector = np.full(self.spike_vector.size, amplitude_factor, dtype="float32") else: amplitude_factor = np.asarray(amplitude_factor) assert amplitude_factor.shape == self.spike_vector.shape @@ -2273,18 +2149,13 @@ def __init__( if parent_recording is not None: assert parent_recording.get_num_segments() == sorting.get_num_segments() - assert ( - parent_recording.get_sampling_frequency() - == sorting.get_sampling_frequency() - ) + assert parent_recording.get_sampling_frequency() == sorting.get_sampling_frequency() assert parent_recording.get_num_channels() == templates.shape[2] parent_recording.copy_metadata(self) if num_samples is None: if parent_recording is None: - num_samples = [ - self.spike_vector["sample_index"][-1] + templates.shape[1] - ] + num_samples = [self.spike_vector["sample_index"][-1] + templates.shape[1]] else: num_samples = [ parent_recording.get_num_frames(segment_index) @@ -2295,25 +2166,13 @@ def __init__( num_samples = [num_samples] for segment_index in range(sorting.get_num_segments()): - start = np.searchsorted( - self.spike_vector["segment_index"], segment_index, side="left" - ) - end = np.searchsorted( - self.spike_vector["segment_index"], segment_index, side="right" - ) + start = np.searchsorted(self.spike_vector["segment_index"], segment_index, side="left") + end = np.searchsorted(self.spike_vector["segment_index"], segment_index, side="right") spikes = self.spike_vector[start:end] - amplitude_vec = ( - amplitude_vector[start:end] if amplitude_vector is not None else None - ) - upsample_vec = ( - upsample_vector[start:end] if upsample_vector is not None else None - ) + amplitude_vec = amplitude_vector[start:end] if amplitude_vector is not None else None + upsample_vec = upsample_vector[start:end] if upsample_vector is not None else None - parent_recording_segment = ( - None - if parent_recording is None - else parent_recording.segments[segment_index] - ) + parent_recording_segment = None if parent_recording is None else parent_recording.segments[segment_index] recording_segment = InjectTemplatesRecordingSegment( self.sampling_frequency, self.dtype, @@ -2353,10 +2212,7 @@ def _check_templates(templates: np.ndarray): max_value = np.max(np.abs(templates)) threshold = 0.01 * max_value - if ( - max(np.max(np.abs(templates[:, 0])), np.max(np.abs(templates[:, -1]))) - > threshold - ): + if max(np.max(np.abs(templates[:, 0])), np.max(np.abs(templates[:, -1]))) > threshold: warnings.warn( "Warning! Your templates do not go to 0 on the edges in InjectTemplatesRecording. Please make your window bigger." ) @@ -2378,9 +2234,7 @@ def __init__( BaseRecordingSegment.__init__( self, sampling_frequency, - t_start=0 - if parent_recording_segment is None - else parent_recording_segment.t_start, + t_start=0 if parent_recording_segment is None else parent_recording_segment.t_start, ) assert not (parent_recording_segment is None and num_samples is None) @@ -2391,11 +2245,7 @@ def __init__( self.amplitude_vector = amplitude_vector self.upsample_vector = upsample_vector self.parent_recording = parent_recording_segment - self.num_samples = ( - parent_recording_segment.get_num_frames() - if num_samples is None - else num_samples - ) + self.num_samples = parent_recording_segment.get_num_frames() if num_samples is None else num_samples def get_traces( self, @@ -2406,11 +2256,7 @@ def get_traces( if channel_indices is None: n_channels = self.templates.shape[2] elif isinstance(channel_indices, slice): - stop = ( - channel_indices.stop - if channel_indices.stop is not None - else self.templates.shape[2] - ) + stop = channel_indices.stop if channel_indices.stop is not None else self.templates.shape[2] start = channel_indices.start if channel_indices.start is not None else 0 step = channel_indices.step if channel_indices.step is not None else 1 n_channels = math.ceil((stop - start) / step) @@ -2418,9 +2264,7 @@ def get_traces( n_channels = len(channel_indices) if self.parent_recording is not None: - traces = self.parent_recording.get_traces( - start_frame, end_frame, channel_indices - ).copy() + traces = self.parent_recording.get_traces(start_frame, end_frame, channel_indices).copy() else: traces = np.zeros([end_frame - start_frame, n_channels], dtype=self.dtype) @@ -2476,9 +2320,7 @@ def get_num_samples(self) -> int: return self.num_samples -inject_templates = define_function_from_class( - source_class=InjectTemplatesRecording, name="inject_templates" -) +inject_templates = define_function_from_class(source_class=InjectTemplatesRecording, name="inject_templates") ## toy example zone ## @@ -2492,9 +2334,7 @@ def generate_channel_locations(num_channels, num_columns, contact_spacing_um): num_contact_per_column = num_channels // num_columns j = 0 for i in range(num_columns): - channel_locations[j : j + num_contact_per_column, 0] = ( - i * contact_spacing_um - ) + channel_locations[j : j + num_contact_per_column, 0] = i * contact_spacing_um channel_locations[j : j + num_contact_per_column, 1] = ( np.arange(num_contact_per_column) * contact_spacing_um ) @@ -2513,9 +2353,7 @@ def _generate_multimodal(rng, size, num_modes, lim0, lim1): prob += np.exp(-((bins - center) ** 2) / (2 * sigma**2)) prob /= np.sum(prob) choices = rng.choice(np.arange(bins.size), size, p=prob) - values = bins[choices] + rng.uniform( - low=-bin_step / 2, high=bin_step / 2, size=size - ) + values = bins[choices] + rng.uniform(low=-bin_step / 2, high=bin_step / 2, size=size) return values @@ -2603,22 +2441,16 @@ def generate_unit_locations( if distribution == "uniform": units_locations[:, 1] = rng.uniform(minimum_y, maximum_y, size=num_units) elif distribution == "multimodal": - units_locations[:, 1] = _generate_multimodal( - rng, num_units, num_modes, minimum_y, maximum_y - ) + units_locations[:, 1] = _generate_multimodal(rng, num_units, num_modes, minimum_y, maximum_y) else: - raise ValueError( - "generate_unit_locations has wrong distribution must be 'uniform' or " - ) + raise ValueError("generate_unit_locations has wrong distribution must be 'uniform' or ") units_locations[:, 2] = rng.uniform(minimum_z, maximum_z, size=num_units) if minimum_distance is not None: solution_found = False renew_inds = None for i in range(max_iteration): - distances = np.linalg.norm( - units_locations[:, np.newaxis] - units_locations[np.newaxis, :], axis=2 - ) + distances = np.linalg.norm(units_locations[:, np.newaxis] - units_locations[np.newaxis, :], axis=2) inds0, inds1 = np.nonzero(distances < minimum_distance) mask = inds0 != inds1 inds0 = inds0[mask] @@ -2631,21 +2463,15 @@ def generate_unit_locations( # random only bad ones in the previous set renew_inds = renew_inds[np.isin(renew_inds, np.unique(inds0))] - units_locations[:, 0][renew_inds] = rng.uniform( - minimum_x, maximum_x, size=renew_inds.size - ) + units_locations[:, 0][renew_inds] = rng.uniform(minimum_x, maximum_x, size=renew_inds.size) if distribution == "uniform": - units_locations[:, 1][renew_inds] = rng.uniform( - minimum_y, maximum_y, size=renew_inds.size - ) + units_locations[:, 1][renew_inds] = rng.uniform(minimum_y, maximum_y, size=renew_inds.size) elif distribution == "multimodal": units_locations[:, 1][renew_inds] = _generate_multimodal( rng, renew_inds.size, num_modes, minimum_y, maximum_y ) - units_locations[:, 2][renew_inds] = rng.uniform( - minimum_z, maximum_z, size=renew_inds.size - ) + units_locations[:, 2][renew_inds] = rng.uniform(minimum_z, maximum_z, size=renew_inds.size) else: solution_found = True @@ -2658,9 +2484,7 @@ def generate_unit_locations( "You can use distance_strict=False or reduce minimum distance" ) else: - warnings.warn( - f"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=}" - ) + warnings.warn(f"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=}") return units_locations @@ -2686,9 +2510,7 @@ def generate_ground_truth_recording( upsample_vector=None, generate_sorting_kwargs=dict(firing_rates=15, refractory_period_ms=4.0), noise_kwargs=dict(noise_levels=5.0, strategy="on_the_fly"), - generate_unit_locations_kwargs=dict( - margin_um=10.0, minimum_z=5.0, maximum_z=50.0, minimum_distance=20 - ), + generate_unit_locations_kwargs=dict(margin_um=10.0, minimum_z=5.0, maximum_z=50.0, minimum_distance=20), generate_templates_kwargs=None, dtype="float32", seed=None, @@ -2787,9 +2609,7 @@ def generate_ground_truth_recording( num_contact_per_column[mid] += num_channels % prb_kwargs["num_columns"] prb_kwargs["num_contact_per_column"] = num_contact_per_column else: - raise ValueError( - "num_columns should be provided in dict generate_probe_kwargs" - ) + raise ValueError("num_columns should be provided in dict generate_probe_kwargs") probe = generate_multi_columns_probe(**prb_kwargs) probe.set_device_channel_indices(np.arange(num_channels)) diff --git a/src/spikeinterface/core/tests/test_generate.py b/src/spikeinterface/core/tests/test_generate.py index 3bbbf66bab..6b6287b99e 100644 --- a/src/spikeinterface/core/tests/test_generate.py +++ b/src/spikeinterface/core/tests/test_generate.py @@ -57,9 +57,7 @@ def test_generate_sorting_with_spikes_on_borders(): ) # check that segments are correctly sorted all_spikes = sorting.to_spike_vector() - np.testing.assert_array_equal( - all_spikes["segment_index"], np.sort(all_spikes["segment_index"]) - ) + np.testing.assert_array_equal(all_spikes["segment_index"], np.sort(all_spikes["segment_index"])) spikes = sorting.to_spike_vector(concatenated=False) # at least num_border spikes at borders for all segments @@ -70,16 +68,9 @@ def test_generate_sorting_with_spikes_on_borders(): np.sort(spikes_in_segment["sample_index"]), ) num_samples = int(segment_duration * 30000) + assert np.sum(spikes_in_segment["sample_index"] < border_size_samples) >= num_spikes_on_borders assert ( - np.sum(spikes_in_segment["sample_index"] < border_size_samples) - >= num_spikes_on_borders - ) - assert ( - np.sum( - spikes_in_segment["sample_index"] - >= num_samples - border_size_samples - ) - >= num_spikes_on_borders + np.sum(spikes_in_segment["sample_index"] >= num_samples - border_size_samples) >= num_spikes_on_borders ) @@ -129,9 +120,9 @@ def test_memory_sorting_generator(): memory_usage_MiB = after_instanciation_MiB - before_instanciation_MiB ratio = memory_usage_MiB / before_instanciation_MiB expected_allocation_MiB = 0 - assert ratio <= 1.0 + relative_tolerance, ( - f"SortingGenerator wrong memory {memory_usage_MiB} instead of {expected_allocation_MiB}" - ) + assert ( + ratio <= 1.0 + relative_tolerance + ), f"SortingGenerator wrong memory {memory_usage_MiB} instead of {expected_allocation_MiB}" def test_sorting_generator_consisency_across_calls(): @@ -168,12 +159,8 @@ def test_sorting_generator_consisency_within_trains(): ) for unit_id in sorting.get_unit_ids(): - spike_train = sorting.get_unit_spike_train( - unit_id=unit_id, start_frame=0, end_frame=1000 - ) - spike_train_again = sorting.get_unit_spike_train( - unit_id=unit_id, start_frame=0, end_frame=1000 - ) + spike_train = sorting.get_unit_spike_train(unit_id=unit_id, start_frame=0, end_frame=1000) + spike_train_again = sorting.get_unit_spike_train(unit_id=unit_id, start_frame=0, end_frame=1000) assert np.allclose(spike_train, spike_train_again) @@ -205,13 +192,11 @@ def test_noise_generator_memory(): ) after_instanciation_MiB = measure_memory_allocation() / bytes_to_MiB_factor memory_usage_MiB = after_instanciation_MiB - before_instanciation_MiB - expected_allocation_MiB = ( - dtype.itemsize * num_channels * noise_block_size / bytes_to_MiB_factor - ) + expected_allocation_MiB = dtype.itemsize * num_channels * noise_block_size / bytes_to_MiB_factor ratio = expected_allocation_MiB / expected_allocation_MiB - assert ratio <= 1.0 + relative_tolerance, ( - f"NoiseGeneratorRecording with 'tile_pregenerated' wrong memory {memory_usage_MiB} instead of {expected_allocation_MiB}" - ) + assert ( + ratio <= 1.0 + relative_tolerance + ), f"NoiseGeneratorRecording with 'tile_pregenerated' wrong memory {memory_usage_MiB} instead of {expected_allocation_MiB}" # case 2: no preallocation very few memory (under 2 MiB) before_instanciation_MiB = measure_memory_allocation() / bytes_to_MiB_factor @@ -226,9 +211,7 @@ def test_noise_generator_memory(): ) after_instanciation_MiB = measure_memory_allocation() / bytes_to_MiB_factor memory_usage_MiB = after_instanciation_MiB - before_instanciation_MiB - assert memory_usage_MiB < 2, ( - f"NoiseGeneratorRecording with 'on_the_fly wrong memory {memory_usage_MiB}MiB" - ) + assert memory_usage_MiB < 2, f"NoiseGeneratorRecording with 'on_the_fly wrong memory {memory_usage_MiB}MiB" def test_noise_generator_several_noise_levels(): @@ -388,9 +371,7 @@ def test_noise_generator_consistency_across_calls(strategy, start_frame, end_fra ) traces = lazy_recording.get_traces(start_frame=start_frame, end_frame=end_frame) - same_traces = lazy_recording.get_traces( - start_frame=start_frame, end_frame=end_frame - ) + same_traces = lazy_recording.get_traces(start_frame=start_frame, end_frame=end_frame) assert np.allclose(traces, same_traces) @@ -406,18 +387,14 @@ def test_noise_generator_consistency_across_calls(strategy, start_frame, end_fra (0, 60_000, 10_000), ], ) -def test_noise_generator_consistency_across_traces( - strategy, start_frame, end_frame, extra_samples -): +def test_noise_generator_consistency_across_traces(strategy, start_frame, end_frame, extra_samples): # Test that the generated traces behave like true arrays. Calling a larger array and then slicing it should # give the same result as calling the slice directly sampling_frequency = 30000 # Hz durations = [10.0] dtype = np.dtype("float32") num_channels = 2 - seed = ( - start_frame + end_frame + extra_samples - ) # To make sure that the seed is different for each test + seed = start_frame + end_frame + extra_samples # To make sure that the seed is different for each test lazy_recording = NoiseGeneratorRecording( num_channels=num_channels, @@ -430,12 +407,8 @@ def test_noise_generator_consistency_across_traces( traces = lazy_recording.get_traces(start_frame=start_frame, end_frame=end_frame) end_frame_larger_array = end_frame + extra_samples - larger_traces = lazy_recording.get_traces( - start_frame=start_frame, end_frame=end_frame_larger_array - ) - equivalent_trace_from_larger_traces = larger_traces[ - :-extra_samples, : - ] # Remove the extra samples + larger_traces = lazy_recording.get_traces(start_frame=start_frame, end_frame=end_frame_larger_array) + equivalent_trace_from_larger_traces = larger_traces[:-extra_samples, :] # Remove the extra samples assert np.allclose(traces, equivalent_trace_from_larger_traces) @@ -468,9 +441,7 @@ def test_generate_single_fake_waveform(): sampling_frequency = 30000.0 ms_before = 1.0 ms_after = 3.0 - wf = generate_single_fake_waveform( - ms_before=ms_before, ms_after=ms_after, sampling_frequency=sampling_frequency - ) + wf = generate_single_fake_waveform(ms_before=ms_before, ms_after=ms_after, sampling_frequency=sampling_frequency) # import matplotlib.pyplot as plt # times = np.arange(wf.size) / sampling_frequency * 1000 - ms_before @@ -483,9 +454,7 @@ def test_generate_single_fake_waveform(): def test_generate_unit_locations(): seed = 0 - probe = generate_multi_columns_probe( - num_columns=2, num_contact_per_column=20, xpitch=20, ypitch=20 - ) + probe = generate_multi_columns_probe(num_columns=2, num_contact_per_column=20, xpitch=20, ypitch=20) channel_locations = probe.contact_positions num_units = 100 @@ -501,9 +470,7 @@ def test_generate_unit_locations(): distance_strict=False, seed=seed, ) - distances = np.linalg.norm( - unit_locations[:, np.newaxis] - unit_locations[np.newaxis, :], axis=2 - ) + distances = np.linalg.norm(unit_locations[:, np.newaxis] - unit_locations[np.newaxis, :], axis=2) dist_flat = np.triu(distances, k=1).flatten() dist_flat = dist_flat[dist_flat > 0] assert np.all(dist_flat > minimum_distance) @@ -526,9 +493,7 @@ def test_generate_templates(): num_units = 10 margin_um = 15.0 channel_locations = generate_channel_locations(num_chans, num_columns, 20.0) - unit_locations = generate_unit_locations( - num_units, channel_locations, margin_um=margin_um, seed=seed - ) + unit_locations = generate_unit_locations(num_units, channel_locations, margin_um=margin_um, seed=seed) sampling_frequency = 30000.0 ms_before = 1.0 @@ -612,9 +577,7 @@ def test_inject_templates(): firing_rates=1.0, seed=42, ) - units_locations = generate_unit_locations( - num_units, channel_locations, margin_um=10.0, seed=42 - ) + units_locations = generate_unit_locations(num_units, channel_locations, margin_um=10.0, seed=42) templates_3d = generate_templates( channel_locations, units_locations, @@ -639,22 +602,15 @@ def test_inject_templates(): sorting, templates_3d, nbefore=nbefore, - num_samples=[ - rec_noise.get_num_frames(seg_ind) - for seg_ind in range(rec_noise.get_num_segments()) - ], + num_samples=[rec_noise.get_num_frames(seg_ind) for seg_ind in range(rec_noise.get_num_segments())], ) # Case 2: with parent_recording - rec2 = InjectTemplatesRecording( - sorting, templates_3d, nbefore=nbefore, parent_recording=rec_noise - ) + rec2 = InjectTemplatesRecording(sorting, templates_3d, nbefore=nbefore, parent_recording=rec_noise) # Case 3: with parent_recording + upsample_factor rng = np.random.default_rng(seed=42) - upsample_vector = rng.integers( - 0, upsample_factor, size=sorting.to_spike_vector().size - ) + upsample_vector = rng.integers(0, upsample_factor, size=sorting.to_spike_vector().size) rec3 = InjectTemplatesRecording( sorting, templates_4d, @@ -665,12 +621,8 @@ def test_inject_templates(): for rec in (rec1, rec2, rec3): assert rec.get_traces(end_frame=600, segment_index=0).shape == (600, 4) - assert rec.get_traces( - start_frame=100, end_frame=600, segment_index=1 - ).shape == (500, 4) - assert rec.get_traces( - start_frame=rec_noise.get_num_frames(0) - 200, segment_index=0 - ).shape == (200, 4) + assert rec.get_traces(start_frame=100, end_frame=600, segment_index=1).shape == (500, 4) + assert rec.get_traces(start_frame=rec_noise.get_num_frames(0) - 200, segment_index=0).shape == (200, 4) # Check dumpability saved_loaded = load(rec.to_dict()) @@ -691,23 +643,16 @@ def test_transformsorting(): transformed = TransformSorting.add_from_sorting(sorting_1, sorting_3) assert len(transformed.unit_ids) == 50 - sorting_1 = NumpySorting.from_unit_dict( - {46: np.array([0, 150], dtype=int)}, sampling_frequency=20000.0 - ) + sorting_1 = NumpySorting.from_unit_dict({46: np.array([0, 150], dtype=int)}, sampling_frequency=20000.0) sorting_2 = NumpySorting.from_unit_dict( {0: np.array([100, 2000], dtype=int), 3: np.array([200, 4000], dtype=int)}, sampling_frequency=20000.0, ) transformed = TransformSorting.add_from_sorting(sorting_1, sorting_2) assert len(transformed.unit_ids) == 3 - assert np.all( - np.array([k for k in transformed.count_num_spikes_per_unit(outputs="array")]) - == 2 - ) + assert np.all(np.array([k for k in transformed.count_num_spikes_per_unit(outputs="array")]) == 2) - sorting_1 = NumpySorting.from_unit_dict( - {0: np.array([12], dtype=int)}, sampling_frequency=20000.0 - ) + sorting_1 = NumpySorting.from_unit_dict({0: np.array([12], dtype=int)}, sampling_frequency=20000.0) sorting_2 = NumpySorting.from_unit_dict( {0: np.array([150], dtype=int), 3: np.array([12, 150], dtype=int)}, sampling_frequency=20000.0, @@ -715,36 +660,26 @@ def test_transformsorting(): transformed = TransformSorting.add_from_sorting(sorting_1, sorting_2) assert len(transformed.unit_ids) == 2 target_array = np.array([2, 2]) - source_array = np.array( - [k for k in transformed.count_num_spikes_per_unit(outputs="array")] - ) + source_array = np.array([k for k in transformed.count_num_spikes_per_unit(outputs="array")]) assert np.array_equal(source_array, target_array) assert transformed.get_added_spikes_from_existing_indices().size == 1 assert transformed.get_added_spikes_from_new_indices().size == 2 assert transformed.get_added_units_inds() == [3] - transformed = TransformSorting.add_from_unit_dict( - sorting_1, {46: np.array([12, 150], dtype=int)} - ) + transformed = TransformSorting.add_from_unit_dict(sorting_1, {46: np.array([12, 150], dtype=int)}) sorting_1 = generate_sorting(seed=0) - transformed = TransformSorting( - sorting_1, sorting_1.to_spike_vector(), refractory_period_ms=0 - ) + transformed = TransformSorting(sorting_1, sorting_1.to_spike_vector(), refractory_period_ms=0) assert len(sorting_1.to_spike_vector()) == len(transformed.to_spike_vector()) - transformed = TransformSorting( - sorting_1, sorting_1.to_spike_vector(), refractory_period_ms=5 - ) + transformed = TransformSorting(sorting_1, sorting_1.to_spike_vector(), refractory_period_ms=5) assert 2 * len(sorting_1.to_spike_vector()) > len(transformed.to_spike_vector()) transformed_2 = TransformSorting(sorting_1, transformed.to_spike_vector()) assert len(transformed_2.to_spike_vector()) > len(transformed.to_spike_vector()) - assert np.sum(transformed_2.get_added_spikes_indices()) >= np.sum( - transformed_2.get_added_spikes_from_new_indices() - ) + assert np.sum(transformed_2.get_added_spikes_indices()) >= np.sum(transformed_2.get_added_spikes_from_new_indices()) assert np.sum(transformed_2.get_added_spikes_indices()) >= np.sum( transformed_2.get_added_spikes_from_existing_indices() ) From 9877d077c0218b1967f819b1daf87f21d2c8581d Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Wed, 1 Apr 2026 13:42:00 -0400 Subject: [PATCH 3/6] Support running without scipy --- src/spikeinterface/core/generate.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 9eb87c97de..8eca16b9d6 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1648,7 +1648,11 @@ def _apply_temporal_psd( ffts will vary slightly. If you can tolerate errors of 5e-7 or if you'll only ever call get_traces() with the same args in non-overlapping chunks, you can use_overlap_add=True. """ - from scipy.signal import oaconvolve, convolve + try: + from scipy.signal import oaconvolve, convolve + have_scipy = True + except ImportError: + have_scipy = False klen = psd.shape[0] block_len = 2 * klen - 1 @@ -1665,7 +1669,9 @@ def _apply_temporal_psd( # stack and convolve chunk_conv = np.concatenate([pad_left.T, chunk.T, pad_right.T], axis=1) kernel = np.fft.fftshift(np.fft.irfft(psd, n=block_len)) - if use_overlap_add: + if not have_scipy: + conv = np.convolve(chunk_conv, kernel[None], mode="valid") + elif use_overlap_add: conv = oaconvolve(chunk_conv, kernel[None], mode="valid", axes=1) else: conv = convolve(chunk_conv, kernel[None], mode="valid", method="direct") From 771262aac7ef8b2b608d1587e8897cc3ddf56410 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 1 Apr 2026 17:42:34 +0000 Subject: [PATCH 4/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/generate.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 8eca16b9d6..d438c00e82 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1650,6 +1650,7 @@ def _apply_temporal_psd( """ try: from scipy.signal import oaconvolve, convolve + have_scipy = True except ImportError: have_scipy = False From 34a8b619f41af4e4e9e8c45acfa8fb235a8d25db Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Wed, 1 Apr 2026 13:55:44 -0400 Subject: [PATCH 5/6] Revert "Support running without scipy" This reverts commit 9877d077c0218b1967f819b1daf87f21d2c8581d. --- src/spikeinterface/core/generate.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 8eca16b9d6..9eb87c97de 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1648,11 +1648,7 @@ def _apply_temporal_psd( ffts will vary slightly. If you can tolerate errors of 5e-7 or if you'll only ever call get_traces() with the same args in non-overlapping chunks, you can use_overlap_add=True. """ - try: - from scipy.signal import oaconvolve, convolve - have_scipy = True - except ImportError: - have_scipy = False + from scipy.signal import oaconvolve, convolve klen = psd.shape[0] block_len = 2 * klen - 1 @@ -1669,9 +1665,7 @@ def _apply_temporal_psd( # stack and convolve chunk_conv = np.concatenate([pad_left.T, chunk.T, pad_right.T], axis=1) kernel = np.fft.fftshift(np.fft.irfft(psd, n=block_len)) - if not have_scipy: - conv = np.convolve(chunk_conv, kernel[None], mode="valid") - elif use_overlap_add: + if use_overlap_add: conv = oaconvolve(chunk_conv, kernel[None], mode="valid", axes=1) else: conv = convolve(chunk_conv, kernel[None], mode="valid", method="direct") From 5a0787b83d35734d5cd0665fc63b29fafcf549a8 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Thu, 2 Apr 2026 09:16:39 -0400 Subject: [PATCH 6/6] tests: Move temporal noise test out of core --- .../core/tests/test_generate.py | 60 ------------------ .../generation/tests/test_noise_tools.py | 61 +++++++++++++++++++ 2 files changed, 61 insertions(+), 60 deletions(-) diff --git a/src/spikeinterface/core/tests/test_generate.py b/src/spikeinterface/core/tests/test_generate.py index 6b6287b99e..16a180a95f 100644 --- a/src/spikeinterface/core/tests/test_generate.py +++ b/src/spikeinterface/core/tests/test_generate.py @@ -282,66 +282,6 @@ def test_noise_generator_correct_shape(strategy): assert traces.shape == (num_frames, num_channels) -@pytest.mark.parametrize("duration", [1.0, 2.0, 2.2]) -@pytest.mark.parametrize("strategy", strategy_list) -def test_noise_generator_temporal(strategy, duration): - psdlen = 25 - kdomain = np.linspace(0.0, 10.0, psdlen) - fake_psd = (kdomain + 0.1) * np.exp(-kdomain) - # this ensures std dev of output ~= 1 - fake_psd /= np.sqrt((fake_psd**2).mean()) - - # Test that the recording has the correct size in shape - sampling_frequency = 30000 # Hz - durations = [duration] - dtype = np.dtype("float32") - num_channels = 2 - seed = 0 - - rec = NoiseGeneratorRecording( - num_channels=num_channels, - sampling_frequency=sampling_frequency, - durations=durations, - dtype=dtype, - seed=seed, - spectral_density=fake_psd, - strategy=strategy, - ) - - # check output matches at different chunks - full_traces = rec.get_traces() - end_frame = rec.get_num_frames() - for t0 in [0, 100]: - for t1 in [end_frame, end_frame - 100]: - print(f"{t0=} {t1=}") - chk = rec.get_traces(0, t0, t1) - chk0 = full_traces[t0:t1] - print(f"{np.flatnonzero((chk!=chk0).any(1))=}") - np.testing.assert_array_equal(chk, chk0) - - np.testing.assert_allclose(full_traces.std(), 1.0, rtol=0.02) - - # re-estimate the psd from the result - # it will not be perfect, okay! - n = 2 * psdlen - 1 - snips = full_traces[: n * (full_traces.shape[0] // n)] - snips = snips.reshape(-1, n, snips.shape[-1]) - psd = np.fft.rfft(snips, n=n, axis=1, norm="ortho") - psd = np.sqrt(np.square(np.abs(psd)).mean(axis=(0, 2))) - - sample_size = snips.shape[0] * snips.shape[2] - standard_error = 1.0 / np.sqrt(sample_size) - - # accuracy is good at low freqs - np.testing.assert_allclose( - psd[1 : psdlen // 3], - fake_psd[1 : psdlen // 3], - atol=3 * standard_error, - rtol=0.1, - ) - np.testing.assert_allclose(psd, fake_psd, atol=0.5) - - @pytest.mark.parametrize("strategy", strategy_list) @pytest.mark.parametrize( "start_frame, end_frame", diff --git a/src/spikeinterface/generation/tests/test_noise_tools.py b/src/spikeinterface/generation/tests/test_noise_tools.py index f633ed1de4..766229032c 100644 --- a/src/spikeinterface/generation/tests/test_noise_tools.py +++ b/src/spikeinterface/generation/tests/test_noise_tools.py @@ -1,7 +1,10 @@ +import numpy as np import probeinterface +import pytest from spikeinterface.generation import ( generate_noise, + NoiseGeneratorRecording, ) @@ -40,5 +43,63 @@ def test_generate_noise(): # plt.show() +@pytest.mark.parametrize("duration", [1.0, 2.0, 2.2]) +@pytest.mark.parametrize("strategy", ["tile_precomputed", "on_the_fly"]) +def test_noise_generator_temporal(strategy, duration): + psdlen = 25 + kdomain = np.linspace(0.0, 10.0, psdlen) + fake_psd = (kdomain + 0.1) * np.exp(-kdomain) + # this ensures std dev of output ~= 1 + fake_psd /= np.sqrt((fake_psd**2).mean()) + + # Test that the recording has the correct size in shape + sampling_frequency = 30000 # Hz + durations = [duration] + dtype = np.dtype("float32") + num_channels = 2 + seed = 0 + + rec = NoiseGeneratorRecording( + num_channels=num_channels, + sampling_frequency=sampling_frequency, + durations=durations, + dtype=dtype, + seed=seed, + spectral_density=fake_psd, + strategy=strategy, + ) + + # check output matches at different chunks + full_traces = rec.get_traces() + end_frame = rec.get_num_frames() + for t0 in [0, 100]: + for t1 in [end_frame, end_frame - 100]: + chk = rec.get_traces(0, t0, t1) + chk0 = full_traces[t0:t1] + np.testing.assert_array_equal(chk, chk0) + + np.testing.assert_allclose(full_traces.std(), 1.0, rtol=0.02) + + # re-estimate the psd from the result + # it will not be perfect, okay! + n = 2 * psdlen - 1 + snips = full_traces[: n * (full_traces.shape[0] // n)] + snips = snips.reshape(-1, n, snips.shape[-1]) + psd = np.fft.rfft(snips, n=n, axis=1, norm="ortho") + psd = np.sqrt(np.square(np.abs(psd)).mean(axis=(0, 2))) + + sample_size = snips.shape[0] * snips.shape[2] + standard_error = 1.0 / np.sqrt(sample_size) + + # accuracy is good at low freqs + np.testing.assert_allclose( + psd[1 : psdlen // 3], + fake_psd[1 : psdlen // 3], + atol=3 * standard_error, + rtol=0.1, + ) + np.testing.assert_allclose(psd, fake_psd, atol=0.5) + + if __name__ == "__main__": test_generate_noise()