Source code for xrheed.conversion.image

import logging
import warnings
from typing import cast

import numpy as np
import xarray as xr
from numpy.typing import NDArray
from scipy import ndimage  # type: ignore
from tqdm.auto import tqdm

from ..constants import DEFAULT_K_VECT, IMAGE_NDIMS, MIRROR_ROT_DEG, STACK_NDIMS
from .base import convert_gx_gy_to_sx_sy

logger = logging.getLogger(__name__)


[docs] def transform_image_to_kxky( rheed_image: xr.DataArray, *, k_vect: np.ndarray | None = None, rotate: bool = True, point_symmetry: bool = False, ) -> xr.DataArray: """ Transform a single RHEED image into kx-ky coordinates. Notes ----- - Rotation is applied only if `rotate=True` AND azimuthal angle is available. - kx and ky are assumed identical and defined by `k_vect`. """ # --- BACKWARD COMPATIBILITY: stack support --- if rheed_image.ndim == STACK_NDIMS: warnings.warn( "Passing a stack to transform_image_to_kxky is deprecated. " "Use transform_stack_to_kxky instead.", DeprecationWarning, stacklevel=2, ) return transform_stack_to_kxky( rheed_image, k_vect=k_vect, rotate=rotate, point_symmetry=point_symmetry, ) if rheed_image.ndim != IMAGE_NDIMS: raise ValueError( f"Unsupported ndim={rheed_image.ndim}, expected {IMAGE_NDIMS} (image)" ) # --- k-vector --- if k_vect is None: k_vect = DEFAULT_K_VECT k_vect = np.asarray(k_vect, dtype=np.float32) gx, gy = np.meshgrid(k_vect, k_vect, indexing="ij") # --- geometry parameters --- ri = rheed_image.ri sx_to_kx, sy_to_ky = convert_gx_gy_to_sx_sy( gx, gy, ewald_radius=ri.ewald_radius, incident_angle=ri.incident_angle, screen_sample_distance=ri.screen_sample_distance, remove_outside=False, ) sx = xr.DataArray(sx_to_kx, dims=("kx", "ky"), coords={"kx": k_vect, "ky": k_vect}) sy = xr.DataArray(sy_to_ky, dims=("kx", "ky"), coords={"kx": k_vect, "ky": k_vect}) # --- rotation logic --- rotate_angle: float | None = None if rotate and hasattr(ri, "azimuthal_angle") and ri.azimuthal_angle is not None: rotate_angle = float(ri.azimuthal_angle) return _transform_frame_kxky( rheed_image, sx=sx, sy=sy, rotate_angle=rotate_angle, point_symmetry=point_symmetry, )
[docs] def transform_stack_to_kxky( rheed_stack: xr.DataArray, *, stack_dim: str | None = None, azimuthal_angle_coord: str = "alpha", k_vect: np.ndarray | None = None, rotate: bool = True, point_symmetry: bool = False, show_progress: bool = False, ) -> xr.DataArray: """ Transform a stack of RHEED images into kx-ky coordinates. """ if rheed_stack.ndim < STACK_NDIMS: raise ValueError("Expected stack (>=3D). Use transform_image_to_kxky.") if stack_dim is None: stack_dim = cast(str, rheed_stack.dims[0]) if stack_dim not in rheed_stack.dims: raise ValueError(f"Invalid stack_dim='{stack_dim}'") # --- k-vector --- if k_vect is None: k_vect = DEFAULT_K_VECT k_vect = np.asarray(k_vect, dtype=np.float32) gx, gy = np.meshgrid(k_vect, k_vect, indexing="ij") # --- compute sx, sy once --- ref = rheed_stack.isel({stack_dim: 0}) ri = ref.ri sx_to_kx, sy_to_ky = convert_gx_gy_to_sx_sy( gx, gy, ewald_radius=ri.ewald_radius, incident_angle=ri.incident_angle, screen_sample_distance=ri.screen_sample_distance, remove_outside=False, ) sx = xr.DataArray(sx_to_kx, dims=("kx", "ky"), coords={"kx": k_vect, "ky": k_vect}) sy = xr.DataArray(sy_to_ky, dims=("kx", "ky"), coords={"kx": k_vect, "ky": k_vect}) angles: xr.DataArray | None = None if rotate: if azimuthal_angle_coord not in rheed_stack.coords: raise ValueError( f"Missing coordinate '{azimuthal_angle_coord}' required for rotation" ) angles = rheed_stack[azimuthal_angle_coord] iterator = range(rheed_stack.sizes[stack_dim]) if show_progress: iterator = tqdm(iterator, desc="Transforming stack") results: list[xr.DataArray] = [] for i in iterator: img = rheed_stack.isel({stack_dim: i}) rotate_angle: float | None = None if rotate and angles is not None: rotate_angle = float(angles.isel({stack_dim: i})) transformed = _transform_frame_kxky( img, sx=sx, sy=sy, rotate_angle=rotate_angle, point_symmetry=point_symmetry, ) coord_val = rheed_stack[stack_dim].values[i].item() transformed = transformed.expand_dims({stack_dim: [coord_val]}) results.append(transformed) out = xr.concat( results, dim=stack_dim, coords="minimal", compat="override", join="override", ) out.attrs = rheed_stack.attrs return out
def _transform_frame_kxky( frame: xr.DataArray, *, sx: xr.DataArray, sy: xr.DataArray, rotate_angle: float | None = None, point_symmetry: bool = False, ) -> xr.DataArray: """ Transform a single 2D RHEED frame into kx-ky coordinates. """ if frame.ndim != IMAGE_NDIMS: raise ValueError("_transform_frame_kxky expects a 2D DataArray") if not np.issubdtype(frame.dtype, np.floating): frame = frame.astype(np.float32) transformed = frame.interp(sx=sx, sy=sy, method="linear") if rotate_angle is not None: transformed = _rotate_trans_image(transformed, rotate_angle) if point_symmetry: rotated_180 = _rotate_trans_image(transformed, MIRROR_ROT_DEG) transformed = xr.where(np.isnan(transformed), rotated_180, transformed) transformed.attrs = frame.attrs return transformed def _rotate_trans_image( trans_image: xr.DataArray, angle: float, mode: str = "constant", ) -> xr.DataArray: """ Rotate a 2D xarray.DataArray around its center by a given angle. Notes ----- - Assumes `trans_image` is already floating point (float32). - NaN regions are preserved using an explicit validity mask. """ if trans_image.ndim != IMAGE_NDIMS: raise ValueError("_rotate_trans_image expects a 2D DataArray") logger.debug( "called _rotate_trans_image: angle=%.3f mode=%s input_shape=%s", angle, mode, trans_image.shape, ) if "kx" not in trans_image.coords or "ky" not in trans_image.coords: raise ValueError("Rotation requires 'kx' and 'ky' coordinates") if not np.allclose(trans_image["kx"].values, trans_image["ky"].values): raise ValueError("kx and ky coordinates must be identical for rotation") valid_mask: NDArray[np.bool_] = ~np.isnan(trans_image.values) filled = trans_image.fillna(0.0) rotated_data = ndimage.rotate( filled.values, angle, reshape=False, order=3, mode=mode, ) rotated_mask: NDArray[np.bool_] = ndimage.rotate( valid_mask, angle, reshape=False, order=0, mode=mode ).astype(bool) rotated = xr.DataArray( rotated_data, coords=trans_image.coords, dims=trans_image.dims, attrs=trans_image.attrs, name=trans_image.name, ) return rotated.where(rotated_mask)