Source code for cdiutils.pipeline.bcdi

"""
Definition of the BcdiPipeline class.

Authors:
    * Clément Atlan, clement.atlan@esrf.fr - 09/2024
"""

# Built-in dependencies.
import copy
import glob
import os
import subprocess
from string import Template

# Dependencies.
import h5py
import numpy as np
import yaml
from tabulate import tabulate

from cdiutils.analysis.stats import find_isosurface

# General cdiutils classes, to handle loading/saving, beamline geometry and
# space conversion.
from cdiutils.converter import SpaceConverter
from cdiutils.facetanalysis import FacetAnalysisProcessor
from cdiutils.geometry import Geometry
from cdiutils.io import CXIFile, Loader, load_cxi
from cdiutils.io.vtk import IS_VTK_AVAILABLE, save_as_vti
from cdiutils.plot.colormap import RED_TO_TEAL
from cdiutils.plot.slice import plot_volume_slices
from cdiutils.plot.volume import plot_3d_surface_projections
from cdiutils.process.phaser import PhasingResultAnalyser, PyNXImportError
from cdiutils.process.postprocessor import PostProcessor

# Utility functions
from cdiutils.utils import (
    CroppingHandler,
    ensure_pynx_shape,
    fill_up_support,
    get_oversampling_ratios,
    hot_pixel_filter,
    normalise,
    oversampling_from_diffraction,
)

# Base Pipeline class and pipeline-related functions.
from .base import Pipeline
from .parameters import (
    DEFAULT_PIPELINE_PARAMS,
    convert_np_arrays,
    isparameter,
    validate_and_fill_params,
)

# Plot functions.
from .pipeline_plotter import PipelinePlotter


[docs] class PyNXScriptError(Exception): """Custom exception to handle pynx script failure."""
[docs] def __init__(self, msg: object = None) -> None: """ Initialise PyNXScriptError with an informative message. The `msg` argument may be a string, an exception, or a file-like object (for example an open stderr pipe). This constructor will coerce non-string inputs into a readable string representation and avoid TypeError when concatenating. Args: msg (object, optional): additional error message. Can be a str, Exception, file-like object or None. """ base_message = "PyNX script failed due to the following error:" if msg is None: super().__init__(base_message) return # try to produce a human friendly string from msg without # raising further exceptions extra_message = "" try: # if already a string, use it directly if isinstance(msg, str): extra_message = msg # file-like objects often expose a read() method elif hasattr(msg, "read") and callable(getattr(msg, "read")): try: # read content (may consume the stream) extra_message = msg.read() except Exception: # fallback to representation extra_message = str(msg) else: # for Exceptions and other objects extra_message = str(msg) except Exception: # last-resort fallback if anything goes wrong extra_message = "<unrepresentable error message>" full_message = base_message if extra_message: full_message = base_message + "\n" + extra_message super().__init__(full_message)
[docs] class BcdiPipeline(Pipeline): """ A class to handle the BCDI workflow, from pre-processing to post-processing, including phase retrieval (using PyNX package). Provide either a path to a parameter file or directly the parameter dictionary. Args: param_file_path (str, optional): the path to the parameter file. Defaults to None. parameters (dict, optional): the parameter dictionary. Defaults to None. """ voxel_pos = ("ref", "max", "com") class_isosurface = 0.1
[docs] def __init__( self, params: dict = None, param_file_path: str = None, ): """ Initialisation method. Args: param_file_path (str, optional): the path to the parameter file. Defaults to None. parameters (dict, optional): the parameter dictionary. Defaults to None. """ super().__init__(params, param_file_path) self.params = validate_and_fill_params(self.params) self.scan = self.params["scan"] self.sample_name = self.params["sample_name"] self.pynx_phasing_dir = self.dump_dir + "/pynx_phasing/" # The dictionary of the Voxels Of Interest, in both the "full" # and "cropped" detector frames. self.voi = {"full": {}, "cropped": {}} # The dictionary of the VOI associated q_lab positions self.q_lab_pos = {key: None for key in self.voxel_pos} # The dictionary of the atomic parameters (d-spacing and lattice # parameter) for the various voxel positions. self.atomic_params = { "dspacing": {k: None for k in self.voxel_pos}, "lattice_parameter": {k: None for k in self.voxel_pos}, } # Define base attributes self.detector_data: np.ndarray = None self.cropped_detector_data: np.ndarray = None self.orthogonalised_intensity: np.ndarray = None self.mask: np.ndarray = None self.angles: dict = None self.converter: SpaceConverter = None self.result_analyser: PhasingResultAnalyser = None self.reconstruction: np.ndarray = None self.structural_props: dict = None # the list of phase retrieval results self.phasing_results = [] # For storing data that later saved in the cxi files self.extra_info: dict = {} self.logger.info("BcdiPipeline initialised.")
[docs] @classmethod def from_file(cls, path: str) -> "BcdiPipeline": """Factory method to create a BcdiPipeline instance from a file.""" if path.endswith(".cxi"): params, converter = cls.load_from_cxi(path) elif path.endswith(".yml") or path.endswith(".yaml"): raise ValueError("Loading from yaml files not yet implemented.") else: raise ValueError("File format not supported.") instance = cls(params) instance.converter = converter if "q_lab_ref" not in params: raise ValueError("q_lab_ref is missing in the parameters.") instance.q_lab_pos["ref"] = params["q_lab_ref"] return instance
[docs] def update_from_file(self, path: str) -> None: """ Update pipeline instance with parameters from CXI file. Loads parameters and SpaceConverter from a CXI file and updates the current instance. Useful for resuming analysis from saved reconstruction results. Args: path (str): path to CXI file. Raises: ValueError: if file format is unsupported or q_lab_ref is missing in parameters. """ if path.endswith(".cxi"): params, converter = self.load_from_cxi(path) elif path.endswith(".yml") or path.endswith(".yaml"): raise ValueError("Loading from yaml files not yet implemented.") else: raise ValueError("File format not supported.") self.params.update(params) self.converter = converter if "q_lab_ref" not in params: raise ValueError("q_lab_ref is missing in the parameters.") self.q_lab_pos["ref"] = params["q_lab_ref"]
[docs] @classmethod def load_from_cxi(cls, path: str) -> tuple[dict, SpaceConverter]: """ Load pipeline parameters and SpaceConverter from CXI file. Extracts stored parameters and reconstruction metadata from a CXI file, rebuilding the SpaceConverter for coordinate transformations. Args: path (str): path to CXI file. Returns: tuple[dict, SpaceConverter]: extracted parameters and configured SpaceConverter instance. Raises: ValueError: if path does not end with '.cxi'. """ if not path.endswith(".cxi"): raise ValueError("CXI file expected.") with CXIFile(path, "r") as cxi: params = cxi["entry_1/parameters_1"] converter = cls._build_converter_from_cxi(cxi) return params, converter
@staticmethod def _build_converter_from_cxi(cxi: CXIFile) -> SpaceConverter: """ Reconstruct SpaceConverter from CXI file metadata. Extracts geometry, detector calibration, Q-space matrices, and voxel sizes from CXI file to rebuild a fully configured SpaceConverter. Args: cxi (CXIFile): opened CXI file handle. Returns: SpaceConverter: configured converter with initialised Q-space. """ converter_params = { "geometry": Geometry.from_setup(cxi["entry_1/geometry_1/name"]), "det_calib_params": cxi["entry_1/detector_1/calibration"], "roi": cxi["entry_1/result_1/roi"], "energy": cxi["entry_1/source_1/energy"], "shape": cxi["entry_1/image_1/image_size"], "q_lab_shift": cxi["entry_1/result_2/q_lab_shift"], "q_lab_matrix": cxi[ "entry_1/result_2/transformation_matrices/q_lab" ], "direct_lab_matrix": cxi[ "entry_1/result_2/transformation_matrices/direct_lab" ], "direct_lab_voxel_size": cxi[ "entry_1/result_2/direct_lab_voxel_size" ], } converter = SpaceConverter(**converter_params) converter.init_q_space(**cxi["entry_1/geometry_1/angles"]) return converter
[docs] @Pipeline.process def preprocess(self, **params) -> None: """ Preprocess BCDI detector data for phase retrieval. Handles complete preprocessing workflow: data loading, Bragg peak centring, cropping, filtering (hot pixels, flat-field, background subtraction), and Q-space coordinate system initialisation. Args: **params: optional parameters to override instance params. Common overrides include 'preprocess_shape', 'hot_ pixel_filter', 'flat_field', 'background_level'. Side effects: Updates instance attributes: detector_data, cropped_ detector_data, mask, voi (voxels of interest), converter, q_lab_pos, atomic_params. initialisation. Raises: ValueError: if the requested shape and the voxel reference are not compatible. """ if params: self.logger.info( "Additional parameters provided, will update the current " "dictionary of parameters." ) for p in params: if not isparameter(p): raise ValueError( f"Parameter '{p}' is not recognised. " "Please check the possible parameter names among:\n" + "\n".join( [ f"- {key}" for key in DEFAULT_PIPELINE_PARAMS.keys() ] ) ) self.params.update(params) # If voxel_reference_methods is not a list, make it a list if not isinstance(self.params["voxel_reference_methods"], list): self.params["voxel_reference_methods"] = [ self.params["voxel_reference_methods"] ] # Check whether the requested shape is permitted by pynx self.params["preprocess_shape"] = ensure_pynx_shape( self.params["preprocess_shape"], verbose=True ) if self.params["light_loading"]: roi = self._light_load() # Filter the data self.cropped_detector_data = self._filter( self.cropped_detector_data ) else: self._load() self._from_2d_to_3d_shape() self.logger.info( "The preprocessing output shape is: and " f"{self.params['preprocess_shape']} will be used for the " "determination of the ROI dimensions." ) # Filter, crop and centre the detector data. self.cropped_detector_data, roi = self._crop_centre( self._filter(self.detector_data) ) for r in roi: if r < 0: raise ValueError( "The preprocess_shape and the detector voxel reference are" f" not compatible: {self.params['preprocess_shape'] = }, " # noqa E251, E202 f"{self.voi['full']['ref'] = }." # noqa E251, E202 ) # print out the oversampling ratio and rebin factor suggestion ratios = oversampling_from_diffraction(self.cropped_detector_data) if ratios is None: self.logger.info("Could not estimate the oversampling.") else: self.logger.info( "\nOversampling ratios calculated from diffraction pattern " "are: " + ", ".join( [f"axis{i}: {ratios[i]:.1f}" for i in range(len(ratios))] ) + ". If low-strain crystal, you can set PyNX 'rebin' parameter" " to (" + ", ".join([f"{r // 2}" for r in ratios]) + ")" ) # position of the max and com in the cropped detector frame for pos in ("max", "com"): self.voi["cropped"][pos] = CroppingHandler.get_position( self.cropped_detector_data, pos ) # Initialise SpaceConverter, later used for orthogonalisation geometry = Geometry.from_setup( self.params["beamline_setup"], sample_orientation=self.params.get("sample_orientation", None), sample_surface_normal=self.params.get( "sample_surface_normal", None ), ) self.converter = SpaceConverter( geometry=geometry, det_calib_params=self.params["det_calib_params"], energy=self.params["energy"], roi=roi[2:], ) self.converter.init_q_space(**self.angles) # Initialise the fancy table with the columns table = [ [ "voxel", "uncrop. det. pos.", "crop. det. pos.", "dspacing (A)", "lat. par. (A)", ] ] # get the position of the reference, max and det voxels in the # q lab space for pos in self.voxel_pos: det_voxel = self.voi["full"][pos] cropped_det_voxel = self.voi["cropped"][pos] if any(np.isnan(e) for e in cropped_det_voxel): self.q_lab_pos[pos] = None self.atomic_params["dspacing"][pos] = None self.atomic_params["lattice_parameter"][pos] = None else: self.q_lab_pos[pos] = self.converter.index_det_to_q_lab( cropped_det_voxel ) # compute the corresponding dpsacing and lattice parameter # for printing self.atomic_params["dspacing"][pos] = self.converter.dspacing( self.q_lab_pos[pos] ) self.atomic_params["lattice_parameter"][pos] = ( self.converter.lattice_parameter( self.q_lab_pos[pos], self.params["hkl"] ) ) table.append( [ pos, det_voxel, cropped_det_voxel, f"{self.atomic_params['dspacing'][pos]:.5f}", f"{self.atomic_params['lattice_parameter'][pos]:.5f}", ] ) self._unwrap_logs() # Turn off wrapping for structured output self.logger.info( "\nSummary table:\n" + tabulate(table, headers="firstrow", tablefmt="fancy_grid"), ) self._wrap_logs() # Turn wrapping back on for regular logs if self.params["orthogonalise_before_phasing"]: self.logger.info( "Orthogonalisation required before phasing.\n" "Will use xrayutilities Fuzzy Gridding without linear " "approximation." ) self.orthogonalised_intensity = ( self.converter.orthogonalise_to_q_lab( self.cropped_detector_data, method="xrayutilities" ) ) # we must orthogonalise the mask and orthogonalised_intensity must # be saved as the pynx input self.mask = self.converter.orthogonalise_to_q_lab( self.mask, method="xrayutilities" ) self.cropped_detector_data = self.orthogonalised_intensity q_lab_regular_grid = self.converter.get_xu_q_lab_regular_grid() else: self.logger.info( "Will linearise the transformation between detector and" " lab space." ) # Initialise the interpolator so we won't need to reload raw # data during the post processing. The converter will be saved. self.converter.init_interpolator( direct_lab_voxel_size=self.params["voxel_size"], space="both", verbose=True, ) if self.params["voxel_size"] is not None: self.logger.info( "Voxel size provided by user will be saved for the direct " "space orthogonalisation." ) # Run the interpolation in the reciprocal space so we don't # do it later self.orthogonalised_intensity = ( self.converter.orthogonalise_to_q_lab( self.cropped_detector_data, method="cdiutils" ) ) q_lab_regular_grid = self.converter.get_q_lab_regular_grid() # Update the preprocess_shape and the det_reference_voxel self.params["det_reference_voxel"] = self.voi["full"]["ref"] # plot and save the detector data in the full detector frame and # final frame dump_file_tmpl = ( f"{self.dump_dir}/S{self.scan}_" + "detector_data_{}.png" ) PipelinePlotter.detector_data( self.cropped_detector_data, full_det_data=self.detector_data, voxels=self.voi, title=( "Detector data preprocessing (slices), " f"{self.sample_name}, S{self.scan}" ), save=dump_file_tmpl.format("slices"), ) PipelinePlotter.detector_data( self.cropped_detector_data, full_det_data=self.detector_data, voxels=self.voi, integrate=True, title=( "Detector data preprocessing (projection), " f"{self.sample_name}, S{self.scan}" ), save=dump_file_tmpl.format("sum"), ) # Plot the reciprocal space data in the detector and lab frames PipelinePlotter.ortho_detector_data( self.cropped_detector_data, self.orthogonalised_intensity, q_lab_regular_grid, title=( r"From detector frame to q lab frame" f", {self.sample_name}, S{self.scan}" ), save=dump_file_tmpl.format("orthogonalisation"), ) # Save the data and the parameters in the dump directory, and # save the q_lab reference position as a parameter. self.params["q_lab_ref"] = self.q_lab_pos["ref"] self._save_preprocessed_data() self._save_parameter_file()
def _from_2d_to_3d_shape(self) -> tuple: """ Extend 2D preprocess_shape to 3D using detector data length. Prepends detector_data first dimension to 2D shape tuple. """ if len(self.params["preprocess_shape"]) == 2: self.params["preprocess_shape"] = ( self.detector_data.shape[0], ) + self.params["preprocess_shape"] def _load(self, roi: tuple[slice] = None) -> None: """ Load the raw detector data and motor positions. Args: roi (tuple[slice], optional): the region of interest on the detector frame. Defaults to None. Raises: ValueError: check whether detector data, motor positions and mask have been correctly loaded. """ # Make the loader using the factory method, only parse the # parameters relevant to the Loader class, to make them explicit. loader_keys = ( "beamline_setup", "scan", "sample_name", "experiment_file_path", "experiment_data_dir_path", "detector_data_path", "edf_file_template", "detector_name", "alien_mask", "flat_field", ) loader = Loader.from_setup(**{k: self.params[k] for k in loader_keys}) if self.params.get("detector_name") is None: self.params["detector_name"] = loader.detector_name if self.params.get("detector_name") is None: raise ValueError( "The automatic detection of the detector name is not " "yet implemented for this setup" f"({self.params['beamline_setup']})." ) self.detector_data = loader.load_detector_data( roi=roi, rocking_angle_binning=self.params["rocking_angle_binning"] ) if roi is not None: shape = loader.load_detector_shape() if shape is not None: self.logger.info(f"Raw detector data shape is: {shape}.") else: self.logger.info( f"Raw detector data shape is: {self.detector_data.shape}." ) self.angles = loader.load_motor_positions( roi=roi, rocking_angle_binning=self.params["rocking_angle_binning"] ) self.mask = loader.get_mask( channel=self.detector_data.shape[0], detector_name=self.params["detector_name"], roi=(slice(None), roi[1], roi[2]) if roi else None, ) if loader.get_alien_mask() is not None: self.logger.info("Alien mask provided. Will update detector mask.") alien_mask = loader.get_alien_mask(roi) self.mask = np.where(self.mask + alien_mask > 0, 1, 0) if any(data is None for data in (self.detector_data, self.angles)): raise ValueError("Something went wrong during data loading.") if self.params["energy"] is None: self.params["energy"] = loader.load_energy() if self.params["energy"] is None: raise ValueError( "The automatic loading of energy is not yet implemented " f"for this setup ({self.params['beamline_setup']})." ) if isinstance(self.params["energy"], (list, np.ndarray)): self.logger.info( "The current scan is an energy scan. The average energy is" f" {np.mean(self.params['energy']):3f} eV." ) self.params["energy"] = loader.format_scanned_counters( self.params["energy"], scan_axis_roi=roi[0] if roi is not None else roi, rocking_angle_binning=self.params["rocking_angle_binning"], ) else: self.logger.info( f"Energy successfully loaded ({self.params['energy']} eV)." ) if self.params["det_calib_params"] is None: self.logger.info( "\ndet_calib_params not provided, will try to find them. " "However, for a more accurate calculation, you'd better " "provide them." ) self.params["det_calib_params"] = loader.load_det_calib_params() if self.params["det_calib_params"] is None: raise ValueError( "The automatic loading of det_calib_params is not yet " "implemented for this setup " f"('{self.params['beamline_setup']}'), you must provide " "them." ) self.logger.info( "det_calib_params successfully loaded:\n" f"{self.params['det_calib_params']}" ) def _light_load(self) -> list: """ Light load the detector data according to the provided preprocess_shape and voxel_reference_methods position. This allows to load only the region of interest determined by these two parameters. Raises: ValueError: When final_shape length is not 3 and rocking_angle_binning are not provided. ValueError: If voxel_reference_methods is not provided. Returns: list: the roi associated to the cropped_detector_data. """ if ( len(self.params["preprocess_shape"]) != 3 and self.params["rocking_angle_binning"] is None ): self.params["rocking_angle_binning"] = 1 self.logger.warning( "When light loading is requested, preprocess_shape must " "include 3 axis lengths or rocking_angle_binning must be " "provided. Will set rocking_angle_binning to 1, but this might" " not be very efficient." ) self.voi["full"]["ref"] = self.params["voxel_reference_methods"][-1] if isinstance(self.voi["full"]["ref"], str): raise ValueError( "When light loading, voxel_reference_methods must " "contain a tuple indicating the position of the voxel you " "want to crop the data at. Ex: [(100, 200, 200)]." ) if ( len(self.voi["full"]["ref"]) == 2 and len(self.params["preprocess_shape"]) == 3 ): self.params["preprocess_shape"] = self.params["preprocess_shape"][ 1: ] elif ( len(self.voi["full"]["ref"]) == 3 and len(self.params["preprocess_shape"]) == 2 ): self.voi["full"]["ref"] = self.voi["full"]["ref"][1:] roi = CroppingHandler.get_roi( self.params["preprocess_shape"], self.voi["full"]["ref"] ) if len(roi) == 4: roi = [None, None, roi[0], roi[1], roi[2], roi[3]] self.logger.info( f"\nLight loading requested, will use ROI {roi} and " "bin along rocking curve direction by " f"{self.params['rocking_angle_binning']} during data loading." ) self._load(roi=CroppingHandler.roi_list_to_slices(roi)) # If user only specified position in 2D or provided a 2D # preprocess_shape, we must extend it to 3D and check the first # dimension length. if len(self.voi["full"]["ref"]) == 2: # covers both cases self.voi["full"]["ref"] = ( self.detector_data.shape[0] // 2, ) + tuple(self.voi["full"]["ref"]) self._from_2d_to_3d_shape() shape = ensure_pynx_shape(self.params["preprocess_shape"]) roi = CroppingHandler.get_roi(shape, self.voi["full"]["ref"]) # If the shape has changed due to PyNX conventions, then # only crop the data along the first axis direction. if shape != self.params["preprocess_shape"]: roi_1d = CroppingHandler.roi_list_to_slices(roi[:2]) self.detector_data = self.detector_data[roi_1d] self.mask = self.mask[roi_1d] rocking_angle = Loader.get_rocking_angle(self.angles) self.angles[rocking_angle] = self.angles[rocking_angle][roi_1d] self.params["preprocess_shape"] = shape self.logger.info(f"New ROI is: {roi}.") # find the position in the cropped detector frame self.voi["cropped"]["ref"] = tuple( p - r if r else p # if r is None, p-r must be p for p, r in zip(self.voi["full"]["ref"], roi[::2]) ) self.voi["full"]["max"], self.voi["full"]["com"] = None, None # Since we light load the data, the full detector data do not exist self.cropped_detector_data = self.detector_data.copy() self.detector_data = None return roi def _filter(self, data: np.ndarray) -> np.ndarray: """ Apply hot pixel filtering and background subtraction. Modifies mask in-place when hot pixels detected. """ if self.params["hot_pixel_filter"]: self.logger.info("hot_pixel_filter requested.") if isinstance(self.params["hot_pixel_filter"], tuple): self.logger.info( "Hot pixel filter parameters are : " f"{self.params['hot_pixel_filter']}" ) data, hot_pixel_mask = hot_pixel_filter( data, *self.params["hot_pixel_filter"] ) else: self.logger.info( "Will use defaults parameters: threshold = 1e2, " "kernel_size = 3 " ) data, hot_pixel_mask = hot_pixel_filter(data) self.mask = np.where(hot_pixel_mask + self.mask > 0, 1, 0) if self.params["background_level"]: self.logger.info( f"background_level set to {self.params['background_level']}, " "will remove the background." ) data -= self.params["background_level"] data[data < 0] = 0 return data def _crop_centre(self, detector_data) -> tuple[np.ndarray, list]: """ Crop and centre the self.detector data and return the cropped detector data and the associated roi. """ self.logger.info( "Method(s) employed for the voxel reference determination are " f"{self.params['voxel_reference_methods']}." ) self._unwrap_logs() # Turn off wrapping for structured output ( cropped_detector_data, self.voi["full"]["ref"], self.voi["cropped"]["ref"], roi, ) = CroppingHandler.chain_centring( detector_data, self.params["preprocess_shape"], methods=self.params["voxel_reference_methods"], verbose=True, ) self._wrap_logs() # Turn wrapping back on for regular logs # position of the max and com in the full detector frame for pos in ("max", "com"): self.voi["full"][pos] = CroppingHandler.get_position( self.detector_data, pos ) # convert numpy.int64 to int to make them serializable self.params["det_reference_voxel"] = tuple( int(e) for e in self.voi["full"]["ref"] ) # center and crop the mask self.mask = self.mask[CroppingHandler.roi_list_to_slices(roi)] self.logger.info( "The reference voxel was found at " f"{self.voi['full']['ref']} in the uncropped data frame.\n" "The process_shape being " f"{self.params['preprocess_shape']}, the roi used to crop " f"the data is {roi}.\n" ) # set the q space area with the sample and detector angles # that correspond to the requested roi for key, value in self.angles.items(): if isinstance(value, (list, np.ndarray)) and len(value) > 1: self.angles[key] = value[np.s_[roi[0] : roi[1]]] # if energy scan, crop the energy values according to the roi if isinstance(self.params["energy"], (list, np.ndarray)): self.params["energy"] = self.params["energy"][ np.s_[roi[0] : roi[1]] ] return cropped_detector_data, roi def _save_parameter_file(self) -> None: """ Save analysis parameters to YAML file. Writes parameters to dump_dir/S{scan}_parameters.yml. """ output_file_path = f"{self.dump_dir}/S{self.scan}_parameters.yml" self.params = convert_np_arrays(**self.params) with open(output_file_path, "w", encoding="utf8") as file: yaml.dump(self.params, file) self.logger.info( f"\nScan parameter file saved at:\n{output_file_path}" ) def _save_preprocessed_data(self) -> None: """ Save preprocessed detector data, mask, and metadata to CXI. Creates NPZ files for PyNX input and comprehensive CXI file. """ # Prepare dir in which pynx phasing results will be saved. os.makedirs(self.pynx_phasing_dir, exist_ok=True) # Save the cropped detector data and mask in pynx phasing dir for name, data in zip( ("data", "mask"), (self.cropped_detector_data, self.mask) ): path = f"{self.pynx_phasing_dir}S{self.scan}_pynx_input_{name}.npz" np.savez(path, data=data) # Save all outputs of the preprocessing stage in a .cxi file dump_path = f"{self.dump_dir}/S{self.scan}_preprocessed_data.cxi" with CXIFile(dump_path, "w") as cxi: cxi.stamp() cxi.set_entry() path = cxi.create_cxi_group( "process", command="cdiutils.BcdiPipeline.preprocess()", comment="Pipeline preprocessing step", ) cxi.softlink(f"{path}/program", "/creator") cxi.softlink(f"{path}/version", "/version") geometry_to_parse = self.converter.geometry.to_dict() geometry_to_parse.update({"angles": self.angles}) geo_path = cxi.create_cxi_group("geometry", **geometry_to_parse) detector = { "description": self.params["detector_name"], "mask": self.mask[0], "calibration": self.params["det_calib_params"], } path = cxi.create_cxi_group("detector", **detector) cxi.softlink(f"{path}/distance", f"{path}/calibration/distance") cxi.softlink(f"{path}/x_pixel_size", f"{path}/calibration/pwidth2") cxi.softlink(f"{path}/y_pixel_size", f"{path}/calibration/pwidth1") cxi.softlink(f"{path}/geometry_1", geo_path) msg = """Raw detector data centring and cropping. The Region of Interest is given by the roi entry. The Voxels of interest are given by the voi entry.""" path = cxi.create_cxi_group( "result", voi=self.voi, roi=self.converter.roi, description=msg ) cxi.softlink(path + "/process_1", "entry_1/process_1") msg = """The orthogonalisation allows to make table of correspondence between the detector pixels and their associated positions in the reciprocal space. From this, the average d-spacing and lattice parameter are computed. The matrix to orthogonalise the reconstructed object after phase retrieval is also computed and will be used in the post-processing stage.""" results = self.converter.to_dict() results = { k: results[k] for k in ( "q_lab_shift", "transformation_matrices", "direct_lab_voxel_size", ) } atomic_params = self.atomic_params.copy() atomic_params["units"] = "angstrom" q_lab = self.q_lab_pos.copy() q_lab["units"] = "1/angstrom" if self.params["orthogonalise_before_phasing"]: qx, qy, qz = self.converter.get_xu_q_lab_regular_grid() else: qx, qy, qz = self.converter.get_q_lab_regular_grid() results.update( { "atomic_parameters": atomic_params, "q_lab": q_lab, "description": msg, "qx_xu": qx, "qy_xu": qy, "qz_xu": qz, } ) path = cxi.create_cxi_group("result", **results) cxi.softlink(path + "/process_1", "entry_1/process_1") exp_path = self.params["experiment_file_path"] cxi.create_cxi_group( "sample", "sample_name", sample_name=self.sample_name, experiment_file_path=exp_path, experiment_identifier=( None if exp_path is None else exp_path.split("/")[-1].split(".")[0] ), ) cxi.create_cxi_group("parameters", **self.params) cxi.create_cxi_group( "source", energy=self.params["energy"], units="eV" ) path = cxi.create_cxi_image( self.cropped_detector_data, data_type="cropped detector data", data_space="reciprocal", mask=self.mask[0], process_1="process_1", ) cxi.softlink("entry_1/cropped_detector_data", path) path = cxi.create_cxi_image( self.orthogonalised_intensity, data_type="orthogonalised detector data", data_space="reciprocal", process_1="process_1", ) cxi.softlink("entry_1/orthogonalised_detector_data", path) self.logger.info(f"Pre-processed data file saved at:\n{dump_path}") def _make_slurm_file(self, template: str = None) -> None: """ Generate SLURM batch script for PyNX phase retrieval. Uses template substitution to create job submission file. """ # Make the pynx slurm file if template is None: template = ( f"{os.path.dirname(__file__)}/pynx-id01-cdi_template.slurm" ) self.logger.info( "Pynx slurm file template not provided, will take " f"the default: {template}" ) else: self.logger.info( f"Pynx slurm file template provided {template = }." # noqa: E251, E202 ) with open(template, "r", encoding="utf8") as file: source = Template(file.read()) pynx_slurm_text = source.substitute( { "number_of_nodes": 2, "data_path": self.pynx_phasing_dir, "SLURM_JOBID": "$SLURM_JOBID", "SLURM_NTASKS": "$SLURM_NTASKS", } ) with open( self.pynx_phasing_dir + "/pynx-id01-cdi.slurm", "w", encoding="utf8", ) as file: file.write(pynx_slurm_text)
[docs] def phase_retrieval_gui(self) -> None: """Launch the interactive phase retrieval GUI. This method lazily imports and launches the PhaseRetrievalGUI from cdiutils.interactive.phase_retrieval. The lazy import avoids requiring the optional GUI-related dependencies (pynx and ipywidgets) unless this method is invoked. The GUI is initialized with the pipeline instance and the pipeline's pynx_phasing_dir as the working directory, and it searches for CXI files matching the pattern "*Run*.cxi". After initialization, the GUI's show() method is called to display the interface. Raises: ImportError: If the PhaseRetrievalGUI cannot be imported because the required dependencies ('pynx' and 'ipywidgets') are not installed. The raised error suggests installing them via: pip install pynx ipywidgets Returns: None Example: pipeline.phase_retrieval_gui() """ # lazy import to avoid requiring ipywidgets/pynx when not using GUI try: from cdiutils.interactive.phase_retrieval import PhaseRetrievalGUI except ImportError as e: raise ImportError( "PhaseRetrievalGUI requires both 'pynx' and 'ipywidgets' to be installed. " "Please install them with: pip install pynx ipywidgets" ) from e self.logger.info("Launching interactive GUI.") gui = PhaseRetrievalGUI( work_dir=self.pynx_phasing_dir, pipeline_instance=self, search_pattern="*Run*.cxi", ) gui.show()
[docs] @Pipeline.process def phase_retrieval( self, jump_to_cluster: bool = False, pynx_slurm_file_template: str = None, clear_former_results: bool = False, cmd: str = None, search_pattern: str = "*Run*.cxi", **pynx_params, ) -> None: """ Execute phase retrieval using PyNX. Runs PyNX either locally (direct subprocess) or on a SLURM cluster. Generates PyNX input file from parameters and manages job submission/monitoring if cluster execution is requested. Args: jump_to_cluster (bool, optional): submit job to SLURM cluster. Defaults to False (local execution). pynx_slurm_file_template (str, optional): path to SLURM script template. Defaults to None (uses built-in template). clear_former_results (bool, optional): delete previous reconstruction CXI files. Defaults to False. cmd (str, optional): command for local PyNX execution. Defaults to None (uses "pynx-cdi-id01 pynx-cdi- inputs.txt"). search_pattern (str, optional): glob pattern for finding result CXI files. Defaults to "*Run*.cxi". **pynx_params: PyNX parameters (e.g., nb_run, nb_raar, support_threshold). Override defaults. Raises: PyNXScriptError: if PyNX execution fails. subprocess.CalledProcessError: if subprocess commandlate (str, optional): the template for the pynx slurm file. Defaults to None. clear_former_results (bool, optional): whether ti clear the former results. Defaults to False. cmd (str, optional): the command to run when running pynx on the current machine. Defaults to None. **pynx_params: additional pynx parameters. Raises: PyNXScriptError: if PyNX execution fails. subprocess.CalledProcessError: if subprocess command fails """ if clear_former_results: self.logger.info("Removing former results.\n") files = glob.glob(self.pynx_phasing_dir + "/*Run*.cxi") files += glob.glob(self.pynx_phasing_dir + "/*Run*.png") for f in files: os.remove(f) self.phasing_results = [] pynx_input_path = self.pynx_phasing_dir + "/pynx-cdi-inputs.txt" # handle pynx params, merge defaults + user inputs self.params["pynx"] = self._merge_pynx_params(pynx_params) # dynamically assign data and mask paths if not set by the user for name in ("data", "mask"): if self.params["pynx"][name] is None: self.params["pynx"][name] = ( f"{self.pynx_phasing_dir}S{self.scan}_pynx_input_{name}" ".npz" ) # Make the pynx input file. with open(pynx_input_path, "w", encoding="utf8") as file: for key, value in self.params["pynx"].items(): file.write(f"{key} = {value}\n") if jump_to_cluster: self.logger.info("Jumping to cluster requested.") self._make_slurm_file(pynx_slurm_file_template) job_id, output_file = self.submit_job( job_file="pynx-id01-cdi.slurm", working_dir=self.pynx_phasing_dir, ) self.monitor_job(job_id, output_file) else: self.logger.info( "Assuming the current machine is running PyNX. Will run the " "provided command." ) if cmd is None: cmd = "pynx-cdi-id01 pynx-cdi-inputs.txt" self.logger.info( f"No command provided. Will use the default: {cmd}" ) self._run_cmd(cmd, self.pynx_phasing_dir)
@staticmethod def _merge_pynx_params(user_pynx_params: dict) -> dict: """ Merge user-specified pynx parameters with default pynx parameters. User-defined values override defaults. Args: user_pynx_params (dict): contains user-specified values for pynx. Returns: dict: merged pynx parameters. """ # deep copy to avoid modifying global defaults merged_pynx = copy.deepcopy(DEFAULT_PIPELINE_PARAMS["pynx"]) # override defaults with user-provided values merged_pynx.update(user_pynx_params) return merged_pynx def _run_cmd(self, cmd: str, cwd: str) -> None: """ Run a command in a subprocess and stream output to logger. Args: cmd (str): the command to execute. cwd (str): the working directory for the subprocess. Raises: PyNXScriptError: if the subprocess returns a non-zero exit code. """ try: # accumulate stderr lines for error reporting stderr_lines = [] with subprocess.Popen( ["bash", "-l", "-c", cmd], stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=cwd, # change to this directory text=True, # ensures stdout/stderr are str, not bytes env=os.environ.copy(), bufsize=1, ) as proc: # stream stdout for line in iter(proc.stdout.readline, ""): self.logger.info(line.strip()) # stream stderr and capture for error reporting for line in iter(proc.stderr.readline, ""): stripped_line = line.strip() self.logger.error(stripped_line) stderr_lines.append(stripped_line) # wait for the process to complete and check the # return code proc.wait() if proc.returncode != 0: # compile captured stderr into a single message stderr_text = ( "\n".join(stderr_lines) if stderr_lines else ( f"Process exited with code {proc.returncode} " "(no stderr output captured)" ) ) self.logger.error( "PyNX phasing process failed with return code " f"{proc.returncode}" ) raise PyNXScriptError(stderr_text) except subprocess.CalledProcessError as e: # log the error if the job submission fails self.logger.error( "Subprocess process failed with return code " f"{e.returncode}: {e.stderr}" ) raise e
[docs] def analyse_phasing_results( self, sorting_criterion: str = "mean_to_max", search_pattern: str = "*Run*.cxi", plot: bool = True, plot_phasing_results: bool = True, plot_phase: bool = False, init_analyser: bool = True, ) -> None: """ Analyse and sort phase retrieval results by quality metrics. Wrapper for PhasingResultAnalyser that evaluates reconstruction quality using various criteria. Sorts results and generates comparison plots. Args: sorting_criterion (str, optional): quality metric for sorting. Options: - 'mean_to_max': amplitude homogeneity (Gaussian mean vs. max) - 'sharpness': sum of amplitude^4 within support - 'std': amplitude standard deviation - 'llk': log-likelihood - 'llkf': free log-likelihood Defaults to "mean_to_max". search_pattern (str, optional): glob pattern for CXI files. Defaults to "*Run*.cxi". plot (bool, optional): enable/disable all plots. Defaults to True. plot_phasing_results (bool, optional): plot result comparisons. Defaults to True. plot_phase (bool, optional): plot phase (with amplitude as opacity) instead of amplitude. Defaults to False. init_analyser (bool, optional): force reinitialisation of PhasingResultAnalyser. Defaults to True. Raises: ValueError: if sorting_criterion is unknown. """ if self.result_analyser is None or init_analyser: self.result_analyser = PhasingResultAnalyser( result_dir_path=self.pynx_phasing_dir ) self.result_analyser.analyse_phasing_results( sorting_criterion=sorting_criterion, search_pattern=search_pattern, plot=plot, plot_phasing_results=plot_phasing_results, plot_phase=plot_phase, )
[docs] def generate_support_from( self, run: int | str = "best", output_path: str = None, fill: bool = False, verbose: bool = True, search_pattern: str = "*Run*.cxi", ) -> None: """ Extract and save support from a specific reconstruction run. Generates a support mask from a phase retrieval result and saves it as a CXI file. Can be used directly in subsequent phasing by setting: `support = <output_path>` in PyNX params. Args: run (int | str, optional): run selection. Use "best" for top-ranked result or integer for specific run number. Defaults to "best". output_path (str, optional): save path for support CXI file. If None, saves to `pynx_phasing_dir/support.cxi`. Defaults to None. fill (bool, optional): fill holes in support using morphological operations. Defaults to False. verbose (bool, optional): print info and plot support. Defaults to True. search_pattern (str, optional): glob pattern for CXI files. Defaults to "*Run*.cxi". """ if run == "best": selected_path = next( iter(self.result_analyser._sorted_phasing_results) ) else: if not self.result_analyser.result_paths: self.result_analyser.find_phasing_results(search_pattern) for path in self.result_analyser.result_paths: if run == int(path.split("Run")[1][:4]): selected_path = path with CXIFile(selected_path) as cxi: support = cxi["entry_1/image_1/support"] if fill: filled_support = fill_up_support(support) if output_path is None: output_path = self.pynx_phasing_dir + "support.cxi" with CXIFile(output_path, "w") as cxi: cxi.stamp() path = cxi.create_cxi_group("image") cxi.create_cxi_dataset( path + "/mask", filled_support if fill else support ) if verbose: self.logger.info( "Support generated from Run : " f"{int(selected_path.split('Run')[1][:4])}, i.e. file:\n" f"{selected_path}\nand saved to:\n{output_path}" ) plot_volume_slices(support, cmap="gray", title="Support") if fill: plot_volume_slices( filled_support, cmap="gray", title="Filled support" )
[docs] def select_best_candidates( self, nb_of_best_sorted_runs: int = None, best_runs: list = None, search_pattern: str = "*Run*.cxi", ) -> None: """ Select best phase retrieval candidates for mode decomposition. Wrapper for PhasingResultAnalyser.select_best_candidates. Choose candidates either by count (top N sorted) or explicit run numbers. Args: nb_of_best_sorted_runs (int, optional): number of top-sorted runs to select. Requires prior call to analyse_phasing_ results(). Defaults to None. best_runs (list[int], optional): explicit list of run numbers (e.g., [2, 5, 7]). Defaults to None. search_pattern (str, optional): glob pattern for CXI files. Defaults to "*Run*.cxi". Raises: ValueError: if result_analyser not initialised (call analyse_phasing_results() first) Raises: ValueError: If the results have not been analysed yet. """ if not self.result_analyser: raise ValueError( "self.result_analyser not initialised yet. Run" " BcdiPipeline.analyse_phasing_results() first." ) self.result_analyser.select_best_candidates( nb_of_best_sorted_runs, best_runs, search_pattern )
[docs] @Pipeline.process def mode_decomposition( self, cmd: str = None, search_pattern: str = "*Run*.cxi" ) -> None: """ Perform mode decomposition on selected reconstruction candidates. Extracts principal modes from multiple phase retrieval results using PyNX's pynx-cdi-analysis (similar to PCA). Modes represent consistent features across reconstructions. Args: cmd (str, optional): command for mode decomposition if PyNX unavailable locally. Defaults to None (uses "pynx-cdi- analysis candidate_*.cxi --modes 1 --modes_output mode.h5"). search_pattern (str, optional): glob pattern for candidate CXI files. Defaults to "*Run*.cxi". Side effects: Saves modes to S{scan}_pynx_reconstruction_mode.cxi in dump_dir. """ if self.result_analyser is None: self.result_analyser = PhasingResultAnalyser( result_dir_path=self.pynx_phasing_dir ) try: modes, mode_weights = self.result_analyser.mode_decomposition( search_pattern=search_pattern ) self._save_pynx_results(modes=modes, mode_weights=mode_weights) except PyNXImportError: self.logger.info( "PyNX is not installed on the current machine. Will try to " "run the provided command instead." ) if cmd is None: cmd = ( "pynx-cdi-analysis candidate_*.cxi --modes 1 " "--modes_output mode.h5" ) self.logger.info( f"No command provided will use the default: {cmd}" ) self._run_cmd(cmd, self.pynx_phasing_dir) self._save_pynx_results(self.pynx_phasing_dir + "mode.h5")
def _save_pynx_results( self, mode_path: str = None, modes: list = None, mode_weights: list = None, ) -> None: """ Save PyNX phasing results and analysis metadata to CXI file. Includes reconstructions, best candidates, metrics, and modes. """ if mode_path is not None: with h5py.File(mode_path) as file: modes = file["entry_1/image_1/data"][()] mode_weights = file["entry_1/data_2/data"][()] else: if modes is None: raise ValueError("mode_path or modes required.") best_candidates, sorted_results, metrics = None, None, None if self.result_analyser is not None: best_candidates = self.result_analyser.best_candidates sorted_results = self.result_analyser.sorted_phasing_results metrics = self.result_analyser.metrics path = f"{self.dump_dir}/S{self.scan}_pynx_reconstruction_mode.cxi" with CXIFile(path, "w") as cxi: cxi.stamp() cxi.set_entry() # Copy the information in PyNX file path = cxi.create_cxi_group( "process", comment="PyNX phasing.", description="""Process information for the original CDI reconstruction (best solution).""", ) if best_candidates: with h5py.File(best_candidates[0], "r") as f: if "/entry_last/image_1/process_1" in f: for key in f["/entry_last/image_1/process_1"]: f.copy( f"/entry_last/image_1/process_1/{key}", cxi.get_node(path), name=key, ) path = cxi.create_cxi_group( "result", description=f"Check results out in {self.pynx_phasing_dir}", ) cxi.softlink(f"{path}/process_1", "/entry_1/process_1") cxi.create_cxi_group( "process", command="cdiutils.BcdiPipeline.analyse_phasing_results()", comment="Pipeline phase retrieval results analysis step.", ) path = cxi.create_cxi_group( "result", description="Sort the phasing results according to criterion.", sorted_results=sorted_results, metrics=metrics, ) cxi.softlink(f"{path}/process_2", "/entry_1/process_2") cxi.create_cxi_group( "process", command="cdiutils.BcdiPipeline.select_best_candidates()", comment="Best candidate selection", ) path = cxi.create_cxi_group( "result", description="Only selected best candidates", best_candidates=best_candidates, ) cxi.softlink(f"{path}/process_3", "/entry_1/process_3") cxi.create_cxi_group( "process", command=( "pynx-cdi-analysis / " "cdiutils.pipeline.BcdiPipeline.mode_decomposition()" ), comment="Mode decomposition", ) path = cxi.create_cxi_group( "result", description="The weights of each calculated mode.", mode_weights=mode_weights, ) cxi.softlink(f"{path}/process_4", "/entry_1/process_4") path = cxi.create_cxi_image(data=modes, process_4="process_4") cxi.get_node(path).attrs["description"] = "Mode decomposition"
[docs] @Pipeline.process def postprocess(self, **params) -> None: """ Postprocess phase retrieval results to extract physical properties. Comprehensive workflow: loads reconstruction, orthogonalises to lab frame, optionally flips/apodizes, estimates support isosurface, and computes structural properties (phase, displacement, strain, d-spacing, lattice parameter). Args: **params: optional parameters to override instance params. Common overrides: - 'voxel_size': target voxel size (nm) - 'isosurface': support threshold (0-1) - 'apodize': window function ('blackman', 'hann', etc.) - 'flip': flip reconstruction (complex conjugate) - 'convention': 'xu' or 'cxi' - 'handle_defects': enable defect-aware processing Side effects: Updates instance attributes: reconstruction, structural_ props, extra_info. Generates amplitude distribution plot. Raises: ValueError: if unrecognised parameter provided. """ _path = f"{self.dump_dir}S{self.scan}_preprocessed_data.cxi" # Whether to reload the pre-processing .cxi file. if self.q_lab_pos.get("ref") is None or self.converter is None: self.logger.info(f"Loading parameters from:\n{_path}") self.update_from_file(_path) if params: for p in params: if not isparameter(p): raise ValueError( f"Parameter '{p}' is not recognised. " "Please check the possible parameter names among:\n" + "\n".join( [ f"- {key}" for key in DEFAULT_PIPELINE_PARAMS.keys() ] ) ) self.logger.info( f"Additional parameters provided {params}, will update the " "current dictionary of parameters." ) self.params.update(params) if "voxel_size" in params and params["voxel_size"] is None: with CXIFile(_path, "r") as cxi: self.converter = self._build_converter_from_cxi(cxi) # Load the reconstruction mode _path = f"{self.dump_dir}/S{self.scan}_pynx_reconstruction_mode.cxi" self.reconstruction = self._load_reconstruction(_path, centre=True) # Handle the voxel size self._check_voxel_size() if not self.params["orthogonalise_before_phasing"]: self.reconstruction = self.converter.orthogonalise_to_direct_lab( self.reconstruction ) # Change convention of the reconstruction if necessary. if self.params["convention"].lower() == "cxi": self.reconstruction = Geometry.swap_convention(self.reconstruction) self.logger.info( f"Voxel size finally used is: {self.params['voxel_size']} nm in " f"the {self.params['convention'].upper()} convention." ) # Handle flipping and apodization if self.params["flip"]: self.reconstruction = PostProcessor.flip_reconstruction( self.reconstruction ) if self.params["apodize"]: self.logger.info( "Apodizing the complex array using " f"{self.params['apodize']} filter." ) self.reconstruction = PostProcessor.apodize( self.reconstruction, window_type=self.params["apodize"] ) # First compute the histogram of the amplitude to get an # isosurface estimate self.logger.info( "Finding an isosurface estimate based on the " "reconstructed Bragg electron density histogram:", ) isosurface, _ = find_isosurface( np.abs(self.reconstruction), nbins=100, sigma_criterion=3, plot=True, # plot and return the figure, save=f"{self.dump_dir}/S{self.scan}_amplitude_distribution.png", ) # store the estimated isosurface self.extra_info["estimated_isosurface"] = isosurface self.logger.info(f"Isosurface estimated at {isosurface}.") if self.params["isosurface"] is not None: self.logger.info( "Isosurface provided by user will be used: " f"{self.params['isosurface']}." ) elif isosurface < 0 or isosurface > 1: self.params["isosurface"] = self.class_isosurface self.logger.info( f"isosurface estimate has a wrong value ({isosurface}) and " f"will be set set to {self.class_isosurface = }." # noqa: E251 ) else: self.params["isosurface"] = isosurface self.logger.info( "Computing the structural properties:" "\n\t- phase \n\t- displacement\n\t- het. (heterogeneous) strain" "\n\t- d-spacing\n\t- lattice parameter." "\nhet. strain maps are computed using various methods, either" " phase ramp removal or d-spacing method.\n" f"The theoretical Bragg peak is {self.params['hkl']}." ) if self.params["handle_defects"]: self.logger.info("Defect handling requested.") if self.params["convention"].lower() == "cxi": g_vector = Geometry.swap_convention(self.q_lab_pos["ref"]) else: g_vector = self.q_lab_pos["ref"] self.structural_props = PostProcessor.get_structural_properties( self.reconstruction, support_parameters=None, isosurface=self.params["isosurface"], g_vector=g_vector, hkl=self.params["hkl"], voxel_size=self.params["voxel_size"], phase_factor=-1, # it came out of pynx.cdi handle_defects=self.params["handle_defects"], ) to_plot = { k: self.structural_props[k] for k in [ "amplitude", "phase", "displacement", "het_strain", "lattice_parameter", ] } # plot and save the detector data in the full detector frame and # final frame dump_file_tmpl = f"{self.dump_dir}/S{self.scan}_" + "{}.png" sample_scan = f"{self.sample_name}, S{self.scan}" self.extra_info["averaged_lattice_parameter"] = np.nanmean( self.structural_props["lattice_parameter"] ) self.extra_info["averaged_dspacing"] = np.nanmean( self.structural_props["dspacing"] ) table_info = { "Isosurface": self.params["isosurface"], "Averaged Lat. Par. (Å)": ( self.extra_info["averaged_lattice_parameter"] ), "Averaged d-spacing (Å)": self.extra_info["averaged_dspacing"], } PipelinePlotter.summary_plot( title=f"Summary figure, {sample_scan}", support=self.structural_props["support"], table_info=table_info, voxel_size=self.params["voxel_size"], convention=self.params["convention"], save=dump_file_tmpl.format("summary_plot"), **to_plot, ) PipelinePlotter.summary_plot( title=f"Strain check figure, {sample_scan}", support=self.structural_props["support"], voxel_size=self.params["voxel_size"], convention=self.params["convention"], save=dump_file_tmpl.format("strain_methods"), **{ k: self.structural_props[k] for k in [ "het_strain", "het_strain_from_dspacing", "het_strain_with_ramp", ] }, ) axis_names = [r"z_{cxi}", r"y_{cxi}", r"x_{cxi}"] denom = ( "du_" + "{" + f"{''.join([str(e) for e in self.params['hkl']])}" + "}" ) titles = [f"${denom}/d{axis_names[i]}$" for i in range(3)] displacement_gradient_plots = { titles[i]: self.structural_props["displacement_gradient"][i] for i in range(3) } ptp_value = np.nanmax( self.structural_props["displacement_gradient"][0] ) - np.nanmin(self.structural_props["displacement_gradient"][0]) PipelinePlotter.summary_plot( title=f"Shear displacement, {sample_scan}", support=self.structural_props["support"], voxel_size=self.params["voxel_size"], convention=self.params["convention"], save=dump_file_tmpl.format("shear_displacement"), unique_vmin=-ptp_value / 2, unique_vmax=ptp_value / 2, cmap=RED_TO_TEAL, **displacement_gradient_plots, ) _, _, means, fwhms = PipelinePlotter.strain_statistics( self.structural_props["het_strain_from_dspacing"], self.structural_props["support"], title=f"Strain statistics, {sample_scan}", save=dump_file_tmpl.format("strain_statistics"), ) self.extra_info["strain_means"] = means self.extra_info["strain_fwhms"] = fwhms plot_3d_surface_projections( data=self.structural_props["het_strain"], support=self.structural_props["support"], voxel_size=self.params["voxel_size"], convention=self.params["convention"], cmap="cet_CET_D13", vmin=-np.nanmax(np.abs(self.structural_props["het_strain"])), vmax=np.nanmax(np.abs(self.structural_props["het_strain"])), cbar_title=r"Strain (%)", title=f"3D views of the strain, {sample_scan}", save=dump_file_tmpl.format("3d_strain_views"), ) # Load the orthogonalised peak path = f"{self.dump_dir}/S{self.scan}_preprocessed_data.cxi" with CXIFile(path, "r") as cxi: ortho_exp_intensity = cxi["entry_1/data_2/data"][()] exp_q_grid = tuple( cxi[f"entry_1/result_2/{k}_xu"][()] for k in ("qx", "qy", "qz") ) # To compare, we must make sure we are back in XU convention. obj = self.structural_props["amplitude"] * np.exp( -1j * self.structural_props["phase"] ) voxel_size = self.params["voxel_size"] if self.params["convention"].lower() == "cxi": obj = Geometry.swap_convention(obj) voxel_size = Geometry.swap_convention(voxel_size) PipelinePlotter.plot_final_object_fft( obj, voxel_size, self.converter.q_lab_shift, ortho_exp_intensity, exp_q_grid, title=f"FFT of final object vs. experimental data, {sample_scan}", save=f"{self.dump_dir}/S{self.scan}_final_object_fft.png", ) self._save_postprocessed_data() self._save_parameter_file()
def _load_reconstruction( self, path: str, centre: bool = False, isosurface: float = None ) -> np.ndarray: """ Load reconstruction from CXI and optionally centre using support. Calculates oversampling ratios from loaded data. """ with CXIFile(path, "r") as cxi: reconstruction = cxi["entry_1/data_1/data"][0] self._get_oversampling_ratios(reconstruction) if centre: if isosurface is None: isosurface = self.class_isosurface amp = np.abs(reconstruction) / np.max(np.abs(reconstruction)) support = np.where(amp >= isosurface, 1, 0) com = CroppingHandler.get_position(support, "com") reconstruction = CroppingHandler.force_centred_cropping( reconstruction, where=com ) return reconstruction def _get_oversampling_ratios(self, data: np.ndarray) -> np.ndarray: amp = np.abs(data) / np.max(np.abs(data)) isosurface = self.params["isosurface"] # if not provided, isosurface is hardcoded at this stage if isosurface is None: isosurface = self.class_isosurface support = np.where(amp >= isosurface, 1, 0) # Print the oversampling ratio ratios = get_oversampling_ratios(support) self.logger.info( "The oversampling ratios in each direction (original frame) are " + ", ".join( [f"axis{i}: {ratios[i]:.1f}" for i in range(len(ratios))] ) ) self.extra_info["oversampling_ratios"] = ratios if support.shape != tuple(self.params["preprocess_shape"]): self.logger.warning( f"Shapes before {self.params['preprocess_shape']} " f"and after {support.shape} Phase Retrieval are different.\n" "Check out PyNX parameters (ex.: auto_center_resize (now " "deprecated) and roi)." ) def _check_voxel_size(self) -> None: """ Validate and initialise voxel size from converter or params. Handles convention conversion between XU and CXI coordinates. """ self.extra_info["voxel_size_from_extent"] = ( self.converter.direct_lab_voxel_size ) if self.params["convention"].lower() == "cxi": self.extra_info["voxel_size_from_extent"] = ( Geometry.swap_convention(self.converter.direct_lab_voxel_size) ) # if cxi requested, convert the voxel size from extent if self.params["voxel_size"] is None: self.params["voxel_size"] = self.converter.direct_lab_voxel_size # In the SpaceConverter, the convention is XU. if self.params["convention"].lower() == "cxi": self.params["voxel_size"] = Geometry.swap_convention( self.params["voxel_size"] ) else: # We consider voxel_size is given with the same convention # as the one specified in 'convention'. # if 1D, make it 3D if isinstance( self.params["voxel_size"], (float, int, np.floating, np.integer), ): self.params["voxel_size"] = tuple( np.repeat( self.params["voxel_size"], self.reconstruction.ndim ) ) if self.params["convention"].lower() == "cxi": # Set the direct space interpolator voxel size with XU # convention. self.converter.direct_lab_voxel_size = ( Geometry.swap_convention(self.params["voxel_size"]) ) def _save_postprocessed_data(self) -> None: """ Save post-processing results to CXI file. Stores orthogonalised data, strain statistics, and metadata. """ dump_path = f"{self.dump_dir}/S{self.scan}_postprocessed_data.cxi" with CXIFile(dump_path, "w") as cxi: cxi.stamp() msg = """Post-processing of the data including: - orthogonalisation - isosurface estimation - apodization - structural properties computation""" path = cxi.create_cxi_group( "process", description=msg, comment="Data post-processing", command="cdiutils.BcdiPipeline.postprocess()", ) cxi.softlink(f"{path}/program", "/creator") cxi.softlink(f"{path}/version", "/version") path = cxi.create_cxi_group( "result", description="Orthogonalisation procedure", voxel_size_from_reciprocal_space_extent=( self.extra_info["voxel_size_from_extent"] ), voxel_size=self.params["voxel_size"], units="nm", ) cxi.softlink(f"{path}/process_1", "entry_1/process_1") path = cxi.create_cxi_group( "result", description="Surface determination", estimated_isosurface=self.extra_info["estimated_isosurface"], used_isosurface=self.params["isosurface"], ) cxi.softlink(f"{path}/process_1", "entry_1/process_1") path = cxi.create_cxi_group( "result", description="Oversampling estimation", oversampling_ratios=self.extra_info["oversampling_ratios"], ) cxi.softlink(f"{path}/process_1", "entry_1/process_1") path = cxi.create_cxi_group( "result", description="Averaged lattice parameter and d-spacing", dspacing=self.extra_info["averaged_dspacing"], lattice_parameter=self.extra_info[ "averaged_lattice_parameter" ], units="angstrom", ) cxi.softlink(f"{path}/process_1", "entry_1/process_1") path = cxi.create_cxi_group( "result", description="Strain statistics", strain_averages=self.extra_info["strain_means"], strain_fwhms=self.extra_info["strain_fwhms"], units="%", ) cxi.softlink(f"{path}/process_1", "entry_1/process_1") # Copy entries from the preprocessed_data file prep_path = f"{self.dump_dir}/S{self.scan}_preprocessed_data.cxi" if os.path.isfile(prep_path): with CXIFile(prep_path, "r") as f: for group in ( "detector_1", "geometry_1", "sample_1", "source_1", ): try: f.copy( f"/entry_1/{group}", cxi, f"/entry_1/{group}" ) except KeyError: print(f"Cannot found {group} in {prep_path} file.") # We do not copy the parameters from preprocessed_data # because they might have changed. cxi.create_cxi_group("parameters", **self.params) for key, data in self.structural_props.items(): if isinstance(data, np.ndarray): path = cxi.create_cxi_image( data=data, data_space="Direct space", title=key, process_1="process_1", ) cxi.softlink(f"entry_1/{key}", path) # save as npz np.savez_compressed( f"{self.dump_dir}/S{self.scan}_structural_properties.npz", **self.structural_props, ) # Save as vti if IS_VTK_AVAILABLE: to_save_as_vti = { k: self.structural_props[k] for k in [ "amplitude", "support", "phase", "displacement", "het_strain", "het_strain_from_dspacing", "lattice_parameter", "dspacing", ] } to_save_as_vti["amplitude"] = normalise( to_save_as_vti["amplitude"] ) # add the dspacing average and lattice constant average around # the NP to avoid nan values that are annoying for 3D # visualisation for k in ( "het_strain", "het_strain_from_dspacing", "dspacing", "lattice_parameter", "displacement", ): to_save_as_vti[k] = np.where( np.isnan(to_save_as_vti[k]), np.nanmean(to_save_as_vti[k]), to_save_as_vti[k], ) # save to vti file save_as_vti( f"{self.dump_dir}/S{self.scan}_structural_properties.vti", voxel_size=self.params["voxel_size"], cxi_convention=(self.params["convention"].lower() == "cxi"), **to_save_as_vti, ) else: self.logger.info( "vtk package not available, will not save the vti file." ) self.logger.info(f"Post-processed data file saved at:\n{dump_path}")
[docs] def facet_analysis(self) -> None: facet_anlysis_processor = FacetAnalysisProcessor( self.params["facets"], self.params["support"]["support_method"], self.dump_dir, ) facet_anlysis_processor.facet_analysis()
[docs] def show_3d_final_result(self): """ Show plotly interactive figure of the final post-processed reconstruction. """ from cdiutils.interactive import plot_3d_isosurface path = f"{self.dump_dir}S{self.scan}_postprocessed_data.cxi" # first check whether the file exists if not os.path.exists(path): raise FileNotFoundError( f"The result file ({path}) does not exist yet, have you " "run the post-processing method?" ) loaded_data = load_cxi( path, "amplitude", "support", "phase", "displacement", "het_strain", "het_strain_from_dspacing", "dspacing", "lattice_parameter", ) loaded_voxel_size = load_cxi(path, "voxel_size") # the default quantity to visualise is the heterogeneous strain # with cmap cet_CET_D13 # calculate symmetric limits based on absolute max data_abs_max = np.abs(loaded_data["het_strain_from_dspacing"]).max() vmin = -data_abs_max vmax = data_abs_max return plot_3d_isosurface( loaded_data["amplitude"], loaded_data, voxel_size=loaded_voxel_size, initial_quantity="het_strain_from_dspacing", cmap="cet_CET_D13", vmin=vmin, vmax=vmax, )