Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix buffer reusing #2490

Merged
merged 4 commits into from
Feb 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions third_party/nvfuser/csrc/kernel_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,7 @@ Allocate::Allocate(
TORCH_INTERNAL_ASSERT(alias->memoryType() == memory_type, "Invalid alias");
}

// FIXME: there is a bug in lower_alias_memory.cpp that causes
// `NVFuserTest.FusionPredicateElimination6_CUDA` to fail if I simplify `5*2`
// into `10`

// size = simplifyExpr(size);
size = simplifyExpr(size);

addInput(size);
addAttribute(buffer);
Expand Down
16 changes: 10 additions & 6 deletions third_party/nvfuser/csrc/lower_alias_memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,18 +321,22 @@ class BufferReuseDebugPrinter {
//! The first write and last read
//! is based on the position on the linear order within
//! the Kernel IR.
//! The interval is semi-open,
//! i.e. [First_Write, Last_Read)
//! So the buffer is NOT available at exactly First_Write
//! position while it IS available at Last_Read.
//! The interval is closed,
//! i.e. [First_Write, Last_Read]
//! So the buffer is NOT available from First_Write to
//! Last_Read position. For the case where First_Write
//! and Last_Read are identical, we can actually reuse
//! buffer if the read and write has exactly the same
//! index, however, for simplicity, we are not taking
//! advantage of this opportunity yet.
Comment on lines +324 to +331
Copy link
Collaborator Author

@zasdfgbnm zasdfgbnm Feb 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a rather big change, we will lose all opportunities like

T1 = set(T0)
T2 = set(T1)

T2 now can not reuse T1

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been wondering if this reuse is really beneficial. I just can't think of cases where any reasonable compiler can't reason about safe reuse. Shared memory is explicitly managed, so that still would be impacted, but I believe it'd be quite rate to have a pattern like above with shared memory.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been wondering if this reuse is really beneficial. I just can't think of cases where any reasonable compiler can't reason about safe reuse. Shared memory is explicitly managed, so that still would be impacted, but I believe it'd be quite rate to have a pattern like above with shared memory.

Yeah, agree

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That said, if we really need to explicitly reuse registers, we could do something similar to what the predicate elimination does. It checks both a producer and a consumer and see if they have the same transformations. If that's the case it should be safe to reuse, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I am thinking the same.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because of the case

T1[index1] = T0[index2]

where index1 and index2 are not the same. We can not reuse T0's allocation for T1.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yes. Now I remember it. Thanks.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The failing test, FusionPersistentNormLocalShared, seems to use both Local and Shared to do a normalization. It has a pattern like:

T8_s[ iS15{i3}, iS86{( ceilDiv(i4, 128) )}, ithreadIdx.x87{128} ] ca_pos( 1 ) produce_pos( 1 )
   = T25_s[ iS48{i3}, iS82{( ceilDiv(i4, 128) )}, ithreadIdx.x83{128} ] ca_pos( 1 )
   - T6_l[ iS11{i0}, bS52{( ceilDiv(1, 128) )}, bthreadIdx.x53{128} ] ca_pos( 1 ) produce_pos( 1 );

Here, previously, T8 and T25 were aliased, so they shared the same buffer, which is not the case now. I thought this pattern would be rare, but maybe not.

I'll disable the test for now. Are you working on improving the alias analysis?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not working on it, but I don't mind start working on it because it is needed for matmul epilogue fusion as well:
#1979
It this urgent? If so, I will start it right after the loop rotation. Otherwise, probably I will work on prologue swizzle first, and return back after prologue swizzle.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's urgent.

class BufferLiveInterval {
public:
// Simple detection of intersection of two intervals
bool intersect(BufferLiveInterval* other) {
if (first_write_pos_ <= other->first_write_pos_) {
return other->first_write_pos_ < last_read_pos_;
return other->first_write_pos_ <= last_read_pos_;
} else {
return first_write_pos_ < other->last_read_pos_;
return first_write_pos_ <= other->last_read_pos_;
}
}

Expand Down
11 changes: 7 additions & 4 deletions third_party/nvfuser/test/test_gpu2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8255,18 +8255,21 @@ TEST_F(NVFuserTest, FusionWARSyncAliasedSmem_CUDA) {
auto tv1 = add(tv0, IrBuilder::create<Double>(1));
auto tv2 = add(tv1, IrBuilder::create<Double>(1));
auto tv3 = add(tv2, IrBuilder::create<Double>(1));
auto tv4 = add(tv3, IrBuilder::create<Double>(1));

fusion.addOutput(tv3);
fusion.addOutput(tv4);

tv1->setMemoryType(MemoryType::Shared);
tv2->setMemoryType(MemoryType::Shared);
tv3->setMemoryType(MemoryType::Shared);

tv3->split(0, 4);
tv0->computeAt(tv3, 1);
tv4->split(0, 4);
tv0->computeAt(tv4, 1);

tv1->axis(-1)->parallelize(ParallelType::TIDx);
tv2->axis(-1)->parallelize(ParallelType::TIDy);
tv3->axis(-1)->parallelize(ParallelType::TIDz);
tv4->axis(-1)->parallelize(ParallelType::TIDx);

// Make sure a WAR sync is inserted at the end of the outer loop
GpuLower gpulw(&fusion);
Expand All @@ -8291,7 +8294,7 @@ TEST_F(NVFuserTest, FusionWARSyncAliasedSmem_CUDA) {
fe.compileFusion(&fusion, aten_inputs);
auto outputs = fe.runFusion(aten_inputs);

auto ref1 = t0 + 3;
auto ref1 = t0 + 4;

testValidate(&fusion, outputs, aten_inputs, {ref1}, __LINE__, __FILE__);
}
Expand Down