diff --git a/ultraplot/axes/plot.py b/ultraplot/axes/plot.py index 3ba71f09f..ae387927f 100644 --- a/ultraplot/axes/plot.py +++ b/ultraplot/axes/plot.py @@ -193,6 +193,8 @@ Width of streamlines. cmap, norm : optional Colormap and normalization for array colors. +colorbar, colorbar_kw : optional + Add a colorbar for array-valued streamline colors. arrowsize : float, optional Arrow size scaling. arrowstyle : str, optional @@ -1918,6 +1920,8 @@ def curved_quiver( grains: Optional[int] = None, density: Optional[int] = None, arrow_at_end: Optional[bool] = None, + colorbar: Optional[str] = None, + colorbar_kw: Optional[dict[str, Any]] = None, ): """ %(plot.curved_quiver)s @@ -1935,6 +1939,7 @@ def curved_quiver( zorder = _not_none(zorder, mlines.Line2D.zorder) transform = _not_none(transform, self.transData) color = _not_none(color, self._get_lines.get_next_color()) + colorbar_kw = colorbar_kw or {} linewidth = _not_none(linewidth, rc["lines.linewidth"]) scale = _not_none(scale, rc["curved_quiver.scale"]) grains = _not_none(grains, rc["curved_quiver.grains"]) @@ -1968,6 +1973,7 @@ def curved_quiver( raise ValueError( "If 'linewidth' is given, must have the shape of 'Grid(x,y)'" ) + linewidth = np.ma.masked_invalid(linewidth) line_kw["linewidth"] = [] else: line_kw["linewidth"] = linewidth @@ -1990,7 +1996,6 @@ def curved_quiver( integrate = solver.get_integrator(u, v, minlength, resolution, magnitude) trajectories = [] - edges = [] if start_points is None: start_points = solver.gen_starting_points(x, y, grains) @@ -2026,18 +2031,19 @@ def curved_quiver( for xs, ys in sp2: xg, yg = solver.domain_map.data2grid(xs, ys) - t = integrate(xg, yg) - if t is not None: - trajectories.append(t[0]) - edges.append(t[1]) + trajectory = integrate(xg, yg) + if trajectory is not None: + trajectories.append(trajectory) streamlines = [] arrows = [] - for t, edge in zip(trajectories, edges): - tgx = np.array(t[0]) - tgy = np.array(t[1]) + for trajectory in trajectories: + tgx = np.array(trajectory.x) + tgy = np.array(trajectory.y) # Rescale from grid-coordinates to data-coordinates. - tx, ty = solver.domain_map.grid2data(*np.array(t)) + tx, ty = solver.domain_map.grid2data( + *np.array([trajectory.x, trajectory.y]) + ) tx += solver.grid.x_origin ty += solver.grid.y_origin @@ -2054,14 +2060,9 @@ def curved_quiver( continue arrow_tail = (tx[-1], ty[-1]) - - # Extrapolate to find arrow head - xg, yg = solver.domain_map.data2grid( - tx[-1] - solver.grid.x_origin, ty[-1] - solver.grid.y_origin - ) - - ui = solver.interpgrid(u, xg, yg) - vi = solver.interpgrid(v, xg, yg) + if trajectory.end_direction is None: + continue + ui, vi = trajectory.end_direction norm_v = np.sqrt(ui**2 + vi**2) if norm_v > 0: @@ -2087,6 +2088,8 @@ def curved_quiver( if isinstance(linewidth, np.ndarray): line_widths = solver.interpgrid(linewidth, tgx, tgy)[:-1] line_kw["linewidth"].extend(line_widths) + if np.ma.is_masked(line_widths[n]): + continue arrow_kw["linewidth"] = line_widths[n] if use_multicolor_lines: @@ -2094,7 +2097,7 @@ def curved_quiver( line_colors.append(color_values) arrow_kw["color"] = cmap(norm(color_values[n])) - if not edge: + if not trajectory.hit_edge: p = mpatches.FancyArrowPatch( arrow_tail, arrow_head, transform=transform, **arrow_kw ) @@ -2125,6 +2128,12 @@ def curved_quiver( lc.set_array(np.ma.hstack(line_colors)) lc.set_cmap(cmap) lc.set_norm(norm) + self._update_guide( + lc, + colorbar=colorbar, + colorbar_kw=colorbar_kw, + queue_colorbar=False, + ) self.add_collection(lc) self.autoscale_view() diff --git a/ultraplot/axes/plot_types/curved_quiver.py b/ultraplot/axes/plot_types/curved_quiver.py index 5489a0a90..639f51d93 100644 --- a/ultraplot/axes/plot_types/curved_quiver.py +++ b/ultraplot/axes/plot_types/curved_quiver.py @@ -20,6 +20,14 @@ class CurvedQuiverSet(StreamplotSet): arrows: object +@dataclass +class _CurvedQuiverTrajectory: + x: list[float] + y: list[float] + hit_edge: bool + end_direction: tuple[float, float] | None + + class _DomainMap(object): """Map representing different coordinate systems. @@ -197,7 +205,7 @@ def get_integrator( minlength: float, resolution: float, magnitude: np.ndarray, - ) -> Callable[[float, float], tuple[tuple[list[float], list[float]], bool] | None]: + ) -> Callable[[float, float], _CurvedQuiverTrajectory | None]: # rescale velocity onto grid-coordinates for integrations. u, v = self.domain_map.data2grid(u, v) @@ -215,9 +223,7 @@ def forward_time(xi: float, yi: float) -> tuple[float, float]: vi = self.interpgrid(v, xi, yi) return ui * dt_ds, vi * dt_ds - def integrate( - x0: float, y0: float - ) -> tuple[tuple[list[float], list[float], bool]] | None: + def integrate(x0: float, y0: float) -> _CurvedQuiverTrajectory | None: """Return x, y grid-coordinates of trajectory based on starting point. Integrate both forward and backward in time from starting point @@ -226,15 +232,26 @@ def integrate( occupied cell in the StreamMask. The resulting trajectory is None if it is shorter than `minlength`. """ - stotal, x_traj, y_traj = 0.0, [], [] self.domain_map.start_trajectory(x0, y0) self.domain_map.reset_start_point(x0, y0) - stotal, x_traj, y_traj, m_total, hit_edge = self.integrate_rk12( + x_traj, y_traj, hit_edge = self.integrate_rk12( x0, y0, forward_time, resolution, magnitude ) if len(x_traj) > 1: - return (x_traj, y_traj), hit_edge + end_dx = x_traj[-1] - x_traj[-2] + end_dy = y_traj[-1] - y_traj[-2] + end_direction = ( + None + if end_dx == 0 and end_dy == 0 + else self.domain_map.grid2data(end_dx, end_dy) + ) + return _CurvedQuiverTrajectory( + x=x_traj, + y=y_traj, + hit_edge=hit_edge, + end_direction=end_direction, + ) else: # reject short trajectories self.domain_map.undo_trajectory() @@ -249,7 +266,7 @@ def integrate_rk12( f: Callable[[float, float], tuple[float, float]], resolution: float, magnitude: np.ndarray, - ) -> tuple[float, list[float], list[float], list[float], bool]: + ) -> tuple[list[float], list[float], bool]: """2nd-order Runge-Kutta algorithm with adaptive step size. This method is also referred to as the improved Euler's method, or @@ -296,9 +313,14 @@ def integrate_rk12( hit_edge = False while self.domain_map.grid.within_grid(xi, yi): + try: + current_magnitude = self.interpgrid(magnitude, xi, yi) + except _CurvedQuiverTerminateTrajectory: + break + xf_traj.append(xi) yf_traj.append(yi) - m_total.append(self.interpgrid(magnitude, xi, yi)) + m_total.append(current_magnitude) try: k1x, k1y = f(xi, yi) @@ -324,8 +346,15 @@ def integrate_rk12( # Only save step if within error tolerance if error < maxerror: - xi += dx2 - yi += dy2 + next_xi = xi + dx2 + next_yi = yi + dy2 + if self.domain_map.grid.within_grid(next_xi, next_yi): + try: + self.interpgrid(magnitude, next_xi, next_yi) + except _CurvedQuiverTerminateTrajectory: + break + xi = next_xi + yi = next_yi self.domain_map.update_trajectory(xi, yi) if not self.domain_map.grid.within_grid(xi, yi): hit_edge = True @@ -339,7 +368,7 @@ def integrate_rk12( else: ds = min(maxds, 0.85 * ds * (maxerror / error) ** 0.5) - return stotal, xf_traj, yf_traj, m_total, hit_edge + return xf_traj, yf_traj, hit_edge def euler_step(self, xf_traj, yf_traj, f): """Simple Euler integration step that extends streamline to boundary.""" @@ -400,7 +429,7 @@ def interpgrid(self, a, xi, yi): if not isinstance(xi, np.ndarray): if np.ma.is_masked(ai): - raise _CurvedQuiverTerminateTrajectory + raise _CurvedQuiverTerminateTrajectory() return ai def gen_starting_points(self, x, y, grains): diff --git a/ultraplot/tests/test_plot.py b/ultraplot/tests/test_plot.py index 1f8811ed2..652f71856 100644 --- a/ultraplot/tests/test_plot.py +++ b/ultraplot/tests/test_plot.py @@ -671,6 +671,32 @@ def test_curved_quiver(rng): return fig +def test_curved_quiver_integrator_skips_nan_seed(): + """ + Test that masked seed points terminate cleanly instead of escaping the solver. + """ + from ultraplot.axes.plot_types.curved_quiver import CurvedQuiverSolver + + x = np.linspace(0, 1, 5) + y = np.linspace(0, 1, 5) + u = np.ones((5, 5)) + v = np.ones((5, 5)) + u[2, 2] = np.nan + v[2, 2] = np.nan + u = np.ma.masked_invalid(u) + v = np.ma.masked_invalid(v) + magnitude = np.sqrt(u**2 + v**2) + magnitude /= np.max(magnitude) + + solver = CurvedQuiverSolver(x, y, density=5) + integrator = solver.get_integrator( + u, v, minlength=0.1, resolution=1.0, magnitude=magnitude + ) + + assert integrator(2.0, 2.0) is None + assert not solver.mask._mask.any() + + def test_validate_vector_shapes_pass(): """ Test that vector shapes match the grid shape using CurvedQuiverSolver. @@ -738,8 +764,8 @@ def test_generate_start_points(): def test_calculate_trajectories(): """ - Test that CurvedQuiverSolver.get_integrator returns callable for each seed point - and returns lists of trajectories and edges of correct length. + Test that CurvedQuiverSolver.get_integrator returns trajectory objects for each + seed point with the expected rendering metadata. """ from ultraplot.axes.plot_types.curved_quiver import CurvedQuiverSolver @@ -755,6 +781,17 @@ def test_calculate_trajectories(): seeds = solver.gen_starting_points(x, y, grains=2) results = [integrator(pt[0], pt[1]) for pt in seeds] assert len(results) == seeds.shape[0] + trajectories = [result for result in results if result is not None] + assert trajectories + for trajectory in trajectories: + assert len(trajectory.x) == len(trajectory.y) + assert isinstance(trajectory.hit_edge, bool) + if trajectory.end_direction is not None: + expected = solver.domain_map.grid2data( + trajectory.x[-1] - trajectory.x[-2], + trajectory.y[-1] - trajectory.y[-2], + ) + assert np.allclose(trajectory.end_direction, expected) @pytest.mark.mpl_image_compare @@ -779,6 +816,62 @@ def test_curved_quiver_multicolor_lines(): return fig +def test_curved_quiver_nan_vectors(): + """ + Test that curved_quiver skips NaN vector regions without failing. + """ + x = np.linspace(-1, 1, 21) + y = np.linspace(-1, 1, 21) + X, Y = np.meshgrid(x, y) + U = -Y.copy() + V = X.copy() + speed = np.sqrt(U**2 + V**2) + invalid = (np.abs(X) < 0.2) & (np.abs(Y) < 0.2) + U[invalid] = np.nan + V[invalid] = np.nan + speed[invalid] = np.nan + + fig, ax = uplt.subplots() + m = ax.curved_quiver( + X, Y, U, V, color=speed, arrow_at_end=True, scale=2.0, grains=10 + ) + + segments = m.lines.get_segments() + assert segments + assert all(np.isfinite(segment).all() for segment in segments) + assert len(ax.patches) > 0 + uplt.close(fig) + + +def test_curved_quiver_colorbar_argument(): + """ + Test that curved_quiver forwards array colors to the shared colorbar guide path. + """ + x = np.linspace(-1, 1, 11) + y = np.linspace(-1, 1, 11) + X, Y = np.meshgrid(x, y) + U = -Y + V = X + speed = np.sqrt(U**2 + V**2) + + fig, ax = uplt.subplots() + m = ax.curved_quiver( + X, + Y, + U, + V, + color=speed, + colorbar="r", + colorbar_kw={"label": "speed"}, + ) + + assert ("right", "center") in ax[0]._colorbar_dict + cbar = ax[0]._colorbar_dict[("right", "center")] + assert cbar.mappable is m.lines + assert cbar.ax.get_ylabel() == "speed" + uplt.close(fig) + + @pytest.mark.mpl_image_compare @pytest.mark.parametrize( "cmap",