Source code for xrheed.kinematics.ewald

import copy
import logging
import warnings
from typing import Optional

import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from matplotlib.axes import Axes
from numpy.typing import NDArray

from ..constants import IMAGE_NDIMS, K_INV_ANGSTROM, STACK_NDIMS
from ..conversion.base import convert_gx_gy_to_sx_sy
from ..plotting.base import plot_image
from .ewald_matching import (
    calculate_match,
    generate_mask,
    generate_spot_structure,
    match_alpha,
    match_alpha_scale,
    match_scale,
)
from .lattice import Lattice, rotation_matrix

logger = logging.getLogger(__name__)


[docs] class Ewald: """ Class for calculating and analyzing the Ewald sphere construction in RHEED. This class combines experimental RHEED image metadata with a reciprocal lattice model to predict diffraction spot positions on the screen. Azimuthal angle convention -------------------------- Three azimuthal angles are distinguished: * image_azimuthal_angle Experimental azimuth of the RHEED image, read from image metadata. This value defines the reference frame and is treated as immutable. * ewald_azimuthal_rotation User-defined relative rotation of the Ewald construction with respect to the image azimuth. This does not modify the image metadata. * Ewald azimuthal angles (effective) The azimuthal angles actually used in the Ewald construction, derived as image_azimuthal_angle ± ewald_azimuthal_rotation When mirror symmetry is enabled, both ± rotations are used. This separation preserves the experimental reference frame while allowing controlled relative rotation of the theoretical Ewald construction. """ SPOT_WIDTH_MM: float = 1.5 SPOT_HEIGHT_MM: float = 5.0 NO_IMAGE_DEFAULTS = { "beam_energy": 18_600.0, "screen_sample_distance": 309.2, "screen_scale": 9.5, "screen_roi_width": 60.0, "screen_roi_height": 80.0, "incident_angle": 1.0, "azimuthal_angle": 0.0, } REQUIRED_IMAGE_ATTRS = ( "beam_energy", "screen_sample_distance", "screen_scale", "incident_angle", "azimuthal_angle", )
[docs] def __init__( self, lattice: Lattice, rheed_data: Optional[xr.DataArray] = None, stack_index: int = 0, ) -> None: self._image_stack: Optional[xr.DataArray] = None self._stack_index: int = stack_index self.image: Optional[xr.DataArray] = None self._image_data_available: bool = False if rheed_data is None: self._initialize_without_rheed_data() else: self._initialize_from_rheed_data(rheed_data, stack_index) self._initialize_geometry() self._initialize_lattice(lattice) self._initialize_cache() self.calculate_ewald() logger.info( "Initialized Ewald: label=%s image_provided=%s beam_energy=%.1f screen_scale=%.2f", self.label, self._image_data_available, self.beam_energy, self.screen_scale, )
def _initialize_without_rheed_data(self) -> None: logger.warning("RHEED image not provided, default parameters are loaded.") self.beam_energy = self.NO_IMAGE_DEFAULTS["beam_energy"] self.screen_sample_distance = self.NO_IMAGE_DEFAULTS["screen_sample_distance"] self.screen_scale = self.NO_IMAGE_DEFAULTS["screen_scale"] self.screen_roi_width = self.NO_IMAGE_DEFAULTS["screen_roi_width"] self.screen_roi_height = self.NO_IMAGE_DEFAULTS["screen_roi_height"] self._incident_angle = self.NO_IMAGE_DEFAULTS["incident_angle"] self._image_azimuthal_angle = self.NO_IMAGE_DEFAULTS["azimuthal_angle"] def _initialize_from_rheed_data( self, rheed_data: xr.DataArray, stack_index: int, ) -> None: if rheed_data.ndim == IMAGE_NDIMS: self.image = rheed_data.copy() elif rheed_data.ndim == STACK_NDIMS: self._image_stack = rheed_data.copy() self.image = self._image_stack[stack_index] else: raise ValueError( f"Invalid RHEED image input.\n" f"Expected DataArray with ndim={IMAGE_NDIMS} or {STACK_NDIMS}, " f"but got ndim={getattr(rheed_data, 'ndim', None)}." ) assert self.image is not None missing = [ attr for attr in self.REQUIRED_IMAGE_ATTRS if getattr(self.image.ri, attr, None) is None ] if missing: raise ValueError( "Invalid RHEED image: missing required RI metadata attributes.\n" f"Missing attributes: {', '.join(missing)}\n" ) self.beam_energy = float(self.image.ri.beam_energy) self.screen_sample_distance = float(self.image.ri.screen_sample_distance) self.screen_scale = float(self.image.ri.screen_scale) self.screen_roi_width = float(self.image.ri.screen_roi_width) self.screen_roi_height = float(self.image.ri.screen_roi_height) self._incident_angle = rheed_data.ri.incident_angle self._image_azimuthal_angle = rheed_data.ri.azimuthal_angle self._image_data_available = True def _initialize_geometry(self) -> None: self._lattice_scale: float = 1.0 self._spot_w_px = int(self.SPOT_WIDTH_MM * self.screen_scale) self._spot_h_px = int(self.SPOT_HEIGHT_MM * self.screen_scale) self.spot_structure = generate_spot_structure( self._spot_w_px, self._spot_h_px, ) self.shift_x = 0.0 self.shift_y = 0.0 self.fine_scaling = 1.0 self.ewald_radius = np.sqrt(self.beam_energy) * K_INV_ANGSTROM self._ewald_roi = self._calc_ewald_roi() self._ewald_azimuthal_rotation = 0.0 self.mirror_symmetry = False def _initialize_lattice(self, lattice: Lattice) -> None: self._lattice = copy.deepcopy(lattice) self._inverse_lattice = self._prepare_inverse_lattice() self.label = lattice.label self.ew_sx: NDArray[np.float32] self.ew_sy: NDArray[np.float32] def _initialize_cache(self) -> None: self.use_cache = True self.cache_dir = "cache" self.cache_key = None def __repr__(self) -> str: return ( f"Ewald Class Object: {self.label}\n" f" Ewald Radius : {self.ewald_radius:.2f} 1/Å\n" f" Image azimuthal angle: : {self.image_azimuthal_angle:.2f}°\n" f" Incident angle : {self.incident_angle:.2f}°\n" f" Real Lattice Scale : {self.lattice_scale:.2f}\n" f" Screen Scale : {self.screen_scale:.2f} px/mm\n" f" Sample-Screen Distance : {self.screen_sample_distance:.1f} mm\n" f" Screen Shift X : {self.shift_x:.2f} mm\n" f" Screen Shift Y : {self.shift_y:.2f} mm\n" f" Reciprocal Vector b1 : [{self._lattice.b1[0]:.2f}, {self._lattice.b1[1]:.2f}] 1/Å\n" f" Reciprocal Vector b2 : [{self._lattice.b2[0]:.2f}, {self._lattice.b2[1]:.2f}] 1/Å\n" ) def __copy__(self) -> "Ewald": """ Create a shallow copy of the Ewald object. Returns ------- Ewald A new instance with the same parameters. """ new_ewald = Ewald(self._lattice, self.image) new_ewald.ewald_azimuthal_rotation = self.ewald_azimuthal_rotation new_ewald.incident_angle = self.incident_angle new_ewald.lattice_scale = self.lattice_scale new_ewald.ewald_roi = self.ewald_roi new_ewald._spot_w_px = self._spot_w_px new_ewald._spot_h_px = self._spot_h_px return new_ewald @property def stack_index(self) -> int: """int: Index of the current image in a stack.""" return self._stack_index @stack_index.setter def stack_index(self, value: int): if self._image_stack is None: raise ValueError("Stack index can only be set for 3D image stacks.") if not (value < self._image_stack.shape[0]): raise ValueError("Stack index out of bounds.") self._stack_index = value self.image = self._image_stack[self._stack_index] self.calculate_ewald() @property def lattice_scale(self) -> float: return self._lattice_scale @lattice_scale.setter def lattice_scale(self, value: float): if abs(self._lattice_scale - value) > 0.5: self.ewald_roi = self._calc_ewald_roi(value) logging.info("New Ewald roi: %.2f", self.ewald_roi) self._lattice_scale = value self.calculate_ewald() @property def image_azimuthal_angle(self) -> float: if isinstance(self._image_azimuthal_angle, np.ndarray): return self._image_azimuthal_angle[self._stack_index] return self._image_azimuthal_angle @image_azimuthal_angle.setter def image_azimuthal_angle(self, value: float): """ Deprecated setter. Historically this modified the azimuthal angle of the image. Now the image azimuthal angle is treated as not mutable physical property. """ if isinstance(self._image_azimuthal_angle, np.ndarray): raise TypeError( "image_azimuthal_angle is derived from stacked image metadata " "and cannot be set directly." ) warnings.warn( "Setting image_azimuthal_angle is deprecated and will be removed " "in a future version. Use ewald_azimuthal_rotation instead.", DeprecationWarning, stacklevel=2, ) # Interpret old behavior as a relative rotation self.ewald_azimuthal_rotation = float(value) - self._image_azimuthal_angle @property def ewald_azimuthal_rotation(self) -> float: return self._ewald_azimuthal_rotation @ewald_azimuthal_rotation.setter def ewald_azimuthal_rotation(self, value: float): self._ewald_azimuthal_rotation = float(value) self.calculate_ewald() @property def ewald_azimuthal_angle(self) -> float: return self.image_azimuthal_angle + self._ewald_azimuthal_rotation @property def incident_angle(self) -> float: if isinstance(self._incident_angle, np.ndarray): return self._incident_angle[self._stack_index] return self._incident_angle @incident_angle.setter def incident_angle(self, value: float): if isinstance(self._incident_angle, np.ndarray): raise ValueError("Cannot set incident individually for stack images.") self._incident_angle = value self.calculate_ewald() @property def ewald_roi(self) -> float: return self._ewald_roi @ewald_roi.setter def ewald_roi(self, value: float): self._ewald_roi = value self._inverse_lattice = self._prepare_inverse_lattice()
[docs] def set_spot_size(self, width: float, height: float): """ Set the spot size used for mask generation. Parameters ---------- width : float Spot width in mm. height : float Spot height in mm. """ self._spot_w_px = int(width * self.screen_scale) self._spot_h_px = int(height * self.screen_scale) self.spot_structure = generate_spot_structure( self._spot_w_px, self._spot_h_px, )
[docs] def calculate_ewald(self, **kwargs) -> None: """ Calculate the Ewald construction and update spot positions. Updates ------- self.ew_sx : NDArray Spot x-coordinates (mm). self.ew_sy : NDArray Spot y-coordinates (mm). """ ewald_radius: float = self.ewald_radius incident_angle: float = self.incident_angle screen_sample_distance: float = self.screen_sample_distance # Arrays for reciprocal kx and ky coordinates gx: NDArray[np.float32] gy: NDArray[np.float32] # Arrays for the calculated spot positions sx: NDArray[np.float32] sy: NDArray[np.float32] inverse_lattice: NDArray[np.float32] = self._inverse_lattice.copy() # Fine scaling inverse_lattice /= self._lattice_scale # Effective azimuthal angles used for the Ewald construction image_azimuthal_angle: float = self.image_azimuthal_angle ewald_azimuthal_rotation: float = self.ewald_azimuthal_rotation # Determine which azimuthal angles are used to rotate the reciprocal lattice if np.isclose(ewald_azimuthal_rotation, 0.0): # No relative rotation: use the image azimuthal angle only ewald_azimuthal_angles = [image_azimuthal_angle] else: # Relative rotation with respect to the image azimuthal angle ewald_azimuthal_angles = [image_azimuthal_angle + ewald_azimuthal_rotation] if self.mirror_symmetry: ewald_azimuthal_angles.insert( 0, image_azimuthal_angle - ewald_azimuthal_rotation ) # Apply azimuthal rotations to the inverse lattice and stack results rotated_inverse_lattices = [ inverse_lattice @ rotation_matrix(azimuthal_angle) for azimuthal_angle in ewald_azimuthal_angles ] stacked = np.vstack(rotated_inverse_lattices) gx, gy = stacked.T[:2] sx, sy = convert_gx_gy_to_sx_sy( gx, gy, ewald_radius, incident_angle, screen_sample_distance, remove_outside=True, ) ind: NDArray[np.bool_] = ( (sx > -self.screen_roi_width) & (sx < self.screen_roi_width) & (sy < self.screen_roi_height) ) sx = sx[ind] sy = sy[ind] self.ew_sx = sx self.ew_sy = sy logger.debug( "calculate_ewald: generated %d spots (mirror=%s) ewald_roi=%.3f", sx.size, self.mirror_symmetry, getattr(self, "_ewald_roi", float("nan")), )
[docs] def plot( self, ax: Optional[Axes] = None, show_image: bool = True, show_roi: bool = False, auto_levels: float = 0.0, show_center_lines: bool = False, **kwargs, ) -> Axes: """ Plot the calculated spot positions and optionally the RHEED image. Parameters ---------- ax : Optional[Axes], optional Matplotlib axes to plot on. If None, a new figure is created. show_image : bool, optional If True, plot the RHEED image (default: True). show_roi : bool, optional If True, overlay the ROI boundary (default: False). auto_levels : float, optional Contrast enhancement factor for image plotting. show_center_lines : bool, optional If True, plot center cross lines (default: False). **kwargs Additional keyword arguments for the scatter plot. Returns ------- matplotlib.axes.Axes The axes with the plotted data. """ if ax is None: fig, ax = plt.subplots() logger.debug( "plot: show_image=%s show_roi=%s show_center_lines=%s", show_image, show_roi, show_center_lines, ) if show_image: if self.image is None: raise ValueError("There was no RHEED image attached.") imshow_keys = {"cmap", "vmin", "vmax"} plot_image_kwargs = { k: kwargs.pop(k) for k in list(kwargs.keys()) if k in imshow_keys } rheed_image = self.image plot_image( rheed_image=rheed_image, ax=ax, auto_levels=auto_levels, show_center_lines=show_center_lines, **plot_image_kwargs, ) if show_roi: ax.set_xlim(rheed_image.sx.min(), rheed_image.sx.max()) ax.set_ylim(rheed_image.sy.min(), rheed_image.sy.max()) # Draw vertical lines at x = ±x_width/2 ax.axvline( x=-self.screen_roi_width, color="red", linestyle="--", linewidth=1 ) ax.axvline( x=self.screen_roi_width, color="red", linestyle="--", linewidth=1 ) # Draw horizontal lines at y = ±y_width/2 ax.axhline( y=-self.screen_roi_height, color="red", linestyle="--", linewidth=1 ) ax.axhline(y=0.0, color="red", linestyle="--", linewidth=1) if "marker" not in kwargs: kwargs["marker"] = "|" fine_scaling: float = self.fine_scaling ax.scatter( (self.ew_sx + self.shift_x) * fine_scaling, (self.ew_sy + self.shift_y) * fine_scaling, **kwargs, ) logger.info("Plotted %d ewald spots on axes.", getattr(self.ew_sx, "size", 0)) return ax
[docs] def plot_spots(self, ax=None, show_image: bool = False, **kwargs): """ Plot the spot mask used for spot matching on a RHEED image. Parameters ---------- ax : matplotlib.axes.Axes, optional Matplotlib Axes to plot on. If None, a new figure and axes are created. show_image : bool, default=False If True, overlay the spot mask on the original RHEED image. If False, only the mask is displayed. **kwargs Additional keyword arguments passed to `ax.imshow()`, e.g., `cmap`, `alpha`. Returns ------- matplotlib.axes.Axes The axes containing the plotted mask (and optionally the image). Raises ------ ValueError If `show_image=True` but no RHEED image is attached (`self.image is None`). Notes ----- - The mask is automatically generated by `self._generate_mask()`. - The image coordinates (`sx`, `sy`) are used to set the extent of the plot. - The default colormap for the mask is grayscale. """ if ax is None: _, ax = plt.subplots() mask = self._generate_mask() if self.image is None: raise ValueError("There was no RHEED image attached.") image = self.image if "cmap" not in kwargs: kwargs["cmap"] = "gray" if show_image: ax.imshow( mask * image.data, origin="lower", extent=(image.sx.min(), image.sx.max(), image.sy.min(), image.sy.max()), aspect="equal", **kwargs, ) logger.debug( "plot_spots: show_image=%s mask_shape=%s", show_image, getattr(mask, "shape", None), ) logger.info("Displayed spot mask on axes.") else: ax.imshow( mask, origin="lower", extent=(image.sx.min(), image.sx.max(), image.sy.min(), image.sy.max()), aspect="equal", cmap="gray", ) return ax
def _prepare_inverse_lattice(self) -> NDArray[np.float32]: """ Generate reciprocal lattice points for the current ROI. Returns ------- NDArray[np.float32] Inverse lattice points as an array of shape (N, 2). """ lattice = self._lattice space_size = self._ewald_roi inverse_lattice = Lattice.generate_lattice( lattice.b1, lattice.b2, space_size=space_size ) return inverse_lattice def _calc_ewald_roi(self, scale_max: float = 1.0) -> float: return float( self.ewald_radius * (self.screen_roi_width / self.screen_sample_distance) * scale_max ) def _generate_mask(self): return generate_mask(self)
[docs] def calculate_match(self, normalize: bool = True): return calculate_match(self, normalize=normalize)
[docs] def match_alpha( self, alpha_vector, normalize: bool = True, tqdm_disable: bool = True, ): return match_alpha( self, alpha_vector=alpha_vector, normalize=normalize, tqdm_disable=tqdm_disable, )
[docs] def match_scale( self, scale_vector, normalize: bool = True, tqdm_disable: bool = True, ): return match_scale( self, scale_vector=scale_vector, normalize=normalize, tqdm_disable=tqdm_disable, )
[docs] def match_alpha_scale( self, alpha_vector, scale_vector, normalize: bool = True, flatten: bool = True, tqdm_disable: bool = True, ): return match_alpha_scale( self, alpha_vector=alpha_vector, scale_vector=scale_vector, normalize=normalize, flatten=flatten, tqdm_disable=tqdm_disable, )