Skip to content

Commit

Permalink
[XLA] Get rid of NaNs when base equals to infinity for complex types …
Browse files Browse the repository at this point in the history
…exponentiation.

The following cutoffs are implemented in this change:
1. inf^(a + 0i) = inf, if a > 0.
2. inf^(a + 0i) = 0, if a < 0.

PiperOrigin-RevId: 574770628
  • Loading branch information
tensorflower-gardener committed Oct 19, 2023
1 parent 2ff1abf commit 822079e
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 17 deletions.
41 changes: 34 additions & 7 deletions third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -508,13 +508,40 @@ class HloEvaluatorTypedVisitor : public ConstDfsHloVisitorWithDefault {
Status HandlePower(const HloInstruction* power) override {
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[power],
ElementWiseBinaryOp(
power, [](ElementwiseT lhs_el, ElementwiseT rhs_el) {
return lhs_el == ElementwiseT(1) ? static_cast<ElementwiseT>(1)
: lhs_el == ElementwiseT(0) && rhs_el == ElementwiseT(0)
? static_cast<ElementwiseT>(1)
: std::pow(lhs_el, rhs_el);
}));
ElementWiseBinaryOp(power, [](ElementwiseT lhs_el,
ElementwiseT rhs_el) {
// Case 0: 1^x = 1
if (lhs_el == ElementwiseT(1)) {
return static_cast<ElementwiseT>(1);
}
// Case 1: 0^0 = 1
if (lhs_el == ElementwiseT(0) && rhs_el == ElementwiseT(0)) {
return static_cast<ElementwiseT>(1);
}
// Case 2:
// 1. inf^(a + 0i) = inf, if a > 0.
// 2. inf^(a + 0i) = 0, if a < 0.
if constexpr (is_complex_v<ElementwiseT>) {
auto is_positive_infinity = [](auto c) {
return c.imag() == 0 && c.real() > 0 && std::isinf(c.real());
};
auto is_positive_real = [](ElementwiseT c) {
return c.real() > 0 && c.imag() == 0;
};
auto is_negative_real = [](ElementwiseT c) {
return c.real() < 0 && c.imag() == 0;
};
if (is_positive_infinity(lhs_el) && is_positive_real(rhs_el) > 0) {
return static_cast<ElementwiseT>(lhs_el);
}
if (is_positive_infinity(lhs_el) && is_negative_real(rhs_el)) {
return static_cast<ElementwiseT>(0);
}
}
// Case 3:
// Fallback to std::pow.
return static_cast<ElementwiseT>(std::pow(lhs_el, rhs_el));
}));
return OkStatus();
}

Expand Down
34 changes: 28 additions & 6 deletions third_party/xla/xla/service/elemental_ir_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1733,6 +1733,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexPower(
llvm::Value* d) {
PrimitiveType component_type =
primitive_util::ComplexComponentType(op->shape().element_type());
llvm::Value* inf = llvm::ConstantFP::getInfinity(a->getType());
auto aa_p_bb = FAdd(FMul(a, a), FMul(b, b));
auto zero = llvm::ConstantFP::get(a->getType(), 0);
auto one_half = llvm::ConstantFP::get(a->getType(), 0.5);
Expand All @@ -1753,15 +1754,36 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexPower(
auto q = FAdd(FMul(c, arg_lhs), FMul(half_d, ln_aa_p_bb));
TF_ASSIGN_OR_RETURN(auto cos_q, EmitCos(component_type, q));
TF_ASSIGN_OR_RETURN(auto sin_q, EmitSin(component_type, q));

// Case 0:
// d^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see
// Branch Cuts for Complex Elementary Functions or Much Ado About
// Nothing's Sign Bit, W. Kahan, Section 10.
return Select(
And(FCmpOEQ(a, one), FCmpOEQ(b, zero)), EmitComposeComplex(op, one, zero),
Select(
And(And(FCmpOEQ(aa_p_bb, zero), FCmpOEQ(d, zero)), FCmpOLE(zero, c)),
EmitComposeComplex(op, Select(FCmpOEQ(zero, c), one, zero), zero),
EmitComposeComplex(op, FMul(coeff, cos_q), FMul(coeff, sin_q))));
auto cutoff_0 = Select(
And(And(FCmpOEQ(aa_p_bb, zero), FCmpOEQ(d, zero)), FCmpOLE(zero, c)),
EmitComposeComplex(op, Select(FCmpOEQ(zero, c), one, zero), zero),
EmitComposeComplex(op, FMul(coeff, cos_q), FMul(coeff, sin_q)));

// Case 1:
// 1^(c + d*i) = 1 + 0*i
auto cutoff_1 = Select(And(FCmpOEQ(a, one), FCmpOEQ(b, zero)),
EmitComposeComplex(op, one, zero), cutoff_0);

// Case 2:
// inf^(c + 0*i) = inf + 0*i, c > 0
auto cutoff_2 = Select(
And(FCmpOEQ(a, inf),
And(FCmpOEQ(b, zero), And(FCmpOEQ(d, zero), FCmpOGT(c, zero)))),
EmitComposeComplex(op, inf, zero), cutoff_1);

// Case 3:
// inf^(c + 0*i) = 0 + 0*i, c < 0
auto cutoff_3 = Select(
And(FCmpOEQ(a, inf),
And(FCmpOEQ(b, zero), And(FCmpOEQ(d, zero), FCmpOLT(c, zero)))),
EmitComposeComplex(op, zero, zero), cutoff_2);

return cutoff_3;
}

StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp(
Expand Down
10 changes: 6 additions & 4 deletions third_party/xla/xla/tests/array_elementwise_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1622,10 +1622,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowNonIntegerF32s) {
XLA_TEST_F(ArrayElementwiseOpTest, PowC64s) {
SetFastMathDisabled(true);
XlaBuilder builder(TestName());
auto lhs = ConstantR1<complex64>(
&builder, {-2.0f, -0.6f, -0.6f, 0.0f, 0.0f, 0.0f, 1.0f});
auto rhs = ConstantR1<complex64>(
&builder, {0.5f, 0.6f, -0.6f, 0.5f, 0.6f, 0.0f, INFINITY});
auto lhs = ConstantR1<complex64>(&builder, {-2.0f, -0.6f, -0.6f, 0.0f, 0.0f,
0.0f, 1.0f, INFINITY, INFINITY});
auto rhs = ConstantR1<complex64>(&builder, {0.5f, 0.6f, -0.6f, 0.5f, 0.6f,
0.0f, INFINITY, 1.0f, -1.1234f});
Pow(lhs, rhs);

ComputeAndCompareR1<complex64>(&builder,
Expand All @@ -1637,6 +1637,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowC64s) {
{0, 0},
{1, 0},
{1, 0},
{INFINITY, 0},
{0, 0},
},
{}, error_spec_);
}
Expand Down

0 comments on commit 822079e

Please sign in to comment.