diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h index 274a5431fa6b4f..3ce52b01cf8d66 100644 --- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h +++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h @@ -508,13 +508,40 @@ class HloEvaluatorTypedVisitor : public ConstDfsHloVisitorWithDefault { Status HandlePower(const HloInstruction* power) override { TF_ASSIGN_OR_RETURN( parent_->evaluated_[power], - ElementWiseBinaryOp( - power, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { - return lhs_el == ElementwiseT(1) ? static_cast(1) - : lhs_el == ElementwiseT(0) && rhs_el == ElementwiseT(0) - ? static_cast(1) - : std::pow(lhs_el, rhs_el); - })); + ElementWiseBinaryOp(power, [](ElementwiseT lhs_el, + ElementwiseT rhs_el) { + // Case 0: 1^x = 1 + if (lhs_el == ElementwiseT(1)) { + return static_cast(1); + } + // Case 1: 0^0 = 1 + if (lhs_el == ElementwiseT(0) && rhs_el == ElementwiseT(0)) { + return static_cast(1); + } + // Case 2: + // 1. inf^(a + 0i) = inf, if a > 0. + // 2. inf^(a + 0i) = 0, if a < 0. + if constexpr (is_complex_v) { + auto is_positive_infinity = [](auto c) { + return c.imag() == 0 && c.real() > 0 && std::isinf(c.real()); + }; + auto is_positive_real = [](ElementwiseT c) { + return c.real() > 0 && c.imag() == 0; + }; + auto is_negative_real = [](ElementwiseT c) { + return c.real() < 0 && c.imag() == 0; + }; + if (is_positive_infinity(lhs_el) && is_positive_real(rhs_el) > 0) { + return static_cast(lhs_el); + } + if (is_positive_infinity(lhs_el) && is_negative_real(rhs_el)) { + return static_cast(0); + } + } + // Case 3: + // Fallback to std::pow. + return static_cast(std::pow(lhs_el, rhs_el)); + })); return OkStatus(); } diff --git a/third_party/xla/xla/service/elemental_ir_emitter.cc b/third_party/xla/xla/service/elemental_ir_emitter.cc index b5e5faed879857..8f6d656923b927 100644 --- a/third_party/xla/xla/service/elemental_ir_emitter.cc +++ b/third_party/xla/xla/service/elemental_ir_emitter.cc @@ -1733,6 +1733,7 @@ StatusOr ElementalIrEmitter::EmitComplexPower( llvm::Value* d) { PrimitiveType component_type = primitive_util::ComplexComponentType(op->shape().element_type()); + llvm::Value* inf = llvm::ConstantFP::getInfinity(a->getType()); auto aa_p_bb = FAdd(FMul(a, a), FMul(b, b)); auto zero = llvm::ConstantFP::get(a->getType(), 0); auto one_half = llvm::ConstantFP::get(a->getType(), 0.5); @@ -1753,15 +1754,36 @@ StatusOr ElementalIrEmitter::EmitComplexPower( auto q = FAdd(FMul(c, arg_lhs), FMul(half_d, ln_aa_p_bb)); TF_ASSIGN_OR_RETURN(auto cos_q, EmitCos(component_type, q)); TF_ASSIGN_OR_RETURN(auto sin_q, EmitSin(component_type, q)); + + // Case 0: // d^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see // Branch Cuts for Complex Elementary Functions or Much Ado About // Nothing's Sign Bit, W. Kahan, Section 10. - return Select( - And(FCmpOEQ(a, one), FCmpOEQ(b, zero)), EmitComposeComplex(op, one, zero), - Select( - And(And(FCmpOEQ(aa_p_bb, zero), FCmpOEQ(d, zero)), FCmpOLE(zero, c)), - EmitComposeComplex(op, Select(FCmpOEQ(zero, c), one, zero), zero), - EmitComposeComplex(op, FMul(coeff, cos_q), FMul(coeff, sin_q)))); + auto cutoff_0 = Select( + And(And(FCmpOEQ(aa_p_bb, zero), FCmpOEQ(d, zero)), FCmpOLE(zero, c)), + EmitComposeComplex(op, Select(FCmpOEQ(zero, c), one, zero), zero), + EmitComposeComplex(op, FMul(coeff, cos_q), FMul(coeff, sin_q))); + + // Case 1: + // 1^(c + d*i) = 1 + 0*i + auto cutoff_1 = Select(And(FCmpOEQ(a, one), FCmpOEQ(b, zero)), + EmitComposeComplex(op, one, zero), cutoff_0); + + // Case 2: + // inf^(c + 0*i) = inf + 0*i, c > 0 + auto cutoff_2 = Select( + And(FCmpOEQ(a, inf), + And(FCmpOEQ(b, zero), And(FCmpOEQ(d, zero), FCmpOGT(c, zero)))), + EmitComposeComplex(op, inf, zero), cutoff_1); + + // Case 3: + // inf^(c + 0*i) = 0 + 0*i, c < 0 + auto cutoff_3 = Select( + And(FCmpOEQ(a, inf), + And(FCmpOEQ(b, zero), And(FCmpOEQ(d, zero), FCmpOLT(c, zero)))), + EmitComposeComplex(op, zero, zero), cutoff_2); + + return cutoff_3; } StatusOr ElementalIrEmitter::EmitComplexBinaryOp( diff --git a/third_party/xla/xla/tests/array_elementwise_ops_test.cc b/third_party/xla/xla/tests/array_elementwise_ops_test.cc index 4cc72ce2a118c4..d932c5c540566a 100644 --- a/third_party/xla/xla/tests/array_elementwise_ops_test.cc +++ b/third_party/xla/xla/tests/array_elementwise_ops_test.cc @@ -1622,10 +1622,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowNonIntegerF32s) { XLA_TEST_F(ArrayElementwiseOpTest, PowC64s) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); - auto lhs = ConstantR1( - &builder, {-2.0f, -0.6f, -0.6f, 0.0f, 0.0f, 0.0f, 1.0f}); - auto rhs = ConstantR1( - &builder, {0.5f, 0.6f, -0.6f, 0.5f, 0.6f, 0.0f, INFINITY}); + auto lhs = ConstantR1(&builder, {-2.0f, -0.6f, -0.6f, 0.0f, 0.0f, + 0.0f, 1.0f, INFINITY, INFINITY}); + auto rhs = ConstantR1(&builder, {0.5f, 0.6f, -0.6f, 0.5f, 0.6f, + 0.0f, INFINITY, 1.0f, -1.1234f}); Pow(lhs, rhs); ComputeAndCompareR1(&builder, @@ -1637,6 +1637,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowC64s) { {0, 0}, {1, 0}, {1, 0}, + {INFINITY, 0}, + {0, 0}, }, {}, error_spec_); }