Skip to content

Commit

Permalink
Fix alias analysis (#185)
Browse files Browse the repository at this point in the history
  • Loading branch information
naoyam authored Apr 20, 2023
1 parent 255af2b commit d68da6b
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 11 deletions.
9 changes: 9 additions & 0 deletions csrc/ir_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,15 @@ bool isReductionTvOp(const Expr* expr) {
return ir_utils::isTvOp(expr) && isReductionOp(expr);
}

bool isPointwiseTvOp(const Expr* expr) {
// LoadStoreOp with rfactor domain means transpose, which is not
// considered pointwise
return isTvOp(expr) &&
(expr->isOneOf<UnaryOp, BinaryOp, TernaryOp>() ||
(expr->isA<LoadStoreOp>() &&
!ir_utils::getTvOutput(expr)->hasRFactor()));
}

std::vector<ViewOp*> getViewOps(Fusion* fusion) {
auto all_exprs = fusion->exprs();

Expand Down
3 changes: 3 additions & 0 deletions csrc/ir_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,9 @@ TORCH_CUDA_CU_API bool isReductionOp(const Expr*);
// Returns if Expr is a reduction op with TensorView or TensorIndex
TORCH_CUDA_CU_API bool isReductionTvOp(const Expr*);

// Returns if Expr is a pointwise op op with TensorView or TensorIndex
TORCH_CUDA_CU_API bool isPointwiseTvOp(const Expr* expr);

// Returns all non-trivial view operations. We shouldn't have trivial view
// operations but this function is to simply make sure if we ever do we don't
// pull them in.
Expand Down
13 changes: 2 additions & 11 deletions csrc/lower_alias_memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1225,7 +1225,8 @@ class ReusableAllocationFinder : private kir::IrVisitor {
if (!tv_def) {
continue;
}
if (!isPointwiseTvOp(tv_def) && !ir_utils::isReductionTvOp(tv_def)) {
if (!ir_utils::isPointwiseTvOp(tv_def) &&
!ir_utils::isReductionTvOp(tv_def)) {
if (isBroadcastTvOp(tv_def)) {
info.has_broadcast_between = true;
} else {
Expand Down Expand Up @@ -1266,16 +1267,6 @@ class ReusableAllocationFinder : private kir::IrVisitor {
}
}

// Do we have a true pointwise op?
// (ie. a TV op, excluding direct assignments and reductions)
bool isPointwiseTvOp(const Expr* expr) {
if (ir_utils::isTvOp(expr)) {
return expr->isA<UnaryOp>() || expr->isA<BinaryOp>() ||
expr->isA<TernaryOp>();
}
return false;
}

// Utility to capture broadcast ops
bool isBroadcastTvOp(const Expr* expr) {
if (!ir_utils::isTvOp(expr)) {
Expand Down

0 comments on commit d68da6b

Please sign in to comment.