Skip to content

Commit 60d716d

Browse files
committed
Fix(view): Ensure all spikes are loaded for 3D view lasso split
1 parent 4a353eb commit 60d716d

1 file changed

Lines changed: 62 additions & 19 deletions

File tree

phy/cluster/views/featureview3d.py

Lines changed: 62 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -552,28 +552,71 @@ def on_request_split(self, sender=None):
552552
logger.debug("Lasso polygon too small")
553553
return np.array([], dtype=np.int64)
554554

555-
# Find points inside the lasso for each cluster
555+
# We need to reload ALL spikes (not just the displayed subset) to ensure
556+
# we catch all points that should be split
557+
logger.debug("Loading all spikes for lasso split operation")
558+
559+
# Get full data for all selected clusters with load_all=True
560+
bunchs = self.get_clusters_data(fixed_channels=self.fixed_channels, load_all=True)
561+
562+
# Also need to recompute 3D positions for all loaded spikes
556563
spike_ids_to_split = []
557-
558-
for cluster_info in self._cluster_data:
559-
if cluster_info['cluster_id'] is None: # Skip background
560-
continue
561-
562-
bunch = cluster_info['bunch']
563-
if hasattr(bunch, 'pos') and len(bunch.pos) > 0:
564-
pts2d = bunch.pos
564+
from matplotlib.path import Path
565+
lasso_path = Path(lasso_points)
566+
567+
for bunch in bunchs:
568+
cluster_id = bunch.get('cluster_id')
569+
570+
# Skip background points (cluster_id = None) - they're shown but not selectable for split
571+
# However, let's check if they actually belong to a selected cluster
572+
if cluster_id is None:
573+
# For background, we need to check if any of these spikes belong to selected clusters
574+
# Background bunch might contain spikes from selected clusters
575+
spike_ids = bunch.get('spike_ids', [])
576+
if spike_ids is None or len(spike_ids) == 0:
577+
continue
578+
579+
# Get 3D coordinates for background
580+
x = self._get_axis_data(bunch, self.x_axis, cluster_id)
581+
y = self._get_axis_data(bunch, self.y_axis, cluster_id)
582+
z = self._get_axis_data(bunch, self.z_axis, cluster_id)
583+
584+
if len(x) == 0 or len(y) == 0 or len(z) == 0:
585+
continue
586+
587+
points_3d = np.column_stack([x, y, z])
588+
pts2d = self._project_3d_to_2d(points_3d)
589+
590+
# Check which points are inside the lasso
591+
inside_mask = lasso_path.contains_points(pts2d)
592+
593+
if np.any(inside_mask):
594+
# Only include spike IDs that actually belong to selected clusters
595+
selected_spikes = np.array(spike_ids)[inside_mask]
596+
# Filter to only include spikes from selected clusters
597+
# This requires the caller to handle the filtering
598+
spike_ids_to_split.extend(selected_spikes.tolist())
565599
else:
566-
# Fallback: compute 2D positions from stored 3D points
567-
pts2d = self._project_3d_to_2d(cluster_info['points_3d'])
568-
bunch.pos = pts2d
569-
# Check which points are inside the lasso
570-
from matplotlib.path import Path
571-
lasso_path = Path(lasso_points)
572-
inside_mask = lasso_path.contains_points(pts2d)
573-
574-
if np.any(inside_mask):
600+
# For regular clusters, include all spikes inside the lasso
575601
spike_ids = bunch.get('spike_ids', [])
576-
if spike_ids is not None and len(spike_ids) > 0:
602+
if spike_ids is None or len(spike_ids) == 0:
603+
continue
604+
605+
# Get 3D coordinates
606+
x = self._get_axis_data(bunch, self.x_axis, cluster_id)
607+
y = self._get_axis_data(bunch, self.y_axis, cluster_id)
608+
z = self._get_axis_data(bunch, self.z_axis, cluster_id)
609+
610+
if len(x) == 0 or len(y) == 0 or len(z) == 0:
611+
continue
612+
613+
points_3d = np.column_stack([x, y, z])
614+
pts2d = self._project_3d_to_2d(points_3d)
615+
616+
# Check which points are inside the lasso
617+
inside_mask = lasso_path.contains_points(pts2d)
618+
619+
if np.any(inside_mask):
577620
selected_spikes = np.array(spike_ids)[inside_mask]
578621
spike_ids_to_split.extend(selected_spikes.tolist())
579622

0 commit comments

Comments
 (0)