Skip to content

Commit

Permalink
[xla:gpu] Support creating ThunkInfo from HLO tensorflow#6224
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 574728456
  • Loading branch information
anlunx authored and tensorflower-gardener committed Oct 19, 2023
1 parent 9e4a43c commit 3fbde8e
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 13 deletions.
1 change: 1 addition & 0 deletions third_party/xla/xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,7 @@ cc_library(
deps = [
":buffer_allocations",
":gpu_executable_run_options",
"//xla/hlo/ir:hlo",
"//xla/service:executable",
"//xla/stream_executor",
"//xla/translate/mhlo_to_hlo:location_exporter",
Expand Down
8 changes: 3 additions & 5 deletions third_party/xla/xla/service/gpu/ir_emitter_unnested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1418,11 +1418,9 @@ Status IrEmitterUnnested::EmitCholeskyThunk(const HloInstruction* instr) {

ThunkSequence thunks;

// TODO(b/304613751): We need to construct real ThunkInfo from instr for Xprof
// integration.
if (operand_buffer != a_buffer) {
thunks.push_back(std::make_unique<DeviceToDeviceCopyThunk>(
Thunk::ThunkInfo(nullptr),
Thunk::ThunkInfo::WithProfileAnnotation(instr),
/*source_buffer=*/operand_buffer,
/*destination_buffer=*/a_buffer,
/*mem_size=*/ShapeUtil::ByteSizeOf(shape),
Expand All @@ -1431,7 +1429,7 @@ Status IrEmitterUnnested::EmitCholeskyThunk(const HloInstruction* instr) {
}

thunks.push_back(std::make_unique<CholeskyThunk>(
Thunk::ThunkInfo(nullptr), options,
Thunk::ThunkInfo::WithProfileAnnotation(instr), options,
PtxOptsFromDebugOptions(ir_emitter_context_->debug_options()), a_buffer,
workspace_buffer, info_buffer, shape.element_type(), batch_size, n));

Expand All @@ -1440,7 +1438,7 @@ Status IrEmitterUnnested::EmitCholeskyThunk(const HloInstruction* instr) {
AddThunkToThunkSequence(std::move(thunks[0]));
} else {
AddThunkToThunkSequence(std::make_unique<SequentialThunk>(
Thunk::ThunkInfo(nullptr), std::move(thunks)));
Thunk::ThunkInfo::WithProfileAnnotation(instr), std::move(thunks)));
}

return OkStatus();
Expand Down
9 changes: 9 additions & 0 deletions third_party/xla/xla/service/gpu/thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.
#include <string>

#include "absl/strings/str_format.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/translate/mhlo_to_hlo/location_exporter.h"

namespace xla {
Expand Down Expand Up @@ -130,5 +131,13 @@ Thunk::ThunkInfo Thunk::ThunkInfo::WithProfileAnnotation(mlir::Operation* op) {
return thunk_info;
}

Thunk::ThunkInfo Thunk::ThunkInfo::WithProfileAnnotation(
const HloInstruction* instr) {
ThunkInfo thunk_info(nullptr);
thunk_info.profile_annotation =
absl::StrFormat("Thunk:#hlo_op=%s#", instr->name());
return thunk_info;
}

} // namespace gpu
} // namespace xla
14 changes: 6 additions & 8 deletions third_party/xla/xla/service/gpu/thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ limitations under the License.

#include "absl/types/span.h"
#include "mlir/IR/Operation.h" // from @llvm-project
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/service/gpu/buffer_allocations.h"
#include "xla/service/gpu/gpu_executable_run_options.h"
#include "xla/service/service_executable_run_options.h"
Expand Down Expand Up @@ -115,19 +116,20 @@ class Thunk {

struct ThunkInfo {
explicit ThunkInfo(mlir::Operation* op) : op(op) {}
std::optional<int64_t> profile_index;
static ThunkInfo WithProfileAnnotation(mlir::Operation* op);
static ThunkInfo WithProfileAnnotation(const HloInstruction* instr);

std::string profile_annotation;
// TODO(b/304613751): This is only needed by the LMHLO. Remove this when
// LMHLO is removed from the runtime pipeline.
mlir::Operation* op;

static ThunkInfo WithProfileAnnotation(mlir::Operation* op);
};

// The hlo_instruction argument is meant to be the instruction this thunk was
// generated from, but Thunk never uses this argument other than to save it
// to Thunk::hlo_instruction, so it can be null.
Thunk(Kind kind, ThunkInfo thunk_info)
: kind_(kind),
profile_index_(thunk_info.profile_index),
profile_annotation_(thunk_info.profile_annotation),
op_(thunk_info.op) {}
virtual ~Thunk() = default;
Expand Down Expand Up @@ -177,12 +179,8 @@ class Thunk {

static absl::string_view KindToString(Thunk::Kind kind);

protected:
std::optional<int64_t> profile_index() const { return profile_index_; }

private:
Kind kind_;
std::optional<int64_t> profile_index_;
std::string profile_annotation_;
mlir::Operation* op_;
};
Expand Down

0 comments on commit 3fbde8e

Please sign in to comment.