Source code for cdiutils.interactive.plotter

"""
Plotter class for interactive data visualization from files or arrays.

This module provides the Plotter class for loading and visualizing various
data formats commonly used in BCDI experiments.

Note: This module requires ipywidgets. Dependency checking is handled at
the package level (cdiutils.interactive.__init__).
"""

# Standard library
import os
from typing import Literal

# Third-party
import h5py as h5
import ipywidgets
import numpy as np
from ipywidgets import interact

# Import plotting functions from the same package
from .plotting import plot_3d_slices, plot_data


[docs] class Plotter: """ Class to plot data from files, NumPy arrays, or layered 3D datasets. This class provides a unified interface for visualizing data commonly produced in BCDI experiments, supporting both static matplotlib-based plots and interactive multi-layer 3D visualization. Args: data: Data to plot. One of: - str: path to a file (.npy, .npz, .cxi, .h5) - np.ndarray: array to plot - dict[str, np.ndarray]: required when plot='layers' plot: Plot type. One of: - '2D' - '1D' - 'slices' - 'phase_slices' - 'contour_slices' - 'sum_slices' - 'sum_contour_slices' - '3D' - 'layers' (interactive multi-layer 3D viewer) Default is 'slices'. log: Display data in logarithmic scale (static plots only). cmap: Colormap name (used for static plots only). figsize: Figure size in inches. fontsize: Base font size for labels, ticks, and titles. title: Optional plot title. layers_kwargs: Keyword arguments forwarded to `MultiVolumeViewer` (only used when plot='layers'). Supported keys: voxel_size: tuple[float, float, float] Physical voxel size along each axis. Used to convert voxel indices to physical coordinates for slices, planes, and clips. unit: str | None Optional unit label for spatial axes. If provided, axis titles are displayed as "X (unit)", "Y (unit)", "Z (unit)". If None (default), axes are shown without units. PLOT_ORDER: Literal['xyz', 'zyx'] Axis ordering convention used for visualization. CBAR_LEN: float Relative colorbar length. render_workers: int | None Number of parallel render workers (used for animation export). render_in_flight: int | None Maximum number of frames rendered concurrently. rendering_mode: Literal['safe', 'fast', 'process'] Rendering backend strategy used during animation export. Ignored for all other plot modes. Attributes: data_array: NumPy array containing the loaded data (static plot modes). data_dict: Dictionary of named 3D arrays (used when plot='layers'). plot: Selected plot type. log: Whether logarithmic scaling is enabled. cmap: Selected colormap. figsize: Figure size in inches. fontsize: Base font size. title: Plot title. layers_kwargs: Keyword arguments forwarded to the multi-layer viewer. filename: Name of the loaded file when data is provided as a path. """
[docs] def __init__( self, data: str | np.ndarray | dict[str, np.ndarray], plot: Literal[ "2D", "slices", "phase_slices", "contour_slices", "sum_slices", "sum_contour_slices", "3D", "1D", "layers", ] = "slices", log: bool = False, cmap: str = "turbo", figsize: tuple[int, int] = (10, 10), fontsize: int = 15, title: str | None = None, layers_kwargs: dict | None = None, ): """Initialise the Plotter class. Args: data: Data to plot. One of: - str: path to a file (.npy, .npz, .cxi, .h5) - np.ndarray: array to plot - dict[str, np.ndarray]: required when plot='layers' plot: Plot type. One of: - '2D' - '1D' - 'slices' - 'phase_slices' - 'contour_slices' - 'sum_slices' - 'sum_contour_slices' - '3D' - 'layers' (multi-layer 3D viewer) log: Display data in logarithmic scale. cmap: Colormap name. (for static plot) figsize: Figure size in inches. fontsize: Base font size for labels, ticks, and titles. title: Optional plot title. layers_kwargs: Keyword arguments forwarded to `MultiVolumeViewer` (only used when plot='layers'). Supported keys: voxel_size: tuple[float, float, float] Physical voxel size along each axis. PLOT_ORDER: Literal['xyz', 'zyx'] Axis ordering convention. CBAR_LEN: float Relative colorbar length. render_workers: int | None Number of parallel render workers (for animation). render_in_flight: int | None Maximum number of in-flight render tasks. rendering_mode: Literal['safe', 'fast', 'process'] Rendering backend strategy (for animation). """ # ---- legacy behaviour ---- self.data_array = None self.data_dict = None self.plot = plot self.log = log self.cmap = cmap self.figsize = figsize self.fontsize = fontsize self.title = title self.layers_kwargs = layers_kwargs or {} # MultiVolumeViewer only accepts dict if self.plot == "layers": if not isinstance(data, dict): print( "[Plotter] MultiVolumeViewer requires a dict[str, np.ndarray].\n" "Example:\n" " {'density': density_3d, 'phase': phase_3d}" ) return self.data_dict = data self.init_plot() return # Get data array from any of the supported files if isinstance(data, str) and os.path.isfile(data): self.filename = data self.get_data_array() elif isinstance(data, np.ndarray): self.data_array = data self.init_plot() else: print( "Please provide either a valid filename (arg filename)" " or directly an np.ndarray (arg data_array)." )
[docs] def init_plot(self): """ Initialise a plot of the data stored in the `data_array` attribute. The type of plot and the parameters are specified in the class constructor. The plot can be a 2D plot, 3D slices, contour plots of slices, sum of slices, sum of contour plots of slices, or a 3D plot. The specific plot type is determined by the value of the `plot` attribute. If the number of dimensions of the `data_array` is not compatible with the specified plot type, the function simply prints the number of dimensions and shape of the `data_array`. Attributes: data_array: An array containing the data to be plotted. plot: The type of plot to be generated, which can be one of the following: "2D", "slices", "phase_slices", "contour_slices", "sum_slices", "sum_contour_slices", "3D" or "layers". figsize: The size of the plot in inches. fontsize: The font size of the plot. log: If True, plot the data in logarithmic scale. cmap: The colourmap to be used for the plot. title: The title of the plot. """ # Import ThreeDViewer here to avoid circular imports from .viewer_3d import ThreeDViewer if self.plot == "2D": plot_data( data_array=self.data_array, figsize=self.figsize, fontsize=self.fontsize, log=self.log, cmap=self.cmap, title=self.title, ) elif self.plot == "slices" and self.data_array.ndim == 3: plot_3d_slices( data_array=self.data_array, fontsize=self.fontsize, title=self.title, figsize=None, log=self.log, cmap=self.cmap, contour=False, sum_over_axis=False, ) elif self.plot == "phase_slices" and self.data_array.ndim == 3: amp = np.abs(self.data_array) phase = np.angle(self.data_array) max_amp = np.max(amp) phase_in_support = np.where(amp > 0.05 * max_amp, phase, np.nan) plot_3d_slices( data_array=phase_in_support, fontsize=self.fontsize, title=self.title, figsize=None, log=self.log, cmap=self.cmap, contour=False, sum_over_axis=False, ) elif self.plot == "contour_slices" and self.data_array.ndim == 3: plot_3d_slices( data_array=self.data_array, fontsize=self.fontsize, title=self.title, figsize=None, log=self.log, cmap=self.cmap, contour=True, sum_over_axis=False, ) elif self.plot == "sum_slices" and self.data_array.ndim == 3: plot_3d_slices( data_array=self.data_array, fontsize=self.fontsize, title=self.title, figsize=None, log=self.log, cmap=self.cmap, contour=False, sum_over_axis=True, ) elif self.plot == "sum_contour_slices" and self.data_array.ndim == 3: plot_3d_slices( data_array=self.data_array, fontsize=self.fontsize, title=self.title, figsize=None, log=self.log, cmap=self.cmap, contour=True, sum_over_axis=True, ) elif self.plot == "3D" and self.data_array.ndim == 3: viewer = ThreeDViewer(self.data_array) viewer.show() elif self.plot == "layers": from .multiviewer_3d import MultiVolumeViewer self.layers_kwargs.pop("figsize", None) self.layers_kwargs.pop("fontsize", None) viewer = MultiVolumeViewer( self.data_dict, fontsize=self.fontsize, figsize=self.figsize, **self.layers_kwargs, ) viewer.show() return elif self.plot == "1D" and self.data_array.ndim == 1: print(self.data_array) plot_data( data_array=self.data_array, figsize=self.figsize, fontsize=self.fontsize, log=self.log, cmap=self.cmap, title=self.title, ) else: print( "#########################################################" "########################################################\n" f"Loaded data array\n" f"\tNb of dimensions: {self.data_array.ndim}\n" f"\tShape: {self.data_array.shape}\n" "\n#########################################################" "########################################################" )
[docs] def get_data_array(self): """ Return the data array stored in the class instance by reading the specified file. The file must have a .npy, .cxi, .h5, or .npz extension. If the file is a .npy or .h5 file, the data array is directly loaded. If the file is a .cxi file, the data array is loaded from `f.root.entry_1.data_1.data[:]` or `f.root.entry_1.image_1.data[:]`, following cxi conventions. If the file is a .npz file, the user is prompted to select the data array from a dropdown list of arrays stored in the .npz file. If the file extension is supported and the data array is successfully loaded, the `init_plot` function is called. Return: A Numpy array representing the data stored in the class, or None if the file could not be loaded. Raises: KeyError: If the file is a .cxi or .h5 file, and the data could not be found in either `f.root.entry_1.data_1.data[:]` or `f.root.entry_1.image_1.data[:]`. """ # No need to select data array interactively if self.filename.endswith((".npy", ".h5", ".cxi")): if self.filename.endswith(".npy"): try: self.data_array = np.load(self.filename) except ValueError: print("Could not load data ... ") elif self.filename.endswith(".cxi"): try: self.data_array = h5.File(self.filename, mode="r")[ "entry_1/data_1/data" ][()] except (KeyError, OSError): try: self.data_array = h5.File(self.filename, mode="r")[ "entry_1/image_1/data" ][()] except (KeyError, OSError): print( "The file could not be loaded, verify that you are" "loading a file with an hdf5 architecture (.nxs, " ".cxi, .h5, ...) and that the file exists." "Otherwise, verify that the data is saved in " "f.root.entry_1.data_1.data[:]," "or f.root.entry_1.image_1.data[:], as it should be" "following cxi conventions." ) elif self.filename.endswith(".h5"): try: self.data_array = h5.File(self.filename, mode="r")[ "entry_1/data_1/data" ][()] if self.data_array.ndim == 4: self.data_array = self.data_array[0] # Due to labelling of axes x,y,z and not z,y,x self.data_array = np.swapaxes(self.data_array, 0, 2) except (KeyError, OSError): try: self.data_array = h5.File(self.filename, mode="r")[ "entry_1/image_1/data" ][()] if self.data_array.ndim == 4: self.data_array = self.data_array[0] # Due to labelling of axes x,y,z and not z,y,x self.data_array = np.swapaxes(self.data_array, 0, 2) except (KeyError, OSError): raise KeyError( "The file could not be loaded, verify that you are" "loading a file with an hdf5 architecture (.nxs, " ".cxi, .h5, ...) and that the file exists." "Otherwise, verify that the data is saved in " "f.root.entry_1.data_1.data[:]," "or f.root.entry_1.image_1.data[:], as it should be" "following cxi conventions." ) # Plot data self.init_plot() # Need to select data array interactively elif self.filename.endswith(".npz"): # Open npz file and allow the user to pick an array try: rawdata = np.load(self.filename) @interact( file=ipywidgets.Dropdown( options=rawdata.files, value=rawdata.files[0], description="Pick an array to load:", style={"description_width": "initial"}, ) ) def open_npz(file): # Pick an array try: self.data_array = rawdata[file] except ValueError: print("Key not valid, is this an array ?") # Plot data self.init_plot() except ValueError: print("Could not load data.") else: print("Data type not supported.")