import copy
import logging
import warnings
from typing import Optional, Union
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from matplotlib.axes import Axes
from numpy.typing import NDArray
from scipy import ndimage # type: ignore
from tqdm.auto import tqdm
from ..constants import IMAGE_NDIMS, K_INV_ANGSTROM, STACK_NDIMS
from ..conversion.base import convert_gx_gy_to_sx_sy
from ..plotting.base import plot_image
from .cache_utils import smart_cache
from .lattice import Lattice, rotation_matrix
logger = logging.getLogger(__name__)
[docs]
class Ewald:
"""
Class for calculating and analyzing the Ewald sphere construction in RHEED.
This class combines experimental RHEED image metadata with a reciprocal
lattice model to predict diffraction spot positions on the screen.
Azimuthal angle convention
--------------------------
Three azimuthal angles are distinguished:
* image_azimuthal_angle
Experimental azimuth of the RHEED image, read from image metadata.
This value defines the reference frame and is treated as immutable.
* ewald_azimuthal_rotation
User-defined relative rotation of the Ewald construction with respect
to the image azimuth. This does not modify the image metadata.
* Ewald azimuthal angles (effective)
The azimuthal angles actually used in the Ewald construction, derived as
image_azimuthal_angle ± ewald_azimuthal_rotation
When mirror symmetry is enabled, both ± rotations are used.
This separation preserves the experimental reference frame while allowing
controlled relative rotation of the theoretical Ewald construction.
"""
# Class constants
SPOT_WIDTH_MM: float = 1.5
SPOT_HEIGHT_MM: float = 5.0
[docs]
def __init__(
self,
lattice: Lattice,
image: Optional[xr.DataArray] = None,
stack_index: int = 0,
) -> None:
"""
Initialize an Ewald object for RHEED spot calculations.
Parameters
----------
lattice : Lattice
Lattice object representing the crystal structure.
image : Optional[xr.DataArray], optional
RHEED image data. Can be a single image or a stack of images. If None,
default values are used.
stack_index : int, optional
Index of the image to use from a stack (default is 0).
"""
self._image_stack: Optional[xr.DataArray] = None
self._stack_index: int = stack_index
if image is None:
logger.warning("RHEED image not provided, default parameters are loaded.")
self.image: Optional[xr.DataArray] = None
self.beam_energy: float = 18_600.0
self.screen_sample_distance: float = 309.2
self.screen_scale: float = 9.5
self.screen_roi_width: float = 60
self.screen_roi_height: float = 80
self._incident_angle: Union[float, NDArray[np.float32]] = 1.0
self._image_azimuthal_angle: Union[float, NDArray[np.float32]] = 0.0
self._image_data_available: bool = False
else:
if image.ndim == IMAGE_NDIMS:
self.image = image.copy()
elif image.ndim == STACK_NDIMS:
self._image_stack = image.copy()
self.image = self._image_stack[stack_index]
self.beam_energy = float(image.ri.beam_energy)
self.screen_sample_distance = float(image.ri.screen_sample_distance)
self.screen_scale = float(image.ri.screen_scale)
self.screen_roi_width = float(image.ri.screen_roi_width)
self.screen_roi_height = float(image.ri.screen_roi_height)
self._incident_angle = image.ri.incident_angle
self._image_azimuthal_angle = image.ri.azimuthal_angle
self._image_data_available = True
self._lattice_scale: float = 1.0
self._spot_w_px: int = int(self.SPOT_WIDTH_MM * self.screen_scale)
self._spot_h_px: int = int(self.SPOT_HEIGHT_MM * self.screen_scale)
self.spot_structure: NDArray[np.bool_] = self._generate_spot_structure()
self.shift_x: float = 0.0
self.shift_y: float = 0.0
self.fine_scaling: float = 1.0
self.ewald_radius: float = np.sqrt(self.beam_energy) * K_INV_ANGSTROM
self._ewald_roi: float = self._calc_ewald_roi()
self._ewald_azimuthal_rotation: float = 0.0
self._lattice: Lattice = copy.deepcopy(lattice)
self._inverse_lattice: NDArray[np.float32] = self._prepare_inverse_lattice()
self.label: Optional[str] = lattice.label
self.mirror_symmetry: bool = False
self.ew_sx: NDArray[np.float32]
self.ew_sy: NDArray[np.float32]
self.use_cache: bool = True
self.cache_dir: str = "cache"
self.cache_key: Optional[str] = None
self.calculate_ewald()
logger.info(
"Initialized Ewald: label=%s image_provided=%s beam_energy=%.1f screen_scale=%.2f",
self.label,
self._image_data_available,
self.beam_energy,
self.screen_scale,
)
def __repr__(self) -> str:
return (
f"Ewald Class Object: {self.label}\n"
f" Ewald Radius : {self.ewald_radius:.2f} 1/Å\n"
f" Image azimuthal angle: : {self.image_azimuthal_angle:.2f}°\n"
f" Incident angle : {self.incident_angle:.2f}°\n"
f" Real Lattice Scale : {self.lattice_scale:.2f}\n"
f" Screen Scale : {self.screen_scale:.2f} px/mm\n"
f" Sample-Screen Distance : {self.screen_sample_distance:.1f} mm\n"
f" Screen Shift X : {self.shift_x:.2f} mm\n"
f" Screen Shift Y : {self.shift_y:.2f} mm\n"
f" Reciprocal Vector b1 : [{self._lattice.b1[0]:.2f}, {self._lattice.b1[1]:.2f}] 1/Å\n"
f" Reciprocal Vector b2 : [{self._lattice.b2[0]:.2f}, {self._lattice.b2[1]:.2f}] 1/Å\n"
)
def __copy__(self) -> "Ewald":
"""
Create a shallow copy of the Ewald object.
Returns
-------
Ewald
A new instance with the same parameters.
"""
new_ewald = Ewald(self._lattice, self.image)
new_ewald.ewald_azimuthal_rotation = self.ewald_azimuthal_rotation
new_ewald.incident_angle = self.incident_angle
new_ewald.lattice_scale = self.lattice_scale
new_ewald.ewald_roi = self.ewald_roi
new_ewald._spot_w_px = self._spot_w_px
new_ewald._spot_h_px = self._spot_h_px
return new_ewald
@property
def stack_index(self) -> int:
"""int: Index of the current image in a stack."""
return self._stack_index
@stack_index.setter
def stack_index(self, value: int):
if self._image_stack is None:
raise ValueError("Stack index can only be set for 3D image stacks.")
if not (value < self._image_stack.shape[0]):
raise ValueError("Stack index out of bounds.")
self._stack_index = value
self.image = self._image_stack[self._stack_index]
self.calculate_ewald()
@property
def lattice_scale(self) -> float:
return self._lattice_scale
@lattice_scale.setter
def lattice_scale(self, value: float):
if abs(self._lattice_scale - value) > 0.5:
self.ewald_roi = self._calc_ewald_roi(value)
logging.info("New Ewald roi: %.2f", self.ewald_roi)
self._lattice_scale = value
self.calculate_ewald()
@property
def image_azimuthal_angle(self) -> float:
if isinstance(self._image_azimuthal_angle, np.ndarray):
return self._image_azimuthal_angle[self._stack_index]
return self._image_azimuthal_angle
@image_azimuthal_angle.setter
def image_azimuthal_angle(self, value: float):
"""
Deprecated setter.
Historically this modified the azimuthal angle of the image.
Now the image azimuthal angle is treated as not mutable physical property.
"""
if isinstance(self._image_azimuthal_angle, np.ndarray):
raise TypeError(
"image_azimuthal_angle is derived from stacked image metadata "
"and cannot be set directly."
)
warnings.warn(
"Setting image_azimuthal_angle is deprecated and will be removed "
"in a future version. Use ewald_azimuthal_rotation instead.",
DeprecationWarning,
stacklevel=2,
)
# Interpret old behavior as a relative rotation
self.ewald_azimuthal_rotation = float(value) - self._image_azimuthal_angle
@property
def ewald_azimuthal_rotation(self) -> float:
return self._ewald_azimuthal_rotation
@ewald_azimuthal_rotation.setter
def ewald_azimuthal_rotation(self, value: float):
self._ewald_azimuthal_rotation = float(value)
self.calculate_ewald()
@property
def ewald_azimuthal_angle(self) -> float:
return self.image_azimuthal_angle + self._ewald_azimuthal_rotation
@property
def incident_angle(self) -> float:
if isinstance(self._incident_angle, np.ndarray):
return self._incident_angle[self._stack_index]
return self._incident_angle
@incident_angle.setter
def incident_angle(self, value: float):
if isinstance(self._incident_angle, np.ndarray):
raise ValueError("Cannot set incident individually for stack images.")
self._incident_angle = value
self.calculate_ewald()
@property
def ewald_roi(self) -> float:
return self._ewald_roi
@ewald_roi.setter
def ewald_roi(self, value: float):
self._ewald_roi = value
self._inverse_lattice = self._prepare_inverse_lattice()
[docs]
def set_spot_size(self, width: float, height: float):
"""
Set the spot size used for mask generation.
Parameters
----------
width : float
Spot width in mm.
height : float
Spot height in mm.
"""
self._spot_w_px = int(width * self.screen_scale)
self._spot_h_px = int(height * self.screen_scale)
self.spot_structure = self._generate_spot_structure()
[docs]
def calculate_ewald(self, **kwargs) -> None:
"""
Calculate the Ewald construction and update spot positions.
Updates
-------
self.ew_sx : NDArray
Spot x-coordinates (mm).
self.ew_sy : NDArray
Spot y-coordinates (mm).
"""
ewald_radius: float = self.ewald_radius
incident_angle: float = self.incident_angle
screen_sample_distance: float = self.screen_sample_distance
# Arrays for reciprocal kx and ky coordinates
gx: NDArray[np.float32]
gy: NDArray[np.float32]
# Arrays for the calculated spot positions
sx: NDArray[np.float32]
sy: NDArray[np.float32]
inverse_lattice: NDArray[np.float32] = self._inverse_lattice.copy()
# Fine scaling
inverse_lattice /= self._lattice_scale
# Effective azimuthal angles used for the Ewald construction
image_azimuthal_angle: float = self.image_azimuthal_angle
ewald_azimuthal_rotation: float = self.ewald_azimuthal_rotation
# Determine which azimuthal angles are used to rotate the reciprocal lattice
if np.isclose(ewald_azimuthal_rotation, 0.0):
# No relative rotation: use the image azimuthal angle only
ewald_azimuthal_angles = [image_azimuthal_angle]
else:
# Relative rotation with respect to the image azimuthal angle
ewald_azimuthal_angles = [image_azimuthal_angle + ewald_azimuthal_rotation]
if self.mirror_symmetry:
ewald_azimuthal_angles.insert(
0, image_azimuthal_angle - ewald_azimuthal_rotation
)
# Apply azimuthal rotations to the inverse lattice and stack results
rotated_inverse_lattices = [
inverse_lattice @ rotation_matrix(azimuthal_angle)
for azimuthal_angle in ewald_azimuthal_angles
]
stacked = np.vstack(rotated_inverse_lattices)
gx, gy = stacked.T[:2]
sx, sy = convert_gx_gy_to_sx_sy(
gx,
gy,
ewald_radius,
incident_angle,
screen_sample_distance,
remove_outside=True,
)
ind: NDArray[np.bool_] = (
(sx > -self.screen_roi_width)
& (sx < self.screen_roi_width)
& (sy < self.screen_roi_height)
)
sx = sx[ind]
sy = sy[ind]
self.ew_sx = sx
self.ew_sy = sy
logger.debug(
"calculate_ewald: generated %d spots (mirror=%s) ewald_roi=%.3f",
sx.size,
self.mirror_symmetry,
getattr(self, "_ewald_roi", float("nan")),
)
[docs]
def plot(
self,
ax: Optional[Axes] = None,
show_image: bool = True,
show_roi: bool = False,
auto_levels: float = 0.0,
show_center_lines: bool = False,
**kwargs,
) -> Axes:
"""
Plot the calculated spot positions and optionally the RHEED image.
Parameters
----------
ax : Optional[Axes], optional
Matplotlib axes to plot on. If None, a new figure is created.
show_image : bool, optional
If True, plot the RHEED image (default: True).
show_roi : bool, optional
If True, overlay the ROI boundary (default: False).
auto_levels : float, optional
Contrast enhancement factor for image plotting.
show_center_lines : bool, optional
If True, plot center cross lines (default: False).
**kwargs
Additional keyword arguments for the scatter plot.
Returns
-------
matplotlib.axes.Axes
The axes with the plotted data.
"""
if ax is None:
fig, ax = plt.subplots()
logger.debug(
"plot: show_image=%s show_roi=%s show_center_lines=%s",
show_image,
show_roi,
show_center_lines,
)
if show_image:
if self.image is None:
raise ValueError("There was no RHEED image attached.")
imshow_keys = {"cmap", "vmin", "vmax"}
plot_image_kwargs = {
k: kwargs.pop(k) for k in list(kwargs.keys()) if k in imshow_keys
}
rheed_image = self.image
plot_image(
rheed_image=rheed_image,
ax=ax,
auto_levels=auto_levels,
show_center_lines=show_center_lines,
**plot_image_kwargs,
)
if show_roi:
ax.set_xlim(rheed_image.sx.min(), rheed_image.sx.max())
ax.set_ylim(rheed_image.sy.min(), rheed_image.sy.max())
# Draw vertical lines at x = ±x_width/2
ax.axvline(
x=-self.screen_roi_width, color="red", linestyle="--", linewidth=1
)
ax.axvline(
x=self.screen_roi_width, color="red", linestyle="--", linewidth=1
)
# Draw horizontal lines at y = ±y_width/2
ax.axhline(
y=-self.screen_roi_height, color="red", linestyle="--", linewidth=1
)
ax.axhline(y=0.0, color="red", linestyle="--", linewidth=1)
if "marker" not in kwargs:
kwargs["marker"] = "|"
fine_scaling: float = self.fine_scaling
ax.scatter(
(self.ew_sx + self.shift_x) * fine_scaling,
(self.ew_sy + self.shift_y) * fine_scaling,
**kwargs,
)
logger.info("Plotted %d ewald spots on axes.", getattr(self.ew_sx, "size", 0))
return ax
[docs]
def plot_spots(self, ax=None, show_image: bool = False, **kwargs):
"""
Plot the spot mask used for spot matching on a RHEED image.
Parameters
----------
ax : matplotlib.axes.Axes, optional
Matplotlib Axes to plot on. If None, a new figure and axes are created.
show_image : bool, default=False
If True, overlay the spot mask on the original RHEED image.
If False, only the mask is displayed.
**kwargs
Additional keyword arguments passed to `ax.imshow()`, e.g., `cmap`, `alpha`.
Returns
-------
matplotlib.axes.Axes
The axes containing the plotted mask (and optionally the image).
Raises
------
ValueError
If `show_image=True` but no RHEED image is attached (`self.image is None`).
Notes
-----
- The mask is automatically generated by `self._generate_mask()`.
- The image coordinates (`sx`, `sy`) are used to set the extent of the plot.
- The default colormap for the mask is grayscale.
"""
if ax is None:
_, ax = plt.subplots()
mask = self._generate_mask()
if self.image is None:
raise ValueError("There was no RHEED image attached.")
image = self.image
if "cmap" not in kwargs:
kwargs["cmap"] = "gray"
if show_image:
ax.imshow(
mask * image.data,
origin="lower",
extent=(image.sx.min(), image.sx.max(), image.sy.min(), image.sy.max()),
aspect="equal",
**kwargs,
)
logger.debug(
"plot_spots: show_image=%s mask_shape=%s",
show_image,
getattr(mask, "shape", None),
)
logger.info("Displayed spot mask on axes.")
else:
ax.imshow(
mask,
origin="lower",
extent=(image.sx.min(), image.sx.max(), image.sy.min(), image.sy.max()),
aspect="equal",
cmap="gray",
)
return ax
[docs]
def calculate_match(self, normalize: bool = True) -> np.uint32:
"""
Calculate the match coefficient between predicted and observed spots.
Parameters
----------
normalize : bool
If True, normalize the coefficient by the number of spots.
Returns
-------
np.uint32
Match coefficient.
"""
assert self.image is not None
image = self.image.data
mask = self._generate_mask()
# Calculate the match coefficient as the sum of masked image intensity
match_coef = (mask * image).sum(dtype=np.uint32)
# Optionally normalize
if normalize:
norm_coef = np.uint32(
np.count_nonzero(mask) // np.count_nonzero(self.spot_structure)
)
match_coef = np.uint32(match_coef // norm_coef)
return match_coef
[docs]
@smart_cache
def match_alpha(
self,
alpha_vector: NDArray,
normalize: bool = True,
tqdm_disable: bool = True,
) -> xr.DataArray:
"""
Calculate match coefficients over a range of azimuthal angles.
Parameters
----------
alpha_vector : NDArray[np.float64]
Array of alpha (azimuthal) angles in degrees.
normalize : bool, optional
If True, normalize the coefficients (default: True).
tqdm_disable : bool, optional
If False, show the tqdm progress bar (default: True).
Returns
-------
xr.DataArray
Match coefficients with alpha as coordinate.
"""
match_vector = np.zeros_like(alpha_vector, dtype=np.uint32)
for i, alpha in enumerate(tqdm(alpha_vector, disable=tqdm_disable)):
self.ewald_azimuthal_rotation = alpha
self.calculate_ewald()
match_vector[i] = self.calculate_match(normalize=normalize)
return xr.DataArray(
match_vector, dims=["alpha"], coords={"alpha": alpha_vector}
)
[docs]
@smart_cache
def match_scale(
self,
scale_vector: NDArray,
normalize: bool = True,
tqdm_disable: bool = True,
) -> xr.DataArray:
"""
Calculate the match coefficient for a series of lattice scale values.
Parameters
----------
scale_vector : NDArray
Array of scale values to test.
normalize : bool, optional
If True, normalize the match coefficient (default: True).
tqdm_disable : bool, optional
If False, show the tqdm progress bar (default: True).
Returns
-------
xr.DataArray
Match coefficients for each scale value.
"""
match_vector = np.zeros_like(scale_vector, dtype=np.uint32)
self.ewald_roi = self._calc_ewald_roi(scale_vector.max())
self._inverse_lattice = self._prepare_inverse_lattice()
for i, scale in enumerate(tqdm(scale_vector, disable=tqdm_disable)):
self.lattice_scale = scale
self.calculate_ewald()
match_vector[i] = self.calculate_match(normalize=normalize)
return xr.DataArray(
match_vector,
dims=["scale"],
coords={"scale": scale_vector},
)
[docs]
@smart_cache
def match_alpha_scale(
self,
alpha_vector: NDArray,
scale_vector: NDArray,
normalize: bool = True,
flatten: bool = True,
tqdm_disable: bool = True,
) -> xr.DataArray:
"""
Calculate the match coefficient for a grid of alpha angles and scale values.
Parameters
----------
alpha_vector : NDArray
Array of azimuthal angles to test.
scale_vector : NDArray
Array of scale values to test.
normalize : bool, optional
If True, normalize the match coefficient (default: True).
flatten : bool, optional
If True, the result map is flatten by subtracting quadratic
background fitted along scale direction
(default: True).
tqdm_disable : bool, optional
If False, show the tqdm progress bar (default: True).
Returns
-------
xr.DataArray
Match coefficients for each (alpha, scale) pair.
"""
match_matrix: NDArray[np.uint32] = np.zeros(
(len(alpha_vector), len(scale_vector)), dtype=np.uint32
)
self._ewald_roi = self._calc_ewald_roi(scale_vector.max())
self._inverse_lattice = self._prepare_inverse_lattice()
for i, scale in enumerate(
tqdm(scale_vector, disable=tqdm_disable, desc="Matching scales")
):
logger.info(
"Matching scale %d/%d: lattice_scale=%.2f",
i + 1,
len(scale_vector),
scale,
)
self.lattice_scale = scale
self.calculate_ewald()
match_alpha = np.zeros_like(alpha_vector)
for j, alpha in enumerate(alpha_vector):
self.ewald_azimuthal_rotation = alpha
self.calculate_ewald()
match_alpha[j] = self.calculate_match(normalize=normalize)
match_matrix[:, i] = match_alpha
if flatten:
# Step 1: Mean over alpha
mean_profile = match_matrix.mean(axis=0)
# Step 2: Fit quadratic
scale_vals = np.arange(match_matrix.shape[1]) # or use actual scale values
coeffs = np.polyfit(scale_vals, mean_profile, deg=2)
background_fit = np.poly1d(coeffs)(scale_vals)
# Step 3: Subtract background
match_matrix = match_matrix - background_fit
match_matrix -= match_matrix.min()
match_matrix_xr = xr.DataArray(
match_matrix,
dims=["alpha", "scale"],
coords={"alpha": alpha_vector, "scale": scale_vector},
)
return match_matrix_xr
def _prepare_inverse_lattice(self) -> NDArray[np.float32]:
"""
Generate reciprocal lattice points for the current ROI.
Returns
-------
NDArray[np.float32]
Inverse lattice points as an array of shape (N, 2).
"""
lattice = self._lattice
space_size = self._ewald_roi
inverse_lattice = Lattice.generate_lattice(
lattice.b1, lattice.b2, space_size=space_size
)
return inverse_lattice
def _generate_spot_structure(self) -> NDArray[np.bool_]:
"""
Generate a binary elliptical spot structure.
Returns
-------
NDArray[np.bool_]
Boolean array mask for the spot shape.
"""
# Define dimensions
spot_w = self._spot_w_px
spot_h = self._spot_h_px
spot_structure = np.zeros((spot_h, spot_w), dtype=bool)
# Center of the ellipse
center_x = spot_w / 2 - 0.5
center_y = spot_h / 2 - 0.5
# Radii of the ellipse
radius_x = spot_w / 2
radius_y = spot_h / 2
for i in range(spot_h):
for j in range(spot_w):
# Check if the point (j, i) is inside the ellipse
if ((j - center_x) ** 2 / radius_x**2) + (
(i - center_y) ** 2 / radius_y**2
) <= 1:
spot_structure[i, j] = True
return spot_structure
# TODO prepare calculate match for a list of phi angles next we will do the same for a list of
# lattice stalling
def _generate_mask(self) -> NDArray[np.bool_]:
"""
Generate a mask for predicted spot positions in the image.
Returns
-------
NDArray[np.bool_]
Boolean mask of the same shape as the RHEED image.
"""
image = self.image
assert image is not None
screen_scale = self.screen_scale
screen_roi_width = self.screen_roi_width
screen_roi_height = self.screen_roi_height
# Physical origin of image
origin_x = image.sx.values.min()
origin_y = image.sy.values.min() # bottom edge in mm
# Map physical coords to pixel indices
ppx = np.round((self.ew_sx - origin_x) * screen_scale).astype(np.uint32)
ppy = np.round((self.ew_sy - origin_y) * screen_scale).astype(np.uint32)
# Filter within bounds
valid = (
(ppx >= 0)
& (ppx < image.shape[1])
& (ppy >= 0)
& (ppy < image.shape[0])
& (self.ew_sx >= -screen_roi_width)
& (self.ew_sx <= screen_roi_width)
& (self.ew_sy >= -screen_roi_height)
& (self.ew_sy <= 0)
)
ppx = ppx[valid]
ppy = ppy[valid]
# Build mask
mask: NDArray[np.bool_] = np.zeros_like(image, dtype=np.bool_)
mask[ppy, ppx] = True
# Apply dilation
mask = ndimage.binary_dilation(mask, structure=self.spot_structure).astype(
np.bool_
)
return mask
def _calc_ewald_roi(self, scale_max: float = 1.0) -> float:
return float(
self.ewald_radius
* (self.screen_roi_width / self.screen_sample_distance)
* scale_max
)