From a89f62ed347b79507a6092d35847291eea537a12 Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Tue, 7 Apr 2026 18:25:32 +1000 Subject: [PATCH 1/4] Teach curved_quiver to treat NaN values as masked regions instead of hard failures so vector fields with gaps still render cleanly. This updates the trajectory solver to stop gracefully at invalid samples, hardens arrow construction around masked endpoints, and adds regression coverage for both solver-level and plotting-level NaN cases. --- ultraplot/axes/plot.py | 16 +++++-- ultraplot/axes/plot_types/curved_quiver.py | 20 ++++++-- ultraplot/tests/test_plot.py | 53 ++++++++++++++++++++++ 3 files changed, 82 insertions(+), 7 deletions(-) diff --git a/ultraplot/axes/plot.py b/ultraplot/axes/plot.py index 3ba71f09f..08f99a04d 100644 --- a/ultraplot/axes/plot.py +++ b/ultraplot/axes/plot.py @@ -1927,7 +1927,11 @@ def curved_quiver( The implementation of this function is based on the `dfm_tools` repository. Original file: https://github.com/Deltares/dfm_tools/blob/829e76f48ebc42460aae118cc190147a595a5f26/dfm_tools/modplot.py """ - from .plot_types.curved_quiver import CurvedQuiverSet, CurvedQuiverSolver + from .plot_types.curved_quiver import ( + CurvedQuiverSet, + CurvedQuiverSolver, + _CurvedQuiverTerminateTrajectory, + ) # Parse inputs arrowsize = _not_none(arrowsize, rc["curved_quiver.arrowsize"]) @@ -1968,6 +1972,7 @@ def curved_quiver( raise ValueError( "If 'linewidth' is given, must have the shape of 'Grid(x,y)'" ) + linewidth = np.ma.masked_invalid(linewidth) line_kw["linewidth"] = [] else: line_kw["linewidth"] = linewidth @@ -2060,8 +2065,11 @@ def curved_quiver( tx[-1] - solver.grid.x_origin, ty[-1] - solver.grid.y_origin ) - ui = solver.interpgrid(u, xg, yg) - vi = solver.interpgrid(v, xg, yg) + try: + ui = solver.interpgrid(u, xg, yg) + vi = solver.interpgrid(v, xg, yg) + except _CurvedQuiverTerminateTrajectory: + continue norm_v = np.sqrt(ui**2 + vi**2) if norm_v > 0: @@ -2087,6 +2095,8 @@ def curved_quiver( if isinstance(linewidth, np.ndarray): line_widths = solver.interpgrid(linewidth, tgx, tgy)[:-1] line_kw["linewidth"].extend(line_widths) + if np.ma.is_masked(line_widths[n]): + continue arrow_kw["linewidth"] = line_widths[n] if use_multicolor_lines: diff --git a/ultraplot/axes/plot_types/curved_quiver.py b/ultraplot/axes/plot_types/curved_quiver.py index 5489a0a90..0ec5d8e67 100644 --- a/ultraplot/axes/plot_types/curved_quiver.py +++ b/ultraplot/axes/plot_types/curved_quiver.py @@ -296,9 +296,14 @@ def integrate_rk12( hit_edge = False while self.domain_map.grid.within_grid(xi, yi): + try: + current_magnitude = self.interpgrid(magnitude, xi, yi) + except _CurvedQuiverTerminateTrajectory: + break + xf_traj.append(xi) yf_traj.append(yi) - m_total.append(self.interpgrid(magnitude, xi, yi)) + m_total.append(current_magnitude) try: k1x, k1y = f(xi, yi) @@ -324,8 +329,15 @@ def integrate_rk12( # Only save step if within error tolerance if error < maxerror: - xi += dx2 - yi += dy2 + next_xi = xi + dx2 + next_yi = yi + dy2 + if self.domain_map.grid.within_grid(next_xi, next_yi): + try: + self.interpgrid(magnitude, next_xi, next_yi) + except _CurvedQuiverTerminateTrajectory: + break + xi = next_xi + yi = next_yi self.domain_map.update_trajectory(xi, yi) if not self.domain_map.grid.within_grid(xi, yi): hit_edge = True @@ -400,7 +412,7 @@ def interpgrid(self, a, xi, yi): if not isinstance(xi, np.ndarray): if np.ma.is_masked(ai): - raise _CurvedQuiverTerminateTrajectory + raise _CurvedQuiverTerminateTrajectory() return ai def gen_starting_points(self, x, y, grains): diff --git a/ultraplot/tests/test_plot.py b/ultraplot/tests/test_plot.py index 1f8811ed2..299206302 100644 --- a/ultraplot/tests/test_plot.py +++ b/ultraplot/tests/test_plot.py @@ -671,6 +671,32 @@ def test_curved_quiver(rng): return fig +def test_curved_quiver_integrator_skips_nan_seed(): + """ + Test that masked seed points terminate cleanly instead of escaping the solver. + """ + from ultraplot.axes.plot_types.curved_quiver import CurvedQuiverSolver + + x = np.linspace(0, 1, 5) + y = np.linspace(0, 1, 5) + u = np.ones((5, 5)) + v = np.ones((5, 5)) + u[2, 2] = np.nan + v[2, 2] = np.nan + u = np.ma.masked_invalid(u) + v = np.ma.masked_invalid(v) + magnitude = np.sqrt(u**2 + v**2) + magnitude /= np.max(magnitude) + + solver = CurvedQuiverSolver(x, y, density=5) + integrator = solver.get_integrator( + u, v, minlength=0.1, resolution=1.0, magnitude=magnitude + ) + + assert integrator(2.0, 2.0) is None + assert not solver.mask._mask.any() + + def test_validate_vector_shapes_pass(): """ Test that vector shapes match the grid shape using CurvedQuiverSolver. @@ -779,6 +805,33 @@ def test_curved_quiver_multicolor_lines(): return fig +def test_curved_quiver_nan_vectors(): + """ + Test that curved_quiver skips NaN vector regions without failing. + """ + x = np.linspace(-1, 1, 21) + y = np.linspace(-1, 1, 21) + X, Y = np.meshgrid(x, y) + U = -Y.copy() + V = X.copy() + speed = np.sqrt(U**2 + V**2) + invalid = (np.abs(X) < 0.2) & (np.abs(Y) < 0.2) + U[invalid] = np.nan + V[invalid] = np.nan + speed[invalid] = np.nan + + fig, ax = uplt.subplots() + m = ax.curved_quiver( + X, Y, U, V, color=speed, arrow_at_end=True, scale=2.0, grains=10 + ) + + segments = m.lines.get_segments() + assert segments + assert all(np.isfinite(segment).all() for segment in segments) + assert len(ax.patches) > 0 + uplt.close(fig) + + @pytest.mark.mpl_image_compare @pytest.mark.parametrize( "cmap", From 2e55add5271f5a21106224fee60779fe42219670 Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Tue, 7 Apr 2026 18:32:31 +1000 Subject: [PATCH 2/4] Stop exposing the curved quiver solver's private termination exception to the plotting layer so NaN handling stays encapsulated where the trajectories are computed. The end-arrow path now uses the final streamline segment as its tangent, which removes the extra endpoint resampling and keeps the public plotting code out of the solver's internal control flow. --- ultraplot/axes/plot.py | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/ultraplot/axes/plot.py b/ultraplot/axes/plot.py index 08f99a04d..a9093c01e 100644 --- a/ultraplot/axes/plot.py +++ b/ultraplot/axes/plot.py @@ -1927,11 +1927,7 @@ def curved_quiver( The implementation of this function is based on the `dfm_tools` repository. Original file: https://github.com/Deltares/dfm_tools/blob/829e76f48ebc42460aae118cc190147a595a5f26/dfm_tools/modplot.py """ - from .plot_types.curved_quiver import ( - CurvedQuiverSet, - CurvedQuiverSolver, - _CurvedQuiverTerminateTrajectory, - ) + from .plot_types.curved_quiver import CurvedQuiverSet, CurvedQuiverSolver # Parse inputs arrowsize = _not_none(arrowsize, rc["curved_quiver.arrowsize"]) @@ -2059,17 +2055,8 @@ def curved_quiver( continue arrow_tail = (tx[-1], ty[-1]) - - # Extrapolate to find arrow head - xg, yg = solver.domain_map.data2grid( - tx[-1] - solver.grid.x_origin, ty[-1] - solver.grid.y_origin - ) - - try: - ui = solver.interpgrid(u, xg, yg) - vi = solver.interpgrid(v, xg, yg) - except _CurvedQuiverTerminateTrajectory: - continue + ui = tx[-1] - tx[-2] + vi = ty[-1] - ty[-2] norm_v = np.sqrt(ui**2 + vi**2) if norm_v > 0: From 0f48aa0013c4e81132bec2a8e5f053465c108043 Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Tue, 7 Apr 2026 18:38:10 +1000 Subject: [PATCH 3/4] Refine the curved quiver NaN fix by making the solver return a small trajectory result object instead of leaking tuple-shaped internal state into the plotting layer. This keeps endpoint direction and edge handling with the integration code, simplifies the renderer, and adds assertions that the solver now exposes the metadata the plot path actually consumes. --- ultraplot/axes/plot.py | 25 ++++++++-------- ultraplot/axes/plot_types/curved_quiver.py | 35 ++++++++++++++++------ ultraplot/tests/test_plot.py | 15 ++++++++-- 3 files changed, 52 insertions(+), 23 deletions(-) diff --git a/ultraplot/axes/plot.py b/ultraplot/axes/plot.py index a9093c01e..dd33c12a3 100644 --- a/ultraplot/axes/plot.py +++ b/ultraplot/axes/plot.py @@ -1991,7 +1991,6 @@ def curved_quiver( integrate = solver.get_integrator(u, v, minlength, resolution, magnitude) trajectories = [] - edges = [] if start_points is None: start_points = solver.gen_starting_points(x, y, grains) @@ -2027,18 +2026,19 @@ def curved_quiver( for xs, ys in sp2: xg, yg = solver.domain_map.data2grid(xs, ys) - t = integrate(xg, yg) - if t is not None: - trajectories.append(t[0]) - edges.append(t[1]) + trajectory = integrate(xg, yg) + if trajectory is not None: + trajectories.append(trajectory) streamlines = [] arrows = [] - for t, edge in zip(trajectories, edges): - tgx = np.array(t[0]) - tgy = np.array(t[1]) + for trajectory in trajectories: + tgx = np.array(trajectory.x) + tgy = np.array(trajectory.y) # Rescale from grid-coordinates to data-coordinates. - tx, ty = solver.domain_map.grid2data(*np.array(t)) + tx, ty = solver.domain_map.grid2data( + *np.array([trajectory.x, trajectory.y]) + ) tx += solver.grid.x_origin ty += solver.grid.y_origin @@ -2055,8 +2055,9 @@ def curved_quiver( continue arrow_tail = (tx[-1], ty[-1]) - ui = tx[-1] - tx[-2] - vi = ty[-1] - ty[-2] + if trajectory.end_direction is None: + continue + ui, vi = trajectory.end_direction norm_v = np.sqrt(ui**2 + vi**2) if norm_v > 0: @@ -2091,7 +2092,7 @@ def curved_quiver( line_colors.append(color_values) arrow_kw["color"] = cmap(norm(color_values[n])) - if not edge: + if not trajectory.hit_edge: p = mpatches.FancyArrowPatch( arrow_tail, arrow_head, transform=transform, **arrow_kw ) diff --git a/ultraplot/axes/plot_types/curved_quiver.py b/ultraplot/axes/plot_types/curved_quiver.py index 0ec5d8e67..639f51d93 100644 --- a/ultraplot/axes/plot_types/curved_quiver.py +++ b/ultraplot/axes/plot_types/curved_quiver.py @@ -20,6 +20,14 @@ class CurvedQuiverSet(StreamplotSet): arrows: object +@dataclass +class _CurvedQuiverTrajectory: + x: list[float] + y: list[float] + hit_edge: bool + end_direction: tuple[float, float] | None + + class _DomainMap(object): """Map representing different coordinate systems. @@ -197,7 +205,7 @@ def get_integrator( minlength: float, resolution: float, magnitude: np.ndarray, - ) -> Callable[[float, float], tuple[tuple[list[float], list[float]], bool] | None]: + ) -> Callable[[float, float], _CurvedQuiverTrajectory | None]: # rescale velocity onto grid-coordinates for integrations. u, v = self.domain_map.data2grid(u, v) @@ -215,9 +223,7 @@ def forward_time(xi: float, yi: float) -> tuple[float, float]: vi = self.interpgrid(v, xi, yi) return ui * dt_ds, vi * dt_ds - def integrate( - x0: float, y0: float - ) -> tuple[tuple[list[float], list[float], bool]] | None: + def integrate(x0: float, y0: float) -> _CurvedQuiverTrajectory | None: """Return x, y grid-coordinates of trajectory based on starting point. Integrate both forward and backward in time from starting point @@ -226,15 +232,26 @@ def integrate( occupied cell in the StreamMask. The resulting trajectory is None if it is shorter than `minlength`. """ - stotal, x_traj, y_traj = 0.0, [], [] self.domain_map.start_trajectory(x0, y0) self.domain_map.reset_start_point(x0, y0) - stotal, x_traj, y_traj, m_total, hit_edge = self.integrate_rk12( + x_traj, y_traj, hit_edge = self.integrate_rk12( x0, y0, forward_time, resolution, magnitude ) if len(x_traj) > 1: - return (x_traj, y_traj), hit_edge + end_dx = x_traj[-1] - x_traj[-2] + end_dy = y_traj[-1] - y_traj[-2] + end_direction = ( + None + if end_dx == 0 and end_dy == 0 + else self.domain_map.grid2data(end_dx, end_dy) + ) + return _CurvedQuiverTrajectory( + x=x_traj, + y=y_traj, + hit_edge=hit_edge, + end_direction=end_direction, + ) else: # reject short trajectories self.domain_map.undo_trajectory() @@ -249,7 +266,7 @@ def integrate_rk12( f: Callable[[float, float], tuple[float, float]], resolution: float, magnitude: np.ndarray, - ) -> tuple[float, list[float], list[float], list[float], bool]: + ) -> tuple[list[float], list[float], bool]: """2nd-order Runge-Kutta algorithm with adaptive step size. This method is also referred to as the improved Euler's method, or @@ -351,7 +368,7 @@ def integrate_rk12( else: ds = min(maxds, 0.85 * ds * (maxerror / error) ** 0.5) - return stotal, xf_traj, yf_traj, m_total, hit_edge + return xf_traj, yf_traj, hit_edge def euler_step(self, xf_traj, yf_traj, f): """Simple Euler integration step that extends streamline to boundary.""" diff --git a/ultraplot/tests/test_plot.py b/ultraplot/tests/test_plot.py index 299206302..9cf014696 100644 --- a/ultraplot/tests/test_plot.py +++ b/ultraplot/tests/test_plot.py @@ -764,8 +764,8 @@ def test_generate_start_points(): def test_calculate_trajectories(): """ - Test that CurvedQuiverSolver.get_integrator returns callable for each seed point - and returns lists of trajectories and edges of correct length. + Test that CurvedQuiverSolver.get_integrator returns trajectory objects for each + seed point with the expected rendering metadata. """ from ultraplot.axes.plot_types.curved_quiver import CurvedQuiverSolver @@ -781,6 +781,17 @@ def test_calculate_trajectories(): seeds = solver.gen_starting_points(x, y, grains=2) results = [integrator(pt[0], pt[1]) for pt in seeds] assert len(results) == seeds.shape[0] + trajectories = [result for result in results if result is not None] + assert trajectories + for trajectory in trajectories: + assert len(trajectory.x) == len(trajectory.y) + assert isinstance(trajectory.hit_edge, bool) + if trajectory.end_direction is not None: + expected = solver.domain_map.grid2data( + trajectory.x[-1] - trajectory.x[-2], + trajectory.y[-1] - trajectory.y[-2], + ) + assert np.allclose(trajectory.end_direction, expected) @pytest.mark.mpl_image_compare From f90453c237d69d79d65af1ef900217c8d545b375 Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Tue, 7 Apr 2026 18:43:17 +1000 Subject: [PATCH 4/4] Teach curved_quiver to participate in the shared colorbar guide path when its streamline colors come from array data, so it behaves like the other color-mapped field plots. This adds the familiar colorbar arguments to the method signature, documents that they apply to array-valued colors, and verifies that the created colorbar is registered against the streamline LineCollection with the requested label. --- ultraplot/axes/plot.py | 11 +++++++++++ ultraplot/tests/test_plot.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/ultraplot/axes/plot.py b/ultraplot/axes/plot.py index dd33c12a3..ae387927f 100644 --- a/ultraplot/axes/plot.py +++ b/ultraplot/axes/plot.py @@ -193,6 +193,8 @@ Width of streamlines. cmap, norm : optional Colormap and normalization for array colors. +colorbar, colorbar_kw : optional + Add a colorbar for array-valued streamline colors. arrowsize : float, optional Arrow size scaling. arrowstyle : str, optional @@ -1918,6 +1920,8 @@ def curved_quiver( grains: Optional[int] = None, density: Optional[int] = None, arrow_at_end: Optional[bool] = None, + colorbar: Optional[str] = None, + colorbar_kw: Optional[dict[str, Any]] = None, ): """ %(plot.curved_quiver)s @@ -1935,6 +1939,7 @@ def curved_quiver( zorder = _not_none(zorder, mlines.Line2D.zorder) transform = _not_none(transform, self.transData) color = _not_none(color, self._get_lines.get_next_color()) + colorbar_kw = colorbar_kw or {} linewidth = _not_none(linewidth, rc["lines.linewidth"]) scale = _not_none(scale, rc["curved_quiver.scale"]) grains = _not_none(grains, rc["curved_quiver.grains"]) @@ -2123,6 +2128,12 @@ def curved_quiver( lc.set_array(np.ma.hstack(line_colors)) lc.set_cmap(cmap) lc.set_norm(norm) + self._update_guide( + lc, + colorbar=colorbar, + colorbar_kw=colorbar_kw, + queue_colorbar=False, + ) self.add_collection(lc) self.autoscale_view() diff --git a/ultraplot/tests/test_plot.py b/ultraplot/tests/test_plot.py index 9cf014696..652f71856 100644 --- a/ultraplot/tests/test_plot.py +++ b/ultraplot/tests/test_plot.py @@ -843,6 +843,35 @@ def test_curved_quiver_nan_vectors(): uplt.close(fig) +def test_curved_quiver_colorbar_argument(): + """ + Test that curved_quiver forwards array colors to the shared colorbar guide path. + """ + x = np.linspace(-1, 1, 11) + y = np.linspace(-1, 1, 11) + X, Y = np.meshgrid(x, y) + U = -Y + V = X + speed = np.sqrt(U**2 + V**2) + + fig, ax = uplt.subplots() + m = ax.curved_quiver( + X, + Y, + U, + V, + color=speed, + colorbar="r", + colorbar_kw={"label": "speed"}, + ) + + assert ("right", "center") in ax[0]._colorbar_dict + cbar = ax[0]._colorbar_dict[("right", "center")] + assert cbar.mappable is m.lines + assert cbar.ax.get_ylabel() == "speed" + uplt.close(fig) + + @pytest.mark.mpl_image_compare @pytest.mark.parametrize( "cmap",