diff --git a/python/geometric_structure/geodesic/fixed_points.py b/python/geometric_structure/geodesic/fixed_points.py index 95a965f7..4c75d177 100644 --- a/python/geometric_structure/geodesic/fixed_points.py +++ b/python/geometric_structure/geodesic/fixed_points.py @@ -3,7 +3,8 @@ from ...upper_halfspace import psl2c_to_o13 # type: ignore from ...upper_halfspace.ideal_point import ideal_point_to_r13 # type: ignore from ...matrix import matrix # type: ignore -from ...math_basics import is_RealIntervalFieldElement # type: ignore +from ...math_basics import (is_RealIntervalFieldElement, + is_ComplexIntervalFieldElement) # type: ignore __all__ = ['r13_fixed_points_of_psl2c_matrix', 'r13_fixed_line_of_psl2c_matrix'] @@ -78,5 +79,14 @@ def _complex_fixed_points_of_psl2c_matrix(m): c = -m[0, 1] # Use usual formula z = (-b +/- sqrt(b^2 - 4 * a * c)) / (2 * a) - d = (b * b - 4 * a * c).sqrt() + d = _safe_complex_sqrt(b * b - 4 * a * c) return [ (-b + s * d) / (2 * a) for s in [+1, -1] ] + +def _safe_complex_sqrt(z): + if is_ComplexIntervalFieldElement(z): + if z.contains_zero(): + CIF = z.parent() + m = z.abs().sqrt().upper() + return CIF((-m, m), (-m, m)) + + return z.sqrt() diff --git a/python/hyperboloid/distances.py b/python/hyperboloid/distances.py index dabdf9ce..52eaa08a 100644 --- a/python/hyperboloid/distances.py +++ b/python/hyperboloid/distances.py @@ -49,13 +49,17 @@ def distance_r13_horoball_line(horoball_defining_vec, # Light-like def distance_r13_horoball_plane(horoball_defining_vec, # Light-like plane_defining_vec): # Unit space-like p = r13_dot(horoball_defining_vec, plane_defining_vec) - return _safe_log_non_neg(p.abs()) + return _safe_log_of_abs(p) def distance_r13_point_line(pt, # Unit time-like line : R13Line): + """ + This also works if line is degenerate and starts and ends at some point. + """ + p = (r13_dot(line.points[0], pt) * r13_dot(line.points[1], pt)) - s = -2 * p / line.inner_product + s = _safe_div(2 * p, -line.inner_product) return _safe_arccosh(_safe_sqrt(s)) def distance_r13_point_plane(pt, # Unit time-like @@ -183,6 +187,9 @@ def _safe_log(p): return RF(-1e20) return p.log() +def _safe_log_of_abs(p): + return _safe_log_non_neg(p.abs()) + def _safe_log_non_neg(p): if p == 0: if is_RealIntervalFieldElement(p): @@ -203,3 +210,23 @@ def _safe_arccosh(p): RF = p.parent() return RF(0) return p.arccosh() + +def _safe_div(a, b): + """ + Compute a / b where be is known to be non-negative and we should + return infinity if b is zero. + """ + + if is_RealIntervalFieldElement(b): + RIF = b.parent() + if b == 0: + return RIF(sage.all.Infinity) + else: + return a / b.intersection(RIF(0, sage.all.Infinity)) + else: + if b <= 0: + RIF = b.parent() + return RIF(1e20) + else: + return a / b + diff --git a/python/math_basics.py b/python/math_basics.py index 80aad5cb..012ca7ae 100644 --- a/python/math_basics.py +++ b/python/math_basics.py @@ -7,23 +7,24 @@ __all__ = ['prod', 'xgcd', 'is_RealIntervalFieldElement', + 'is_ComplexIntervalFieldElement' 'is_Interval', 'correct_min', 'correct_max', 'lower'] +def is_Interval(x): + """ + Returns True is x is either a real or complex interval as constructed + with RealIntervalField or ComplexIntervalField, respectively. + """ + return is_RealIntervalFieldElement(x) or is_ComplexIntervalFieldElement(x) + if _within_sage: from sage.all import prod, xgcd from sage.rings.real_mpfi import is_RealIntervalFieldElement from sage.rings.complex_interval import is_ComplexIntervalFieldElement - def is_Interval(x): - """ - Returns True is x is either a real or complex interval as constructed - with RealIntervalField or ComplexIntervalField, respectively. - """ - return is_RealIntervalFieldElement(x) or is_ComplexIntervalFieldElement(x) - else: def prod(L, initial=None): @@ -77,10 +78,16 @@ def is_RealIntervalFieldElement(x): # so always return False. return False - def is_Interval(x): - return False - + def is_ComplexIntervalFieldElement(x): + """ + is_ComplexIntervalFieldElement returns whether x is a complex + interval (constructed with ComplexIntervalField(precision)(value)). + """ + # We do not support interval arithmetic outside of SnapPy, + # so always return False. + return False + def correct_min(l): """ A version of min that works correctly even when l is a list of