Skip to content

Commit

Permalink
Adds a lowering to generate the Philox algorithm in a linalg.generic …
Browse files Browse the repository at this point in the history
…op. The lowering numerically matches the existing XLA implementation.

PiperOrigin-RevId: 536232082
  • Loading branch information
tensorflower-gardener committed May 29, 2023
1 parent 0a2aab7 commit 1374a1c
Show file tree
Hide file tree
Showing 4 changed files with 571 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -2346,6 +2346,14 @@ class RngBitGeneratorConverter
return success();
}

if (op.getRngAlgorithm() == mhlo::RngAlgorithm::PHILOX) {
Value random;
if (generateLinalgPhilox(rewriter, loc, resultTy, state, random).failed())
return failure();
rewriter.replaceOp(op, {state, random});
return success();
}

return failure();
}
};
Expand Down
272 changes: 271 additions & 1 deletion tensorflow/compiler/xla/mlir_hlo/mhlo/utils/mhlo_rng_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ class ArithOp {
return ArithOp(builder, loc, res);
}

ArithOp operator*(ArithOp &rhs) {
Value res = builder.create<arith::MulIOp>(loc, value, rhs.value);
return ArithOp(builder, loc, res);
}

ArithOp operator|(ArithOp &rhs) {
Value res = builder.create<arith::OrIOp>(loc, value, rhs.value);
return ArithOp(builder, loc, res);
Expand Down Expand Up @@ -322,7 +327,7 @@ std::pair<ShapedType, int64_t> threeFry32Shape(ShapedType resultTy) {

// This implementation generates a 32-bit tensor of ThreeFry random numbers.
// It matches the XLA implementation bit-exact and includes an inefficient
// method of concatenating / slicing the pairs of generated nunbers.
// method of concatenating / slicing the pairs of generated numbers.
//
// We should consider dropping the complex slicing and simply generating
// 2x the values, then downcast to a 32-bit. It substantially simplifies
Expand Down Expand Up @@ -464,6 +469,255 @@ LogicalResult generateLinalgThreeFry64(OpBuilder &builder, Location loc,
return success();
}

using PhiloxKey = std::pair<ArithOp, ArithOp>;
using PhiloxState = std::array<ArithOp, 4>;

// Computes high and low words from multiplying 32 bit integers.
// Per the paper, mulhi and mullo of the same arguments can be computed
// Simultaneously in a single instruction on x86 architectures.
std::pair<ArithOp, ArithOp> multiplyHilo(ArithOp counter, ArithOp key) {
counter = counter.extendUI(64);
key = key.extendUI(64);
ArithOp product = counter * key;
ArithOp ci64 = counter.constantI(/*value=*/32, /*bits=*/64);
ArithOp hi = product >> ci64;
hi = hi.truncI(32);
product = product.truncI(32);
return std::pair<ArithOp, ArithOp>{hi, product};
}

PhiloxState philoxRound(PhiloxState x, PhiloxKey key) {
// These are philox specific constants.
ArithOp m0 = x[0].constantI(0xD2511F53, 32);
ArithOp m1 = x[2].constantI(0xCD9E8D57, 32);
std::pair<ArithOp, ArithOp> p0 = multiplyHilo(x[0], m0);
std::pair<ArithOp, ArithOp> p1 = multiplyHilo(x[2], m1);

PhiloxState state = {p1.first ^ x[1] ^ key.first, p1.second,
p0.first ^ x[3] ^ key.second, p0.second};
return state;
}

PhiloxKey raiseKey(PhiloxKey key) {
// These are philox specific constants.
ArithOp w0 = key.first.constantI(0x9E3779B9, 32);
ArithOp w1 = key.first.constantI(0xBB67AE85, 32);
return PhiloxKey{key.first + w0, key.second + w1};
}

// Implements the Philox 4x32 counter-based PRNG algorithm.
// The Philox PRNG has been proposed in:
// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.
// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
std::array<ArithOp, 4> runPhilox4x32(PhiloxKey key, ArithOp state) {
ArithOp index = state.linalgIndex(0);
index = index.indexCast(64);
index = index + state;

// Split into the 2xi32 used for threefry.
std::pair<ArithOp, ArithOp> input = splitI64(index);
ArithOp input0 = input.first;
ArithOp input1 = input.second;

// We initialize the state as such to match the XLA implementation.
PhiloxState state4 = {input0, input1, key.first, key.second};

// We perform 10 rounds to match the XLA implementation.
static const int kNumRounds = 10;
for (int round = 0; round < kNumRounds; ++round, key = raiseKey(key)) {
state4 = philoxRound(state4, key);
}
return state4;
}

// Generates an array of primitive type U32 with the given shape containing
// random bits generated by the Philox algorithm. Returns the array and the new
// state of the random number generator.
LogicalResult generateLinalgPhilox32(OpBuilder &builder, Location loc,
ShapedType resultTy, Value &store,
Value &result) {
Type resultETy = resultTy.getElementType();

Value initialState = extractState64(builder, loc, store);
if (!initialState) return failure();

std::pair<Value, Value> keys = extractKey32(builder, loc, store);
if (!keys.first || !keys.second) return failure();

int64_t numElements = resultTy.getNumElements();
int64_t count = (numElements + 3) / 4;
ShapedType intermediateType =
RankedTensorType::get({count, 1}, resultTy.getElementType());
int64_t concatDim = 1;

// Compute the number of random i64s generated and increment state.
Value countVal =
builder.create<arith::ConstantOp>(loc, builder.getI64IntegerAttr(count));
Value newState = builder.create<arith::AddIOp>(loc, initialState, countVal);

// set up four outputs
Value dest0 = builder.create<tensor::EmptyOp>(loc, ArrayRef<int64_t>({count}),
resultETy);
Value dest1 = builder.create<tensor::EmptyOp>(loc, ArrayRef<int64_t>({count}),
resultETy);
Value dest2 = builder.create<tensor::EmptyOp>(loc, ArrayRef<int64_t>({count}),
resultETy);
Value dest3 = builder.create<tensor::EmptyOp>(loc, ArrayRef<int64_t>({count}),
resultETy);

ShapedType destTy = dest0.getType().cast<ShapedType>();

SmallVector<AffineMap> indexingMaps(4, builder.getMultiDimIdentityMap(1));
SmallVector<utils::IteratorType> iterators(1, utils::IteratorType::parallel);

linalg::GenericOp generic = builder.create<linalg::GenericOp>(
loc, TypeRange{destTy, destTy, destTy, destTy},
/*inputs=*/ValueRange(),
/*outputs=*/ValueRange{dest0, dest1, dest2, dest3},
/*indexingMaps=*/indexingMaps, iterators,
[&](OpBuilder &b, Location nestedLoc, ValueRange) {
auto output =
runPhilox4x32(PhiloxKey{ArithOp(b, nestedLoc, keys.first),
ArithOp(b, nestedLoc, keys.second)},
ArithOp(b, nestedLoc, initialState));
auto out0 = output[0].truncI(resultETy.getIntOrFloatBitWidth());
auto out1 = output[1].truncI(resultETy.getIntOrFloatBitWidth());
auto out2 = output[2].truncI(resultETy.getIntOrFloatBitWidth());
auto out3 = output[3].truncI(resultETy.getIntOrFloatBitWidth());
b.create<linalg::YieldOp>(
loc, ValueRange{out0.val(), out1.val(), out2.val(), out3.val()});
});

if (resultTy.getNumElements() == 1) {
result = reshapeToTarget(builder, loc, resultTy, generic.getResult(0));
store = setState64(builder, loc, store, newState);
return success();
}

Value r0 =
reshapeToTarget(builder, loc, intermediateType, generic.getResult(0));
Value r1 =
reshapeToTarget(builder, loc, intermediateType, generic.getResult(1));
Value r2 =
reshapeToTarget(builder, loc, intermediateType, generic.getResult(2));
Value r3 =
reshapeToTarget(builder, loc, intermediateType, generic.getResult(3));

Value concatenate = builder.create<mhlo::ConcatenateOp>(
loc, ValueRange{r0, r1, r2, r3}, builder.getI64IntegerAttr(concatDim));

// Collapse the concat dimension back into the parent.
llvm::SmallVector<int64_t> collapseShape(intermediateType.getShape());
collapseShape[0] = collapseShape[0] * 4;
Value reshapeIntermediate = builder.create<mhlo::ReshapeOp>(
loc, resultTy.clone(collapseShape), concatenate);

// Slice to only the required results.
collapseShape[0] = resultTy.getNumElements();

llvm::SmallVector<int64_t> offset(resultTy.getRank(), 0);
llvm::SmallVector<int64_t> stride(resultTy.getRank(), 1);
Value slice = builder.create<mhlo::SliceOp>(
loc, intermediateType.clone(collapseShape), reshapeIntermediate,
builder.getI64TensorAttr(offset), builder.getI64TensorAttr(collapseShape),
builder.getI64TensorAttr(stride));
Value reshapeResult = builder.create<mhlo::ReshapeOp>(loc, resultTy, slice);

// Set the new tensor values.
store = setState64(builder, loc, store, newState);
result = reshapeResult;

return success();
}

LogicalResult generateLinalgPhilox64(OpBuilder &builder, Location loc,
ShapedType resultTy, Value &store,
Value &result) {
Type resultETy = resultTy.getElementType();

Value initialState = extractState64(builder, loc, store);
if (!initialState) return failure();

std::pair<Value, Value> keys = extractKey32(builder, loc, store);
if (!keys.first || !keys.second) return failure();

int64_t numElements = resultTy.getNumElements();
int64_t count = (numElements + 1) / 2;
ShapedType intermediateType =
RankedTensorType::get({count, 1}, resultTy.getElementType());
int64_t concatDim = 1;

// Compute the number of random i64s generated and increment state.
Value countVal =
builder.create<arith::ConstantOp>(loc, builder.getI64IntegerAttr(count));
Value newState = builder.create<arith::AddIOp>(loc, initialState, countVal);

// set up four outputs
Value dest0 = builder.create<tensor::EmptyOp>(loc, ArrayRef<int64_t>({count}),
resultETy);
Value dest1 = builder.create<tensor::EmptyOp>(loc, ArrayRef<int64_t>({count}),
resultETy);
ShapedType destTy = dest0.getType().cast<ShapedType>();

SmallVector<AffineMap> indexingMaps(2, builder.getMultiDimIdentityMap(1));
SmallVector<utils::IteratorType> iterators(1, utils::IteratorType::parallel);

linalg::GenericOp generic = builder.create<linalg::GenericOp>(
loc, TypeRange{destTy, destTy},
/*inputs=*/ValueRange(),
/*outputs=*/ValueRange{dest0, dest1},
/*indexingMaps=*/indexingMaps, iterators,
[&](OpBuilder &b, Location nestedLoc, ValueRange) {
auto output =
runPhilox4x32(PhiloxKey{ArithOp(b, nestedLoc, keys.first),
ArithOp(b, nestedLoc, keys.second)},
ArithOp(b, nestedLoc, initialState));
auto out0 = output[0];
auto out1 = output[1];
auto out2 = output[2];
auto out3 = output[3];
Value result1 = fuseI32s(out0, out1).val();
Value result2 = fuseI32s(out2, out3).val();
b.create<linalg::YieldOp>(loc, ValueRange{result1, result2});
});

if (resultTy.getNumElements() == 1) {
result = reshapeToTarget(builder, loc, resultTy, generic.getResult(0));
store = setState64(builder, loc, store, newState);
return success();
}

Value r0 =
reshapeToTarget(builder, loc, intermediateType, generic.getResult(0));
Value r1 =
reshapeToTarget(builder, loc, intermediateType, generic.getResult(1));
Value concatenate = builder.create<mhlo::ConcatenateOp>(
loc, ValueRange{r0, r1}, builder.getI64IntegerAttr(concatDim));

// Collapse the concat dimension back into the parent.
llvm::SmallVector<int64_t> collapseShape(intermediateType.getShape());
collapseShape[0] = collapseShape[0] * 2;
Value reshapeIntermediate = builder.create<mhlo::ReshapeOp>(
loc, resultTy.clone(collapseShape), concatenate);

// Slice to only the required results.
collapseShape[0] = resultTy.getNumElements();

llvm::SmallVector<int64_t> offset(resultTy.getRank(), 0);
llvm::SmallVector<int64_t> stride(resultTy.getRank(), 1);
Value slice = builder.create<mhlo::SliceOp>(
loc, intermediateType.clone(collapseShape), reshapeIntermediate,
builder.getI64TensorAttr(offset), builder.getI64TensorAttr(collapseShape),
builder.getI64TensorAttr(stride));
Value reshapeResult = builder.create<mhlo::ReshapeOp>(loc, resultTy, slice);

// Set the new tensor values.
store = setState64(builder, loc, store, newState);
result = reshapeResult;

return success();
}

} // namespace

LogicalResult generateLinalgThreeFry(OpBuilder &builder, Location loc,
Expand All @@ -485,5 +739,21 @@ LogicalResult generateLinalgThreeFry(OpBuilder &builder, Location loc,
return failure();
}

LogicalResult generateLinalgPhilox(OpBuilder &builder, Location loc,
ShapedType resultTy, Value &state,
Value &result) {
Type eTy = resultTy.getElementType();
if (eTy.getIntOrFloatBitWidth() == 64) {
return generateLinalgPhilox64(builder, loc, resultTy, state, result);
}

// The 32 bit implementation trancates to result eTy.
if (eTy.getIntOrFloatBitWidth() == 32 || eTy.getIntOrFloatBitWidth() == 16) {
return generateLinalgPhilox32(builder, loc, resultTy, state, result);
}

return failure();
}

} // namespace mhlo
} // namespace mlir
4 changes: 4 additions & 0 deletions tensorflow/compiler/xla/mlir_hlo/mhlo/utils/mhlo_rng_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ LogicalResult generateLinalgThreeFry(OpBuilder& builder, Location loc,
ShapedType resultTy, Value& state,
Value& result);

LogicalResult generateLinalgPhilox(OpBuilder& builder, Location loc,
ShapedType resultTy, Value& state,
Value& result);

} // namespace mhlo
} // namespace mlir

Expand Down
Loading

0 comments on commit 1374a1c

Please sign in to comment.