Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/spatialdata_plot/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from spatialdata_plot._accessor import register_spatial_data_accessor
from spatialdata_plot._logging import _log_context, logger
from spatialdata_plot.pl.render import (
_draw_channel_legend,
_render_images,
_render_labels,
_render_points,
Expand All @@ -40,6 +41,7 @@
CBAR_DEFAULT_FRACTION,
CBAR_DEFAULT_LOCATION,
CBAR_DEFAULT_PAD,
ChannelLegendEntry,
CmapParams,
ColorbarSpec,
ImageRenderParams,
Expand Down Expand Up @@ -523,6 +525,7 @@ def render_images(
transfunc: Callable[[np.ndarray], np.ndarray] | list[Callable[[np.ndarray], np.ndarray]] | None = None,
colorbar: bool | str | None = "auto",
colorbar_params: dict[str, object] | None = None,
channels_as_categories: bool = False,
**kwargs: Any,
) -> sd.SpatialData:
"""
Expand Down Expand Up @@ -600,6 +603,13 @@ def render_images(
colorbar_params :
Parameters forwarded to Matplotlib's colorbar alongside layout hints such as ``loc``, ``width``, ``pad``,
and ``label``.
channels_as_categories : bool, default False
When ``True`` and rendering multiple channels, show a categorical
legend mapping each channel name to its compositing color. The
legend uses the ``legend_*`` parameters from :meth:`show`.
Ignored for single-channel and RGB(A) images. When multiple
``render_images`` calls use this flag on the same axes, all
channel entries are combined into a single legend.
kwargs
Additional arguments to be passed to cmap, norm, and other rendering functions.

Expand Down Expand Up @@ -681,6 +691,7 @@ def render_images(
colorbar_params=param_values["colorbar_params"],
transfunc=transfunc,
grayscale=grayscale,
channels_as_categories=channels_as_categories,
)
n_steps += 1

Expand Down Expand Up @@ -1140,6 +1151,7 @@ def _draw_colorbar(
ax = fig_params.ax if fig_params.axs is None else fig_params.axs[i]
assert isinstance(ax, Axes)
axis_colorbar_requests: list[ColorbarSpec] | None = [] if legend_params.colorbar else None
axis_channel_legend_entries: list[ChannelLegendEntry] = []

wants_images = False
wants_labels = False
Expand Down Expand Up @@ -1170,6 +1182,7 @@ def _draw_colorbar(
scalebar_params=scalebar_params,
legend_params=legend_params,
colorbar_requests=axis_colorbar_requests,
channel_legend_entries=axis_channel_legend_entries,
rasterize=rasterize,
)

Expand Down Expand Up @@ -1279,6 +1292,9 @@ def _draw_colorbar(
if legend_params.colorbar and axis_colorbar_requests:
pending_colorbars.append((ax, axis_colorbar_requests))

if axis_channel_legend_entries:
_draw_channel_legend(ax, axis_channel_legend_entries, legend_params, fig_params)

if pending_colorbars and fig_params.fig is not None:
fig = fig_params.fig
fig.canvas.draw()
Expand Down
98 changes: 98 additions & 0 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import dataclasses
from collections import abc
from collections.abc import Sequence
from copy import copy
from typing import Any

Expand All @@ -18,9 +19,11 @@
import spatialdata as sd
import xarray as xr
from anndata import AnnData
from matplotlib import patheffects
from matplotlib.cm import ScalarMappable
from matplotlib.colors import ListedColormap, Normalize
from scanpy._settings import settings as sc_settings
from scanpy.plotting._tools.scatterplots import _add_categorical_legend
from spatialdata import get_extent, get_values, join_spatialelement_table
from spatialdata._core.query.relational_query import match_table_to_element
from spatialdata.models import PointsModel, ShapesModel, get_table_keys
Expand All @@ -41,6 +44,7 @@
_render_ds_outlines,
)
from spatialdata_plot.pl.render_params import (
ChannelLegendEntry,
CmapParams,
Color,
ColorbarSpec,
Expand Down Expand Up @@ -1094,6 +1098,78 @@ def _is_rgb_image(channel_coords: list[Any]) -> tuple[bool, bool]:
return False, False


def _collect_channel_legend_entries(
channels: Sequence[str | int],
seed_colors: Sequence[str | tuple[float, ...]],
channel_legend_entries: list[ChannelLegendEntry],
) -> None:
"""Accumulate channel-to-color mappings for a deferred combined legend."""
channel_names = [str(ch) for ch in channels]
if len(set(channel_names)) != len(channel_names):
logger.warning("channels_as_categories: duplicate channel names detected; skipping legend entries.")
return

color_hexes = [matplotlib.colors.to_hex(c, keep_alpha=False) for c in seed_colors]
for name, color in zip(channel_names, color_hexes, strict=True):
channel_legend_entries.append(ChannelLegendEntry(channel_name=name, color_hex=color))


def _draw_channel_legend(
ax: matplotlib.axes.SubplotBase,
entries: list[ChannelLegendEntry],
legend_params: LegendParams,
fig_params: FigParams,
) -> None:
"""Draw a single combined categorical legend from accumulated channel entries.

Because ``_add_categorical_legend`` adds invisible labeled scatter artists,
calling it here automatically merges with any earlier legend entries
(e.g. from labels or shapes) on the same axes via ``ax.legend()``.

``multi_panel`` is only set when no prior legend exists on the axis,
to avoid shrinking the axes twice (once for labels/shapes, once for
channels).
"""
# Deduplicate: if the same channel name appears twice, keep the last color
palette_dict: dict[str, str] = {}
for entry in entries:
palette_dict[entry.channel_name] = entry.color_hex

legend_loc = legend_params.legend_loc
if legend_loc == "on data":
logger.warning(
"legend_loc='on data' is not supported for channel legends (no scatter coordinates); "
"falling back to 'right margin'."
)
legend_loc = "right margin"

categories = pd.Categorical(list(palette_dict))

path_effect = (
[patheffects.withStroke(linewidth=legend_params.legend_fontoutline, foreground="w")]
if legend_params.legend_fontoutline is not None
else []
)

# Only apply multi_panel shrink if no legend already exists on this axis
# (labels/shapes draw their legend during the render loop and already shrink).
has_existing_legend = ax.get_legend() is not None
needs_multi_panel = fig_params.axs is not None and not has_existing_legend

_add_categorical_legend(
ax,
categories,
palette=palette_dict,
legend_loc=legend_loc,
legend_fontweight=legend_params.legend_fontweight,
legend_fontsize=legend_params.legend_fontsize,
legend_fontoutline=path_effect,
na_color=["lightgray"],
na_in_legend=False,
multi_panel=needs_multi_panel,
)


def _render_images(
sdata: sd.SpatialData,
render_params: ImageRenderParams,
Expand All @@ -1104,6 +1180,7 @@ def _render_images(
legend_params: LegendParams,
rasterize: bool,
colorbar_requests: list[ColorbarSpec] | None = None,
channel_legend_entries: list[ChannelLegendEntry] | None = None,
) -> None:
_log_context.set("render_images")
sdata_filt = sdata.filter_by_coordinate_system(
Expand Down Expand Up @@ -1319,10 +1396,14 @@ def _render_images(

layers[ch] = ch_norm(layers[ch])

# Colors for the channel legend (set by each branch if applicable)
legend_colors: list[str] | None = None

# 2A) Image has 3 channels, no palette info, and no/only one cmap was given
if palette is None and n_channels == 3 and not isinstance(render_params.cmap_params, list):
if render_params.cmap_params.cmap_is_default: # -> use RGB
stacked = np.clip(np.stack([layers[ch] for ch in layers], axis=-1), 0, 1)
legend_colors = ["red", "green", "blue"]
else: # -> use given cmap for each channel
channel_cmaps = [render_params.cmap_params.cmap] * n_channels
stacked = (
Expand Down Expand Up @@ -1404,6 +1485,8 @@ def _render_images(
f"multichannel strategy 'stack' to render."
) # TODO: update when pca is added as strategy

legend_colors = seed_colors

_ax_show_and_transform(
colored,
trans_data,
Expand All @@ -1421,6 +1504,8 @@ def _render_images(
colored = np.stack([channel_cmaps[i](layers[c]) for i, c in enumerate(channels)], 0).sum(0)
colored = np.clip(colored[:, :, :3], 0, 1)

legend_colors = list(palette)

_ax_show_and_transform(
colored,
trans_data,
Expand All @@ -1440,6 +1525,8 @@ def _render_images(
)
colored = colored[:, :, :3]

legend_colors = [matplotlib.colors.to_hex(cm(0.75)) for cm in channel_cmaps]

_ax_show_and_transform(
colored,
trans_data,
Expand All @@ -1452,6 +1539,17 @@ def _render_images(
elif palette is not None and got_multiple_cmaps:
raise ValueError("If 'palette' is provided, 'cmap' must be None.")

# Collect channel legend entries (single point for all multi-channel paths)
if render_params.channels_as_categories and channel_legend_entries is not None:
if legend_colors is not None:
_collect_channel_legend_entries(channels, legend_colors, channel_legend_entries)
else:
logger.warning(
"channels_as_categories requires distinct per-channel colors; "
"ignored when a single cmap is shared across channels. "
"Use 'palette' or a list of cmaps instead."
)


def _render_labels(
sdata: sd.SpatialData,
Expand Down
9 changes: 9 additions & 0 deletions src/spatialdata_plot/pl/render_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,14 @@ class ColorbarSpec:
alpha: float | None = None


@dataclass
class ChannelLegendEntry:
"""A single channel-to-color mapping for the categorical channel legend."""

channel_name: str
color_hex: str


CBAR_DEFAULT_LOCATION = "right"
CBAR_DEFAULT_FRACTION = 0.075
CBAR_DEFAULT_PAD = 0.015
Expand Down Expand Up @@ -275,6 +283,7 @@ class ImageRenderParams:
colorbar_params: dict[str, object] | None = None
transfunc: Callable[[np.ndarray], np.ndarray] | list[Callable[[np.ndarray], np.ndarray]] | None = None
grayscale: bool = False
channels_as_categories: bool = False


@dataclass
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
96 changes: 96 additions & 0 deletions tests/pl/test_render_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,3 +491,99 @@ def test_cmap_matches_selected_channels_not_full_image(sdata_blobs: SpatialData)
sdata_blobs.pl.render_images("blobs_image", channel=[0], cmap=["gray"]).pl.show(ax=ax)
assert len(ax.get_images()) == 1
plt.close(fig)


# ---------------------------------------------------------------------------
# channels_as_categories visual tests (#459)
# ---------------------------------------------------------------------------


class TestChannelsAsCategories(PlotTester, metaclass=PlotTesterMeta):
def test_plot_channels_as_categories_two_channels(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_images(element="blobs_image", channel=[0, 1], channels_as_categories=True).pl.show()

def test_plot_channels_as_categories_three_channels_default(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_images(element="blobs_image", channels_as_categories=True).pl.show()

def test_plot_channels_as_categories_with_palette(self, sdata_blobs_str: SpatialData):
sdata_blobs_str.pl.render_images(
element="blobs_image",
channel=["c1", "c2", "c3"],
palette=["red", "green", "blue"],
channels_as_categories=True,
).pl.show()

def test_plot_channels_as_categories_many_channels(self, sdata_blobs_str: SpatialData):
sdata_blobs_str.pl.render_images(element="blobs_image", channels_as_categories=True).pl.show()

def test_plot_channels_as_categories_with_cmap_list(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_images(
element="blobs_image",
channel=[0, 1, 2],
cmap=["Reds", "Greens", "Blues"],
channels_as_categories=True,
).pl.show()

def test_plot_channels_as_categories_legend_upper_left(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_images(element="blobs_image", channel=[0, 1], channels_as_categories=True).pl.show(
legend_loc="upper left"
)

def test_plot_channels_as_categories_legend_lower_right(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_images(element="blobs_image", channel=[0, 1], channels_as_categories=True).pl.show(
legend_loc="lower right"
)


class TestChannelsAsCategoriesNonVisual:
"""Non-visual tests for channels_as_categories edge cases."""

def test_channels_as_categories_ignored_for_single_channel(self, sdata_blobs: SpatialData):
fig, ax = plt.subplots()
sdata_blobs.pl.render_images(element="blobs_image", channel=0, channels_as_categories=True).pl.show(ax=ax)
assert ax.get_legend() is None
plt.close("all")

def test_channels_as_categories_false_no_legend(self, sdata_blobs: SpatialData):
fig, ax = plt.subplots()
sdata_blobs.pl.render_images(element="blobs_image", channel=[0, 1], channels_as_categories=False).pl.show(ax=ax)
assert ax.get_legend() is None
plt.close("all")

def test_channels_as_categories_chained_renders_combine(self, sdata_blobs: SpatialData):
"""Multiple render_images with channels_as_categories should produce one combined legend."""
fig, ax = plt.subplots()
(
sdata_blobs.pl.render_images(
element="blobs_image", channel=[0, 1], palette=["red", "green"], channels_as_categories=True
)
.pl.render_images(
element="blobs_image", channel=[1, 2], palette=["cyan", "blue"], channels_as_categories=True
)
.pl.show(ax=ax)
)
legend = ax.get_legend()
assert legend is not None
labels = [t.get_text() for t in legend.get_texts()]
# Both render calls contribute: channels 0, 1, 2.
# Channel "1" appears in both calls — dedup keeps the last color.
assert "0" in labels
assert "1" in labels
assert "2" in labels
assert len(labels) == 3
plt.close("all")

def test_channels_as_categories_coexists_with_other_elements(self, sdata_blobs: SpatialData):
"""Channel legend should not crash when combined with other render calls."""
fig, ax = plt.subplots()
(
sdata_blobs.pl.render_images(element="blobs_image", channel=[0, 1], channels_as_categories=True)
.pl.render_labels(element="blobs_labels")
.pl.show(ax=ax)
)
legend = ax.get_legend()
assert legend is not None
labels = [t.get_text() for t in legend.get_texts()]
assert "0" in labels
assert "1" in labels
plt.close("all")
Loading