import logging
import warnings
from typing import Optional, Tuple
import lmfit as lf # type: ignore
import numpy as np
import xarray as xr
from lmfit.models import LorentzianModel # type: ignore
from numpy.typing import NDArray
from scipy.signal import find_peaks # type: ignore
from scipy.special import expit # type: ignore
from xrheed.preparation.filters import gaussian_filter_profile
from ..constants import IMAGE_DIMS
logger = logging.getLogger(__name__)
[docs]
def find_horizontal_center(
image: xr.DataArray,
n_stripes: int = 10,
prominence: float = 0.1,
refinement_tolerance: float = 1.0, # default in mm
) -> float:
"""
Estimate horizontal (sx) symmetry center of a diffraction image.
Parameters
----------
image : xr.DataArray
2D image with 'sx' and 'sy' coordinates.
n_stripes : int, optional
Number of horizontal stripes along 'sy' to analyze (default 10).
prominence : float, optional
Minimum prominence for peak detection (relative to normalized profile).
refinement_tolerance : float, optional
Maximum allowed deviation from global center (in sx units, default 0.2 mm).
Returns
-------
float
Estimated sx coordinate of symmetry center.
"""
logger.debug("find_horizontal_profile called")
if set(image.dims) != IMAGE_DIMS:
raise AssertionError(
f"Image dims mismatch: expected {IMAGE_DIMS}, got {set(image.dims)}"
)
# --- Global profile and approximate center ---
global_profile = image.mean(dim="sy")
smooth_sigma = 2.0 * _spot_sigma_from_profile(global_profile)
global_profile_smooth = gaussian_filter_profile(global_profile, sigma=smooth_sigma)
# Normalize
vals = global_profile_smooth.values.astype(float)
vals = (vals - vals.min()) / np.ptp(vals)
# Detect peaks
peaks, _ = find_peaks(vals, prominence=prominence)
x_coords = global_profile_smooth.sx.values[peaks]
heights = vals[peaks]
if x_coords.size == 0:
raise RuntimeError("No peaks found in global profile")
elif x_coords.size < 3:
approx_center = float(global_profile_smooth.idxmax(dim="sx").item())
else:
approx_center = float(np.average(x_coords, weights=heights))
logger.debug("Global approx_center: %.4f", approx_center)
global_max = global_profile_smooth.max()
ny = int(image.sizes["sy"])
stripe_height = max(1, ny // int(n_stripes))
sx_coords = np.asarray(image.sx.values)
centers = []
for i in range(n_stripes):
start = i * stripe_height
end = ny if i == n_stripes - 1 else (i + 1) * stripe_height
stripe = image.isel(sy=slice(start, end))
profile = stripe.mean(dim="sy")
if profile.size == 0:
continue
profile_smooth = gaussian_filter_profile(profile, sigma=smooth_sigma)
if profile_smooth.max() < global_max * 0.7:
continue
vals = profile_smooth.values.astype(float)
vals = (vals - vals.min()) / np.ptp(vals)
peaks, _ = find_peaks(vals, prominence=prominence)
if peaks.size == 0:
continue
x_coords = np.sort(sx_coords[peaks])
# Candidate centers: mean of all + drop-one means
candidates = [np.mean(x_coords)]
if x_coords.size > 1:
for j in range(x_coords.size):
candidates.append(np.mean(np.delete(x_coords, j)))
# Pick candidate closest to global approx_center
center = min(candidates, key=lambda c: abs(c - approx_center))
# Only accept if within fixed tolerance
if abs(center - approx_center) <= refinement_tolerance:
centers.append(center)
else:
logger.debug(
"Stripe %d rejected: candidate %.3f vs approx_center %.3f (tol %.3f)",
i,
center,
approx_center,
refinement_tolerance,
)
if not centers:
raise RuntimeError("No valid peaks found in any stripe.")
center_final = float(np.median(centers))
logger.info(
"Estimated horizontal center: %.4f, using %d stripes",
center_final,
len(centers),
)
return center_final
[docs]
def find_vertical_center(
image: xr.DataArray,
center_x: float = 0.0,
n_stripes: int = 10,
prominence: float = 0.1,
) -> float:
"""
Estimate the vertical (sy) center of a RHEED image using the shadow edge.
The image is divided into vertical stripes along 'sx'; for each stripe,
a profile along 'sy' is extracted and a linear+sigmoid model is fitted
to locate the shadow edge. The final center is the median of valid fits.
Parameters
----------
image : xr.DataArray
2D RHEED image with 'sx' and 'sy' coordinates.
center_x : float, optional
Horizontal center (sx) to subtract from coordinates before analysis
(default 0.0). Useful to align profiles to a previously estimated
horizontal center.
n_stripes : int, optional
Number of vertical stripes along 'sx' to analyze (default 10).
prominence : float, optional
Minimum prominence for peak detection (relative to normalized profile).
Returns
-------
float
Estimated sy coordinate of the vertical center.
"""
logger.debug("find_vertical_profile called center_x=%s", center_x)
if set(image.dims) != IMAGE_DIMS:
raise AssertionError(
f"Image dims mismatch: expected {IMAGE_DIMS}, got {set(image.dims)}"
)
nx: int = int(image.sizes["sx"])
stripe_width: int = max(1, nx // n_stripes)
global_profile = image.mean(dim="sx")
smooth_sigma: float = 1.0 * _spot_sigma_from_profile(global_profile)
centers = []
for i in range(n_stripes):
start = i * stripe_width
end = nx if i == n_stripes - 1 else (i + 1) * stripe_width
stripe = image.isel(sx=slice(start, end))
if stripe.size == 0:
continue
# Collapse stripe into a vertical profile
profile = stripe.sum(dim="sx")
# Smooth profile
profile_smoothed = gaussian_filter_profile(profile, sigma=smooth_sigma)
# Extract coordinates and values
sy_coords = profile_smoothed["sy"].values
vals = profile_smoothed.values.astype(float)
# Find all local maxima
peaks, _ = find_peaks(vals, prominence=prominence)
if peaks.size == 0:
continue
# Filter peaks to only those at negative sy
negative_mask = sy_coords[peaks] < 0
negative_peaks = peaks[negative_mask]
if negative_peaks.size == 0:
continue
# Take the *last* peak among negative sy (i.e. closest to zero from the left)
peak_idx = int(negative_peaks[-1])
# Restrict to the falling edge after that local maximum
subprofile = profile_smoothed.isel(sy=slice(peak_idx, None))
if subprofile.size == 0:
continue
sy_coords = subprofile["sy"].values
vals = subprofile.values.astype(float)
# Add synthetic plateau points before the falling edge
n_extra = 100
sy_step = np.median(np.diff(sy_coords))
sy_extra = sy_coords[0] - sy_step * np.arange(n_extra, 0, -1)
vals_extra = np.full_like(sy_extra, vals[0]) # flat extension
# Concatenate synthetic + original data
sy_coords = np.concatenate([sy_extra, sy_coords])
vals = np.concatenate([vals_extra, vals])
if np.ptp(vals) == 0:
continue
# Normalize
vals = (vals - vals.min()) / np.ptp(vals)
# Fit sigmoid with limited iterations
sigmoid_model = lf.Model(_linear_plus_sigmoid)
params = sigmoid_model.make_params(a=0.0, b=0.1, L=1.0, k=-0.5, x0=0.0)
params["L"].set(min=0.8, max=1.2)
params["k"].set(min=-2.0, max=-0.2)
params["x0"].set(min=-10.0, max=10.0)
params["a"].set(min=-0.2, max=0.2)
params["b"].set(min=0.0, max=0.2)
result = sigmoid_model.fit(
vals,
params=params,
x=sy_coords,
max_nfev=100, # limit number of iterations
)
# Accept only if fit converged and quality is reasonable
redchi = getattr(result, "redchi", np.inf)
if result.success and redchi < 0.01:
x0 = result.params["x0"].value
k = result.params["k"].value
# use the edge of the sigmoid about 16%
center = x0 - np.log(5) / k
centers.append(center)
logger.debug(
"Fit accepted: x0=%.4f, k=%.4f, redchi=%.4g", x0, k, result.redchi
)
else:
logger.debug(
"Fit rejected: success=%s, redchi=%.4g", result.success, result.redchi
)
if not centers:
raise RuntimeError("No valid vertical centers found in any stripe.")
center_y = float(np.median(centers))
logger.info(
"Vertical center estimated at %.4f, using %d edge profiles",
center_y,
len(centers),
)
# --- refinement: adjust using reflected and trismission spots if available ---
try:
sy_mirr, sy_trans = _find_reflection_and_transmission_spots(
image, center_x=center_x, center_y=center_y
)
if sy_trans is not None:
shadow_edge = 0.5 * (sy_trans + sy_mirr)
center_y += shadow_edge
logger.info(
"Adjust using reflected and transmission spots: %.4f", shadow_edge
)
else:
logger.debug("Incident angle refinement skipped (no transmission spot)")
except Exception as e:
logger.debug("Incident angle refinement failed: %s", str(e))
return center_y
[docs]
def find_incident_angle(
image: xr.DataArray,
y_range: tuple[float, float] = (-30, 30),
prominence: float = 0.1,
) -> float:
"""
Find incident angle in degrees using reflection/transmission spots near sx=0.
"""
if set(image.dims) != IMAGE_DIMS:
raise AssertionError(
f"Image dims mismatch: expected {IMAGE_DIMS}, got {set(image.dims)}"
)
screen_sample_distance: float = image.ri.screen_sample_distance
logger.debug(
"find_incident_angle: screen_sample_distance=%.4f, y_range=%s, prominence=%.3f",
screen_sample_distance,
y_range,
prominence,
)
sy_mirr, sy_trans = _find_reflection_and_transmission_spots(
image, y_range=y_range, prominence=prominence
)
logger.info("Mirror spot detected at sy=%.4f", sy_mirr)
if sy_trans is not None:
logger.info("Transmission spot detected at sy=%.4f", sy_trans)
beta_deg = _calculate_incident_angle(sy_mirr, sy_trans, screen_sample_distance)
if sy_trans is not None:
spot_distance = sy_trans - sy_mirr
shadow_edge = 0.5 * (sy_trans + sy_mirr)
logger.info("Spot distance=%.4f, Shadow edge=%.4f", spot_distance, shadow_edge)
logger.info("Incident angle (deg) from reflection+transmission: %.4f", beta_deg)
else:
logger.warning(
"Transmission spot not detected; using reflection-only estimate. "
"Incident angle (deg)=%.4f",
beta_deg,
)
return beta_deg
# Define sigmoid function for fitting
def _sigmoid(x: NDArray, amp: float, k: float, x0: float, back: float) -> NDArray:
"""
Sigmoid function used for fitting shadow edges.
Parameters
----------
x : NDArray
Input values.
amp : float
Amplitude.
k : float
Slope.
x0 : float
Center position.
back : float
Background offset.
Returns
-------
NDArray
Sigmoid function values.
"""
return amp / (1 + np.exp(-k * (x - x0))) + back
# Model: Linear + Sigmoid
def _linear_plus_sigmoid(
x: NDArray, a: float, b: float, L: float, k: float, x0: float
) -> NDArray:
"""
Linear plus sigmoid model for fitting shadow edges.
Parameters
----------
x : NDArray
Input values.
a : float
Linear slope.
b : float
Linear offset.
L : float
Sigmoid amplitude.
k : float
Sigmoid slope.
x0 : float
Sigmoid center.
Returns
-------
NDArray
Model values.
"""
return a * x + b + L * expit(k * (x - x0))
logging.getLogger(__name__)
def _spot_sigma_from_profile(
profile: xr.DataArray,
max_sigma: float = 2.0, # in mm
) -> float:
"""
Fit a Lorentzian around peaks in a 1D diffraction profile.
Iteratively expand window until fit stabilizes.
Returns sigma (HWHM), capped to avoid runaway values.
Parameters
----------
profile : xr.DataArray
1D profile with coordinate 'sx' or 'sy'.
max_sigma : float, optional
Maximum allowed sigma in mm (default 2.0).
Returns
-------
float
Estimated sigma (HWHM) in mm. Falls back to max_sigma if no stable fit.
"""
# --- coordinate extraction ---
if "sx" in profile.coords:
x = profile["sx"].values.astype(float)
elif "sy" in profile.coords:
x = profile["sy"].values.astype(float)
else:
raise AssertionError("Profile must have 'sx' or 'sy' coordinate")
y = profile.values.astype(float)
dx = abs(x[1] - x[0])
n = len(y)
# window sizes in index units
start_window = int((0.5 * max_sigma) // dx)
max_window = int((2.0 * max_sigma) // dx)
# --- find candidate peaks ---
peaks, _ = find_peaks(y, prominence=0.5)
if peaks.size == 0:
warnings.warn("No peaks detected, returning max_sigma")
return max_sigma
# Sort peaks by height (strongest first)
peak_order = peaks[np.argsort(y[peaks])[::-1]]
# --- try each peak until one works ---
for i_max in peak_order:
best_sigma = None
prev_sigma = None
for half in range(start_window, max_window + 1, start_window):
left = max(0, i_max - half)
right = min(n, i_max + half)
xw = x[left:right]
yw = y[left:right]
if len(xw) < 5 or np.ptp(yw) == 0:
continue
# Normalize to [0,1]
yw = (yw - yw.min()) / np.ptp(yw)
model = LorentzianModel(prefix="l_")
params = model.make_params()
params["l_center"].set(
value=x[i_max], min=x[i_max] - max_sigma, max=x[i_max] + max_sigma
)
params["l_sigma"].set(value=0.5, min=dx, max=0.5 * max_window)
params["l_amplitude"].set(value=1.0, min=0.5, max=1.2)
try:
result = model.fit(yw, params, x=xw, max_nfev=100)
except Exception as e:
logger.debug("Fit failed at window %d for peak %d: %s", half, i_max, e)
continue
redchi = getattr(result, "redchi", np.inf)
if not (result.success and redchi < 0.1):
logger.debug("Rejecting poor fit (redchi=%.2f)", redchi)
continue
sigma = float(result.params["l_sigma"].value)
logger.debug("Peak %d, window %d: sigma=%.4f", i_max, half, sigma)
# stability check
if prev_sigma is not None and abs(sigma - prev_sigma) < 0.05 * sigma:
best_sigma = sigma
break
prev_sigma = sigma
best_sigma = sigma
if best_sigma is not None:
capped_sigma = min(best_sigma, max_sigma)
if capped_sigma < best_sigma:
logger.debug("Sigma capped: %.4f → %.4f", best_sigma, capped_sigma)
return capped_sigma
# --- if all peaks fail ---
warnings.warn("All peak fits failed, returning max_sigma")
return max_sigma
def _find_reflection_and_transmission_spots(
image: xr.DataArray,
y_range: tuple[float, float] = (-30, 30),
prominence: float = 0.1,
center_x: float = 0.0,
center_y: float = 0.0,
) -> Tuple[float, Optional[float]]:
"""
Detect reflection (sy<0) and transmission (sy>0) spots near sx=0.
Optionally shift coordinates by center_x and center_y.
Parameters
----------
image : xr.DataArray
RHEED image with 'sx' and 'sy' coordinates.
y_range : tuple(float, float), optional
Range of sy to select for the vertical profile (default -30..30).
prominence : float, optional
Minimum prominence for peak detection.
center_x : float, optional
Horizontal center to subtract from sx (default 0.0).
center_y : float, optional
Vertical center to subtract from sy (default 0.0).
Returns
-------
sy_mirr : float
Position of the reflection spot (always required).
sy_trans : float | None
Position of the transmission spot, or None if not found.
"""
# --- determine sx range dynamically from reflection profile ---
profile_for_sigma = image.sel(sy=slice(-20, 0)).sum(dim="sy")
sigma = _spot_sigma_from_profile(profile_for_sigma) * 0.5
x_range = (center_x - sigma, center_x + sigma)
# --- vertical profile near sx=0 ---
vertical_profile: xr.DataArray = image.sel(
sx=slice(*x_range), sy=slice(*y_range)
).sum("sx")
sigma = _spot_sigma_from_profile(vertical_profile) * 0.5
vertical_profile = gaussian_filter_profile(vertical_profile, sigma=sigma)
sy_coords = vertical_profile.sy.values - center_y
vals = vertical_profile.values.astype(float)
if np.ptp(vals) == 0:
raise RuntimeError("Flat profile: cannot detect spots")
vals -= vals.min()
vals /= vals.max()
peaks, _ = find_peaks(vals, prominence=prominence)
if peaks.size == 0:
raise RuntimeError("No peaks detected in vertical profile")
sy_peaks = sy_coords[peaks]
vals_peaks = vals[peaks]
refl_candidates = sy_peaks[sy_peaks < 0]
trans_candidates = sy_peaks[sy_peaks > 0]
if refl_candidates.size == 0:
raise RuntimeError("No reflection spot detected")
sy_mirr = float(refl_candidates[np.argmax(vals_peaks[sy_peaks < 0])])
sy_trans = None
if trans_candidates.size > 0:
sy_trans = float(trans_candidates[np.argmax(vals_peaks[sy_peaks > 0])])
return sy_mirr, sy_trans
def _calculate_incident_angle(
sy_mirr: float, sy_trans: Optional[float], screen_sample_distance: float
) -> float:
"""
Calculate incident angle beta (deg) from mirror and transmission spot positions.
If sy_trans is None, use reflection-only estimate.
"""
if sy_trans is not None:
spot_distance = sy_trans - sy_mirr
beta_rad = np.arctan(0.5 * spot_distance / screen_sample_distance)
beta_deg = np.degrees(beta_rad)
else:
beta_rad = np.arctan(-sy_mirr / screen_sample_distance)
beta_deg = np.degrees(beta_rad)
return float(beta_deg)