diff --git a/ultraplot/figure.py b/ultraplot/figure.py index 4f1bedb92..2bac74b22 100644 --- a/ultraplot/figure.py +++ b/ultraplot/figure.py @@ -1164,6 +1164,109 @@ def _snap_axes_to_pixel_grid(self, renderer) -> None: which="both", ) + def _find_misaligned_spans( + self, axes: List[paxes.Axes], *, tol: float = 1e-9 + ) -> List[Tuple[str, int, int, mtransforms.Bbox, mtransforms.Bbox, paxes.Axes]]: + """ + Identify spanning axes whose actual position differs from their + gridspec slot (e.g. because of an aspect constraint). + + Returns a list of ``(axis, start, stop, slot, pos, ref_ax)`` tuples + where *axis* is ``'y'`` for row-spanning or ``'x'`` for column-spanning. + """ + spans = [] + for ax in axes: + try: + ax.apply_aspect() + ss = ax.get_subplotspec().get_topmost_subplotspec() + row1, row2, col1, col2 = ss._get_rows_columns() + slot = ss.get_position(self) + pos = ax.get_position(original=False) + except (AttributeError, TypeError): + continue + + if row2 > row1 and ( + abs(pos.y0 - slot.y0) > tol + or abs((pos.y0 + pos.height) - (slot.y0 + slot.height)) > tol + ): + spans.append(("y", row1, row2, slot, pos, ax)) + if col2 > col1 and ( + abs(pos.x0 - slot.x0) > tol + or abs((pos.x0 + pos.width) - (slot.x0 + slot.width)) > tol + ): + spans.append(("x", col1, col2, slot, pos, ax)) + return spans + + def _remap_axes_to_span( + self, + axes: List[paxes.Axes], + spans: List[ + Tuple[str, int, int, mtransforms.Bbox, mtransforms.Bbox, paxes.Axes] + ], + *, + tol: float = 1e-9, + ) -> None: + """ + Remap sibling axes so they align with the actual bounds of + spanning axes described by *spans*. Siblings with their own + fixed aspect are skipped since they have independent constraints. + """ + for axis, start, stop, slot, pos, ref_ax in spans: + slot0 = slot.y0 if axis == "y" else slot.x0 + slotsize = slot.height if axis == "y" else slot.width + pos0 = pos.y0 if axis == "y" else pos.x0 + possize = pos.height if axis == "y" else pos.width + if slotsize <= tol or possize <= tol: + continue + + for ax in axes: + if ax is ref_ax: + continue + try: + if ax.get_aspect() != "auto": + continue + ss = ax.get_subplotspec().get_topmost_subplotspec() + row1, row2, col1, col2 = ss._get_rows_columns() + if axis == "y": + if row1 < start or row2 > stop: + continue + else: + if col1 < start or col2 > stop: + continue + old = ss.get_position(self) + except (AttributeError, TypeError): + continue + + if axis == "y": + rel0 = (old.y0 - slot0) / slotsize + rel1 = (old.y0 + old.height - slot0) / slotsize + new0 = pos0 + rel0 * possize + new1 = pos0 + rel1 * possize + bounds = [old.x0, new0, old.width, new1 - new0] + else: + rel0 = (old.x0 - slot0) / slotsize + rel1 = (old.x0 + old.width - slot0) / slotsize + new0 = pos0 + rel0 * possize + new1 = pos0 + rel1 * possize + bounds = [new0, old.y0, new1 - new0, old.height] + ax.set_position(bounds, which="both") + + def _align_spanning_axes(self, *, tol: float = 1e-9) -> None: + """ + Align sibling subplots to spanning axes whose actual position + differs from their gridspec slot. + + When a subplot spans multiple rows or columns and is shrunk inside + its slot (e.g. by a fixed aspect ratio), the adjacent subplots keep + their full extent and visibly stick out. This method detects the + mismatch and remaps the sibling positions proportionally. + """ + axes = list(self._iter_axes(hidden=False, children=False, panels=False)) + if not axes: + return + spans = self._find_misaligned_spans(axes, tol=tol) + self._remap_axes_to_span(axes, spans, tol=tol) + def _share_ticklabels(self, *, axis: str) -> None: """ Tick label sharing is determined at the figure level. While @@ -2562,6 +2665,61 @@ def _align_super_title(self, renderer): y = y_target - y_bbox self._suptitle.set_position((x, y)) + @staticmethod + def _deduplicate_axes(axes: Iterable[paxes.Axes]) -> List[paxes.Axes]: + """ + Resolve panel parents and remove duplicates, preserving order. + """ + seen = set() + unique = [] + for ax in axes: + ax = ax._panel_parent or ax + ax_id = id(ax) + if ax_id not in seen: + seen.add(ax_id) + unique.append(ax) + return unique + + @staticmethod + def _normalize_title_alignment(loc: str) -> str: + """ + Convert a *loc* string to a horizontal alignment for ``Text.set_ha``. + """ + align = _translate_loc(loc, "text") + match align: + case "left" | "outer left" | "upper left" | "lower left": + return "left" + case "center" | "upper center" | "lower center": + return "center" + case "right" | "outer right" | "upper right" | "lower right": + return "right" + case _: + raise ValueError(f"Invalid shared subplot title location {loc!r}.") + + @staticmethod + def _resolve_title_props( + fontdict: dict[str, Any] | None, kwargs: dict[str, Any] + ) -> dict[str, Any]: + """ + Build the property dict for a title from rc defaults, *fontdict*, + and extra *kwargs*. + """ + kw = rc.fill( + { + "size": "title.size", + "weight": "title.weight", + "color": "title.color", + "family": "font.family", + }, + context=True, + ) + if "color" in kw and kw["color"] == "auto": + del kw["color"] + if fontdict: + kw.update(fontdict) + kw.update(kwargs) + return kw + def _update_subset_title( self, axes: Iterable[paxes.Axes], @@ -2594,16 +2752,7 @@ def _update_subset_title( if not axes: raise ValueError("Need at least one axes to create a shared subplot title.") - seen = set() - unique_axes = [] - for ax in axes: - ax = ax._panel_parent or ax - ax_id = id(ax) - if ax_id in seen: - continue - seen.add(ax_id) - unique_axes.append(ax) - axes = unique_axes + axes = self._deduplicate_axes(axes) if len(axes) < 2: return axes[0].set_title( title, fontdict=fontdict, loc=loc, pad=pad, y=y, **kwargs @@ -2611,30 +2760,9 @@ def _update_subset_title( key = tuple(sorted(id(ax) for ax in axes)) group = self._subset_title_dict.get(key) - kw = rc.fill( - { - "size": "title.size", - "weight": "title.weight", - "color": "title.color", - "family": "font.family", - }, - context=True, - ) - if "color" in kw and kw["color"] == "auto": - del kw["color"] - if fontdict: - kw.update(fontdict) - kw.update(kwargs) - align = _translate_loc(loc, "text") - match align: - case "left" | "outer left" | "upper left" | "lower left": - align = "left" - case "center" | "upper center" | "lower center": - align = "center" - case "right" | "outer right" | "upper right" | "lower right": - align = "right" - case _: - raise ValueError(f"Invalid shared subplot title location {loc!r}.") + kw = self._resolve_title_props(fontdict, kwargs) + align = self._normalize_title_alignment(loc) + if group is None: artist = self.text( 0.5, @@ -2660,6 +2788,16 @@ def _update_subset_title( artist.update(kw) return artist + def _visible_subset_group_axes(self, group: dict[str, Any]) -> List[paxes.Axes]: + """ + Return visible axes from a subset-title group that belong to this figure. + """ + return [ + ax + for ax in group["axes"] + if ax is not None and ax.figure is self and ax.get_visible() + ] + def _get_subset_title_bbox( self, ax: paxes.Axes, renderer ) -> mtransforms.Bbox | None: @@ -2677,32 +2815,22 @@ def _get_subset_title_bbox( artist = group["artist"] if not artist.get_visible() or not artist.get_text(): continue - axs = [ - group_ax._panel_parent or group_ax - for group_ax in group["axes"] - if group_ax is not None - and group_ax.figure is self - and group_ax.get_visible() - ] + axs = [a._panel_parent or a for a in self._visible_subset_group_axes(group)] if not axs or ax not in axs: continue - top = min(group_ax._range_subplotspec("y")[0] for group_ax in axs) + top = min(a._range_subplotspec("y")[0] for a in axs) if ax._range_subplotspec("y")[0] == top: bboxes.append(artist.get_window_extent(renderer)) return mtransforms.Bbox.union(bboxes) if bboxes else None - def _align_subset_titles(self, renderer): + def _align_subset_titles(self, renderer: Any) -> None: """ Update the positions of titles spanning subplot subsets. """ for key in list(self._subset_title_dict): group = self._subset_title_dict[key] artist = group["artist"] - axs = [ - ax - for ax in group["axes"] - if ax is not None and ax.figure is self and ax.get_visible() - ] + axs = self._visible_subset_group_axes(group) if not axs: artist.remove() del self._subset_title_dict[key] @@ -2979,9 +3107,11 @@ def _align_content(): # noqa: E306 return if aspect: gs._auto_layout_aspect() + self._align_spanning_axes() _align_content() if tight: gs._auto_layout_tight(renderer) + self._align_spanning_axes() _align_content() @warnings._rename_kwargs( diff --git a/ultraplot/tests/test_figure.py b/ultraplot/tests/test_figure.py index 53a297399..a78ac0402 100644 --- a/ultraplot/tests/test_figure.py +++ b/ultraplot/tests/test_figure.py @@ -479,3 +479,394 @@ def test_subplots_pixelsnap_aligns_axes_bounds(): [bbox.x0 * width, bbox.y0 * height, bbox.x1 * width, bbox.y1 * height] ) assert np.allclose(coords, np.round(coords), atol=1e-8) + + +def test_figure_repr(): + fig, axs = uplt.subplots(ncols=2, nrows=3) + r = repr(fig) + assert "Figure(" in r + assert "nrows=3" in r + assert "ncols=2" in r + uplt.close(fig) + + +def test_register_share_label_group_basic(): + fig, axs = uplt.subplots(ncols=3) + axs[0].set_xlabel("shared x") + axs[1].set_xlabel("also x") + fig._register_share_label_group([axs[0], axs[1]], target="x", source=axs[0]) + assert fig._has_share_label_groups("x") + assert fig._is_share_label_group_member(axs[0], "x") + assert fig._is_share_label_group_member(axs[1], "x") + assert not fig._is_share_label_group_member(axs[2], "x") + uplt.close(fig) + + +def test_register_share_label_group_y(): + fig, axs = uplt.subplots(nrows=3) + axs[0].set_ylabel("shared y") + axs[1].set_ylabel("also y") + fig._register_share_label_group([axs[0], axs[1]], target="y", source=axs[0]) + assert fig._has_share_label_groups("y") + assert fig._is_share_label_group_member(axs[0], "y") + uplt.close(fig) + + +def test_register_share_label_group_empty_and_single(): + fig, axs = uplt.subplots(ncols=2) + fig._register_share_label_group([], target="x") + assert not fig._has_share_label_groups("x") + fig._register_share_label_group([axs[0]], target="x") + assert not fig._has_share_label_groups("x") + uplt.close(fig) + + +def test_register_share_label_group_deduplicates(): + fig, axs = uplt.subplots(ncols=2) + axs[0].set_xlabel("x") + fig._register_share_label_group([axs[0], axs[0], axs[1]], target="x") + assert fig._has_share_label_groups("x") + uplt.close(fig) + + +def test_clear_share_label_groups_all(): + fig, axs = uplt.subplots(ncols=3) + axs[0].set_xlabel("x") + fig._register_share_label_group([axs[0], axs[1]], target="x") + fig._register_share_label_group([axs[0], axs[1]], target="y") + assert fig._has_share_label_groups("x") + fig._clear_share_label_groups() + assert not fig._has_share_label_groups("x") + assert not fig._has_share_label_groups("y") + uplt.close(fig) + + +def test_clear_share_label_groups_by_axes(): + fig, axs = uplt.subplots(ncols=3) + axs[0].set_xlabel("x0") + axs[2].set_xlabel("x2") + fig._register_share_label_group([axs[0], axs[1]], target="x") + fig._clear_share_label_groups(axes=[axs[0]], target="x") + assert not fig._has_share_label_groups("x") + uplt.close(fig) + + +def test_clear_share_label_groups_with_spanning_labels(): + fig, axs = uplt.subplots(ncols=3) + axs[0].set_xlabel("shared x") + axs[1].set_xlabel("shared x") + fig._register_share_label_group([axs[0], axs[1]], target="x", source=axs[0]) + fig.canvas.draw() + fig._clear_share_label_groups(axes=[axs[0], axs[1]], target="x") + assert not fig._has_share_label_groups("x") + uplt.close(fig) + + +def test_apply_share_label_groups(): + fig, axs = uplt.subplots(ncols=3, share=False) + axs[0].set_xlabel("shared label") + axs[1].set_xlabel("") + fig._register_share_label_group([axs[0], axs[1]], target="x", source=axs[0]) + fig.canvas.draw() + uplt.close(fig) + + +def test_apply_share_label_groups_y(): + fig, axs = uplt.subplots(nrows=3, share=False) + axs[0].set_ylabel("shared label") + axs[1].set_ylabel("") + fig._register_share_label_group([axs[0], axs[1]], target="y", source=axs[0]) + fig.canvas.draw() + uplt.close(fig) + + +def test_register_share_label_group_updates_existing(): + fig, axs = uplt.subplots(ncols=3) + axs[0].set_xlabel("original") + fig._register_share_label_group([axs[0], axs[1]], target="x", source=axs[0]) + axs[0].set_xlabel("updated") + fig._register_share_label_group([axs[0], axs[1]], target="x", source=axs[0]) + fig.canvas.draw() + uplt.close(fig) + + +def test_share_label_group_mixed_label_position_splits(): + fig, axs = uplt.subplots(ncols=3, share=False) + axs[0].set_xlabel("bottom") + axs[1].xaxis.set_label_position("top") + axs[1].set_xlabel("top") + axs[2].set_xlabel("bottom") + fig._register_share_label_group([axs[0], axs[1], axs[2]], target="x") + fig.canvas.draw() + uplt.close(fig) + + +def test_deduplicate_axes(): + fig, axs = uplt.subplots(ncols=3) + result = fig._deduplicate_axes([axs[0], axs[0], axs[1]]) + assert len(result) == 2 + uplt.close(fig) + + +def test_normalize_title_alignment_left(): + from ultraplot.figure import Figure + + assert Figure._normalize_title_alignment("left") == "left" + + +def test_normalize_title_alignment_center(): + from ultraplot.figure import Figure + + assert Figure._normalize_title_alignment("center") == "center" + + +def test_normalize_title_alignment_right(): + from ultraplot.figure import Figure + + assert Figure._normalize_title_alignment("right") == "right" + + +def test_normalize_title_alignment_invalid(): + from ultraplot.figure import Figure + + with pytest.raises((ValueError, KeyError)): + Figure._normalize_title_alignment("invalid_loc_xyz") + + +def test_resolve_title_props_defaults(): + from ultraplot.figure import Figure + + kw = Figure._resolve_title_props(None, {}) + assert isinstance(kw, dict) + + +def test_resolve_title_props_with_fontdict(): + from ultraplot.figure import Figure + + kw = Figure._resolve_title_props({"size": 20}, {"weight": "bold"}) + assert kw["size"] == 20 + assert kw["weight"] == "bold" + + +def test_visible_subset_group_axes(): + fig, axs = uplt.subplots(ncols=3) + group = {"axes": list(axs), "artist": None} + result = fig._visible_subset_group_axes(group) + assert len(result) == 3 + uplt.close(fig) + + +def test_update_subset_title_single_axes_delegates(): + fig, axs = uplt.subplots(ncols=3) + artist = fig._update_subset_title([axs[0]], "Solo title") + assert artist.get_text() == "Solo title" + uplt.close(fig) + + +def test_update_subset_title_empty_raises(): + fig, axs = uplt.subplots(ncols=2) + with pytest.raises(ValueError, match="Need at least one"): + fig._update_subset_title([], "No axes") + uplt.close(fig) + + +def test_update_subset_title_creates_group(): + fig, axs = uplt.subplots(ncols=3) + artist = fig._update_subset_title([axs[0], axs[1]], "Two-panel title") + assert artist.get_text() == "Two-panel title" + assert len(fig._subset_title_dict) == 1 + uplt.close(fig) + + +def test_update_subset_title_update_existing(): + fig, axs = uplt.subplots(ncols=3) + fig._update_subset_title([axs[0], axs[1]], "First") + fig._update_subset_title([axs[0], axs[1]], "Updated") + assert len(fig._subset_title_dict) == 1 + group = next(iter(fig._subset_title_dict.values())) + assert group["artist"].get_text() == "Updated" + uplt.close(fig) + + +def test_get_subset_title_bbox_returns_none_when_empty(): + fig, axs = uplt.subplots(ncols=2) + renderer = fig._get_renderer() + assert fig._get_subset_title_bbox(axs[0], renderer) is None + uplt.close(fig) + + +def test_get_subset_title_bbox_for_top_row_only(): + fig, axs = uplt.subplots(nrows=2, ncols=2) + fig._update_subset_title([axs[0], axs[1]], "Top row title") + fig.canvas.draw() + renderer = fig._get_renderer() + bbox_top = fig._get_subset_title_bbox(axs[0], renderer) + bbox_bottom = fig._get_subset_title_bbox(axs[2], renderer) + assert bbox_top is not None + assert bbox_bottom is None + uplt.close(fig) + + +def test_align_subset_titles_removes_orphaned(): + fig, axs = uplt.subplots(ncols=3) + fig._update_subset_title([axs[0], axs[1]], "Will be orphaned") + key = next(iter(fig._subset_title_dict)) + fig._subset_title_dict[key]["axes"] = [] + renderer = fig._get_renderer() + fig._align_subset_titles(renderer) + assert len(fig._subset_title_dict) == 0 + uplt.close(fig) + + +def test_align_subset_titles_with_manual_y(): + fig, axs = uplt.subplots(ncols=3) + fig._update_subset_title([axs[0], axs[1]], "Manual Y", y=0.95) + fig.canvas.draw() + key = next(iter(fig._subset_title_dict)) + artist = fig._subset_title_dict[key]["artist"] + assert np.isclose(artist.get_position()[1], 0.95) + uplt.close(fig) + + +def test_subset_title_left_alignment(): + fig, axs = uplt.subplots(ncols=3) + fig._update_subset_title([axs[0], axs[1]], "Left title", loc="left") + key = next(iter(fig._subset_title_dict)) + artist = fig._subset_title_dict[key]["artist"] + assert artist.get_ha() == "left" + uplt.close(fig) + + +def test_subset_title_right_alignment(): + fig, axs = uplt.subplots(ncols=3) + fig._update_subset_title([axs[0], axs[1]], "Right title", loc="right") + key = next(iter(fig._subset_title_dict)) + artist = fig._subset_title_dict[key]["artist"] + assert artist.get_ha() == "right" + uplt.close(fig) + + +def test_find_aspect_spans_empty(): + fig, axs = uplt.subplots(ncols=2) + spans = fig._find_misaligned_spans([]) + assert spans == [] + uplt.close(fig) + + +def test_find_aspect_spans_no_aspect(): + fig, axs = uplt.subplots(ncols=2) + axes = list(fig._iter_axes(hidden=False, children=False, panels=False)) + spans = fig._find_misaligned_spans(axes) + assert spans == [] + uplt.close(fig) + + +def test_remap_axes_with_empty_spans(): + fig, axs = uplt.subplots(ncols=2) + axes = list(fig._iter_axes(hidden=False, children=False, panels=False)) + fig._remap_axes_to_span(axes, []) + uplt.close(fig) + + +def test_align_spanning_axes_no_axes(): + fig = uplt.figure() + fig._align_spanning_axes() + uplt.close(fig) + + +def test_aspect_row_spanning_layout(): + fig, axs = uplt.subplots([[1, 2], [1, 3]]) + axs[0].set_aspect("equal") + axs[0].plot([0, 1], [0, 1]) + axs[1].plot([0, 1], [0, 1]) + axs[2].plot([0, 1], [0, 1]) + fig.canvas.draw() + axes = list(fig._iter_axes(hidden=False, children=False, panels=False)) + spans = fig._find_misaligned_spans(axes) + assert len(spans) >= 1 + assert any(s[0] == "y" for s in spans) + uplt.close(fig) + + +def test_aspect_col_spanning_layout(): + fig, axs = uplt.subplots([[1, 1], [2, 3]]) + axs[0].set_aspect("equal") + axs[0].plot([0, 1], [0, 1]) + axs[1].plot([0, 1], [0, 1]) + axs[2].plot([0, 1], [0, 1]) + fig.canvas.draw() + axes = list(fig._iter_axes(hidden=False, children=False, panels=False)) + spans = fig._find_misaligned_spans(axes) + assert len(spans) >= 1 + assert any(s[0] == "x" for s in spans) + uplt.close(fig) + + +def test_full_align_aspect_row_spanning(): + fig, axs = uplt.subplots([[1, 2], [1, 3]]) + axs[0].set_aspect("equal") + axs[0].plot([0, 1], [0, 1]) + axs[1].plot([0, 1], [0, 1]) + axs[2].plot([0, 1], [0, 1]) + fig.canvas.draw() + pos0 = axs[0].get_position() + pos1 = axs[1].get_position() + pos2 = axs[2].get_position() + assert pos1.y0 + pos1.height <= pos0.y0 + pos0.height + 0.01 + uplt.close(fig) + + +def test_add_subplot_three_integer_args(): + fig = uplt.figure() + ax = fig.add_subplot(2, 2, 1) + assert ax is not None + ax2 = fig.add_subplot(2, 2, (3, 4)) + assert ax2 is not None + uplt.close(fig) + + +def test_explicit_figwidth_figheight(): + fig, axs = uplt.subplots(figwidth=6, figheight=4) + w, h = fig.get_size_inches() + assert np.isclose(w, 6, atol=0.1) + assert np.isclose(h, 4, atol=0.1) + uplt.close(fig) + + +def test_figwidth_overrides_refwidth(): + with warnings.catch_warnings(record=True) as record: + warnings.simplefilter("always") + fig, axs = uplt.subplots(figwidth=6, refwidth=3) + conflict_warnings = [w for w in record if "conflicting" in str(w.message).lower()] + assert len(conflict_warnings) >= 1 + uplt.close(fig) + + +def test_figheight_overrides_refheight(): + with warnings.catch_warnings(record=True) as record: + warnings.simplefilter("always") + fig, axs = uplt.subplots(figheight=4, refheight=2) + conflict_warnings = [w for w in record if "conflicting" in str(w.message).lower()] + assert len(conflict_warnings) >= 1 + uplt.close(fig) + + +def test_journal_size(): + fig, axs = uplt.subplots(journal="ams1") + fig.canvas.draw() + uplt.close(fig) + + +def test_subplots_with_gridspec_kw_warns(): + with warnings.catch_warnings(record=True) as record: + warnings.simplefilter("always") + fig, axs = uplt.subplots([[1, 2], [3, 4]], gridspec_kw={"hspace": 0.5}) + kw_warnings = [w for w in record if "not necessary" in str(w.message).lower()] + assert len(kw_warnings) >= 1 + uplt.close(fig) + + +def test_refaspect_as_tuple(): + fig, axs = uplt.subplots(refaspect=(16, 9)) + fig.canvas.draw() + uplt.close(fig) diff --git a/ultraplot/tests/test_ultralayout.py b/ultraplot/tests/test_ultralayout.py index 3ea43b1d8..8a72b7aab 100644 --- a/ultraplot/tests/test_ultralayout.py +++ b/ultraplot/tests/test_ultralayout.py @@ -184,6 +184,26 @@ def test_ultralayout_panel_alignment_matches_parent(): uplt.close(fig) +@pytest.mark.parametrize("ultra_layout", [False, True]) +def test_fixed_aspect_spanning_axes_keeps_adjacent_stack_aligned(ultra_layout): + """Fixed-aspect spanning axes should keep adjacent subplot stacks aligned.""" + if ultra_layout: + pytest.importorskip("kiwisolver") + + fig, axs = uplt.subplots(array=[[1, 2], [1, 3]], ultra_layout=ultra_layout) + axs[0].plot([0, 1], [0, 1]) + axs[0].format(aspect="equal") + fig.canvas.draw() + + left = axs[0].get_position() + top_right = axs[1].get_position() + bottom_right = axs[2].get_position() + + assert np.isclose(top_right.y1, left.y1) + assert np.isclose(bottom_right.y0, left.y0) + uplt.close(fig) + + def test_subplots_with_orthogonal_layout(): """Test creating subplots with orthogonal layout (should work as before).""" layout = [[1, 2], [3, 4]]