Skip to content

Commit

Permalink
Refactor ExhaustiveOpTestBase associated constants and types to `Ex…
Browse files Browse the repository at this point in the history
…haustiveOpTestTraits`

Factors out all of the associated constants and types from `ExhaustiveOpTestBase` into a "traits" object to ease with sharing the constants between multiple types in future refactorings.

PiperOrigin-RevId: 672123474
  • Loading branch information
Gregory Pataky authored and tensorflower-gardener committed Sep 7, 2024
1 parent ea3e1b9 commit b580857
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 97 deletions.
27 changes: 14 additions & 13 deletions third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ void PrintMismatch(int64_t* mismatches, const ErrorGenerator& err_generator) {

template <PrimitiveType T, size_t N>
void ExhaustiveOpTestBase<T, N>::ExpectNear(
const InputLiterals& input_literals, const Literal& result_literal,
const LiteralInputs& input_literals, const Literal& result_literal,
EvaluateOp evaluate_op, ErrorSpecGen error_spec_gen,
OutputRangeCheck check_valid_range) {
// Cache for when all components are subnormal testing values.
Expand All @@ -466,11 +466,12 @@ void ExhaustiveOpTestBase<T, N>::ExpectNear(
// kNumSubnormalSubstitutionValues raised to the num_components.
// num_components = N for the reals, and 2*N for the complex.
int64_t max_cache_size =
pow(kNumSubnormalSubstitutionValues, N * (kIsComplex ? 2 : 1));
pow(kNumSubnormalSubstitutionValues, N * (Traits::kIsComplex ? 2 : 1));
pure_subnormal_cache.reserve(max_cache_size);
for (int i = 0; i < max_cache_size; ++i) {
pure_subnormal_cache.push_back(CallOperation(
evaluate_op, FromCacheLocation<kIsComplex, NativeRefT, N>(i)));
evaluate_op,
FromCacheLocation<Traits::kIsComplex, NativeRefT, N>(i)));
}
}

Expand Down Expand Up @@ -498,7 +499,7 @@ void ExhaustiveOpTestBase<T, N>::ExpectNear(
"-----------------------------------------------\n"));
}

NativeInputsList inputs_arr;
NativeListInputs inputs_arr;
for (int i = 0; i < N; ++i) {
const Literal& literal = input_literals[i];
inputs_arr[i] = literal.data<NativeT>();
Expand Down Expand Up @@ -594,8 +595,9 @@ void ExhaustiveOpTestBase<T, N>::ExpectNear(
// more than 1 input.
if constexpr (N == 1) {
int cache_loc =
GetCacheLocation<kIsComplex, typename NativeRefInputs::value_type,
N>(test_value);
GetCacheLocation<Traits::kIsComplex,
typename NativeRefInputs::value_type, N>(
test_value);
if (cache_loc == kInvalidCacheIndex) {
result = CallOperation(evaluate_op, test_value);
} else {
Expand Down Expand Up @@ -624,15 +626,14 @@ void ExhaustiveOpTestBase<T, N>::ExpectNear(

CHECK_EQ(subnormal_test_inputs.size(), subnormal_test_results.size());
for (int i = 0; i < subnormal_test_inputs.size(); ++i) {
using IntegralNativeRefT =
typename ExhaustiveOpTestBase<kRef, N>::ComponentIntegralNativeT;
absl::StrAppend(
&mismatch,
absl::StrFormat(" %10s (evaluated at %s)\n",
StringifyNum<NativeRefT, IntegralNativeRefT>(
subnormal_test_results[i]),
GetSubnormalDescription<kIsComplex, NativeRefT, N>(
subnormal_test_inputs[i], inputs_ref_ty)));
absl::StrFormat(
" %10s (evaluated at %s)\n",
StringifyNum<NativeRefT, ComponentIntegralNativeRefT>(
subnormal_test_results[i]),
GetSubnormalDescription<Traits::kIsComplex, NativeRefT, N>(
subnormal_test_inputs[i], inputs_ref_ty)));
}
absl::StrAppend(
&mismatch,
Expand Down
170 changes: 86 additions & 84 deletions third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,53 +217,6 @@ class ErrorSpecBuilder {
ErrorSpec spec_;
};

// Representations of the reference function passed in by the user.
template <typename NativeRefT, size_t K>
struct EvaluateOpWrapper {};
template <typename NativeRefT>
struct EvaluateOpWrapper<NativeRefT, 1> {
using type = NativeRefT (*)(NativeRefT);
};
template <typename NativeRefT>
struct EvaluateOpWrapper<NativeRefT, 2> {
using type = NativeRefT (*)(NativeRefT, NativeRefT);
};

// Representations of the reference function passed in by the user.
template <typename XlaInputs, size_t K>
struct EnqueueOpWrapper {};
template <typename XlaInputs>
struct EnqueueOpWrapper<XlaInputs, 1> {
using type = std::function<XlaOp(XlaOp)>;
static XlaOp BuildFromInputs(XlaInputs inputs, type ty) {
return ty(inputs[0]);
}
};
template <typename XlaInputs>
struct EnqueueOpWrapper<XlaInputs, 2> {
using type = std::function<XlaOp(XlaOp, XlaOp)>;
static XlaOp BuildFromInputs(XlaInputs inputs, type ty) {
return ty(inputs[0], inputs[1]);
}
};

// Representations of the ErrorSpecGen function passed in by the user.
template <PrimitiveType T, size_t K>
struct ErrorSpecGenWrapper {};
template <PrimitiveType T>
struct ErrorSpecGenWrapper<T, 1> {
using NativeT = typename primitive_util::PrimitiveTypeToNative<T>::type;
using type = ErrorSpec (*)(NativeT);
};
template <PrimitiveType T>
struct ErrorSpecGenWrapper<T, 2> {
using NativeT = typename primitive_util::PrimitiveTypeToNative<T>::type;
using type = ErrorSpec (*)(NativeT, NativeT);
};

template <PrimitiveType T, size_t N>
typename ErrorSpecGenWrapper<T, N>::type GetDefaultSpecGenerator();

// The primitive type used to compute the reference output.
constexpr PrimitiveType Ref(PrimitiveType T) {
return !primitive_util::IsFloatingPointType(T) || T == F64 ? T : F32;
Expand All @@ -277,55 +230,104 @@ constexpr PrimitiveType Component(PrimitiveType T) {
: T;
}

// T: The primitive type being tested.
// N: The number of operands that the function being tested takes.
// Associates constants and types with a PrimitiveType (T) and number of test
// arguments (N) for the exhaustive test infrastructure.
template <PrimitiveType T, size_t N>
class ExhaustiveOpTestBase : public ClientLibraryTestBase {
class ExhaustiveOpTestTraits {
public:
// Definitions depending on the primitive type T.
static constexpr bool kIsComplex = primitive_util::IsComplexType(T);
static constexpr PrimitiveType kComponent = Component(T);
static constexpr PrimitiveType kRef = Ref(T);
// Same as kComponent, but for the kRef primitive type.
static constexpr PrimitiveType kComponentRef = Component(kRef);

// The primitive type of an unsigned integer that can be bitcasted to and
// from ComponentT.
static constexpr PrimitiveType kComponent = Component(T);
static constexpr PrimitiveType kComponentRef = Component(kRef);
// The PrimitiveType of the associated unsigned integer to use T with
// bitcasting.
static constexpr PrimitiveType kComponentIntegral =
primitive_util::UnsignedIntegralTypeForBitWidth(
primitive_util::BitWidth(kComponent));
static constexpr PrimitiveType kComponentIntegralRef =
primitive_util::UnsignedIntegralTypeForBitWidth(
primitive_util::BitWidth(kComponentRef));

// Native types that correspond to the primitive types above.
using NativeT = primitive_util::NativeTypeOf<T>;
using NativeRefT = primitive_util::NativeTypeOf<kRef>;
using ComponentNativeT = primitive_util::NativeTypeOf<kComponent>;
using ComponentNativeRefT = primitive_util::NativeTypeOf<kComponentRef>;
using ComponentIntegralNativeT =
primitive_util::NativeTypeOf<kComponentIntegral>;
using ComponentIntegralNativeRefT =
primitive_util::NativeTypeOf<kComponentIntegralRef>;

using InputLiterals = std::array<Literal, N>;

// N data items representing a single input to an XLA function.
using NativeInputs = std::array<NativeT, N>;

private:
// N spans corresponding to the list of literal data values.
using NativeInputsList = std::array<absl::Span<const NativeT>, N>;

// N data items representing a single input to an interpreter backend
// function.
using NativeListInputs = std::array<absl::Span<const NativeT>, N>;
using NativeRefInputs = std::array<NativeRefT, N>;

// N data items representing a single input to an XLA function.
using LiteralInputs = std::array<Literal, N>;
using XlaInputs = std::array<XlaOp, N>;

public:
using ErrorSpecGen = typename ErrorSpecGenWrapper<T, N>::type;
using EvaluateOp = typename EvaluateOpWrapper<NativeRefT, N>::type;
using EnqueueOp = typename EnqueueOpWrapper<XlaInputs, N>::type;
using EnqueueOp = std::conditional_t<
N == 1, std::function<XlaOp(XlaOp)>,
std::conditional_t<N == 2, std::function<XlaOp(XlaOp, XlaOp)>,
std::enable_if_t<N == 1 || N == 2, void>>>;
using EvaluateOp = std::conditional_t<
N == 1, NativeRefT (*)(NativeRefT),
std::conditional_t<N == 2, NativeRefT (*)(NativeRefT, NativeRefT),
std::enable_if_t<N == 1 || N == 2, void>>>;
using OutputRangeCheck = std::function<bool(NativeInputs, NativeT)>;

explicit ExhaustiveOpTestBase()
using ErrorSpecGen = std::conditional_t<
N == 1, ErrorSpec (*)(NativeT),
std::conditional_t<N == 2, ErrorSpec (*)(NativeT, NativeT),
std::enable_if_t<N == 1 || N == 2, void>>>;

static XlaOp BuildFromInputs(XlaInputs inputs, EnqueueOp op) {
if constexpr (N == 1) {
return op(inputs[0]);
} else if constexpr (N == 2) {
return op(inputs[0], inputs[1]);
} else {
static_assert(
false, "BuildFromInputs only supports N == 1 and N == 2 currently.");
}
}
};

template <PrimitiveType T, size_t N>
typename ExhaustiveOpTestTraits<T, N>::ErrorSpecGen GetDefaultSpecGenerator();

// Base class from which all exhaustive tests should inherit.
//
// Holds a bunch of utility functions to simplify the process of running the
// operation and checking against expectations across multiple values.
//
// Type Parameters:
// - T: The primitive type being tested.
// - N: The number of operands that the function being tested takes.
template <PrimitiveType T, size_t N>
class ExhaustiveOpTestBase : public ClientLibraryTestBase {
public:
using Traits = ExhaustiveOpTestTraits<T, N>;

using NativeT = typename Traits::NativeT;
using NativeRefT = typename Traits::NativeRefT;
using ComponentNativeT = typename Traits::ComponentNativeT;
using ComponentNativeRefT = typename Traits::ComponentNativeRefT;
using ComponentIntegralNativeT = typename Traits::ComponentIntegralNativeT;
using ComponentIntegralNativeRefT =
typename Traits::ComponentIntegralNativeRefT;

using NativeInputs = typename Traits::NativeInputs;
using NativeListInputs = typename Traits::NativeListInputs;
using NativeRefInputs = typename Traits::NativeRefInputs;
using LiteralInputs = typename Traits::LiteralInputs;
using XlaInputs = typename Traits::XlaInputs;

using EvaluateOp = typename Traits::EvaluateOp;
using EnqueueOp = typename Traits::EnqueueOp;
using OutputRangeCheck = typename Traits::OutputRangeCheck;
using ErrorSpecGen = typename Traits::ErrorSpecGen;

ExhaustiveOpTestBase()
: ty_(T),
platform_(client_->platform()->Name()),
eup_version_(xla::exhaustive_op_test::GetEupVersion()),
Expand Down Expand Up @@ -368,7 +370,7 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase {
void Run(EnqueueOp enqueue_op, EvaluateOp evaluate_op,
ErrorSpecGen error_spec_gen,
OutputRangeCheck check_valid_range = nullptr) {
InputLiterals input_literals = CreateInputLiterals();
LiteralInputs input_literals = CreateLiteralInputs();
FillInput(&input_literals);

XlaBuilder builder(TestName());
Expand All @@ -377,7 +379,7 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase {
xla_inputs[i] =
Parameter(&builder, i, input_literals[i].shape(), "input");
}
EnqueueOpWrapper<XlaInputs, N>::BuildFromInputs(xla_inputs, enqueue_op);
Traits::BuildFromInputs(xla_inputs, enqueue_op);

TF_ASSERT_OK_AND_ASSIGN(XlaComputation comp, builder.Build());
TF_ASSERT_OK_AND_ASSIGN(Literal result_literal,
Expand Down Expand Up @@ -411,7 +413,7 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase {
// and just needs to be close to one of them.
// check_valid_range can be used to provide a function that is called with
// the result to check whether it is in the expected range.
void ExpectNear(const InputLiterals& input_literals,
void ExpectNear(const LiteralInputs& input_literals,
const Literal& result_literal, EvaluateOp evaluate_op,
ErrorSpecGen error_spec_gen,
OutputRangeCheck check_valid_range = nullptr);
Expand Down Expand Up @@ -483,7 +485,7 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase {
virtual int64_t GetInputSize() = 0;

// Fills the literals with values to test for.
virtual void FillInput(InputLiterals* literals) = 0;
virtual void FillInput(LiteralInputs* literals) = 0;

// Replace infinites with max value to help compute errors.
static ComponentNativeRefT ReplaceInfWithMax(ComponentNativeRefT value) {
Expand Down Expand Up @@ -643,8 +645,8 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase {
return test_values;
}

InputLiterals CreateInputLiterals() {
InputLiterals literals;
LiteralInputs CreateLiteralInputs() {
LiteralInputs literals;
for (int i = 0; i < N; ++i) {
literals[i] = LiteralUtil::CreateFromDimensions(T, {GetInputSize()});
}
Expand Down Expand Up @@ -1256,7 +1258,7 @@ inline ErrorSpec DefaultSpecGenerator<BF16, 2>(bfloat16, bfloat16) {
}

template <PrimitiveType T, size_t N>
typename ErrorSpecGenWrapper<T, N>::type GetDefaultSpecGenerator() {
typename ExhaustiveOpTestTraits<T, N>::ErrorSpecGen GetDefaultSpecGenerator() {
return DefaultSpecGenerator<T, N>;
}

Expand Down Expand Up @@ -1297,17 +1299,17 @@ inline std::function<XlaOp(XlaOp, XlaOp)> AddEmptyBroadcastDimension(
template <PrimitiveType T>
class ExhaustiveUnaryTest : public ExhaustiveOpTestBase<T, 1> {
public:
using typename ExhaustiveOpTestBase<T, 1>::ErrorSpecGen;
static ErrorSpecGen GetDefaultSpecGenerator() {
static typename ExhaustiveOpTestTraits<T, 1>::ErrorSpecGen
GetDefaultSpecGenerator() {
return exhaustive_op_test::GetDefaultSpecGenerator<T, 1>();
}
};

template <PrimitiveType T>
class ExhaustiveBinaryTest : public ExhaustiveOpTestBase<T, 2> {
public:
using typename ExhaustiveOpTestBase<T, 2>::ErrorSpecGen;
static ErrorSpecGen GetDefaultSpecGenerator() {
static typename ExhaustiveOpTestTraits<T, 2>::ErrorSpecGen
GetDefaultSpecGenerator() {
return exhaustive_op_test::GetDefaultSpecGenerator<T, 2>();
}
};
Expand Down

0 comments on commit b580857

Please sign in to comment.