import copy
import glob
import os
import re
import shutil
from typing import Type
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import silx.io
from matplotlib.colors import LogNorm
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy.fft import fftn, fftshift, ifftshift
from scipy.ndimage import fourier_shift
from scipy.stats import gaussian_kde
from skimage.registration import phase_cross_correlation
try:
from pynx.cdi import (
CDI,
ER,
HIO,
RAAR,
AutoCorrelationSupport,
DetwinHIO,
DetwinRAAR,
FourierApplyAmplitude,
InitPSF,
InitSupportShape,
ScaleObj,
SupportTooLarge,
SupportTooSmall,
SupportUpdate,
)
from pynx.cdi.selection import match2
from pynx.utils.math import ortho_modes
IS_PYNX_AVAILABLE = True
CDI_Type = Type[CDI]
except ImportError:
IS_PYNX_AVAILABLE = False
CDI_Type = None
from cdiutils.plot import add_colorbar, get_plot_configs
from cdiutils.process.postprocessor import PostProcessor
from cdiutils.utils import CroppingHandler, get_centred_slices, valid_args_only
DEFAULT_PYNX_PARAMS = {
# support-related params
"support_threshold": (0.15, 0.40),
"smooth_width": (2, 0.5, 900), # (start, end, period)
"post_expand": None, # (-1, 1)
"support_update_period": 50,
"method": "rms",
"force_shrink": False,
"update_border_n": 0,
"support": None,
"obj": None,
"amp_range": (0, 1),
"scale_obj": "I",
"phase_range": (-np.pi, np.pi),
"all_random": False,
"support_shape": None,
"support_formula": None,
"support_size": None,
# object initialisation
# "support": "auto",
"support_autocorrelation_threshold": (0.09, 0.11),
# operator parameters
"calc_llk": 50,
"show_cdi": 0,
"update_psf": 100,
"fig_num": -1,
"zero_mask": False,
# others
"psf": "pseudo-voigt,1,0.05", # model,fwhm,eta
"compute_free_llk": True,
"positivity": False,
"confidence_interval_factor_mask_min": 0.5,
"confidence_interval_factor_mask_max": 1.2,
"detwin": False,
}
[docs]
class PyNXImportError(ImportError):
"""Custom exception to handle Pynx import error."""
[docs]
def __init__(self, msg: str = None) -> None:
_msg = "'PyNX' is not installed on the current machine."
if msg is not None:
_msg += "\n" + msg
super().__init__(_msg)
[docs]
class PyNXPhaser:
"""
A class for using PyNX's phasing algorithms without worrying
too much about the initialisation and the parameters. This class
uses generic PyNX parameters but they can also be provided upon
instanciation.
"""
[docs]
def __init__(
self,
iobs: np.ndarray,
mask: np.ndarray = None,
operators: dict = None,
**params,
) -> None:
"""
Initialisation method.
Args:
iobs (np.ndarray): the observed intensity.
mask (np.ndarray, optional): the mask, can include detector
gap mask/hot pixels/aliens. Defaults to None.
operators (dict, optional): the operators you might want to
use. If not provided, ER, HIO and RAAR will be
initialised anyway. Defaults to None.
params (dict, optional): the PyNX parameters. If not
provided, will use some generic parameters.
Raises:
ModuleNotFoundError: if PyNX is not installed, this class
is of no use.
"""
if not IS_PYNX_AVAILABLE:
raise PyNXImportError
self.iobs = fftshift(iobs)
self.mask = None
if mask is not None:
self.mask = fftshift(mask)
# Get the pynx parameters from the default and update them with
# those provided by the user
# not elegant, but works...
self.params = copy.deepcopy(DEFAULT_PYNX_PARAMS)
if params is not None:
self.params.update(params)
if operators is None:
self.operators = self._init_operators()
else:
self.operators = operators
self.support_update: SupportUpdate
self._init_support_update()
self.support_threshold_auto_tune_factor: float = 1.1
self.cdi: CDI_Type = None
self.init_cdi_params = dict
self.cdi_list: list[CDI_Type] = []
self.wrong_support_failure_tolerance = 5
self.figure: matplotlib.figure.Figure = None
def _init_operators(self) -> dict:
"""Initialise generic PynX operators."""
operator_parameters = {
key: self.params[key]
for key in [
"calc_llk",
"show_cdi",
"update_psf",
"fig_num",
"zero_mask",
"confidence_interval_factor_mask_max",
"confidence_interval_factor_mask_min",
]
}
return {
"er": ER(**operator_parameters),
"hio": HIO(**operator_parameters),
"detwinhio": DetwinHIO(detwin_axis=1),
"raar": RAAR(**operator_parameters),
"detwinraar": DetwinRAAR(detwin_axis=1),
"faa": FourierApplyAmplitude(
**valid_args_only(operator_parameters, FourierApplyAmplitude)
),
}
def _init_support_update(self) -> None:
"""
Initialise the SupportUpdate class from pynx, depending on the
parameters of the instance.
"""
# load default pynx values for missing parameters
support_params = {
k: self.params.get(k, default)
for k, default in (
("threshold_relative", 0.2),
("smooth_width", 3),
("force_shrink", False),
("method", "rms"),
("post_expand", None),
("verbose", False),
("update_border_n", 0),
("min_fraction", 0),
("max_fraction", 1),
("lazy", False),
)
}
# if support_threshold is provided, use it to set the
# threshold_relative, unless threshold_relative is already provided
if (
"support_threshold" in self.params
and "threshold_relative" not in self.params
):
# take care of different types
if isinstance(self.params["support_threshold"], (float, int)):
support_params["threshold_relative"] = self.params[
"support_threshold"
]
# if given as a string or list/tuple, assume it's a range
elif isinstance(self.params["support_threshold"], str):
# assuming it's a range in string format "min, max"
support_params["threshold_relative"] = np.random.uniform(
*(
float(x)
for x in self.params["support_threshold"].split(",")
)
)
elif isinstance(self.params["support_threshold"], (list, tuple)):
support_params["threshold_relative"] = np.random.uniform(
*self.params["support_threshold"]
)
print("Using support update parameters:", support_params)
if self.params["support_update_period"]:
self.support_update = SupportUpdate(**support_params)
else:
self.support_update = 1
[docs]
def init_cdi(
self, verbose: bool = True, init_main_cdi: bool = True, **params
) -> CDI_Type:
"""
Initialise the CDI object.
Args:
verbose (bool, optional): whether to print the
initialisation steps. Defaults to True.
init_main_cdi (bool, optional): whether to set the cdi
attribute or not. Defaults to True.
**params (optional): the parameters to update the
default ones with.
Returns:
CDI_Type: the initialised CDI object.
"""
if params is not None:
for key in ["support", "obj"]:
if key in params and isinstance(params[key], np.ndarray):
if verbose:
print(f"fftshifting {key}")
params[key] = fftshift(params[key])
self.params.update(params)
if self.params["positivity"]:
self.params["phase_range"] = 0
for key in ["amp_range", "phase_range"]:
if isinstance(self.params[key], (int, float)):
self.params[key] = (
-self.params[key] if key == "phase_range" else 0,
self.params[key],
)
cdi = CDI(
self.iobs, self.params["support"], self.params["obj"], self.mask
)
if self.params["support"] is None and self.params["obj"] is None:
if self.params["all_random"]:
if verbose:
print("Full random initialisation requested.")
amp = np.random.uniform(
low=self.params["amp_range"][0],
high=self.params["amp_range"][1],
size=self.iobs.size,
).reshape(self.iobs.shape)
phase = np.random.uniform(
low=self.params["phase_range"][0],
high=self.params["phase_range"][1],
size=self.iobs.size,
).reshape(self.iobs.shape)
cdi.set_obj(amp * np.exp(1j * phase))
elif self.params["support_shape"] is not None:
cdi = (
InitSupportShape(
shape=self.params["support_shape"],
size=self.params.get("support_size"),
formula=self.params.get("support_formula"),
)
* cdi
)
else:
if verbose:
print("Support will be initialised using autocorrelation.")
cdi = (
AutoCorrelationSupport(
threshold=self.params[
"support_autocorrelation_threshold"
]
)
* cdi
)
if self.params["obj"] is None and not self.params["all_random"]:
if verbose:
print(
"obj is None, will initialise object with support, and:"
"\n\t-amp_range: "
+ (
"scale to observed I"
if self.params["scale_obj"]
else f"{self.params['amp_range']}"
)
+ f"\n\t-phase_range: {self.params['phase_range']}"
)
# Let's initialise the object from the support
current_support = cdi.get_support()
obj = current_support.astype(np.complex64)
amp = np.random.uniform(
low=self.params["amp_range"][0],
high=self.params["amp_range"][1],
size=np.flatnonzero(current_support).size,
)
phase = np.random.uniform(
low=self.params["phase_range"][0],
high=self.params["phase_range"][1],
size=np.flatnonzero(current_support).size,
)
obj[current_support > 0] *= amp * np.exp(1j * phase)
cdi.set_obj(obj)
# cdi = InitObjRandom(
# amax=self.params["amp_range"],
# phirange=self.params["phase_range"]
# ) * cdi
if self.params["scale_obj"]:
cdi = (
ScaleObj(method=self.params["scale_obj"], verbose=verbose)
* cdi
)
if self.params["psf"] is not None:
model, fwhm, eta = self.params["psf"].split(",")
cdi = InitPSF(model, float(fwhm), float(eta)) * cdi
if self.params["compute_free_llk"]:
cdi.init_free_pixels()
if init_main_cdi:
self.cdi = cdi
return cdi
[docs]
@classmethod
def read_instructions(cls, recipe: str) -> list[str]:
"""Read the instructions given in the recipe."""
recipe = recipe.replace(" ", "")
instructions = recipe.split(",")
if "=" in instructions:
pass
return instructions
[docs]
def run(
self, recipe: str, cdi: CDI_Type = None, init_cdi: bool = True
) -> None:
"""
Run the reconstruction algorithm.
Args:
recipe (str): the instruction to run, i.e. the sequence of
projection algorithms (ex:
"ER**200, HIO**400, RAAR**800")
cdi (CDI_Type, optional): a cdi object you might to work on,
if not provided, will used the attribute self.cdi.
Defaults to None.
init_cdi (bool, optional): whether to initialise the cdi
object or not. Defaults to True.
Raises:
ValueError: if instruction is not authorised.
ValueError: if instruction is invalid
"""
if cdi is None:
cdi = self.cdi
if init_cdi:
cdi = self.init_cdi(init_main_cdi=False)
instructions = self.read_instructions(recipe)
for i, instruction in enumerate(instructions):
print(f"Instruction #{i + 1}: {instruction}")
if "**" in instruction:
algo_name, iteration = instruction.split("**")
iteration = int(iteration)
try:
algo = self.operators[algo_name.lower()]
except KeyError as exc:
raise ValueError(
f"Invalid operator name '{algo_name}', should be:\n"
f"{list(self.operators.keys())}."
) from exc
attempt = 1
while attempt < self.wrong_support_failure_tolerance:
try:
support_update_period = (
self.params["support_update_period"]
if self.params["support_update_period"]
else 1
)
cdi = (
self.support_update
* (algo**support_update_period)
** (iteration // support_update_period)
) * cdi
attempt = self.wrong_support_failure_tolerance
except SupportTooLarge:
print(
"Support is too large, shrinking it. "
f"Attempt #{attempt + 1}"
)
self.support_update.threshold_relative /= (
self.support_threshold_auto_tune_factor
)
attempt += 1
except SupportTooSmall:
print(
"Support is too small, making it bigger. "
f"Attempt #{attempt + 1}"
)
self.support_update.threshold_relative *= (
self.support_threshold_auto_tune_factor
)
attempt += 1
elif instruction.lower() == "faa":
cdi = self.operators["faa"] * cdi
else:
raise ValueError(f"Invalid instruction ({instruction}).")
[docs]
def run_multiple_instances(
self, run_nb: int, recipe: str, init_cdi: bool = True
) -> None:
"""
Run several reconstructions.
Args:
run_nb (int): the number of reconstructions
recipe (str): the instruction to run, i.e. the sequence of
projection algorithms (ex:
"ER**200, HIO**400, RAAR**800")
init_cdi (bool, optional): whether to initialise the cdi
object or not. Defaults to True.
Raises:
ValueError: if init_cdi is False but no initialisation was
done with the init_cdi() method beforehand.
"""
if init_cdi:
self.cdi_list = []
for i in range(run_nb):
self.cdi_list.append(
self.init_cdi(init_main_cdi=False, verbose=False)
)
else:
if self.cdi_list is None or not self.cdi_list:
raise ValueError(
"CDI object are not initialised, init_cdi should be True."
)
# if support_threshold is given as a list/tuple/string, set a different
# threshold for each run
support_relative_must_be_updated = False
if isinstance(self.params["support_threshold"], (list, tuple, str)):
support_relative_must_be_updated = True
for i, cdi in enumerate(self.cdi_list):
print(f"Run #{i + 1}")
# set support threshold if needed
if support_relative_must_be_updated:
if isinstance(self.params["support_threshold"], str):
threshold = np.random.uniform(
*(
float(x)
for x in self.params["support_threshold"].split(
","
)
)
)
else:
threshold = np.random.uniform(
*self.params["support_threshold"]
)
self.support_update.threshold_relative = threshold
print(f"Support threshold set to {threshold}")
# run the recipe
self.run(recipe, cdi, init_cdi=False)
[docs]
def genetic_phasing(
self,
run_nb: int,
genetic_pass_nb: int,
recipe: str,
selection_method: str = "sharpness",
init_cdi: bool = True,
):
"""
Run 'genetic'-like phasing. It runs a number of reconstruction
independently (defined by the recipe). The support of the best
result based on the selection_method is used to replace other
reconstruction support. This is repeated genetic_pass_nb number
of times.
Args:
run_nb (int): the number of reconstructions.
genetic_pass_nb (int): the number of genetic pass, i.e. the
number of times the recipe is applied.
recipe (str): the instruction to run, i.e. the sequence of
projection algorithms (ex:
"ER**200, HIO**400, RAAR**800")
selection_method (str, optional): The metric used to select
the best reconstruction. Defaults to "sharpness".
init_cdi (bool, optional): whether to initialise the cdi
object or not. Defaults to True.
Raises:
ValueError: if init_cdi is False but no initialisation was
done with the init_cdi() method beforehand.
ValueError: selection method is unknown/invalid.
"""
if init_cdi:
self.cdi_list = []
for i in range(run_nb):
self.cdi_list.append(
self.init_cdi(init_main_cdi=False, verbose=False)
)
else:
if self.cdi_list is None or self.cdi_list == []:
raise ValueError(
"CDI object are not initialised, init_cdi should be True."
)
if selection_method not in ("sharpness", "mean_to_max"):
raise ValueError(
f"Invalid selection_method ({selection_method}), can be"
"'sharpness' or 'mean_to_max'."
)
metrics = [None for _ in range(run_nb)]
for i in range(genetic_pass_nb + 1):
if i == 0:
print("First reconstruction pass")
else:
print(f"\nGenetic pass #{i}.")
for i in range(run_nb):
try:
metrics[i] = (
PhasingResultAnalyser.amplitude_based_metrics(
np.abs(self.cdi_list[i].get_obj(shift=True)),
self.cdi_list[i].get_support(shift=True),
)[selection_method]
)
except ValueError:
metrics[i] = np.nan
indice = np.argsort(metrics)[0]
print(
"Updating cdi objects with the best reconstruction "
f"(run # {indice + 1}).\n"
)
amplitude_reference = np.abs(self.cdi_list[indice].get_obj())
for i in range(run_nb):
if i == indice:
continue
new_obj = self.cdi_list[i].get_obj()
new_obj = np.sqrt(
np.abs(new_obj) * amplitude_reference
) * np.exp(1j * np.angle(new_obj))
self.cdi_list[i].set_obj(new_obj)
self.run_multiple_instances(run_nb, recipe, init_cdi=False)
[docs]
@staticmethod
def plot_cdi(
cdi: CDI_Type,
spaces: str = "both",
axis: int = 0,
title: str = None,
axes: plt.Axes = None,
) -> None:
"""
A staticmethod to plot cdi object main quantity.
See ShowCDI operator to see how Vincent plots CDI objects.
https://gitlab.esrf.fr/favre/PyNX/-/blob/master/pynx/cdi/cpu_operator.py?ref_type=heads
Args:
cdi (CDI_Type): the cdi object to plot the quantity from.
spaces (str, optional): Whether reciprocal or direct space.
Defaults to "both".
axis (int, optional): What slice to plot. Defaults to 0.
title (str, optional): A title for the plot. Defaults to None.
"""
# Check the dimension of the object, whether 2D or 3D
if cdi.get_obj().ndim == 2:
the_slice = slice(None)
else:
# The slice to plot according to the given axis
the_slice = get_centred_slices(cdi.get_obj().shape)[axis]
# Get the support and amplitude
support = cdi.get_support(shift=True)[the_slice]
direct_space_obj = cdi.get_obj(shift=True)[the_slice]
amplitude = np.abs(direct_space_obj)
# Vincent's style to manage vmax value for amplitude plotting.
# Basically: "max is at the 99 percentile relative to the number
# of points inside the support"
percent = 100 * (1 - 0.01 * support.sum() / support.size)
max99 = np.percentile(amplitude, percent)
# Initialise the quantities to plot, in any case direct space
# amplitude and phase are required.
quantities = {
"amplitude": amplitude,
"phase": np.ma.masked_array(
np.angle(direct_space_obj), mask=(support == 0)
),
}
if spaces == "direct":
quantities["support"] = support
elif spaces == "both":
quantities["calculated_intensity"] = (
np.abs(ifftshift(fftn(cdi.get_obj().copy()))[the_slice]) ** 2
)
iobs = cdi.get_iobs(shift=True).copy()[the_slice]
tmp = np.logical_and(iobs > -1e19, iobs < 0)
if tmp.sum() > 0:
# change back free pixels to their real intensity
iobs[tmp] = -iobs[tmp] - 1
iobs[iobs < 0] = 0
quantities["observed_intensity"] = iobs
nrows, ncols = (1, 3) if spaces == "direct" else (2, 2)
update = False
if axes is None:
fig, axes = plt.subplots(
nrows, ncols, layout="tight", figsize=(3, 3)
)
else:
update = True
if len(axes) != nrows * ncols:
raise ValueError(
"Provided axes should have shape (1, 3) or (2, 2)."
)
for ax, key in zip(axes.flat, quantities):
plot_params = get_plot_configs(key)
# Refine the generic plot configs
if key == "amplitude":
plot_params["vmax"] = max99
plot_params["cmap"] = "gray"
if key == "calculated_intensity":
plot_params["title"] = "Calculated Int. (a.u.)"
if key == "observed_intensity":
plot_params["title"] = "Observed Int. (a.u.)"
if "intensity" in key:
plot_params["norm"] = LogNorm(1)
plot_params.pop("vmin"), plot_params.pop("vmax")
# Set the ax title and remove it from the plotting params
ax.set_title(plot_params.pop("title"))
if update:
ax.images[0].set_data(quantities[key])
else:
image = ax.matshow(
quantities[key], origin="lower", **plot_params
)
# if key in ("support", "phase") and support.sum() > 0:
if key in ("amplitude", "support", "phase") and support.sum() > 0:
ax.set_xlim(
np.nonzero(support.sum(axis=0))[0][[0, -1]]
+ np.array([-10, 10])
)
ax.set_ylim(
np.nonzero(support.sum(axis=1))[0][[0, -1]]
+ np.array([-10, 10])
)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="7.5%", pad=0.05)
fig.colorbar(
image, cax=cax, extend="both" if key == "phase" else None
)
ax.locator_params(nbins=5)
ax.xaxis.set_ticks_position("bottom")
ax.yaxis.set_ticks_position("left")
fig.suptitle(title)
return axes
def __repr__(self) -> str:
text = (
"PyNXPhaser object:\n"
f"PyNXPhaser with iobs shape: {self.iobs.shape}, "
f"mask shape: {self.mask.shape if self.mask is not None else None}"
)
parameters_text = "\nPyNX parameters:"
for key in sorted(self.params):
parameters_text += f"\n\t{key}: {self.params[key]}"
text += parameters_text
return text
[docs]
class PhasingResultAnalyser:
"""
This class provided utility function for phase retrieval results
analysis.
"""
[docs]
def __init__(
self, cdi_results: list = None, result_dir_path: str = None
) -> None:
"""
Init method.
Args:
cdi_results (list, optional): results as CDI objects.
Defaults to None.
result_dir_path (str, optional): the path of directory
containing the .cxi phase retrieval results. Defaults
to None.
Raises:
ValueError: if non of the parameters is provided or if both
are.
"""
message = (
"cdi_results or result_dir_path cannot be provided simultaneously"
", but one of the two must be."
)
if cdi_results is None and result_dir_path is None:
raise ValueError(message)
if cdi_results is not None and result_dir_path is not None:
raise ValueError(message)
# Convert the parsed list into a dict whose keys are run numbers
self.cdi_results = []
if cdi_results:
self.cdi_results = {
f"Run{i + 1:04d}": cdi for i, cdi in enumerate(cdi_results)
}
self.result_dir_path = result_dir_path
self._metrics = None
self._sorted_phasing_results = None
self.result_paths = []
self.best_candidates = []
@property
def metrics(self) -> dict:
if self._metrics is None:
return None
metrics = {m: {} for m in self._metrics}
for m in self._metrics:
for path in self._metrics[m]:
if os.path.exists(path):
run = path.split("/")[-1]
metrics[m][run] = self._metrics[m][path]
else:
metrics[m][path] = self._metrics[m][path]
return metrics
@property
def sorted_phasing_results(self) -> dict:
if self._sorted_phasing_results is None:
return None
sorted_results = {}
for path in self._sorted_phasing_results:
if os.path.exists(path):
run = path.split("/")[-1]
sorted_results[run] = self._sorted_phasing_results[path]
else:
sorted_results[path] = self._sorted_phasing_results[path]
return sorted_results
[docs]
def find_phasing_results(self, search_pattern: str = "*Run*.cxi") -> None:
"""
Find phasing results (.cxi files) that match the given pattern.
Args:
search_pattern (str, optional): Pattern to search for files.
Uses glob syntax (not regex). Defaults to "*Run*.cxi".
Raises:
ValueError: If no files match the given pattern.
"""
self.result_paths = []
# Handle the path joining properly without duplicate slashes
full_path_pattern = os.path.join(self.result_dir_path, search_pattern)
fetched_paths = glob.glob(full_path_pattern)
for path in fetched_paths:
if os.path.isfile(path):
self.result_paths.append(path)
if not self.result_paths:
raise ValueError(
f"No result found that match the pattern '{search_pattern}' "
f"in the result_dir_path ({self.result_dir_path})"
)
[docs]
@staticmethod
def amplitude_based_metrics(
amplitude: np.ndarray, support: np.ndarray
) -> dict:
"""
Compute the criteria based on amplitude analysis, namely:
* the mean_to_max, mean of the gaussian fitting to max
distance.
* the std, the standard deviation of the amplitude
* the sharpness, the sum or mean of the amplitude to the
power of 4.
Args:
amplitude (np.ndarray): the amplitude to work on.
support (np.ndarray): the associated support.
Returns:
dict: the dictionary containing the metrics.
"""
if amplitude.shape != support.shape:
raise ValueError("Amplitude and support must have the same shape.")
if amplitude.size == 0 or support.size == 0:
raise ValueError("Amplitude and support must not be empty arrays.")
sharpness = np.mean((amplitude * support) ** 4)
amplitude = amplitude[support > 0]
std = np.std(amplitude)
amplitude /= np.max(amplitude)
# fit the amplitude distribution
kernel = gaussian_kde(amplitude)
x = np.linspace(0, 1, 100)
fitted_counts = kernel(x)
max_index = np.argmax(fitted_counts)
return {
"mean_to_max": 1 - x[max_index],
"std": std,
"sharpness": sharpness,
}
[docs]
def analyse_phasing_results(
self,
sorting_criterion: str = "mean_to_max",
search_pattern: str = "*Run*.cxi",
plot: bool = True,
plot_phasing_results: bool = True,
plot_phase: bool = False,
) -> None:
"""
Analyse the phase retrieval results by sorting them according to
the sorting_criteion, which must be selected in among:
* mean_to_max the difference between the mean of the
Gaussian fitting of the amplitude histogram and the maximum
value of the amplitude. We consider the closest to the max
the mean is, the most homogeneous is the amplitude of the
reconstruction, hence the best.
* the sharpness the sum of the amplitude within support to
the power of 4. For reconstruction with similar support,
lowest values means graeter amplitude homogeneity.
* std the standard deviation of the amplitude.
* llk the log-likelihood of the reconstruction.
* llkf the free log-likelihood of the reconstruction.
Args:
sorting_criterion (str, optional): the criterion to sort the
results with. Defaults to "mean_to_max".
search_pattern (str, optional): Pattern to search for files.
Uses glob syntax (not regex). Defaults to "*Run*.cxi".
plot (bool, optional): whether or not to disable all plots.
plot_phasing_results (bool, optional): whether to plot the
phasing results. Defaults to True.
plot_phase (bool, optional): whether the phase must be
plotted. If True, will the phase is plotted with
amplitude as opacity. If False, amplitude is plotted
instead. Defaults to False.
Raises:
ValueError: if sorting_criterion is unknown.
"""
criteria = ["mean_to_max", "std", "llk", "llkf", "sharpness", "all"]
if sorting_criterion not in criteria:
raise ValueError(
f"Provided criterion ({sorting_criterion}) is unknown. "
f"Possible criteria are:\n{criteria}."
)
if self._sorted_phasing_results is None:
print("[INFO] Computing metrics...")
self._metrics = {
m: {f: None for f in self.result_paths} for m in criteria
}
if self.cdi_results:
for run in self.cdi_results:
self._metrics["llk"][run] = self.cdi_results[
run
].llk_poisson
self._metrics["llkf"][run] = self.cdi_results[
run
].llk_poisson_free
amplitude = np.abs(self.cdi_results[run].get_obj())
support = self.cdi_results[run].get_support()
# update the metric dictionary
try:
amplitude_based_metrics = self.amplitude_based_metrics(
amplitude, support
)
except ValueError:
amplitude_based_metrics = {
"mean_to_max": np.nan,
"std": np.nan,
"sharpness": np.nan,
}
for m in ["mean_to_max", "std", "sharpness"]:
self._metrics[m][run] = amplitude_based_metrics[m]
else:
self.find_phasing_results(search_pattern)
for p in self.result_paths:
with silx.io.h5py_utils.File(p, "r") as file:
self._metrics["llk"][p] = file[
"entry_1/image_1/process_1/results/llk_poisson"
][()]
self._metrics["llkf"][p] = file[
"entry_1/image_1/process_1/results/"
"free_llk_poisson"
][()]
support = file["entry_1/image_1/support"][()]
amplitude = np.abs(file["entry_1/data_1/data"][()])
amplitude_based_metrics = self.amplitude_based_metrics(
amplitude, support
)
for m in ["mean_to_max", "std", "sharpness"]:
self._metrics[m][p] = amplitude_based_metrics[m]
# normalise all the values so they can be compared in the same plot
for m in ["mean_to_max", "std", "llk", "llkf", "sharpness"]:
minimum_value = min(list(self._metrics[m].values()))
ptp_value = np.ptp(np.array(list(self._metrics[m].values())))
for k in self._metrics[m].keys():
self._metrics[m][k] = (
self._metrics[m][k] - minimum_value
) / ptp_value
# compute the average value over all metrics
for run in self._metrics["llk"]:
self._metrics["all"][run] = 0
for m in ["mean_to_max", "std", "llk", "llkf", "sharpness"]:
self._metrics["all"][run] += self._metrics[m][run]
self._metrics["all"][run] /= 5
# now sort the files according to the provided criterion
self._sorted_phasing_results = dict(
sorted(
self._metrics[sorting_criterion].items(),
key=lambda item: item[1],
reverse=False,
)
)
if self.cdi_results is not None:
runs = [
str(extract_run_info(run)[0]).zfill(2)
for run in self._sorted_phasing_results
]
else:
runs = [
str(extract_run_info(file)[0]).zfill(2)
for file in self._sorted_phasing_results
]
if plot:
figure, ax = plt.subplots(1, 1, layout="tight", figsize=(6, 3))
colors = {
m: c
for m, c in zip(
self._metrics,
[
"lightcoral",
"mediumslateblue",
"dodgerblue",
"plum",
"teal",
"crimson",
],
)
}
for m in self._metrics:
ax.plot(
runs,
[
self._metrics[m][f]
for f in self._sorted_phasing_results
],
label=m,
color=colors[m],
marker="o",
markersize=4,
markerfacecolor=colors[m],
markeredgewidth=0.5,
markeredgecolor="k",
)
ax.set_ylabel("Normalised metric")
ax.set_xlabel("Run number")
# Shrink current axis by 20%
box = ax.get_position()
ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
# Put a legend to the right of the current axis
ax.legend(
loc="center left", bbox_to_anchor=(1, 0.5), frameon=False
)
figure.suptitle("Phasing result analysis (the lower the better)\n")
print(
f"[INFO] the sorted list of runs using '{sorting_criterion}' "
f"sorting_criterion is:\n{runs}."
)
if plot_phasing_results and plot:
print("[INFO] Plotting phasing results...")
for result in self._sorted_phasing_results:
if self.cdi_results:
data = self.cdi_results[result].get_obj(shift=True)
support = self.cdi_results[result].get_support(shift=True)
else:
with silx.io.h5py_utils.File(result, "r") as file:
data = file["entry_1/data_1/data"][()]
support = file["entry_1/image_1/support"][()]
run = extract_run_info(result)[0]
title = f"Phasing results, run {extract_run_info(result)[0]}"
self.plot_phasing_result(data, support, title, plot_phase)
[docs]
def select_best_candidates(
self,
nb_of_best_sorted_runs: int = None,
best_runs: list[int] = None,
search_pattern: str = "*Run*.cxi",
) -> None:
"""
Select the best candidates, two methods are possible. Either
select a specific number of runs, provided they were alaysed and
sorted beforehand. Or simply provide a list of integers
corresponding to the digit numbers of the best runs.
Args:
nb_of_best_sorted_runs (int, optional): the number of best
runs to select, provided they were analysed beforehand.
Defaults to None.
best_runs (list[int], optional): the best runs to select.
Defaults to None.
search_pattern (str, optional): Pattern to search for files.
Uses glob syntax (not regex). Defaults to "*Run*.cxi".
Raises:
ValueError: If nb_of_best_sorted_runs but reconstructions
were not analysed before.
"""
if nb_of_best_sorted_runs is None and best_runs is None:
print(
"Neither nb_of_best_sorted_runs nor best_runs are provided. "
"Will select only the first best run."
)
nb_of_best_sorted_runs = 1
# If selection is made by hand, i.e. best_runs is provided
if best_runs:
if self.cdi_results:
# This is the notebook mode
self.best_candidates = [f"Run{r:04d}" for r in best_runs]
else:
# This is the script/pipeline mode
self.best_candidates = []
if not self.result_paths:
self.find_phasing_results(search_pattern)
for path in self.result_paths:
run_nb = extract_run_info(path)[0]
if run_nb in best_runs:
self.best_candidates.append(path)
elif nb_of_best_sorted_runs:
if self._sorted_phasing_results is None:
raise ValueError("Phasing results have not been analysed yet.")
self.best_candidates = list(self._sorted_phasing_results.keys())
self.best_candidates = self.best_candidates[
:nb_of_best_sorted_runs
]
if not self.cdi_results:
# This is the script/pipeline mode
# Remove the previous candidate files
for f in glob.glob(self.result_dir_path + "/candidate_*.cxi"):
os.remove(f)
printout_list = [
str(extract_run_info(f)[0]).zfill(2)
for f in self.best_candidates
]
print(f"[INFO] Best candidates selected:\n{printout_list}")
for i, f in enumerate(self.best_candidates):
dir_name, file_name = os.path.split(f)
run_nb = str(extract_run_info(file_name)[0]).zfill(2)
scan_nb = file_name.split("_")[0]
file_name = (
f"/candidate_{i + 1}-{len(self.best_candidates)}"
f"_{scan_nb}_run_{run_nb}.cxi"
)
shutil.copy(f, dir_name + file_name)
# Note: this method only works if PyNX is installed. If not, use
# the BcdiPipeline method.
[docs]
def mode_decomposition(
self, verbose: bool = True, search_pattern: str = "*Run*.cxi"
) -> np.ndarray:
"""
Run a mode decomposition à la PyNX. See pynx_cdi_analysis.py
script. Note that this method only works if PyNX is installed.
If not, use the BcdiPipeline method.
Args:
verbose (bool, optional): whether to print some logs.
Defaults to True.
search_pattern (str, optional): Pattern to search for files.
Uses glob syntax (not regex). Defaults to "*Run*.cxi".
Raises:
ValueError: in script/notebook mode, if not results are
found in the reconstruction folder.
Returns:
np.ndarray: the main mode.
"""
if not IS_PYNX_AVAILABLE:
raise PyNXImportError
if self.best_candidates:
result_keys = self.best_candidates
else:
print(
"No best candidates selected, computation will be conducted "
"on all reconstructions."
)
if self._sorted_phasing_results is not None:
result_keys = self._sorted_phasing_results.keys()
else:
if verbose:
print(
"Results are not sorted, therefore matching won't be "
"done against the best result."
)
if self.result_paths is None:
self.find_phasing_results(search_pattern)
result_keys = self.result_paths
results = []
for r in result_keys:
if self.cdi_results:
results.append(self.cdi_results[r].get_obj(shift=True))
else:
with silx.io.h5py_utils.File(r, "r") as file:
results.append(file["entry_1/data_1/data"][()])
match2_results = [results[0]]
for i in range(1, len(results)):
_, d2c, r = match2(results[0], results[i])
match2_results.append(d2c) # don't know if needed, that's PyNX way
if verbose:
print(f"R_match({i}) = {r * 100:6.3f} %")
modes, mode_weights = ortho_modes(
match2_results, nb_mode=1, return_weights=True
)
if verbose:
print(f"First mode represents {mode_weights[0] * 100:6.3f} %")
return modes, mode_weights
[docs]
@staticmethod
def plot_phasing_result(
data: np.ndarray,
support: np.ndarray,
title: str = None,
plot_phase: bool = False,
) -> None:
"""
Plot the reconstructed object in reciprocal and direct spaces.
Args:
data (np.ndarray): the reconstruction data to plot.
support (np.ndarray): the support of the reconstruction.
title (str, optional): the title of the plot. Defaults to None.
plot_phase (bool, optional): whether to plot the phase,
if True, will plot the phase whit amplitude as opacity.
If False, amplitude will be plotted.
Defaults to False.
"""
reciprocal_space_data = np.abs(ifftshift(fftn(fftshift(data)))) ** 2
direct_space_amplitude = np.abs(data)
normalised_direct_space_amplitude = (
direct_space_amplitude - np.min(direct_space_amplitude)
) / np.ptp(direct_space_amplitude)
direct_space_phase = np.angle(data)
if data.ndim == 3:
figure, axes = plt.subplots(
2, 3, figsize=(6, 4), layout="constrained"
)
com = CroppingHandler.get_position(support, "com")
shift = tuple(com[i] - support.shape[i] // 2 for i in range(3))
slices = get_centred_slices(data.shape, shift)
for i in range(3):
rcp_im = axes[0, i].matshow(
np.sum(reciprocal_space_data, axis=i),
cmap="turbo",
norm=LogNorm(1),
)
if plot_phase:
direct_space_im = axes[1, i].matshow(
direct_space_phase[slices[i]],
vmin=-np.pi,
vmax=np.pi,
alpha=normalised_direct_space_amplitude[slices[i]],
cmap="cet_CET_C9s_r",
)
else:
direct_space_im = axes[1, i].matshow(
direct_space_amplitude[slices[i]],
vmin=0,
vmax=direct_space_amplitude.max(),
cmap="turbo",
)
if support[slices[i]].sum() > 0:
axes[1, i].set_xlim(
np.nonzero(support[slices[i]].sum(axis=0))[0][[0, -1]]
+ np.array([-5, 5])
)
axes[1, i].set_ylim(
np.nonzero(support[slices[i]].sum(axis=1))[0][[0, -1]]
+ np.array([-5, 5])
)
figure.colorbar(rcp_im, ax=axes[0, 2], extend="both")
figure.colorbar(direct_space_im, ax=axes[1, 2], extend="both")
axes[0, 1].set_title("Intensity projection (a.u.)")
axes[1, 1].set_title(
"Phase (rad)" if plot_phase else "Amplitude (a.u.)"
)
if data.ndim == 2:
figure, axes = plt.subplots(1, 2, figsize=(4, 2), layout="tight")
rcp_im = axes[0].matshow(
reciprocal_space_data, cmap="turbo", norm=LogNorm(1)
)
if plot_phase:
direct_space_im = axes[1].matshow(
direct_space_phase,
vmin=-np.pi,
vmax=np.pi,
alpha=normalised_direct_space_amplitude,
cmap="cet_CET_C9s_r",
)
else:
direct_space_im = axes[1].matshow(
direct_space_amplitude,
vmin=0,
vmax=direct_space_amplitude.max(),
cmap="turbo",
)
if support.sum() > 0:
axes[1].set_xlim(
np.nonzero(support.sum(axis=0))[0][[0, -1]]
+ np.array([-5, 5])
)
axes[1].set_ylim(
np.nonzero(support.sum(axis=1))[0][[0, -1]]
+ np.array([-5, 5])
)
figure.colorbar(rcp_im, ax=axes[0])
figure.colorbar(
direct_space_im,
ax=axes[1],
extend="both" if plot_phase else None,
)
axes[0].set_title("Intensity sum (a.u.)")
axes[1].set_title(
"Phase (rad)" if plot_phase else "Amplitude (a.u.)"
)
for ax in axes.ravel():
ax.set_xticks([])
ax.set_yticks([])
figure.suptitle(title)
plt.show()
[docs]
@staticmethod
def twin_image_checkup(
ref: np.ndarray | str,
reconstruction: np.ndarray | str,
phase_unwrap: bool = True,
**plot_params,
) -> tuple[plt.Figure, plt.Axes]:
"""
Implementation of the check-up for the twin image problem as
described by Manuel Sicairos (see https://www.researchgate.net/publication/233828110_Understanding_the_twin-image_problem_in_phase_retrieval) # noqa
The function takes two reconstructions, one is considered a
reference. After the two reconstuctions being registered, their
phase in the Fourier space are compared (difference). The same
procedure is applied to the twin image of the second
reconstruction.
To cite the paper:
'The smoothness of the retrieved Fourier phase differences can
give an indication of the twin image problem. If either of the
Fourier phase differences is smooth everywhere, this would be an
indication that the reconstructions are free of the twin-image
problem. In this test, the twin-image problem would be
characterized by the appearance of smooth and random regions
that appear complementary when comparing upright with twin
images.'
Args:
ref (np.ndarray | str): the reconstruction considered as the
reference.
reconstruction (np.ndarray | str): the reconstruction that
will be twinned and compared to the reference.
phase_unwrap (bool, optional): whether to unwrap the phase
upon plotting. Defaults to True.
Returns:
tuple[plt.Figure, plt.Axes]: the matpltolib figure and axes
objects.
"""
def check_load(data: str) -> np.ndarray:
if isinstance(data, str):
if data.endswith(".cxi"):
with silx.io.h5py_utils.File(data) as file:
return file["entry_1/data_1/data"][()]
if data.endswith(".npz"):
with np.load(data) as file:
return file["arr_0"][()]
raise NotImplementedError(
"File extension handling not implemented."
)
if isinstance(data, np.ndarray):
return data
raise ValueError(
"ref and reconstruction must be either strings (path to file) "
f"or np.ndarrays, not {type(data)}."
)
ref = check_load(ref)
reconstruction = check_load(reconstruction)
if ref.shape != reconstruction.shape:
raise ValueError(
f"Shapes of reference (ref.shape = {ref.shape}) and "
"reconstruction (reconstruction.shape = "
f"{reconstruction.shape}) must be "
"identical."
)
# Register the two datasets
shift, _, _ = phase_cross_correlation(
np.abs(ref), np.abs(reconstruction), upsample_factor=100
)
print(
"The shift calculated from the "
f"phase_cross_correlation is: {shift}."
)
fourier_data = {
"ref": ifftshift(fftn(fftshift(ref))),
"reconstruction": ifftshift(
fourier_shift(fftn(fftshift(reconstruction)), shift)
),
}
fourier_data["twin"] = np.conj(fourier_data["reconstruction"])
fourier_phases = {}
for k in fourier_data:
phase = np.angle(fourier_data[k])
centre_index = tuple(s // 2 for s in phase.shape)
fourier_phases[k] = phase - phase[centre_index]
phase_diffs = [
PostProcessor.unwrap_phase(
fourier_phases[k] - fourier_phases["ref"]
)
if phase_unwrap
else fourier_phases[k] - fourier_phases["ref"]
for k in ("reconstruction", "twin")
]
_plot_params = {"cmap": "cet_CET_C9s_r"}
if phase_unwrap:
_plot_params["vmax"] = np.max(np.abs(np.asarray(phase_diffs)))
else:
_plot_params["vmax"] = 2 * np.pi
_plot_params["vmin"] = -_plot_params["vmax"]
if plot_params:
_plot_params.update(plot_params)
if ref.ndim == 3:
figure, axes = plt.subplots(2, 3, layout="tight", figsize=(6, 3))
slices = get_centred_slices(ref.shape)
for i in range(2):
for j, ax in enumerate(axes[i].flat):
ax.imshow(phase_diffs[i][slices[j]], **_plot_params)
axes[0, 1].set_title(r"$\varphi - \varphi_{\text{ref}}$", y=1.1)
axes[1, 1].set_title(
r"$\varphi_{\text{twin}} - \varphi_{\text{ref}} "
r"= -\varphi - \varphi_{\text{ref}}$",
y=1.1,
)
if ref.ndim == 2:
figure, axes = plt.subplots(1, 2, layout="tight")
axes[0].imshow(phase_diffs[0], **_plot_params)
axes[1].imshow(phase_diffs[1], **_plot_params)
axes[0].set_title(r"$\varphi - \varphi_{\text{ref}}$")
axes[1].set_title(
r"$\varphi_{\text{twin}} - \varphi_{\text{ref}} "
r"= -\varphi - \varphi_{\text{ref}}$",
)
for ax in axes.flat:
ax.set_xticks([])
ax.set_yticks([])
add_colorbar(ax, extend="both")
return figure, axes