From 92fbc7d258532df71bd9e576566f6eb6b80a464d Mon Sep 17 00:00:00 2001 From: Patrick Damme Date: Mon, 8 Apr 2024 22:49:25 +0200 Subject: [PATCH] Simpler treatment of variadic operands with zero occurrences in RewriteToCallKernelOpPass. - Some DaphneIR operations have variadic operands, i.e., operands which can occur an arbitrary number of times (including zero). - When lowering DaphneIR operations to the DaphneIR CallKernelOp, this poses a problem as the C++ kernels cannot be pre-compiled for any number of operands. - Thus, the kernels expect an array of operands (pointer plus size); in DaphneIR this is represented by a value of type daphne::VariadicPack, which stores all occurrences of a variadic operand. - So RewriteToCallKernelOpPass needs to convert variadic operands of DaphneIR operations to DaphneIR's VariadicPack. - In the past there have been problems in case of variadic operands with zero occurrences. - As GroupOp was the only problematic op so far, a workaround has been provided in PRs #564 and #543. - This commit improves the treatment of variadic operands with zero occurrences in RewriteToCallKernelOpPass: - First, the code is more general now and clearly expresses that this problem is not about GroupOp, but about instances of ops with variadic operands with zero occurrences, in general. - Second, the code for treating individual such ops is significantly simpler and more extensible now, since one only has to provide a default MLIR type for the variadic operand. - Furthermore, a simplified version of the example DaphneDSL code mentioned in issue #562, that failed before PR #564, was added as a test case now. --- .../lowering/RewriteToCallKernelOpPass.cpp | 67 +++++++++++-------- test/api/cli/sql/SQLTest.cpp | 2 +- test/api/cli/sql/group_5.daphne | 10 +++ test/api/cli/sql/group_5.txt | 8 +++ 4 files changed, 58 insertions(+), 29 deletions(-) create mode 100644 test/api/cli/sql/group_5.daphne create mode 100644 test/api/cli/sql/group_5.txt diff --git a/src/compiler/lowering/RewriteToCallKernelOpPass.cpp b/src/compiler/lowering/RewriteToCallKernelOpPass.cpp index 42633a965..db06779d0 100644 --- a/src/compiler/lowering/RewriteToCallKernelOpPass.cpp +++ b/src/compiler/lowering/RewriteToCallKernelOpPass.cpp @@ -181,12 +181,12 @@ namespace for(size_t i = 0; i < resultTypes.size(); i++) callee << "__" << CompilerUtils::mlirTypeToCppTypeName(resultTypes[i], false); - // Append names of operand types to the kernel name. Variadic + // Append names of operand types to the kernel name. Variadic ODS // operands, which can have an arbitrary number of occurrences, are // treated specially. Operation::operand_type_range operandTypes = op->getOperandTypes(); // The operands of the CallKernelOp may differ from the operands - // of the given operation, if it has a variadic operand. + // of the given operation, if it has a variadic ODS operand. std::vector newOperands; if( @@ -197,40 +197,51 @@ namespace op->hasTrait::Impl>() || op->hasTrait::Impl>() ) { - // For operations with variadic operands, we replace all - // occurrences of a variadic operand by a single operand of + // For operations with variadic ODS operands, we replace all + // occurrences of a variadic ODS operand by a single operand of // type VariadicPack as well as an operand for the number of - // occurrences. All occurrences of the variadic operand are + // occurrences. All occurrences of the variadic ODS operand are // stored in the VariadicPack. + // Note that a variadic ODS operand may have zero occurrences. + // In that case, there is no operand corresponding to the + // variadic ODS operand. const size_t numODSOperands = getNumODSOperands(op); for(size_t i = 0; i < numODSOperands; i++) { auto odsOpInfo = getODSOperandInfo(op, i); const unsigned idx = std::get<0>(odsOpInfo); const unsigned len = std::get<1>(odsOpInfo); const bool isVariadic = std::get<2>(odsOpInfo); - - // TODO The group operation currently expects at least four inputs due to the - // expectation of a aggregation. To make the group operation possible without aggregations, - // we have to use this workaround to create the correct name and skip the creation - // of the variadic pack ops. Should be changed when reworking the lowering to kernels. - if(llvm::dyn_cast(op) && idx >= operandTypes.size()) { - callee << "__char_variadic__size_t"; - auto cvpOp = rewriter.create( - loc, - daphne::VariadicPackType::get( - rewriter.getContext(), - daphne::StringType::get(rewriter.getContext()) - ), - rewriter.getI64IntegerAttr(0) - ); - newOperands.push_back(cvpOp); - newOperands.push_back(rewriter.create( - loc, rewriter.getIndexType(), rewriter.getIndexAttr(0)) - ); - continue; - } else { - callee << "__" << CompilerUtils::mlirTypeToCppTypeName(operandTypes[idx], generalizeInputTypes); + + // Determine the MLIR type of the current ODS operand. + Type odsOperandTy; + if(len > 0) { + // If the current ODS operand has occurrences, then + // we use the type of the first operand belonging to + // the current ODS operand. + odsOperandTy = operandTypes[idx]; } + else { // len == 0 + // If the current ODS operand does not have any occurrences + // (e.g., a variadic ODS operand with zero concrete operands + // provided), then we cannot derive the type of the + // current ODS operand from any given operand. Instead, + // we use a default type depending on which ODS operand of + // which operation it is. + // Note that we cannot simply omit the type, since the + // underlying kernel expects an "empty list" (represented + // in the DAPHNE compiler by an empty VariadicPack). + if(llvm::dyn_cast(op) && i == 2) + // A GroupOp may have zero aggregation column names. + odsOperandTy = daphne::StringType::get(rewriter.getContext()); + else + throw std::runtime_error( + "RewriteToCallKernelOpPass encountered a variadic ODS operand with zero occurrences, " + "but does not know how to handle it: ODS operand " + std::to_string(i) + + " of operation " + op->getName().getStringRef().str() + ); + } + + callee << "__" << CompilerUtils::mlirTypeToCppTypeName(odsOperandTy, generalizeInputTypes); if(isVariadic) { // Variadic operand. @@ -239,7 +250,7 @@ namespace loc, daphne::VariadicPackType::get( rewriter.getContext(), - op->getOperand(idx).getType() + odsOperandTy ), rewriter.getI64IntegerAttr(len) ); diff --git a/test/api/cli/sql/SQLTest.cpp b/test/api/cli/sql/SQLTest.cpp index 3eb01fa87..494eee0ea 100644 --- a/test/api/cli/sql/SQLTest.cpp +++ b/test/api/cli/sql/SQLTest.cpp @@ -84,7 +84,7 @@ MAKE_SUCCESS_TEST_CASE("join", 1); MAKE_SUCCESS_TEST_CASE("group", 3); MAKE_PASS_FAILURE_TEST_CASE("group", 1); -MAKE_TEST_CASE("group", 4) +MAKE_TEST_CASE("group", 5) MAKE_TEST_CASE("thetaJoin_equal", 4) diff --git a/test/api/cli/sql/group_5.daphne b/test/api/cli/sql/group_5.daphne new file mode 100644 index 000000000..b6f7790f1 --- /dev/null +++ b/test/api/cli/sql/group_5.daphne @@ -0,0 +1,10 @@ +// Group with and without aggregation in the same DaphneDSL script. + +f = createFrame([1, 2, 3, 2], "a"); +registerView("f", f); + +res1 = sql("SELECT f.a FROM f GROUP BY f.a;"); +print(res1); + +res2 = sql("SELECT sum(f.a) FROM f GROUP BY f.a;"); +print(res2); \ No newline at end of file diff --git a/test/api/cli/sql/group_5.txt b/test/api/cli/sql/group_5.txt new file mode 100644 index 000000000..0149d8101 --- /dev/null +++ b/test/api/cli/sql/group_5.txt @@ -0,0 +1,8 @@ +Frame(3x1, [f.a:int64_t]) +1 +2 +3 +Frame(3x1, [sum(f.a):int64_t]) +1 +4 +3