Source code for cdiutils.plot.slice

import warnings

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.image import AxesImage
from mpl_toolkits.axes_grid1 import AxesGrid

from cdiutils.plot.formatting import (
    CXI_VIEW_PARAMETERS,
    NATURAL_VIEW_PARAMETERS,
    XU_VIEW_PARAMETERS,
    add_colorbar,
    get_figure_size,
    get_x_y_limits_extents,
    set_x_y_limits_extents,
)
from cdiutils.utils import (
    extract_reduced_shape,
    get_centred_slices,
    nan_to_zero,
)


[docs] def plot_volume_slices( data: np.ndarray, support: np.ndarray = None, voxel_size: tuple | list = None, data_centre: tuple | list = None, views: tuple[str] = None, convention: str = None, title: str = None, equal_limits: bool = True, slice_shift: tuple | list = None, integrate: bool = False, opacity: np.ndarray = None, plot_type: str = "imshow", contour_levels: int = 100, show: bool = True, **plot_params, ) -> tuple[plt.Figure, plt.Axes]: """ Generic function for plotting 2D slices (cross section or sum, with option 'integrate') of 3D volumes. The slices are plotted according to the specified views and conventions. If not specified, natural views are plotted in matrix convention (x-axis: 2nd dim, y-axis: 1st dim), i.e: * first slice: taken at the centre of axis0 * second slice: taken at the centre of axis1 * third slice: taken at the centre of axis2 Args: data (np.ndarray): the data to plot. support (np.ndarray, optional): a support for the data. Defaults to None. voxel_size (tuple | list, optional): the voxel size to modify the aspect ratio accordingly. Defaults to None. data_centre (tuple | list, optional): the centre to take the data at. Defaults to None. views (tuple[str], optional): the views for each plot according to the provided convention. If None default views of the specified convention are plotted. Defaults to None. convention (str, optional): the convention employed to plot the multiple slices, if views not specified, will set the default views for the specified convention, i.e.: ("x-", "y+", "z-") for XU convention and ("z+", "y-", "x+") for the CXI convention. If None, natural views are plotted. Defaults to None. title (str, optional): the title of the plot. Defaults to None. equal_limits (bool, optional): whether to have the same limit extend for all axes. Defaults to True. slice_shift (tuple | list, optional): the shift in the slice selection, by default will use the centre for each dim. Defaults to None. integrate (bool, optional): whether to sum the data instead of taking the slice. Defaults to False. opacity (np.ndarray, optional): the opacity 3D array of the data. Defaults to None. If constant opacity is required, use the 'alpha' parameter. plot_type (str, optional): Type of plot to use. Options are 'imshow' or 'contourf'. Defaults to 'imshow'. contour_levels (int, optional): Number of contour levels when using 'contourf' plot type. Defaults to 100. show (bool, optional): whether to show the plot. Defaults to True. False might be useful if the function is only used for generating the axes that are then redrawn afterwards. **plot_params: additional plot params that will be parsed into the matplotlib imshow() function. Returns: tuple[plt.Figure, plt.Axes]: the generated figure and axes. """ _plot_params = {"cmap": "turbo"} if plot_params: _plot_params.update(plot_params) view_params = CXI_VIEW_PARAMETERS.copy() if convention is None: if views is None: # Simplest case, no swapping, no flipping etc. # For the default behaviour we use the 'natural views' view_params = NATURAL_VIEW_PARAMETERS.copy() views = ("dim0", "dim1", "dim2") elif convention.lower() in ("xu", "lab"): view_params = XU_VIEW_PARAMETERS.copy() # overwrite the params if views is None: views = ("x-", "y+", "z-") elif convention.lower() == "cxi": if views is None: views = ("z+", "y-", "x+") slices = get_centred_slices(data.shape, shift=slice_shift) shape = data.shape if support is not None: shape = extract_reduced_shape(support) if voxel_size is not None: extents = get_x_y_limits_extents(data.shape, voxel_size, data_centre) limits = get_x_y_limits_extents( shape, voxel_size, data_centre, equal_limits=equal_limits ) figure, axes = plt.subplots(1, 3, layout="tight", figsize=(6, 2)) for i, v in enumerate(views): plane = view_params[v]["plane"] to_plot = data.sum(axis=i) if integrate else data[slices[i]] _plot_params["alpha"] = np.ones_like(to_plot) if opacity is not None: _plot_params["alpha"] = opacity[slices[i]] if plane[0] > plane[1]: to_plot = np.swapaxes(to_plot, 1, 0) _plot_params["alpha"] = np.swapaxes(_plot_params["alpha"], 1, 0) if view_params[v]["xaxis_points_left"]: to_plot = to_plot[np.s_[:, ::-1]] _plot_params["alpha"] = _plot_params["alpha"][np.s_[:, ::-1]] # Handle plot type if plot_type in ("contourf", "contour"): ny, nx = to_plot.shape if voxel_size is not None: y_coords = np.linspace( extents[plane[0]][0], extents[plane[0]][1], ny ) x_coords = np.linspace( extents[plane[1]][0], extents[plane[1]][1], nx ) if view_params[v]["xaxis_points_left"]: x_coords = np.flip(x_coords) X, Y = np.meshgrid(x_coords, y_coords) else: X, Y = np.meshgrid(np.arange(nx), np.arange(ny)) alpha = _plot_params.pop("alpha", None) im = axes[i].contourf( X, Y, to_plot, levels=contour_levels, **_plot_params ) # 2D array of opacity is not supported in contourf, so we # need a workaround: we add a contourf with the alpha values if opacity is not None: whites = [ (1, 1, 1, 1 - i / (contour_levels - 1)) for i in range(contour_levels) ] axes[i].contourf( X, Y, alpha, levels=contour_levels, colors=whites ) add_colorbar(axes[i], im) axes[i].set_aspect("equal") elif plot_type == "imshow": im = axes[i].imshow(to_plot, **_plot_params) add_colorbar(axes[i], im) if voxel_size is not None: set_x_y_limits_extents( axes[i], extents, limits, plane, view_params[v]["xaxis_points_left"], ) else: raise ValueError( f"Unknown plot type '{plot_type}'. " "Options are 'imshow' or 'contourf'." ) figure.suptitle(title) if show: plt.show() else: plt.close(figure) return figure, axes
[docs] def plot_multiple_volume_slices( *data_arrays: np.ndarray, data_labels: list[str] = None, supports: list[np.ndarray] = None, voxel_sizes: list = None, data_centres: list = None, slice_shifts: list = None, data_stacking: str = "horizontal", pvs_args: dict = None, cbar_args: dict = None, xlim: tuple = None, ylim: tuple = None, remove_ticks: bool = False, figsize: tuple = None, title: str = None, show: bool = True, **plot_params, ) -> plt.Figure: """ Plot 2D slices of multiple 3D volumes with customizable layout. This function uses plot_volume_slices as a building block to create a composite figure comparing multiple datasets. Args: *data_arrays (np.ndarray): Multiple 3D arrays to plot. data_labels (list[str], optional): Labels for each dataset. supports (list[np.ndarray], optional): Support masks for each dataset. voxel_sizes (list, optional): List of voxel sizes for each dataset. data_centres (list, optional): List of data centers for each dataset. slice_shifts (list, optional): List of slice shifts for each dataset. data_stacking (str, optional): How to arrange plots ("vertical" or "horizontal"). Defaults to "vertical". pvs_args (dict, optional): Dictionary of parameters for plot_volume_slices. cbar_args (dict, optional): Dictionary of colorbar parameters. xlim (tuple, optional): Custom x-axis limits (min, max) to apply to all plots. ylim (tuple, optional): Custom y-axis limits (min, max) to apply to all plots. remove_ticks (bool, optional): Whether to remove ticks between subplots. Defaults to False. figsize (tuple, optional): Figure size. If None, calculated based on data. title (str, optional): Overall figure title. show (bool, optional): Whether to display the figure. Defaults to True. **plot_params: Additional plotting parameters passed to plot_volume_slices. Returns: plt.Figure: The generated figure. """ # Validate inputs n_datasets = len(data_arrays) if n_datasets == 0: raise ValueError("At least one dataset must be provided") # Setup pvs_args with defaults _pvs_args = { "views": None, "convention": None, "equal_limits": True, "integrate": False, "show": False, # Always False since we manage display ourselves } if pvs_args: _pvs_args.update(pvs_args) # Add any additional plot parameters _pvs_args.update(plot_params) # Determine the view labels based on convention convention = _pvs_args.get("convention") convention = convention.lower() if convention is not None else None if convention == "cxi": # Default CXI views view_labels = ["z+ (-x, y)", "y- (-x, z)", "x+ (z, y)"] elif convention in ("xu", "lab"): # Default XU views view_labels = ["x- (y, z)", "y- (x, z)", "z- (x, y)"] else: # Fallback to generic labels for unknown convention view_labels = [f"View {i + 1}" for i in range(3)] # Use specified views if provided if _pvs_args.get("views") is not None: view_labels = _pvs_args["views"] # Setup input lists with proper defaults input_lists = _prepare_input_lists( n_datasets, data_labels, supports, voxel_sizes, data_centres, slice_shifts, ) # Set up colorbar arguments show_cbar = False if cbar_args: show_cbar = True # if cbar_args is provided, we show cbar if "show" in cbar_args: show_cbar = cbar_args.pop("show") # Override "show" if specified _cbar_args = { "location": "right", "title": None, "extend": "both", "ticks": None, "size": "5%", "pad": "3%", } if cbar_args.get("location", "right") == "bottom": _cbar_args["pad"] = "10%" _cbar_args.update(cbar_args) # Determine layout parameters stacking_vertical = data_stacking.lower() in ("vertical", "v") # First get individual plots using plot_volume_slices individual_plots = _generate_individual_plots( data_arrays, input_lists, _pvs_args ) # the number of views is 3 because we are plotting 3 slices n_views = 3 # Determine grid layout based on stacking direction nrows, ncols = ( (n_datasets, n_views) if stacking_vertical else (n_views, n_datasets) ) # Determine figure size if not provided if figsize is None: figsize = (n_datasets, n_views) # horizontal stacking if stacking_vertical: figsize = (n_views, n_datasets) # Create the composite figure using AxesGrid composite_fig = plt.figure(figsize=figsize) grid = AxesGrid( composite_fig, 111, nrows_ncols=(nrows, ncols), axes_pad=0.05, share_all=False, cbar_mode="single" if show_cbar else None, cbar_location=_cbar_args["location"] if show_cbar else "right", cbar_size=_cbar_args["size"] if show_cbar else None, cbar_pad=_cbar_args["pad"] if show_cbar else None, ) # Track global min/max for colorbar vmin_global, vmax_global = float("inf"), float("-inf") images = [] # Copy each plot from individual figures to the composite figure for dataset_idx, (fig, axes) in enumerate(individual_plots): for view_idx, ax in enumerate(axes): if stacking_vertical: grid_idx = dataset_idx * n_views + view_idx else: # horizontal stacking grid_idx = view_idx * n_datasets + dataset_idx target_ax = grid[grid_idx] target_im = _copy_image_to_axes(ax.get_images()[0], target_ax) images.append(target_im) # Update global min/max for colorbar vmin, vmax = target_im.get_clim() vmin_global = min(vmin_global, vmin) vmax_global = max(vmax_global, vmax) # Copy axis limits and other properties target_ax.set_xlim(ax.get_xlim()) target_ax.set_ylim(ax.get_ylim()) # Apply custom x and y limits if provided if ( xlim is not None and input_lists["voxel_sizes"][dataset_idx] is not None ): target_ax.set_xlim(xlim) if ( ylim is not None and input_lists["voxel_sizes"][dataset_idx] is not None ): target_ax.set_ylim(ylim) # Add dataset and view labels _add_axis_labels( target_ax, stacking_vertical, view_idx, dataset_idx, input_lists["data_labels"][dataset_idx], view_labels[view_idx] if view_idx < len(view_labels) else f"View {view_idx + 1}", ) # Close the individual figure to free memory plt.close(fig) # Ensure all images use the same color scale for im in images: im.set_clim(vmin_global, vmax_global) # Remove ticks between subplots for i in range(nrows): for j in range(ncols): ax = grid[i * ncols + j] if remove_ticks: ax.set_xticks([]) ax.set_yticks([]) else: if j != 0: # Remove y-ticks for all columns except the first ax.tick_params( axis="y", which="both", left=False, right=False, labelleft=False, ) if i != nrows - 1: # Remove x-ticks for all rows except last ax.tick_params( axis="x", which="both", bottom=False, top=False, labelbottom=False, ) # Add colorbar if requested if show_cbar and images: # Check for unsupported colorbar positions if _cbar_args["location"] in ("left", "top"): raise NotImplementedError( "Colorbar location 'top' and 'left' are not currently " "supported. Please use 'right', 'bottom'." ) cbar = grid.cbar_axes[0].colorbar( images[0], extend=_cbar_args["extend"] ) if _cbar_args["title"]: orientation = ( "horizontal" if _cbar_args["location"] == "bottom" else "vertical" ) if orientation == "vertical": # For vert. colorbar, rotate the title and position it properly if _cbar_args["location"] == "right": cbar.ax.set_ylabel( _cbar_args["title"], rotation=270, labelpad=10, va="bottom", ) # Adjust the y label position to align with colorbar cbar.ax.yaxis.set_label_position("right") else: # horizontal if _cbar_args["location"] == "bottom": cbar.ax.set_xlabel( _cbar_args["title"], ha="center", labelpad=5 ) if _cbar_args["ticks"] is not None: cbar.set_ticks(_cbar_args["ticks"]) # Add overall title if provided if title: composite_fig.suptitle(title, fontsize=10, y=1.02) # Show or close the figure if show: plt.show() else: plt.close(composite_fig) return composite_fig
def _prepare_input_lists( n_datasets: int, data_labels: list[str] | None, supports: list[np.ndarray] | None, voxel_sizes: list[tuple[float, float, float]] | None, data_centres: list[tuple[float, float, float]] | None, slice_shifts: list[tuple[int, int, int]] | None, ) -> dict[str, list]: """Prepare lists of inputs with proper defaults.""" result = { "data_labels": ( data_labels if data_labels and len(data_labels) == n_datasets else [f"Dataset {i + 1}" for i in range(n_datasets)] ), "supports": ( supports if supports and len(supports) == n_datasets else [None] * n_datasets ), "voxel_sizes": ( voxel_sizes if voxel_sizes and len(voxel_sizes) == n_datasets else [None] * n_datasets ), "data_centres": ( data_centres if data_centres and len(data_centres) == n_datasets else [None] * n_datasets ), "slice_shifts": ( slice_shifts if slice_shifts and len(slice_shifts) == n_datasets else [None] * n_datasets ), } return result def _copy_image_to_axes( src_image: AxesImage, target_ax: plt.Axes ) -> AxesImage: """Copy an image from one axes to another, preserving properties.""" array = src_image.get_array() cmap = src_image.get_cmap() norm = src_image.norm extent = src_image.get_extent() origin = src_image.origin new_image = target_ax.imshow( array, cmap=cmap, norm=norm, extent=extent, origin=origin ) return new_image def _add_axis_labels( ax: plt.Axes, stacking_vertical: bool, view_idx: int, dataset_idx: int, data_label: str, view_label: str, ) -> None: """Add dataset and view labels to the appropriate axes.""" if stacking_vertical: # Always show data labels on the left side if view_idx == 0: ax.annotate( data_label, xy=(0, 0.5), xytext=(-12, 0), xycoords="axes fraction", textcoords="offset points", ha="right", va="center", fontweight="bold", ) # View labels at the top instead of the bottom for vert. stacking if dataset_idx == 0: # First row (top) instead of last row (bottom) ax.annotate( view_label, xy=(0.5, 1), xytext=(0, 12), xycoords="axes fraction", textcoords="offset points", ha="center", va="bottom", ) else: # For horizontal stacking, keep as is if view_idx == 0: ax.annotate( data_label, xy=(0.5, 1), xytext=(0, 12), xycoords="axes fraction", textcoords="offset points", ha="center", va="bottom", fontweight="bold", ) if dataset_idx == 0: ax.annotate( view_label, xy=(0, 0.5), xytext=(-12, 0), xycoords="axes fraction", textcoords="offset points", ha="right", va="center", ) def _generate_individual_plots( data_arrays: list[np.ndarray], input_lists: dict, pvs_args: dict ) -> list[tuple[plt.Figure, np.ndarray]]: """Generate individual plots using plot_volume_slices.""" individual_plots = [] for i, data in enumerate(data_arrays): dataset_params = pvs_args.copy() dataset_params.update( { "support": input_lists["supports"][i], "voxel_size": input_lists["voxel_sizes"][i], "data_centre": input_lists["data_centres"][i], "slice_shift": input_lists["slice_shifts"][i], } ) fig, axes = plot_volume_slices(data, **dataset_params) individual_plots.append((fig, axes)) return individual_plots
[docs] def plot_slices( *data: list[np.ndarray], slice_labels: list = None, figsize: tuple[float] = None, data_stacking: str = "vertical", nan_supports: list = None, vmin: float = None, vmax: float = None, alphas: list = None, origin: str = "lower", cmap: str = "turbo", show_cbar: bool = True, cbar_title: str = None, cbar_location: str = "top", cbar_extend: str = "both", norm: matplotlib.colors.Normalize = None, cbar_ticks: list = None, slice_name: str = None, suptitle: str = None, show: bool = True, ) -> matplotlib.figure.Figure: """Plot 2D slices of the provided data.""" if figsize is None: if data_stacking in ("vertical", "v"): figsize = (6, 4 * len(data)) else: figsize = (6 * len(data), 4) if data_stacking in ("vertical", "v"): nrows_ncols = (len(data), 1) elif data_stacking in ("horizontal", "h"): nrows_ncols = (1, len(data)) else: raise ValueError("data_stacking should be 'vertical' or 'horizontal'.") if slice_labels is None: slice_labels = [None for i in range(len(data))] elif len(slice_labels) != len(data): print( "Number of slice_labels should be identical to number of *data.\n" "slice_labels won't be displayed." ) slice_labels = ["" for i in range(len(data))] if figsize is None: figsize = get_figure_size() figure = plt.figure(figsize=figsize) grid = AxesGrid( figure, 111, nrows_ncols=nrows_ncols, axes_pad=0.05, cbar_mode="single" if show_cbar else None, cbar_location=cbar_location, cbar_pad=0.25 if show_cbar else None, cbar_size=0.2 if show_cbar else None, ) for i, to_plot in enumerate(data): if nan_supports is not None: if isinstance(nan_supports, list): to_plot = to_plot * nan_supports[i] else: to_plot = to_plot * nan_supports im = grid[i].matshow( to_plot, vmin=vmin, vmax=vmax, cmap=cmap, origin=origin, norm=norm, alpha=None if alphas is None else alphas[i], ) if data_stacking in ("vertical", "v"): grid[i].annotate( slice_labels[i] if slice_labels is not None else "", xy=(0.2, 0.5), xytext=(-grid[i].yaxis.labelpad - 2, 0), xycoords=grid[i].yaxis.label, textcoords="offset points", ha="right", va="center", ) else: grid[i].annotate( slice_labels[i] if slice_labels is not None else "", xy=(0.5, 0.9), xytext=(0, -grid[i].xaxis.labelpad - 2), xycoords=grid[i].xaxis.label, textcoords="offset points", ha="center", va="top", ) if data_stacking in ("vertical", "v"): grid[len(data) - 1].annotate( slice_name, xy=(0.5, 0.2), xytext=(0, -grid[len(data) - 1].xaxis.labelpad - 2), xycoords=grid[len(data) - 1].xaxis.label, textcoords="offset points", ha="center", va="top", ) else: grid[0].annotate( slice_name, xy=(0.2, 0.5), xytext=(-grid[0].yaxis.labelpad - 2, 0), xycoords=grid[0].yaxis.label, textcoords="offset points", ha="right", va="center", ) for i, ax in enumerate(grid): ax.axes.xaxis.set_ticks([]) ax.axes.yaxis.set_ticks([]) if show_cbar: ticklocation = ( "bottom" if cbar_location in ("top", "bottom") else "auto" ) cbar = grid.cbar_axes[0].colorbar( im, extend=cbar_extend, ticklocation=ticklocation ) grid.cbar_axes[0].set_title(cbar_title) if cbar_ticks: cbar.set_ticks(cbar_ticks) cbar.set_ticklabels(cbar_ticks) figure.suptitle(suptitle) figure.tight_layout() if show: plt.show() return figure
[docs] def plot_contour( ax, support_2d, linewidth=1, color="k", pixel_size=None, data_centre=None ): shape = support_2d.shape x_range = np.arange(0, shape[1]) y_range = np.arange(0, shape[0]) if pixel_size is not None: x_range = x_range * pixel_size[1] y_range = y_range * pixel_size[0] if data_centre is not None: x_range = x_range - x_range.mean() + data_centre[1] y_range = y_range - y_range.mean() + data_centre[0] X, Y = np.meshgrid(x_range, y_range) with warnings.catch_warnings(): warnings.simplefilter("ignore", category=UserWarning) ax.contour( X, Y, nan_to_zero(support_2d), levels=[0, 1], linewidths=linewidth, colors=color, )