Skip to content

Commit 4a353eb

Browse files
committed
feat: Add 3D feature view and update related components
1 parent 7a2494b commit 4a353eb

7 files changed

Lines changed: 1325 additions & 12 deletions

File tree

phy/apps/base.py

Lines changed: 92 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,8 @@ def on_close_view(view_, gui):
272272

273273

274274
class FeatureMixin(object):
275+
# Spike attributes that can be used for visualization in addition to the features.
276+
_spike_attributes = ('amplitudes', 'depths')
275277
n_spikes_features = 2500
276278
n_spikes_features_background = 2500
277279

@@ -286,22 +288,107 @@ class FeatureMixin(object):
286288
)
287289

288290
_cached = (
289-
'_get_features',
290291
'get_spike_feature_amplitudes',
291292
)
292293

294+
_memcached = (
295+
# '_get_features_for_view',
296+
# 'get_spike_attributes_for_views',
297+
)
298+
299+
# This property provides a consistent public interface for views to get feature data,
300+
# abstracting away the underlying implementation.
301+
get_features = property(lambda self: self._get_features_for_view)
302+
303+
def _get_feature_spike_ids(self, cluster_id, load_all=False):
304+
"""Return spike ids to be used in the feature view."""
305+
if load_all:
306+
return self.supervisor.get_spike_ids(cluster_id)
307+
# Background spikes.
308+
if cluster_id is None:
309+
return self.selector(self.n_spikes_features_background, [])
310+
# Spikes in a cluster.
311+
return self.selector(self.n_spikes_features, [cluster_id])
312+
313+
def _get_features_for_view(self, cluster_ids, channel_ids=None, load_all=False):
314+
"""Get features for a list of clusters.
315+
316+
This function is the main entry point for views to retrieve feature data.
317+
It handles fetching data for both background spikes and specific clusters,
318+
and determines the appropriate channels to use if not specified.
319+
"""
320+
if self.model.features is None:
321+
return
322+
323+
# Special case for background spikes.
324+
if cluster_ids is None:
325+
spike_ids = self._get_feature_spike_ids(None, load_all=load_all)
326+
if spike_ids is None or not len(spike_ids):
327+
return
328+
features = self.model.features[spike_ids, ...]
329+
# We need to specify the channel ids, which are all channels in this case.
330+
b_channel_ids = np.arange(self.model.channel_positions.shape[0])
331+
b = Bunch(
332+
data=features,
333+
spike_ids=spike_ids,
334+
channel_ids=b_channel_ids,
335+
cluster_id=None,
336+
)
337+
# This is a list of bunches.
338+
return [b]
339+
340+
bunchs = []
341+
for cluster_id in cluster_ids:
342+
spike_ids = self._get_feature_spike_ids(cluster_id, load_all=load_all)
343+
if spike_ids is None or not len(spike_ids):
344+
continue
345+
346+
# If channel_ids are not provided, get the best channels for the cluster.
347+
if channel_ids is None:
348+
c_ids = self.get_best_channels(cluster_id)
349+
else:
350+
c_ids = channel_ids
351+
352+
# Get the features for the specified channels.
353+
features_bunch = self._get_spike_features(spike_ids, c_ids)
354+
if not features_bunch:
355+
continue
356+
357+
features_bunch.cluster_id = cluster_id
358+
bunchs.append(features_bunch)
359+
return bunchs
360+
361+
def get_spike_attributes_for_views(self):
362+
"""Return a dictionary of functions `cluster_id => values`.
363+
364+
This method provides a flexible "data menu" for views. Instead of returning data
365+
directly, it returns a dictionary of callable functions. Each function can be
366+
invoked by a view to get a specific data attribute (e.g., depths, amplitudes)
367+
for a cluster on demand. This design enables the creation of complex views
368+
(like a 3D view) that require multiple independent data sources.
369+
"""
370+
d = {}
371+
for name in self._spike_attributes:
372+
# The function takes a cluster_id and returns an array.
373+
d[name] = lambda cluster_id, name=name: getattr(
374+
self.model, 'get_spike_%s' % name)(self._get_feature_spike_ids(cluster_id))
375+
# Use helper that works across models (TemplateModel may not implement get_spike_times)
376+
d['time'] = lambda cluster_id, load_all=False: self._get_feature_view_spike_times(
377+
cluster_id, load_all=load_all)
378+
return d
379+
293380
def get_spike_feature_amplitudes(
294-
self, spike_ids, channel_id=None, channel_ids=None, pc=None, **kwargs):
295-
"""Return the features for the specified channel and PC."""
381+
self, spike_ids, channel_id=None, **kwargs):
382+
"""Return the maximum amplitude of the features on one channel."""
296383
if self.model.features is None:
297384
return
298385
channel_id = channel_id if channel_id is not None else channel_ids[0]
299386
features = self._get_spike_features(spike_ids, [channel_id]).get('data', None)
300387
if features is None: # pragma: no cover
301388
return
302389
assert features.shape[0] == len(spike_ids)
303-
logger.log(5, "Show channel %s and PC %s in amplitude view.", channel_id, pc)
304-
return features[:, 0, pc or 0]
390+
logger.log(5, "Show channel %s and PC 0 in amplitude view.", channel_id)
391+
return features[:, 0, 0]
305392

306393
def create_amplitude_view(self):
307394
view = super(FeatureMixin, self).create_amplitude_view()

phy/apps/template/gui.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
from phylib import _add_log_file
1717
from phylib.io.model import TemplateModel, load_model
1818
from phylib.io.traces import MtscompEphysReader
19-
from phylib.utils import Bunch, connect
19+
from phylib.utils import Bunch, connect, unconnect
2020

21-
from phy.cluster.views import ScatterView
21+
from phy.cluster.views import ScatterView, Feature3DView
2222
from phy.gui import create_app, run_app
2323
from ..base import WaveformMixin, FeatureMixin, TemplateMixin, TraceMixin, BaseController
2424

@@ -70,6 +70,7 @@ class TemplateController(WaveformMixin, FeatureMixin, TemplateMixin, TraceMixin,
7070
'CorrelogramView',
7171
'ISIView',
7272
'FeatureView',
73+
'Feature3DView',
7374
'AmplitudeView',
7475
'FiringRateView',
7576
'TraceView',
@@ -141,6 +142,7 @@ def _get_template_features(self, cluster_ids, load_all=False):
141142
def _set_view_creator(self):
142143
super(TemplateController, self)._set_view_creator()
143144
self.view_creator['TemplateFeatureView'] = self.create_template_feature_view
145+
self.view_creator['Feature3DView'] = self.create_feature_3d_view
144146

145147
# Public methods
146148
# -------------------------------------------------------------------------
@@ -194,6 +196,49 @@ def create_template_feature_view(self):
194196
return
195197
return TemplateFeatureView(coords=self._get_template_features)
196198

199+
def create_feature_3d_view(self):
200+
"""Create and configure the 3D feature view.
201+
202+
This view requires multiple data sources to render the 3D scatter plot:
203+
* `features`: The main feature data, typically used for the X and Y axes.
204+
* `attributes`: A dictionary of other data vectors (like depth), used for the
205+
Z axis and color. This is provided by `get_spike_attributes_for_views`.
206+
* `channel_positions`: The physical layout of the probe channels.
207+
"""
208+
logger.debug("Creating Feature3DView")
209+
try:
210+
# Gather the different data sources required by the view.
211+
features = self.get_features
212+
attributes = self.get_spike_attributes_for_views()
213+
channel_positions = self.model.channel_positions
214+
logger.debug(f"Features: {features}")
215+
logger.debug(f"Attributes: {attributes}")
216+
logger.debug(f"Channel positions: {channel_positions.shape if channel_positions is not None else 'None'}")
217+
view = Feature3DView(
218+
features=features,
219+
attributes=attributes,
220+
channel_positions=channel_positions,
221+
cluster_ids=self.supervisor.selected
222+
)
223+
logger.debug("Feature3DView created successfully")
224+
225+
# Connect the view to the supervisor's select event.
226+
# This ensures the view is updated when the cluster selection changes.
227+
@connect(sender=self.supervisor)
228+
def on_select(sender, cluster_ids, **kwargs):
229+
if view.auto_update:
230+
view.on_select(cluster_ids=cluster_ids)
231+
232+
# Disconnect the view when it's closed to prevent memory leaks.
233+
@connect(sender=view)
234+
def on_close_view(view_, gui):
235+
unconnect(on_select)
236+
237+
return view
238+
except Exception as e:
239+
logger.error(f"Error creating Feature3DView: {e}", exc_info=True)
240+
raise
241+
197242

198243
#------------------------------------------------------------------------------
199244
# Template commands
@@ -226,4 +271,4 @@ def template_describe(params_path):
226271
"""Describe a template dataset."""
227272
model = load_model(params_path)
228273
model.describe()
229-
model.close()
274+
model.close()

phy/cluster/supervisor.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from phylib.utils import Bunch, emit, connect, unconnect
2121
from phy.gui.actions import Actions
22-
from phy.gui.qt import _block, set_busy, _wait
22+
from phy.gui.qt import _block, set_busy, _wait, QMessageBox
2323
from phy.gui.widgets import Table, HTMLWidget, _uniq, Barrier
2424

2525
logger = logging.getLogger(__name__)
@@ -1044,6 +1044,21 @@ def split(self, spike_ids=None, spike_clusters_rel=0):
10441044
out = self.clustering.split(
10451045
spike_ids, spike_clusters_rel=spike_clusters_rel)
10461046
self._global_history.action(self.clustering)
1047+
1048+
# Show a pop-up with the split information.
1049+
if out:
1050+
added = out.get('added', [])
1051+
deleted = out.get('deleted', [])
1052+
message = f"Split successful.\n\n"
1053+
if added:
1054+
message += f"New clusters created: {', '.join(map(str, added))}\n"
1055+
if deleted:
1056+
message += f"Original clusters affected: {', '.join(map(str, deleted))}"
1057+
1058+
box = QMessageBox()
1059+
box.setText(message)
1060+
box.exec_()
1061+
10471062
return out
10481063

10491064
# Move actions

phy/cluster/views/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .amplitude import AmplitudeView # noqa
1313
from .correlogram import CorrelogramView # noqa
1414
from .feature import FeatureView # noqa
15+
from .featureview3d import Feature3DView # noqa
1516
from .histogram import HistogramView, ISIView, FiringRateView # noqa
1617
from .probe import ProbeView # noqa
1718
from .raster import RasterView # noqa

phy/cluster/views/feature.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,8 @@ def attach(self, gui):
396396
def toggle_automatic_channel_selection(self, checked):
397397
"""Toggle the automatic selection of channels when the cluster selection changes."""
398398
self.fixed_channels = not checked
399+
# The status bar needs to be updated manually to reflect the change.
400+
self.update_status()
399401

400402
@property
401403
def status(self):

0 commit comments

Comments
 (0)