@@ -272,6 +272,8 @@ def on_close_view(view_, gui):
272272
273273
274274class FeatureMixin (object ):
275+ # Spike attributes that can be used for visualization in addition to the features.
276+ _spike_attributes = ('amplitudes' , 'depths' )
275277 n_spikes_features = 2500
276278 n_spikes_features_background = 2500
277279
@@ -286,22 +288,107 @@ class FeatureMixin(object):
286288 )
287289
288290 _cached = (
289- '_get_features' ,
290291 'get_spike_feature_amplitudes' ,
291292 )
292293
294+ _memcached = (
295+ # '_get_features_for_view',
296+ # 'get_spike_attributes_for_views',
297+ )
298+
299+ # This property provides a consistent public interface for views to get feature data,
300+ # abstracting away the underlying implementation.
301+ get_features = property (lambda self : self ._get_features_for_view )
302+
303+ def _get_feature_spike_ids (self , cluster_id , load_all = False ):
304+ """Return spike ids to be used in the feature view."""
305+ if load_all :
306+ return self .supervisor .get_spike_ids (cluster_id )
307+ # Background spikes.
308+ if cluster_id is None :
309+ return self .selector (self .n_spikes_features_background , [])
310+ # Spikes in a cluster.
311+ return self .selector (self .n_spikes_features , [cluster_id ])
312+
313+ def _get_features_for_view (self , cluster_ids , channel_ids = None , load_all = False ):
314+ """Get features for a list of clusters.
315+
316+ This function is the main entry point for views to retrieve feature data.
317+ It handles fetching data for both background spikes and specific clusters,
318+ and determines the appropriate channels to use if not specified.
319+ """
320+ if self .model .features is None :
321+ return
322+
323+ # Special case for background spikes.
324+ if cluster_ids is None :
325+ spike_ids = self ._get_feature_spike_ids (None , load_all = load_all )
326+ if spike_ids is None or not len (spike_ids ):
327+ return
328+ features = self .model .features [spike_ids , ...]
329+ # We need to specify the channel ids, which are all channels in this case.
330+ b_channel_ids = np .arange (self .model .channel_positions .shape [0 ])
331+ b = Bunch (
332+ data = features ,
333+ spike_ids = spike_ids ,
334+ channel_ids = b_channel_ids ,
335+ cluster_id = None ,
336+ )
337+ # This is a list of bunches.
338+ return [b ]
339+
340+ bunchs = []
341+ for cluster_id in cluster_ids :
342+ spike_ids = self ._get_feature_spike_ids (cluster_id , load_all = load_all )
343+ if spike_ids is None or not len (spike_ids ):
344+ continue
345+
346+ # If channel_ids are not provided, get the best channels for the cluster.
347+ if channel_ids is None :
348+ c_ids = self .get_best_channels (cluster_id )
349+ else :
350+ c_ids = channel_ids
351+
352+ # Get the features for the specified channels.
353+ features_bunch = self ._get_spike_features (spike_ids , c_ids )
354+ if not features_bunch :
355+ continue
356+
357+ features_bunch .cluster_id = cluster_id
358+ bunchs .append (features_bunch )
359+ return bunchs
360+
361+ def get_spike_attributes_for_views (self ):
362+ """Return a dictionary of functions `cluster_id => values`.
363+
364+ This method provides a flexible "data menu" for views. Instead of returning data
365+ directly, it returns a dictionary of callable functions. Each function can be
366+ invoked by a view to get a specific data attribute (e.g., depths, amplitudes)
367+ for a cluster on demand. This design enables the creation of complex views
368+ (like a 3D view) that require multiple independent data sources.
369+ """
370+ d = {}
371+ for name in self ._spike_attributes :
372+ # The function takes a cluster_id and returns an array.
373+ d [name ] = lambda cluster_id , name = name : getattr (
374+ self .model , 'get_spike_%s' % name )(self ._get_feature_spike_ids (cluster_id ))
375+ # Use helper that works across models (TemplateModel may not implement get_spike_times)
376+ d ['time' ] = lambda cluster_id , load_all = False : self ._get_feature_view_spike_times (
377+ cluster_id , load_all = load_all )
378+ return d
379+
293380 def get_spike_feature_amplitudes (
294- self , spike_ids , channel_id = None , channel_ids = None , pc = None , ** kwargs ):
295- """Return the features for the specified channel and PC ."""
381+ self , spike_ids , channel_id = None , ** kwargs ):
382+ """Return the maximum amplitude of the features on one channel ."""
296383 if self .model .features is None :
297384 return
298385 channel_id = channel_id if channel_id is not None else channel_ids [0 ]
299386 features = self ._get_spike_features (spike_ids , [channel_id ]).get ('data' , None )
300387 if features is None : # pragma: no cover
301388 return
302389 assert features .shape [0 ] == len (spike_ids )
303- logger .log (5 , "Show channel %s and PC %s in amplitude view." , channel_id , pc )
304- return features [:, 0 , pc or 0 ]
390+ logger .log (5 , "Show channel %s and PC 0 in amplitude view." , channel_id )
391+ return features [:, 0 , 0 ]
305392
306393 def create_amplitude_view (self ):
307394 view = super (FeatureMixin , self ).create_amplitude_view ()
0 commit comments