diff --git a/autogalaxy/__init__.py b/autogalaxy/__init__.py index 79ac197a..5ac5bd04 100644 --- a/autogalaxy/__init__.py +++ b/autogalaxy/__init__.py @@ -58,7 +58,6 @@ from .analysis.adapt_images.adapt_images import galaxy_name_image_dict_via_result_from from . import aggregator as agg from . import exc -from . import plot from . import util from .ellipse.dataset_interp import DatasetInterp from .ellipse.ellipse.ellipse import Ellipse @@ -120,3 +119,12 @@ from autoconf.fitsable import hdu_list_for_output_from __version__ = "2026.4.5.3" + + +def __getattr__(name): + if name == "plot": + from . import plot + + globals()["plot"] = plot + return plot + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/autogalaxy/analysis/plotter.py b/autogalaxy/analysis/plotter.py index 6f9bbb87..97928974 100644 --- a/autogalaxy/analysis/plotter.py +++ b/autogalaxy/analysis/plotter.py @@ -4,11 +4,11 @@ if TYPE_CHECKING: from pathlib import Path + from autoarray.plot.output import Output from autoconf import conf import autoarray as aa -import autoarray.plot as aplt from autogalaxy.analysis.adapt_images.adapt_images import AdaptImages from autogalaxy.galaxy.galaxy import Galaxy @@ -68,9 +68,11 @@ def fmt(self) -> List[str]: except KeyError: return conf.instance["visualize"]["plots"]["format"] - def output_from(self) -> aplt.Output: + def output_from(self) -> "Output": """Return an ``autoarray`` ``Output`` object pointed at ``image_path``.""" - return aplt.Output(path=self.image_path, format=self.fmt) + from autoarray.plot.output import Output + + return Output(path=self.image_path, format=self.fmt) def galaxies( self, diff --git a/autogalaxy/ellipse/plot/fit_ellipse_plot_util.py b/autogalaxy/ellipse/plot/fit_ellipse_plot_util.py index 68bf8982..7c12d7e2 100644 --- a/autogalaxy/ellipse/plot/fit_ellipse_plot_util.py +++ b/autogalaxy/ellipse/plot/fit_ellipse_plot_util.py @@ -1,9 +1,10 @@ import itertools -import matplotlib.pyplot as plt import numpy as np def plot_ellipse_residuals(array, fit_list, colors, output, for_subplot: bool = False): + import matplotlib.pyplot as plt + """Plot the 1-D ellipse residuals as a function of position angle. For each :class:`~autogalaxy.ellipse.fit_ellipse.FitEllipse` in diff --git a/autogalaxy/ellipse/plot/fit_ellipse_plots.py b/autogalaxy/ellipse/plot/fit_ellipse_plots.py index cbbd40c4..745ce5d6 100644 --- a/autogalaxy/ellipse/plot/fit_ellipse_plots.py +++ b/autogalaxy/ellipse/plot/fit_ellipse_plots.py @@ -1,15 +1,13 @@ import numpy as np import math -import matplotlib.pyplot as plt from typing import List, Optional import autoarray as aa -from autoarray import plot as aplt -from autoarray.plot.utils import conf_subplot_figsize, tight_layout +from autoarray.plot.utils import subplots, conf_subplot_figsize, tight_layout from autogalaxy.ellipse.plot import fit_ellipse_plot_util from autogalaxy.ellipse.fit_ellipse import FitEllipse -from autogalaxy.plot.plot_utils import plot_array, _save_subplot +from autogalaxy.util.plot_utils import plot_array, _save_subplot from autogalaxy.util import error_util @@ -108,7 +106,9 @@ def _plot_ellipse_residuals( ax : matplotlib.axes.Axes or None Reserved for future direct-axes support (currently unused). """ - output = aplt.Output(path=output_path, format=output_format) if output_path else aplt.Output() + from autoarray.plot.output import Output + + output = Output(path=output_path, format=output_format) if output_path else Output() fit_ellipse_plot_util.plot_ellipse_residuals( array=fit_list[0].dataset.data.native, @@ -148,7 +148,7 @@ def subplot_fit_ellipse( disable_data_contours : bool If ``True``, suppress ellipse contour overlays on the image panel. """ - fig, axes = plt.subplots(1, 2, figsize=conf_subplot_figsize(1, 2)) + fig, axes = subplots(1, 2, figsize=conf_subplot_figsize(1, 2)) _plot_data( fit_list=fit_list, @@ -213,7 +213,7 @@ def subplot_ellipse_errors( fit_ellipse_list[i].append(aa.Grid2DIrregular.from_yx_1d(y=y, x=x)) n = len(fit_ellipse_list) - fig, axes = plt.subplots(1, n, figsize=conf_subplot_figsize(1, n)) + fig, axes = subplots(1, n, figsize=conf_subplot_figsize(1, n)) axes_flat = [axes] if n == 1 else list(axes.flatten()) for i in range(n): diff --git a/autogalaxy/galaxy/plot/adapt_plots.py b/autogalaxy/galaxy/plot/adapt_plots.py index c3143425..f50ecb62 100644 --- a/autogalaxy/galaxy/plot/adapt_plots.py +++ b/autogalaxy/galaxy/plot/adapt_plots.py @@ -1,12 +1,11 @@ -import matplotlib.pyplot as plt import numpy as np from typing import Dict import autoarray as aa -from autoarray.plot.utils import conf_subplot_figsize, tight_layout +from autoarray.plot.utils import subplots, conf_subplot_figsize, tight_layout from autogalaxy.galaxy.galaxy import Galaxy -from autogalaxy.plot.plot_utils import plot_array, _save_subplot +from autogalaxy.util.plot_utils import plot_array, _save_subplot def subplot_adapt_images( @@ -46,7 +45,7 @@ def subplot_adapt_images( n = len(adapt_galaxy_name_image_dict) cols = min(n, 3) rows = (n + cols - 1) // cols - fig, axes = plt.subplots(rows, cols, figsize=conf_subplot_figsize(rows, cols)) + fig, axes = subplots(rows, cols, figsize=conf_subplot_figsize(rows, cols)) axes_list = [axes] if n == 1 else list(np.array(axes).flatten()) for i, (_, galaxy_image) in enumerate(adapt_galaxy_name_image_dict.items()): diff --git a/autogalaxy/galaxy/plot/galaxies_plots.py b/autogalaxy/galaxy/plot/galaxies_plots.py index 67f16add..5ee76ea9 100644 --- a/autogalaxy/galaxy/plot/galaxies_plots.py +++ b/autogalaxy/galaxy/plot/galaxies_plots.py @@ -1,11 +1,10 @@ -import matplotlib.pyplot as plt import numpy as np import autoarray as aa from autogalaxy.galaxy.galaxies import Galaxies -from autogalaxy.plot.plot_utils import _to_lines, _to_positions, plot_array, plot_grid, _save_subplot, _critical_curves_from -from autoarray.plot.utils import hide_unused_axes, conf_subplot_figsize, tight_layout +from autogalaxy.util.plot_utils import _to_lines, _to_positions, plot_array, plot_grid, _save_subplot, _critical_curves_from +from autoarray.plot.utils import subplots, hide_unused_axes, conf_subplot_figsize, tight_layout from autogalaxy import exc @@ -141,7 +140,7 @@ def _defl_x(): n = len(panels) cols = min(n, 3) rows = (n + cols - 1) // cols - fig, axes = plt.subplots(rows, cols, figsize=conf_subplot_figsize(rows, cols)) + fig, axes = subplots(rows, cols, figsize=conf_subplot_figsize(rows, cols)) axes_flat = [axes] if n == 1 else list(np.array(axes).flatten()) for i, (_, array, title, p, l) in enumerate(panels): @@ -201,7 +200,7 @@ def subplot_galaxy_images( gs = Galaxies(galaxies=galaxies) n = len(gs) - fig, axes = plt.subplots(1, n, figsize=conf_subplot_figsize(1, n)) + fig, axes = subplots(1, n, figsize=conf_subplot_figsize(1, n)) axes_flat = [axes] if n == 1 else list(axes.flatten()) for i in range(n): diff --git a/autogalaxy/galaxy/plot/galaxy_plots.py b/autogalaxy/galaxy/plot/galaxy_plots.py index 3505ffff..b83a55df 100644 --- a/autogalaxy/galaxy/plot/galaxy_plots.py +++ b/autogalaxy/galaxy/plot/galaxy_plots.py @@ -1,14 +1,13 @@ from __future__ import annotations -import matplotlib.pyplot as plt from typing import TYPE_CHECKING import autoarray as aa -from autoarray.plot.utils import conf_subplot_figsize, tight_layout +from autoarray.plot.utils import subplots, conf_subplot_figsize, tight_layout from autogalaxy.galaxy.galaxy import Galaxy from autogalaxy.profiles.light.abstract import LightProfile from autogalaxy.profiles.mass.abstract.abstract import MassProfile -from autogalaxy.plot.plot_utils import plot_array, _save_subplot +from autogalaxy.util.plot_utils import plot_array, _save_subplot def subplot_of_light_profiles( @@ -48,7 +47,7 @@ def subplot_of_light_profiles( return n = len(light_profiles) - fig, axes = plt.subplots(1, n, figsize=conf_subplot_figsize(1, n)) + fig, axes = subplots(1, n, figsize=conf_subplot_figsize(1, n)) axes_flat = [axes] if n == 1 else list(axes.flatten()) for i, lp in enumerate(light_profiles): @@ -130,7 +129,7 @@ def _deflections_x(mp): if not flag: continue - fig, axes = plt.subplots(1, n, figsize=conf_subplot_figsize(1, n)) + fig, axes = subplots(1, n, figsize=conf_subplot_figsize(1, n)) axes_flat = [axes] if n == 1 else list(axes.flatten()) for i, mp in enumerate(mass_profiles): diff --git a/autogalaxy/gui/clicker.py b/autogalaxy/gui/clicker.py index ab569a07..8590ec9c 100644 --- a/autogalaxy/gui/clicker.py +++ b/autogalaxy/gui/clicker.py @@ -1,9 +1,6 @@ import numpy as np -from matplotlib import pyplot as plt import autoarray as aa -import autoarray.plot as aplt -from autoarray.plot.utils import _conf_imshow_origin from autogalaxy import exc @@ -23,6 +20,9 @@ def __init__(self, image, pixel_scales, search_box_size, in_pixels: bool = False self.in_pixels = in_pixels def start(self, data, pixel_scales): + from matplotlib import pyplot as plt + import autoarray.plot as aplt + from autoarray.plot.utils import _conf_imshow_origin n_y, n_x = data.shape_native hw = int(n_x / 2) * pixel_scales diff --git a/autogalaxy/gui/scribbler.py b/autogalaxy/gui/scribbler.py index 4111b09f..2e03fc7b 100644 --- a/autogalaxy/gui/scribbler.py +++ b/autogalaxy/gui/scribbler.py @@ -1,11 +1,7 @@ from collections import OrderedDict import numpy as np -import matplotlib -import matplotlib.pyplot as plt from typing import Tuple -from autoarray.plot.utils import _conf_imshow_origin - class Scribbler: def __init__( @@ -60,6 +56,10 @@ def __init__( extent = (x0_pix, x1_pix, y0_pix, y1_pix) + import matplotlib + import matplotlib.pyplot as plt + from autoarray.plot.utils import _conf_imshow_origin + matplotlib.use(backend) self.im = image diff --git a/autogalaxy/imaging/plot/fit_imaging_plots.py b/autogalaxy/imaging/plot/fit_imaging_plots.py index 58bf0660..a2148fd0 100644 --- a/autogalaxy/imaging/plot/fit_imaging_plots.py +++ b/autogalaxy/imaging/plot/fit_imaging_plots.py @@ -1,12 +1,11 @@ -import matplotlib.pyplot as plt from pathlib import Path import autoarray as aa from autoconf.fitsable import hdu_list_for_output_from -from autoarray.plot.utils import conf_subplot_figsize, tight_layout +from autoarray.plot.utils import subplots, conf_subplot_figsize, tight_layout from autogalaxy.imaging.fit_imaging import FitImaging -from autogalaxy.plot.plot_utils import plot_array, _save_subplot +from autogalaxy.util.plot_utils import plot_array, _save_subplot def subplot_fit( @@ -50,7 +49,7 @@ def subplot_fit( (fit.chi_squared_map, "Chi-Squared Map", r"$\chi^2$"), ] n = len(panels) - fig, axes = plt.subplots(1, n, figsize=conf_subplot_figsize(1, n)) + fig, axes = subplots(1, n, figsize=conf_subplot_figsize(1, n)) axes_flat = list(axes.flatten()) for i, (array, title, cb_unit) in enumerate(panels): @@ -116,7 +115,7 @@ def subplot_of_galaxy( ), ] n = len(panels) - fig, axes = plt.subplots(1, n, figsize=conf_subplot_figsize(1, n)) + fig, axes = subplots(1, n, figsize=conf_subplot_figsize(1, n)) axes_flat = list(axes.flatten()) for i, (array, title) in enumerate(panels): @@ -156,7 +155,7 @@ def subplot_fit_imaging_list( File format string or list, e.g. ``"png"`` or ``["png"]``. """ n = len(fit_list) - fig, axes = plt.subplots(n, 5, figsize=conf_subplot_figsize(n, 5)) + fig, axes = subplots(n, 5, figsize=conf_subplot_figsize(n, 5)) if n == 1: axes = [axes] for i, fit in enumerate(fit_list): diff --git a/autogalaxy/interferometer/plot/fit_interferometer_plots.py b/autogalaxy/interferometer/plot/fit_interferometer_plots.py index 1e30fd35..25636d6f 100644 --- a/autogalaxy/interferometer/plot/fit_interferometer_plots.py +++ b/autogalaxy/interferometer/plot/fit_interferometer_plots.py @@ -1,15 +1,14 @@ -import matplotlib.pyplot as plt import numpy as np from pathlib import Path import autoarray as aa from autoconf.fitsable import hdu_list_for_output_from from autoarray.plot import plot_visibilities_1d -from autoarray.plot.utils import conf_subplot_figsize, tight_layout +from autoarray.plot.utils import subplots, conf_subplot_figsize, tight_layout from autogalaxy.interferometer.fit_interferometer import FitInterferometer from autogalaxy.galaxy.plot import galaxies_plots -from autogalaxy.plot.plot_utils import plot_array, _save_subplot +from autogalaxy.util.plot_utils import plot_array, _save_subplot def subplot_fit( @@ -48,7 +47,7 @@ def subplot_fit( (fit.chi_squared_map, "Chi-Squared Map"), ] n = len(panels) - fig, axes = plt.subplots(1, n, figsize=conf_subplot_figsize(1, n)) + fig, axes = subplots(1, n, figsize=conf_subplot_figsize(1, n)) axes_flat = list(axes.flatten()) for i, (vis, title) in enumerate(panels): @@ -97,7 +96,7 @@ def subplot_fit_dirty_images( (fit.dirty_chi_squared_map, "Dirty Chi-Squared Map", r"$\chi^2$"), ] n = len(panels) - fig, axes = plt.subplots(1, n, figsize=conf_subplot_figsize(1, n)) + fig, axes = subplots(1, n, figsize=conf_subplot_figsize(1, n)) axes_flat = list(axes.flatten()) for i, (array, title, cb_unit) in enumerate(panels): @@ -165,7 +164,7 @@ def subplot_fit_real_space( (fit.dirty_residual_map, "Dirty Residual Map"), ] n = len(panels) - fig, axes = plt.subplots(1, n, figsize=conf_subplot_figsize(1, n)) + fig, axes = subplots(1, n, figsize=conf_subplot_figsize(1, n)) axes_flat = list(axes.flatten()) for i, (array, title) in enumerate(panels): plot_array( diff --git a/autogalaxy/plot/__init__.py b/autogalaxy/plot/__init__.py index aa1b3529..b679de0b 100644 --- a/autogalaxy/plot/__init__.py +++ b/autogalaxy/plot/__init__.py @@ -6,7 +6,7 @@ output_figure, ) -from autogalaxy.plot.plot_utils import plot_array, plot_grid, fits_array +from autogalaxy.util.plot_utils import plot_array, plot_grid, fits_array from autoarray.dataset.plot.imaging_plots import ( subplot_imaging_dataset, diff --git a/autogalaxy/plot/plot_utils.py b/autogalaxy/plot/plot_utils.py index 28a66de8..0e631ae0 100644 --- a/autogalaxy/plot/plot_utils.py +++ b/autogalaxy/plot/plot_utils.py @@ -1,7 +1,6 @@ import logging import os import numpy as np -import matplotlib.pyplot as plt logger = logging.getLogger(__name__) @@ -74,6 +73,7 @@ def _save_subplot(fig, output_path, output_filename, output_format=None, For FITS output use the dedicated ``fits_*`` functions instead. """ + import matplotlib.pyplot as plt from autoarray.plot.utils import _output_mode_save, _conf_output_format, _FAST_PLOTS if _output_mode_save(fig, output_filename): diff --git a/autogalaxy/profiles/plot/basis_plots.py b/autogalaxy/profiles/plot/basis_plots.py index 3fb1c13e..274da8de 100644 --- a/autogalaxy/profiles/plot/basis_plots.py +++ b/autogalaxy/profiles/plot/basis_plots.py @@ -1,11 +1,10 @@ -import matplotlib.pyplot as plt import numpy as np import autoarray as aa -from autoarray.plot.utils import conf_subplot_figsize, tight_layout +from autoarray.plot.utils import subplots, conf_subplot_figsize, tight_layout from autogalaxy.profiles.basis import Basis -from autogalaxy.plot.plot_utils import _to_positions, plot_array, _save_subplot +from autogalaxy.util.plot_utils import _to_positions, plot_array, _save_subplot from autogalaxy import exc @@ -60,7 +59,7 @@ def subplot_image( n = len(basis.light_profile_list) cols = min(n, 4) rows = (n + cols - 1) // cols - fig, axes = plt.subplots(rows, cols, figsize=conf_subplot_figsize(rows, cols)) + fig, axes = subplots(rows, cols, figsize=conf_subplot_figsize(rows, cols)) axes_flat = [axes] if n == 1 else list(np.array(axes).flatten()) _positions = _to_positions(positions) diff --git a/autogalaxy/quantity/plot/fit_quantity_plots.py b/autogalaxy/quantity/plot/fit_quantity_plots.py index 12e7396b..8f2d35fc 100644 --- a/autogalaxy/quantity/plot/fit_quantity_plots.py +++ b/autogalaxy/quantity/plot/fit_quantity_plots.py @@ -1,10 +1,9 @@ -import matplotlib.pyplot as plt import autoarray as aa -from autoarray.plot.utils import conf_subplot_figsize, tight_layout +from autoarray.plot.utils import subplots, conf_subplot_figsize, tight_layout from autogalaxy.quantity.fit_quantity import FitQuantity -from autogalaxy.plot.plot_utils import plot_array, _save_subplot +from autogalaxy.util.plot_utils import plot_array, _save_subplot def _subplot_fit_array(fit, output_path, output_format, colormap, use_log10, positions, filename="fit"): @@ -43,7 +42,7 @@ def _subplot_fit_array(fit, output_path, output_format, colormap, use_log10, pos (fit.chi_squared_map, "Chi-Squared Map"), ] n = len(panels) - fig, axes = plt.subplots(1, n, figsize=conf_subplot_figsize(1, n)) + fig, axes = subplots(1, n, figsize=conf_subplot_figsize(1, n)) axes_flat = list(axes.flatten()) for i, (array, title) in enumerate(panels): diff --git a/autogalaxy/util/plot_utils.py b/autogalaxy/util/plot_utils.py new file mode 100644 index 00000000..0e631ae0 --- /dev/null +++ b/autogalaxy/util/plot_utils.py @@ -0,0 +1,447 @@ +import logging +import os +import numpy as np + +logger = logging.getLogger(__name__) + + +def _to_lines(*items): + """Convert multiple line sources into a flat list of (N,2) numpy arrays. + + Each item may be ``None`` (skipped), a list of line-like objects, or a + single line-like object. A line-like object is anything that either has + an ``.array`` attribute or can be coerced to a 2-D numpy array with shape + ``(N, 2)``. Items that cannot be converted, or that are empty, are + silently dropped. + + Parameters + ---------- + *items + Any number of line sources to merge. + + Returns + ------- + list of np.ndarray or None + A flat list of ``(N, 2)`` arrays, or ``None`` if nothing valid was + found. + """ + result = [] + for item in items: + if item is None: + continue + if isinstance(item, list): + for sub in item: + try: + arr = np.array(sub.array if hasattr(sub, "array") else sub) + if arr.ndim == 2 and arr.shape[1] == 2 and len(arr) > 0: + result.append(arr) + except Exception: + pass + else: + try: + arr = np.array(item.array if hasattr(item, "array") else item) + if arr.ndim == 2 and arr.shape[1] == 2 and len(arr) > 0: + result.append(arr) + except Exception: + pass + return result or None + + +def _to_positions(*items): + """Convert multiple position sources into a flat list of (N,2) numpy arrays. + + Thin wrapper around :func:`_to_lines` — positions and lines share the same + underlying representation (lists of ``(N, 2)`` coordinate arrays). + + Parameters + ---------- + *items + Any number of position sources to merge. + + Returns + ------- + list of np.ndarray or None + A flat list of ``(N, 2)`` arrays, or ``None`` if nothing valid was + found. + """ + return _to_lines(*items) + + +def _save_subplot(fig, output_path, output_filename, output_format=None, + dpi=300): + """Save a subplot figure to disk (or show it if output_format/output_path say so). + + For FITS output use the dedicated ``fits_*`` functions instead. + """ + import matplotlib.pyplot as plt + from autoarray.plot.utils import _output_mode_save, _conf_output_format, _FAST_PLOTS + + if _output_mode_save(fig, output_filename): + return + + if _FAST_PLOTS: + plt.close(fig) + return + + fmt = output_format[0] if isinstance(output_format, (list, tuple)) else (output_format or _conf_output_format()) + if fmt == "show" or not output_path: + plt.show() + else: + os.makedirs(str(output_path), exist_ok=True) + fpath = os.path.join(str(output_path), f"{output_filename}.{fmt}") + fig.savefig(fpath, dpi=dpi, bbox_inches="tight", pad_inches=0.1) + plt.close(fig) + + +def _resolve_colormap(colormap): + """Resolve 'default' or None to the autoarray default colormap.""" + if colormap in ("default", None): + from autoarray.plot.utils import _default_colormap + return _default_colormap() + return colormap + + +def _resolve_format(output_format): + """Normalise output_format: accept a list/tuple or a plain string.""" + from autoarray.plot.utils import _conf_output_format + + if isinstance(output_format, (list, tuple)): + return output_format[0] + return output_format or _conf_output_format() + + +def _numpy_grid(grid): + """Convert a grid-like object to a numpy array, or return None.""" + if grid is None: + return None + try: + return np.array(grid.array if hasattr(grid, "array") else grid) + except Exception: + return None + + +def plot_array( + array, + title="", + output_path=None, + output_filename="array", + output_format=None, + colormap="default", + use_log10=False, + vmin=None, + vmax=None, + symmetric=False, + positions=None, + lines=None, + line_colors=None, + grid=None, + cb_unit=None, + ax=None, +): + """Plot an autoarray ``Array2D`` to file or onto an existing ``Axes``. + + All array preprocessing (zoom, mask-edge extraction, native/extent + unpacking) is handled internally so callers never need to duplicate it. + The actual rendering is delegated to ``autoarray.plot.plot_array``. + + Parameters + ---------- + array + The ``Array2D`` (or array-like) to plot. + title : str + Title displayed above the panel. + output_path : str or None + Directory in which to save the figure. ``None`` → call + ``plt.show()`` instead. + output_filename : str + Stem of the output file name (extension is added from + *output_format*). + output_format : str + File format, e.g. ``"png"`` or ``"pdf"``. + colormap : str + Matplotlib colormap name, or ``"default"`` to use the autoarray + default (``"jet"``). + use_log10 : bool + If ``True`` apply a log₁₀ stretch to the array values. + vmin, vmax : float or None + Explicit colour-bar limits. Ignored when *symmetric* is ``True``. + symmetric : bool + If ``True`` set ``vmin = -vmax`` so that zero maps to the middle of + the colormap. + positions : list or array-like or None + Point positions to scatter-plot over the image. + lines : list or array-like or None + Line coordinates to overlay on the image. + line_colors : list or None + Colours for each entry in *lines*. + grid : array-like or None + An additional grid of points to overlay. + ax : matplotlib.axes.Axes or None + Existing ``Axes`` to draw into. When provided the figure is *not* + saved — the caller is responsible for saving. + """ + from autoarray.plot import plot_array as _aa_plot_array + + colormap = _resolve_colormap(colormap) + output_format = _resolve_format(output_format) + + if symmetric: + try: + arr = array.native.array + except AttributeError: + arr = np.asarray(array) + finite = arr[np.isfinite(arr)] + abs_max = float(np.max(np.abs(finite))) if len(finite) > 0 else 1.0 + vmin, vmax = -abs_max, abs_max + + _positions_list = positions if isinstance(positions, list) else _to_positions(positions) + _lines_list = lines if isinstance(lines, list) else _to_lines(lines) + + if ax is not None: + _output_path = None + else: + _output_path = output_path if output_path is not None else "." + + _aa_plot_array( + array=array, + ax=ax, + grid=_numpy_grid(grid), + positions=_positions_list, + lines=_lines_list, + line_colors=line_colors, + title=title or "", + colormap=colormap, + use_log10=use_log10, + vmin=vmin, + vmax=vmax, + cb_unit=cb_unit, + output_path=_output_path, + output_filename=output_filename, + output_format=output_format, + ) + + +def _fits_values_and_header(array): + """Extract raw numpy values and header dict from an autoarray object. + + Returns ``(values, header_dict, ext_name)`` where *header_dict* and + *ext_name* may be ``None`` for plain arrays. + """ + from autoarray.structures.visibilities import AbstractVisibilities + from autoarray.mask.abstract_mask import Mask + + if isinstance(array, AbstractVisibilities): + return np.asarray(array.in_array), None, None + if isinstance(array, Mask): + header = array.header_dict if hasattr(array, "header_dict") else None + return np.asarray(array.astype("float")), header, "mask" + if hasattr(array, "native"): + try: + header = array.mask.header_dict + except (AttributeError, TypeError): + header = None + return np.asarray(array.native.array).astype("float"), header, None + + return np.asarray(array), None, None + + +def fits_array(array, file_path, overwrite=False, ext_name=None): + """Write an autoarray ``Array2D``, ``Mask2D``, or array-like to a ``.fits`` file. + + Handles header metadata (pixel scales, origin) automatically for + autoarray objects. + + Parameters + ---------- + array + The data to write. + file_path : str or Path + Full path including filename and ``.fits`` extension. + overwrite : bool + If ``True`` an existing file at *file_path* is replaced. + ext_name : str or None + FITS extension name. Auto-detected for masks (``"mask"``). + """ + from autoconf.fitsable import output_to_fits + + values, header_dict, auto_ext_name = _fits_values_and_header(array) + if ext_name is None: + ext_name = auto_ext_name + + output_to_fits( + values=values, + file_path=file_path, + overwrite=overwrite, + header_dict=header_dict, + ext_name=ext_name, + ) + + +def plot_grid( + grid, + title="", + output_path=None, + output_filename="grid", + output_format=None, + lines=None, + ax=None, +): + """Plot an autoarray ``Grid2D`` as a scatter plot. + + Delegates to ``autoarray.plot.plot_grid`` after converting the grid to a + plain numpy array. + + Parameters + ---------- + grid + The ``Grid2D`` (or grid-like) to plot. + title : str + Title displayed above the panel. + output_path : str or None + Directory in which to save the figure. ``None`` → call + ``plt.show()`` instead. + output_filename : str + Stem of the output file name. + output_format : str + File format, e.g. ``"png"``. + lines : list or None + Line coordinates to overlay on the grid plot. + ax : matplotlib.axes.Axes or None + Existing ``Axes`` to draw into. + """ + from autoarray.plot import plot_grid as _aa_plot_grid + + output_format = _resolve_format(output_format) + + if ax is not None: + _output_path = None + else: + _output_path = output_path if output_path is not None else "." + + _aa_plot_grid( + grid=np.array(grid.array if hasattr(grid, "array") else grid), + ax=ax, + title=title or "", + output_path=_output_path, + output_filename=output_filename, + output_format=output_format, + ) + + +def _critical_curves_method(): + """Read ``general.critical_curves_method`` from the visualize config. + + Returns ``"marching_squares"`` (the default) or ``"zero_contour"``. + Any unrecognised value falls back to ``"marching_squares"`` with a warning. + """ + from autoconf import conf + + try: + method = conf.instance["visualize"]["general"]["general"]["critical_curves_method"] + except (KeyError, TypeError): + method = "marching_squares" + + if method not in ("zero_contour", "marching_squares"): + logger.warning( + f"visualize/general.yaml: unrecognised critical_curves_method " + f"'{method}'. Falling back to 'marching_squares'." + ) + return "marching_squares" + return method + + +def _caustics_from(mass_obj, grid): + """Compute tangential and radial caustics for a mass object via LensCalc. + + The algorithm used is controlled by ``general.critical_curves_method`` in + ``visualize/general.yaml``: + + - ``"zero_contour"`` *(default)* — uses ``jax_zero_contour`` to trace the + zero contour of each eigen value directly. No dense evaluation grid is + needed; a coarse 25 × 25 scan finds the seed points automatically. + - ``"marching_squares"`` — evaluates eigen values on the full *grid* and + uses marching squares to find the contours. + + Parameters + ---------- + mass_obj + Any object understood by ``LensCalc.from_mass_obj`` (e.g. a + :class:`~autogalaxy.galaxy.galaxies.Galaxies` or autolens ``Tracer``). + grid : aa.type.Grid2DLike + The grid on which to evaluate the caustics (used only for the + ``"marching_squares"`` path; ignored by ``"zero_contour"``). + + Returns + ------- + tuple[list, list] + ``(tangential_caustics, radial_caustics)``. + """ + if os.environ.get("PYAUTO_DISABLE_CRITICAL_CAUSTICS") == "1": + return [], [] + + from autogalaxy.operate.lens_calc import LensCalc + + od = LensCalc.from_mass_obj(mass_obj) + method = _critical_curves_method() + + if method == "zero_contour": + tan_ca = od.tangential_caustic_list_via_zero_contour_from() + rad_ca = od.radial_caustic_list_via_zero_contour_from() + else: + tan_ca = od.tangential_caustic_list_from(grid=grid) + rad_ca = od.radial_caustic_list_from(grid=grid) + + return tan_ca, rad_ca + + +def _critical_curves_from(mass_obj, grid, tc=None, rc=None): + """Compute tangential and radial critical curves for a mass object. + + If *tc* is already provided it is returned unchanged (along with *rc*), + allowing callers to cache the curves across multiple plot calls. + + The algorithm used when *tc* is ``None`` is controlled by + ``general.critical_curves_method`` in ``visualize/general.yaml``: +/btw ok + - ``"zero_contour"`` *(default)* — uses ``jax_zero_contour``; no dense + grid needed, seed points found automatically via a coarse grid scan. + - ``"marching_squares"`` — evaluates eigen values on the full *grid* and + uses marching squares. Radial critical curves are only computed when at + least one radial critical-curve area exceeds the grid pixel scale. + + Parameters + ---------- + mass_obj + Any object understood by ``LensCalc.from_mass_obj``. + grid : aa.type.Grid2DLike + Evaluation grid (used only for the ``"marching_squares"`` path). + tc : list or None + Pre-computed tangential critical curves; ``None`` to trigger + computation. + rc : list or None + Pre-computed radial critical curves; ``None`` to trigger computation. + + Returns + ------- + tuple[list, list or None] + ``(tangential_critical_curves, radial_critical_curves)``. + """ + from autogalaxy.operate.lens_calc import LensCalc + + if os.environ.get("PYAUTO_DISABLE_CRITICAL_CAUSTICS") == "1": + return [], [] + + if tc is None: + od = LensCalc.from_mass_obj(mass_obj) + method = _critical_curves_method() + + if method == "zero_contour": + tc = od.tangential_critical_curve_list_via_zero_contour_from() + rc = od.radial_critical_curve_list_via_zero_contour_from() + else: + tc = od.tangential_critical_curve_list_from(grid=grid) + rc_area = od.radial_critical_curve_area_list_from(grid=grid) + if any(area > grid.pixel_scale for area in rc_area): + rc = od.radial_critical_curve_list_from(grid=grid) + + return tc, rc