Skip to content

Commit

Permalink
Clean up relative methods into classes
Browse files Browse the repository at this point in the history
The classes keep track of the current values of the relative parameters
(before they were modified by lmfit). This is necessary for modifying
all detectors by the diff of the change.

Signed-off-by: Patrick Avery <patrick.avery@kitware.com>
  • Loading branch information
psavery committed Oct 15, 2024
1 parent 090402d commit 7c7a8c6
Show file tree
Hide file tree
Showing 8 changed files with 128 additions and 83 deletions.
4 changes: 2 additions & 2 deletions hexrd/fitting/calibration/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from .instrument import InstrumentCalibrator
from .laue import LaueCalibrator
from .lmfit_param_handling import RelativeConstraints
from .multigrain import calibrate_instrument_from_sx, generate_parameter_names
from .powder import PowderCalibrator
from .relative_constraints import RelativeConstraintsType
from .structureless import StructurelessCalibrator

# For backward-compatibility, since it used to be named this:
Expand All @@ -14,7 +14,7 @@
'InstrumentCalibrator',
'LaueCalibrator',
'PowderCalibrator',
'RelativeConstraints',
'RelativeConstraintsType',
'StructurelessCalibrator',
'StructureLessCalibrator',
]
30 changes: 25 additions & 5 deletions hexrd/fitting/calibration/instrument.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from typing import Optional

import lmfit
import numpy as np
Expand All @@ -9,7 +10,11 @@
DEFAULT_EULER_CONVENTION,
update_instrument_from_params,
validate_params_list,
)
from .relative_constraints import (
create_relative_constraints,
RelativeConstraints,
RelativeConstraintsType,
)

logger = logging.getLogger()
Expand All @@ -24,7 +29,7 @@ class InstrumentCalibrator:
def __init__(self, *args, engineering_constraints=None,
set_refinements_from_instrument_flags=True,
euler_convention=DEFAULT_EULER_CONVENTION,
relative_constraints=RelativeConstraints.none):
relative_constraints_type=RelativeConstraintsType.none):
"""
Model for instrument calibration class as a function of
Expand All @@ -47,7 +52,8 @@ def __init__(self, *args, engineering_constraints=None,
assert calib.instr is self.instr, \
"all calibrators must refer to the same instrument"
self._engineering_constraints = engineering_constraints
self._relative_constraints = relative_constraints
self._relative_constraints = create_relative_constraints(
relative_constraints_type, self.instr)
self.euler_convention = euler_convention

self.params = self.make_lmfit_params()
Expand Down Expand Up @@ -164,18 +170,32 @@ def engineering_constraints(self, v):
self._engineering_constraints = v
self.params = self.make_lmfit_params()

@property
def relative_constraints_type(self):
return self._relative_constraints.type

@relative_constraints_type.setter
def relative_constraints_type(self, v: Optional[RelativeConstraintsType]):
v = v if v is not None else RelativeConstraintsType.none

current = getattr(self, '_relative_constraints', None)
if current is None or current.type != v:
self.relative_constraints = create_relative_constraints(
v, self.instr)

@property
def relative_constraints(self) -> RelativeConstraints:
return self._relative_constraints

@relative_constraints.setter
def relative_constraints(self, v: RelativeConstraints):
if v == self._relative_constraints:
return

self._relative_constraints = v
self.params = self.make_lmfit_params()

def reset_relative_constraint_params(self):
# Set them back to zero.
self.relative_constraints.reset()

def run_calibration(self, odict):
resd0 = self.residual()
nrm_ssr_0 = _normalized_ssqr(resd0)
Expand Down
112 changes: 71 additions & 41 deletions hexrd/fitting/calibration/lmfit_param_handling.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from enum import Enum
from typing import Optional

import lmfit
import numpy as np
Expand All @@ -17,24 +17,18 @@
rotMatOfExpMap,
)
from hexrd.material.unitcell import _lpname
from .relative_constraints import (
RelativeConstraints,
RelativeConstraintsType,
)


# First is the axes_order, second is extrinsic
DEFAULT_EULER_CONVENTION = ('zxz', False)


class RelativeConstraints(Enum):
"""These are relative constraints between the detectors"""
# 'none' means no relative constraints
none = 'None'
# 'group' means constrain tilts/translations within a group
group = 'Group'
# 'system' means constrain tilts/translations within the whole system
system = 'System'


def create_instr_params(instr, euler_convention=DEFAULT_EULER_CONVENTION,
relative_constraints=RelativeConstraints.none):
relative_constraints=None):
# add with tuples: (NAME VALUE VARY MIN MAX EXPR BRUTE_STEP)
parms_list = []

Expand Down Expand Up @@ -62,23 +56,27 @@ def create_instr_params(instr, euler_convention=DEFAULT_EULER_CONVENTION,
parms_list.append(('instr_tvec_y', instr.tvec[1], False, -np.inf, np.inf))
parms_list.append(('instr_tvec_z', instr.tvec[2], False, -np.inf, np.inf))

if relative_constraints == RelativeConstraints.none:
if (
relative_constraints is None or
relative_constraints.type == RelativeConstraintsType.none
):
add_unconstrained_detector_parameters(
instr,
euler_convention,
parms_list,
)
elif relative_constraints == RelativeConstraints.group:
elif relative_constraints.type == RelativeConstraintsType.group:
# This should be implemented soon
raise NotImplementedError(relative_constraints)
elif relative_constraints == RelativeConstraints.system:
raise NotImplementedError(relative_constraints.type)
elif relative_constraints.type == RelativeConstraintsType.system:
add_system_constrained_detector_parameters(
instr,
euler_convention,
parms_list,
relative_constraints,
)
else:
raise NotImplementedError(relative_constraints)
raise NotImplementedError(relative_constraints.type)

return parms_list

Expand Down Expand Up @@ -122,10 +120,24 @@ def add_unconstrained_detector_parameters(instr, euler_convention, parms_list):
-np.inf, np.inf))


def add_system_constrained_detector_parameters(instr, euler_convention,
parms_list):
mean_center = instr.mean_detector_center
mean_tilt = instr.mean_detector_tilt
def add_system_constrained_detector_parameters(
instr, euler_convention,
parms_list, relative_constraints: RelativeConstraints):
system_params = relative_constraints.params
system_tvec = system_params['translation']
system_tilt = system_params['tilt']

if euler_convention is not None:
# Convert the tilt to the specified Euler convention
normalized = normalize_euler_convention(euler_convention)
rme = RotMatEuler(
np.zeros(3,),
axes_order=normalized[0],
extrinsic=normalized[1],
)

rme.rmat = _tilt_to_rmat(system_tilt, None)
system_tilt = np.degrees(rme.angles)

tvec_names = [
'system_tvec_x',
Expand All @@ -138,12 +150,12 @@ def add_system_constrained_detector_parameters(instr, euler_convention,
tilt_deltas = [2, 2, 2]

for i, name in enumerate(tvec_names):
value = mean_center[i]
value = system_tvec[i]
delta = tvec_deltas[i]
parms_list.append((name, value, True, value - delta, value + delta))

for i, name in enumerate(tilt_names):
value = mean_tilt[i]
value = system_tilt[i]
delta = tilt_deltas[i]
parms_list.append((name, value, True, value - delta, value + delta))

Expand All @@ -160,8 +172,10 @@ def create_beam_param_names(instr: HEDMInstrument) -> dict[str, str]:
return param_names


def update_instrument_from_params(instr, params, euler_convention,
relative_constraints):
def update_instrument_from_params(
instr, params,
euler_convention=DEFAULT_EULER_CONVENTION,
relative_constraints: Optional[RelativeConstraints] = None):
"""
this function updates the instrument from the
lmfit parameter list. we don't have to keep track
Expand Down Expand Up @@ -196,23 +210,27 @@ def update_instrument_from_params(instr, params, euler_convention,
params['instr_tvec_z'].value]
instr.tvec = np.r_[instr_tvec]

if relative_constraints == RelativeConstraints.none:
if (
relative_constraints is None or
relative_constraints.type == RelativeConstraintsType.none
):
update_unconstrained_detector_parameters(
instr,
params,
euler_convention,
)
elif relative_constraints == RelativeConstraints.group:
elif relative_constraints.type == RelativeConstraintsType.group:
# This should be implemented soon
raise NotImplementedError(relative_constraints)
elif relative_constraints == RelativeConstraints.system:
raise NotImplementedError(relative_constraints.type)
elif relative_constraints.type == RelativeConstraintsType.system:
update_system_constrained_detector_parameters(
instr,
params,
euler_convention,
relative_constraints,
)
else:
raise NotImplementedError(relative_constraints)
raise NotImplementedError(relative_constraints.type)


def update_unconstrained_detector_parameters(instr, params, euler_convention):
Expand Down Expand Up @@ -245,10 +263,15 @@ def update_unconstrained_detector_parameters(instr, params, euler_convention):
)


def update_system_constrained_detector_parameters(instr, params, euler_convention):
# We will always rotate/translate about the center of the group
def update_system_constrained_detector_parameters(
instr, params, euler_convention,
relative_constraints: RelativeConstraints):
# We will always rotate about the center of the detectors
mean_center = instr.mean_detector_center
mean_tilt = instr.mean_detector_tilt

system_params = relative_constraints.params
system_tvec = system_params['translation']
system_tilt = system_params['tilt']

tvec_names = [
'system_tvec_x',
Expand All @@ -263,11 +286,11 @@ def update_system_constrained_detector_parameters(instr, params, euler_conventio
if any(params[x].vary for x in tilt_names):
# Find the change in tilt, create an rmat, then apply to detector tilts
# and translations.
new_mean_tilt = np.array([params[x].value for x in tilt_names])
new_system_tilt = np.array([params[x].value for x in tilt_names])

# The old mean tilt was in the None convention
old_rmat = _tilt_to_rmat(mean_tilt, None)
new_rmat = _tilt_to_rmat(new_mean_tilt, euler_convention)
# The old system tilt was in the None convention
old_rmat = _tilt_to_rmat(system_tilt, None)
new_rmat = _tilt_to_rmat(new_system_tilt, euler_convention)

# Compute the rmat used to convert from old to new
rmat_diff = new_rmat @ old_rmat.T
Expand All @@ -276,19 +299,26 @@ def update_system_constrained_detector_parameters(instr, params, euler_conventio
for panel in instr.detectors.values():
panel.tilt = _rmat_to_tilt(rmat_diff @ panel.rmat)

# Also rotate the detectors about the center
# Also rotate the detectors about the mean center
panel.tvec = rmat_diff @ (panel.tvec - mean_center) + mean_center

# Update the system tilt
system_tilt[:] = _rmat_to_tilt(new_rmat)

if any(params[x].vary for x in tvec_names):
# Find the change in center and shift all tvecs
new_mean_center = np.array([params[x].value for x in tvec_names])
new_system_tvec = np.array([params[x].value for x in tvec_names])

diff = new_mean_center - mean_center
diff = new_system_tvec - system_tvec
for panel in instr.detectors.values():
panel.tvec += diff

# Update the system tvec
system_tvec[:] = new_system_tvec


def _tilt_to_rmat(tilt: np.ndarray, euler_convention: dict | tuple) -> np.ndarray:
def _tilt_to_rmat(tilt: np.ndarray,
euler_convention: dict | tuple) -> np.ndarray:
# Convert the tilt to exponential map parameters, and then
# to the rotation matrix, and return.
if euler_convention is None:
Expand Down
35 changes: 26 additions & 9 deletions hexrd/fitting/calibration/structureless.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import copy
from typing import Optional

import lmfit
import numpy as np

Expand All @@ -9,10 +11,14 @@
create_instr_params,
create_tth_parameters,
DEFAULT_EULER_CONVENTION,
RelativeConstraints,
tth_parameter_prefixes,
update_instrument_from_params,
)
from .relative_constraints import (
create_relative_constraints,
RelativeConstraints,
RelativeConstraintsType,
)


class StructurelessCalibrator:
Expand All @@ -39,14 +45,15 @@ def __init__(self,
data,
tth_distortion=None,
engineering_constraints=None,
relative_constraints=RelativeConstraints.none,
relative_constraints_type=RelativeConstraintsType.none,
euler_convention=DEFAULT_EULER_CONVENTION):

self._instr = instr
self._data = data
self._tth_distortion = tth_distortion
self._engineering_constraints = engineering_constraints
self._relative_constraints = relative_constraints
self._relative_constraints = create_relative_constraints(
relative_constraints_type, self.instr)
self.euler_convention = euler_convention
self._update_tth_distortion_panels()
self.make_lmfit_params()
Expand Down Expand Up @@ -163,16 +170,26 @@ def _update_tth_distortion_panels(self):
obj.panel = self.instr.detectors[det_key]

@property
def relative_constraints(self):
def relative_constraints_type(self):
return self._relative_constraints.type

@relative_constraints_type.setter
def relative_constraints_type(self, v: Optional[RelativeConstraintsType]):
v = v if v is not None else RelativeConstraintsType.none

current = getattr(self, '_relative_constraints', None)
if current is None or current.type != v:
self.relative_constraints = create_relative_constraints(
v, self.instr)

@property
def relative_constraints(self) -> RelativeConstraints:
return self._relative_constraints

@relative_constraints.setter
def relative_constraints(self, v):
if v == self._relative_constraints:
return

def relative_constraints(self, v: RelativeConstraints):
self._relative_constraints = v
self.make_lmfit_params()
self.params = self.make_lmfit_params()

@property
def engineering_constraints(self):
Expand Down
Loading

0 comments on commit 7c7a8c6

Please sign in to comment.