diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 1c9ece728f..65a4be0274 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -366,10 +366,16 @@ def __init__( added_unit_indices = np.arange(len(self.parent_unit_ids)) self._cached_spike_vector = np.concatenate((self._cached_spike_vector, added_spikes_existing_units)) self.added_spikes_from_existing_mask = np.concatenate( - (self.added_spikes_from_existing_mask, np.ones(len(added_spikes_existing_units), dtype=bool)) + ( + self.added_spikes_from_existing_mask, + np.ones(len(added_spikes_existing_units), dtype=bool), + ) ) self.added_spikes_from_new_mask = np.concatenate( - (self.added_spikes_from_new_mask, np.zeros(len(added_spikes_existing_units), dtype=bool)) + ( + self.added_spikes_from_new_mask, + np.zeros(len(added_spikes_existing_units), dtype=bool), + ) ) if added_spikes_new_units is not None and len(added_spikes_new_units) > 0: @@ -378,13 +384,24 @@ def __init__( ), "added_spikes_new_units should be a spike vector" self._cached_spike_vector = np.concatenate((self._cached_spike_vector, added_spikes_new_units)) self.added_spikes_from_existing_mask = np.concatenate( - (self.added_spikes_from_existing_mask, np.zeros(len(added_spikes_new_units), dtype=bool)) + ( + self.added_spikes_from_existing_mask, + np.zeros(len(added_spikes_new_units), dtype=bool), + ) ) self.added_spikes_from_new_mask = np.concatenate( - (self.added_spikes_from_new_mask, np.ones(len(added_spikes_new_units), dtype=bool)) + ( + self.added_spikes_from_new_mask, + np.ones(len(added_spikes_new_units), dtype=bool), + ) ) - sort_idxs = np.lexsort([self._cached_spike_vector["sample_index"], self._cached_spike_vector["segment_index"]]) + sort_idxs = np.lexsort( + [ + self._cached_spike_vector["sample_index"], + self._cached_spike_vector["segment_index"], + ] + ) self._cached_spike_vector = self._cached_spike_vector[sort_idxs] self.added_spikes_from_existing_mask = self.added_spikes_from_existing_mask[sort_idxs] self.added_spikes_from_new_mask = self.added_spikes_from_new_mask[sort_idxs] @@ -484,7 +501,9 @@ def add_from_sorting(sorting1: BaseSorting, sorting2: BaseSorting, refractory_pe @staticmethod def add_from_unit_dict( - sorting1: BaseSorting, units_dict_list: list[dict] | dict, refractory_period_ms=None + sorting1: BaseSorting, + units_dict_list: list[dict] | dict, + refractory_period_ms=None, ) -> "TransformSorting": """ Construct TransformSorting by adding one sorting with a @@ -509,7 +528,12 @@ def add_from_unit_dict( @staticmethod def from_samples_and_labels( - sorting1, times_list, labels_list, sampling_frequency, unit_ids=None, refractory_period_ms=None + sorting1, + times_list, + labels_list, + sampling_frequency, + unit_ids=None, + refractory_period_ms=None, ) -> "NumpySorting": """ Construct TransformSorting from: @@ -554,7 +578,8 @@ def clean_refractory_period(self): * (self._cached_spike_vector["segment_index"] == segment_index) ) to_keep[indices[1:]] = np.logical_or( - to_keep[indices[1:]], np.diff(self._cached_spike_vector[indices]["sample_index"]) > rpv + to_keep[indices[1:]], + np.diff(self._cached_spike_vector[indices]["sample_index"]) > rpv, ) self._cached_spike_vector = self._cached_spike_vector[to_keep] @@ -647,11 +672,19 @@ def generate_snippets( ) sorting = generate_sorting( - num_units=num_units, sampling_frequency=sampling_frequency, durations=durations, empty_units=empty_units + num_units=num_units, + sampling_frequency=sampling_frequency, + durations=durations, + empty_units=empty_units, ) snippets = snippets_from_sorting( - recording=recording, sorting=sorting, nbefore=nbefore, nafter=nafter, wf_folder=wf_folder, **job_kwargs + recording=recording, + sorting=sorting, + nbefore=nbefore, + nafter=nafter, + wf_folder=wf_folder, + **job_kwargs, ) if set_probe: @@ -1250,6 +1283,14 @@ class NoiseGeneratorRecording(BaseRecording): Std of the white noise (if an array, defined by per channels) cov_matrix : np.ndarray | None, default: None The covariance matrix of the noise + spectral_density : np.ndarray | None, default: None + The spectral density of the noise, such as would be estimated from an array of snippets of shape + `(n_snippets, spectral_snippet_length)` by the following method (Welch's method): + + ```python + periodogram = rfft(snippets, n=next_fast_len(snippets.shape[1]), norm="ortho") + spectral_density = np.sqrt((periodogram * periodogram.conj()).mean(axis=0)) + ``` dtype : np.dtype | str | None, default: "float32" The dtype of the recording. Note that only np.float32 and np.float64 are supported. seed : int | None, default: None @@ -1260,6 +1301,10 @@ class NoiseGeneratorRecording(BaseRecording): very fast and cusume only one noise block. * "on_the_fly": generate on the fly a new noise block by combining seed + noise block index no memory preallocation but a bit more computaion (random) + temporal_use_overlap_add: bool, default: False + If applying a spectral density, it's faster to use overlap-add convolutions, but these can + introduce numerical differences when the args to get_traces() change. But, it is okay to set + this flag if you will only ever call get_traces with non-overlapping chunks. noise_block_size : int, default: 30000 Size in sample of noise block. @@ -1276,9 +1321,11 @@ def __init__( durations: list[float], noise_levels: float | np.ndarray = 1.0, cov_matrix: np.ndarray | None = None, + spectral_density: np.ndarray | None = None, dtype: np.dtype | str | None = "float32", seed: int | None = None, strategy: Literal["tile_pregenerated", "on_the_fly"] = "tile_pregenerated", + temporal_use_overlap_add: bool = False, noise_block_size: int = 30000, ): @@ -1297,7 +1344,12 @@ def __init__( assert len(noise_levels[0]) == num_channels, "Noise levels should have a size of num_channels" - BaseRecording.__init__(self, sampling_frequency=sampling_frequency, channel_ids=channel_ids, dtype=dtype) + BaseRecording.__init__( + self, + sampling_frequency=sampling_frequency, + channel_ids=channel_ids, + dtype=dtype, + ) num_segments = len(durations) @@ -1322,9 +1374,11 @@ def __init__( noise_block_size, noise_levels, cov_matrix, + spectral_density, dtype, segments_seeds[i], strategy, + temporal_use_overlap_add, ) self.add_recording_segment(rec_segment) @@ -1350,9 +1404,11 @@ def __init__( noise_block_size, noise_levels, cov_matrix, + spectral_density, dtype, seed, strategy, + use_overlap_add, ): assert seed is not None @@ -1363,21 +1419,33 @@ def __init__( self.noise_block_size = noise_block_size self.noise_levels = noise_levels self.cov_matrix = cov_matrix + self.spectral_density = spectral_density self.dtype = dtype self.seed = seed self.strategy = strategy + self.use_overlap_add = use_overlap_add if self.strategy == "tile_pregenerated": rng = np.random.default_rng(seed=self.seed) if self.cov_matrix is None: self.noise_block = ( - rng.standard_normal(size=(self.noise_block_size, self.num_channels), dtype=self.dtype) + rng.standard_normal( + size=(self.noise_block_size, self.num_channels), + dtype=self.dtype, + ) * noise_levels ) else: self.noise_block = rng.multivariate_normal( - np.zeros(self.num_channels), self.cov_matrix, size=self.noise_block_size + np.zeros(self.num_channels), + self.cov_matrix, + size=self.noise_block_size, + ) + + if spectral_density is not None: + self.noise_block = _apply_temporal_psd( + self.noise_block, spectral_density, use_overlap_add=use_overlap_add ) elif self.strategy == "on_the_fly": @@ -1386,7 +1454,7 @@ def __init__( def get_num_samples(self) -> int: return self.num_samples - def get_traces( + def get_traces_spatial_only( self, start_frame: int | None = None, end_frame: int | None = None, @@ -1414,13 +1482,20 @@ def get_traces( elif self.strategy == "on_the_fly": rng = np.random.default_rng(seed=(self.seed, block_index)) if self.cov_matrix is None: - noise_block = rng.standard_normal(size=(self.noise_block_size, self.num_channels), dtype=self.dtype) + noise_block = rng.standard_normal( + size=(self.noise_block_size, self.num_channels), + dtype=self.dtype, + ) else: noise_block = rng.multivariate_normal( - np.zeros(self.num_channels), self.cov_matrix, size=self.noise_block_size + np.zeros(self.num_channels), + self.cov_matrix, + size=self.noise_block_size, ) noise_block *= self.noise_levels + else: + assert False if block_index == first_block_index: if first_block_index != last_block_index: @@ -1442,6 +1517,61 @@ def get_traces( return traces + def get_traces( + self, + start_frame: int | None = None, + end_frame: int | None = None, + channel_indices: list | None = None, + ): + if self.spectral_density is None or self.strategy == "tile_pregenerated": + return self.get_traces_spatial_only(start_frame, end_frame, channel_indices) + + n_samples = self.get_num_samples() + if start_frame is None: + start_frame = 0 + if end_frame is None: + end_frame = n_samples + + # margin logic for temporal correlations + # temporal correlations are done "circularly" so that the noise wraps continuously + # having the spatial only helper makes the logic way easier! + + # circular margin logic + # left_circ_len is how much margin to grab from the right side to fill out the left + # margin circularly. similarly right_circ_len. + margin = self.spectral_density.shape[0] - 1 + left_pad_len = min(start_frame, margin) + left_circ_len = margin - left_pad_len + right_pad_len = min(n_samples - end_frame, margin) + right_circ_len = margin - right_pad_len + + # grab padded chunk + pad_start_frame = start_frame - left_pad_len + pad_end_frame = end_frame + right_pad_len + print(f"{pad_start_frame=} {pad_end_frame=}") + assert 0 <= pad_start_frame < pad_end_frame <= n_samples + chunk = self.get_traces_spatial_only(pad_start_frame, pad_end_frame, channel_indices) + empty = np.zeros_like(chunk[:0]) + if left_circ_len: + circ_left = self.get_traces_spatial_only(n_samples - left_circ_len, n_samples, channel_indices) + else: + circ_left = empty + if right_circ_len: + circ_right = self.get_traces_spatial_only(0, right_circ_len, channel_indices) + else: + circ_right = empty + + chunk_main = chunk[left_pad_len : chunk.shape[0] - right_pad_len] + pad_left = np.concatenate([circ_left, chunk[:left_pad_len]]) + pad_right = np.concatenate([chunk[chunk.shape[0] - right_pad_len : chunk.shape[0]], circ_right]) + return _apply_temporal_psd( + chunk_main, + self.spectral_density, + pad_left, + pad_right, + use_overlap_add=self.use_overlap_add, + ) + noise_generator_recording = define_function_from_class( source_class=NoiseGeneratorRecording, name="noise_generator_recording" @@ -1502,6 +1632,50 @@ def generate_recording_by_size( return recording +def _apply_temporal_psd( + chunk: np.ndarray, + psd: np.ndarray, + pad_left: np.ndarray | None = None, + pad_right: np.ndarray | None = None, + use_overlap_add: bool = False, +): + """Apply a convolution so that chunk's first dimension has psd `psd`, if it started out as white noise + + If padding arrays are not supplied, circular padding is extracted. + + It is very sad because it's relatively slow, but we are absolutely forced to use direct convolution here. + If not, then the output will have numerical disagreement depending on the args to get_traces(), since the + ffts will vary slightly. If you can tolerate errors of 5e-7 or if you'll only ever call get_traces() with + the same args in non-overlapping chunks, you can use_overlap_add=True. + """ + from scipy.signal import oaconvolve, convolve + + klen = psd.shape[0] + block_len = 2 * klen - 1 + pad_len = klen - 1 + + # padding + T = chunk.shape[0] + if pad_left is None: + pad_left = chunk[T - pad_len : T] + if pad_right is None: + pad_right = chunk[:pad_len] + assert pad_left.shape == pad_right.shape == (pad_len, chunk.shape[1]) + + # stack and convolve + chunk_conv = np.concatenate([pad_left.T, chunk.T, pad_right.T], axis=1) + kernel = np.fft.fftshift(np.fft.irfft(psd, n=block_len)) + if use_overlap_add: + conv = oaconvolve(chunk_conv, kernel[None], mode="valid", axes=1) + else: + conv = convolve(chunk_conv, kernel[None], mode="valid", method="direct") + del chunk_conv + conv = np.ascontiguousarray(conv.T) + assert conv.shape == chunk.shape + + return conv + + ## Waveforms zone ## @@ -1615,7 +1789,12 @@ def generate_single_fake_waveform( nrepol = int(repolarization_ms * sampling_frequency / 1000.0) tau_ms = repolarization_ms * 0.5 wf[nbefore : nbefore + nrepol] = exp_growth( - negative_amplitude, positive_amplitude, repolarization_ms, tau_ms, sampling_frequency, flip=True + negative_amplitude, + positive_amplitude, + repolarization_ms, + tau_ms, + sampling_frequency, + flip=True, ) # recovery @@ -2089,8 +2268,16 @@ def get_traces( else: traces = np.zeros([end_frame - start_frame, n_channels], dtype=self.dtype) - start = np.searchsorted(self.spike_vector["sample_index"], start_frame - self.templates.shape[1], side="left") - end = np.searchsorted(self.spike_vector["sample_index"], end_frame + self.templates.shape[1], side="right") + start = np.searchsorted( + self.spike_vector["sample_index"], + start_frame - self.templates.shape[1], + side="left", + ) + end = np.searchsorted( + self.spike_vector["sample_index"], + end_frame + self.templates.shape[1], + side="right", + ) for i in range(start, end): spike = self.spike_vector[i] @@ -2241,8 +2428,14 @@ def generate_unit_locations( rng = np.random.default_rng(seed=seed) units_locations = np.zeros((num_units, 3), dtype="float32") - minimum_x, maximum_x = np.min(channel_locations[:, 0]) - margin_um, np.max(channel_locations[:, 0]) + margin_um - minimum_y, maximum_y = np.min(channel_locations[:, 1]) - margin_um, np.max(channel_locations[:, 1]) + margin_um + minimum_x, maximum_x = ( + np.min(channel_locations[:, 0]) - margin_um, + np.max(channel_locations[:, 0]) + margin_um, + ) + minimum_y, maximum_y = ( + np.min(channel_locations[:, 1]) - margin_um, + np.max(channel_locations[:, 1]) + margin_um, + ) units_locations[:, 0] = rng.uniform(minimum_x, maximum_x, size=num_units) if distribution == "uniform": diff --git a/src/spikeinterface/core/tests/test_generate.py b/src/spikeinterface/core/tests/test_generate.py index e0e28d09cd..16a180a95f 100644 --- a/src/spikeinterface/core/tests/test_generate.py +++ b/src/spikeinterface/core/tests/test_generate.py @@ -63,7 +63,10 @@ def test_generate_sorting_with_spikes_on_borders(): # at least num_border spikes at borders for all segments for spikes_in_segment in spikes: # check that sample indices are correctly sorted within segments - np.testing.assert_array_equal(spikes_in_segment["sample_index"], np.sort(spikes_in_segment["sample_index"])) + np.testing.assert_array_equal( + spikes_in_segment["sample_index"], + np.sort(spikes_in_segment["sample_index"]), + ) num_samples = int(segment_duration * 30000) assert np.sum(spikes_in_segment["sample_index"] < border_size_samples) >= num_spikes_on_borders assert ( @@ -501,15 +504,28 @@ def test_inject_templates(): # generate some sutff rec_noise = generate_recording( - num_channels=num_channels, durations=durations, sampling_frequency=sampling_frequency, seed=42 + num_channels=num_channels, + durations=durations, + sampling_frequency=sampling_frequency, + seed=42, ) channel_locations = rec_noise.get_channel_locations() sorting = generate_sorting( - num_units=num_units, durations=durations, sampling_frequency=sampling_frequency, firing_rates=1.0, seed=42 + num_units=num_units, + durations=durations, + sampling_frequency=sampling_frequency, + firing_rates=1.0, + seed=42, ) units_locations = generate_unit_locations(num_units, channel_locations, margin_um=10.0, seed=42) templates_3d = generate_templates( - channel_locations, units_locations, sampling_frequency, ms_before, ms_after, seed=42, upsample_factor=None + channel_locations, + units_locations, + sampling_frequency, + ms_before, + ms_after, + seed=42, + upsample_factor=None, ) templates_4d = generate_templates( channel_locations, @@ -536,7 +552,11 @@ def test_inject_templates(): rng = np.random.default_rng(seed=42) upsample_vector = rng.integers(0, upsample_factor, size=sorting.to_spike_vector().size) rec3 = InjectTemplatesRecording( - sorting, templates_4d, nbefore=nbefore, parent_recording=rec_noise, upsample_vector=upsample_vector + sorting, + templates_4d, + nbefore=nbefore, + parent_recording=rec_noise, + upsample_vector=upsample_vector, ) for rec in (rec1, rec2, rec3): @@ -565,7 +585,8 @@ def test_transformsorting(): sorting_1 = NumpySorting.from_unit_dict({46: np.array([0, 150], dtype=int)}, sampling_frequency=20000.0) sorting_2 = NumpySorting.from_unit_dict( - {0: np.array([100, 2000], dtype=int), 3: np.array([200, 4000], dtype=int)}, sampling_frequency=20000.0 + {0: np.array([100, 2000], dtype=int), 3: np.array([200, 4000], dtype=int)}, + sampling_frequency=20000.0, ) transformed = TransformSorting.add_from_sorting(sorting_1, sorting_2) assert len(transformed.unit_ids) == 3 @@ -573,7 +594,8 @@ def test_transformsorting(): sorting_1 = NumpySorting.from_unit_dict({0: np.array([12], dtype=int)}, sampling_frequency=20000.0) sorting_2 = NumpySorting.from_unit_dict( - {0: np.array([150], dtype=int), 3: np.array([12, 150], dtype=int)}, sampling_frequency=20000.0 + {0: np.array([150], dtype=int), 3: np.array([12, 150], dtype=int)}, + sampling_frequency=20000.0, ) transformed = TransformSorting.add_from_sorting(sorting_1, sorting_2) assert len(transformed.unit_ids) == 2 @@ -613,9 +635,17 @@ def test_generate_ground_truth_recording(): def test_generate_sorting_to_inject(): durations = [10.0, 20.0] - sorting = generate_sorting(num_units=10, durations=durations, sampling_frequency=30000, firing_rates=1.0, seed=2205) + sorting = generate_sorting( + num_units=10, + durations=durations, + sampling_frequency=30000, + firing_rates=1.0, + seed=2205, + ) injected_sorting = generate_sorting_to_inject( - sorting, [int(duration * sorting.sampling_frequency) for duration in durations], seed=2308 + sorting, + [int(duration * sorting.sampling_frequency) for duration in durations], + seed=2308, ) num_spikes = sorting.count_num_spikes_per_unit() num_injected_spikes = injected_sorting.count_num_spikes_per_unit() diff --git a/src/spikeinterface/generation/tests/test_noise_tools.py b/src/spikeinterface/generation/tests/test_noise_tools.py index f633ed1de4..766229032c 100644 --- a/src/spikeinterface/generation/tests/test_noise_tools.py +++ b/src/spikeinterface/generation/tests/test_noise_tools.py @@ -1,7 +1,10 @@ +import numpy as np import probeinterface +import pytest from spikeinterface.generation import ( generate_noise, + NoiseGeneratorRecording, ) @@ -40,5 +43,63 @@ def test_generate_noise(): # plt.show() +@pytest.mark.parametrize("duration", [1.0, 2.0, 2.2]) +@pytest.mark.parametrize("strategy", ["tile_precomputed", "on_the_fly"]) +def test_noise_generator_temporal(strategy, duration): + psdlen = 25 + kdomain = np.linspace(0.0, 10.0, psdlen) + fake_psd = (kdomain + 0.1) * np.exp(-kdomain) + # this ensures std dev of output ~= 1 + fake_psd /= np.sqrt((fake_psd**2).mean()) + + # Test that the recording has the correct size in shape + sampling_frequency = 30000 # Hz + durations = [duration] + dtype = np.dtype("float32") + num_channels = 2 + seed = 0 + + rec = NoiseGeneratorRecording( + num_channels=num_channels, + sampling_frequency=sampling_frequency, + durations=durations, + dtype=dtype, + seed=seed, + spectral_density=fake_psd, + strategy=strategy, + ) + + # check output matches at different chunks + full_traces = rec.get_traces() + end_frame = rec.get_num_frames() + for t0 in [0, 100]: + for t1 in [end_frame, end_frame - 100]: + chk = rec.get_traces(0, t0, t1) + chk0 = full_traces[t0:t1] + np.testing.assert_array_equal(chk, chk0) + + np.testing.assert_allclose(full_traces.std(), 1.0, rtol=0.02) + + # re-estimate the psd from the result + # it will not be perfect, okay! + n = 2 * psdlen - 1 + snips = full_traces[: n * (full_traces.shape[0] // n)] + snips = snips.reshape(-1, n, snips.shape[-1]) + psd = np.fft.rfft(snips, n=n, axis=1, norm="ortho") + psd = np.sqrt(np.square(np.abs(psd)).mean(axis=(0, 2))) + + sample_size = snips.shape[0] * snips.shape[2] + standard_error = 1.0 / np.sqrt(sample_size) + + # accuracy is good at low freqs + np.testing.assert_allclose( + psd[1 : psdlen // 3], + fake_psd[1 : psdlen // 3], + atol=3 * standard_error, + rtol=0.1, + ) + np.testing.assert_allclose(psd, fake_psd, atol=0.5) + + if __name__ == "__main__": test_generate_noise()