"""
3D viewer widget for interactive visualisation of CDI reconstruction data.
This module provides the ThreeDViewer class for displaying 3D objects
from CDI optimisation results using Plotly.
"""
import ipywidgets as widgets
import numpy as np
from IPython.display import HTML, display
try:
import plotly.graph_objects as go
from scipy.interpolate import RegularGridInterpolator
from skimage.measure import marching_cubes
IS_PLOTLY_AVAILABLE = True
except ImportError:
IS_PLOTLY_AVAILABLE = False
# check if volume module is available for shared utilities
try:
from .volume import (
_extract_isosurface_with_values,
colorcet_to_plotly,
)
HAS_VOLUME_UTILS = True
except ImportError:
HAS_VOLUME_UTILS = False
import matplotlib.pyplot as plt
def colorcet_to_plotly(cmap_name: str, n_colors: int = 256):
"""Fallback colormap converter."""
if cmap_name not in plt.colormaps():
raise ValueError(f"Colormap '{cmap_name}' not found.")
cmap = plt.get_cmap(cmap_name)
colors = [cmap(i) for i in np.linspace(0, 1, n_colors)]
return [
[
i / (n_colors - 1),
f"rgb({int(c[0] * 255)},{int(c[1] * 255)},{int(c[2] * 255)})",
]
for i, c in enumerate(colors)
]
[docs]
class ThreeDViewer(widgets.Box):
"""
Widget to display 3D objects from CDI optimisation using Plotly.
This class provides interactive 3D visualisation of volumetric data
with controls for threshold, phase/amplitude display, and colormap
selection.
Interactive controls:
- Threshold slider: controls the isosurface level
- Phase/Amplitude toggle: switches between phase and amplitude
display
- Colormap dropdown: selects the colormap for the surface colour
- Auto-scale checkbox: automatically scales the colorbar to data
range
- Symmetric checkbox: forces the colorbar to be symmetric around
zero
- Set limits checkbox: enables manual vmin/vmax input fields
- Replace NaN with mean checkbox: replaces NaN values in the
displayed quantity with the mean value to avoid weird
colouring artefacts
"""
# colormaps (1D - standard matplotlib/colorcet)
cmap_options = (
"turbo",
"viridis",
"inferno",
"magma",
"plasma",
"cividis",
"RdBu",
"coolwarm",
"twilight",
"Blues",
"Greens",
"Greys",
"Purples",
"Oranges",
"Reds",
"cet_CET_D13",
"cet_CET_C9s_r",
"cet_CET_D1A",
"jch_const",
"jch_max",
)
[docs]
def __init__(
self,
input_file: np.ndarray | None = None,
html_width: int | None = None,
voxel_size: tuple = (1, 1, 1),
figsize: tuple = (9, 6),
):
"""
Initialise the 3D viewer with Plotly backend.
Args:
input_file (np.ndarray | None, optional): 3D complex array
to visualise. Defaults to None.
html_width (int | None, optional): HTML width in %. If
given, the width of the notebook will be changed to that
value (e.g. full width with 100). Defaults to None.
voxel_size (tuple, optional): voxel size (dx, dy, dz) for
proper scaling. Defaults to (1, 1, 1).
figsize (tuple, optional): figure size in inches
(width, height). Defaults to (9, 6).
Raises:
ImportError: if plotly or required packages are not
installed.
"""
if not IS_PLOTLY_AVAILABLE:
raise ImportError(
"ThreeDViewer requires plotly, scikit-image, and scipy. "
"Install with: pip install cdiutils[interactive]"
)
super().__init__()
if html_width is not None:
html_code = f"""
<style>.container {{
width:{int(html_width)}%
!important; }}</style>
"""
display(HTML(html_code))
# store parameters
self.voxel_size = np.array(voxel_size)
self.figsize = figsize
# create plotly figure
self.fig = go.FigureWidget()
self.fig.update_layout(
template="plotly_white",
scene=dict(
xaxis=dict(showbackground=True, title="x"),
yaxis=dict(showbackground=True, title="y"),
zaxis=dict(showbackground=True, title="z"),
aspectmode="data",
camera=dict(
eye=dict(x=1.5, y=1.5, z=1.5),
# improve zoom sensitivity
projection=dict(type="perspective"),
),
),
width=figsize[0] * 96,
height=figsize[1] * 96,
dragmode="orbit",
)
# create control widgets
self.threshold = widgets.FloatSlider(
value=5,
min=0,
max=20,
step=0.02,
description="Threshold:",
disabled=False,
continuous_update=False,
orientation="horizontal",
readout=True,
readout_format=".2f",
)
self.toggle_phase = widgets.ToggleButtons(
options=["Amplitude", "Phase"],
description="Display:",
disabled=False,
value="Phase",
button_style="",
)
self.toggle_rotate = widgets.ToggleButton(
value=False,
description="Rotate",
tooltips="Rotate view",
)
# colormap dropdown - default depends on mode
self.colormap = widgets.Dropdown(
options=self.cmap_options,
value="cet_CET_C9s_r",
description="Colormap:",
disabled=False,
)
self.theme_toggle = widgets.ToggleButton(
value=False,
description="Dark Theme",
tooltips="Toggle dark/light theme",
)
# colorbar control checkboxes
self.auto_scale = widgets.Checkbox(
value=True,
description="Auto-scale colorbar",
tooltips="Scale colorbar to min/max of current plot",
indent=False,
)
self.symmetric_scale = widgets.Checkbox(
value=False,
description="Symmetric colorbar",
tooltips="Center colorbar at 0 (for strain, phase, etc.)",
indent=False,
)
self.replace_nan = widgets.Checkbox(
value=False,
description="Replace NaN with mean",
tooltips="Replace NaN values with mean (fixes weird colouring)",
indent=False,
)
# set observers
self.threshold.observe(self._on_update_plot, names="value")
self.toggle_phase.observe(self._on_change_display, names="value")
self.colormap.observe(self._on_update_plot, names="value")
self.theme_toggle.observe(self._on_update_style, names="value")
self.toggle_rotate.observe(self._on_animate, names="value")
self.auto_scale.observe(self._on_update_plot, names="value")
self.symmetric_scale.observe(self._on_update_plot, names="value")
self.replace_nan.observe(self._on_update_plot, names="value")
# internal state
self.data = None
self.rgi = None # interpolator
self._rotation_angle = 0
self._rotation_callback = None
# create layout
controls_row1 = widgets.HBox([self.threshold])
controls_row2 = widgets.HBox([self.toggle_phase, self.toggle_rotate])
controls_row3 = widgets.HBox([self.colormap, self.theme_toggle])
controls_row4 = widgets.HBox(
[self.auto_scale, self.symmetric_scale, self.replace_nan]
)
self.vbox = widgets.VBox(
[controls_row1, controls_row2, controls_row3, controls_row4]
)
# load data if provided
if isinstance(input_file, np.ndarray):
self.set_data(input_file)
# set children for the Box widget
self.children = [self.fig, self.vbox]
[docs]
def show(self) -> None:
"""Display the 3D viewer widget."""
display(self)
[docs]
def set_data(
self, data: np.ndarray, threshold: float | None = None
) -> None:
"""
Set the 3D data to visualise.
Args:
data (np.ndarray): 3D complex array to visualise.
threshold (float | None, optional): initial threshold value.
If None, uses current slider value. Defaults to None.
"""
self.data = data
# create interpolator for getting values at mesh vertices
nz, ny, nx = data.shape
grid_z = np.arange(nz)
grid_y = np.arange(ny)
grid_x = np.arange(nx)
self.rgi = RegularGridInterpolator(
(grid_z, grid_y, grid_x),
data,
bounds_error=False,
fill_value=0,
)
# update threshold range if needed, handle NaN values
amp = np.abs(data)
self.threshold.max = float(np.nanmax(amp))
if threshold is not None:
self.threshold.value = threshold
# initial plot
self._on_update_plot()
def _on_update_plot(self, change=None) -> None:
"""
Update the plot according to parameters.
Args:
change: widget change event (not used but required by
observer).
"""
if self.data is None:
return
try:
# use shared helper function if available
if HAS_VOLUME_UTILS:
verts_scaled, faces, vals = _extract_isosurface_with_values(
self.data,
self.data,
self.threshold.value,
self.voxel_size,
use_interpolator=True, # needed for complex
)
else:
# fallback: inline implementation
verts, faces, _, _ = marching_cubes(
np.abs(self.data),
level=self.threshold.value,
step_size=1,
)
verts_scaled = verts * self.voxel_size
vals = self.rgi(verts)
# optionally replace NaN values with mean to fix weird
# colouring
if self.replace_nan.value and np.any(np.isnan(vals)):
mean_val = np.nanmean(vals)
vals = np.where(np.isnan(vals), mean_val, vals)
# determine colours based on display mode
if self.toggle_phase.value == "Phase":
# get phase values
phase_vals = np.angle(vals)
# determine colour range based on settings
if self.symmetric_scale.value:
# symmetric around 0, use actual phase values
intensity = phase_vals
cmin, cmax = -np.pi, np.pi
colorbar = dict(
title="Phase (rad)",
tickmode="array",
tickvals=[-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi],
ticktext=["-π", "-π/2", "0", "π/2", "π"],
len=0.7,
x=0.85,
showticklabels=True,
thickness=20,
lenmode="fraction",
xanchor="left",
)
elif self.auto_scale.value:
# auto-scale to actual data range, handle NaN
intensity = phase_vals
cmin, cmax = (
float(np.nanmin(phase_vals)),
float(np.nanmax(phase_vals)),
)
colorbar = dict(
title="Phase (rad)",
len=0.7,
x=0.85,
showticklabels=True,
thickness=20,
lenmode="fraction",
xanchor="left",
)
else:
# normalise to [0, 1] for full range
intensity = (phase_vals + np.pi) / (2 * np.pi)
cmin, cmax = 0, 1
colorbar = dict(
title="Phase (rad)",
tickmode="array",
tickvals=[0, 0.25, 0.5, 0.75, 1],
ticktext=["-π", "-π/2", "0", "π/2", "π"],
len=0.7,
x=0.85,
showticklabels=True,
thickness=20,
lenmode="fraction",
xanchor="left",
)
vertex_colors = None
colorscale = colorcet_to_plotly(self.colormap.value)
else: # amplitude
# use actual amplitude values (not normalised)
intensity = np.abs(vals)
# determine colour range based on settings
if self.symmetric_scale.value:
# symmetric around 0 - doesn't make much sense
# for amplitude but keep for consistency; centre
# at mean
mean_val = float(np.nanmean(intensity))
max_dev = float(
max(
np.nanmax(intensity) - mean_val,
mean_val - np.nanmin(intensity),
)
)
cmin = mean_val - max_dev
cmax = mean_val + max_dev
elif self.auto_scale.value:
# auto-scale to actual data range, handle NaN
cmin, cmax = (
float(np.nanmin(intensity)),
float(np.nanmax(intensity)),
)
else:
# use full range from data, handle NaN
cmin, cmax = (
float(np.nanmin(intensity)),
float(np.nanmax(intensity)),
)
vertex_colors = None
colorscale = colorcet_to_plotly(self.colormap.value)
colorbar = dict(
title="Amplitude",
len=0.7,
x=0.85,
showticklabels=True,
thickness=20,
lenmode="fraction",
xanchor="left",
)
# update or create mesh
with self.fig.batch_update():
if len(self.fig.data) == 0:
# first time - add mesh
mesh_args = dict(
x=verts_scaled[:, 0],
y=verts_scaled[:, 1],
z=verts_scaled[:, 2],
i=faces[:, 0],
j=faces[:, 1],
k=faces[:, 2],
intensity=intensity,
vertexcolor=vertex_colors,
colorscale=colorscale,
colorbar=colorbar,
cmin=cmin,
cmax=cmax,
opacity=1.0,
flatshading=False,
lighting=dict(
ambient=0.85,
diffuse=0.1,
specular=0.5,
roughness=0.2,
fresnel=0.5,
),
)
self.fig.add_trace(go.Mesh3d(**mesh_args))
else:
# update existing mesh
self.fig.data[0].x = verts_scaled[:, 0]
self.fig.data[0].y = verts_scaled[:, 1]
self.fig.data[0].z = verts_scaled[:, 2]
self.fig.data[0].i = faces[:, 0]
self.fig.data[0].j = faces[:, 1]
self.fig.data[0].k = faces[:, 2]
self.fig.data[0].intensity = intensity
self.fig.data[0].vertexcolor = vertex_colors
self.fig.data[0].colorscale = colorscale
self.fig.data[0].colorbar = colorbar
self.fig.data[0].cmin = cmin
self.fig.data[0].cmax = cmax
except Exception as e:
print(f"Error updating plot: {e}")
def _on_change_display(self, change) -> None:
"""
Handle display mode change (amplitude/phase).
Args:
change: widget change event.
"""
if change["name"] == "value":
# switch default colormap based on mode
if change["new"] == "Phase":
# use cyclic colormap for phase (good defaults)
if self.colormap.value not in [
"twilight",
"cet_CET_C9s_r",
"jch_const",
"jch_max",
]:
self.colormap.value = "cet_CET_C9s_r"
# enable symmetric scale for phase (centered at 0)
if not self.symmetric_scale.value:
self.symmetric_scale.value = True
else:
# use sequential colormap for amplitude
if self.colormap.value in [
"twilight",
"cet_CET_C9s_r",
"jch_const",
"jch_max",
]:
self.colormap.value = "turbo"
# disable symmetric scale for amplitude (usually not needed)
if self.symmetric_scale.value:
self.symmetric_scale.value = False
# update plot
self._on_update_plot()
def _on_update_style(self, change) -> None:
"""
Update the plot style (theme).
Args:
change: widget change event.
"""
if change["name"] == "value":
if self.theme_toggle.value:
self.fig.update_layout(template="plotly_dark")
else:
self.fig.update_layout(template="plotly_white")
def _on_animate(self, change) -> None:
"""
Handle rotation animation toggle.
Args:
change: widget change event.
"""
if change["name"] == "value":
if change["new"]: # start rotation
self._start_rotation()
else: # stop rotation
self._stop_rotation()
def _start_rotation(self) -> None:
"""Start continuous rotation animation."""
import asyncio
async def rotate():
"""Async rotation loop."""
while self.toggle_rotate.value:
self._rotation_angle += 2
# update camera azimuth
eye_x = 1.5 * np.cos(np.radians(self._rotation_angle))
eye_y = 1.5 * np.sin(np.radians(self._rotation_angle))
eye_z = 1.5
with self.fig.batch_update():
self.fig.layout.scene.camera.eye = dict(
x=eye_x, y=eye_y, z=eye_z
)
await asyncio.sleep(0.05) # ~20 FPS
# create and run task
loop = asyncio.get_event_loop()
self._rotation_callback = loop.create_task(rotate())
def _stop_rotation(self) -> None:
"""Stop rotation animation."""
if self._rotation_callback is not None:
self._rotation_callback.cancel()
self._rotation_callback = None