"""Utility functions for phase processing in a BCDI framework.
Author:
Clément Atlan - 27.10.2023
"""
import copy
import numpy as np
from scipy.fft import fftn, fftshift, ifftn, ifftshift
from scipy.ndimage import binary_erosion
from skimage.filters import window
from skimage.restoration import unwrap_phase
from sklearn.linear_model import LinearRegression
from cdiutils.process.support_processor import SupportProcessor
from cdiutils.utils import (
CroppingHandler,
find_suitable_array_shape,
hybrid_gradient,
make_support,
nan_to_zero,
normalise,
zero_to_nan,
)
[docs]
class PostProcessor:
"""
A class to bundle all functions needed to post-process BCDI data.
"""
[docs]
@staticmethod
def prepare_volume(
complex_object: np.ndarray,
isosurface,
support_parameters: dict = None,
final_shape: np.ndarray | tuple | list = None,
) -> tuple[np.ndarray, np.ndarray]:
"""
Prepare the volume by finding a smaller array shape, centering
at the center of mass of the support, and cropping
Args:
complex_object (np.ndarray): the complex object
(rho e^{i phi})
isosurface (bool): the isosurface that determines the
support
final_shape (np.ndarray | tuple | list, optional): the
final shape of the array requested. Defaults to None.
Returns:
tuple[np.ndarray, np.ndarray]: the cropped complex_object
and the associated support.
"""
if support_parameters is None:
support = make_support(
normalise(np.abs(complex_object)),
isosurface=isosurface,
nan_values=False,
)
if final_shape is None:
final_shape = find_suitable_array_shape(
support, pad=np.repeat(10, support.ndim)
)
# center the arrays at the center of mass of the support
com = CroppingHandler.get_position(support, "com")
complex_object = CroppingHandler.force_centred_cropping(
complex_object, where=com, output_shape=final_shape
)
support = CroppingHandler.force_centred_cropping(
support, where=com, output_shape=final_shape
)
bulk = binary_erosion(support.astype(bool))
surface = (support.astype(bool) & ~bulk).astype(int)
return complex_object, support, surface
support_pre_crop = make_support(
normalise(np.abs(complex_object)), isosurface=0.2
)
final_shape_pre_crop = copy.copy(final_shape)
if final_shape_pre_crop is None:
final_shape_pre_crop = find_suitable_array_shape(
support_pre_crop, pad=[6, 6, 6]
)
com_pre_crop = CroppingHandler.get_position(support_pre_crop, "com")
complex_object_pre_crop = CroppingHandler.force_centred_cropping(
complex_object,
where=com_pre_crop,
output_shape=final_shape_pre_crop,
)
support_processor = SupportProcessor(
params=support_parameters,
data=normalise(np.abs(complex_object_pre_crop)),
isosurface=isosurface,
)
support, surface = support_processor.support_calculation()
if final_shape is None:
final_shape = find_suitable_array_shape(support, pad=[6, 6, 6])
print(f"[INFO] new array shape is {final_shape}")
# center the arrays at the center of mass of the support
com = CroppingHandler.get_position(support, "com")
complex_object = CroppingHandler.force_centred_cropping(
complex_object_pre_crop, where=com, output_shape=final_shape
)
support = CroppingHandler.force_centred_cropping(
support, where=com, output_shape=final_shape
)
surface = CroppingHandler.force_centred_cropping(
surface, where=com, output_shape=final_shape
)
return complex_object, support, surface
[docs]
@staticmethod
def flip_reconstruction(data: np.ndarray) -> np.ndarray:
"""
Flip a direct space reconstruction.
Args:
data (np.ndarray): the 3D direct space reconstruction
Returns:
np.ndarray: the flipped reconstruction
"""
return ifftshift(ifftn(np.conj(fftn(fftshift(data)))))
[docs]
@staticmethod
def apodize(
direct_space_data: np.ndarray,
window_type: str = "blackman",
scale: float = 1,
) -> np.ndarray:
"""
Apodization in the direct space data using Blackman window.
Args:
direct_space_data (np.ndarray): the 3D volume data to
apodize.
scale (float, optional): value of the integral of the
Blackman window. Defaults to None.
Returns:
np.ndarray: Apodized 3D array.
"""
blackman_window = window(window_type, direct_space_data.shape)
blackman_window = blackman_window / blackman_window.max() * scale
q_space_data = ifftshift(fftn(fftshift(direct_space_data)))
q_space_data = q_space_data * blackman_window
return ifftshift(ifftn(fftshift(q_space_data)))
[docs]
@staticmethod
def unwrap_phase(
phase: np.ndarray, support: np.ndarray = None
) -> np.ndarray:
"""
Unwrap phase for voxels that belong to the given support.
Args:
phase (np.ndarray): the phase to unwrap
support (np.ndarray): the support where voxels of interest are
Returns:
np.ndarray: the unwrapped phase
"""
if support is None:
return unwrap_phase(phase, wrap_around=False)
support = nan_to_zero(support)
mask = np.where(support == 0, 1, 0)
phase = np.ma.masked_array(phase, mask=mask)
return unwrap_phase(phase, wrap_around=False).data
[docs]
@staticmethod
def remove_phase_ramp(phase: np.ndarray) -> np.ndarray:
"""
Remove the phase ramp of a 2 | 3D phase object.
Args:
phase (np.ndarray): the 2 | 3D phase object
Returns:
np.ndarray: the phase without the computed ramp.
"""
non_nan_coordinates = np.where(np.logical_not(np.isnan(phase)))
non_nan_phase = phase[non_nan_coordinates]
indices = np.indices(phase.shape)
filtered_indices = []
for i in range(len(indices)):
filtered_indices.append(indices[i][non_nan_coordinates])
indices = np.swapaxes(filtered_indices, 0, 1)
reg = LinearRegression().fit(indices, non_nan_phase)
indices = np.indices(phase.shape)
ramp = 0
for i in range(len(indices)):
ramp += reg.coef_[i] * indices[i]
ramp += reg.intercept_
return phase - ramp
[docs]
@staticmethod
def phase_offset_to_zero(
phase: np.ndarray,
support: np.ndarray = None,
) -> np.ndarray:
"""
Set the phase offset to the mean phase value.
"""
if support is None:
return phase - np.nanmean(phase)
return phase - np.nanmean(phase * support)
[docs]
@staticmethod
def get_displacement(
phase: np.ndarray,
g_vector: np.ndarray | tuple | list,
) -> np.ndarray:
"""
Calculate the displacement from phase and g_vector.
"""
return phase / np.linalg.norm(g_vector)
[docs]
@staticmethod
def get_displacement_gradient(
displacement: np.ndarray,
voxel_size: np.ndarray | tuple | list,
gradient_method: str = "hybrid",
) -> np.ndarray:
"""
Calculate the gradient of the displacement.
Args:
displacement (np.ndarray): displacement array.
voxel_size (np.ndarray | tuple | list): the voxel size of
the array.
gradient_method (str, optional): the method employed to
compute the gradient. "numpy" is the traditional gradient.
"hybrid" compute first order gradient at the surface and
second order within the bulk of the reconstruction.
Defaults to "hybrid".
Raises:
ValueError: If parsed method is unknown.
Returns:
np.ndarray: the gradient of the volume in the three
directions.
"""
if gradient_method == "numpy":
grad_function = np.gradient
elif gradient_method in ("hybrid", "h"):
grad_function = hybrid_gradient
else:
raise ValueError("Unknown method for normal strain computation.")
return grad_function(displacement, *voxel_size)
[docs]
@classmethod
def get_het_normal_strain(
cls,
displacement: np.ndarray,
g_vector: np.ndarray | tuple | list,
voxel_size: np.ndarray | tuple | list,
gradient_method: str = "hybrid",
) -> np.ndarray:
"""
Compute the heterogeneous normal strain, i.e. the gradient of
the displacement projected along the measured Bragg peak
direction.
Args:
displacement (np.ndarray): the displacement array
g_vector (np.ndarray | tuple | list): the position of the
measured Bragg peak (com | max of the intensity).
voxel_size (np.ndarray | tuple | list): voxel size of the
array
gradient_method (str, optional): the method employed to
compute the gradient. "numpy" is the traditional gradient.
"hybrid" compute first order gradient at the surface and
second order within the bulk of the reconstruction.
Defaults to "hybrid".
Returns:
np.ndarray: the heterogeneous normal strain
"""
displacement_gradient = cls.get_displacement_gradient(
displacement, voxel_size, gradient_method
)
displacement_gradient = np.moveaxis(
np.asarray(displacement_gradient),
source=0,
destination=displacement.ndim,
)
return np.dot(
displacement_gradient, g_vector / np.linalg.norm(g_vector)
)
[docs]
@classmethod
def get_structural_properties(
cls,
complex_object: np.ndarray,
isosurface: np.ndarray,
g_vector: np.ndarray | tuple | list,
hkl: tuple | list,
voxel_size: np.ndarray | tuple | list,
phase_factor: int = -1,
handle_defects: bool = False,
support_parameters: dict = None,
) -> dict:
"""
Main method used in the post-processing workflow. The method
computes all the structural properties of interest in BCDI
(amplitude, phase, displacement, displacement gradient,
heterogeneous strain d-spacing and lattice parameter maps.)
Args:
complex_object (np.ndarray): the reconstructed object
(rho e^(i phi))
g_vector (np.ndarray | tuple | list): the reciprocal space
node on which the displacement gradient must be
projected.
hkl (tuple | list): the probed Bragg reflection.
voxel_size (np.ndarray | tuple | list): the voxel size
of the 3D array.
phase_factor (int, optional): the factor the phase should
should be multiplied by, depending on the FFT convention
used. Defaults to -1 (PyNX convention in Phase
Retrieval, in PyNX scattering, use 1).
handle_defects (bool, optional): whether a defect is present
in the reconstruction, in this case phasing processing
and strain computation is different. Defaults to False.
Returns:
dict: the structural properties of the object in the form of
a dictionary. Each key corresponds to one quantity of'
interest, including: amplitude, support, phase,
displacement, displacement_gradient (in all three
directions), het. (heterogeneous) strain using various
methods, d-spacing, lattice parameter 3D maps. hkl,
g_vector and voxel size are also returned.
"""
complex_object, support, surface = cls.prepare_volume(
complex_object,
support_parameters=support_parameters,
isosurface=isosurface,
final_shape=None,
)
# extract phase and amplitude
amplitude = np.abs(complex_object)
phase = np.angle(complex_object) * phase_factor
support = zero_to_nan(support) # 0 values must be nan now
phase = cls.unwrap_phase(phase, support)
phase = phase * support
phase_with_ramp = phase.copy() # save the 'ramped' phase for later
phase = cls.remove_phase_ramp(phase)
phase = cls.phase_offset_to_zero(phase)
# compute the displacement
displacement = cls.get_displacement(phase, g_vector)
displacement_with_ramp = cls.get_displacement(
phase_with_ramp, g_vector
)
if handle_defects:
het_strain_with_ramp = np.zeros(amplitude.shape)
displacement_gradient = np.zeros((3,) + amplitude.shape)
phases = [
np.mod(phase_with_ramp * support + i * (np.pi / 2), 2 * np.pi)
for i in range(3)
]
strains = [
cls.get_het_normal_strain(
cls.get_displacement(phases[i], g_vector) * 1e-1,
g_vector,
voxel_size,
gradient_method="hybrid",
)
for i in range(3)
]
displacement_gradients = [
cls.get_displacement_gradient(
cls.get_displacement(phases[i], g_vector) * 1e-1,
voxel_size,
gradient_method="hybrid",
)
for i in range(3)
]
for i in range(3):
# strain case
mask = np.isclose(strains[i % 3], strains[(i + 1) % 3])
het_strain_with_ramp[mask == 1] = strains[i][mask == 1]
# displacement gradient case
for k in range(3):
mask = np.isclose(
displacement_gradients[k][i % 3],
displacement_gradients[k][(i + 3) % 3],
)
displacement_gradient[k][mask == 1] = (
displacement_gradients[k][i][mask == 1]
)
numpy_het_strain = np.array([np.nan])
else:
# compute the various strain quantities
numpy_het_strain = cls.get_het_normal_strain(
displacement * 1e-1, # displacement values converted in nm.
g_vector,
voxel_size,
gradient_method="numpy",
)
het_strain = cls.get_het_normal_strain(
displacement * 1e-1,
g_vector,
voxel_size,
gradient_method="hybrid",
)
het_strain_with_ramp = cls.get_het_normal_strain(
displacement_with_ramp * 1e-1,
g_vector,
voxel_size,
gradient_method="hybrid",
)
# compute the displacement gradient
displacement_gradient = cls.get_displacement_gradient(
displacement, voxel_size, gradient_method="hybrid"
)
# compute the dspacing and lattice_parameter
dspacing = (
2 * np.pi / np.linalg.norm(g_vector) * (1 + het_strain_with_ramp)
)
lattice_parameter = (
np.sqrt(hkl[0] ** 2 + hkl[1] ** 2 + hkl[2] ** 2) * dspacing
)
dspacing_mean = np.nanmean(dspacing)
het_strain_from_dspacing = (dspacing - dspacing_mean) / dspacing_mean
if handle_defects:
het_strain = het_strain_from_dspacing.copy()
# all strains are saved in percent
return {
"amplitude": amplitude,
"support": nan_to_zero(support),
"surface": nan_to_zero(surface),
"phase": nan_to_zero(phase),
"displacement": nan_to_zero(displacement),
"displacement_gradient": nan_to_zero(displacement_gradient),
"het_strain": nan_to_zero(het_strain) * 100,
"het_strain_with_ramp": nan_to_zero(het_strain_with_ramp) * 100,
"het_strain_from_dspacing": nan_to_zero(het_strain_from_dspacing)
* 100,
"numpy_het_strain": numpy_het_strain * 100,
"dspacing": dspacing,
"lattice_parameter": lattice_parameter,
"hkl": hkl,
"g_vector": g_vector,
"voxel_size": voxel_size,
}