Source code for cdiutils.wavefront

import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.patches import Ellipse
from scipy.fft import fft2, fftshift, ifft2, ifftshift
from scipy.ndimage import center_of_mass
from scipy.optimize import curve_fit
from scipy.signal import find_peaks, peak_widths

from cdiutils.plot.formatting import add_colorbar
from cdiutils.utils import CroppingHandler


[docs] def angular_spectrum_propagation( wavefront: np.ndarray, propagation_distance: float, wavelength: float, pixel_size: float, magnification: float = 1, do_fftshift: bool = True, verbose: bool = False, ) -> np.ndarray: """ Computes the near-field propagation of a wavefront using the Angular Spectrum Method. Parameters: wavefront (np.ndarray): 2D or 3D complex array representing the wavefront. If the input is 2D, it will be converted to 3D with one slice. It can be fftshifted or not, see do_fftshift. propagation_distance (float): Propagation distance. wavelength (float): Wavelength of the wavefront. pixel_size (float): Pixel size in the spatial domain. magnification (float, optional): Magnification factor to handle the pixel size of the propagated wavefront (default is 1). do_fftshift (bool, optional): whether to apply fftshift to the wavefront before propagation. If True, the wavefront is fftshifted before propagation and ifftshifted after. Defaults to True. verbose (bool): whether to print z limits. Returns: np.ndarray: Propagated wavefront at distance z. """ if propagation_distance == 0: return wavefront # Handle 2D and 3D input wavefronts if wavefront.ndim == 2: # Convert 2D to 3D with one slice wavefront_stack = wavefront[np.newaxis, ...] elif wavefront.ndim == 3: wavefront_stack = wavefront else: raise ValueError("Input wavefront must be a 2D or 3D array.") if do_fftshift: wavefront_stack = fftshift(wavefront_stack, axes=(-2, -1)) nz, ny, nx = wavefront_stack.shape min_dist = ( max(abs((magnification - 1) / magnification), abs(magnification - 1)) * nx * pixel_size**2 / wavelength ) max_dist = abs(magnification) * nx * pixel_size**2 / wavelength if verbose: print( "Near field magnified propagation: " f"{min_dist:.4e} < |{propagation_distance=:.4e}| < {max_dist:.4e}?" ) # Spatial coordinates in the source plane x = fftshift(np.arange(-nx // 2, nx // 2) * pixel_size) y = fftshift(np.arange(-ny // 2, ny // 2) * pixel_size) Y, X = np.meshgrid(x, y, indexing="ij") # Spatial frequency coordinates (or Fourier coordinates) fx = np.fft.fftfreq(nx, pixel_size) fy = np.fft.fftfreq(ny, pixel_size) FY, FX = np.meshgrid(fx, fy, indexing="ij") # Quadratic phase factor in the source plane Q1 = np.exp( 1j * np.pi / (wavelength * propagation_distance) * (1 - magnification) * (X**2 + Y**2) ) # Propagation kernel in the Fourier domain Q2 = np.exp( -1j * np.pi * wavelength * propagation_distance / magnification * (FX**2 + FY**2) ) # Quadratic phase factor in the observation plane Q3 = np.exp( 1j * np.pi / (wavelength * propagation_distance) * (magnification - 1) * magnification * (X**2 + Y**2) ) # Initialise the output array for the propagated wavefront propagated_wavefront = np.zeros_like(wavefront_stack, dtype=complex) # Process each 2D slice in the stack for j in range(nz): wavefront_slice = wavefront_stack[j, :, :] # Apply the source plane quadratic phase (Q1) wavefront_mod = wavefront_slice * Q1 # Fourier transform of the modified wavefront slice wavefront_ft = fft2(wavefront_mod, norm="ortho") # Apply the propagation kernel (Q2) wavefront_ft_propagated = wavefront_ft * Q2 # Inverse Fourier transform to get the propagated wavefront propagated_slice = ifft2(wavefront_ft_propagated, norm="ortho") # Apply the observation plane quadratic phase (Q3) propagated_slice = propagated_slice * Q3 # Store the result in the output array propagated_wavefront[j, :, :] = propagated_slice # Apply ifftshift to the propagated wavefront if it was shifted if do_fftshift: propagated_wavefront = ifftshift(propagated_wavefront, axes=(-2, -1)) # If the input was 2D, return a 2D array if wavefront.ndim == 2: return np.squeeze(propagated_wavefront) return propagated_wavefront
[docs] def get_width_metrics( profile: np.ndarray, axis_values: np.ndarray, verbose: bool = False ) -> dict: """ Compute the FWHM and other width metrics of a given profile. The function uses the `find_peaks` and `peak_widths` functions from `scipy.signal` to find the peaks and calculate the full width at half maximum (FWHM) and full width at 10% maximum (FW10%M) of the profile. Args: profile (np.ndarray): the profile data to be analysed. axis_values (np.ndarray): the axis values corresponding to the profile data. verbose (bool, optional): whether to print out info. Defaults to False. Returns: dict: a dictionary containing the FWHM and FW10%M values, their indices, heights, and boundaries. """ # find peaks peaks, properties = find_peaks(profile, height=0.5 * np.max(profile)) # use the highest peak or the middle if no peak is found if len(peaks) > 0: max_index = np.argmax(properties["peak_heights"]) highest_peak_idx = peaks[max_index] else: highest_peak_idx = len(profile) // 2 # for absolute FWHM - using original peak_widths method fwhm_indices, fwhm_height, left_idx, right_idx = peak_widths( profile, [highest_peak_idx], rel_height=0.5 ) # for absolute FW10%M - using peak_widths with 0.9 relative height # (10% from max) fw10m_indices, fw10m_height, left10_idx, right10_idx = peak_widths( profile, [highest_peak_idx], rel_height=0.9 ) # convert width indices to physical units using axis values # for FWHM left_pos = np.interp(left_idx[0], np.arange(len(axis_values)), axis_values) right_pos = np.interp( right_idx[0], np.arange(len(axis_values)), axis_values ) fwhm = abs(right_pos - left_pos) # for FW10%M left10_pos = np.interp( left10_idx[0], np.arange(len(axis_values)), axis_values ) right10_pos = np.interp( right10_idx[0], np.arange(len(axis_values)), axis_values ) fw10m = abs(right10_pos - left10_pos) # Statistical FWHM using Gaussian fitting def gaussian(x, amp, mean, sigma, offset): return amp * np.exp(-((x - mean) ** 2) / (2 * sigma**2)) + offset try: # fit Gaussian to profile p0 = [ np.max(profile), axis_values[highest_peak_idx], fwhm / 2.355, 0, ] # initial guess params, _ = curve_fit(gaussian, axis_values, profile, p0=p0) gauss_fwhm = 2.355 * params[2] # FWHM = 2.355 * sigma for Gaussian # generate fitted curves for plotting gauss_fit = gaussian(axis_values, *params) fit_success = True except Exception as e: if verbose: print(f"Gaussian fitting failed: {e}") gauss_fwhm = float("nan") gauss_fit, params = None, None fit_success = False if verbose: print( f"Highest peak at index {highest_peak_idx}, " f"pos = {axis_values[highest_peak_idx]:.2e}, " f"value = {profile[highest_peak_idx]:.2e}\n" f"FWHM of the probe: {fwhm:.2e} m " f"({fwhm_indices[0]:.1f} pixels)\n" ) return { "highest_peak_idx": highest_peak_idx, "fwhm": { "value": fwhm, "indices": fwhm_indices[0], "height": fwhm_height[0], "boundaries": (left_pos, right_pos), }, "fw10m": { "value": fw10m, "indices": fw10m_indices[0], "height": fw10m_height[0], "boundaries": (left10_pos, right10_pos), }, "gaussian": { "value": gauss_fwhm, "fit": gauss_fit, "success": fit_success, "params": params, }, }
[docs] def probe_metrics( probe: np.ndarray, pixel_size: tuple, zoom_factor: int | str = "auto", probe_convention: str = "pynx", centre_at_max: bool = False, verbose: bool = False, ) -> tuple: """ Plot the probe along with the line profile and its FWHM estimate. The probe is displayed in a 2D image, and the line profile is shown as a 1D plot. The FWHM is calculated and displayed on the line profile. If modes is True, all modes are plotted, otherwise only the first mode is plotted. > Notes: In the PyNX imshow plot of the probe, the origin is set to "lower" and exent is set to (-x_min, x_max, y_min, y_max). This means that in the PyNX convention, the probe is stored as a 2D (y = y_cxi, x = -x_cxi) array. Args: probe (np.ndarray): the probe data to be plotted in matrix (y, x) convention. pixel_size (tuple): the pixel size of the probe data (y, x) zoom_factor (int | str, optional): the zoom factor for the probe plot. If "auto", the window is set to be 3 times the FWHM. Defaults to "auto". probe_convention (str, optional): the convention used for the probe data. If "pynx", the probe is stored as a 2D (y = y_cxi, x = -x_cxi) array. Defaults to "pynx". centre_at_max (bool, optional): if True, the probe is centred before the analysis. Defaults to False. verbose (bool, optional): if True, prints additional information. Defaults to False. Returns: tuple: the figure and axes objects for the probe and line profile plots. """ if probe.ndim != 2: raise ValueError("Probe must be 2D array (2D single-mode probe)") if probe_convention != "pynx": raise ValueError( "Only PyNX convention is supported for now. " "Use probe_convention='pynx'." ) # probe amplitude or intensity probe_amplitude = np.abs(probe) # check if the probe is well centred com = center_of_mass(probe_amplitude) if verbose: print( f"Probe intensity centre of mass: {com[0]:.2f} (y), " f"{com[1]:.2f} (x) pixels\n" f"Probe shape: {probe.shape}." ) if centre_at_max: if verbose: print("Centring the probe at the maximum intensity.") probe = CroppingHandler.force_centred_cropping( probe, where="max", verbose=verbose ) # recompute the probe intensity in the new cropped frame probe_amplitude = np.abs(probe) # initialise the main dictionary containing all the line profile # metrics metrics = {"x": {"color": "dodgerblue"}, "y": {"color": "lightcoral"}} for i, axis in enumerate(["y", "x"]): metrics[axis]["axis_values"] = np.linspace( -pixel_size[i] * probe.shape[i] / 2, pixel_size[i] * probe.shape[i] / 2, probe.shape[i], ) # the probe is stored as a (y = y_cxi, x = -x_cxi) array, so we need # to flip the x-axis extent. metrics["x"]["axis_values"] = np.flip(metrics["x"]["axis_values"]) metrics["x"]["profile"] = probe_amplitude[probe.shape[0] // 2, :] metrics["x"]["centre"] = metrics["x"]["axis_values"][ probe.shape[1] // 2 # the pos that serves to get the y profile ] metrics["y"]["profile"] = probe_amplitude[:, probe.shape[1] // 2] metrics["y"]["centre"] = metrics["y"]["axis_values"][ probe.shape[0] // 2 # the pos that serves to get the x profile ] extent = ( metrics["x"]["axis_values"][0], metrics["x"]["axis_values"][-1], metrics["y"]["axis_values"][0], metrics["y"]["axis_values"][-1], ) # compute FWHM and other metrics for x and y profiles for axis in ["x", "y"]: metrics[axis].update( get_width_metrics( metrics[axis]["profile"], metrics[axis]["axis_values"], verbose=verbose, ) ) metrics[axis]["peak_pos"] = metrics[axis]["axis_values"][ metrics[axis]["highest_peak_idx"] ] # calculate ROI boundaries (in physical units) if zoom_factor == "auto": # 8x FWHM. /2 means each side zoom_extent = ( 8 / 2 * np.array( [ metrics["x"]["fwhm"]["value"], -metrics["x"]["fwhm"]["value"], -metrics["y"]["fwhm"]["value"], metrics["y"]["fwhm"]["value"], ] ) ) elif zoom_factor == 1: zoom_extent = extent else: zoom_extent = np.array(extent) * 1 / zoom_factor metrics["x"]["min"] = metrics["x"]["centre"] + zoom_extent[1] # inverted! metrics["x"]["max"] = metrics["x"]["centre"] + zoom_extent[0] # inverted! metrics["y"]["min"] = metrics["y"]["centre"] + zoom_extent[2] metrics["y"]["max"] = metrics["y"]["centre"] + zoom_extent[3] fig = plt.figure(figsize=(6, 6), layout="tight") gs = gridspec.GridSpec(3, 2, figure=fig, height_ratios=[0.4, 0.4, 0.2]) axes = np.array( [ [fig.add_subplot(gs[0, 0]), fig.add_subplot(gs[0, 1])], [fig.add_subplot(gs[1, 0]), fig.add_subplot(gs[1, 1])], ] ) table_ax = fig.add_subplot(gs[2, :]) # Show the probe intensity imshow_kwargs = { "extent": extent, "origin": "lower", } X, Y = np.meshgrid( metrics["x"]["axis_values"], metrics["y"]["axis_values"] ) axes[0, 0].imshow(probe_amplitude, cmap="viridis", **imshow_kwargs) add_colorbar(axes[0, 0]) axes[0, 0].set_title(r"Probe amplitude ($|\mathcal{P}|$, a. u.)") opacity = np.abs(probe) / np.max(np.abs(probe)) axes[0, 1].imshow( np.angle(probe), alpha=opacity, cmap="cet_CET_C9s_r", **imshow_kwargs ) axes[0, 1].set_facecolor("black") add_colorbar(axes[0, 1], extend="both") axes[0, 1].set_title(r"Probe phase ($\text{arg}(\mathcal{P})$, rad)") # # Add FWHM indicators as ellipse indicator_params = {"alpha": 0.5, "lw": 0.5, "linestyle": "--"} ellipse = Ellipse( (metrics["x"]["centre"], metrics["y"]["centre"]), width=metrics["x"]["fwhm"]["value"], height=metrics["y"]["fwhm"]["value"], edgecolor="w", facecolor="none", **indicator_params, ) axes[0, 0].add_patch(ellipse) # Add crosshair at peak axes[0, 0].axvline(x=metrics["x"]["centre"], color="w", **indicator_params) axes[0, 0].axhline(y=metrics["y"]["centre"], color="w", **indicator_params) # Set limits to zoomed region for ax in axes[0, :]: ax.set_xlim(metrics["x"]["max"], metrics["x"]["min"]) # inverted! ax.set_ylim(metrics["y"]["min"], metrics["y"]["max"]) # ax.set_ylim(xy_min, xy_max) ax.set_xlabel(r"$x_{\text{CXI}}$ (m)") ax.set_ylabel(r"$y_{\text{CXI}}$ (m)") for i, (subplot, axis) in enumerate(zip(axes[1, :], ["x", "y"])): subplot.plot( metrics[axis]["axis_values"], metrics[axis]["profile"], color=metrics[axis]["color"], label="line profile", lw=1, ) subplot.axvspan( *metrics[axis]["fwhm"]["boundaries"], color=metrics[axis]["color"], alpha=0.4, label="FWHM width", lw=0, ) subplot.axvspan( *metrics[axis]["fw10m"]["boundaries"], color=metrics[axis]["color"], alpha=0.1, label="FW10M width", lw=0, ) # Gaussian fit if metrics[axis]["gaussian"]["success"]: subplot.plot( metrics[axis]["axis_values"], metrics[axis]["gaussian"]["fit"], color="k", label="gaussian fit", marker="o", markersize=2, markerfacecolor=metrics[axis]["color"], markeredgewidth=0.4, markeredgecolor="k", lw=0.25, ) # other markers line_params = { "lw": 0.5, "linestyle": "--", "color": "k", "alpha": 0.5, "label": "peak position", } subplot.axvline( x=metrics[axis]["axis_values"][metrics[axis]["highest_peak_idx"]], **line_params, ) # titles and labels formatted_axis = r"$" + axis + r"_{\text{CXI}}$" subplot.set_title(f"{formatted_axis} profile") subplot.set_xlabel(f"{formatted_axis} (m)") subplot.set_ylabel(r"$|\mathcal{P}|$ (a. u.)") subplot.legend(frameon=False, fontsize=6) # x plot, min and max are inverted axes[1, 0].set_xlim(metrics["x"]["max"], metrics["x"]["min"]) axes[1, 1].set_xlim(metrics["y"]["min"], metrics["y"]["max"]) # Add a table with the metrics table_ax.axis("off") metrics_table = [ [ "Metric", r"$x_{\text{CXI}}$ direction", r"$y_{\text{CXI}}$ direction", ], [ "Absolute FWHM", f"{metrics['x']['fwhm']['value']:.2e} m", f"{metrics['y']['fwhm']['value']:.2e} m", ], [ "FW10%M", f"{metrics['x']['fw10m']['value']:.2e} m", f"{metrics['y']['fw10m']['value']:.2e} m", ], [ "Gaussian FWHM", f"{metrics['x']['gaussian']['value']:.2e} m", f"{metrics['y']['gaussian']['value']:.2e} m", ], [ "Probe peak", f"{metrics['x']['peak_pos']:.2e} m", f"{metrics['y']['peak_pos']:.2e} m", ], ] table = table_ax.table( cellText=metrics_table, loc="center", cellLoc="center" ) table.auto_set_font_size(False) table.set_fontsize(7) for (i, _), cell in table.get_celld().items(): cell.set_linewidth(0.3) if i == 0: cell.set_text_props(fontweight="bold") return fig, axes
[docs] def probe_focus_sweep( probe: np.ndarray, pixel_size: tuple, wavelength: float, step_nb: int = 100, step_size: float = 10e-6, ) -> np.ndarray: """ Propagate the probe through a range of distances using the angular spectrum method. The function computes the propagated probe at each distance and returns the propagated probe and the corresponding propagation positions. Args: probe (np.ndarray): the probe to be propagated, in the shape (height, width) or (modes, height, width). pixel_size (tuple): the pixel size of the probe in meters. wavelength (float): the wavelength of the X-ray beam in meters. step_nb (int, optional): the number of steps for the propagation sweep. This defines the number of propagation positions. Defaults to 100. step_size (_type_, optional): the step size for the propagation sweep in meters. This defines the distance between each propagation position. Defaults to 10e-6 m. Returns: np.ndarray: the propagated probe at each propagation position, in the shape (step_nb, height, width) or (modes, step_nb, height, width). np.ndarray: the propagation positions in meters. """ propagation_positions = step_size * np.arange(-step_nb // 2, step_nb // 2) progpated_probe = np.empty( propagation_positions.shape + probe.shape, dtype=np.complex64 ) for i, distance in enumerate(propagation_positions): progpated_probe[i] = angular_spectrum_propagation( probe, propagation_distance=distance, # in meters wavelength=wavelength, pixel_size=pixel_size[0], do_fftshift=True, # If true fftshift before and ifftshit after verbose=False, ) if probe.ndim == 3: # multiple modes progpated_probe = progpated_probe.transpose(1, 0, 2, 3) return progpated_probe, propagation_positions
[docs] def plot_propagated_probe( propagated_probe: np.ndarray, pixel_size: tuple | float, propagation_step_size: float, convert_to_microns: bool = True, focal_distances: tuple | None = None, plot_phase: bool = True, ) -> tuple[plt.Figure, plt.Axes]: """ Plot the propagated probe as a 2D image with the phase or amplitude information. The function displays the probe at different propagation positions, with the option to plot the phase or amplitude. It also allows for the conversion of pixel size and propagation step size to microns. If focal distances are provided, vertical lines are drawn at those positions. Args: propagated_probe (np.ndarray): the propagated probe data, in the shape (step_nb, height, width) or (modes, step_nb, height, width). If 3D, the first dimension is considered as the propagation axis, and the rest are height and width. If 4D, will take the first mode only. pixel_size (tuple | float): the pixel size of the probe in meters. propagation_step_size (float, optional): the step size for the propagation sweep in meters. This defines the distance between each propagation position. If convert_to_microns is True, this value is converted to microns. convert_to_microns (bool, optional): if True, converts the pixel size and propagation step size to microns. This is useful for displaying the probe in microns instead of meters. Defaults to True. focal_distances (tuple | None, optional): the focal distances where vertical lines are drawn on the plot. If provided, it should be a tuple of two values representing the focal distances in meters for both directions. If convert_to_microns is True, these values are converted to microns. If None, no vertical lines are drawn. Defaults to None. plot_phase (bool, optional): if True, plots the phase of the propagated probe. If False, plots the amplitude. The opacity of the phase plot is set to the sum of the absolute values along the propagation axis, normalised to the maximum value. Defaults to True. Raises: ValueError: if the propagated_probe is not 3D or 4D, or if the pixel_size is not a float or a tuple of floats. Returns: tuple[plt.Figure, plt.Axes]: the figure and axes objects for the propagated probe plot. """ if propagated_probe.ndim == 4: propagated_probe = propagated_probe[0] elif propagated_probe.ndim != 3: raise ValueError( "Expected propagated_probe to be 3 or 4D (propagation axis, " "height, width)." ) if isinstance(pixel_size, float): pixel_size = (pixel_size, pixel_size, pixel_size) if convert_to_microns: pixel_size = tuple(p * 1e6 for p in pixel_size) propagation_step_size *= 1e6 unit_label = "μm" if convert_to_microns else "m" if focal_distances is not None: if convert_to_microns: focal_distances = tuple(f * 1e6 for f in focal_distances) fig, axes = plt.subplots(2, 1, figsize=(5, 3), layout="tight", sharex=True) plot_params = { "cmap": "cet_CET_C9s_r" if plot_phase else "turbo", "alpha": None, "origin": "lower", "aspect": "auto", } for i, ax in enumerate(np.flip(axes).flat): if plot_phase: opacity = np.abs(propagated_probe).sum(axis=2 - i).T opacity /= opacity.max() slices = [slice(None)] * propagated_probe.ndim slices[2 - i] = propagated_probe.shape[2 - i] // 2 to_plot = np.angle(propagated_probe[tuple(slices)]).T ax.set_facecolor("black") plot_params["alpha"] = opacity else: to_plot = np.abs(propagated_probe).sum(axis=2 - i).T ax.imshow( to_plot, extent=( propagation_step_size * -propagated_probe.shape[0] / 2, propagation_step_size * propagated_probe.shape[0] / 2, -pixel_size[i] * propagated_probe.shape[i + 1] / 2, pixel_size[i] * propagated_probe.shape[i + 1] / 2, ), **plot_params, ) ax.axvline( x=0, color="white", linestyle="-", linewidth=0.125, ) if focal_distances is not None: label = ( f"focal distance = {int(focal_distances[1 - i])} {unit_label}" ) ax.axvline( x=focal_distances[1 - i], color="white", linestyle="--", label=label, ) axes[0].set_ylabel(r"$y_{\text{CXI}}$" + f", height ({unit_label})") axes[1].set_ylabel(r"$x_{\text{CXI}}$" + f", width ({unit_label})") axes[1].set_xlabel( r"$z_{\text{CXI}}$" + f", propagation distance ({unit_label})" ) quantity = "phase" if plot_phase else "amplitude" fig.suptitle(f"Propagated probe ({quantity})") for ax in axes.flat: add_colorbar(ax, extend="both", size="2%") legend = ax.legend(frameon=False, fontsize=6) for text in legend.get_texts(): text.set_color("white") return fig, axes
[docs] def get_focal_distances( propagated_probe: np.ndarray, propagation_positions: np.ndarray, method: str = "max", ) -> tuple[tuple[float, float], tuple[int, int]]: """ Get the focal distances from the propagated probe data. The function computes the focal distances by reducing the probe data along the propagation axis using the specified method (sum or max). It returns the focal distances and the corresponding indexes in the propagation positions array. Args: propagated_probe (np.ndarray): the propagated probe data, in the shape (step_nb, height, width) or (modes, step_nb, height, width). If 3D, the first dimension is considered as the propagation axis, and the rest are height and width. If 4D, will take the first mode only. propagation_positions (np.ndarray): the propagation positions in meters. method (str, optional): the method to use for reducing the probe data. Can be "sum" or "max". If "sum", the function computes the sum of the absolute values along the propagation axis and then finds the maximum. Defaults to "max". Raises: ValueError: if the propagated_probe is not 3D or 4D. ValueError: if the method is not "sum" or "max". Returns: tuple[tuple[float, float], tuple[int, int]]: the focal distances as a tuple of two floats (focal_distance_1, focal_distance_2) and the corresponding indexes in the propagation positions array as a tuple of two integers (index_1, index_2). """ if propagated_probe.ndim == 4: propagated_probe = propagated_probe[0] elif propagated_probe.ndim != 3: raise ValueError( "Expected propagated_probe to be 3 or 4D (propagation axis, " "height, width)." ) if method == "sum": reducing_function = np.sum elif method == "max": reducing_function = np.max else: raise ValueError(f"Unknown method: {method}") focal_distances, indexes = [], [] for i in range(2): indexes.append( np.argmax( reducing_function( np.sum(np.abs(propagated_probe), axis=i + 1), axis=1 ), ) ) focal_distances.append(propagation_positions[indexes[-1]]) return tuple(focal_distances), indexes
[docs] def focus_probe( probe: np.ndarray, pixel_size: tuple, wavelength: float, step_nb: int = 200, step_size: float = 10e-6, plot: bool = True, **plot_kwargs, ) -> tuple: """ Complete analysis of probe focus characteristics by propagating the probe through a range of distances and computing the focal distances. The function performs a probe focus sweep, computes the focal distances using the specified method (max), and plots the propagated probe with the focal distances. It returns the focused probe and the focal distances. Args: probe (np.ndarray): the probe to be propagated, in the shape (height, width) or (modes, height, width). If 3D, the first dimension is considered as the propagation axis, and the rest are height and width. If 4D, will take the first mode only. pixel_size (tuple): the pixel size of the probe in meters, as a tuple (height, width) or a single float value for both dimensions. wavelength (float): the wavelength of the X-ray beam in meters. step_nb (int, optional): the number of steps for the propagation sweep. This defines the number of propagation positions. It determines how many times the probe is propagated through the range of distances. Defaults to 100. step_size (float, optional): the step size for the propagation sweep in meters. This defines the distance between each propagation position. Defaults to 10e-6. plot (bool, optional): whether to plot the propagated probe and focal distances. If True, the function will plot the propagated probe with the focal distances. If False, no plot is generated.Defaults to True. Returns: tuple: (focused_probe, focal_distances) """ propagated_probe, propagation_positions = probe_focus_sweep( probe, pixel_size, wavelength, step_nb, step_size ) focal_distances, indexes = get_focal_distances( propagated_probe, propagation_positions, method="max" ) if propagated_probe.ndim == 4: focused_probe = propagated_probe[:, indexes[1], ...] elif propagated_probe.ndim == 3: focused_probe = propagated_probe[indexes[0]] if plot: plot_propagated_probe( propagated_probe, pixel_size, propagation_step_size=step_size, focal_distances=focal_distances, **plot_kwargs, ) return focused_probe, focal_distances[1]