From b580857596d7b692da232773ab8f535490b73629 Mon Sep 17 00:00:00 2001 From: Gregory Pataky Date: Sat, 7 Sep 2024 13:52:21 -0700 Subject: [PATCH] Refactor `ExhaustiveOpTestBase` associated constants and types to `ExhaustiveOpTestTraits` Factors out all of the associated constants and types from `ExhaustiveOpTestBase` into a "traits" object to ease with sharing the constants between multiple types in future refactorings. PiperOrigin-RevId: 672123474 --- .../exhaustive/exhaustive_op_test_utils.cc | 27 +-- .../exhaustive/exhaustive_op_test_utils.h | 170 +++++++++--------- 2 files changed, 100 insertions(+), 97 deletions(-) diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.cc b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.cc index ae50fe0c630009..5e85347fdd0fef 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.cc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.cc @@ -452,7 +452,7 @@ void PrintMismatch(int64_t* mismatches, const ErrorGenerator& err_generator) { template void ExhaustiveOpTestBase::ExpectNear( - const InputLiterals& input_literals, const Literal& result_literal, + const LiteralInputs& input_literals, const Literal& result_literal, EvaluateOp evaluate_op, ErrorSpecGen error_spec_gen, OutputRangeCheck check_valid_range) { // Cache for when all components are subnormal testing values. @@ -466,11 +466,12 @@ void ExhaustiveOpTestBase::ExpectNear( // kNumSubnormalSubstitutionValues raised to the num_components. // num_components = N for the reals, and 2*N for the complex. int64_t max_cache_size = - pow(kNumSubnormalSubstitutionValues, N * (kIsComplex ? 2 : 1)); + pow(kNumSubnormalSubstitutionValues, N * (Traits::kIsComplex ? 2 : 1)); pure_subnormal_cache.reserve(max_cache_size); for (int i = 0; i < max_cache_size; ++i) { pure_subnormal_cache.push_back(CallOperation( - evaluate_op, FromCacheLocation(i))); + evaluate_op, + FromCacheLocation(i))); } } @@ -498,7 +499,7 @@ void ExhaustiveOpTestBase::ExpectNear( "-----------------------------------------------\n")); } - NativeInputsList inputs_arr; + NativeListInputs inputs_arr; for (int i = 0; i < N; ++i) { const Literal& literal = input_literals[i]; inputs_arr[i] = literal.data(); @@ -594,8 +595,9 @@ void ExhaustiveOpTestBase::ExpectNear( // more than 1 input. if constexpr (N == 1) { int cache_loc = - GetCacheLocation(test_value); + GetCacheLocation( + test_value); if (cache_loc == kInvalidCacheIndex) { result = CallOperation(evaluate_op, test_value); } else { @@ -624,15 +626,14 @@ void ExhaustiveOpTestBase::ExpectNear( CHECK_EQ(subnormal_test_inputs.size(), subnormal_test_results.size()); for (int i = 0; i < subnormal_test_inputs.size(); ++i) { - using IntegralNativeRefT = - typename ExhaustiveOpTestBase::ComponentIntegralNativeT; absl::StrAppend( &mismatch, - absl::StrFormat(" %10s (evaluated at %s)\n", - StringifyNum( - subnormal_test_results[i]), - GetSubnormalDescription( - subnormal_test_inputs[i], inputs_ref_ty))); + absl::StrFormat( + " %10s (evaluated at %s)\n", + StringifyNum( + subnormal_test_results[i]), + GetSubnormalDescription( + subnormal_test_inputs[i], inputs_ref_ty))); } absl::StrAppend( &mismatch, diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.h b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.h index f55d77e817fd11..4c264555659f0f 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.h +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.h @@ -217,53 +217,6 @@ class ErrorSpecBuilder { ErrorSpec spec_; }; -// Representations of the reference function passed in by the user. -template -struct EvaluateOpWrapper {}; -template -struct EvaluateOpWrapper { - using type = NativeRefT (*)(NativeRefT); -}; -template -struct EvaluateOpWrapper { - using type = NativeRefT (*)(NativeRefT, NativeRefT); -}; - -// Representations of the reference function passed in by the user. -template -struct EnqueueOpWrapper {}; -template -struct EnqueueOpWrapper { - using type = std::function; - static XlaOp BuildFromInputs(XlaInputs inputs, type ty) { - return ty(inputs[0]); - } -}; -template -struct EnqueueOpWrapper { - using type = std::function; - static XlaOp BuildFromInputs(XlaInputs inputs, type ty) { - return ty(inputs[0], inputs[1]); - } -}; - -// Representations of the ErrorSpecGen function passed in by the user. -template -struct ErrorSpecGenWrapper {}; -template -struct ErrorSpecGenWrapper { - using NativeT = typename primitive_util::PrimitiveTypeToNative::type; - using type = ErrorSpec (*)(NativeT); -}; -template -struct ErrorSpecGenWrapper { - using NativeT = typename primitive_util::PrimitiveTypeToNative::type; - using type = ErrorSpec (*)(NativeT, NativeT); -}; - -template -typename ErrorSpecGenWrapper::type GetDefaultSpecGenerator(); - // The primitive type used to compute the reference output. constexpr PrimitiveType Ref(PrimitiveType T) { return !primitive_util::IsFloatingPointType(T) || T == F64 ? T : F32; @@ -277,55 +230,104 @@ constexpr PrimitiveType Component(PrimitiveType T) { : T; } -// T: The primitive type being tested. -// N: The number of operands that the function being tested takes. +// Associates constants and types with a PrimitiveType (T) and number of test +// arguments (N) for the exhaustive test infrastructure. template -class ExhaustiveOpTestBase : public ClientLibraryTestBase { +class ExhaustiveOpTestTraits { public: - // Definitions depending on the primitive type T. static constexpr bool kIsComplex = primitive_util::IsComplexType(T); - static constexpr PrimitiveType kComponent = Component(T); static constexpr PrimitiveType kRef = Ref(T); - // Same as kComponent, but for the kRef primitive type. - static constexpr PrimitiveType kComponentRef = Component(kRef); - // The primitive type of an unsigned integer that can be bitcasted to and - // from ComponentT. + static constexpr PrimitiveType kComponent = Component(T); + static constexpr PrimitiveType kComponentRef = Component(kRef); + // The PrimitiveType of the associated unsigned integer to use T with + // bitcasting. static constexpr PrimitiveType kComponentIntegral = primitive_util::UnsignedIntegralTypeForBitWidth( primitive_util::BitWidth(kComponent)); + static constexpr PrimitiveType kComponentIntegralRef = + primitive_util::UnsignedIntegralTypeForBitWidth( + primitive_util::BitWidth(kComponentRef)); - // Native types that correspond to the primitive types above. using NativeT = primitive_util::NativeTypeOf; using NativeRefT = primitive_util::NativeTypeOf; using ComponentNativeT = primitive_util::NativeTypeOf; using ComponentNativeRefT = primitive_util::NativeTypeOf; using ComponentIntegralNativeT = primitive_util::NativeTypeOf; + using ComponentIntegralNativeRefT = + primitive_util::NativeTypeOf; - using InputLiterals = std::array; - - // N data items representing a single input to an XLA function. using NativeInputs = std::array; - - private: // N spans corresponding to the list of literal data values. - using NativeInputsList = std::array, N>; - - // N data items representing a single input to an interpreter backend - // function. + using NativeListInputs = std::array, N>; using NativeRefInputs = std::array; - - // N data items representing a single input to an XLA function. + using LiteralInputs = std::array; using XlaInputs = std::array; - public: - using ErrorSpecGen = typename ErrorSpecGenWrapper::type; - using EvaluateOp = typename EvaluateOpWrapper::type; - using EnqueueOp = typename EnqueueOpWrapper::type; + using EnqueueOp = std::conditional_t< + N == 1, std::function, + std::conditional_t, + std::enable_if_t>>; + using EvaluateOp = std::conditional_t< + N == 1, NativeRefT (*)(NativeRefT), + std::conditional_t>>; using OutputRangeCheck = std::function; - explicit ExhaustiveOpTestBase() + using ErrorSpecGen = std::conditional_t< + N == 1, ErrorSpec (*)(NativeT), + std::conditional_t>>; + + static XlaOp BuildFromInputs(XlaInputs inputs, EnqueueOp op) { + if constexpr (N == 1) { + return op(inputs[0]); + } else if constexpr (N == 2) { + return op(inputs[0], inputs[1]); + } else { + static_assert( + false, "BuildFromInputs only supports N == 1 and N == 2 currently."); + } + } +}; + +template +typename ExhaustiveOpTestTraits::ErrorSpecGen GetDefaultSpecGenerator(); + +// Base class from which all exhaustive tests should inherit. +// +// Holds a bunch of utility functions to simplify the process of running the +// operation and checking against expectations across multiple values. +// +// Type Parameters: +// - T: The primitive type being tested. +// - N: The number of operands that the function being tested takes. +template +class ExhaustiveOpTestBase : public ClientLibraryTestBase { + public: + using Traits = ExhaustiveOpTestTraits; + + using NativeT = typename Traits::NativeT; + using NativeRefT = typename Traits::NativeRefT; + using ComponentNativeT = typename Traits::ComponentNativeT; + using ComponentNativeRefT = typename Traits::ComponentNativeRefT; + using ComponentIntegralNativeT = typename Traits::ComponentIntegralNativeT; + using ComponentIntegralNativeRefT = + typename Traits::ComponentIntegralNativeRefT; + + using NativeInputs = typename Traits::NativeInputs; + using NativeListInputs = typename Traits::NativeListInputs; + using NativeRefInputs = typename Traits::NativeRefInputs; + using LiteralInputs = typename Traits::LiteralInputs; + using XlaInputs = typename Traits::XlaInputs; + + using EvaluateOp = typename Traits::EvaluateOp; + using EnqueueOp = typename Traits::EnqueueOp; + using OutputRangeCheck = typename Traits::OutputRangeCheck; + using ErrorSpecGen = typename Traits::ErrorSpecGen; + + ExhaustiveOpTestBase() : ty_(T), platform_(client_->platform()->Name()), eup_version_(xla::exhaustive_op_test::GetEupVersion()), @@ -368,7 +370,7 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { void Run(EnqueueOp enqueue_op, EvaluateOp evaluate_op, ErrorSpecGen error_spec_gen, OutputRangeCheck check_valid_range = nullptr) { - InputLiterals input_literals = CreateInputLiterals(); + LiteralInputs input_literals = CreateLiteralInputs(); FillInput(&input_literals); XlaBuilder builder(TestName()); @@ -377,7 +379,7 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { xla_inputs[i] = Parameter(&builder, i, input_literals[i].shape(), "input"); } - EnqueueOpWrapper::BuildFromInputs(xla_inputs, enqueue_op); + Traits::BuildFromInputs(xla_inputs, enqueue_op); TF_ASSERT_OK_AND_ASSIGN(XlaComputation comp, builder.Build()); TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, @@ -411,7 +413,7 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { // and just needs to be close to one of them. // check_valid_range can be used to provide a function that is called with // the result to check whether it is in the expected range. - void ExpectNear(const InputLiterals& input_literals, + void ExpectNear(const LiteralInputs& input_literals, const Literal& result_literal, EvaluateOp evaluate_op, ErrorSpecGen error_spec_gen, OutputRangeCheck check_valid_range = nullptr); @@ -483,7 +485,7 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { virtual int64_t GetInputSize() = 0; // Fills the literals with values to test for. - virtual void FillInput(InputLiterals* literals) = 0; + virtual void FillInput(LiteralInputs* literals) = 0; // Replace infinites with max value to help compute errors. static ComponentNativeRefT ReplaceInfWithMax(ComponentNativeRefT value) { @@ -643,8 +645,8 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { return test_values; } - InputLiterals CreateInputLiterals() { - InputLiterals literals; + LiteralInputs CreateLiteralInputs() { + LiteralInputs literals; for (int i = 0; i < N; ++i) { literals[i] = LiteralUtil::CreateFromDimensions(T, {GetInputSize()}); } @@ -1256,7 +1258,7 @@ inline ErrorSpec DefaultSpecGenerator(bfloat16, bfloat16) { } template -typename ErrorSpecGenWrapper::type GetDefaultSpecGenerator() { +typename ExhaustiveOpTestTraits::ErrorSpecGen GetDefaultSpecGenerator() { return DefaultSpecGenerator; } @@ -1297,8 +1299,8 @@ inline std::function AddEmptyBroadcastDimension( template class ExhaustiveUnaryTest : public ExhaustiveOpTestBase { public: - using typename ExhaustiveOpTestBase::ErrorSpecGen; - static ErrorSpecGen GetDefaultSpecGenerator() { + static typename ExhaustiveOpTestTraits::ErrorSpecGen + GetDefaultSpecGenerator() { return exhaustive_op_test::GetDefaultSpecGenerator(); } }; @@ -1306,8 +1308,8 @@ class ExhaustiveUnaryTest : public ExhaustiveOpTestBase { template class ExhaustiveBinaryTest : public ExhaustiveOpTestBase { public: - using typename ExhaustiveOpTestBase::ErrorSpecGen; - static ErrorSpecGen GetDefaultSpecGenerator() { + static typename ExhaustiveOpTestTraits::ErrorSpecGen + GetDefaultSpecGenerator() { return exhaustive_op_test::GetDefaultSpecGenerator(); } };