Skip to content

Commit a1bb307

Browse files
authored
Use format() for shared subplot slice titles (#652)
1 parent 61366b0 commit a1bb307

5 files changed

Lines changed: 420 additions & 25 deletions

File tree

ultraplot/axes/base.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3191,13 +3191,31 @@ def get_tightbbox(self, renderer, *args, **kwargs):
31913191
# Perform extra post-processing steps
31923192
# NOTE: This should be updated alongside draw(). We also cache the resulting
31933193
# bounding box to speed up tight layout calculations (see _range_tightbbox).
3194+
include_subset_titles = kwargs.pop("include_subset_titles", True)
31943195
self._add_queued_guides()
31953196
self._apply_title_above()
31963197
if self._colorbar_fill:
31973198
self._colorbar_fill.update_ticks(manual_only=True) # only if needed
31983199
if self._inset_parent is not None and self._inset_zoom:
31993200
self.indicate_inset_zoom()
3200-
self._tight_bbox = super().get_tightbbox(renderer, *args, **kwargs)
3201+
bbox = super().get_tightbbox(renderer, *args, **kwargs)
3202+
fig = self.figure
3203+
if (
3204+
bbox is not None
3205+
and fig is not None
3206+
and self._panel_parent is None
3207+
and include_subset_titles
3208+
and hasattr(fig, "_get_subset_title_bbox")
3209+
):
3210+
title_bbox = fig._get_subset_title_bbox(self, renderer)
3211+
if title_bbox is not None:
3212+
bbox = mtransforms.Bbox.from_extents(
3213+
bbox.xmin,
3214+
min(bbox.ymin, title_bbox.ymin),
3215+
bbox.xmax,
3216+
max(bbox.ymax, title_bbox.ymax),
3217+
)
3218+
self._tight_bbox = bbox
32013219
return self._tight_bbox
32023220

32033221
def get_default_bbox_extra_artists(self):

ultraplot/figure.py

Lines changed: 196 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
from packaging import version
1212

1313
try:
14-
from typing import List, Optional, Tuple, Union
14+
from typing import Any, Iterable, List, Optional, Tuple, Union
1515
except ImportError:
16-
from typing_extensions import List, Optional, Tuple, Union
16+
from typing_extensions import Any, Iterable, List, Optional, Tuple, Union
1717

1818
import matplotlib.axes as maxes
1919
import matplotlib.figure as mfigure
@@ -868,6 +868,7 @@ def _normalize_share(value):
868868
self._supylabel_dict = {} # an axes: label mapping
869869
self._suplabel_dict = {"left": {}, "right": {}, "bottom": {}, "top": {}}
870870
self._share_label_groups = {"x": {}, "y": {}} # explicit label-sharing groups
871+
self._subset_title_dict = {}
871872
self._suptitle_pad = rc["suptitle.pad"]
872873
d = self._suplabel_props = {} # store the super label props
873874
d["left"] = {"va": "center", "ha": "right"}
@@ -1662,7 +1663,9 @@ def _get_align_coord(self, side, axs, align="center", includepanels=False):
16621663
ax = ax._panel_parent or ax # always use main subplot for spanning labels
16631664
return pos, ax
16641665

1665-
def _get_offset_coord(self, side, axs, renderer, *, pad=None, extra=None):
1666+
def _get_offset_coord(
1667+
self, side, axs, renderer, *, pad=None, extra=None, include_subset_titles=True
1668+
):
16661669
"""
16671670
Return the figure coordinate for offsetting super labels and super titles.
16681671
"""
@@ -1675,7 +1678,12 @@ def _get_offset_coord(self, side, axs, renderer, *, pad=None, extra=None):
16751678
) # noqa: E501
16761679
objs = objs + (extra or ()) # e.g. top super labels
16771680
for obj in objs:
1678-
bbox = obj.get_tightbbox(renderer) # cannot use cached bbox
1681+
if isinstance(obj, paxes.Axes):
1682+
bbox = obj.get_tightbbox(
1683+
renderer, include_subset_titles=include_subset_titles
1684+
)
1685+
else:
1686+
bbox = obj.get_tightbbox(renderer) # cannot use cached bbox
16791687
attr = s + "max" if side in ("top", "right") else s + "min"
16801688
c = getattr(bbox, attr)
16811689
c = (c, 0) if side in ("left", "right") else (0, c)
@@ -2523,6 +2531,12 @@ def _align_super_title(self, renderer):
25232531
if not axs:
25242532
return
25252533
labs = tuple(t for t in self._suplabel_dict["top"].values() if t.get_text())
2534+
subset_titles = tuple(
2535+
group["artist"]
2536+
for group in self._subset_title_dict.values()
2537+
if group["artist"].get_text()
2538+
)
2539+
labs = labs + subset_titles
25262540
pad = (self._suptitle_pad / 72) / self.get_size_inches()[1]
25272541

25282542
# Get current alignment settings from suptitle (may be set via suptitle_kw)
@@ -2548,6 +2562,183 @@ def _align_super_title(self, renderer):
25482562
y = y_target - y_bbox
25492563
self._suptitle.set_position((x, y))
25502564

2565+
def _update_subset_title(
2566+
self,
2567+
axes: Iterable[paxes.Axes],
2568+
title: str | None,
2569+
*,
2570+
fontdict: dict[str, Any] | None = None,
2571+
loc: str | None = None,
2572+
pad: float | str | None = None,
2573+
y: float | None = None,
2574+
**kwargs: Any,
2575+
) -> mtext.Text:
2576+
"""
2577+
Create or update a title spanning a subset of subplots.
2578+
"""
2579+
fontdict = _not_none(fontdict, kwargs.pop("fontdict", None))
2580+
loc = _not_none(
2581+
loc,
2582+
kwargs.pop("loc", None),
2583+
rc.find("title.loc", context=True),
2584+
rc["title.loc"],
2585+
)
2586+
pad = _not_none(
2587+
pad,
2588+
kwargs.pop("pad", None),
2589+
rc.find("title.pad", context=True),
2590+
rc["title.pad"],
2591+
)
2592+
y = _not_none(y, kwargs.pop("y", None))
2593+
axes = [ax for ax in axes if ax is not None and ax.figure is self]
2594+
if not axes:
2595+
raise ValueError("Need at least one axes to create a shared subplot title.")
2596+
2597+
seen = set()
2598+
unique_axes = []
2599+
for ax in axes:
2600+
ax = ax._panel_parent or ax
2601+
ax_id = id(ax)
2602+
if ax_id in seen:
2603+
continue
2604+
seen.add(ax_id)
2605+
unique_axes.append(ax)
2606+
axes = unique_axes
2607+
if len(axes) < 2:
2608+
return axes[0].set_title(
2609+
title, fontdict=fontdict, loc=loc, pad=pad, y=y, **kwargs
2610+
)
2611+
2612+
key = tuple(sorted(id(ax) for ax in axes))
2613+
group = self._subset_title_dict.get(key)
2614+
kw = rc.fill(
2615+
{
2616+
"size": "title.size",
2617+
"weight": "title.weight",
2618+
"color": "title.color",
2619+
"family": "font.family",
2620+
},
2621+
context=True,
2622+
)
2623+
if "color" in kw and kw["color"] == "auto":
2624+
del kw["color"]
2625+
if fontdict:
2626+
kw.update(fontdict)
2627+
kw.update(kwargs)
2628+
align = _translate_loc(loc, "text")
2629+
match align:
2630+
case "left" | "outer left" | "upper left" | "lower left":
2631+
align = "left"
2632+
case "center" | "upper center" | "lower center":
2633+
align = "center"
2634+
case "right" | "outer right" | "upper right" | "lower right":
2635+
align = "right"
2636+
case _:
2637+
raise ValueError(f"Invalid shared subplot title location {loc!r}.")
2638+
if group is None:
2639+
artist = self.text(
2640+
0.5,
2641+
0.0,
2642+
"",
2643+
transform=self.transFigure,
2644+
ha=align,
2645+
va="baseline",
2646+
zorder=3.5,
2647+
)
2648+
group = {"axes": axes, "artist": artist, "pad": None, "y": None}
2649+
self._subset_title_dict[key] = group
2650+
else:
2651+
artist = group["artist"]
2652+
group["axes"] = axes
2653+
group["pad"] = pad
2654+
group["y"] = y
2655+
artist.set_ha(align)
2656+
artist.set_va("baseline")
2657+
if title is not None:
2658+
artist.set_text(title)
2659+
if kw:
2660+
artist.update(kw)
2661+
return artist
2662+
2663+
def _get_subset_title_bbox(
2664+
self, ax: paxes.Axes, renderer
2665+
) -> mtransforms.Bbox | None:
2666+
"""
2667+
Return the union bbox for shared titles covering the given axes.
2668+
2669+
Shared subset titles live above the subset's top edge, so they should
2670+
only contribute to the tight bounding boxes for axes that actually touch
2671+
that top boundary. Otherwise, multi-row subsets can incorrectly claim
2672+
the title as extra inter-row spacing.
2673+
"""
2674+
ax = ax._panel_parent or ax
2675+
bboxes = []
2676+
for group in self._subset_title_dict.values():
2677+
artist = group["artist"]
2678+
if not artist.get_visible() or not artist.get_text():
2679+
continue
2680+
axs = [
2681+
group_ax._panel_parent or group_ax
2682+
for group_ax in group["axes"]
2683+
if group_ax is not None
2684+
and group_ax.figure is self
2685+
and group_ax.get_visible()
2686+
]
2687+
if not axs or ax not in axs:
2688+
continue
2689+
top = min(group_ax._range_subplotspec("y")[0] for group_ax in axs)
2690+
if ax._range_subplotspec("y")[0] == top:
2691+
bboxes.append(artist.get_window_extent(renderer))
2692+
return mtransforms.Bbox.union(bboxes) if bboxes else None
2693+
2694+
def _align_subset_titles(self, renderer):
2695+
"""
2696+
Update the positions of titles spanning subplot subsets.
2697+
"""
2698+
for key in list(self._subset_title_dict):
2699+
group = self._subset_title_dict[key]
2700+
artist = group["artist"]
2701+
axs = [
2702+
ax
2703+
for ax in group["axes"]
2704+
if ax is not None and ax.figure is self and ax.get_visible()
2705+
]
2706+
if not axs:
2707+
artist.remove()
2708+
del self._subset_title_dict[key]
2709+
continue
2710+
if not artist.get_text():
2711+
continue
2712+
align = artist.get_ha()
2713+
x, _ = self._get_align_coord(
2714+
"top",
2715+
axs,
2716+
includepanels=self._includepanels,
2717+
align=align,
2718+
)
2719+
top_labels = tuple(
2720+
lab
2721+
for ax, lab in self._suplabel_dict["top"].items()
2722+
if lab.get_text() and ax in axs
2723+
)
2724+
artist.set_x(x)
2725+
manual_y = group["y"]
2726+
if manual_y is not None:
2727+
artist.set_y(manual_y)
2728+
continue
2729+
pad = group["pad"]
2730+
if pad is not None:
2731+
pad = units(pad, "pt") / (72 * self.get_size_inches()[1])
2732+
y_target = self._get_offset_coord(
2733+
"top",
2734+
axs,
2735+
renderer,
2736+
pad=pad,
2737+
extra=top_labels,
2738+
include_subset_titles=False,
2739+
)
2740+
artist.set_y(y_target)
2741+
25512742
def _update_axis_label(self, side, axs):
25522743
"""
25532744
Update the aligned axis label for the input axes.
@@ -2777,6 +2968,7 @@ def _align_content(): # noqa: E306
27772968
self._align_axis_label(axis)
27782969
for side in ("left", "right", "top", "bottom"):
27792970
self._align_super_labels(side, renderer)
2971+
self._align_subset_titles(renderer)
27802972
self._align_super_title(renderer)
27812973

27822974
# Update the layout

ultraplot/gridspec.py

Lines changed: 44 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from .config import rc
2121
from .internals import (
2222
_not_none,
23+
_pop_rc,
2324
docstring,
2425
ic, # noqa: F401
2526
warnings,
@@ -2083,29 +2084,52 @@ def format(self, **kwargs):
20832084
share_ylabels = kwargs.get("share_ylabels", None)
20842085
xlabel = kwargs.get("xlabel", None)
20852086
ylabel = kwargs.get("ylabel", None)
2087+
title = kwargs.get("title", None)
20862088
axes = [ax for ax in self if ax is not None]
20872089
all_axes = set(self.figure._subplot_dict.values())
20882090
is_subset = bool(axes) and all_axes and set(axes) != all_axes
2089-
if len(self) > 1:
2090-
if share_xlabels is False:
2091-
self.figure._clear_share_label_groups(self, target="x")
2092-
if share_ylabels is False:
2093-
self.figure._clear_share_label_groups(self, target="y")
2094-
if not is_subset and share_xlabels is None and xlabel is not None:
2095-
self.figure._clear_share_label_groups(self, target="x")
2096-
if not is_subset and share_ylabels is None and ylabel is not None:
2097-
self.figure._clear_share_label_groups(self, target="y")
2098-
if is_subset and share_xlabels is None and xlabel is not None:
2099-
self.figure._register_share_label_group(self, target="x")
2100-
if is_subset and share_ylabels is None and ylabel is not None:
2101-
self.figure._register_share_label_group(self, target="y")
2102-
self.figure.format(axs=self, **kwargs)
2103-
# Refresh groups after labels are set
2104-
if len(self) > 1:
2105-
if is_subset and share_xlabels is None and xlabel is not None:
2106-
self.figure._register_share_label_group(self, target="x")
2107-
if is_subset and share_ylabels is None and ylabel is not None:
2108-
self.figure._register_share_label_group(self, target="y")
2091+
shared_subset_title = len(self) > 1 and is_subset and isinstance(title, str)
2092+
shared_title_kw = (
2093+
dict(kwargs.pop("title_kw", None) or {}) if shared_subset_title else None
2094+
)
2095+
if shared_subset_title:
2096+
kwargs.pop("title", None)
2097+
shared_title_loc = kwargs.pop("titleloc", None)
2098+
shared_title_pad = kwargs.pop("titlepad", None)
2099+
kwargs.pop("titleabove", None)
2100+
else:
2101+
shared_title_loc = None
2102+
shared_title_pad = None
2103+
rc_kw, rc_mode = _pop_rc(kwargs)
2104+
with rc.context(rc_kw, mode=rc_mode):
2105+
if len(self) > 1:
2106+
if share_xlabels is False:
2107+
self.figure._clear_share_label_groups(self, target="x")
2108+
if share_ylabels is False:
2109+
self.figure._clear_share_label_groups(self, target="y")
2110+
if not is_subset and share_xlabels is None and xlabel is not None:
2111+
self.figure._clear_share_label_groups(self, target="x")
2112+
if not is_subset and share_ylabels is None and ylabel is not None:
2113+
self.figure._clear_share_label_groups(self, target="y")
2114+
if is_subset and share_xlabels is None and xlabel is not None:
2115+
self.figure._register_share_label_group(self, target="x")
2116+
if is_subset and share_ylabels is None and ylabel is not None:
2117+
self.figure._register_share_label_group(self, target="y")
2118+
self.figure.format(axs=self, **kwargs)
2119+
if shared_subset_title:
2120+
self.figure._update_subset_title(
2121+
self,
2122+
title,
2123+
loc=shared_title_loc,
2124+
pad=shared_title_pad,
2125+
**(shared_title_kw or {}),
2126+
)
2127+
# Refresh groups after labels are set
2128+
if len(self) > 1:
2129+
if is_subset and share_xlabels is None and xlabel is not None:
2130+
self.figure._register_share_label_group(self, target="x")
2131+
if is_subset and share_ylabels is None and ylabel is not None:
2132+
self.figure._register_share_label_group(self, target="y")
21092133

21102134
def share_labels(self, *, axis="x"):
21112135
"""

ultraplot/tests/test_figure.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,22 @@ def test_suptitle_vertical_alignment_preserves_top_spacing(va):
447447
uplt.close("all")
448448

449449

450+
def test_suptitle_clears_shared_subset_titles():
451+
fig, axs = uplt.subplots(nrows=2, ncols=2)
452+
axs[0, :].format(title="Row title")
453+
fig.format(suptitle="Figure title")
454+
fig.canvas.draw()
455+
renderer = fig.canvas.get_renderer()
456+
457+
subset_title = next(iter(fig._subset_title_dict.values()))["artist"]
458+
subset_bbox = subset_title.get_window_extent(renderer)
459+
suptitle_bbox = fig._suptitle.get_window_extent(renderer)
460+
461+
assert subset_bbox.y1 <= suptitle_bbox.y0
462+
463+
uplt.close("all")
464+
465+
450466
def test_subplots_pixelsnap_aligns_axes_bounds():
451467
with uplt.rc.context({"subplots.pixelsnap": True}):
452468
fig, axs = uplt.subplots(ncols=2, nrows=2)

0 commit comments

Comments
 (0)