diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index c8d68089..65e12ac7 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -30,6 +30,7 @@ from spatialdata_plot._accessor import register_spatial_data_accessor from spatialdata_plot._logging import _log_context, logger from spatialdata_plot.pl.render import ( + _draw_channel_legend, _render_images, _render_labels, _render_points, @@ -40,6 +41,7 @@ CBAR_DEFAULT_FRACTION, CBAR_DEFAULT_LOCATION, CBAR_DEFAULT_PAD, + ChannelLegendEntry, CmapParams, ColorbarSpec, ImageRenderParams, @@ -523,6 +525,7 @@ def render_images( transfunc: Callable[[np.ndarray], np.ndarray] | list[Callable[[np.ndarray], np.ndarray]] | None = None, colorbar: bool | str | None = "auto", colorbar_params: dict[str, object] | None = None, + channels_as_categories: bool = False, **kwargs: Any, ) -> sd.SpatialData: """ @@ -600,6 +603,13 @@ def render_images( colorbar_params : Parameters forwarded to Matplotlib's colorbar alongside layout hints such as ``loc``, ``width``, ``pad``, and ``label``. + channels_as_categories : bool, default False + When ``True`` and rendering multiple channels, show a categorical + legend mapping each channel name to its compositing color. The + legend uses the ``legend_*`` parameters from :meth:`show`. + Ignored for single-channel and RGB(A) images. When multiple + ``render_images`` calls use this flag on the same axes, all + channel entries are combined into a single legend. kwargs Additional arguments to be passed to cmap, norm, and other rendering functions. @@ -681,6 +691,7 @@ def render_images( colorbar_params=param_values["colorbar_params"], transfunc=transfunc, grayscale=grayscale, + channels_as_categories=channels_as_categories, ) n_steps += 1 @@ -1140,6 +1151,7 @@ def _draw_colorbar( ax = fig_params.ax if fig_params.axs is None else fig_params.axs[i] assert isinstance(ax, Axes) axis_colorbar_requests: list[ColorbarSpec] | None = [] if legend_params.colorbar else None + axis_channel_legend_entries: list[ChannelLegendEntry] = [] wants_images = False wants_labels = False @@ -1170,6 +1182,7 @@ def _draw_colorbar( scalebar_params=scalebar_params, legend_params=legend_params, colorbar_requests=axis_colorbar_requests, + channel_legend_entries=axis_channel_legend_entries, rasterize=rasterize, ) @@ -1279,6 +1292,9 @@ def _draw_colorbar( if legend_params.colorbar and axis_colorbar_requests: pending_colorbars.append((ax, axis_colorbar_requests)) + if axis_channel_legend_entries: + _draw_channel_legend(ax, axis_channel_legend_entries, legend_params, fig_params) + if pending_colorbars and fig_params.fig is not None: fig = fig_params.fig fig.canvas.draw() diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index ca67d026..7750f864 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -2,6 +2,7 @@ import dataclasses from collections import abc +from collections.abc import Sequence from copy import copy from typing import Any @@ -18,9 +19,11 @@ import spatialdata as sd import xarray as xr from anndata import AnnData +from matplotlib import patheffects from matplotlib.cm import ScalarMappable from matplotlib.colors import ListedColormap, Normalize from scanpy._settings import settings as sc_settings +from scanpy.plotting._tools.scatterplots import _add_categorical_legend from spatialdata import get_extent, get_values, join_spatialelement_table from spatialdata._core.query.relational_query import match_table_to_element from spatialdata.models import PointsModel, ShapesModel, get_table_keys @@ -41,6 +44,7 @@ _render_ds_outlines, ) from spatialdata_plot.pl.render_params import ( + ChannelLegendEntry, CmapParams, Color, ColorbarSpec, @@ -1094,6 +1098,78 @@ def _is_rgb_image(channel_coords: list[Any]) -> tuple[bool, bool]: return False, False +def _collect_channel_legend_entries( + channels: Sequence[str | int], + seed_colors: Sequence[str | tuple[float, ...]], + channel_legend_entries: list[ChannelLegendEntry], +) -> None: + """Accumulate channel-to-color mappings for a deferred combined legend.""" + channel_names = [str(ch) for ch in channels] + if len(set(channel_names)) != len(channel_names): + logger.warning("channels_as_categories: duplicate channel names detected; skipping legend entries.") + return + + color_hexes = [matplotlib.colors.to_hex(c, keep_alpha=False) for c in seed_colors] + for name, color in zip(channel_names, color_hexes, strict=True): + channel_legend_entries.append(ChannelLegendEntry(channel_name=name, color_hex=color)) + + +def _draw_channel_legend( + ax: matplotlib.axes.SubplotBase, + entries: list[ChannelLegendEntry], + legend_params: LegendParams, + fig_params: FigParams, +) -> None: + """Draw a single combined categorical legend from accumulated channel entries. + + Because ``_add_categorical_legend`` adds invisible labeled scatter artists, + calling it here automatically merges with any earlier legend entries + (e.g. from labels or shapes) on the same axes via ``ax.legend()``. + + ``multi_panel`` is only set when no prior legend exists on the axis, + to avoid shrinking the axes twice (once for labels/shapes, once for + channels). + """ + # Deduplicate: if the same channel name appears twice, keep the last color + palette_dict: dict[str, str] = {} + for entry in entries: + palette_dict[entry.channel_name] = entry.color_hex + + legend_loc = legend_params.legend_loc + if legend_loc == "on data": + logger.warning( + "legend_loc='on data' is not supported for channel legends (no scatter coordinates); " + "falling back to 'right margin'." + ) + legend_loc = "right margin" + + categories = pd.Categorical(list(palette_dict)) + + path_effect = ( + [patheffects.withStroke(linewidth=legend_params.legend_fontoutline, foreground="w")] + if legend_params.legend_fontoutline is not None + else [] + ) + + # Only apply multi_panel shrink if no legend already exists on this axis + # (labels/shapes draw their legend during the render loop and already shrink). + has_existing_legend = ax.get_legend() is not None + needs_multi_panel = fig_params.axs is not None and not has_existing_legend + + _add_categorical_legend( + ax, + categories, + palette=palette_dict, + legend_loc=legend_loc, + legend_fontweight=legend_params.legend_fontweight, + legend_fontsize=legend_params.legend_fontsize, + legend_fontoutline=path_effect, + na_color=["lightgray"], + na_in_legend=False, + multi_panel=needs_multi_panel, + ) + + def _render_images( sdata: sd.SpatialData, render_params: ImageRenderParams, @@ -1104,6 +1180,7 @@ def _render_images( legend_params: LegendParams, rasterize: bool, colorbar_requests: list[ColorbarSpec] | None = None, + channel_legend_entries: list[ChannelLegendEntry] | None = None, ) -> None: _log_context.set("render_images") sdata_filt = sdata.filter_by_coordinate_system( @@ -1319,10 +1396,14 @@ def _render_images( layers[ch] = ch_norm(layers[ch]) + # Colors for the channel legend (set by each branch if applicable) + legend_colors: list[str] | None = None + # 2A) Image has 3 channels, no palette info, and no/only one cmap was given if palette is None and n_channels == 3 and not isinstance(render_params.cmap_params, list): if render_params.cmap_params.cmap_is_default: # -> use RGB stacked = np.clip(np.stack([layers[ch] for ch in layers], axis=-1), 0, 1) + legend_colors = ["red", "green", "blue"] else: # -> use given cmap for each channel channel_cmaps = [render_params.cmap_params.cmap] * n_channels stacked = ( @@ -1404,6 +1485,8 @@ def _render_images( f"multichannel strategy 'stack' to render." ) # TODO: update when pca is added as strategy + legend_colors = seed_colors + _ax_show_and_transform( colored, trans_data, @@ -1421,6 +1504,8 @@ def _render_images( colored = np.stack([channel_cmaps[i](layers[c]) for i, c in enumerate(channels)], 0).sum(0) colored = np.clip(colored[:, :, :3], 0, 1) + legend_colors = list(palette) + _ax_show_and_transform( colored, trans_data, @@ -1440,6 +1525,8 @@ def _render_images( ) colored = colored[:, :, :3] + legend_colors = [matplotlib.colors.to_hex(cm(0.75)) for cm in channel_cmaps] + _ax_show_and_transform( colored, trans_data, @@ -1452,6 +1539,17 @@ def _render_images( elif palette is not None and got_multiple_cmaps: raise ValueError("If 'palette' is provided, 'cmap' must be None.") + # Collect channel legend entries (single point for all multi-channel paths) + if render_params.channels_as_categories and channel_legend_entries is not None: + if legend_colors is not None: + _collect_channel_legend_entries(channels, legend_colors, channel_legend_entries) + else: + logger.warning( + "channels_as_categories requires distinct per-channel colors; " + "ignored when a single cmap is shared across channels. " + "Use 'palette' or a list of cmaps instead." + ) + def _render_labels( sdata: sd.SpatialData, diff --git a/src/spatialdata_plot/pl/render_params.py b/src/spatialdata_plot/pl/render_params.py index 344df5f9..92dcf179 100644 --- a/src/spatialdata_plot/pl/render_params.py +++ b/src/spatialdata_plot/pl/render_params.py @@ -199,6 +199,14 @@ class ColorbarSpec: alpha: float | None = None +@dataclass +class ChannelLegendEntry: + """A single channel-to-color mapping for the categorical channel legend.""" + + channel_name: str + color_hex: str + + CBAR_DEFAULT_LOCATION = "right" CBAR_DEFAULT_FRACTION = 0.075 CBAR_DEFAULT_PAD = 0.015 @@ -275,6 +283,7 @@ class ImageRenderParams: colorbar_params: dict[str, object] | None = None transfunc: Callable[[np.ndarray], np.ndarray] | list[Callable[[np.ndarray], np.ndarray]] | None = None grayscale: bool = False + channels_as_categories: bool = False @dataclass diff --git a/tests/_images/ChannelsAsCategories_channels_as_categories_legend_lower_right.png b/tests/_images/ChannelsAsCategories_channels_as_categories_legend_lower_right.png new file mode 100644 index 00000000..df57f52c Binary files /dev/null and b/tests/_images/ChannelsAsCategories_channels_as_categories_legend_lower_right.png differ diff --git a/tests/_images/ChannelsAsCategories_channels_as_categories_legend_upper_left.png b/tests/_images/ChannelsAsCategories_channels_as_categories_legend_upper_left.png new file mode 100644 index 00000000..79502b90 Binary files /dev/null and b/tests/_images/ChannelsAsCategories_channels_as_categories_legend_upper_left.png differ diff --git a/tests/_images/ChannelsAsCategories_channels_as_categories_many_channels.png b/tests/_images/ChannelsAsCategories_channels_as_categories_many_channels.png new file mode 100644 index 00000000..3332f0fe Binary files /dev/null and b/tests/_images/ChannelsAsCategories_channels_as_categories_many_channels.png differ diff --git a/tests/_images/ChannelsAsCategories_channels_as_categories_three_channels_default.png b/tests/_images/ChannelsAsCategories_channels_as_categories_three_channels_default.png new file mode 100644 index 00000000..537aa53e Binary files /dev/null and b/tests/_images/ChannelsAsCategories_channels_as_categories_three_channels_default.png differ diff --git a/tests/_images/ChannelsAsCategories_channels_as_categories_two_channels.png b/tests/_images/ChannelsAsCategories_channels_as_categories_two_channels.png new file mode 100644 index 00000000..b2a0829d Binary files /dev/null and b/tests/_images/ChannelsAsCategories_channels_as_categories_two_channels.png differ diff --git a/tests/_images/ChannelsAsCategories_channels_as_categories_with_cmap_list.png b/tests/_images/ChannelsAsCategories_channels_as_categories_with_cmap_list.png new file mode 100644 index 00000000..41a9cc62 Binary files /dev/null and b/tests/_images/ChannelsAsCategories_channels_as_categories_with_cmap_list.png differ diff --git a/tests/_images/ChannelsAsCategories_channels_as_categories_with_palette.png b/tests/_images/ChannelsAsCategories_channels_as_categories_with_palette.png new file mode 100644 index 00000000..1c018591 Binary files /dev/null and b/tests/_images/ChannelsAsCategories_channels_as_categories_with_palette.png differ diff --git a/tests/pl/test_render_images.py b/tests/pl/test_render_images.py index 0bae024b..58635d2d 100644 --- a/tests/pl/test_render_images.py +++ b/tests/pl/test_render_images.py @@ -491,3 +491,99 @@ def test_cmap_matches_selected_channels_not_full_image(sdata_blobs: SpatialData) sdata_blobs.pl.render_images("blobs_image", channel=[0], cmap=["gray"]).pl.show(ax=ax) assert len(ax.get_images()) == 1 plt.close(fig) + + +# --------------------------------------------------------------------------- +# channels_as_categories visual tests (#459) +# --------------------------------------------------------------------------- + + +class TestChannelsAsCategories(PlotTester, metaclass=PlotTesterMeta): + def test_plot_channels_as_categories_two_channels(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_images(element="blobs_image", channel=[0, 1], channels_as_categories=True).pl.show() + + def test_plot_channels_as_categories_three_channels_default(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_images(element="blobs_image", channels_as_categories=True).pl.show() + + def test_plot_channels_as_categories_with_palette(self, sdata_blobs_str: SpatialData): + sdata_blobs_str.pl.render_images( + element="blobs_image", + channel=["c1", "c2", "c3"], + palette=["red", "green", "blue"], + channels_as_categories=True, + ).pl.show() + + def test_plot_channels_as_categories_many_channels(self, sdata_blobs_str: SpatialData): + sdata_blobs_str.pl.render_images(element="blobs_image", channels_as_categories=True).pl.show() + + def test_plot_channels_as_categories_with_cmap_list(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_images( + element="blobs_image", + channel=[0, 1, 2], + cmap=["Reds", "Greens", "Blues"], + channels_as_categories=True, + ).pl.show() + + def test_plot_channels_as_categories_legend_upper_left(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_images(element="blobs_image", channel=[0, 1], channels_as_categories=True).pl.show( + legend_loc="upper left" + ) + + def test_plot_channels_as_categories_legend_lower_right(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_images(element="blobs_image", channel=[0, 1], channels_as_categories=True).pl.show( + legend_loc="lower right" + ) + + +class TestChannelsAsCategoriesNonVisual: + """Non-visual tests for channels_as_categories edge cases.""" + + def test_channels_as_categories_ignored_for_single_channel(self, sdata_blobs: SpatialData): + fig, ax = plt.subplots() + sdata_blobs.pl.render_images(element="blobs_image", channel=0, channels_as_categories=True).pl.show(ax=ax) + assert ax.get_legend() is None + plt.close("all") + + def test_channels_as_categories_false_no_legend(self, sdata_blobs: SpatialData): + fig, ax = plt.subplots() + sdata_blobs.pl.render_images(element="blobs_image", channel=[0, 1], channels_as_categories=False).pl.show(ax=ax) + assert ax.get_legend() is None + plt.close("all") + + def test_channels_as_categories_chained_renders_combine(self, sdata_blobs: SpatialData): + """Multiple render_images with channels_as_categories should produce one combined legend.""" + fig, ax = plt.subplots() + ( + sdata_blobs.pl.render_images( + element="blobs_image", channel=[0, 1], palette=["red", "green"], channels_as_categories=True + ) + .pl.render_images( + element="blobs_image", channel=[1, 2], palette=["cyan", "blue"], channels_as_categories=True + ) + .pl.show(ax=ax) + ) + legend = ax.get_legend() + assert legend is not None + labels = [t.get_text() for t in legend.get_texts()] + # Both render calls contribute: channels 0, 1, 2. + # Channel "1" appears in both calls — dedup keeps the last color. + assert "0" in labels + assert "1" in labels + assert "2" in labels + assert len(labels) == 3 + plt.close("all") + + def test_channels_as_categories_coexists_with_other_elements(self, sdata_blobs: SpatialData): + """Channel legend should not crash when combined with other render calls.""" + fig, ax = plt.subplots() + ( + sdata_blobs.pl.render_images(element="blobs_image", channel=[0, 1], channels_as_categories=True) + .pl.render_labels(element="blobs_labels") + .pl.show(ax=ax) + ) + legend = ax.get_legend() + assert legend is not None + labels = [t.get_text() for t in legend.get_texts()] + assert "0" in labels + assert "1" in labels + plt.close("all")