Source code for cdiutils.utils

import inspect
import numbers
import warnings

import matplotlib
import numpy as np
import scipy.constants as cts
from scipy.fft import fftn, fftshift, ifftn, ifftshift
from scipy.interpolate import RegularGridInterpolator
from scipy.ndimage import center_of_mass, convolve, median_filter


[docs] def transform_volume( data: np.ndarray, transform_matrix: np.ndarray, output_shape: tuple | None = None, voxel_size: tuple | float | None = None, method: str = "direct", preserve_norm: bool = False, ) -> np.ndarray: """Apply a linear transform to a 3D volume. Args: data: complex 3D array describing the source volume. transform_matrix: 3x3 matrix mapping source indices to target. output_shape: target volume shape; uses ``data.shape`` when None. voxel_size: target voxel spacing; uses unit voxels when None. method: transformation mode; ``"direct"`` or ``"fourier"``. preserve_norm: if True, rescales result by ``|det(A)|**(-1/2)`` to approximately preserve the L2 norm for the underlying continuous field. For pure rotations (orthogonal ``A``) this factor is 1; for general shears or scales it compensates the change of volume element. Returns: np.ndarray: interpolated volume on the target grid. """ if data.ndim != 3: raise ValueError("data must be a 3D array") if transform_matrix.shape != (3, 3): raise ValueError("transform_matrix must be 3x3") if output_shape is None: output_shape = data.shape if len(output_shape) != 3: raise ValueError("output_shape must have length three") methods = ("direct", "fourier") if method not in methods: raise ValueError("method must be 'direct' or 'fourier'") if voxel_size is None: voxel_size_array = np.ones(3, dtype=float) elif isinstance(voxel_size, numbers.Number): voxel_size_array = np.full(3, float(voxel_size), dtype=float) else: if len(voxel_size) != 3: raise ValueError("voxel_size must have length three") voxel_size_array = np.asarray(voxel_size, dtype=float) matrix_inverse = np.linalg.inv(transform_matrix) if method == "direct": result = _transform_direct( data, matrix_inverse, output_shape, voxel_size_array, ) else: result = _transform_fourier( data, transform_matrix, matrix_inverse, output_shape, ) if preserve_norm: det = float(np.linalg.det(transform_matrix)) if det == 0.0: raise ValueError( "Cannot preserve norm for a singular transform matrix", ) scale = abs(det) ** (-0.5) result = result * scale return result
def _transform_direct( data: np.ndarray, matrix_inverse: np.ndarray, output_shape: tuple, voxel_size_array: np.ndarray, ) -> np.ndarray: """Resample 3D data by direct interpolation on a target grid. Args: data: input 3D volume. matrix_inverse: inverse of the 3x3 transform matrix. output_shape: shape of the output volume. voxel_size_array: spacing of output voxels along each axis. Returns: np.ndarray: interpolated volume with the given output shape. """ target_axes = [ # physical coordinates of target grid, centred at zero (np.arange(size) - size // 2) * spacing for size, spacing in zip(output_shape, voxel_size_array) ] target_mesh = np.meshgrid(*target_axes, indexing="ij") target_coords = np.stack( [axis.ravel() for axis in target_mesh], axis=1, ) source_coords = target_coords @ matrix_inverse.T source_axes = [ np.arange(-dim // 2, dim // 2, 1, dtype=float) for dim in data.shape ] interpolator = RegularGridInterpolator( tuple(source_axes), data, method="linear", bounds_error=False, fill_value=0.0, ) interpolated = interpolator(source_coords).reshape(output_shape) return interpolated.astype(data.dtype, copy=False) def _transform_fourier( data: np.ndarray, transform_matrix: np.ndarray, matrix_inverse: np.ndarray, output_shape: tuple, ) -> np.ndarray: """Resample 3D data using Fourier-domain rotation. Args: data: input 3D volume. transform_matrix: 3x3 transform matrix in spatial domain. matrix_inverse: inverse of the transform matrix. output_shape: shape of the output volume. Returns: np.ndarray: rotated and interpolated volume in spatial domain. """ orthogonality = transform_matrix @ transform_matrix.T if not np.allclose(orthogonality, np.eye(3), atol=1e-6): raise ValueError( "fourier method requires an orthogonal transform matrix" ) complex_data = data.astype(np.complex128, copy=False) frequency_data = fftshift(fftn(complex_data)) rotated_frequency = _transform_direct( frequency_data, matrix_inverse, output_shape, np.ones(3, dtype=float), ) rotated_frequency = rotated_frequency.astype(np.complex128, copy=False) spatial_data = ifftn(ifftshift(rotated_frequency)) if np.iscomplexobj(data): return spatial_data.astype(data.dtype, copy=False) return np.real(spatial_data).astype(data.dtype, copy=False)
[docs] def get_reciprocal_voxel_size( real_voxel_size: float | tuple | np.ndarray, shape: tuple | list | np.ndarray, convention: str | None = None, ) -> tuple: """ Calculate reciprocal-space voxel size from real-space parameters. Uses the generic relation dq = factor / (N * dx) for each axis, where ``N`` is the number of pixels and ``dx`` the corresponding real-space voxel size. This implementation is dimension-agnostic and does not assume any particular axis ordering. ``real_voxel_size`` and ``shape`` are only required to have the same length. This makes it compatible with both CXI- and xrayutilities-style conventions as long as the ordering of ``real_voxel_size`` matches that of ``shape``. Args: real_voxel_size: Real-space voxel size (scalar or sequence). If scalar, the same value is used for all axes; otherwise it must have the same length as ``shape``. shape: Array shape, e.g. ``(nz, ny, nx)``. Any dimensionality is accepted. convention: Optional convention for calculation. Currently only controls an overall 2pi factor when set to ``"xu"`` or ``"crystallography"`` or None. For other values ``factor=1``. Defaults to None. Returns: Tuple of reciprocal voxel sizes (dq0, dq1, ..., dq_{ndim-1}) in inverse units, with the same ordering as ``shape``. """ # normalise real_voxel_size to a 1D float array matching ndim ndim = len(shape) if isinstance(real_voxel_size, (int, float)): real_sizes = np.full(ndim, float(real_voxel_size), dtype=float) else: real_sizes = np.asarray(real_voxel_size, dtype=float).ravel() if real_sizes.size != ndim: raise ValueError( "real_voxel_size must be scalar or have the same " "length as shape", ) # choose overall factor factor = 1.0 if convention is None or convention in ("xu", "crystallography"): factor = 2 * np.pi dq = factor / (shape * real_sizes) return tuple(float(v) for v in dq)
[docs] def bin_along_axis( data: np.ndarray | list, binning_factor: int, binning_method: str = "sum", axis: int = 0, ) -> np.ndarray: """ Bin n-dimensional data along a specified axis using the specified method (mean, sum, median, max). Args: data (np.ndarray | list): n-dimensional data to be binned. binning_factor (int): the number of elements per bin. binning_method (str, optional): the binning method, either "mean", "sum", "median", or "max". Defaults to "sum". axis (int, optional): the axis along which to perform binning. Defaults to 0. Raises: ValueError: if the binning_method is unknown. Returns: np.ndarray: binned data. """ if binning_factor == 1 or binning_factor is None: return data # No binning required if isinstance(data, list): data = np.array(data) axis = 0 original_dim = data.shape[axis] nb_of_bins = original_dim // binning_factor remaining = original_dim % binning_factor # Reshape data for easy binning, ignore leftover data for now. # Move the axis to the front for easier manipulation. reshaped_data = np.moveaxis(data, axis, 0) full_bins_data = reshaped_data[: nb_of_bins * binning_factor].reshape( nb_of_bins, binning_factor, *reshaped_data.shape[1:] ) # Get the binning operation if binning_method == "mean": operation = np.mean binned_data = np.mean(full_bins_data, axis=1) elif binning_method == "sum": operation = np.sum binned_data = np.sum(full_bins_data, axis=1) elif binning_method == "max": operation = np.max elif binning_method == "median": operation = np.median else: raise ValueError(f"Unsupported binning method: {binning_method}") binned_data = operation(full_bins_data, axis=1) # Handle the remaining (leftover data that doesn't fit perfectly # into a bin) if remaining > 0: remaining_data = reshaped_data[-remaining:] remaining_bin = operation(remaining_data, axis=0) binned_data = np.concatenate( [binned_data, np.expand_dims(remaining_bin, axis=0)], axis=0 ) # Move the axis back to its original position return np.moveaxis(binned_data, 0, axis)
[docs] def get_prime_factors(n: int) -> list[int]: """Return the prime factors of the given integer.""" i = 2 factors = [] while i * i <= n: if n % i: i += 1 else: n //= i factors.append(i) if n > 1: factors.append(n) return factors
[docs] def is_valid_shape( n: int, maxprime: int = 13, required_dividers: tuple[int] = (2,) ) -> bool: """Check if n meets Pynx shape constraints.""" factors = get_prime_factors(n) return max(factors) <= maxprime and all( n % k == 0 for k in required_dividers )
[docs] def adjust_to_valid_shape( n: int, maxprime: int = 13, required_dividers: tuple[int] = (2,), decrease: bool = True, ) -> int: """Find the nearest valid shape value.""" if maxprime < n: if n <= 1: raise ValueError("n<=1, cannot be adjusted.") while not is_valid_shape(n, maxprime, required_dividers): n = n - 1 if decrease else n + 1 if n == 0: raise ValueError("No valid shape found") return n
[docs] def ensure_pynx_shape( shape: int | tuple | list | np.ndarray, maxprime: int = 13, required_dividers: tuple[int] = (2,), decrease: bool = True, verbose: bool = False, ) -> tuple | np.ndarray | list: """ Ensure shape dimensions comply with Pynx constraints. This function has been adapted from the PyNX library. See: pynx.utils.math.smaller_primes. Args: shape (int | tuple | list | np.ndarray): the shape to adjust. maxprime (int, optional): the maximum prime factor allowed. Defaults to 13. required_dividers (tuple, optional): the required dividers. Defaults to (4,). decrease (bool, optional): whether to decrease the shape. Defaults to True. verbose (bool, optional): whether to print messages. Raises: TypeError: if the shape is not an int, list, tuple, or numpy array. Returns: tuple | np.ndarray | list: the adjusted shape. """ if isinstance(shape, int): adjusted_shape = adjust_to_valid_shape( shape, maxprime, required_dividers, decrease ) elif isinstance(shape, (list, tuple, np.ndarray)): adjusted_shape = [ adjust_to_valid_shape(dim, maxprime, required_dividers, decrease) for dim in shape ] adjusted_shape = ( np.array(adjusted_shape) if isinstance(shape, np.ndarray) else type(shape)(adjusted_shape) ) else: raise TypeError("Shape must be an int, list, tuple, or numpy array") if verbose: if adjusted_shape != tuple(shape): print( f"PyNX needs a different shape to comply with gpu-based FFT, " f"requested shape {tuple(shape)} will be cropped to " f"{adjusted_shape}." ) else: print("Shape already in agreement with pynx shape conventions.") return adjusted_shape
[docs] def energy_to_wavelength(energy: float) -> float: """ Find the wavelength in metre (not angstrom!) that energy (in eV) corresponds to. Args: energy (float): the energy to convert. Returns: float: the wavelength in metre. """ return (cts.c * cts.h) / (cts.e * energy)
[docs] def wavelength_to_energy(wavelength: float) -> float: """ Find the energy in eV that wavelength in metre (not angstrom!) corresponds to. Args: wavelength (float): the wavelength in metre to convert. Returns: float: the energy in eV. """ return (cts.c * cts.h) / (cts.e * wavelength)
[docs] def fill_up_support(support: np.ndarray) -> np.ndarray: """ Fill up the support using the convex hull of the support. Args: support (np.ndarray): The binary support array to fill up. Returns: np.ndarray: The filled support array. """ convex_support = np.zeros(support.shape) for axis in range(support.ndim): cumulative_sum = np.cumsum(support, axis=axis) reversed_cumulative_sum = np.flip( np.cumsum(np.flip(support, axis=axis), axis=axis), axis=axis ) combined_support = cumulative_sum * reversed_cumulative_sum convex_support[combined_support != 0] = 1 return convex_support
[docs] def size_up_support(support: np.ndarray) -> np.ndarray: kernel = np.ones(shape=(3, 3, 3)) convolved_support = convolve(support, kernel, mode="constant", cval=0.0) return np.where(convolved_support > 3, 1, 0)
[docs] def find_hull( volume: np.ndarray, threshold: float = 18, kernel_size: int = 3, boolean_values: bool = False, nan_value: bool = False, ) -> np.ndarray: """ Find the convex hull of a 2D or 3D object. :param volume: 2 or 3D np.ndarray. The volume to get the hull from. :param threshold: threshold that selects what belongs to the hull or not (int). If threshold >= 27, the returned hull will be similar to volume. :param kernel_size: the size of the kernel used to convolute (int). :param boolean_values: whether or not to return 1 and 0 np.ndarray or the computed coordination. :returns: the convex hull of the shape accordingly to the given threshold (np.array). """ kernel = np.ones(shape=tuple(np.repeat(kernel_size, volume.ndim))) convolved_support = convolve(volume, kernel, mode="constant", cval=0.0) hull = np.where( ((0 < convolved_support) & (convolved_support <= threshold)), 1 if boolean_values else convolved_support, np.nan if nan_value else 0, ) return hull
[docs] def make_support( data: np.ndarray, isosurface: float = 0.5, nan_values: bool = False ) -> np.ndarray: """Create a support using the provided isosurface value.""" data = normalise(data) if nan_values: return np.where(data >= isosurface, 1, np.nan) else: return (data >= isosurface).astype(int)
[docs] def unit_vector(vector: tuple | list | np.ndarray) -> np.ndarray: """Return a unit vector.""" return np.array(vector) / np.linalg.norm(vector)
[docs] def angle(v1: np.ndarray, v2: np.ndarray) -> float: """Compute angle between two vectors.""" return np.arccos( np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2)) )
[docs] def v1_to_v2_rotation_matrix(v1: np.ndarray, v2: np.ndarray) -> np.ndarray: """ Rotation matrix around axis v1xv2 """ vec_rot_axis = np.cross(v1, v2) normed_vrot = unit_vector(vec_rot_axis) theta = angle(v1, v2) n1, n2, n3 = normed_vrot ct = np.cos(theta) st = np.sin(theta) r = np.array( ( ( ct + n1**2 * (1 - ct), n1 * n2 * (1 - ct) - n3 * st, n1 * n3 * (1 - ct) + n2 * st, ), ( n1 * n2 * (1 - ct) + n3 * st, ct + n2**2 * (1 - ct), n2 * n3 * (1 - ct) - n1 * st, ), ( n1 * n3 * (1 - ct) - n2 * st, n2 * n3 * (1 - ct) + n1 * st, ct + n3**2 * (1 - ct), ), ) ) return r
[docs] def normalise(data: np.ndarray, zero_centered: bool = False) -> np.ndarray: """Normalise a np.ndarray so the values are between 0 and 1.""" if zero_centered: abs_max = np.max([np.abs(np.min(data)), np.abs(np.max(data))]) vmin, vmax = -abs_max, abs_max ptp = vmax - vmin else: ptp = np.ptp(data) vmin = np.min(data) return (data - vmin) / ptp
[docs] def normalise_complex_array(array: np.ndarray) -> np.ndarray: """Normalise a array of complex numbers.""" shifted_array = array - array.real.min() - 1j * array.imag.min() return shifted_array / np.abs(shifted_array).max()
[docs] def find_max_pos(data: np.ndarray) -> tuple: """Find the index coordinates of the maximum value.""" return np.unravel_index(data.argmax(), data.shape)
[docs] def shape_for_safe_centred_cropping( data_shape: tuple | np.ndarray | list, position: tuple | np.ndarray | list, final_shape: tuple = None, ) -> tuple: """ Utility function that finds the smallest shape that allows a safe cropping, i.e. without moving data from one side to another when using the np.roll() function. """ if not isinstance(data_shape, np.ndarray): data_shape = np.array(data_shape) if not isinstance(position, np.ndarray): position = np.array(position) secured_shape = 2 * np.min([position, data_shape - position], axis=0) secured_shape = tuple(round(e) for e in secured_shape) if final_shape is None: return tuple(secured_shape) return tuple(np.min([secured_shape, final_shape], axis=0))
def _center_at_com(data: np.ndarray): shape = data.shape com = tuple(e for e in center_of_mass(data)) print((np.array(shape) / 2 == np.array(com)).all()) com_to_center = np.array( [int(np.rint(shape[i] / 2 - com[i])) for i in range(3)] ) if (com_to_center == np.array((0, 0, 0)).astype(int)).all(): return data, com data = center(data, where=com) return _center_at_com(data)
[docs] def center( data: np.ndarray, where: str | tuple | list | np.ndarray = "com", return_former_center: bool = False, ) -> np.ndarray | tuple[np.ndarray, tuple]: """ Center 3D volume data such that the center of mass or max of data is at the very center of the 3D matrix. :param data: volume data (np.array). 3D numpy array which will be centered. :param com: center of mass coordinates(list, np.array). If no com is provided, com of the given data is computed (default: None). :param where: what region to place at the center (str), either com or max, or a tuple of the coordinates where to place the center at. :returns: centered 3D numpy array. """ shape = data.shape if isinstance(where, (tuple, list, np.ndarray)) and len(where) == 3: reference_position = tuple(where) elif where == "max": reference_position = find_max_pos(data) elif where == "com": reference_position = tuple(e for e in center_of_mass(data)) else: raise ValueError( "where must be 'max', 'com' or tuple or list of 3 floats " f"coordinates, can't be type: {type(where)} ({where}) " ) xcenter, ycenter, zcenter = reference_position centered_data = np.roll(data, int(np.rint(shape[0] / 2 - xcenter)), axis=0) centered_data = np.roll( centered_data, int(np.rint(shape[1] / 2 - ycenter)), axis=1 ) centered_data = np.roll( centered_data, int(np.rint(shape[2] / 2 - zcenter)), axis=2 ) if return_former_center: return centered_data, (xcenter, ycenter, zcenter) return centered_data
[docs] def symmetric_pad( data: np.ndarray, output_shape: tuple | list | np.ndarray, values: float = 0, ) -> np.ndarray: """Return padded data so it matches the provided final_shape""" if data.ndim != len(output_shape): raise ValueError( f"output_shape length ({len(output_shape)}) should match of input " f"data dimension ({data.ndim})." ) pad_widths = [] for current_s, output_s in zip(data.shape, output_shape): if output_s < current_s: pad_widths.append((0, 0)) else: pad_left = (output_s - current_s) // 2 pad_right = pad_left + (output_s - current_s) % 2 pad_widths.append((pad_left, pad_right)) return np.pad( data, pad_width=pad_widths, mode="constant", constant_values=values )
[docs] def crop_at_center( data: np.ndarray, final_shape: list | tuple | np.ndarray ) -> np.ndarray: """ Crop 3D array data to match the final_shape. Center of the input data remains the center of cropped data. :param data: 3D array data to be cropped (np.array). :param final_shape: the targetted shape (list). If None, nothing happens. :returns: cropped 3D array (np.array). """ shape = data.shape final_shape = np.array(final_shape) if not (final_shape <= data.shape).all(): print( "One of the axis of the final shape is larger than " f"the initial axis (initial shape: {shape}, final shape: " f"{tuple(final_shape)}).\nDid not proceed to cropping." ) return data c = np.array(shape) // 2 # coordinates of the center to_crop = final_shape // 2 # indices to crop at both sides # if final_shape[i] is odd, center[i] must be at # final_shape[i] + 1 // 2 plus_one = np.where((final_shape % 2 == 0), 0, 1) cropped = data[ c[0] - to_crop[0] : c[0] + to_crop[0] + plus_one[0], c[1] - to_crop[1] : c[1] + to_crop[1] + plus_one[1], c[2] - to_crop[2] : c[2] + to_crop[2] + plus_one[2], ] return cropped
[docs] def compute_distance_from_com( data: np.ndarray, com: tuple | list | np.ndarray = None ) -> np.ndarray: """ Return a np.ndarray of the same shape of the provided data. (i, j, k) Value will correspond to the distance of the (i, j, k) voxel in data to the center of mass if that voxel is not null. """ nonzero_coordinates = np.nonzero(data) distance_matrix = np.zeros(shape=data.shape) if com is None: com = center_of_mass(data) for x, y, z in zip( nonzero_coordinates[0], nonzero_coordinates[1], nonzero_coordinates[2] ): distance = np.sqrt( (x - com[0]) ** 2 + (y - com[1]) ** 2 + (z - com[2]) ** 2 ) distance_matrix[x, y, z] = distance return distance_matrix
[docs] def num_to_nan(data: np.ndarray, num: int | float = 0): """ Replace all occurrences of 'num' in the array with np.nan. Args: data (np.ndarray): NumPy array. num (int | float, optional): The number to replace with np.nan Defaults to 0. Returns: A new NumPy array with 'num' replaced by np.nan. """ return np.where(data == num, np.nan, data)
[docs] def zero_to_nan(data: np.ndarray, boolean_values: bool = False) -> np.ndarray: """Convert zero values to np.nan.""" return np.where(data == 0, np.nan, 1 if boolean_values else data)
[docs] def nan_to_zero(data: np.ndarray, boolean_values: bool = False) -> np.ndarray: """Convert np.nan values to 0.""" return np.where(np.isnan(data), 0, 1 if boolean_values else data)
[docs] def to_bool(data: np.ndarray, nan_value: bool = False) -> np.ndarray: """Convert values to 1 (True) if not nan otherwise to 0 (False)""" return np.where(np.isnan(data), np.nan if nan_value else 0, 1)
[docs] def nan_center_of_mass( data: np.ndarray, return_int: bool = False ) -> np.ndarray: """ Compute the center of mass of a np.ndarray that may contain nan values. """ if not np.isnan(data).any(): com = center_of_mass(data) non_nan_coord = np.where(np.invert(np.isnan(data))) com = np.average([non_nan_coord], axis=2, weights=data[non_nan_coord])[0] if return_int: return tuple([int(round(e)) for e in com]) return tuple(com)
[docs] def hybrid_gradient( data: np.ndarray, *d: list, ) -> list[np.ndarray] | np.ndarray: """ Compute the gradient of a n-dim volume in each axis direction, 2nd order in the interior of the non-nan object, 1st order at the interface between the non-nan object and the surrounding nan values. Args: data (np.ndarray): the n-dim volume i to be derived d (list): the spacing in each direction Returns: list[np.ndarray] | np.ndarray: the gradients (in each direction) with the same shape as the input data """ if isinstance(d, (int, float)): d = [d] if data.ndim != len(d): raise ValueError(f"Invalid shape for d ({d}), must match data.ndim") gradient = [] for i in range(data.ndim): upper_slice = [slice(None)] * data.ndim upper_slice[i] = np.s_[1:] upper_slice = tuple(upper_slice) lower_slice = [slice(None)] * data.ndim lower_slice[i] = np.s_[:-1] lower_slice = tuple(lower_slice) # compute the first order gradient gradient.append((data[upper_slice] - data[lower_slice]) / d[i]) # some warning is expecting here as mean of empty slices might occur with warnings.catch_warnings(): warnings.simplefilter("ignore", category=RuntimeWarning) # here is the trick, using the np.nanmean allows keeping the # first order derivative at the interface. The other values # correspond to second order gradient gradient[i] = np.nanmean( [ gradient[i][upper_slice], gradient[i][lower_slice], ], axis=0, ) pad = [(0, 0) for _ in range(data.ndim)] pad[i] = (1, 1) gradient[i] = np.pad(gradient[i], pad, constant_values=np.nan) if len(gradient) == 1: return gradient[0] return gradient
[docs] class CroppingHandler: """ A class to handle data cropping. The class allows finding the requested position of the center and crop the data accordingly. """
[docs] @staticmethod def get_position(data: np.ndarray, method: str | tuple[int]) -> tuple[int]: """ Get the position of the reference voxel based on the centering method. Args: data: Input data array. method: Centering method. Can be "max" for maximum intensity, "com" for center of mass, or a tuple of coordinates representing the voxel position. Returns: The position of the reference voxel as a tuple of coordinates. Raises: ValueError: If an invalid method is provided. """ # if the data is complex, we take the absolute value if np.iscomplexobj(data): data = np.abs(data) if method == "max": return np.unravel_index(np.argmax(data), data.shape) elif method == "com": if isinstance(data, np.ma.MaskedArray): # data must be filled to account for the mask values data = data.filled(0) com = center_of_mass(data) return tuple(np.nan if np.isnan(e) else int(round(e)) for e in com) elif isinstance(method, (list, tuple)) and all( isinstance(e, (int, np.int64)) for e in method ): return tuple(method) else: raise ValueError( "Invalid method provided. Can be 'max', 'com' or a tuple of " "coordinates." )
[docs] @classmethod def get_masked_data( cls, data: np.ndarray, roi: list[int] ) -> np.ma.MaskedArray: """ Get the masked data array based on the region of interest (ROI). Args: data: Input data array. roi: Region of interest as a list of representing the cropped region ex: [start, end, start, end]. Returns: The masked data array with the specified ROI. """ mask = np.ones_like(data) mask[cls.roi_list_to_slices(roi)] = 0 return np.ma.array(data, mask=mask)
[docs] @staticmethod def roi_list_to_slices(roi: list[int]) -> tuple[slice, ...]: """ Convert a ROI to a tuple of slices. Args: roi: Region of interest as a list of start and end values for each dimension. Returns: The ROI converted to a tuple of slices. """ if len(roi) % 2 != 0: raise ValueError( "ROI should have start and end values for each dimension." ) return tuple( slice(start, end) for start, end in zip(roi[::2], roi[1::2]) )
[docs] @classmethod def get_roi( cls, output_shape: tuple, where: tuple, input_shape: tuple = None ) -> list[int]: """ Calculate the region of interest (ROI) for cropping the data based on the input shape, desired output shape, and reference voxel position. Args: output_shape: Desired output shape after cropping. where: Reference voxel position as a tuple of coordinates. input_shape: Shape of the input data array. Returns: The region of interest as a list of start and end values for each dimension. """ # define how much to crop data # plus_one is whether or not to add one to the bounds. plus_one = np.where((np.array(output_shape) % 2 == 0), 0, 1) crop = [ [e // 2, e // 2 + plus_one[i]] for i, e in enumerate(output_shape) ] roi = [] if input_shape is None: for i in range(len(output_shape)): roi.append(where[i] - crop[i][0]) roi.append(where[i] + crop[i][1]) return roi for i, s in enumerate(input_shape): # extend the roi to comply with the output_shape add_left = ( where[i] + crop[i][1] - s if where[i] + crop[i][1] > s else 0 ) add_right = ( crop[i][0] - where[i] if where[i] - crop[i][0] < 0 else 0 ) roi.append(np.max([where[i] - crop[i][0], 0]) - add_left) roi.append(np.min([where[i] + crop[i][1], s]) + add_right) # for i in range(0, len(roi), 2): # if roi[i] < 0: # warnings.warn( # f"The calculated roi contains a negative value " # "({roi[i]}), will set it to 0, this might give " # "inconsistent results." # ) # roi[i+1] -= roi[i] # roi[i] = 0 return roi
[docs] @classmethod def chain_centring( cls, data: np.ndarray, output_shape: tuple[int, ...], methods: list[str | tuple[int, ...]], verbose: bool = False, ) -> tuple[np.ndarray, tuple[int, ...]]: """ Apply sequential centring methods to the input data and return the cropped and centred data along with the position of the reference voxel in the newly cropped data frame. Args: data: Input data array. output_shape: Desired output shape after cropping. methods: list of sequential centring methods. Each method can be "max" for maximum intensity, "com" for center of mass, or a tuple of coordinates representing the voxel position. verbose: (bool, optional) whether to print out messages. Returns: A tuple containing the cropped and centered data array, the position of the reference voxel in the original data frame, and the position of the reference voxel in the newly cropped data frame and the roi. Raises: ValueError: If an invalid method is provided. """ # For the first method the data are not masked masked_data = data position = None msg = "" for method in methods: # position is found in the masked data position = cls.get_position(masked_data, method) msg += f"\t- {method}: {position}, value: {data[position]}\n" # get the roi roi = cls.get_roi(output_shape, position, data.shape) # mask the data values which are outside roi masked_data = cls.get_masked_data(data, roi=roi) if verbose: print("Chain centring:\n" + msg) # actual position along which the data are centered using roi position = tuple( (start + stop) // 2 for start, stop in zip(roi[::2], roi[1::2]) ) cropped_data = data[cls.roi_list_to_slices(roi)] cropped_position = tuple(p - r for p, r in zip(position, roi[::2])) return cropped_data.copy(), position, cropped_position, roi
[docs] @classmethod def force_centred_cropping( cls, data: np.ndarray, where: str | tuple = "centre", output_shape: tuple = None, verbose: bool = False, ) -> np.ndarray: """ Crop the data so the given reference position (where) is at the center of the final data frame no matter the output_shape. Therefore the real output shape might be different to output_shape. """ if output_shape is None: output_shape = data.shape output_shape = np.array(output_shape) if where == "centre": where = tuple(e // 2 for e in data.shape) position = cls.get_position(data, where) shape = data.shape safe_shape = np.array( shape_for_safe_centred_cropping(shape, position, output_shape) ) if verbose: print(f"Safe shape for cropping: {tuple(safe_shape)}") if np.all(safe_shape == output_shape): print("Does not require forced-centered cropping.") else: print( "Required shape for cropping at the center is " f"{tuple(safe_shape)}." ) plus_one = np.where((safe_shape % 2 == 0), 0, 1) crop = [ [safe_shape[i] // 2, safe_shape[i] // 2 + plus_one[i]] for i in range(len(safe_shape)) ] roi = [] for i, s in enumerate(shape): roi.append(np.max([position[i] - crop[i][0], 0])) roi.append(np.min([position[i] + crop[i][1], s])) return data[cls.roi_list_to_slices(roi)]
[docs] def compute_corrected_angles( inplane_angle: float, outofplane_angle: float, detector_coordinates: tuple, detector_distance: float, direct_beam_position: tuple, pixel_size: float = 55e-6, verbose=False, ) -> tuple[float, float]: """ Compute the detector corrected angles given the angles saved in the experiment data file and the position of interest in the detector frame :param inplane_angle: in-plane detector angle in degrees (float). :param outofplane_angle: out-of-plane detector angle in degrees (float). :param detector_coordinates: the detector coordinates of the point of interest (tuple or list). :param detector_distance: the sample to detector distance :param direct_beam_position: the direct beam position in the detector frame (tuple or list). :param pixel_size: the pixel size (float). :param verbose: whether or not to print the corrections (bool). :return: the two corrected angles. """ inplane_correction = np.rad2deg( np.arctan( (detector_coordinates[1] - direct_beam_position[0]) * pixel_size / detector_distance ) ) outofplane_correction = np.rad2deg( np.arctan( (detector_coordinates[0] - direct_beam_position[1]) * pixel_size / detector_distance ) ) corrected_inplane_angle = float(inplane_angle - inplane_correction) corrected_outofplane_angle = float( outofplane_angle - outofplane_correction ) if verbose: print( f"current in-plane angle: {inplane_correction}\n" f"in-plane angle correction: {corrected_inplane_angle}\n" f"corrected in-plane angle: {corrected_inplane_angle}\n\n" f"current out-of-plane angle: {outofplane_angle}\n" f"out-of-plane angle correction: {outofplane_correction}\n" f"corrected out-of-plane angle: {corrected_outofplane_angle}" ) return corrected_inplane_angle, corrected_outofplane_angle
[docs] def find_suitable_array_shape( support: np.ndarray, pad: list = None, symmetrical: bool = True ) -> np.ndarray: """Find a smaller array shape of 2 or 3D array containing a support.""" if pad is None: pad = np.repeat(4, support.ndim) pad = np.array(pad) if support.sum() <= 0: raise ValueError("support must contain 0 and 1.") def get_2d_shape(support_2d, pad): shape = [] for k in range(2): lim = np.nonzero(support_2d.sum(axis=1 - k))[0][[0, -1]] s = np.ptp(lim) + pad[k] * 2 shape.append(s) return shape shape = [] if support.ndim == 3: shapes_2d = [] for i in range(3): axes = tuple(k for k in range(3) if k != i) shapes_2d.append( get_2d_shape(support.sum(axis=i), np.squeeze(pad[[axes]])) ) shape = (shapes_2d[1][0], shapes_2d[0][0], shapes_2d[0][1]) elif support.ndim == 2: if len(pad) != 2: raise ValueError( f"Length of pad ({len(pad) = }) must be equal to support " # noqa E251 f"dimensions ({support.ndim = })" # noqa E251 ) shape = get_2d_shape(support, pad) if symmetrical: return tuple(np.repeat(np.max(shape), support.ndim)) return shape
[docs] def extract_reduced_shape( support: np.ndarray, pad: tuple | list | np.ndarray = None, symmetric: bool = False, ) -> tuple: if pad is None: pad = np.array([-10, 10]) support_limits = [] for i in range(support.ndim): limit = np.nonzero(support.sum(axis=i))[0][[0, -1]] limit += pad # padding support_limits.append(limit) shape = [np.ptp(limit) for limit in support_limits] if symmetric: return tuple(np.repeat(np.max(shape), support.ndim)) return shape
[docs] def get_oversampling_ratios( support: np.ndarray = None, direct_space_object: np.ndarray = None, isosurface: float = 0.3, plot: bool = False, ) -> np.ndarray: """ Compute the oversampling ratio of a reconstruction. Function proposed by Ewen Bellec (ewen.bellec@esrf.fr) Args: support (np.ndarray, optional): the support of the reconstruction. Defaults to None. direct_space_object (np.ndarray, optional): the reconstructed object. Defaults to None. isosurface (float, optional): the isosurface to determine the support. Defaults to .3. plot (bool, optional): whether to plot or not. Defaults to False. Raises: ValueError: If support is not provided, requires direct_space_object and isosurface (default to 0.3) value. Returns: np.ndarray: the oversampling ratio. """ if support is None: if direct_space_object is None: raise ValueError( "If support is not provided, provide direct_space_object and " "isosurface (default to 0.3) value" ) support = make_support(np.abs(direct_space_object), isosurface) if support.sum() < 1: return None support_indices = np.where(support == 1) size_per_dim = np.max(support_indices, axis=1) - np.min( support_indices, axis=1 ) oversampling_ratio = np.divide(np.array(support.shape), size_per_dim) if plot: _, ax = matplotlib.pyplot.subplots( 1, support.ndim, figsize=(5 * support.ndim, 4) ) for n in range(support.ndim): axes = tuple(np.delete(np.arange(3), n)) proj = np.max(support, axis=axes) ax[n].plot(proj) title = ( "oversampling along axis {n}\n{round(oversampling_ratio[n],2)}" ) ax[n].set_title(title, fontsize=15) return oversampling_ratio
[docs] def oversampling_from_diffraction( data: np.ndarray, support_threshold: float = 0.1, ) -> np.ndarray: """ Compute the oversampling ratios from diffraction data. Autocorrelation of the intensity is calculated to generate a support using support_threshold (isosurface). Args: data (np.ndarray): intensity (diffracted data) support_threshold (0.1): the threshold to build the support from a la pynx, default to 0.1. Returns: tuple[np.ndarray, np.ndarray]: the oversampling ratio and a suggestion for the rebin factors """ autocorrelation = np.abs(ifftshift(fftn(fftshift(data)))) support = autocorrelation > support_threshold * np.max(autocorrelation) oversampling_ratio = get_oversampling_ratios(support=support) return oversampling_ratio
[docs] def get_centred_slices( shape: tuple | list | np.ndarray, shift: tuple | list = None ) -> list: """ Compute the slices that allows to select the centre of each axis. It returns a list of len(shape) tuples made of len(shape) slices. The shift allows to shift the centred the slices by the amount provided. Ex: * if shape = (25, 48), will return:: [(12, slice(None, None, None)), (slice(None, None, None), 24)] * if shape = (25, 48, 50), will return:: [(12, slice(None, None, None), slice(None, None, None)), (slice(None, None, None), 24, slice(None, None, None)), (slice(None, None, None), slice(None, None, None), 25)] Args: shape (tuple | list | np.ndarray): the shape of the np.ndarray of which you want to select the centre. shift (tuple | list): a shift in the slice selection with respect to the centre. Returns: list: the list of tuples made of slices. """ if shift is None: shift = (0,) * len(shape) slices = [] for i, element in enumerate(shape): s = [slice(None)] * len(shape) s[i] = element // 2 + shift[i] slices.append(tuple(s)) return slices
[docs] def hot_pixel_filter( data: np.ndarray, threshold: float = 1e2, kernel_size: int = 3 ) -> tuple[np.ndarray, np.ndarray]: """ Remove hot pixels using a median filter. Args: data (np.ndarray): the input data. threshold (float, optional): the threshold to determine the mask. Mask pixels that are threshold times higher than neighbouring pixels. Defaults to 1e2. kernel_size (int, optional): the size of the kernel to compute the median filter. It corresponds to the size parameter of scipy.ndimage.median_filter function. Defaults to 3. Returns: tuple[np.ndarray, np.ndarray]: the cleaned data, hot pixel are set to 0 and the mask (1 for hot pixel, 0 otherwise). """ data_median = median_filter(data, size=kernel_size) mask = data < threshold * (data_median + 1) cleaned_data = data * mask return cleaned_data, np.logical_not(mask)
[docs] def valid_args_only(params: dict, function: callable) -> dict: return { k: v for k, v in params.items() if k in inspect.getfullargspec(function).args }