From 3ed70b23c9e0f94cba9c33b305d51853dee8345c Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Fri, 17 Feb 2023 17:35:50 -0800 Subject: [PATCH] Fix buffer reusing (#2490) --- third_party/nvfuser/csrc/kernel_ir.cpp | 6 +----- third_party/nvfuser/csrc/lower_alias_memory.cpp | 16 ++++++++++------ third_party/nvfuser/test/test_gpu2.cpp | 11 +++++++---- 3 files changed, 18 insertions(+), 15 deletions(-) diff --git a/third_party/nvfuser/csrc/kernel_ir.cpp b/third_party/nvfuser/csrc/kernel_ir.cpp index c1c669794f0ab..faeeeaadc5267 100644 --- a/third_party/nvfuser/csrc/kernel_ir.cpp +++ b/third_party/nvfuser/csrc/kernel_ir.cpp @@ -157,11 +157,7 @@ Allocate::Allocate( TORCH_INTERNAL_ASSERT(alias->memoryType() == memory_type, "Invalid alias"); } - // FIXME: there is a bug in lower_alias_memory.cpp that causes - // `NVFuserTest.FusionPredicateElimination6_CUDA` to fail if I simplify `5*2` - // into `10` - - // size = simplifyExpr(size); + size = simplifyExpr(size); addInput(size); addAttribute(buffer); diff --git a/third_party/nvfuser/csrc/lower_alias_memory.cpp b/third_party/nvfuser/csrc/lower_alias_memory.cpp index 77cd30a9b9c6c..6f932c3c2b69c 100644 --- a/third_party/nvfuser/csrc/lower_alias_memory.cpp +++ b/third_party/nvfuser/csrc/lower_alias_memory.cpp @@ -321,18 +321,22 @@ class BufferReuseDebugPrinter { //! The first write and last read //! is based on the position on the linear order within //! the Kernel IR. -//! The interval is semi-open, -//! i.e. [First_Write, Last_Read) -//! So the buffer is NOT available at exactly First_Write -//! position while it IS available at Last_Read. +//! The interval is closed, +//! i.e. [First_Write, Last_Read] +//! So the buffer is NOT available from First_Write to +//! Last_Read position. For the case where First_Write +//! and Last_Read are identical, we can actually reuse +//! buffer if the read and write has exactly the same +//! index, however, for simplicity, we are not taking +//! advantage of this opportunity yet. class BufferLiveInterval { public: // Simple detection of intersection of two intervals bool intersect(BufferLiveInterval* other) { if (first_write_pos_ <= other->first_write_pos_) { - return other->first_write_pos_ < last_read_pos_; + return other->first_write_pos_ <= last_read_pos_; } else { - return first_write_pos_ < other->last_read_pos_; + return first_write_pos_ <= other->last_read_pos_; } } diff --git a/third_party/nvfuser/test/test_gpu2.cpp b/third_party/nvfuser/test/test_gpu2.cpp index c7c778b4017da..6afdcdfb7aedd 100644 --- a/third_party/nvfuser/test/test_gpu2.cpp +++ b/third_party/nvfuser/test/test_gpu2.cpp @@ -8255,18 +8255,21 @@ TEST_F(NVFuserTest, FusionWARSyncAliasedSmem_CUDA) { auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = add(tv1, IrBuilder::create(1)); auto tv3 = add(tv2, IrBuilder::create(1)); + auto tv4 = add(tv3, IrBuilder::create(1)); - fusion.addOutput(tv3); + fusion.addOutput(tv4); tv1->setMemoryType(MemoryType::Shared); tv2->setMemoryType(MemoryType::Shared); + tv3->setMemoryType(MemoryType::Shared); - tv3->split(0, 4); - tv0->computeAt(tv3, 1); + tv4->split(0, 4); + tv0->computeAt(tv4, 1); tv1->axis(-1)->parallelize(ParallelType::TIDx); tv2->axis(-1)->parallelize(ParallelType::TIDy); tv3->axis(-1)->parallelize(ParallelType::TIDz); + tv4->axis(-1)->parallelize(ParallelType::TIDx); // Make sure a WAR sync is inserted at the end of the outer loop GpuLower gpulw(&fusion); @@ -8291,7 +8294,7 @@ TEST_F(NVFuserTest, FusionWARSyncAliasedSmem_CUDA) { fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); - auto ref1 = t0 + 3; + auto ref1 = t0 + 4; testValidate(&fusion, outputs, aten_inputs, {ref1}, __LINE__, __FILE__); }