-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
expand removing consecutive cast to handle meta operations in between #3644
Conversation
!test |
!test |
!test |
The benchmark failure is coming from some segmentation. i.e. there's some I patched that in #3670 |
Good idea. In addition, you could propagate up casts to outputs. Hopefully, after propagating up and down, the cancellable casts will be adjacent and be trivial to remove. (It's certainly fine to leave this for the future.) |
|
||
// replays meta operation on `new_in`. return the new output from replayed meta | ||
// operation | ||
Val* replayMetaOnNewInput( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to reuse/extend
Line 295 in 37e7005
Expr* replayExprWithNewInput(Expr* e, Val* new_in); |
continue; | ||
} | ||
// We do not support the replay if expr out has non-trivial transforms | ||
// between its logical_dom to alloc_dom. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Non-permuting allocation domains will the norm for multi-GPU fusions with DID loop split. Anything you can do to save my future time will be greatly appreciated!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm... that's a good point. I can have a follow up PR on that. We can/should add some utility on replaying allocation domain.
|
||
// adding prev_expr to visited node so we'll short-cut it. | ||
visited.insert(prev_expr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm unsure we need or will still need visited
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good call. I forgot we are traversing through exprs, instead of following the data flow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
realized that I still need the visited here.
i.e. this function removeChainedCasts(expr, visited)
could have removed previous exprs in the for loop, which means we cannot call isCast(expr)
since those are dangling pointers now.
if (visited.count(expr) != 0 || !isCast(expr)) {
continue;
}
PR Reviewer Guide 🔍(Review updated until commit 9f920ca)Here are some key observations to aid the review process:
|
!test --diff-bench |
!test --diff-bench |
|
||
// replays meta operation on `new_in`. return the new output from replayed meta | ||
// operation | ||
Val* replayMetaOnNewInput( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SGTM. Would you mind leaving a TODO in code so readers will know we intend to extend and reuse?
void castOptimizationPass(Fusion* fusion) { | ||
FusionGuard fusion_guard(fusion); | ||
auto exprs = fusion->exprs(); | ||
std::unordered_set<Expr*> visited; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
std::unordered_set<Expr*> visited; | |
std::unordered_set<Expr*> folded; |
Or removed
Expr* removeChainedCasts(Expr* expr, std::unordered_set<Expr*>& visited) { | ||
std::list<Val*> chain_cast_vals; | ||
auto prev_expr = expr->input(0)->definition(); | ||
while (isCast(prev_expr)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: because removedChainedCasts
is called in a while loop, we don't really need another while loop here. That may help simplify the logic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see an easy way to simplify the logic.
Since the logic here relies on analyzing the chain of casts and tries to simplify it based on the high/low water mark of cast types
We can embed that into the parent while loop, but I think that's going to make it harder to read, since we are interleaving it with the meta/cast swap part.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe we can fold iteratively without having to analyze the whole chain. We only need to apply two rules to a fixed point.
- when we see
dtype A -> (cast) -> dtype B -> (cast) -> dtype C
and that B includes A, simplify it todtype A -> (cast) -> dtype C
. - when we see
dtype A -> (cast) -> dtype A
(i.e. the producer TV and consumer have the same dtype), remove the cast.
Am I understanding this correctly?
(In case I wasn't clear, I'm not requesting changes in this PR)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a fair thing to try.
The existing approach is a bit more aggressive, but it might not be necessary to have that and it's a fair ask to change it for simpler implementation. (Also, I realized that using isInclusiveType
to figure out the lowest precision here might not be numerically safe.)
I'll follow up with a quick refactor on that. Need to double check then if any existing tests need to be modified.
!test |
Hitting this issue on thunder again. #3660 |
!test |
can't repro the matmul on H100... 😕 trying the CI again |
!test |
Existing ConsecutiveCast optimization pass only optimize a consecutive cast operations. This PR expand ConsecutiveCast pass to handle cases where a chain of cast operations is broken by a meta operation in the middle.
e.g.
The existing pass wouldn't be able to cancel out the two casts, because they are separated by the squeeze operation.
In this PR, before we trace back from the last CastOp for the chain of casts, we look at the input to the cast operation. If it's a movable meta operation, we swap the order of the meta op and the cast op first, then we resume the chain look up on consecutive casts.