Skip to content

Commit

Permalink
[BACKEND][NFC] Interface for LinearLayout conversion. (#4525)
Browse files Browse the repository at this point in the history
Currently, implementing a LinearLayout conversion from a new layout
requires implementing the conversion method in
`LinearLayoutConversions.cpp`.
The current approach is not ideal because:
- It does not support modular design in that a third-party layout cannot
implement its own conversion externally.
- There is no compile-time method to check whether a layout supports
LinearLayout conversion.

This PR suggests adding a TableGen trait to interface LinearLayout
conversion method.
The process is now simplified as 1) add `LinearLayoutTrait` to the
layout and implement`[Layout]::toLinearLayout` function.


Checklists:
- [x] I am not making a trivial change, such as fixing a typo in a
comment.

- [x] I have written a PR description following these
  [rules](https://cbea.ms/git-commit/#why-not-how).

- [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`.

- Select one of the following.
  - [ ] I have added tests.
    - `/test` for `lit` tests
    - `/unittest` for C++ tests
    - `/python/test` for end-to-end tests
- [x] This PR does not need a test because it is an improvement of
programming interface of an existing feature.

- Select one of the following.
  - [x] I have not added any `lit` tests.
- [ ] The `lit` tests I have added follow these [best
practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices),
including the "tests should be minimal" section. (Usually running Python
code
    and using the instructions it generates is not minimal.)
  • Loading branch information
hwnam831 authored Aug 17, 2024
1 parent f0d9c7a commit 2abfaec
Show file tree
Hide file tree
Showing 3 changed files with 197 additions and 191 deletions.
6 changes: 6 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,10 @@ For the Threads Per Warp and Values Per Thread level, the linear id distribution
InterfaceMethod<"Gets the number of contiguous elements per thread.",
"SmallVector<unsigned>",
"getContigPerThread">,
InterfaceMethod<"Convert to LinearLayout.",
"std::optional<LinearLayout>",
"toLinearLayout",
(ins "ArrayRef<int64_t>":$shape)>
];
}

Expand Down Expand Up @@ -576,6 +580,8 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},

SmallVector<unsigned> getSizePerThread() const;
SmallVector<unsigned> getShapePerCTATile(ArrayRef<int64_t> tensorShape = ArrayRef<int64_t>()) const;

std::optional<LinearLayout> toLinearLayout(ArrayRef<int64_t> shape) const;
}];
}

Expand Down
2 changes: 1 addition & 1 deletion include/triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#ifndef TRITON_GPU_DIALECT_INTERFACES_H
#define TRITON_GPU_DIALECT_INTERFACES_H

#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrInterfaces.h.inc"

#endif // TRITON_GPU_DIALECT_INTERFACES_H
Loading

0 comments on commit 2abfaec

Please sign in to comment.