Skip to content

Commit

Permalink
Fix buffer reusing (#2490)
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm authored Feb 18, 2023
1 parent 08dc16d commit 3ed70b2
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 15 deletions.
6 changes: 1 addition & 5 deletions third_party/nvfuser/csrc/kernel_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,7 @@ Allocate::Allocate(
TORCH_INTERNAL_ASSERT(alias->memoryType() == memory_type, "Invalid alias");
}

// FIXME: there is a bug in lower_alias_memory.cpp that causes
// `NVFuserTest.FusionPredicateElimination6_CUDA` to fail if I simplify `5*2`
// into `10`

// size = simplifyExpr(size);
size = simplifyExpr(size);

addInput(size);
addAttribute(buffer);
Expand Down
16 changes: 10 additions & 6 deletions third_party/nvfuser/csrc/lower_alias_memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,18 +321,22 @@ class BufferReuseDebugPrinter {
//! The first write and last read
//! is based on the position on the linear order within
//! the Kernel IR.
//! The interval is semi-open,
//! i.e. [First_Write, Last_Read)
//! So the buffer is NOT available at exactly First_Write
//! position while it IS available at Last_Read.
//! The interval is closed,
//! i.e. [First_Write, Last_Read]
//! So the buffer is NOT available from First_Write to
//! Last_Read position. For the case where First_Write
//! and Last_Read are identical, we can actually reuse
//! buffer if the read and write has exactly the same
//! index, however, for simplicity, we are not taking
//! advantage of this opportunity yet.
class BufferLiveInterval {
public:
// Simple detection of intersection of two intervals
bool intersect(BufferLiveInterval* other) {
if (first_write_pos_ <= other->first_write_pos_) {
return other->first_write_pos_ < last_read_pos_;
return other->first_write_pos_ <= last_read_pos_;
} else {
return first_write_pos_ < other->last_read_pos_;
return first_write_pos_ <= other->last_read_pos_;
}
}

Expand Down
11 changes: 7 additions & 4 deletions third_party/nvfuser/test/test_gpu2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8255,18 +8255,21 @@ TEST_F(NVFuserTest, FusionWARSyncAliasedSmem_CUDA) {
auto tv1 = add(tv0, IrBuilder::create<Double>(1));
auto tv2 = add(tv1, IrBuilder::create<Double>(1));
auto tv3 = add(tv2, IrBuilder::create<Double>(1));
auto tv4 = add(tv3, IrBuilder::create<Double>(1));

fusion.addOutput(tv3);
fusion.addOutput(tv4);

tv1->setMemoryType(MemoryType::Shared);
tv2->setMemoryType(MemoryType::Shared);
tv3->setMemoryType(MemoryType::Shared);

tv3->split(0, 4);
tv0->computeAt(tv3, 1);
tv4->split(0, 4);
tv0->computeAt(tv4, 1);

tv1->axis(-1)->parallelize(ParallelType::TIDx);
tv2->axis(-1)->parallelize(ParallelType::TIDy);
tv3->axis(-1)->parallelize(ParallelType::TIDz);
tv4->axis(-1)->parallelize(ParallelType::TIDx);

// Make sure a WAR sync is inserted at the end of the outer loop
GpuLower gpulw(&fusion);
Expand All @@ -8291,7 +8294,7 @@ TEST_F(NVFuserTest, FusionWARSyncAliasedSmem_CUDA) {
fe.compileFusion(&fusion, aten_inputs);
auto outputs = fe.runFusion(aten_inputs);

auto ref1 = t0 + 3;
auto ref1 = t0 + 4;

testValidate(&fusion, outputs, aten_inputs, {ref1}, __LINE__, __FILE__);
}
Expand Down

0 comments on commit 3ed70b2

Please sign in to comment.