Skip to content

Commit

Permalink
Add support for ternary short-circuiting (#707)
Browse files Browse the repository at this point in the history
  • Loading branch information
marcauberer authored Jan 13, 2025
1 parent 524e910 commit b9fb8da
Show file tree
Hide file tree
Showing 11 changed files with 171 additions and 47 deletions.
2 changes: 1 addition & 1 deletion docs/docs/language/operator-overloading.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ title: Operator Overloading
---

Spice allows overloading operators for [custom struct types](structs.md).
Currently, this works for the operators `+`, `-`, `*`, `/`, `==`, `!=`, `<<`, `>>`, `+=`, `-=`, `*=`, `/=`,
Currently, this works for the operators `+`, `-`, `*`, `/`, `==`, `!=`, `<<`, `>>`, `+=`, `-=`, `*=`, `/=`, `[]`,
`++` (postfix) and `--` (postfix).
In the future, more operators will be supported for overloading.

Expand Down
44 changes: 34 additions & 10 deletions src/irgenerator/GenExpressions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,19 +95,43 @@ std::any IRGenerator::visitTernaryExpr(const TernaryExprNode *node) {
// It is a ternary
// Retrieve the condition value
llvm::Value *condValue = resolveValue(node->condition);

// Get the values of true and false
llvm::Value *trueValue;
llvm::Value *falseValue;
if (node->isShortened) {
trueValue = condValue;
falseValue = resolveValue(node->falseExpr);
const LogicalOrExprNode *trueNode = node->isShortened ? node->condition : node->trueExpr;
const LogicalOrExprNode *falseNode = node->falseExpr;

llvm::Value* resultValue;
if (trueNode->hasCompileTimeValue() && falseNode->hasCompileTimeValue()) {
// If both are constants, we can simply emit a selection instruction
llvm::Value *trueValue = resolveValue(trueNode);
llvm::Value *falseValue = resolveValue(falseNode);
resultValue = builder.CreateSelect(condValue, trueValue, falseValue);
} else {
trueValue = resolveValue(node->trueExpr);
falseValue = resolveValue(node->falseExpr);
// We have at least one non-constant value, use branching to not perform both sides
const std::string codeLoc = node->codeLoc.toPrettyLineAndColumn();
llvm::BasicBlock *condTrue = createBlock("cond.true." + codeLoc);
llvm::BasicBlock *condFalse = createBlock("cond.false." + codeLoc);
llvm::BasicBlock *condExit = createBlock("cond.exit." + codeLoc);

// Jump from original block to true or false block, depending on condition
insertCondJump(condValue, condTrue, condFalse);

// Fill true block
switchToBlock(condTrue);
llvm::Value *trueValue = resolveValue(trueNode);
insertJump(condExit);

// Fill false block
switchToBlock(condFalse);
llvm::Value *falseValue = resolveValue(falseNode);
insertJump(condExit);

// Fill the exit block
switchToBlock(condExit);
llvm::PHINode* phiInst = builder.CreatePHI(trueValue->getType(), 2, "cond.result");
phiInst->addIncoming(trueValue, condTrue);
phiInst->addIncoming(falseValue, condFalse);
resultValue = phiInst;
}

llvm::Value *resultValue = builder.CreateSelect(condValue, trueValue, falseValue);
return LLVMExprResult{.value = resultValue};
}

Expand Down
6 changes: 3 additions & 3 deletions src/irgenerator/IRGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ llvm::Value *IRGenerator::insertInBoundsGEP(llvm::Type *type, llvm::Value *baseP
std::string varName) const {
assert(basePtr->getType()->isPointerTy());
assert(!indices.empty());
assert(std::ranges::all_of(indices, [](llvm::Value *index) {
llvm::Type *indexType = index->getType();
assert(std::ranges::all_of(indices, [](const llvm::Value *index) {
const llvm::Type *indexType = index->getType();
return indexType->isIntegerTy(32) || indexType->isIntegerTy(64);
}));

Expand Down Expand Up @@ -465,7 +465,7 @@ LLVMExprResult IRGenerator::doAssignment(llvm::Value *lhsAddress, SymbolTableEnt

if (isDecl && rhsSType.is(TY_STRUCT) && rhs.isTemporary()) {
assert(lhsEntry != nullptr);
// Directly set the address to the lhs entry
// Directly set the address to the lhs entry (temp stealing)
llvm::Value *rhsAddress = resolveAddress(rhs);
lhsEntry->updateAddress(rhsAddress);
rhs.entry = lhsEntry;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,19 @@ define dso_local i32 @main() #1 {
%result = alloca i32, align 4
store i32 0, ptr %result, align 4
%1 = call i1 @_Z2f1v()
%2 = call i1 @_Z2f2v()
%3 = select i1 %1, i1 %1, i1 %2
%4 = zext i1 %3 to i32
br i1 %1, label %cond.true.L12C26, label %cond.false.L12C26

cond.true.L12C26: ; preds = %0
%2 = call i1 @_Z2f1v()
br label %cond.exit.L12C26

cond.false.L12C26: ; preds = %0
%3 = call i1 @_Z2f2v()
br label %cond.exit.L12C26

cond.exit.L12C26: ; preds = %cond.false.L12C26, %cond.true.L12C26
%cond.result = phi i1 [ %2, %cond.true.L12C26 ], [ %3, %cond.false.L12C26 ]
%4 = zext i1 %cond.result to i32
%5 = call i32 (ptr, ...) @printf(ptr noundef @printf.str.2, i32 %4)
%6 = load i32, ptr %result, align 4
ret i32 %6
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Result: 3
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
; ModuleID = 'source.spice'
source_filename = "source.spice"

@printf.str.0 = private unnamed_addr constant [11 x i8] c"Result: %d\00", align 1

define private i1 @_Z10condition1v() {
%result = alloca i1, align 1
ret i1 false
}

define private i1 @_Z10condition2v() {
%result = alloca i1, align 1
ret i1 true
}

; Function Attrs: noinline nounwind optnone uwtable
define dso_local i32 @main() #0 {
%result = alloca i32, align 4
store i32 0, ptr %result, align 4
%1 = call i1 @_Z10condition1v()
br i1 %1, label %land.1.L10C26, label %land.exit.L10C26

land.1.L10C26: ; preds = %0
%2 = call i1 @_Z10condition2v()
br label %land.exit.L10C26

land.exit.L10C26: ; preds = %land.1.L10C26, %0
%land_phi = phi i1 [ %1, %0 ], [ %2, %land.1.L10C26 ]
%3 = select i1 %land_phi, i32 2, i32 3
%4 = call i32 (ptr, ...) @printf(ptr noundef @printf.str.0, i32 %3)
%5 = load i32, ptr %result, align 4
ret i32 %5
}

; Function Attrs: nofree nounwind
declare noundef i32 @printf(ptr nocapture noundef readonly, ...) #1

attributes #0 = { noinline nounwind optnone uwtable }
attributes #1 = { nofree nounwind }
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
f<bool> condition1() {
return false;
}

f<bool> condition2() {
return true;
}

f<int> main() {
printf("Result: %d", condition1() && condition2() ? 2: 3);
}
Original file line number Diff line number Diff line change
@@ -1 +1 @@
Result: 3
Result: false
Original file line number Diff line number Diff line change
@@ -1,39 +1,64 @@
; ModuleID = 'source.spice'
source_filename = "source.spice"

@printf.str.0 = private unnamed_addr constant [11 x i8] c"Result: %d\00", align 1
@anon.string.0 = private unnamed_addr constant [56 x i8] c"Assertion failed: Condition 'false' evaluated to false.\00", align 1
@anon.string.1 = private unnamed_addr constant [6 x i8] c"false\00", align 1
@printf.str.0 = private unnamed_addr constant [11 x i8] c"Result: %s\00", align 1

define private i1 @_Z10condition1v() {
define private i1 @_Z7condFctv() {
%result = alloca i1, align 1
ret i1 false
}

define private i1 @_Z10condition2v() {
%result = alloca i1, align 1
ret i1 true
define private ptr @_Z7trueFctv() {
%result = alloca ptr, align 8
br i1 false, label %assert.exit.L6, label %assert.then.L6, !prof !0

assert.then.L6: ; preds = %0
%1 = call i32 (ptr, ...) @printf(ptr @anon.string.0)
call void @exit(i32 1)
unreachable

assert.exit.L6: ; preds = %0
%2 = load ptr, ptr %result, align 8
ret ptr %2
}

; Function Attrs: nofree nounwind
declare noundef i32 @printf(ptr nocapture noundef readonly, ...) #0

; Function Attrs: cold noreturn nounwind
declare void @exit(i32) #1

define private ptr @_Z8falseFctv() {
%result = alloca ptr, align 8
ret ptr @anon.string.1
}

; Function Attrs: noinline nounwind optnone uwtable
define dso_local i32 @main() #0 {
define dso_local i32 @main() #2 {
%result = alloca i32, align 4
store i32 0, ptr %result, align 4
%1 = call i1 @_Z10condition1v()
br i1 %1, label %land.1.L10C26, label %land.exit.L10C26
%1 = call i1 @_Z7condFctv()
br i1 %1, label %cond.true.L15C26, label %cond.false.L15C26

land.1.L10C26: ; preds = %0
%2 = call i1 @_Z10condition2v()
br label %land.exit.L10C26
cond.true.L15C26: ; preds = %0
%2 = call ptr @_Z7trueFctv()
br label %cond.exit.L15C26

land.exit.L10C26: ; preds = %land.1.L10C26, %0
%land_phi = phi i1 [ %1, %0 ], [ %2, %land.1.L10C26 ]
%3 = select i1 %land_phi, i32 2, i32 3
%4 = call i32 (ptr, ...) @printf(ptr noundef @printf.str.0, i32 %3)
cond.false.L15C26: ; preds = %0
%3 = call ptr @_Z8falseFctv()
br label %cond.exit.L15C26

cond.exit.L15C26: ; preds = %cond.false.L15C26, %cond.true.L15C26
%cond.result = phi ptr [ %2, %cond.true.L15C26 ], [ %3, %cond.false.L15C26 ]
%4 = call i32 (ptr, ...) @printf(ptr noundef @printf.str.0, ptr %cond.result)
%5 = load i32, ptr %result, align 4
ret i32 %5
}

; Function Attrs: nofree nounwind
declare noundef i32 @printf(ptr nocapture noundef readonly, ...) #1
attributes #0 = { nofree nounwind }
attributes #1 = { cold noreturn nounwind }
attributes #2 = { noinline nounwind optnone uwtable }

attributes #0 = { noinline nounwind optnone uwtable }
attributes #1 = { nofree nounwind }
!0 = !{!"branch_weights", i32 2000, i32 1}
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
f<bool> condition1() {
f<bool> condFct() {
return false;
}

f<bool> condition2() {
return true;
f<string> trueFct() {
assert false; // Should not be called
return "true";
}

f<string> falseFct() {
return "false";
}

f<int> main() {
printf("Result: %d", condition1() && condition2() ? 2: 3);
printf("Result: %s", condFct() ? trueFct() : falseFct());
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,22 @@ define dso_local i32 @main() #0 {
store i32 0, ptr %result, align 4
store i1 true, ptr %condition, align 1
%1 = load i1, ptr %condition, align 1
br i1 %1, label %cond.true.L7C13, label %cond.false.L7C13

cond.true.L7C13: ; preds = %0
%2 = call i32 @_Z3getv()
%3 = select i1 %1, i32 %2, i32 24
store i32 %3, ptr %r, align 4
%4 = load i32, ptr %r, align 4
%5 = call i32 (ptr, ...) @printf(ptr noundef @printf.str.0, i32 %4)
%6 = load i32, ptr %result, align 4
ret i32 %6
br label %cond.exit.L7C13

cond.false.L7C13: ; preds = %0
br label %cond.exit.L7C13

cond.exit.L7C13: ; preds = %cond.false.L7C13, %cond.true.L7C13
%cond.result = phi i32 [ %2, %cond.true.L7C13 ], [ 24, %cond.false.L7C13 ]
store i32 %cond.result, ptr %r, align 4
%3 = load i32, ptr %r, align 4
%4 = call i32 (ptr, ...) @printf(ptr noundef @printf.str.0, i32 %3)
%5 = load i32, ptr %result, align 4
ret i32 %5
}

; Function Attrs: nofree nounwind
Expand Down

0 comments on commit b9fb8da

Please sign in to comment.