Skip to content

Commit 90f2d84

Browse files
author
rohan
committed
feat(gds-viz): add phase portrait visualization module
New gds_viz.phase module with vector fields, trajectories, nullclines, and full phase_portrait() combining all three. Supports >2D systems via fixed_states projection. Behind [phase] optional extra (matplotlib + numpy + gds-continuous) — existing Mermaid functionality has no new dependencies. 10 tests: vector field computation, trajectory integration, nullclines, Lorenz 3D projection. 97% package coverage. Closes #126.
1 parent 828f832 commit 90f2d84

4 files changed

Lines changed: 563 additions & 0 deletions

File tree

packages/gds-viz/gds_viz/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,12 @@
1717
"system_to_mermaid",
1818
"trace_to_mermaid",
1919
]
20+
21+
22+
def __getattr__(name: str) -> object:
23+
"""Lazy import for optional phase portrait module."""
24+
if name == "phase_portrait":
25+
from gds_viz.phase import phase_portrait
26+
27+
return phase_portrait
28+
raise AttributeError(f"module 'gds_viz' has no attribute {name!r}")

packages/gds-viz/gds_viz/phase.py

Lines changed: 339 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,339 @@
1+
"""Phase portrait visualization for continuous-time ODE systems.
2+
3+
Produces matplotlib figures: vector fields, trajectories, nullclines,
4+
and backward reachable set boundaries (isochrones).
5+
6+
Requires ``gds-viz[phase]`` (matplotlib + numpy + gds-continuous).
7+
8+
Example::
9+
10+
from gds_continuous import ODEModel
11+
from gds_viz.phase import phase_portrait
12+
13+
model = ODEModel(
14+
state_names=["x", "v"],
15+
initial_state={"x": 1.0, "v": 0.0},
16+
rhs=my_ode_fn,
17+
)
18+
fig = phase_portrait(model, x_var="x", y_var="v", x_range=(-3, 3), y_range=(-3, 3))
19+
"""
20+
21+
from __future__ import annotations
22+
23+
from dataclasses import dataclass, field
24+
from typing import TYPE_CHECKING, Any
25+
26+
if TYPE_CHECKING:
27+
from gds_continuous import ODEModel
28+
from gds_continuous.results import ODEResults
29+
30+
31+
def _require_phase_deps() -> None:
32+
"""Raise ImportError if matplotlib/numpy are absent."""
33+
try:
34+
import matplotlib # noqa: F401
35+
import numpy # noqa: F401
36+
except ImportError as exc:
37+
raise ImportError(
38+
"Phase portrait visualization requires matplotlib and numpy. "
39+
"Install with: uv add gds-viz[phase]"
40+
) from exc
41+
42+
43+
@dataclass(frozen=True)
44+
class PhasePlotConfig:
45+
"""Configuration for a phase portrait."""
46+
47+
x_var: str
48+
y_var: str
49+
x_range: tuple[float, float]
50+
y_range: tuple[float, float]
51+
resolution: int = 20
52+
fixed_states: dict[str, float] = field(default_factory=dict)
53+
params: dict[str, float] = field(default_factory=dict)
54+
title: str = ""
55+
56+
57+
def compute_vector_field(
58+
model: ODEModel,
59+
config: PhasePlotConfig,
60+
*,
61+
t: float = 0.0,
62+
) -> tuple[Any, Any, Any, Any]:
63+
"""Compute a 2D vector field over a grid.
64+
65+
Parameters
66+
----------
67+
model
68+
ODE model with the RHS function.
69+
config
70+
Grid specification (axes, ranges, resolution).
71+
t
72+
Time value for evaluating the RHS (default 0).
73+
74+
Returns
75+
-------
76+
X, Y, dX, dY : numpy arrays
77+
Meshgrid coordinates and derivative components.
78+
"""
79+
_require_phase_deps()
80+
import numpy as np
81+
82+
x_idx = model.state_names.index(config.x_var)
83+
y_idx = model.state_names.index(config.y_var)
84+
85+
xs = np.linspace(config.x_range[0], config.x_range[1], config.resolution)
86+
ys = np.linspace(config.y_range[0], config.y_range[1], config.resolution)
87+
X, Y = np.meshgrid(xs, ys)
88+
89+
dX = np.zeros_like(X)
90+
dY = np.zeros_like(Y)
91+
92+
# Build base state from fixed values
93+
base = [config.fixed_states.get(n, 0.0) for n in model.state_names]
94+
95+
for i in range(config.resolution):
96+
for j in range(config.resolution):
97+
state = list(base)
98+
state[x_idx] = X[i, j]
99+
state[y_idx] = Y[i, j]
100+
deriv = model.rhs(t, state, config.params)
101+
dX[i, j] = deriv[x_idx]
102+
dY[i, j] = deriv[y_idx]
103+
104+
return X, Y, dX, dY
105+
106+
107+
def compute_trajectories(
108+
model: ODEModel,
109+
initial_conditions: list[dict[str, float]],
110+
*,
111+
t_span: tuple[float, float] = (0.0, 10.0),
112+
params: dict[str, float] | None = None,
113+
solver: str = "RK45",
114+
max_step: float = 0.05,
115+
) -> list[ODEResults]:
116+
"""Integrate multiple trajectories from different initial conditions.
117+
118+
Parameters
119+
----------
120+
model
121+
ODE model (``rhs`` is used, ``initial_state`` is overridden).
122+
initial_conditions
123+
List of state dicts, one per trajectory.
124+
t_span
125+
Integration time interval.
126+
params
127+
Parameter values (single set, not a sweep).
128+
solver
129+
SciPy solver name.
130+
max_step
131+
Maximum integration step size.
132+
133+
Returns
134+
-------
135+
List of ODEResults, one per initial condition.
136+
"""
137+
from gds_continuous import ODEModel as _ODEModel
138+
from gds_continuous import ODESimulation
139+
140+
results = []
141+
p = params or {}
142+
for ic in initial_conditions:
143+
m = _ODEModel(
144+
state_names=model.state_names,
145+
initial_state=ic,
146+
rhs=model.rhs,
147+
params={k: [v] for k, v in p.items()},
148+
)
149+
sim = ODESimulation(
150+
model=m,
151+
t_span=t_span,
152+
solver=solver, # type: ignore[arg-type]
153+
max_step=max_step,
154+
)
155+
results.append(sim.run())
156+
return results
157+
158+
159+
def vector_field_plot(
160+
model: ODEModel,
161+
config: PhasePlotConfig,
162+
*,
163+
ax: Any | None = None,
164+
normalize: bool = True,
165+
color: str = "gray",
166+
alpha: float = 0.6,
167+
) -> Any:
168+
"""Plot a 2D vector field (quiver plot).
169+
170+
Returns the matplotlib Figure.
171+
"""
172+
_require_phase_deps()
173+
import matplotlib.pyplot as plt
174+
import numpy as np
175+
176+
X, Y, dX, dY = compute_vector_field(model, config)
177+
178+
if ax is None:
179+
fig, ax = plt.subplots(1, 1, figsize=(8, 8))
180+
else:
181+
fig = ax.get_figure()
182+
183+
if normalize:
184+
mag = np.sqrt(dX**2 + dY**2)
185+
mag = np.where(mag > 0, mag, 1.0)
186+
dX = dX / mag
187+
dY = dY / mag
188+
189+
ax.quiver(X, Y, dX, dY, color=color, alpha=alpha, scale=25)
190+
ax.set_xlabel(config.x_var)
191+
ax.set_ylabel(config.y_var)
192+
ax.set_aspect("equal")
193+
if config.title:
194+
ax.set_title(config.title)
195+
ax.grid(True, alpha=0.3)
196+
return fig
197+
198+
199+
def trajectory_plot(
200+
results_list: list[ODEResults],
201+
x_var: str,
202+
y_var: str,
203+
*,
204+
ax: Any | None = None,
205+
colormap: str = "viridis",
206+
linewidth: float = 1.0,
207+
show_start: bool = True,
208+
show_end: bool = True,
209+
) -> Any:
210+
"""Plot trajectories in phase space.
211+
212+
Returns the matplotlib Figure.
213+
"""
214+
_require_phase_deps()
215+
import matplotlib.pyplot as plt
216+
217+
if ax is None:
218+
fig, ax = plt.subplots(1, 1, figsize=(8, 8))
219+
else:
220+
fig = ax.get_figure()
221+
222+
cmap = plt.get_cmap(colormap)
223+
n = max(len(results_list), 1)
224+
225+
for i, res in enumerate(results_list):
226+
c = cmap(i / n)
227+
xs = res.state_array(x_var)
228+
ys = res.state_array(y_var)
229+
ax.plot(xs, ys, "-", color=c, linewidth=linewidth, alpha=0.8)
230+
if show_start:
231+
ax.plot(xs[0], ys[0], "o", color=c, markersize=5)
232+
if show_end:
233+
ax.plot(xs[-1], ys[-1], "s", color=c, markersize=4)
234+
235+
ax.set_xlabel(x_var)
236+
ax.set_ylabel(y_var)
237+
ax.set_aspect("equal")
238+
ax.grid(True, alpha=0.3)
239+
return fig
240+
241+
242+
def phase_portrait(
243+
model: ODEModel,
244+
x_var: str,
245+
y_var: str,
246+
x_range: tuple[float, float],
247+
y_range: tuple[float, float],
248+
*,
249+
initial_conditions: list[dict[str, float]] | None = None,
250+
params: dict[str, float] | None = None,
251+
fixed_states: dict[str, float] | None = None,
252+
t_span: tuple[float, float] = (0.0, 10.0),
253+
resolution: int = 20,
254+
title: str = "",
255+
show_nullclines: bool = False,
256+
figsize: tuple[float, float] = (10, 10),
257+
) -> Any:
258+
"""Full phase portrait: vector field + optional trajectories + nullclines.
259+
260+
Parameters
261+
----------
262+
model
263+
ODE model.
264+
x_var, y_var
265+
State variable names for the two axes.
266+
x_range, y_range
267+
Plot ranges for each axis.
268+
initial_conditions
269+
List of state dicts for trajectory integration. None = no trajectories.
270+
params
271+
Parameter values for RHS evaluation.
272+
fixed_states
273+
Values for state variables not on the axes (for >2D systems).
274+
t_span
275+
Integration time for trajectories.
276+
resolution
277+
Grid density for vector field.
278+
title
279+
Plot title.
280+
show_nullclines
281+
If True, draw zero-contours of dx/dt=0 and dy/dt=0.
282+
figsize
283+
Figure size.
284+
285+
Returns
286+
-------
287+
matplotlib Figure.
288+
"""
289+
_require_phase_deps()
290+
import matplotlib.pyplot as plt
291+
import numpy as np
292+
293+
config = PhasePlotConfig(
294+
x_var=x_var,
295+
y_var=y_var,
296+
x_range=x_range,
297+
y_range=y_range,
298+
resolution=resolution,
299+
fixed_states=fixed_states or {},
300+
params=params or {},
301+
title=title,
302+
)
303+
304+
fig, ax = plt.subplots(1, 1, figsize=figsize)
305+
306+
# Vector field
307+
X, Y, dX, dY = compute_vector_field(model, config)
308+
mag = np.sqrt(dX**2 + dY**2)
309+
mag = np.where(mag > 0, mag, 1.0)
310+
ax.quiver(X, Y, dX / mag, dY / mag, color="gray", alpha=0.4, scale=25)
311+
312+
# Nullclines
313+
if show_nullclines:
314+
ax.contour(X, Y, dX, levels=[0], colors=["blue"], linewidths=[1.5], alpha=0.6)
315+
ax.contour(X, Y, dY, levels=[0], colors=["red"], linewidths=[1.5], alpha=0.6)
316+
317+
# Trajectories
318+
if initial_conditions:
319+
trajs = compute_trajectories(
320+
model, initial_conditions, t_span=t_span, params=params
321+
)
322+
cmap = plt.get_cmap("viridis")
323+
n = max(len(trajs), 1)
324+
for i, res in enumerate(trajs):
325+
c = cmap(i / n)
326+
xs = res.state_array(x_var)
327+
ys = res.state_array(y_var)
328+
ax.plot(xs, ys, "-", color=c, linewidth=1.2, alpha=0.8)
329+
ax.plot(xs[0], ys[0], "o", color=c, markersize=5)
330+
331+
ax.set_xlabel(x_var)
332+
ax.set_ylabel(y_var)
333+
ax.set_xlim(x_range)
334+
ax.set_ylim(y_range)
335+
ax.set_aspect("equal")
336+
ax.set_title(title)
337+
ax.grid(True, alpha=0.3)
338+
plt.tight_layout()
339+
return fig

packages/gds-viz/pyproject.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ dependencies = [
3333
"gds-framework>=0.2.3",
3434
]
3535

36+
[project.optional-dependencies]
37+
phase = ["matplotlib>=3.8", "numpy>=1.26", "gds-continuous>=0.1.0"]
38+
3639
[project.urls]
3740
Homepage = "https://github.com/BlockScience/gds-core"
3841
Repository = "https://github.com/BlockScience/gds-core"
@@ -50,6 +53,7 @@ packages = ["gds_viz"]
5053

5154
[tool.uv.sources]
5255
gds-framework = { workspace = true }
56+
gds-continuous = { workspace = true }
5357

5458
[tool.pytest.ini_options]
5559
testpaths = ["tests"]
@@ -74,4 +78,7 @@ dev = [
7478
"pytest>=8.0",
7579
"pytest-cov>=6.0",
7680
"ruff>=0.8",
81+
"matplotlib>=3.8",
82+
"numpy>=1.26",
83+
"scipy>=1.13",
7784
]

0 commit comments

Comments
 (0)