From d68da6b4021bcb9ceae29fb1d1823ee49b51b887 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 19 Apr 2023 17:11:24 -0700 Subject: [PATCH] Fix alias analysis (#185) --- csrc/ir_utils.cpp | 9 +++++++++ csrc/ir_utils.h | 3 +++ csrc/lower_alias_memory.cpp | 13 ++----------- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/csrc/ir_utils.cpp b/csrc/ir_utils.cpp index cba2c1557b8..206e5f67192 100644 --- a/csrc/ir_utils.cpp +++ b/csrc/ir_utils.cpp @@ -572,6 +572,15 @@ bool isReductionTvOp(const Expr* expr) { return ir_utils::isTvOp(expr) && isReductionOp(expr); } +bool isPointwiseTvOp(const Expr* expr) { + // LoadStoreOp with rfactor domain means transpose, which is not + // considered pointwise + return isTvOp(expr) && + (expr->isOneOf() || + (expr->isA() && + !ir_utils::getTvOutput(expr)->hasRFactor())); +} + std::vector getViewOps(Fusion* fusion) { auto all_exprs = fusion->exprs(); diff --git a/csrc/ir_utils.h b/csrc/ir_utils.h index b7f9b53eac1..74bb52803ad 100644 --- a/csrc/ir_utils.h +++ b/csrc/ir_utils.h @@ -338,6 +338,9 @@ TORCH_CUDA_CU_API bool isReductionOp(const Expr*); // Returns if Expr is a reduction op with TensorView or TensorIndex TORCH_CUDA_CU_API bool isReductionTvOp(const Expr*); +// Returns if Expr is a pointwise op op with TensorView or TensorIndex +TORCH_CUDA_CU_API bool isPointwiseTvOp(const Expr* expr); + // Returns all non-trivial view operations. We shouldn't have trivial view // operations but this function is to simply make sure if we ever do we don't // pull them in. diff --git a/csrc/lower_alias_memory.cpp b/csrc/lower_alias_memory.cpp index 9523d6e3ab9..58036e46dbb 100644 --- a/csrc/lower_alias_memory.cpp +++ b/csrc/lower_alias_memory.cpp @@ -1225,7 +1225,8 @@ class ReusableAllocationFinder : private kir::IrVisitor { if (!tv_def) { continue; } - if (!isPointwiseTvOp(tv_def) && !ir_utils::isReductionTvOp(tv_def)) { + if (!ir_utils::isPointwiseTvOp(tv_def) && + !ir_utils::isReductionTvOp(tv_def)) { if (isBroadcastTvOp(tv_def)) { info.has_broadcast_between = true; } else { @@ -1266,16 +1267,6 @@ class ReusableAllocationFinder : private kir::IrVisitor { } } - // Do we have a true pointwise op? - // (ie. a TV op, excluding direct assignments and reductions) - bool isPointwiseTvOp(const Expr* expr) { - if (ir_utils::isTvOp(expr)) { - return expr->isA() || expr->isA() || - expr->isA(); - } - return false; - } - // Utility to capture broadcast ops bool isBroadcastTvOp(const Expr* expr) { if (!ir_utils::isTvOp(expr)) {