@@ -552,28 +552,71 @@ def on_request_split(self, sender=None):
552552 logger .debug ("Lasso polygon too small" )
553553 return np .array ([], dtype = np .int64 )
554554
555- # Find points inside the lasso for each cluster
555+ # We need to reload ALL spikes (not just the displayed subset) to ensure
556+ # we catch all points that should be split
557+ logger .debug ("Loading all spikes for lasso split operation" )
558+
559+ # Get full data for all selected clusters with load_all=True
560+ bunchs = self .get_clusters_data (fixed_channels = self .fixed_channels , load_all = True )
561+
562+ # Also need to recompute 3D positions for all loaded spikes
556563 spike_ids_to_split = []
557-
558- for cluster_info in self ._cluster_data :
559- if cluster_info ['cluster_id' ] is None : # Skip background
560- continue
561-
562- bunch = cluster_info ['bunch' ]
563- if hasattr (bunch , 'pos' ) and len (bunch .pos ) > 0 :
564- pts2d = bunch .pos
564+ from matplotlib .path import Path
565+ lasso_path = Path (lasso_points )
566+
567+ for bunch in bunchs :
568+ cluster_id = bunch .get ('cluster_id' )
569+
570+ # Skip background points (cluster_id = None) - they're shown but not selectable for split
571+ # However, let's check if they actually belong to a selected cluster
572+ if cluster_id is None :
573+ # For background, we need to check if any of these spikes belong to selected clusters
574+ # Background bunch might contain spikes from selected clusters
575+ spike_ids = bunch .get ('spike_ids' , [])
576+ if spike_ids is None or len (spike_ids ) == 0 :
577+ continue
578+
579+ # Get 3D coordinates for background
580+ x = self ._get_axis_data (bunch , self .x_axis , cluster_id )
581+ y = self ._get_axis_data (bunch , self .y_axis , cluster_id )
582+ z = self ._get_axis_data (bunch , self .z_axis , cluster_id )
583+
584+ if len (x ) == 0 or len (y ) == 0 or len (z ) == 0 :
585+ continue
586+
587+ points_3d = np .column_stack ([x , y , z ])
588+ pts2d = self ._project_3d_to_2d (points_3d )
589+
590+ # Check which points are inside the lasso
591+ inside_mask = lasso_path .contains_points (pts2d )
592+
593+ if np .any (inside_mask ):
594+ # Only include spike IDs that actually belong to selected clusters
595+ selected_spikes = np .array (spike_ids )[inside_mask ]
596+ # Filter to only include spikes from selected clusters
597+ # This requires the caller to handle the filtering
598+ spike_ids_to_split .extend (selected_spikes .tolist ())
565599 else :
566- # Fallback: compute 2D positions from stored 3D points
567- pts2d = self ._project_3d_to_2d (cluster_info ['points_3d' ])
568- bunch .pos = pts2d
569- # Check which points are inside the lasso
570- from matplotlib .path import Path
571- lasso_path = Path (lasso_points )
572- inside_mask = lasso_path .contains_points (pts2d )
573-
574- if np .any (inside_mask ):
600+ # For regular clusters, include all spikes inside the lasso
575601 spike_ids = bunch .get ('spike_ids' , [])
576- if spike_ids is not None and len (spike_ids ) > 0 :
602+ if spike_ids is None or len (spike_ids ) == 0 :
603+ continue
604+
605+ # Get 3D coordinates
606+ x = self ._get_axis_data (bunch , self .x_axis , cluster_id )
607+ y = self ._get_axis_data (bunch , self .y_axis , cluster_id )
608+ z = self ._get_axis_data (bunch , self .z_axis , cluster_id )
609+
610+ if len (x ) == 0 or len (y ) == 0 or len (z ) == 0 :
611+ continue
612+
613+ points_3d = np .column_stack ([x , y , z ])
614+ pts2d = self ._project_3d_to_2d (points_3d )
615+
616+ # Check which points are inside the lasso
617+ inside_mask = lasso_path .contains_points (pts2d )
618+
619+ if np .any (inside_mask ):
577620 selected_spikes = np .array (spike_ids )[inside_mask ]
578621 spike_ids_to_split .extend (selected_spikes .tolist ())
579622
0 commit comments