Source code for cdiutils.interactive.plotter

"""
Plotter class for interactive data visualization from files or arrays.

This module provides the Plotter class for loading and visualizing various
data formats commonly used in BCDI experiments.

Note: This module requires ipywidgets. Dependency checking is handled at
the package level (cdiutils.interactive.__init__).
"""

import os

import h5py as h5
import ipywidgets
import numpy as np
from ipywidgets import interact

# Import plotting functions from the same package
from .plotting import plot_3d_slices, plot_data


[docs] class Plotter: """ Class to plot data from files or Numpy arrays. Args: data: The data to plot. This can either be the path to a file, or a Numpy array directly. plot: Specifies the type of plot to create. Available options are: '2D', 'slices', "phase_slices", 'contour_slices', 'sum_slices', 'sum_contour_slices', '3D', "1D", by default 'slices'. log: Whether to display the plot in log scale, by default False. cmap: The colourmap to use for the plot, by default 'turbo'. figsize: The size of the figure in inches, by default (10, 10). fontsize: The font size to use in the plot, by default 15. title: The title of the plot, by default None. Attributes: data_array: The Numpy array with the data. plot: The type of plot specified. log: Whether to display the plot in log scale. cmap: The colourmap specified. figsize: The size of the figure specified. fontsize: The font size specified. title: The title of the plot. filename: The name of the file if data is given as a path. """
[docs] def __init__( self, data: str | np.ndarray, plot: str = "slices", log: bool = False, cmap: str = "turbo", figsize: tuple[int, int] = (10, 10), fontsize: int = 15, title: str = None, ): """Initialise the Plotter class. Args: data: The data to plot. This can either be the path to a file, or a Numpy array directly. plot: Specifies the type of plot to create. Available options are: '2D', 'slices', "phase_slices", 'contour_slices', 'sum_slices', 'sum_contour_slices', '3D', "1D", by default 'slices'. log: Whether to display the plot in log scale, by default False. cmap: The colourmap to use for the plot, by default 'turbo'. figsize: The size of the figure in inches, by default (10, 10). fontsize: The font size to use in the plot, by default 15. title: The title of the plot, by default None. """ self.data_array = None self.plot = plot self.log = log self.cmap = cmap self.figsize = figsize self.fontsize = fontsize self.title = title # Get data array from any of the supported files if isinstance(data, str) and os.path.isfile(data): self.filename = data self.get_data_array() elif isinstance(data, np.ndarray): self.data_array = data self.init_plot() else: print( "Please provide either a valid filename (arg filename)" " or directly an np.ndarray (arg data_array)." )
[docs] def init_plot(self): """ Initialise a plot of the data stored in the `data_array` attribute. The type of plot and the parameters are specified in the class constructor. The plot can be a 2D plot, 3D slices, contour plots of slices, sum of slices, sum of contour plots of slices, or a 3D plot. The specific plot type is determined by the value of the `plot` attribute. If the number of dimensions of the `data_array` is not compatible with the specified plot type, the function simply prints the number of dimensions and shape of the `data_array`. Attributes: data_array: An array containing the data to be plotted. plot: The type of plot to be generated, which can be one of the following: "2D", "slices", "phase_slices", "contour_slices", "sum_slices", "sum_contour_slices", or "3D". figsize: The size of the plot in inches. fontsize: The font size of the plot. log: If True, plot the data in logarithmic scale. cmap: The colourmap to be used for the plot. title: The title of the plot. """ # Import ThreeDViewer here to avoid circular imports from .viewer_3d import ThreeDViewer if self.plot == "2D": plot_data( data_array=self.data_array, figsize=self.figsize, fontsize=self.fontsize, log=self.log, cmap=self.cmap, title=self.title, ) elif self.plot == "slices" and self.data_array.ndim == 3: plot_3d_slices( data_array=self.data_array, fontsize=self.fontsize, title=self.title, figsize=None, log=self.log, cmap=self.cmap, contour=False, sum_over_axis=False, ) elif self.plot == "phase_slices" and self.data_array.ndim == 3: amp = np.abs(self.data_array) phase = np.angle(self.data_array) max_amp = np.max(amp) phase_in_support = np.where(amp > 0.05 * max_amp, phase, np.nan) plot_3d_slices( data_array=phase_in_support, fontsize=self.fontsize, title=self.title, figsize=None, log=self.log, cmap=self.cmap, contour=False, sum_over_axis=False, ) elif self.plot == "contour_slices" and self.data_array.ndim == 3: plot_3d_slices( data_array=self.data_array, fontsize=self.fontsize, title=self.title, figsize=None, log=self.log, cmap=self.cmap, contour=True, sum_over_axis=False, ) elif self.plot == "sum_slices" and self.data_array.ndim == 3: plot_3d_slices( data_array=self.data_array, fontsize=self.fontsize, title=self.title, figsize=None, log=self.log, cmap=self.cmap, contour=False, sum_over_axis=True, ) elif self.plot == "sum_contour_slices" and self.data_array.ndim == 3: plot_3d_slices( data_array=self.data_array, fontsize=self.fontsize, title=self.title, figsize=None, log=self.log, cmap=self.cmap, contour=True, sum_over_axis=True, ) elif self.plot == "3D" and self.data_array.ndim == 3: viewer = ThreeDViewer(self.data_array) viewer.show() elif self.plot == "1D" and self.data_array.ndim == 1: print(self.data_array) plot_data( data_array=self.data_array, figsize=self.figsize, fontsize=self.fontsize, log=self.log, cmap=self.cmap, title=self.title, ) else: print( "#########################################################" "########################################################\n" f"Loaded data array\n" f"\tNb of dimensions: {self.data_array.ndim}\n" f"\tShape: {self.data_array.shape}\n" "\n#########################################################" "########################################################" )
[docs] def get_data_array(self): """ Return the data array stored in the class instance by reading the specified file. The file must have a .npy, .cxi, .h5, or .npz extension. If the file is a .npy or .h5 file, the data array is directly loaded. If the file is a .cxi file, the data array is loaded from `f.root.entry_1.data_1.data[:]` or `f.root.entry_1.image_1.data[:]`, following cxi conventions. If the file is a .npz file, the user is prompted to select the data array from a dropdown list of arrays stored in the .npz file. If the file extension is supported and the data array is successfully loaded, the `init_plot` function is called. Return: A Numpy array representing the data stored in the class, or None if the file could not be loaded. Raises: KeyError: If the file is a .cxi or .h5 file, and the data could not be found in either `f.root.entry_1.data_1.data[:]` or `f.root.entry_1.image_1.data[:]`. """ # No need to select data array interactively if self.filename.endswith((".npy", ".h5", ".cxi")): if self.filename.endswith(".npy"): try: self.data_array = np.load(self.filename) except ValueError: print("Could not load data ... ") elif self.filename.endswith(".cxi"): try: self.data_array = h5.File(self.filename, mode="r")[ "entry_1/data_1/data" ][()] except (KeyError, OSError): try: self.data_array = h5.File(self.filename, mode="r")[ "entry_1/image_1/data" ][()] except (KeyError, OSError): print( "The file could not be loaded, verify that you are" "loading a file with an hdf5 architecture (.nxs, " ".cxi, .h5, ...) and that the file exists." "Otherwise, verify that the data is saved in " "f.root.entry_1.data_1.data[:]," "or f.root.entry_1.image_1.data[:], as it should be" "following cxi conventions." ) elif self.filename.endswith(".h5"): try: self.data_array = h5.File(self.filename, mode="r")[ "entry_1/data_1/data" ][()] if self.data_array.ndim == 4: self.data_array = self.data_array[0] # Due to labelling of axes x,y,z and not z,y,x self.data_array = np.swapaxes(self.data_array, 0, 2) except (KeyError, OSError): try: self.data_array = h5.File(self.filename, mode="r")[ "entry_1/image_1/data" ][()] if self.data_array.ndim == 4: self.data_array = self.data_array[0] # Due to labelling of axes x,y,z and not z,y,x self.data_array = np.swapaxes(self.data_array, 0, 2) except (KeyError, OSError): raise KeyError( "The file could not be loaded, verify that you are" "loading a file with an hdf5 architecture (.nxs, " ".cxi, .h5, ...) and that the file exists." "Otherwise, verify that the data is saved in " "f.root.entry_1.data_1.data[:]," "or f.root.entry_1.image_1.data[:], as it should be" "following cxi conventions." ) # Plot data self.init_plot() # Need to select data array interactively elif self.filename.endswith(".npz"): # Open npz file and allow the user to pick an array try: rawdata = np.load(self.filename) @interact( file=ipywidgets.Dropdown( options=rawdata.files, value=rawdata.files[0], description="Pick an array to load:", style={"description_width": "initial"}, ) ) def open_npz(file): # Pick an array try: self.data_array = rawdata[file] except ValueError: print("Key not valid, is this an array ?") # Plot data self.init_plot() except ValueError: print("Could not load data.") else: print("Data type not supported.")