diff --git a/optika/_tests/test_sags.py b/optika/_tests/test_sags.py index 1ff81d5..4015b04 100644 --- a/optika/_tests/test_sags.py +++ b/optika/_tests/test_sags.py @@ -170,6 +170,23 @@ class TestSphericalSag( pass +@pytest.mark.parametrize( + argnames="a", + argvalues=[ + optika.sags.CylindricalSag( + radius=radius, + transformation=transformation, + ) + for radius in radius_parameterization() + for transformation in test_mixins.transformation_parameterization + ], +) +class TestCylindricalSag( + AbstractTestAbstractSag, +): + pass + + class AbstractTestAbstractConicSag( AbstractTestAbstractSag, ): diff --git a/optika/sags.py b/optika/sags.py index 8afb044..c5e2b19 100644 --- a/optika/sags.py +++ b/optika/sags.py @@ -10,6 +10,7 @@ "AbstractSag", "NoSag", "SphericalSag", + "CylindricalSag", "AbstractConicSag", "ConicSag", "ParabolicSag", @@ -286,6 +287,128 @@ def normal( return result / result.length +@dataclasses.dataclass(eq=False, repr=False) +class CylindricalSag( + AbstractSag, + Generic[RadiusT], +): + r""" + A cylindrical sag function, where the local :math:`y` axis is the axis of + symmetry for the cylinder. + + The sag (:math:`z` coordinate) of a spherical surface is calculated using + the expression + + .. math:: + + z(x, y) = \frac{c x^2}{1 + \sqrt{1 - c^2 x^2}} + + where :math:`c` is the :attr:`curvature`, + and :math:`x` is the horizontal component of the evaluation point. + + Examples + -------- + Plot a slice through the sag surface + + .. jupyter-execute:: + + import matplotlib.pyplot as plt + import astropy.units as u + import astropy.visualization + import named_arrays as na + import optika + + sag = optika.sags.SphericalSag( + radius=na.linspace(100, 300, axis="radius", num=3) * u.mm, + ) + + position = na.Cartesian3dVectorArray( + x=na.linspace(-90, 90, axis="x", num=101) * u.mm, + y=0 * u.mm, + z=0 * u.mm + ) + + z = sag(position) + + with astropy.visualization.quantity_support(): + plt.figure() + plt.gca().set_aspect("equal") + na.plt.plot(position.x, z, axis="x", label=sag.radius) + plt.legend(title="radius") + """ + + radius: RadiusT = np.inf * u.mm + """The radius of the cylinder.""" + + transformation: None | na.transformations.AbstractTransformation = None + """ + The transformation between the surface coordinate system and the sag + coordinate system. + """ + + parameters_slope_error: None | optika.metrology.SlopeErrorParameters = None + """A set of parameters describing the slope error of the sag function.""" + + parameters_roughness: None | optika.metrology.RoughnessParameters = None + """A set of parameters describing the roughness of the sag function.""" + + parameters_microroughness: None | optika.metrology.RoughnessParameters = None + """A set of parameters describing the microroughness of the sag function.""" + + @property + def shape(self) -> dict[str, int]: + return na.broadcast_shapes( + optika.shape(self.radius), + optika.shape(self.transformation), + optika.shape(self.parameters_slope_error), + optika.shape(self.parameters_roughness), + optika.shape(self.parameters_microroughness), + ) + + def __call__( + self, + position: na.AbstractCartesian3dVectorArray, + ) -> na.AbstractScalar: + + c = 1 / self.radius + transformation = self.transformation + if transformation is not None: + position = transformation.inverse(position) + + shape = na.shape_broadcasted(c, position) + c = na.broadcast_to(c, shape) + position = na.broadcast_to(position, shape) + + r2 = np.square(position.x) + sz = c * r2 / (1 + np.sqrt(1 - np.square(c) * r2)) + return sz + + def normal( + self, + position: na.AbstractCartesian3dVectorArray, + ) -> na.AbstractCartesian3dVectorArray: + + c = 1 / self.radius + transformation = self.transformation + if transformation is not None: + position = transformation.inverse(position) + + shape = na.shape_broadcasted(c, position) + c = na.broadcast_to(c, shape) + position = na.broadcast_to(position, shape) + + x2 = np.square(position.x) + c2 = np.square(c) + g = np.sqrt(1 - c2 * x2) + dzdx = c * position.x / g + result = na.Cartesian3dVectorArray( + x=dzdx, + y=0, + z=-1 * u.dimensionless_unscaled, + ) + return result / result.length + + @dataclasses.dataclass(eq=False, repr=False) class AbstractConicSag( AbstractSag,