"""
Xarray accessors for RHEED (Reflection High-Energy Electron Diffraction) data.
Accessors
---------
- **ri**: for manipulating and analyzing RHEED images, including plotting and image centering.
- **rp**: for manipulating RHEED intensity profiles.
These accessors extend xarray's `DataArray` objects with domain-specific methods for RHEED analysis.
"""
import logging
from typing import Literal, Optional, Tuple, Union
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from matplotlib.axes import Axes
from matplotlib.patches import Rectangle
from scipy import ndimage # type: ignore
from .constants import (
DEFAULT_SCREEN_ROI_HEIGHT,
DEFAULT_SCREEN_ROI_WIDTH,
IMAGE_DIMS,
IMAGE_NDIMS,
K_INV_ANGSTROM,
STACK_NDIMS,
)
from .conversion.base import convert_sx_to_ky
from .plotting.base import plot_image
from .plotting.profiles import plot_profile
from .preparation.alignment import (
find_horizontal_center,
find_incident_angle,
find_vertical_center,
)
logger = logging.getLogger(__name__)
[docs]
@xr.register_dataarray_accessor("ri")
class RHEEDAccessor:
"""
Xarray accessor for RHEED images.
Provides convenient access to RHEED-specific metadata, image manipulation,
centering, rotation, and profile extraction methods.
"""
[docs]
def __init__(self, xarray_obj: xr.DataArray) -> None:
self._obj = xarray_obj
logger.debug(
"Registered RHEEDAccessor for DataArray with shape %s and coords %s",
getattr(xarray_obj, "shape", None),
list(xarray_obj.coords.keys()),
)
# ---- Properties ----
@property
def screen_sample_distance(self) -> float:
"""Distance between sample and screen in mm. Default: 1.0 mm."""
da = self._obj
return float(da.attrs.get("screen_sample_distance", 1.0))
@property
def beta(self) -> Union[float, np.ndarray, None]:
"""Incident angle beta in degrees."""
da = self._obj
if "beta" in da.coords:
values = da.coords["beta"].values
return float(values) if values.ndim == 0 else values
else:
return None
@beta.setter
def beta(self, value: float) -> None:
"""Set incident angle beta in degrees."""
if not isinstance(value, (int, float)):
raise ValueError(f"beta must be numeric, got {value!r}")
da = self._obj
if "beta" in da.coords and da.coords["beta"].ndim != 0:
raise ValueError(
"Cannot set scalar beta on data with varying beta. "
"Assign coordinates explicitly instead."
)
da.coords["beta"] = float(value)
@property
def alpha(self) -> Union[float, np.ndarray, None]:
"""Azimuthal angle alpha in degrees."""
da = self._obj
if "alpha" in da.coords:
values = da.coords["alpha"].values
return float(values) if values.ndim == 0 else values
else:
return None
@alpha.setter
def alpha(self, value: float) -> None:
"""Set azimuthal angle alpha in degrees."""
if not isinstance(value, (int, float)):
raise ValueError(f"alpha must be numeric, got {value!r}")
da = self._obj
if "alpha" in da.coords and da.coords["alpha"].ndim != 0:
raise ValueError(
"Cannot set scalar alpha on data with varying alpha. "
"Assign coordinates explicitly instead."
)
da.coords["alpha"] = float(value)
# ------------------------------------------------------------------
# Semantic aliases
# ------------------------------------------------------------------
@property
def incident_angle(self) -> Union[float, np.ndarray, None]:
"""Alias for beta (incident angle)."""
return self.beta
@incident_angle.setter
def incident_angle(self, value: float) -> None:
self.beta = value
@property
def azimuthal_angle(self) -> Union[float, np.ndarray, None]:
"""Alias for alpha (azimuthal angle)."""
return self.alpha
@azimuthal_angle.setter
def azimuthal_angle(self, value: float) -> None:
self.alpha = value
@property
def screen_scale(self) -> float:
"""Screen scale in px/mm. Default: 1.0."""
da = self._obj
return float(da.attrs.get("screen_scale", 1.0))
@screen_scale.setter
def screen_scale(self, px_to_mm: float) -> None:
"""
Set the screen scale (px/mm) and update coordinate scaling accordingly.
Parameters
----------
px_to_mm : float
New scale in pixels per millimeter. Must be positive.
"""
if px_to_mm <= 0:
raise ValueError("screen_scale must be positive.")
da = self._obj
old_px_to_mm = self.screen_scale
da.attrs["screen_scale"] = float(px_to_mm)
missing = IMAGE_DIMS - da.coords.keys()
if missing:
raise ValueError(f"Missing required coordinate(s): {sorted(missing)}")
da["sx"] = da.sx * old_px_to_mm / px_to_mm
da["sy"] = da.sy * old_px_to_mm / px_to_mm
@property
def screen_width(self) -> Optional[float]:
"""Screen width in mm, if set."""
val = self._obj.attrs.get("screen_width")
return float(val) if val is not None else None
@property
def screen_roi_width(self) -> float:
"""Width of the region of interest (ROI) on the screen in mm."""
return float(self._obj.attrs.get("screen_roi_width", DEFAULT_SCREEN_ROI_WIDTH))
@screen_roi_width.setter
def screen_roi_width(self, value: float) -> None:
"""Set the screen ROI width in mm."""
self._obj.attrs["screen_roi_width"] = float(value)
@property
def screen_roi_height(self) -> float:
"""Height of the region of interest (ROI) on the screen in mm."""
return float(
self._obj.attrs.get("screen_roi_height", DEFAULT_SCREEN_ROI_HEIGHT)
)
@screen_roi_height.setter
def screen_roi_height(self, value: float) -> None:
"""Set the screen ROI height in mm."""
self._obj.attrs["screen_roi_height"] = float(value)
@property
def beam_energy(self) -> Optional[float]:
"""Beam energy in eV, if set."""
val = self._obj.attrs.get("beam_energy")
return float(val) if val is not None else None
@beam_energy.setter
def beam_energy(self, value: float) -> None:
"""Set the beam energy in eV."""
self._obj.attrs["beam_energy"] = float(value)
@property
def ewald_radius(self) -> float:
"""
Compute the Ewald sphere radius in reciprocal space (k-space).
Returns
-------
float
Ewald sphere radius in 1/Å.
"""
beam_energy = self.beam_energy
if beam_energy is None:
raise ValueError("Beam energy is not set.")
return np.sqrt(beam_energy) * K_INV_ANGSTROM
def __repr__(self) -> str:
da = self._obj
beta = self.incident_angle
alpha = self.azimuthal_angle
beta_str = "None" if beta is None else f"{beta:.2f} deg"
alpha_str = "None" if alpha is None else f"{alpha:.2f} deg"
return (
f"<RHEEDAccessor>\n"
f" File name: {da.attrs.get('file_name', 'N/A')}\n"
f" File creation time: {da.attrs.get('file_ctime', 'N/A')}\n"
f" Image shape: {da.shape}\n"
f" Screen scale: {self.screen_scale} px/mm\n"
f" Screen sample distance: {self.screen_sample_distance} mm\n"
f" Incident (beta) angle: {beta_str}\n"
f" Azimuthal (alpha) angle: {alpha_str}\n"
f" Beam Energy: {self.beam_energy} eV\n"
)
[docs]
def rotate(self, angle: float) -> None:
"""
Rotate the image or stack of images by a specified angle.
Parameters
----------
angle : float
Rotation angle in degrees. Positive values rotate counterclockwise.
"""
da = self._obj
logger.debug("rotate called: angle=%s, ndim=%s", angle, da.ndim)
if da.ndim == IMAGE_NDIMS:
da.data = ndimage.rotate(da.data, angle, reshape=False)
elif da.ndim == STACK_NDIMS:
stack_dim = da.dims[0]
da.data = np.stack(
[
ndimage.rotate(da.isel({stack_dim: i}).data, angle, reshape=False)
for i in range(da.sizes[stack_dim])
],
axis=0,
)
else:
logger.error(
"rotate: unsupported ndim=%s (expected %s or %s)",
da.ndim,
IMAGE_NDIMS,
STACK_NDIMS,
)
raise ValueError(
f"Expected {IMAGE_NDIMS}D or {STACK_NDIMS}D, got {da.ndim}D"
)
logger.info("Rotation applied: angle=%.4f degrees", float(angle))
[docs]
def set_center_manual(
self,
center_x: Union[float, list[float], np.ndarray] = 0.0,
center_y: Union[float, list[float], np.ndarray] = 0.0,
method: Literal["linear", "nearest", "cubic"] = "linear",
) -> None:
"""
Manually shift the image center for a single image or a stack.
Parameters
----------
center_x : float or sequence
Horizontal shift(s). Scalar applied to all frames; array-like must match stack length.
center_y : float or sequence
Vertical shift(s). Same logic as center_x.
method : {'linear', 'nearest', 'cubic'}, optional
Interpolation method for per-frame shifts (default='linear').
"""
da = self._obj
missing = IMAGE_DIMS - da.coords.keys()
if missing:
raise ValueError(f"Missing required coordinate(s): {sorted(missing)}")
if da.ndim == IMAGE_NDIMS:
da["sx"] = da.sx - center_x
da["sy"] = da.sy - center_y
elif da.ndim == STACK_NDIMS:
stack_dim = da.dims[0]
n_frames = da.sizes[stack_dim]
cx = np.asarray(center_x).copy()
cy = np.asarray(center_y).copy()
# Broadcast scalars
if cx.size == 1:
cx = np.full(n_frames, cx.item())
if cy.size == 1:
cy = np.full(n_frames, cy.item())
if len(cx) != n_frames or len(cy) != n_frames:
logger.error(
"Invalid center lengths: expected %s, got %s and %s",
n_frames,
len(cx),
len(cy),
)
raise ValueError(
f"center_x/center_y must be scalar or length={n_frames}, got {len(cx)} and {len(cy)}"
)
# Normalize shifts relative to first frame
cx0, cy0 = cx[0], cy[0]
da["sx"] = da.sx - cx0
da["sy"] = da.sy - cy0
sx_origin = da.sx.copy()
sy_origin = da.sy.copy()
cx -= cx0
cy -= cy0
# In-place modification of the underlying numpy array for all frames
for i in range(n_frames):
if i == 0:
continue # first frame already shifted
new_coords = {"sx": sx_origin + cx[i], "sy": sy_origin + cy[i]}
da.data[i] = (
da.isel({stack_dim: i})
.interp(new_coords, method=method, kwargs={"fill_value": 0})
.data
)
logger.info(
"Centering applied to %d frames: center_x=%.4f, center_y=%.4f",
n_frames,
float(cx0),
float(cy0),
)
else:
raise ValueError(
f"Unsupported ndim={da.ndim}, expected {IMAGE_NDIMS} or {STACK_NDIMS}"
)
[docs]
def set_center_auto(
self,
update_incident_angle: bool = False,
stack_index: int = 0,
) -> None:
"""
Automatically determine and apply the image center.
Parameters
----------
update_incident_angle : bool, default False
If True, recomputes and updates the incident angle after centering.
stack_index : int, default 0
Frame index to use if the data is a stack; ignored otherwise.
"""
da = self._obj
image = da[stack_index] if da.ndim == STACK_NDIMS else da
# Compute center from ROI
image_roi = image.ri.get_roi_image()
center_x = find_horizontal_center(image_roi)
center_y = find_vertical_center(image_roi, center_x=center_x)
# Apply center (mutates da)
self.set_center_manual(center_x, center_y)
logger.debug(
"Applied automatic centering: center_x=%.4f, center_y=%.4f",
float(center_x),
float(center_y),
)
if update_incident_angle:
# Re-select after mutation
image = da[stack_index] if da.ndim == STACK_NDIMS else da
image_roi = image.ri.get_roi_image()
incident_angle = find_incident_angle(image_roi)
da.ri.incident_angle = incident_angle
logger.info("Updated incident angle: %.4f", float(incident_angle))
[docs]
def get_roi_image(self) -> xr.DataArray:
"""
Return a copy of the image restricted to the screen ROI.
The ROI is defined by the attributes 'screen_roi_width' and
'screen_roi_height' (in mm).
"""
da = self._obj
roi_width: float = self.screen_roi_width
roi_height: float = self.screen_roi_height
da_roi = da.sel(
sx=slice(-roi_width, roi_width),
sy=slice(-roi_height, None),
).copy()
return da_roi
[docs]
def get_profile(
self,
center: Tuple[float, float],
width: float,
height: float,
stack_index: int = 0,
reduce_over: Literal["sy", "sx", "both"] = "sy",
method: Literal["mean", "sum"] = "mean",
show_origin: bool = False,
**kwargs,
) -> xr.DataArray:
"""
Extract an intensity profile from the RHEED image or stack.
Parameters
----------
center : tuple[float, float]
Center coordinates (sx, sy) in mm.
width : float
Width of the profile window in mm.
height : float
Height of the profile window in mm.
stack_index : int
Frame index for stacks (default=0).
reduce_over : {'sy', 'sx', 'both'}
Dimension(s) over which to reduce intensity (default='sy').
method : {'mean', 'sum'}
Reduction method (default='mean').
show_origin : bool
If True, display a rectangle showing the profile window.
Returns
-------
xr.DataArray
Profile data with metadata preserved.
"""
da = self._obj
logger.debug(
"get_profile called: center=%s width=%s height=%s stack_index=%s reduce_over=%s method=%s",
center,
width,
height,
stack_index,
reduce_over,
method,
)
cropped = da.sel(
sx=slice(center[0] - width / 2, center[0] + width / 2),
sy=slice(center[1] - height / 2, center[1] + height / 2),
)
reduce_func = cropped.mean if method == "mean" else cropped.sum
if reduce_over == "sy":
profile = reduce_func(dim="sy")
elif reduce_over == "sx":
profile = reduce_func(dim="sx")
elif reduce_over == "both":
profile = reduce_func(dim=("sy", "sx"))
else:
raise ValueError("reduce_over must be 'sy', 'sx', or 'both'")
profile.attrs = da.attrs.copy()
profile.attrs.update(
{
"profile_center": center,
"profile_width": width,
"profile_height": height,
"reduce_over": reduce_over,
"reduce_method": method,
}
)
if show_origin:
# Use provided axis or create a new one
ax = kwargs.pop(" ax", None)
if ax is None:
_, ax = plt.subplots() # discard fig, keep ax only
self.plot_image(ax=ax, stack_index=stack_index, **kwargs)
rect = Rectangle(
(center[0] - width / 2, center[1] - height / 2),
width,
height,
linewidth=1,
edgecolor="red",
facecolor="none",
)
ax.add_patch(rect)
logger.debug("Added origin rectangle to plot at center=%s", center)
return profile
[docs]
def plot_image(
self,
ax: Optional[Axes] = None,
auto_levels: float = 0.0,
show_center_lines: bool = False,
show_specular_spot: bool = False,
stack_index: int = 0,
**kwargs,
) -> Axes:
"""
Plot a RHEED image or stack frame.
Parameters
----------
ax : matplotlib.axes.Axes, optional
Axes to plot on. If None, a new figure is created.
auto_levels : float
Automatic contrast adjustment level (default=0.0).
show_center_lines : bool
If True, show horizontal and vertical center lines.
show_specular_spot : bool
If True, highlight the specular spot.
stack_index : int
Frame index for stacks (default=0).
Returns
-------
matplotlib.axes.Axes
The axes containing the plot.
"""
da = self._obj
logger.debug(
"plot_image called: ndim=%s stack_index=%s auto_levels=%s show_center_lines=%s show_specular_spot=%s",
da.ndim,
stack_index,
auto_levels,
show_center_lines,
show_specular_spot,
)
if da.ndim == STACK_NDIMS:
da = da.isel({da.dims[0]: stack_index})
elif da.ndim != IMAGE_NDIMS:
logger.error("plot_image: unsupported ndim=%s", da.ndim)
raise ValueError(
f"Expected {IMAGE_NDIMS}D or {STACK_NDIMS}D, got {da.ndim}D"
)
return plot_image(
rheed_image=da,
ax=ax,
auto_levels=auto_levels,
show_center_lines=show_center_lines,
show_specular_spot=show_specular_spot,
**kwargs,
)
[docs]
@xr.register_dataarray_accessor("rp")
class RHEEDProfileAccessor:
"""
Xarray accessor for RHEED intensity profiles.
Provides profile plotting and conversion to reciprocal space.
"""
[docs]
def __init__(self, xarray_obj: xr.DataArray):
self._obj = xarray_obj
logger.debug(
"Registered RHEEDProfileAccessor for DataArray with shape %s",
getattr(xarray_obj, "shape", None),
)
def __repr__(self):
da = self._obj
return (
f"<RHEEDProfileAccessor>\n"
f" Center: sx, sy [mm]: {da.attrs.get('profile_center', 'N/A')} \n"
f" Width: {da.attrs.get('profile_width', 'N/A')} mm\n"
f" Height: {da.attrs.get('profile_height', 'N/A')} mm\n"
f" Reduce over: {da.attrs.get('reduce_over', 'N/A')}\n"
f" Reduce method: {da.attrs.get('reduce_method', 'N/A')}\n"
)
[docs]
def convert_to_k(self) -> xr.DataArray:
"""
Convert profile coordinates from screen units (sx) to reciprocal space (ky).
Returns
-------
xr.DataArray
Profile with 'sx' replaced by 'ky'.
"""
da = self._obj
if "sx" not in da.coords:
raise ValueError("The profile must have 'sx' coordinate to convert to ky.")
k_e = da.ri.ewald_radius
screen_sample_distance = da.ri.screen_sample_distance
logger.debug(
"convert_to_k: converting sx->ky with ewald_radius=%s, screen_sample_distance=%s",
k_e,
screen_sample_distance,
)
sx = da.coords["sx"].values
ky = convert_sx_to_ky(
sx,
ewald_radius=k_e,
screen_sample_distance_mm=screen_sample_distance,
)
return da.assign_coords(sx=ky).rename({"sx": "ky"})
[docs]
def plot_profile(
self,
ax: Optional[Axes] = None,
transform_to_k: bool = True,
normalize: bool = True,
**kwargs,
) -> Axes:
"""
Plot a RHEED intensity profile.
Parameters
----------
ax : matplotlib.axes.Axes, optional
Axes to plot on. If None, a new figure is created.
transform_to_k : bool
If True, convert sx to ky using the Ewald sphere.
normalize : bool
If True, normalize intensity to [0, 1].
Returns
-------
matplotlib.axes.Axes
The axes containing the plot.
"""
da = self._obj.copy()
return plot_profile(
rheed_profile=da,
ax=ax,
transform_to_k=transform_to_k,
normalize=normalize,
**kwargs,
)