Skip to content

Commit

Permalink
[DAPHNE-daphne-eu#186,daphne-eu#560,daphne-eu#667] Parsing unary minu…
Browse files Browse the repository at this point in the history
…s operator (daphne-eu#607)

[DAPHNE-daphne-eu#560] Parsing unary minus operator

- This change decouples the '-' from INT_LITERAL and FLOAT_LITERAL and handles a unary minus separately.
- The additive inverse is internally achieved by EwMinusOp, a concrete elementwise unary operation.
- Accordingly, a new unary op code for the ewUnarySca/ewUnaryMat kernels was added.
- As a leading minus is technically not part of the literal anymore, the parsing of integer literals needs a special case for (-)2^63.
- Added script-level test cases for unary minus, its interplay with other operators, and its use in matrix literals.
- Updated the DaphneDSL language reference.
  - No changes to the description of literals: leading minus is still presented as a part of the literal here, since this is more intuitive from a user's point of view.
  - Mentioning unary -/+ in the operators table.
- And some more minor things.
- Closes daphne-eu#186, closes daphne-eu#560, closes daphne-eu#667.

---------

Co-authored-by: Patrick Damme <patrick.damme@tu-berlin.de>
  • Loading branch information
corepointer and pdamme authored Jul 2, 2024
1 parent 81667eb commit 5a47688
Show file tree
Hide file tree
Showing 19 changed files with 244 additions and 17 deletions.
5 changes: 3 additions & 2 deletions doc/DaphneDSL/LanguageRef.md
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ DaphneDSL currently supports the following binary operators:

| Operator | Meaning |
| --- | --- |
| `-`, `+` | additive inverse (unary operators) |
| `@` | matrix multiplication (highest precedence) |
| `^` | exponentiation |
| `%` | modulo |
Expand All @@ -234,12 +235,12 @@ DaphneDSL currently supports the following binary operators:
| `&&` | logical AND |
| `\|\|` | logical OR (lowest precedence) |

*We plan to add more operators, including unary operators.*
*We plan to add more unary and binary operators in the future.*

*Matrix multiplication (`@`):*
The inputs must be matrices of compatible shapes, and the output is always a matrix.

*All other operators:*
*All other binary operators:*
The following table shows which combinations of inputs are allowed and which result they yield:

| Left input | Right input | Result | Details |
Expand Down
1 change: 1 addition & 0 deletions src/ir/daphneir/DaphneOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ class Daphne_EwUnaryOp<string name, Type scalarType, list<Trait> traits = []> :
// Arithmetic/general math
// ----------------------------------------------------------------------------

// TODO EwMinusOp: Should an unsigned integer argument yield a signed integer result?
def Daphne_EwMinusOp : Daphne_EwUnaryOp<"ewMinus", NumScalar, [ValueTypeFromFirstArg]>;
def Daphne_EwAbsOp : Daphne_EwUnaryOp<"ewAbs", NumScalar, [ValueTypeFromFirstArg]>;
def Daphne_EwSignOp : Daphne_EwUnaryOp<"ewSign", NumScalar, [ValueTypeFromFirstArg]>;
Expand Down
7 changes: 3 additions & 4 deletions src/parser/daphnedsl/DaphneDSLGrammar.g4
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ expr:
| KW_AS (('.' DATA_TYPE) | ('.' VALUE_TYPE) | ('.' DATA_TYPE '<' VALUE_TYPE '>')) '(' expr ')' # castExpr
| obj=expr '[[' (rows=expr)? ',' (cols=expr)? ']]' # rightIdxFilterExpr
| obj=expr idx=indexing # rightIdxExtractExpr
| op=('+'|'-') arg=expr # minusExpr
| lhs=expr op='@' rhs=expr # matmulExpr
| lhs=expr op='^' rhs=expr # powExpr
| lhs=expr op='%' rhs=expr # modExpr
Expand Down Expand Up @@ -161,16 +162,14 @@ VALUE_TYPE:
) ;
INT_LITERAL:
('0' | '-'? NON_ZERO_DIGIT (DIGIT_SEP? DIGIT)* ('l' | 'u' | 'ull' | 'z')?);
('0' | NON_ZERO_DIGIT (DIGIT_SEP? DIGIT)*) ('l' | 'u' | 'ull' | 'z')? ;
FLOAT_LITERAL:
(
// special values
'nan' | 'nanf' | '-'? ('inf' | 'inff')
'nan' | 'nanf' | 'inf' | 'inff'
|
// ordinary values
// optional minus
'-'?
// part before the decimal point
('0' | NON_ZERO_DIGIT (DIGIT_SEP? DIGIT)*)
(
Expand Down
50 changes: 46 additions & 4 deletions src/parser/daphnedsl/DaphneDSLVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -794,17 +794,35 @@ antlrcpp::Any DaphneDSLVisitor::visitArgExpr(DaphneDSLGrammarParser::ArgExprCont
"argument " + arg + " referenced, but not provided as a command line argument"
);

bool hasMinus;
std::string litStr;
if(!it->second.empty() && it->second[0] == '-') {
hasMinus = true;
litStr = it->second.substr(1);
}
else {
hasMinus = false;
litStr = it->second;
}

// Parse the string that was passed as the value for this argument on the
// command line as a DaphneDSL literal.
// TODO: fix for string literals when " are not escaped or not present
std::istringstream stream(it->second);
std::istringstream stream(litStr);
antlr4::ANTLRInputStream input(stream);
input.name = "argument"; // TODO Does this make sense?
DaphneDSLGrammarLexer lexer(&input);
antlr4::CommonTokenStream tokens(&lexer);
DaphneDSLGrammarParser parser(&tokens);
DaphneDSLGrammarParser::LiteralContext * literalCtx = parser.literal();
return visitLiteral(literalCtx);

mlir::Value lit = visitLiteral(literalCtx);
if(!hasMinus)
return lit;
else
return utils.retValWithInferedType(builder.create<mlir::daphne::EwMinusOp>(
utils.getLoc(ctx->start), utils.unknownType, lit
));
}

antlrcpp::Any DaphneDSLVisitor::visitIdentifierExpr(DaphneDSLGrammarParser::IdentifierExprContext * ctx) {
Expand Down Expand Up @@ -1235,6 +1253,21 @@ antlrcpp::Any DaphneDSLVisitor::visitRightIdxExtractExpr(DaphneDSLGrammarParser:
return obj;
}

antlrcpp::Any DaphneDSLVisitor::visitMinusExpr(DaphneDSLGrammarParser::MinusExprContext *ctx) {
std::string op = ctx->op->getText();
mlir::Location loc = utils.getLoc(ctx->op);
mlir::Value arg = utils.valueOrError(visit(ctx->arg));

if(op == "-")
return utils.retValWithInferedType(
builder.create<mlir::daphne::EwMinusOp>(loc, utils.unknownType, arg)
);
if(op == "+")
return arg;

throw ErrorHandler::compilerError(utils.getLoc(ctx->start), "DSLVisitor", "unexpected op symbol");
}

antlrcpp::Any DaphneDSLVisitor::visitMatmulExpr(DaphneDSLGrammarParser::MatmulExprContext * ctx) {
std::string op = ctx->op->getText();
mlir::Location loc = utils.getLoc(ctx->op);
Expand Down Expand Up @@ -1757,8 +1790,17 @@ antlrcpp::Any DaphneDSLVisitor::visitLiteral(DaphneDSLGrammarParser::LiteralCont
static_cast<std::size_t>(std::stoll(litStr))));
}
else {
return static_cast<mlir::Value>(builder.create<mlir::daphne::ConstantOp>(loc,
static_cast<int64_t>(std::stoll(litStr))));
// Note that a leading minus of a numeric literal is not parsed as part of the literal itself,
// but handled separately as a unary minus operator. Thus, this visitor actually sees the
// number without the minus. This is problematic when a DaphneDSL script contains the minimum
// int64 value -2^63, because without the minus, 2^63 is beyond the range of int64, as the
// maximum int64 value is 2^63 - 1. Thus, we need a special case here.
if(std::stoull(litStr) == (std::numeric_limits<int64_t>::max() + 1ull))
return static_cast<mlir::Value>(builder.create<mlir::daphne::ConstantOp>(loc,
static_cast<int64_t>(std::numeric_limits<int64_t>::min())));
else
return static_cast<mlir::Value>(builder.create<mlir::daphne::ConstantOp>(loc,
static_cast<int64_t>(std::stoll(litStr))));
}
}
if(auto lit = ctx->FLOAT_LITERAL()) {
Expand Down
2 changes: 2 additions & 0 deletions src/parser/daphnedsl/DaphneDSLVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ class DaphneDSLVisitor : public DaphneDSLGrammarVisitor {
antlrcpp::Any visitRightIdxFilterExpr(DaphneDSLGrammarParser::RightIdxFilterExprContext * ctx) override;

antlrcpp::Any visitRightIdxExtractExpr(DaphneDSLGrammarParser::RightIdxExtractExprContext * ctx) override;

antlrcpp::Any visitMinusExpr(DaphneDSLGrammarParser::MinusExprContext *ctx) override;

antlrcpp::Any visitMatmulExpr(DaphneDSLGrammarParser::MatmulExprContext * ctx) override;

Expand Down
2 changes: 2 additions & 0 deletions src/runtime/local/kernels/EwUnarySca.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ EwUnaryScaFuncPtr<VTRes, VTArg> getEwUnaryScaFuncPtr(UnaryOpCode opCode) {
switch(opCode) {
#define MAKE_CASE(opCode) case opCode: return &EwUnarySca<opCode, VTRes, VTArg>::apply;
// Arithmetic/general math.
MAKE_CASE(UnaryOpCode::MINUS)
MAKE_CASE(UnaryOpCode::ABS)
MAKE_CASE(UnaryOpCode::SIGN)
MAKE_CASE(UnaryOpCode::SQRT)
Expand Down Expand Up @@ -142,6 +143,7 @@ TRes ewUnarySca(UnaryOpCode opCode, TArg arg, DCTX(ctx)) {

// One such line for each unary function to support.
// Arithmetic/general math.
MAKE_EW_UNARY_SCA(UnaryOpCode::MINUS, -arg);
MAKE_EW_UNARY_SCA(UnaryOpCode::ABS, abs(arg));
MAKE_EW_UNARY_SCA(UnaryOpCode::SIGN, (arg == 0) ? 0 : ((arg < 0) ? -1 : ((arg > 0) ? 1 : std::numeric_limits<TRes>::quiet_NaN())));
MAKE_EW_UNARY_SCA_OPEN_DOMAIN_ERROR(UnaryOpCode::SQRT, sqrt(arg),
Expand Down
1 change: 1 addition & 0 deletions src/runtime/local/kernels/UnaryOpCode.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

enum class UnaryOpCode {
// Arithmetic/general math.
MINUS,
ABS,
SIGN, // signum (-1, 0, +1)
SQRT,
Expand Down
4 changes: 2 additions & 2 deletions src/runtime/local/kernels/kernels.json
Original file line number Diff line number Diff line change
Expand Up @@ -2908,7 +2908,7 @@
[["DenseMatrix", "double"],["DenseMatrix", "double"]],
[["DenseMatrix", "int64_t"],["DenseMatrix", "int64_t"]]
],
"opCodes": ["SIGN", "SQRT", "EXP", "ABS", "FLOOR", "CEIL", "ROUND", "LN",
"opCodes": ["MINUS", "SIGN", "SQRT", "EXP", "ABS", "FLOOR", "CEIL", "ROUND", "LN",
"SIN", "COS", "TAN", "ASIN", "ACOS", "ATAN", "SINH", "COSH", "TANH"]
},
{
Expand Down Expand Up @@ -2942,7 +2942,7 @@
["float", "float"],
["int64_t", "int64_t"]
],
"opCodes": ["SIGN", "SQRT", "EXP", "ABS", "FLOOR", "CEIL", "ROUND", "LN",
"opCodes": ["MINUS", "SIGN", "SQRT", "EXP", "ABS", "FLOOR", "CEIL", "ROUND", "LN",
"SIN", "COS", "TAN", "ASIN", "ACOS", "ATAN", "SINH", "COSH", "TANH"]
},
{
Expand Down
2 changes: 1 addition & 1 deletion test/api/cli/expressions/matrix_literal_success_6.daphne
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ def testFunc() -> f64 {
return 1.1;
}

testExprMat = [testVarExpr, testFunc(), sqrt(3), 2+2];
testExprMat = [testVarExpr, testFunc(), sqrt(3), 2+2, -5.5];
print(testExprMat);
3 changes: 2 additions & 1 deletion test/api/cli/expressions/matrix_literal_success_6.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
DenseMatrix(4x1, double)
DenseMatrix(5x1, double)
10
1.1
1.73205
4
-5.5
2 changes: 1 addition & 1 deletion test/api/cli/operations/OperationsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ MAKE_TEST_CASE("idxMin", 1)
MAKE_TEST_CASE("mean", 1)
MAKE_TEST_CASE("operator_at", 2)
MAKE_TEST_CASE("operator_eq", 2)
MAKE_TEST_CASE("operator_minus", 1)
MAKE_TEST_CASE("operator_minus", 4)
MAKE_TEST_CASE("operator_plus", 2)
MAKE_TEST_CASE("operator_slash", 1)
MAKE_TEST_CASE("operator_times", 1)
Expand Down
2 changes: 1 addition & 1 deletion test/api/cli/operations/operator_minus_1.daphne
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// '-' for subtraction.
// binary '-' for subtraction.

print("scalar - scalar");
print(1 - 2);
Expand Down
21 changes: 21 additions & 0 deletions test/api/cli/operations/operator_minus_2.daphne
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// binary '-' for subtraction, no spaces around the operator (used to be a problem in the past).

print("scalar - scalar");
print(1-2);
print(1.1-2.2);
print(1-2.3);

print("matrix - matrix");
print([1, 2]-[3, 4]);
print([1.1, 2.2]-[3.3, 4.4]);
print([1, 2]-[3.4, 4.6]);

print("matrix - scalar");
print([1, 2]-3);
print([1.1, 2.2]-3.3);
print([1, 2]-3.3);

print("scalar - matrix");
print(3-[1, 2]);
print(3.3-[1.1, 2.2]);
print(3.3-[1, 2]);
34 changes: 34 additions & 0 deletions test/api/cli/operations/operator_minus_2.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
scalar - scalar
-1
-1.1
-1.3
matrix - matrix
DenseMatrix(2x1, int64_t)
-2
-2
DenseMatrix(2x1, double)
-2.2
-2.2
DenseMatrix(2x1, double)
-2.4
-2.6
matrix - scalar
DenseMatrix(2x1, int64_t)
-2
-1
DenseMatrix(2x1, double)
-2.2
-1.1
DenseMatrix(2x1, double)
-2.3
-1.3
scalar - matrix
DenseMatrix(2x1, int64_t)
2
1
DenseMatrix(2x1, double)
2.2
1.1
DenseMatrix(2x1, double)
2.3
1.3
39 changes: 39 additions & 0 deletions test/api/cli/operations/operator_minus_3.daphne
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// unary '-' for additive inverse.

print("scalar");
# single minus
print(-2);
print(-2.2);
# single plus
print(+2);
print(+2.2);
# double minus (cancels out)
print(--2);
print(--2.2);
# double plus (cancels out)
print(++2);
print(++2.2);
# mix of plus and minus
print(+-2);
print(+-2.2);
print(+-+2);
print(+-+2.2);

print("matrix");
# single minus
print(-[2]);
print(-[2.2]);
# single plus
print(+[2]);
print(+[2.2]);
# double minus (cancels out)
print(--[2]);
print(--[2.2]);
# double plus (cancels out)
print(++[2]);
print(++[2.2]);
# mix of plus and minus
print(+-[2]);
print(+-[2.2]);
print(+-+[2]);
print(+-+[2.2]);
38 changes: 38 additions & 0 deletions test/api/cli/operations/operator_minus_3.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
scalar
-2
-2.2
2
2.2
2
2.2
2
2.2
-2
-2.2
-2
-2.2
matrix
DenseMatrix(1x1, int64_t)
-2
DenseMatrix(1x1, double)
-2.2
DenseMatrix(1x1, int64_t)
2
DenseMatrix(1x1, double)
2.2
DenseMatrix(1x1, int64_t)
2
DenseMatrix(1x1, double)
2.2
DenseMatrix(1x1, int64_t)
2
DenseMatrix(1x1, double)
2.2
DenseMatrix(1x1, int64_t)
-2
DenseMatrix(1x1, double)
-2.2
DenseMatrix(1x1, int64_t)
-2
DenseMatrix(1x1, double)
-2.2
19 changes: 19 additions & 0 deletions test/api/cli/operations/operator_minus_4.daphne
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// unary '-' for additive inverse, interplay with other operators.

a = 2;
print("interplay with binary plus/minus");
print(1 + -a);
print(1 + +a);
print(1 - -a);
print(1 - +a);
print(-1 + -a);
print(-1 + +a);
print(-1 - -a);
print(-1 - +a);

print("interplay with other operators");
print(-log(8, 2));
print(-sum([1, 2, 3]));
print(-[10, 20, 30][1, 0]);
print(-(3 * 5));
print(10.0^-3);
16 changes: 16 additions & 0 deletions test/api/cli/operations/operator_minus_4.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
interplay with binary plus/minus
-1
3
3
-1
-3
1
1
-3
interplay with other operators
-3
-6
DenseMatrix(1x1, int64_t)
-20
-15
0.001
Loading

0 comments on commit 5a47688

Please sign in to comment.