Skip to content

Commit

Permalink
Simpler treatment of variadic operands with zero occurrences in Rewri…
Browse files Browse the repository at this point in the history
…teToCallKernelOpPass.

- Some DaphneIR operations have variadic operands, i.e., operands which can occur an arbitrary number of times (including zero).
- When lowering DaphneIR operations to the DaphneIR CallKernelOp, this poses a problem as the C++ kernels cannot be pre-compiled for any number of operands.
- Thus, the kernels expect an array of operands (pointer plus size); in DaphneIR this is represented by a value of type daphne::VariadicPack, which stores all occurrences of a variadic operand.
- So RewriteToCallKernelOpPass needs to convert variadic operands of DaphneIR operations to DaphneIR's VariadicPack.
- In the past there have been problems in case of variadic operands with zero occurrences.
- As GroupOp was the only problematic op so far, a workaround has been provided in PRs daphne-eu#564 and daphne-eu#543.
- This commit improves the treatment of variadic operands with zero occurrences in RewriteToCallKernelOpPass:
  - First, the code is more general now and clearly expresses that this problem is not about GroupOp, but about instances of ops with variadic operands with zero occurrences, in general.
  - Second, the code for treating individual such ops is significantly simpler and more extensible now, since one only has to provide a default MLIR type for the variadic operand.
- Furthermore, a simplified version of the example DaphneDSL code mentioned in issue daphne-eu#562, that failed before PR daphne-eu#564, was added as a test case now.
  • Loading branch information
pdamme committed Apr 8, 2024
1 parent 67e2705 commit 92fbc7d
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 29 deletions.
67 changes: 39 additions & 28 deletions src/compiler/lowering/RewriteToCallKernelOpPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,12 +181,12 @@ namespace
for(size_t i = 0; i < resultTypes.size(); i++)
callee << "__" << CompilerUtils::mlirTypeToCppTypeName(resultTypes[i], false);

// Append names of operand types to the kernel name. Variadic
// Append names of operand types to the kernel name. Variadic ODS
// operands, which can have an arbitrary number of occurrences, are
// treated specially.
Operation::operand_type_range operandTypes = op->getOperandTypes();
// The operands of the CallKernelOp may differ from the operands
// of the given operation, if it has a variadic operand.
// of the given operation, if it has a variadic ODS operand.
std::vector<Value> newOperands;

if(
Expand All @@ -197,40 +197,51 @@ namespace
op->hasTrait<OpTrait::AtLeastNOperands<1>::Impl>() ||
op->hasTrait<OpTrait::AtLeastNOperands<2>::Impl>()
) {
// For operations with variadic operands, we replace all
// occurrences of a variadic operand by a single operand of
// For operations with variadic ODS operands, we replace all
// occurrences of a variadic ODS operand by a single operand of
// type VariadicPack as well as an operand for the number of
// occurrences. All occurrences of the variadic operand are
// occurrences. All occurrences of the variadic ODS operand are
// stored in the VariadicPack.
// Note that a variadic ODS operand may have zero occurrences.
// In that case, there is no operand corresponding to the
// variadic ODS operand.
const size_t numODSOperands = getNumODSOperands(op);
for(size_t i = 0; i < numODSOperands; i++) {
auto odsOpInfo = getODSOperandInfo(op, i);
const unsigned idx = std::get<0>(odsOpInfo);
const unsigned len = std::get<1>(odsOpInfo);
const bool isVariadic = std::get<2>(odsOpInfo);

// TODO The group operation currently expects at least four inputs due to the
// expectation of a aggregation. To make the group operation possible without aggregations,
// we have to use this workaround to create the correct name and skip the creation
// of the variadic pack ops. Should be changed when reworking the lowering to kernels.
if(llvm::dyn_cast<daphne::GroupOp>(op) && idx >= operandTypes.size()) {
callee << "__char_variadic__size_t";
auto cvpOp = rewriter.create<daphne::CreateVariadicPackOp>(
loc,
daphne::VariadicPackType::get(
rewriter.getContext(),
daphne::StringType::get(rewriter.getContext())
),
rewriter.getI64IntegerAttr(0)
);
newOperands.push_back(cvpOp);
newOperands.push_back(rewriter.create<daphne::ConstantOp>(
loc, rewriter.getIndexType(), rewriter.getIndexAttr(0))
);
continue;
} else {
callee << "__" << CompilerUtils::mlirTypeToCppTypeName(operandTypes[idx], generalizeInputTypes);

// Determine the MLIR type of the current ODS operand.
Type odsOperandTy;
if(len > 0) {
// If the current ODS operand has occurrences, then
// we use the type of the first operand belonging to
// the current ODS operand.
odsOperandTy = operandTypes[idx];
}
else { // len == 0
// If the current ODS operand does not have any occurrences
// (e.g., a variadic ODS operand with zero concrete operands
// provided), then we cannot derive the type of the
// current ODS operand from any given operand. Instead,
// we use a default type depending on which ODS operand of
// which operation it is.
// Note that we cannot simply omit the type, since the
// underlying kernel expects an "empty list" (represented
// in the DAPHNE compiler by an empty VariadicPack).
if(llvm::dyn_cast<daphne::GroupOp>(op) && i == 2)
// A GroupOp may have zero aggregation column names.
odsOperandTy = daphne::StringType::get(rewriter.getContext());
else
throw std::runtime_error(
"RewriteToCallKernelOpPass encountered a variadic ODS operand with zero occurrences, "
"but does not know how to handle it: ODS operand " + std::to_string(i) +
" of operation " + op->getName().getStringRef().str()
);
}

callee << "__" << CompilerUtils::mlirTypeToCppTypeName(odsOperandTy, generalizeInputTypes);

if(isVariadic) {
// Variadic operand.
Expand All @@ -239,7 +250,7 @@ namespace
loc,
daphne::VariadicPackType::get(
rewriter.getContext(),
op->getOperand(idx).getType()
odsOperandTy
),
rewriter.getI64IntegerAttr(len)
);
Expand Down
2 changes: 1 addition & 1 deletion test/api/cli/sql/SQLTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ MAKE_SUCCESS_TEST_CASE("join", 1);
MAKE_SUCCESS_TEST_CASE("group", 3);
MAKE_PASS_FAILURE_TEST_CASE("group", 1);

MAKE_TEST_CASE("group", 4)
MAKE_TEST_CASE("group", 5)


MAKE_TEST_CASE("thetaJoin_equal", 4)
Expand Down
10 changes: 10 additions & 0 deletions test/api/cli/sql/group_5.daphne
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Group with and without aggregation in the same DaphneDSL script.

f = createFrame([1, 2, 3, 2], "a");
registerView("f", f);

res1 = sql("SELECT f.a FROM f GROUP BY f.a;");
print(res1);

res2 = sql("SELECT sum(f.a) FROM f GROUP BY f.a;");
print(res2);
8 changes: 8 additions & 0 deletions test/api/cli/sql/group_5.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Frame(3x1, [f.a:int64_t])
1
2
3
Frame(3x1, [sum(f.a):int64_t])
1
4
3

0 comments on commit 92fbc7d

Please sign in to comment.