import numpy as np
from scipy.fft import fftn, fftshift, ifftn
[docs]
def make_box(
shape: tuple[int, int, int],
dimensions: int | tuple[int, int, int] = 15,
centre: tuple[int, int, int] | None = None,
rotation: np.ndarray | tuple[float, float, float] | None = None,
value: float = 1.0,
) -> np.ndarray:
"""
Create a 3D parallelepiped (rectangular cuboid) binary array.
A cube is a special case where all dimensions are equal.
Args:
shape: 3D array shape (nz, ny, nx).
dimensions: Side lengths in pixels. If scalar, creates a
cube. If tuple, order is (length_z, length_y, length_x).
centre: Centre position (z, y, x). If None, uses array
centre.
rotation: Rotation to apply. Can be:
- None: no rotation
- (3,3) array: rotation matrix
- tuple of 3 floats: Euler angles (deg) as
(alpha, beta, gamma) combined as
Rz(alpha) @ Ry(beta) @ Rx(gamma)
value: Value to fill inside the parallelepiped.
Returns:
Binary array with `value` inside parallelepiped and 0
outside.
Raises:
ValueError: If shape is not 3D, dimensions invalid, or
rotation has invalid shape.
"""
if len(shape) != 3:
raise ValueError(f"shape must be 3D, got {len(shape)}D")
# parse dimensions
dims_array = np.asarray(dimensions)
if dims_array.size == 1:
dim_z = dim_y = dim_x = int(dims_array)
elif dims_array.size == 3:
dim_z, dim_y, dim_x = [int(d) for d in dims_array]
else:
raise ValueError("dimensions must be scalar or sequence of 3 elements")
if any(d <= 0 for d in [dim_z, dim_y, dim_x]):
raise ValueError("all dimensions must be positive")
# parse centre
if centre is None:
centre = tuple(np.array(shape) // 2)
# optimised case: axis-aligned parallelepiped (no rotation)
rotation_matrix = _parse_rotation_matrix(rotation)
if rotation_matrix is None:
parallelepiped = np.zeros(shape, dtype=float)
half_dims = np.array([dim_z, dim_y, dim_x]) // 2
# compute bounds using vectorised operations
starts = np.maximum(0, np.array(centre) - half_dims)
ends = np.minimum(shape, np.array(centre) + half_dims)
parallelepiped[
starts[0] : ends[0],
starts[1] : ends[1],
starts[2] : ends[2],
] = value
return parallelepiped
# rotated case: use coordinate transformation
coords_z, coords_y, coords_x = _get_centred_coordinates(shape, centre)
# apply inverse rotation (to transform world coords to local)
rotation_inv = rotation_matrix.T
rotated_z, rotated_y, rotated_x = _apply_rotation_to_coordinates(
coords_z, coords_y, coords_x, rotation_inv
)
# check if inside box
half_z, half_y, half_x = dim_z / 2, dim_y / 2, dim_x / 2
inside = (
(np.abs(rotated_z) <= half_z)
& (np.abs(rotated_y) <= half_y)
& (np.abs(rotated_x) <= half_x)
)
parallelepiped = np.zeros(shape, dtype=float)
parallelepiped[inside] = value
return parallelepiped
[docs]
def make_ellipsoid(
shape: tuple[int, int, int],
radii: float | tuple[float, float, float] = 15,
centre: tuple[float, float, float] | None = None,
rotation: np.ndarray | tuple[float, float, float] | None = None,
value: float = 1.0,
) -> np.ndarray:
"""
Create a 3D ellipsoid binary array.
A sphere is a special case where all radii are equal.
Args:
shape: 3D array shape (nz, ny, nx).
radii: Radii in pixels. If scalar, creates a sphere.
If tuple, order is (rz, ry, rx).
centre: Centre position (z, y, x). If None, uses array
centre.
rotation: Rotation to apply. Can be:
- None: no rotation
- (3,3) array: rotation matrix
- tuple of 3 floats: Euler angles (deg) as
(alpha, beta, gamma) combined as
Rz(alpha) @ Ry(beta) @ Rx(gamma)
value: Value to fill inside the ellipsoid.
Returns:
3D array with `value` inside ellipsoid and 0 outside.
Raises:
ValueError: If radii or rotation have invalid shape.
"""
# parse radii
radii_array = np.asarray(radii)
if radii_array.size == 1:
radius_z = radius_y = radius_x = float(radii_array)
elif radii_array.size == 3:
radius_z, radius_y, radius_x = [float(r) for r in radii_array]
else:
raise ValueError(
"radii must be scalar or sequence of 3 elements (rz, ry, rx)"
)
# get centred coordinates
coords_z, coords_y, coords_x = _get_centred_coordinates(shape, centre)
# apply rotation if requested
rotation_matrix = _parse_rotation_matrix(rotation)
if rotation_matrix is not None:
rotated_z, rotated_y, rotated_x = _apply_rotation_to_coordinates(
coords_z, coords_y, coords_x, rotation_matrix
)
else:
rotated_z, rotated_y, rotated_x = coords_z, coords_y, coords_x
# ellipsoid equation: (z/rz)^2 + (y/ry)^2 + (x/rx)^2 <= 1
with np.errstate(divide="ignore", invalid="ignore"):
inside = (rotated_z / radius_z) ** 2 + (rotated_y / radius_y) ** 2 + (
rotated_x / radius_x
) ** 2 <= 1.0
array = np.zeros(shape, dtype=float)
array[inside] = value
return array
[docs]
def make_cylinder(
shape: tuple[int, int, int],
radius: float = 10.0,
height: float = 25.0,
centre: tuple[float, float, float] | None = None,
axis: int = 0,
rotation: np.ndarray | tuple[float, float, float] | None = None,
value: float = 1.0,
) -> np.ndarray:
"""
Create a 3D cylinder binary array.
Args:
shape: 3D array shape (nz, ny, nx).
radius: Cylinder radius in pixels.
height: Cylinder height in pixels.
centre: Centre position (z, y, x). If None, uses array
centre.
axis: Cylinder axis direction as integer (0=z, 1=y, 2=x).
rotation: Additional rotation to apply after axis
alignment. Same format as other shape functions.
value: Value to fill inside the cylinder.
Returns:
3D array with `value` inside cylinder and 0 outside.
Raises:
ValueError: If axis is invalid or parameters negative.
"""
if radius <= 0 or height <= 0:
raise ValueError("radius and height must be positive")
if axis not in [0, 1, 2]:
raise ValueError("axis must be 0 (z), 1 (y), or 2 (x)")
# get centred coordinates as list for easier indexing
coords = list(_get_centred_coordinates(shape, centre))
# apply rotation if requested
rotation_matrix = _parse_rotation_matrix(rotation)
if rotation_matrix is not None:
coords = list(
_apply_rotation_to_coordinates(
coords[0], coords[1], coords[2], rotation_matrix
)
)
# determine radial and axial coordinates using modular arithmetic
# for axis=0 (z): radial uses indices [1,2] (y,x)
# for axis=1 (y): radial uses indices [0,2] (z,x)
# for axis=2 (x): radial uses indices [0,1] (z,y)
radial_indices = [(axis + 1) % 3, (axis + 2) % 3]
radial_dist_sq = coords[radial_indices[0]] ** 2 + (
coords[radial_indices[1]] ** 2
)
axial_coord = coords[axis]
# cylinder equation
inside = (radial_dist_sq <= radius**2) & (
np.abs(axial_coord) <= height / 2
)
array = np.zeros(shape, dtype=float)
array[inside] = value
return array
[docs]
def add_linear_phase(
obj: np.ndarray,
phase_gradient: tuple[float, float, float] = (1.0, 1.0, 1.0),
apply_to_support: bool = True,
) -> np.ndarray:
"""
Add a linear phase gradient to an object.
Args:
obj: Real-valued object (amplitude).
phase_gradient: Phase gradient (radians/pixel) along
(z, y, x) directions.
apply_to_support: If True, apply phase only where obj > 0.
Returns:
Complex object with linear phase applied.
"""
shape = obj.shape
# create centred coordinate grids
centre = tuple(np.array(shape) // 2)
grids = np.ogrid[: shape[0], : shape[1], : shape[2]]
coords = [grids[i] - centre[i] for i in range(3)]
# compute linear phase using vectorised dot product
phase = sum(phase_gradient[i] * coords[i] for i in range(3))
# apply to support if requested
if apply_to_support:
phase = phase * (obj > 0)
return obj * np.exp(1j * phase)
[docs]
def add_quadratic_phase(
obj: np.ndarray,
curvature: tuple[float, float, float] = (2, 2, 2),
apply_to_support: bool = True,
) -> np.ndarray:
"""
Add quadratic phase (e.g., defocus, strain) to an object.
Args:
obj: Real-valued object (amplitude).
curvature: Phase curvature coefficients (radians/pixel²)
along (z, y, x) directions.
apply_to_support: If True, apply phase only where obj > 0.
Returns:
Complex object with quadratic phase applied.
"""
shape = obj.shape
# create centred coordinate grids
centre = tuple(np.array(shape) // 2)
grids = np.ogrid[: shape[0], : shape[1], : shape[2]]
coords = [grids[i] - centre[i] for i in range(3)]
# compute quadratic phase
phase = sum(curvature[i] * coords[i] ** 2 for i in range(3))
if apply_to_support:
phase = phase * (obj > 0)
return obj * np.exp(1j * phase)
[docs]
def add_displacement_field(
obj: np.ndarray,
displacement_field: np.ndarray,
q_bragg: tuple[float, float, float],
) -> np.ndarray:
"""
Add phase from a 3D displacement field.
The phase is computed as φ = 2π * Q · u(r), where Q is the
Bragg vector and u(r) is the displacement field.
Args:
obj: Real-valued object (amplitude).
displacement_field: 3D displacement vector field with
shape (3, nz, ny, nx) where first axis is (uz, uy, ux).
q_bragg: Bragg vector (qz, qy, qx) in reciprocal
space units.
Returns:
Complex object with displacement-induced phase.
Raises:
ValueError: If displacement_field shape is incompatible.
"""
if displacement_field.shape[0] != 3:
raise ValueError("displacement_field must have shape (3, nz, ny, nx)")
if displacement_field.shape[1:] != obj.shape:
raise ValueError(
"displacement_field spatial shape must match obj shape"
)
# compute phase: φ = 2π * Q · u using vectorised dot product
phase = (
2 * np.pi * sum(q_bragg[i] * displacement_field[i] for i in range(3))
)
# apply only to object support
phase = phase * (obj > 0)
return obj * np.exp(1j * phase)
[docs]
def add_random_phase(
obj: np.ndarray,
phase_std: float = 2.0,
correlation_length: float | None = 10,
apply_to_support: bool = True,
) -> np.ndarray:
"""
Add random phase noise to an object.
Args:
obj: Real-valued object (amplitude).
phase_std: Standard deviation of phase noise (radians).
correlation_length: Correlation length for spatially
correlated noise (pixels). If None, uses uncorrelated
noise.
apply_to_support: If True, apply phase only where obj > 0.
Returns:
Complex object with random phase noise.
"""
shape = obj.shape
# generate random phase
phase = np.random.normal(0, phase_std, shape)
# apply spatial correlation if requested
if correlation_length is not None and phase_std > 0:
# Gaussian smoothing in Fourier space
kz = np.fft.fftfreq(shape[0])
ky = np.fft.fftfreq(shape[1])
kx = np.fft.fftfreq(shape[2])
kz, ky, kx = np.meshgrid(kz, ky, kx, indexing="ij")
k_sq = kz**2 + ky**2 + kx**2
# Gaussian filter
filter_func = np.exp(-2 * (np.pi * correlation_length) ** 2 * k_sq)
phase_fft = fftn(phase)
phase = np.real(ifftn(phase_fft * filter_func))
# renormalise to maintain std
phase = phase / phase.std() * phase_std
if apply_to_support:
phase = phase * (obj > 0)
return obj * np.exp(1j * phase)
def _parse_rotation_matrix(
rotation: np.ndarray | tuple[float, float, float] | None,
) -> np.ndarray | None:
"""
Parse rotation input and return a rotation matrix.
Helper function to standardise rotation matrix creation across
shape generation functions.
Args:
rotation: Rotation specification. Can be:
- None: returns None (no rotation)
- (3,3) array: returns as-is (assumed rotation matrix)
- tuple of 3 floats: Euler angles (deg) as
(alpha, beta, gamma) combined as
Rz(alpha) @ Ry(beta) @ Rx(gamma)
Returns:
3x3 rotation matrix or None if rotation is None.
Raises:
ValueError: If rotation has invalid shape.
"""
if rotation is None:
return None
rotation_matrix = np.asarray(rotation)
if rotation_matrix.shape == (3,):
# convert Euler angles (degrees) to rotation matrix
alpha, beta, gamma = np.deg2rad(rotation_matrix)
rot_z = np.array(
[
[np.cos(alpha), -np.sin(alpha), 0],
[np.sin(alpha), np.cos(alpha), 0],
[0, 0, 1],
]
)
rot_y = np.array(
[
[np.cos(beta), 0, np.sin(beta)],
[0, 1, 0],
[-np.sin(beta), 0, np.cos(beta)],
]
)
rot_x = np.array(
[
[1, 0, 0],
[0, np.cos(gamma), -np.sin(gamma)],
[0, np.sin(gamma), np.cos(gamma)],
]
)
return rot_z @ rot_y @ rot_x
elif rotation_matrix.shape == (3, 3):
return rotation_matrix
else:
raise ValueError(
"rotation must be None, (3,3) matrix or "
"length-3 sequence of Euler angles (deg)"
)
def _get_centred_coordinates(
shape: tuple[int, int, int],
centre: tuple[float, float, float] | None = None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Create coordinate grids centred at specified position.
Helper function to standardise coordinate grid creation across
shape generation functions.
Args:
shape: 3D array shape (nz, ny, nx).
centre: Centre position (z, y, x). If None, uses array
centre.
Returns:
Tuple of (coords_z, coords_y, coords_x) as centred
coordinate grids suitable for broadcasting.
"""
if centre is None:
centre = tuple(np.array(shape) // 2)
# create coordinate grids (broadcast-friendly)
grids = np.ogrid[: shape[0], : shape[1], : shape[2]]
# centre coordinates and return them
return tuple(grids[i] - centre[i] for i in range(3))
def _apply_rotation_to_coordinates(
coords_z: np.ndarray,
coords_y: np.ndarray,
coords_x: np.ndarray,
rotation_matrix: np.ndarray,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Apply rotation matrix to coordinate grids.
Helper function to apply rotation transformation to centred
coordinates in a vectorised manner.
Args:
coords_z: Z-coordinates relative to centre.
coords_y: Y-coordinates relative to centre.
coords_x: X-coordinates relative to centre.
rotation_matrix: 3x3 rotation matrix to apply.
Returns:
Tuple of (rotated_z, rotated_y, rotated_x) coordinate grids.
"""
rotated_z = (
rotation_matrix[0, 0] * coords_z
+ rotation_matrix[0, 1] * coords_y
+ rotation_matrix[0, 2] * coords_x
)
rotated_y = (
rotation_matrix[1, 0] * coords_z
+ rotation_matrix[1, 1] * coords_y
+ rotation_matrix[1, 2] * coords_x
)
rotated_x = (
rotation_matrix[2, 0] * coords_z
+ rotation_matrix[2, 1] * coords_y
+ rotation_matrix[2, 2] * coords_x
)
return rotated_z, rotated_y, rotated_x
[docs]
def simulate_diffraction(
obj: np.ndarray,
photon_budget: float | None = None,
max_intensity: float | None = None,
scale: float = 1.0,
poisson_statistics: bool = False,
convention: str | None = None,
) -> np.ndarray:
"""
Simulate diffraction pattern from a real-space object.
This uses a single, consistent numerical convention: a forward
n-dimensional FFT followed by fftshift to place the Bragg peak
at the centre of the array::
reciprocal_obj = fftshift(fftn(obj))
The ``convention`` argument is kept for future extension (e.g.
axis re-ordering or additional phase factors) but does not
change the underlying FFT operator. This keeps the
implementation simple and avoids mixing ``fftn``/``ifftn``
conventions. If you need to match a sign convention used in
analytical Bragg-CDI derivations or in external tools, do so
via the definition of the q-grid (e.g. flipping an axis or
using ``-phase`` consistently), not by swapping FFT directions
here.
Intensity is computed as ``|reciprocal_obj|**2``. Optionally,
the result can be scaled to ``photon_budget`` and/or
``max_intensity``. Note that first scaling to
``max_intensity`` will discard photon budget scaling if both
are provided. Scale is applied in any case as a multiplicative
factor.
For noise modelling refer to :func:`add_noise`.
Args:
obj: Real-space complex object to simulate.
photon_budget: Total photon budget for the exposure. If
provided, intensity is scaled so that the sum equals
this value (before Poisson sampling).
max_intensity: Maximum intensity value for final scaling.
If None, no scaling is applied.
scale: Multiplicative scale factor applied to intensity.
Defaults to 1.0.
poisson_statistics: If True, apply Poisson statistics to
the final intensity (photon counting). Default is False.
convention: Placeholder for future FFT/q-space
conventions. Currently ignored except for being accepted
for API compatibility.
Returns:
Simulated diffraction pattern (intensity).
Raises:
ValueError: If max_intensity or photon_budget are negative.
Example:
>>> # simulate diffraction from a 3D object
>>> obj = make_box((64, 64, 64), dimensions=20)
>>> obj = add_random_phase(obj, amplitude=0.1)
>>> intensity = simulate_diffraction(obj, photon_budget=1e9)
>>> intensity.shape
(64, 64, 64)
Notes:
The FFT convention is fixed to avoid confusion. The forward
FFT is always used. For crystallographic sign conventions,
adjust the phase of your object (e.g., use ``np.conj(obj)``
if needed) rather than changing the FFT direction.
"""
# validate inputs
if photon_budget is not None and photon_budget < 0:
raise ValueError(
f"photon_budget must be non-negative, got {photon_budget}"
)
if max_intensity is not None and max_intensity < 0:
raise ValueError(
f"max_intensity must be non-negative, got {max_intensity}"
)
# compute diffraction pattern
reciprocal_obj = fftshift(fftn(obj))
intensity = np.abs(reciprocal_obj) ** 2
# apply scaling
if photon_budget is not None:
scale *= photon_budget / intensity.sum()
if max_intensity is not None:
scale *= max_intensity / intensity.max()
intensity = intensity * scale
# apply Poisson statistics (photon counting)
if poisson_statistics:
intensity = np.random.poisson(intensity)
return intensity