Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expand removing consecutive cast to handle meta operations in between #3644

Merged
merged 56 commits into from
Jan 21, 2025

Conversation

jjsjann123
Copy link
Collaborator

@jjsjann123 jjsjann123 commented Dec 24, 2024

Existing ConsecutiveCast optimization pass only optimize a consecutive cast operations. This PR expand ConsecutiveCast pass to handle cases where a chain of cast operations is broken by a meta operation in the middle.

e.g.

T1 = castOp(T0, fp32)
T2 = squeeze(T1)
T3 = castOp(T2, fp16)

The existing pass wouldn't be able to cancel out the two casts, because they are separated by the squeeze operation.

In this PR, before we trace back from the last CastOp for the chain of casts, we look at the input to the cast operation. If it's a movable meta operation, we swap the order of the meta op and the cast op first, then we resume the chain look up on consecutive casts.

@jjsjann123
Copy link
Collaborator Author

!test

@jjsjann123
Copy link
Collaborator Author

!test

@jjsjann123
Copy link
Collaborator Author

!test

@jjsjann123
Copy link
Collaborator Author

The benchmark failure is coming from some segmentation. i.e. there's some set->cast pattern in NvFuserScheduler_TIMM_vit_base_patch16_224_bcast5_NCHW___GRAPH/NvFuserScheduler_TIMM_vit_base_patch16_224_bcast5_NCHW/64/197/768/manual_time, which now would throw some no-op segments after the reorder.

I patched that in #3670

@wujingyue
Copy link
Collaborator

propagate downCast to input

Good idea. In addition, you could propagate up casts to outputs. Hopefully, after propagating up and down, the cancellable casts will be adjacent and be trivial to remove.

(It's certainly fine to leave this for the future.)

csrc/ir/utils.cpp Outdated Show resolved Hide resolved
csrc/ir/utils.cpp Outdated Show resolved Hide resolved
csrc/preseg_passes/consecutive_cast.cpp Outdated Show resolved Hide resolved
csrc/preseg_passes/consecutive_cast.cpp Outdated Show resolved Hide resolved

// replays meta operation on `new_in`. return the new output from replayed meta
// operation
Val* replayMetaOnNewInput(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to reuse/extend

Expr* replayExprWithNewInput(Expr* e, Val* new_in);
?

csrc/preseg_passes/consecutive_cast.cpp Outdated Show resolved Hide resolved
continue;
}
// We do not support the replay if expr out has non-trivial transforms
// between its logical_dom to alloc_dom.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-permuting allocation domains will the norm for multi-GPU fusions with DID loop split. Anything you can do to save my future time will be greatly appreciated!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm... that's a good point. I can have a follow up PR on that. We can/should add some utility on replaying allocation domain.

csrc/preseg_passes/consecutive_cast.cpp Outdated Show resolved Hide resolved

// adding prev_expr to visited node so we'll short-cut it.
visited.insert(prev_expr);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm unsure we need or will still need visited.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good call. I forgot we are traversing through exprs, instead of following the data flow.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

realized that I still need the visited here.

i.e. this function removeChainedCasts(expr, visited) could have removed previous exprs in the for loop, which means we cannot call isCast(expr) since those are dangling pointers now.

    if (visited.count(expr) != 0 || !isCast(expr)) {
      continue;
    }

Copy link

github-actions bot commented Jan 15, 2025

PR Reviewer Guide 🔍

(Review updated until commit 9f920ca)

Here are some key observations to aid the review process:

⏱️ Estimated effort to review: 4 🔵🔵🔵🔵⚪
🧪 PR contains tests
⚡ Recommended focus areas for review

Logic Change

The removeChainedCasts function has been modified to handle cases where a chain of cast operations is broken by a meta operation in the middle. The function now checks if the output of the previous expression is used by other operations or is a direct output from fusion, and if so, it skips the casting chaining. Additionally, the function now updates the starting_anchor to the current val if the val is not a no-op, which ensures that the incompatible casts are preserved.

//        a. if `anchor_dtype` is no narrower than `output_dtype`, all previous
//        cast after `starting_anchor` is no-op, we re-wire `starting_anchor`
//        directly to `expr`;
//
//        b. otherwise, we can't bypass `lo_anchor` cast, we rewire this
//        section as `starting_anchor`->`lo_anchor`->`expr->output(0)`
Expr* removeChainedCasts(Expr* expr, std::unordered_set<Expr*>& folded) {
  std::list<Val*> chain_cast_vals;
  auto prev_expr = expr->input(0)->definition();
  while (isCast(prev_expr)) {
    auto intermediate_cast = prev_expr->output(0);
    // 1.2 Note, if the output of prev_expr
    //   is used by other operation(s); or
    //   is a direct output from fusion
    // we skip the casting chaining
    if (intermediate_cast->isFusionOutput() ||
        intermediate_cast->uses().size() > 1) {
      break;
    }

    // adding prev_expr to folded so we'll short-cut it.
    folded.insert(prev_expr);
    // in the loop, we just repetitively chaining consecutive casts.
    chain_cast_vals.push_front(intermediate_cast);
    prev_expr = prev_expr->input(0)->definition();
  }

  // skip current expr if there's no chain_cast_vals
  if (chain_cast_vals.empty()) {
    return expr;
  }

  // 1.3.1 Note, chain_cast_vals has a straight-line use without branches
  auto lo_anchor = chain_cast_vals.front()->definition()->input(0);
  auto anchor_dtype = lo_anchor->getDataType().value();
  auto starting_anchor = lo_anchor;
  for (auto val : chain_cast_vals) {
    auto val_dtype = val->getDataType().value();

    // 1.3.2.a short-cut when we are not losing precision
    if (isInclusiveType(anchor_dtype, val_dtype)) {
      continue;
    }

    // 1.3.2.c NOTE: To enter here, we have
    //   !isInclusiveType(anchor_dtype, val_dtype) &&
    //   !isInclusiveType(val_dtype, anchor_dtype)
    //
    // Which means the dtype between lo_anchor and val isn't compatible and
    // can't be fold away without losing information. So we update the
    // starting_anchor to current val, which ensures that we preserve the
    // incompatible casts. e.g. for cases where no one type is strictly wider
    // than the other: i.e. bf16 & fp16, int32 & float32 e.t.c.
    if (!isInclusiveType(val_dtype, anchor_dtype)) {
      lo_anchor = replaceInputInCast(lo_anchor, starting_anchor);
      val = replaceInputInCast(val, lo_anchor);
      // We need to update the starting_anchor for the fold to be past this
      // current cast.
      starting_anchor = val;
    }
    // 1.3.2.b/c updating new lo_anchor to current val
    lo_anchor = val;
    anchor_dtype = lo_anchor->getDataType().value();
  }

  auto output_dtype = expr->output(0)->getDataType().value();

  if (isInclusiveType(output_dtype, anchor_dtype)) {
    // 1.4.a: if lo_anchor is no narrower than output_dtype, everything is an
    // no-op

    if (starting_anchor->getDataType().value() == output_dtype) {
      // if output dtype is identical to starting_anchor dtype, we can't keep
      // the last cast op and will need to re-write all uses here
      ir_utils::replaceValue(
          expr->fusion(), {{expr->output(0), starting_anchor}});
    } else {
      Val* new_expr_val = replaceInputInCast(expr->output(0), starting_anchor);
      expr = new_expr_val->definition();
    }
  } else {
    // 1.4.b: This is the case where we cannot fold away the cast of
    // lo_anchor; we'll just re-wire input to expr with lo_anchor
    lo_anchor = replaceInputInCast(lo_anchor, starting_anchor);
    Val* new_expr_val = replaceInputInCast(expr->output(0), lo_anchor);
    expr = new_expr_val->definition();
  }
  return expr;
}
Function Signature Change

The castOptimizationPass function has been modified to take an additional folded parameter, which is used to keep track of the expressions that have been folded.

void castOptimizationPass(Fusion* fusion) {
  FusionGuard fusion_guard(fusion);
  auto exprs = fusion->exprs();
  std::unordered_set<Expr*> folded;
  for (auto iter = exprs.rbegin(); iter != exprs.rend(); ++iter) {
    auto expr = *iter;
    // skip current expr if it's not a foldable cast or it has already been
    // removed in removeChainedCasts and is now a dangling pointer.
    if (folded.count(expr) != 0 || !isCast(expr)) {
      continue;
    }

    // initialize changed to true so we'll enter the loop in initial iteration.
    bool changed = true;
    while (changed) {
      changed = false;
      // when down cast follows a meta operation that's safe to be swapped, we
      // do so for two reasons:
      // 1. lifting a down cast to inputs would reduce intermediate buffer size
      // 2. it might place the cast op next to another cast op that can be
      // optimized away. e.g. for a trivial reduction on reduced precision, the
      // pattern will be
      //    T1 = castOp(T0, fp32)
      //    T2 = squeeze(T1)
      //    T3 = castOp(T2, fp16) // downCast
      // by swapping the last two op, we get
      //    T1 = castOp(T0, fp32)
      //    T2 = castOp(T1, fp16)
      //    T3 = squeeze(T2)      // operation in reduced precision
      // and we can further cancel out the two cast ops.
      if (shouldSwapMetaCast(expr)) {
        // replay [meta -> expr] with
        //        [replayed_expr -> replayed_meta]
        Val* expr_out = expr->output(0);

        // initializing alloc_domain permutation as empty
        std::optional<std::vector<int64_t>> expr_out_allocation_permutation =
            {};

        // compute logical_dom to alloc_dom permutation
        if (expr_out->isA<TensorView>()) {
          TensorView* expr_out_tv = expr_out->as<TensorView>();
          expr_out_allocation_permutation = ir_utils::computePermutation(
              expr_out_tv->getLogicalDomain(),
              expr_out_tv->getMaybeAllocationDomain());
        }

        // We do not support the replay if expr out has non-trivial transforms
        // between its logical_dom to alloc_dom.
        if (expr_out_allocation_permutation.has_value()) {
          Expr* meta = expr->input(0)->definition();

          // replayed expr(cast).
          Val* replayed_expr_out = castOp(expr_out->dtype(), meta->input(0));

          // replay meta operation on replayed expr output.
          Val* replayed_meta_out = replayMetaOnNewInput(
              meta, replayed_expr_out, expr_out_allocation_permutation.value());

          // replace uses of expr output with output of replayed_meta.
          ir_utils::replaceValInAllExprInputsAndFusionOutputs(
              expr_out, replayed_meta_out);

          // update expr to point to the replayed_expr
          expr = replayed_expr_out->definition();
        }
        changed = true;
      }

      // optimize chained cast operations ending at expr
      if (Expr* new_expr = removeChainedCasts(expr, folded); new_expr != expr) {
        expr = new_expr;
        changed = true;
      }
    }
  }
}
Function Addition

A new function isSimpleTVSet has been added to check whether an expression is a simple Set of a TensorView.

//! Checks whether this is a simple Set of a TensorView. If not, then this might
//! represent a scalar set, or a segment_set.
bool isSimpleTVSet(Expr* expr);

@jjsjann123
Copy link
Collaborator Author

!test --diff-bench

@jjsjann123
Copy link
Collaborator Author

!test --diff-bench

@jjsjann123 jjsjann123 requested a review from wujingyue January 16, 2025 09:01

// replays meta operation on `new_in`. return the new output from replayed meta
// operation
Val* replayMetaOnNewInput(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM. Would you mind leaving a TODO in code so readers will know we intend to extend and reuse?

void castOptimizationPass(Fusion* fusion) {
FusionGuard fusion_guard(fusion);
auto exprs = fusion->exprs();
std::unordered_set<Expr*> visited;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
std::unordered_set<Expr*> visited;
std::unordered_set<Expr*> folded;

Or removed

Expr* removeChainedCasts(Expr* expr, std::unordered_set<Expr*>& visited) {
std::list<Val*> chain_cast_vals;
auto prev_expr = expr->input(0)->definition();
while (isCast(prev_expr)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: because removedChainedCasts is called in a while loop, we don't really need another while loop here. That may help simplify the logic.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see an easy way to simplify the logic.
Since the logic here relies on analyzing the chain of casts and tries to simplify it based on the high/low water mark of cast types

We can embed that into the parent while loop, but I think that's going to make it harder to read, since we are interleaving it with the meta/cast swap part.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we can fold iteratively without having to analyze the whole chain. We only need to apply two rules to a fixed point.

  1. when we see dtype A -> (cast) -> dtype B -> (cast) -> dtype C and that B includes A, simplify it to dtype A -> (cast) -> dtype C.
  2. when we see dtype A -> (cast) -> dtype A (i.e. the producer TV and consumer have the same dtype), remove the cast.
    Am I understanding this correctly?

(In case I wasn't clear, I'm not requesting changes in this PR)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a fair thing to try.

The existing approach is a bit more aggressive, but it might not be necessary to have that and it's a fair ask to change it for simpler implementation. (Also, I realized that using isInclusiveType to figure out the lowest precision here might not be numerically safe.)

I'll follow up with a quick refactor on that. Need to double check then if any existing tests need to be modified.

@jjsjann123
Copy link
Collaborator Author

!test

@jjsjann123
Copy link
Collaborator Author

Hitting this issue on thunder again. #3660
But i'm also seeing a failure in matmul. Let me see if I can repro that locally.

@jjsjann123
Copy link
Collaborator Author

!test

@jjsjann123
Copy link
Collaborator Author

can't repro the matmul on H100... 😕 trying the CI again

@jjsjann123
Copy link
Collaborator Author

!test

@jjsjann123 jjsjann123 merged commit 7cfeac3 into main Jan 21, 2025
51 checks passed
@jjsjann123 jjsjann123 deleted the preseg_passes_consecutive_cast branch January 21, 2025 21:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants