Source code for cdiutils.interactive.volume

"""
Interactive 3D volume visualisation tools for BCDI data.

This module provides interactive widgets for visualising 3D volumetric data
with different backends:
- ThreeDViewer: Plotly-based interactive 3D viewer class
  (requires plotly, ipywidgets - included in: pip install cdiutils[interactive])
- plot_3d_isosurface: Plotly-based isosurface function
  (requires plotly, ipywidgets - included in: pip install cdiutils[interactive])
- VolumeViewer: PyVista/Trame-based visualisation
  (requires pyvista, trame - install with: pip install cdiutils[pyvista])

The Plotly-based functions are recommended for most use cases as they provide
excellent interactive performance and are included in the standard interactive
dependencies. PyVista/Trame is available for specialised workflows.
"""

import os

import matplotlib.pyplot as plt
import numpy as np

# PyVista/Trame availability checked in __init__.py, but we need the imports
try:
    import pyvista as pv
    from pyvista.trame.ui.vuetify3 import divider, select, slider

    IS_PYVISTA_AVAILABLE = True
except ImportError:
    IS_PYVISTA_AVAILABLE = False
    pv = None

# Plotly and related imports
try:
    import plotly.graph_objects as go
    from scipy.interpolate import RegularGridInterpolator
    from scipy.ndimage import map_coordinates
    from skimage import measure

    IS_PLOTLY_AVAILABLE = True
except ImportError:
    IS_PLOTLY_AVAILABLE = False
    go = None


[docs] def plot_3d_isosurface( amplitude: np.ndarray, quantities: dict[str, np.ndarray], voxel_size: tuple[float, float, float] | None = None, initial_quantity: str | None = None, cmap: str | None = None, vmin: float | None = None, vmax: float | None = None, convention: str | None = None, figsize: tuple[int, int] = (9, 6), title: str | None = None, lighting_params: dict[str, float] | None = None, camera_position: dict | None = None, theme: str = "plotly_white", ): """ Plot an interactive 3D isosurface using Plotly and ipywidgets. This function creates an interactive 3D visualisation where users can adjust the isosurface threshold, switch between different scalar quantities, change colormaps, and control colourbar scaling using interactive widgets. The camera view is preserved when updating. Interactive Controls: - Isosurface slider: Adjust threshold level (0-1 normalised) - Quantity dropdown: Switch between different quantities - Colourmap dropdown: Change the colourmap on-the-fly - Set limits checkbox: Enable manual colourbar limits (default: OFF) * When OFF: Auto-scales to min/max of current plot * When ON: Enables vmin/vmax input fields for manual control - vmin/vmax inputs: Set custom colourbar limits (disabled unless "Set limits" is checked) - Symmetric colourbar: Centre at 0 for strain/phase data (default: OFF) - Replace NaN with mean: Replace NaN values with mean to avoid weird colouring artefacts (default: OFF) Args: amplitude (np.ndarray): 3D array for determining isosurface levels. quantities (dict[str, np.ndarray]): Dictionary of 3D arrays to visualise, with keys as quantity names and values as numpy arrays. voxel_size (tuple[float, float, float] | None, optional): Size of voxels (dx, dy, dz) for proper scaling. Defaults to (1.0, 1.0, 1.0). initial_quantity (str | None, optional): Name of the quantity to display initially. Must be a key in quantities dict. If None, uses the first key in quantities. Defaults to None. cmap (str | None, optional): Initial colourmap name. Defaults to "viridis". vmin (float | None, optional): Initial minimum value for colour scale. If None, auto-scales to data minimum. Defaults to None. vmax (float | None, optional): Initial maximum value for colour scale. If None, auto-scales to data maximum. Defaults to None. convention (str | None, optional): Coordinate convention ("cxi" or "xu"). Defaults to "cxi". figsize (tuple[int, int], optional): Figure size in inches (width, height). Defaults to (9, 6). title (str | None, optional): Plot title. Defaults to "Interactive 3D Isosurface". lighting_params (dict[str, float] | None, optional): Plotly lighting parameters. Defaults to preset values. camera_position (dict | None, optional): Initial camera position. Defaults to eye=dict(x=1.5, y=1.5, z=1.5). theme (str, optional): Plotly theme. Defaults to "plotly_white". Returns: VBox: ipywidgets VBox containing controls and the figure widget. Raises: PlotlyImportError: if plotly or required packages are not installed. ValueError: if amplitude and quantities have different shapes, or if initial_quantity is not in quantities dict. Example: >>> import numpy as np >>> amp = np.random.rand(50, 50, 50) >>> strain = np.random.randn(50, 50, 50) * 0.01 >>> phase = np.random.randn(50, 50, 50) * np.pi >>> widget = plot_3d_isosurface( ... amp, ... {"het_strain": strain, "phase": phase}, ... voxel_size=(1.0, 1.0, 1.0), ... initial_quantity="het_strain", ... cmap="cet_CET_D13" ... ) >>> display(widget) """ if not IS_PLOTLY_AVAILABLE: raise PlotlyImportError() try: from ipywidgets import ( Checkbox, Dropdown, FloatSlider, FloatText, HBox, VBox, ) except ImportError as e: raise PlotlyImportError( f"Required packages not available: {e}. " "Install with: pip install cdiutils[interactive]" ) # validate inputs for name, quantity in quantities.items(): if amplitude.shape != quantity.shape: raise ValueError( f"amplitude and quantity '{name}' must have the same " f"shape. Got {amplitude.shape} and {quantity.shape}" ) # set defaults if voxel_size is None: voxel_size = (1.0, 1.0, 1.0) if lighting_params is None: lighting_params = dict( ambient=0.90, diffuse=0.05, specular=0.5, roughness=0.2, fresnel=0.5, ) if camera_position is None: camera_position = dict(eye=dict(x=1.5, y=1.5, z=1.5)) if title is None: title = "Interactive 3D Isosurface" if convention is None: convention = "cxi" # determine initial quantity to display quantity_names = list(quantities.keys()) if initial_quantity is None: initial_qty_name = quantity_names[0] else: if initial_quantity not in quantities: raise ValueError( f"initial_quantity '{initial_quantity}' not found in " f"quantities dict. Available: {quantity_names}" ) initial_qty_name = initial_quantity # set initial colourmap if cmap is None: initial_cmap = "viridis" else: initial_cmap = cmap # create initial isosurface using shared helper isosurface_default = 0.5 # handle NaN values in amplitude when calculating threshold amplitude_max = np.nanmax(amplitude) verts_scaled, faces, quantity_at_verts = _extract_isosurface_with_values( amplitude, quantities[initial_qty_name], isosurface_default * amplitude_max, voxel_size, ) # set initial colourbar limits, handle NaN values if vmin is None: initial_vmin = np.nanmin(quantity_at_verts) else: initial_vmin = vmin if vmax is None: initial_vmax = np.nanmax(quantity_at_verts) else: initial_vmax = vmax # convert colourmap to plotly format initial_plotly_cmap = colorcet_to_plotly(initial_cmap) # Create FigureWidget (allows in-place updates) fig = go.FigureWidget( data=[ go.Mesh3d( x=verts_scaled[:, 0], y=verts_scaled[:, 1], z=verts_scaled[:, 2], i=faces[:, 0], j=faces[:, 1], k=faces[:, 2], intensity=quantity_at_verts, colorscale=initial_plotly_cmap, cmin=initial_vmin, cmax=initial_vmax, colorbar=dict(title=initial_qty_name), opacity=1.0, flatshading=False, lighting=lighting_params, hovertemplate=( "<b>Position:</b><br>" "x: %{x:.1f}<br>y: %{y:.1f}<br>z: %{z:.1f}<br>" f"<b>{initial_qty_name}:</b> %{{intensity:.3f}}<br>" "<extra></extra>" ), ) ] ) fig.update_layout( title=f"{title} - {initial_qty_name} (iso={isosurface_default:.2f})", template=theme, scene=dict( xaxis=dict(showbackground=True, title="x"), yaxis=dict(showbackground=True, title="y"), zaxis=dict(showbackground=True, title="z"), aspectmode="data", camera=camera_position, ), width=figsize[0] * 96, height=figsize[1] * 96, dragmode="orbit", ) # Create widgets isosurface_slider = FloatSlider( value=isosurface_default, min=0.0, max=1.0, step=0.01, description="Isosurface:", continuous_update=False, style={"description_width": "initial"}, ) quantity_dropdown = Dropdown( options=quantity_names, value=initial_qty_name, description="Quantity:", style={"description_width": "initial"}, ) # Colormap dropdown cmap_options = [ "turbo", "viridis", "inferno", "magma", "plasma", "cividis", "RdBu", "coolwarm", "twilight", "Blues", "Greens", "Greys", "Purples", "Oranges", "Reds", "cet_CET_D13", "cet_CET_C9s_r", "cet_CET_D1A", "jch_const", "jch_max", ] colormap_dropdown = Dropdown( options=cmap_options, value=initial_cmap, description="Colormap:", style={"description_width": "initial"}, ) # Colorbar control checkboxes and inputs set_limits_checkbox = Checkbox( value=False, description="Set limits:", tooltip="Enable manual colorbar limits", indent=False, style={"description_width": "80px"}, ) vmin_input = FloatText( value=initial_vmin, description="vmin:", disabled=True, # disabled by default style={"description_width": "50px"}, layout={"width": "150px"}, ) vmax_input = FloatText( value=initial_vmax, description="vmax:", disabled=True, # disabled by default style={"description_width": "50px"}, layout={"width": "150px"}, ) symmetric_checkbox = Checkbox( value=False, description="Symmetric colorbar", tooltip="Center colorbar at 0 (for strain, phase, etc.)", indent=False, style={"description_width": "initial"}, ) replace_nan_checkbox = Checkbox( value=False, description="Replace NaN with mean", tooltip="Replace NaN values with mean (fixes weird colouring)", indent=False, style={"description_width": "initial"}, ) def toggle_limit_inputs(change=None) -> None: """ Enable/disable limit input fields based on checkbox. Args: change: widget change event (not used but required by observer). """ if set_limits_checkbox.value: vmin_input.disabled = False vmax_input.disabled = False else: vmin_input.disabled = True vmax_input.disabled = True def update_mesh(change=None) -> None: """ Update mesh when slider or dropdown changes. Handles NaN values in quantity data using np.nanmin/np.nanmax for robust limit calculation. Can optionally replace NaN with mean value to avoid weird colouring. Args: change: widget change event (not used but required by observer). """ iso_level = isosurface_slider.value qty_name = quantity_dropdown.value quantity = quantities[qty_name] # use shared helper function to extract isosurface verts_scaled, faces, quantity_at_verts = ( _extract_isosurface_with_values( amplitude, quantity, iso_level * np.nanmax(amplitude), # handle NaN in amplitude voxel_size, ) ) # optionally replace NaN values with mean to fix weird colouring if replace_nan_checkbox.value: if np.any(np.isnan(quantity_at_verts)): mean_val = np.nanmean(quantity_at_verts) quantity_at_verts = np.where( np.isnan(quantity_at_verts), mean_val, quantity_at_verts, ) # get colourmap from dropdown cmap = colormap_dropdown.value plotly_cmap = colorcet_to_plotly(cmap) # determine colour scale limits based on checkbox settings if symmetric_checkbox.value: # symmetric around 0, handle NaN values min_val = np.nanmin(quantity_at_verts) max_val = np.nanmax(quantity_at_verts) max_abs = max(abs(min_val), abs(max_val)) vmin = -max_abs vmax = max_abs # update input fields to show current values # (but keep disabled if unchecked) vmin_input.value = vmin vmax_input.value = vmax elif set_limits_checkbox.value: # use manual limits from input fields vmin = vmin_input.value vmax = vmax_input.value else: # default: auto-scale to actual data range, handle NaN vmin = float(np.nanmin(quantity_at_verts)) vmax = float(np.nanmax(quantity_at_verts)) # update input fields to show current values # (but keep disabled) vmin_input.value = vmin vmax_input.value = vmax # update mesh data in-place (preserves camera view!) with fig.batch_update(): fig.data[0].x = verts_scaled[:, 0] fig.data[0].y = verts_scaled[:, 1] fig.data[0].z = verts_scaled[:, 2] fig.data[0].i = faces[:, 0] fig.data[0].j = faces[:, 1] fig.data[0].k = faces[:, 2] fig.data[0].intensity = quantity_at_verts fig.data[0].colorscale = plotly_cmap fig.data[0].cmin = vmin fig.data[0].cmax = vmax fig.data[0].colorbar.title = qty_name fig.data[0].hovertemplate = ( "<b>Position:</b><br>" "x: %{x:.1f}<br>y: %{y:.1f}<br>z: %{z:.1f}<br>" f"<b>{qty_name}:</b> %{{intensity:.3f}}<br>" "<extra></extra>" ) fig.layout.title = f"{title} - {qty_name} (iso={iso_level:.2f})" # Attach observers to widgets isosurface_slider.observe(update_mesh, names="value") quantity_dropdown.observe(update_mesh, names="value") colormap_dropdown.observe(update_mesh, names="value") set_limits_checkbox.observe(toggle_limit_inputs, names="value") set_limits_checkbox.observe(update_mesh, names="value") vmin_input.observe(update_mesh, names="value") vmax_input.observe(update_mesh, names="value") symmetric_checkbox.observe(update_mesh, names="value") replace_nan_checkbox.observe(update_mesh, names="value") # create layout controls_row1 = HBox( [isosurface_slider, quantity_dropdown, colormap_dropdown] ) controls_row2 = HBox( [ set_limits_checkbox, vmin_input, vmax_input, symmetric_checkbox, replace_nan_checkbox, ] ) widget = VBox([controls_row1, controls_row2, fig]) return widget
class PyVistaImportError(ImportError): """Custom exception to handle PyVista/Trame import error.""" def __init__(self, msg: str = None) -> None: """ Initialise PyVistaImportError with informative message. Args: msg (str, optional): additional error message. Defaults to None. """ _msg = ( "PyVista and Trame packages are not installed. " "Install with: pip install cdiutils[pyvista]" ) if msg is not None: _msg += "\n" + msg super().__init__(_msg) class PlotlyImportError(ImportError): """Custom exception to handle Plotly import error.""" def __init__(self, msg: str = None) -> None: """ Initialise PlotlyImportError with informative message. Args: msg (str, optional): additional error message. Defaults to None. """ _msg = ( "Plotly and required packages are not installed. " "Install with: pip install cdiutils[interactive]" ) if msg is not None: _msg += "\n" + msg super().__init__(_msg) def _extract_isosurface_with_values( amplitude: np.ndarray, quantity: np.ndarray, isosurface_level: float, voxel_size: tuple = (1.0, 1.0, 1.0), use_interpolator: bool = False, ): """ Extract isosurface and interpolate quantity values at vertices. This is a shared utility function used by both ThreeDViewer and plot_3d_isosurface to avoid code duplication. Args: amplitude (np.ndarray): 3D array for determining isosurface. quantity (np.ndarray): 3D array of values to interpolate at surface vertices (can be complex). isosurface_level (float): threshold value for marching cubes. voxel_size (tuple, optional): voxel size (dx, dy, dz) for scaling. Defaults to (1.0, 1.0, 1.0). use_interpolator (bool, optional): if True, use RegularGridInterpolator (needed for complex arrays). If False, use map_coordinates (faster for real arrays). Defaults to False. Returns: tuple: (verts_scaled, faces, quantity_at_verts) where: - verts_scaled: nx3 array of scaled vertex positions - faces: mx3 array of triangle face indices - quantity_at_verts: length-n array of interpolated values Raises: PlotlyImportError: if required packages not available. """ if not IS_PLOTLY_AVAILABLE: raise PlotlyImportError() # extract isosurface using marching cubes verts, faces, _, _ = measure.marching_cubes( np.abs(amplitude), level=isosurface_level, step_size=1, ) # scale vertices verts_scaled = verts * voxel_size # interpolate quantity values at vertices if use_interpolator or np.iscomplexobj(quantity): # use RegularGridInterpolator for complex arrays nz, ny, nx = quantity.shape grid_z = np.arange(nz) grid_y = np.arange(ny) grid_x = np.arange(nx) rgi = RegularGridInterpolator( (grid_z, grid_y, grid_x), quantity, bounds_error=False, fill_value=0, ) quantity_at_verts = rgi(verts) else: # use map_coordinates for real arrays (faster) quantity_at_verts = map_coordinates( quantity, verts.T, order=1, mode="nearest" ) return verts_scaled, faces, quantity_at_verts def colorcet_to_plotly(cmap_name: str, n_colors: int = 256) -> list[list]: """ Convert a colorcet or matplotlib colormap to a Plotly colorscale. Args: cmap_name (str): name of the colorcet or matplotlib colormap (e.g., 'rainbow', 'fire', 'cet_CET_D13'). n_colors (int, optional): number of colour samples to extract from the colormap. Defaults to 256. Returns: list[list]: Plotly colorscale as a list of [position, 'rgb(r,g,b)'] entries with positions in [0.0, 1.0]. Raises: ValueError: if the specified colormap name is not found in the matplotlib/colorcet colormaps. """ # get the colorcet colormap if cmap_name not in plt.colormaps(): raise ValueError( f"Colormap '{cmap_name}' not found in matplotlib/colorcet " f"colormaps." ) cmap = plt.get_cmap(cmap_name) # sample colours from the colormap colors = [cmap(i) for i in np.linspace(0, 1, n_colors)] # convert to Plotly format plotly_colorscale = [ [ i / (n_colors - 1), f"rgb({int(c[0] * 255)},{int(c[1] * 255)},{int(c[2] * 255)})", ] for i, c in enumerate(colors) ] return plotly_colorscale class VolumeViewer: """ A class to plot volume in 3D with Trame and PyVista. This class provides interactive 3D visualization of volumetric data using PyVista's Trame backend for Jupyter notebooks. Raises: PyVistaImportError: if Trame or PyVista are not installed. """ generic_params = { "amplitude": {"cmap": "turbo", "centred_clim": False, "clim": [0, 1]}, "support": {"cmap": "viridis", "centred_clim": False}, "phase": {"cmap": "cet_CET_C9s_r", "centred_clim": True}, "displacement": {"cmap": "cet_CET_D1A", "centred_clim": True}, "het_strain": {"cmap": "cet_CET_D13", "centred_clim": True}, "het_strain_from_dspacing": { "cmap": "cet_CET_D13", "centred_clim": True, }, "lattice_parameter": {"cmap": "turbo", "centred_clim": False}, "dspacing": {"cmap": "turbo", "centred_clim": False}, "isosurface": 0.50, "cmap": "turbo", } cmap_options = ( "turbo", "viridis", "spectral", "inferno", "magma", "plasma", "cividis", "RdBu", "coolwarm", "Blues", "Greens", "Greys", "Purples", "Oranges", "Reds", "cet_CET_D13", "cet_CET_C9s_r", "cet_CET_D1A", "jch_const", "jch_max", ) @classmethod def _generate_toolbar_tools( cls, initial_scalar: str, available_scalars: list[str], **kwargs ) -> callable: """ Generate toolbar widgets for the Trame interface. Args: initial_scalar (str): Initial scalar field to display. available_scalars (list[str]): List of available scalar fields. Returns: callable: Toolbar function for Trame UI. """ def toolbar_tools() -> None: divider(vertical=True, classes="mx-1") # isosurface slider slider( model=("isosurface_value", cls.generic_params["isosurface"]), tooltip="Adjust isosurface threshold", min=0.0, max=1.0, step=0.01, dense=True, hide_details=False, style="width: 250px", classes="my-0 py-0 ml-1 mr-1", ) divider(vertical=True, classes="mx-1") # scalar field dropdown select( model=("scalar_field", initial_scalar), tooltip="Choose scalar field for coloring", items=("available_scalars", available_scalars), hide_details=True, dense=True, outlined=True, ) divider(vertical=True, classes="mx-1") # colourmap dropdown select( model=("cmap", cls.generic_params[initial_scalar]["cmap"]), tooltip="Choose a colourmap", items=("cmap_options", cls.cmap_options), hide_details=True, dense=True, outlined=True, ) return toolbar_tools @classmethod def contour_plot( cls, data_path: str | None = None, initial_active_scalar: str = "het_strain", **data: np.ndarray, ): """ Generate a contour plot application using PyVista. Args: data_path (str | None, optional): Path to a .vti file containing the data. Defaults to None. initial_active_scalar (str, optional): Initial scalar field to display. Defaults to "het_strain". **data (np.ndarray): Dictionary of numpy arrays to visualize. Raises: PyVistaImportError: if Trame or PyVista are not installed. ValueError: If the path is not a .vti file. ValueError: If initial_active_scalar is not available. NotImplementedError: When parsing np.ndarray directly (reserved for future use). Returns: The widget viewer for display in Jupyter notebooks. """ if not IS_PYVISTA_AVAILABLE: raise PyVistaImportError() if data_path is not None: # ignoring the **data if not data_path.endswith(".vti"): raise ValueError( "The provided data_path should point to a .vti file" ) structure_grid = pv.read(data_path) available_scalars = structure_grid.array_names elif len(data) < 1: raise NotImplementedError( "Either np.ndarray or data_path must be provided." ) else: initial_active_scalar = list(data.keys())[0] mesh = np.meshgrid( *[np.arange(s) for s in data[initial_active_scalar].shape], indexing="ij", ) structure_grid = pv.StructuredGrid(*mesh) available_scalars = list(data.keys()) for key, d in data.items(): structure_grid.point_data[key] = d.flatten() plotter = pv.Plotter(notebook=True) # generate the initial isosurface contours = structure_grid.contour( [cls.generic_params["isosurface"]], scalars="amplitude" ) if initial_active_scalar not in available_scalars: raise ValueError( f"initial_active_scalar (={initial_active_scalar}) " "cannot be found in the provided data." ) contours.set_active_scalars(initial_active_scalar) initial_clim = cls.generic_params[initial_active_scalar].get("clim") if cls.generic_params[initial_active_scalar]["centred_clim"]: initial_clim = ( -np.max(data[initial_active_scalar]), np.max(data[initial_active_scalar]), ) mesh_actor = plotter.add_mesh( contours, scalars=initial_active_scalar, cmap=cls.generic_params[initial_active_scalar]["cmap"], clim=initial_clim, scalar_bar_args={ "title": initial_active_scalar.replace("_", " ").capitalize() }, ) plotter.add_axes() # get the IPython widget widget = plotter.show( jupyter_kwargs={ "add_menu_items": cls._generate_toolbar_tools( initial_active_scalar, available_scalars ) }, return_viewer=True, ) # connect Trame state with PyVista state = widget.viewer.server.state ctrl = widget.viewer.server.controller state.isosurface_value = cls.generic_params["isosurface"] state.scalar_field = initial_active_scalar state.cmap = cls.generic_params[initial_active_scalar]["cmap"] ctrl.view_update = widget.viewer.update # Trame Callbacks @state.change("isosurface_value") def update_isosurface(isosurface_value, **kwargs): """Update isosurface when slider changes.""" new_contours = structure_grid.contour( [isosurface_value], scalars="amplitude" ) new_contours.set_active_scalars(state.scalar_field) mesh_actor.mapper.dataset = new_contours ctrl.view_update() @state.change("scalar_field") def update_scalar_field(scalar_field, **kwargs): """Change the active scalar field dynamically.""" contours.set_active_scalars(scalar_field) mesh_actor.mapper.array_name = scalar_field cmap = cls.generic_params[scalar_field]["cmap"] centred_clim = cls.generic_params[scalar_field]["centred_clim"] clim_range = list(contours.get_data_range(scalar_field)) if centred_clim: max_abs = max(abs(clim_range[0]), abs(clim_range[1])) clim_range = [-max_abs, max_abs] else: clim_range = cls.generic_params[scalar_field].get("clim") mesh_actor.mapper.scalar_range = clim_range state.cmap = cmap mesh_actor.mapper.lookup_table = pv.LookupTable(cmap) plotter.remove_scalar_bar() plotter.add_scalar_bar( title=scalar_field.replace("_", " ").capitalize(), n_labels=5 ) ctrl.view_update() @state.change("cmap") def update_colourmap(cmap, **kwargs): """Update the colourmap dynamically.""" state.cmap = cmap mesh_actor.mapper.lookup_table = pv.LookupTable(cmap) plotter.remove_scalar_bar() plotter.add_scalar_bar( title=state.scalar_field.replace("_", " ").capitalize(), n_labels=5, ) ctrl.view_update() return widget @staticmethod def multi_mesh( scalar_field: np.ndarray, isosurfaces: list[float] | np.ndarray, initial_view: dict[str, float] = None, kwargs_mesh: dict[str, float | str | bool] = None, scalar_field_name: str = "Values", window_size: list[int] = None, plot_title: str = "3D view", interactive: bool = True, jupyter_backend: str = "client", ) -> None: """ Visualise a 3D scalar field using PyVista with isosurfaces. This function creates a structured 3D grid from a scalar field and generates isosurfaces (contours) for the specified values. Args: scalar_field (np.ndarray): 3D array representing the scalar field to visualize. isosurfaces (list[float] | np.ndarray): List of scalar values for which isosurfaces will be generated. initial_view (dict[str, float], optional): Dictionary specifying the initial camera position. Defaults to None. kwargs_mesh (dict, optional): Keyword arguments for PyVista's add_mesh function. Defaults to None. scalar_field_name (str, optional): Name for the scalar field. Defaults to "Values". window_size (list[int], optional): Window size in pixels. Defaults to [1100, 700]. plot_title (str, optional): Title for the plot window. Defaults to "3D view". interactive (bool, optional): Enable interactive mode. Defaults to True. jupyter_backend (str, optional): Backend for Jupyter display. Defaults to "client". Returns: None: Displays the 3D plot. Raises: PyVistaImportError: if PyVista is not installed. """ if not IS_PYVISTA_AVAILABLE: raise PyVistaImportError() if window_size is None: window_size = [1100, 700] if kwargs_mesh is None: kwargs_mesh = { "cmap": "viridis", "opacity": 0.2, "show_edges": False, "style": "wireframe", "log_scale": False, } # Create grid for PyVista nx, ny, nz = scalar_field.shape x = np.arange(nx, dtype=np.float32) y = np.arange(ny, dtype=np.float32) z = np.arange(nz, dtype=np.float32) X, Y, Z = np.meshgrid(x, y, z, indexing="ij") grid = pv.StructuredGrid(X, Y, Z) grid.point_data[scalar_field_name] = scalar_field.flatten(order="F") # Generate contours for different isosurfaces contours = grid.contour(isosurfaces=isosurfaces, method="contour") plotter = pv.Plotter() plotter.add_mesh(contours, **kwargs_mesh) # Set the initial view if provided if initial_view: if "azimuth" in initial_view: plotter.camera.Azimuth(initial_view["azimuth"]) if "elevation" in initial_view: plotter.camera.Elevation(initial_view["elevation"]) if "roll" in initial_view: plotter.camera.Roll(initial_view["roll"]) plotter.show( title=plot_title, window_size=window_size, interactive=interactive, jupyter_backend=jupyter_backend, ) # Print the current camera view after interaction current_camera = plotter.camera print( f"Current Camera View - Azimuth: {current_camera.azimuth}, " f"Elevation: {current_camera.elevation}, " f"Roll: {current_camera.roll}" ) @staticmethod def save_rotating_contours( scalar_field: np.ndarray, isosurfaces: list[float] | np.ndarray, save_directory: str, scalar_field_name: str = "Values", rotation_axis: str = "z", n_frames: int = 18, initial_view: dict[str, float] = None, kwargs_mesh: dict[str, float | str | bool] = None, window_size: list[int] = None, ) -> None: """ Generate and save rotating 3D contour plot images. Args: scalar_field (np.ndarray): 3D array to visualize. isosurfaces (list[float] | np.ndarray): List of isosurface values. save_directory (str): Directory to save images. scalar_field_name (str, optional): Name for the scalar field. Defaults to "Values". rotation_axis (str, optional): Axis of rotation ("x", "y", or "z"). Defaults to "z". n_frames (int, optional): Number of rotation frames. Defaults to 18. initial_view (dict[str, float], optional): Initial camera position. Defaults to None. kwargs_mesh (dict, optional): PyVista mesh customization. Defaults to None. window_size (list[int], optional): Window size in pixels. Defaults to [1100, 700]. Returns: None: Saves images to disk. Raises: PyVistaImportError: if PyVista is not installed. """ if not IS_PYVISTA_AVAILABLE: raise PyVistaImportError() os.makedirs(save_directory, exist_ok=True) if window_size is None: window_size = [1100, 700] if kwargs_mesh is None: kwargs_mesh = { "cmap": "viridis", "opacity": 0.2, "show_edges": False, "style": "wireframe", "log_scale": False, } # Create the grid and contours nx, ny, nz = scalar_field.shape x = np.arange(nx, dtype=np.float32) y = np.arange(ny, dtype=np.float32) z = np.arange(nz, dtype=np.float32) X, Y, Z = np.meshgrid(x, y, z, indexing="ij") grid = pv.StructuredGrid(X, Y, Z) grid.point_data[scalar_field_name] = scalar_field.flatten(order="F") contours = grid.contour(isosurfaces=isosurfaces, method="contour") plotter = pv.Plotter(window_size=window_size) plotter.add_mesh(contours, **kwargs_mesh) # Set the initial view if provided if initial_view: if "azimuth" in initial_view: plotter.camera.Azimuth(initial_view["azimuth"]) if "elevation" in initial_view: plotter.camera.Elevation(initial_view["elevation"]) if "roll" in initial_view: plotter.camera.Roll(initial_view["roll"]) # Determine rotation step angle_step = 360 / n_frames for i in range(n_frames): # Rotate the view if rotation_axis == "x": plotter.camera.Elevation(angle_step) elif rotation_axis == "y": plotter.camera.Azimuth(angle_step) elif rotation_axis == "z": plotter.camera.Roll(angle_step) else: raise ValueError("rotation_axis must be 'x', 'y', or 'z'") plotter.render() filename = os.path.join(save_directory, f"frame_{i:03d}.png") plotter.screenshot(filename) plotter.close()