Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 178 additions & 48 deletions ultraplot/figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -1164,6 +1164,109 @@ def _snap_axes_to_pixel_grid(self, renderer) -> None:
which="both",
)

def _find_misaligned_spans(
self, axes: List[paxes.Axes], *, tol: float = 1e-9
) -> List[Tuple[str, int, int, mtransforms.Bbox, mtransforms.Bbox, paxes.Axes]]:
"""
Identify spanning axes whose actual position differs from their
gridspec slot (e.g. because of an aspect constraint).

Returns a list of ``(axis, start, stop, slot, pos, ref_ax)`` tuples
where *axis* is ``'y'`` for row-spanning or ``'x'`` for column-spanning.
"""
spans = []
for ax in axes:
try:
ax.apply_aspect()
ss = ax.get_subplotspec().get_topmost_subplotspec()
row1, row2, col1, col2 = ss._get_rows_columns()
slot = ss.get_position(self)
pos = ax.get_position(original=False)
except (AttributeError, TypeError):
continue

if row2 > row1 and (
abs(pos.y0 - slot.y0) > tol
or abs((pos.y0 + pos.height) - (slot.y0 + slot.height)) > tol
):
spans.append(("y", row1, row2, slot, pos, ax))
if col2 > col1 and (
abs(pos.x0 - slot.x0) > tol
or abs((pos.x0 + pos.width) - (slot.x0 + slot.width)) > tol
):
spans.append(("x", col1, col2, slot, pos, ax))
return spans

def _remap_axes_to_span(
self,
axes: List[paxes.Axes],
spans: List[
Tuple[str, int, int, mtransforms.Bbox, mtransforms.Bbox, paxes.Axes]
],
*,
tol: float = 1e-9,
) -> None:
"""
Remap sibling axes so they align with the actual bounds of
spanning axes described by *spans*. Siblings with their own
fixed aspect are skipped since they have independent constraints.
"""
for axis, start, stop, slot, pos, ref_ax in spans:
slot0 = slot.y0 if axis == "y" else slot.x0
slotsize = slot.height if axis == "y" else slot.width
pos0 = pos.y0 if axis == "y" else pos.x0
possize = pos.height if axis == "y" else pos.width
if slotsize <= tol or possize <= tol:
continue

for ax in axes:
if ax is ref_ax:
continue
try:
if ax.get_aspect() != "auto":
continue
ss = ax.get_subplotspec().get_topmost_subplotspec()
row1, row2, col1, col2 = ss._get_rows_columns()
if axis == "y":
if row1 < start or row2 > stop:
continue
else:
if col1 < start or col2 > stop:
continue
old = ss.get_position(self)
except (AttributeError, TypeError):
continue

if axis == "y":
rel0 = (old.y0 - slot0) / slotsize
rel1 = (old.y0 + old.height - slot0) / slotsize
new0 = pos0 + rel0 * possize
new1 = pos0 + rel1 * possize
bounds = [old.x0, new0, old.width, new1 - new0]
else:
rel0 = (old.x0 - slot0) / slotsize
rel1 = (old.x0 + old.width - slot0) / slotsize
new0 = pos0 + rel0 * possize
new1 = pos0 + rel1 * possize
bounds = [new0, old.y0, new1 - new0, old.height]
ax.set_position(bounds, which="both")

def _align_spanning_axes(self, *, tol: float = 1e-9) -> None:
"""
Align sibling subplots to spanning axes whose actual position
differs from their gridspec slot.

When a subplot spans multiple rows or columns and is shrunk inside
its slot (e.g. by a fixed aspect ratio), the adjacent subplots keep
their full extent and visibly stick out. This method detects the
mismatch and remaps the sibling positions proportionally.
"""
axes = list(self._iter_axes(hidden=False, children=False, panels=False))
if not axes:
return
spans = self._find_misaligned_spans(axes, tol=tol)
self._remap_axes_to_span(axes, spans, tol=tol)

def _share_ticklabels(self, *, axis: str) -> None:
"""
Tick label sharing is determined at the figure level. While
Expand Down Expand Up @@ -2562,6 +2665,61 @@ def _align_super_title(self, renderer):
y = y_target - y_bbox
self._suptitle.set_position((x, y))

@staticmethod
def _deduplicate_axes(axes: Iterable[paxes.Axes]) -> List[paxes.Axes]:
"""
Resolve panel parents and remove duplicates, preserving order.
"""
seen = set()
unique = []
for ax in axes:
ax = ax._panel_parent or ax
ax_id = id(ax)
if ax_id not in seen:
seen.add(ax_id)
unique.append(ax)
return unique

@staticmethod
def _normalize_title_alignment(loc: str) -> str:
"""
Convert a *loc* string to a horizontal alignment for ``Text.set_ha``.
"""
align = _translate_loc(loc, "text")
match align:
case "left" | "outer left" | "upper left" | "lower left":
return "left"
case "center" | "upper center" | "lower center":
return "center"
case "right" | "outer right" | "upper right" | "lower right":
return "right"
case _:
raise ValueError(f"Invalid shared subplot title location {loc!r}.")

@staticmethod
def _resolve_title_props(
fontdict: dict[str, Any] | None, kwargs: dict[str, Any]
) -> dict[str, Any]:
"""
Build the property dict for a title from rc defaults, *fontdict*,
and extra *kwargs*.
"""
kw = rc.fill(
{
"size": "title.size",
"weight": "title.weight",
"color": "title.color",
"family": "font.family",
},
context=True,
)
if "color" in kw and kw["color"] == "auto":
del kw["color"]
if fontdict:
kw.update(fontdict)
kw.update(kwargs)
return kw

def _update_subset_title(
self,
axes: Iterable[paxes.Axes],
Expand Down Expand Up @@ -2594,47 +2752,17 @@ def _update_subset_title(
if not axes:
raise ValueError("Need at least one axes to create a shared subplot title.")

seen = set()
unique_axes = []
for ax in axes:
ax = ax._panel_parent or ax
ax_id = id(ax)
if ax_id in seen:
continue
seen.add(ax_id)
unique_axes.append(ax)
axes = unique_axes
axes = self._deduplicate_axes(axes)
if len(axes) < 2:
return axes[0].set_title(
title, fontdict=fontdict, loc=loc, pad=pad, y=y, **kwargs
)

key = tuple(sorted(id(ax) for ax in axes))
group = self._subset_title_dict.get(key)
kw = rc.fill(
{
"size": "title.size",
"weight": "title.weight",
"color": "title.color",
"family": "font.family",
},
context=True,
)
if "color" in kw and kw["color"] == "auto":
del kw["color"]
if fontdict:
kw.update(fontdict)
kw.update(kwargs)
align = _translate_loc(loc, "text")
match align:
case "left" | "outer left" | "upper left" | "lower left":
align = "left"
case "center" | "upper center" | "lower center":
align = "center"
case "right" | "outer right" | "upper right" | "lower right":
align = "right"
case _:
raise ValueError(f"Invalid shared subplot title location {loc!r}.")
kw = self._resolve_title_props(fontdict, kwargs)
align = self._normalize_title_alignment(loc)

if group is None:
artist = self.text(
0.5,
Expand All @@ -2660,6 +2788,16 @@ def _update_subset_title(
artist.update(kw)
return artist

def _visible_subset_group_axes(self, group: dict[str, Any]) -> List[paxes.Axes]:
"""
Return visible axes from a subset-title group that belong to this figure.
"""
return [
ax
for ax in group["axes"]
if ax is not None and ax.figure is self and ax.get_visible()
]

def _get_subset_title_bbox(
self, ax: paxes.Axes, renderer
) -> mtransforms.Bbox | None:
Expand All @@ -2677,32 +2815,22 @@ def _get_subset_title_bbox(
artist = group["artist"]
if not artist.get_visible() or not artist.get_text():
continue
axs = [
group_ax._panel_parent or group_ax
for group_ax in group["axes"]
if group_ax is not None
and group_ax.figure is self
and group_ax.get_visible()
]
axs = [a._panel_parent or a for a in self._visible_subset_group_axes(group)]
if not axs or ax not in axs:
continue
top = min(group_ax._range_subplotspec("y")[0] for group_ax in axs)
top = min(a._range_subplotspec("y")[0] for a in axs)
if ax._range_subplotspec("y")[0] == top:
bboxes.append(artist.get_window_extent(renderer))
return mtransforms.Bbox.union(bboxes) if bboxes else None

def _align_subset_titles(self, renderer):
def _align_subset_titles(self, renderer: Any) -> None:
"""
Update the positions of titles spanning subplot subsets.
"""
for key in list(self._subset_title_dict):
group = self._subset_title_dict[key]
artist = group["artist"]
axs = [
ax
for ax in group["axes"]
if ax is not None and ax.figure is self and ax.get_visible()
]
axs = self._visible_subset_group_axes(group)
if not axs:
artist.remove()
del self._subset_title_dict[key]
Expand Down Expand Up @@ -2979,9 +3107,11 @@ def _align_content(): # noqa: E306
return
if aspect:
gs._auto_layout_aspect()
self._align_spanning_axes()
_align_content()
if tight:
gs._auto_layout_tight(renderer)
self._align_spanning_axes()
_align_content()

@warnings._rename_kwargs(
Expand Down
Loading
Loading