diff --git a/include/Expr.h b/include/Expr.h index abbc92d..fab2019 100644 --- a/include/Expr.h +++ b/include/Expr.h @@ -76,4 +76,9 @@ class Literal : public Expr { ~Literal(); }; +class Variable : public Expr { +public: + Token *name; +}; + #endif diff --git a/include/Interpreter.h b/include/Interpreter.h index e523248..93596a7 100644 --- a/include/Interpreter.h +++ b/include/Interpreter.h @@ -5,175 +5,43 @@ #ifndef CRUX_INTERPRETER_H #define CRUX_INTERPRETER_H -#include "Error.h" #include "Expr.h" #include "Statement.h" +#include "Token.h" #include "utls/Object.h" -#include "utls/RuntimeError.h" -#include -#include #include class Interpreter { private: - void excecute(Statement *stmnt) { - switch (stmnt->type) { - case StmntPrint_type: - return visitPrintStmnt((Print *)stmnt); - case StmntExpr_type: - return visitExprStmnt((Expression *)stmnt); - } - } - - Object evaluate(Expr *expr) { - switch (expr->type) { - case ExprType_Binary: - return visitBinaryExp((Binary *)expr); - case ExprType_Unary: - return visitUnaryExp((Unary *)expr); - case ExprType_Literal: - return visitLiteral((Literal *)expr); - case ExprType_Grouping: - return visitGroupExp((Grouping *)expr); - case ExprType_Ternary: - return visitTernaryExp((Ternary *)expr); - } - return Object(); - } - - bool isTruthy(Object right) { - if (right.type == nullptr_type) - return false; - if (right.type == bool_type) - return right.bool_literal; - return true; - } - - bool isEqual(Object left, Object right) { - if (left.type == nullptr_type && right.type == nullptr_type) - return true; - if (left.type == nullptr_type) - return false; - - return left.num_literal == right.num_literal; - } - - void checkNumberOperand(Token *op, Object right) { - if (right.type == num_type) - return; - RuntimeError(*op, "Operand must be a number"); - } - bool checkIfSameTypes(Object left, Object right) { - if (left.type == num_type && right.type == num_type) - return true; - else - return false; - } - - bool checkCompatibility(Token *op, Object left, Object right) { - if ((left.type == string_type && right.type == num_type) || - left.type == num_type && right.type == string_type) { - return true; - } else { - return false; - } - } + Object excecute(Statement *stmnt); + + Object evaluate(Expr *expr); + + bool isTruthy(Object right); + + bool isEqual(Object left, Object right); + + void checkNumberOperand(Token *op, Object right); + + bool checkIfSameTypes(Object left, Object right); + + bool checkCompatibility(Token *op, Object left, Object right); public: - void interpret(std::vector &statements) { - try { - for (Statement *stmt : statements) { - excecute(stmt); - } - } catch (RuntimeError error) { - crux::runtimeError(error); - } - } - - void visitPrintStmnt(Print *expr) { - Object value = evaluate(expr->expression); - std::cout << value.str() << "\n"; - return; - } - - void visitExprStmnt(Expression *expr) { evaluate(expr->expression); } - - Object visitLiteral(Literal *expr) { return *expr->literal; } - - Object visitGroupExp(Grouping *expr) { return evaluate(expr->expression); } - - Object visitUnaryExp(Unary *expr) { - Object right = evaluate(expr->right); - switch (expr->op->type) { - case BANG: - return !isTruthy(right); - case MINUS: - checkNumberOperand(expr->op, right); - return -right.num_literal; - default: - RuntimeError(*expr->op, "Invalid operator used"); - } - return Object(); - } - - Object visitBinaryExp(Binary *expr) { - Object left = evaluate(expr->left); - Object right = evaluate(expr->right); - Token *op = expr->op; - switch (op->type) { - case MINUS: - checkIfSameTypes(left, right); - return left.num_literal - right.num_literal; - case SLASH: - checkIfSameTypes(left, right); - return left.num_literal / right.num_literal; - case STAR: - checkIfSameTypes(left, right); - return left.num_literal * right.num_literal; - case PLUS: - if (checkIfSameTypes(left, right)) { - if (left.type == num_type && right.type == num_type) - return left.num_literal + right.num_literal; - if (left.type == string_type && right.type == string_type) - return left.string_literal + right.string_literal; - } else if (checkCompatibility(op, left, right)) { - if (left.type == string_type && right.type == num_type) - return left.string_literal + std::to_string(right.num_literal); - else if (left.type == num_type && right.type == string_type) - return std::to_string(left.num_literal) + right.string_literal; - } - throw new RuntimeError(*op, "Error: Cannot evaluate expression"); - case GREATER: - checkIfSameTypes(left, right); - return left.num_literal > right.num_literal; - case GREATER_EQUAL: - checkIfSameTypes(left, right); - return left.num_literal >= right.num_literal; - case LESS: - checkIfSameTypes(left, right); - return left.num_literal < right.num_literal; - case LESS_EQUAL: - checkIfSameTypes(left, right); - return left.num_literal <= right.num_literal; - case BANG_EQUAL: - return !isEqual(left, right); - case EQUAL_EQUAL: - return isEqual(left, right); - } - RuntimeError(*op, "Operator not found"); - } - - Object visitTernaryExp(Ternary *expr) { - Object condition = evaluate(expr->condition); - if (condition.type == bool_type) { - if (condition.bool_literal) - return evaluate(expr->expression1); - else - return evaluate(expr->expression2); - } else { - RuntimeError(*expr->op1, "Ternary Expression couldn't be evaluated"); - } - return Object(); - } + std::string interpret(std::vector &statements); + + void visitPrintStmnt(Print *expr); + + Object visitExprStmnt(Expression *expr); + + Object visitLiteral(Literal *expr); + + Object visitGroupExp(Grouping *expr); + + Object visitUnaryExp(Unary *expr); + + Object visitBinaryExp(Binary *expr); + + Object visitTernaryExp(Ternary *expr); }; #endif // CRUX_INTERPRETER_H diff --git a/include/Parser.h b/include/Parser.h index 8e47649..0ec58b8 100644 --- a/include/Parser.h +++ b/include/Parser.h @@ -31,6 +31,10 @@ class Parser { Statement *printStatement(); Statement *expressionStatement(); + // Variable stuff + Statement *declaration(); + Statement *varDeclaration(); + // helper functions Expr *equality(); bool check(TokenType type); diff --git a/include/Statement.h b/include/Statement.h index 0102a83..88986ff 100644 --- a/include/Statement.h +++ b/include/Statement.h @@ -7,10 +7,7 @@ #include "Expr.h" -enum Statement_type { - StmntExpr_type, - StmntPrint_type, -}; +enum Statement_type { StmntExpr_type, StmntPrint_type, StmntVar_type }; class Statement { public: @@ -30,4 +27,11 @@ class Expression : public Statement { Expression(Expr *expression); }; +class Var : public Statement { +public: + Token *name; + Expr *expression; + Var(Token *name, Expr *expression); +}; + #endif diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 690e829..2687c07 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -4,6 +4,7 @@ set(Sources Error.cpp utls/Object.cpp Expr.cpp AstPrinter.cpp + Interpreter.cpp Statement.cpp ) diff --git a/src/Interpreter.cpp b/src/Interpreter.cpp index eef34c5..645076d 100644 --- a/src/Interpreter.cpp +++ b/src/Interpreter.cpp @@ -1,6 +1,179 @@ // // Created by Rohith on 3/17/24 // -#ifndef CRUX_INTERPRETER_CPP -#define CRUX_INTERPETER_CPP -#endif + +#include "Interpreter.h" +#include "Error.h" +#include "utls/Object.h" +#include "utls/RuntimeError.h" +#include +#include +#include + +Object Interpreter::excecute(Statement *stmnt) { + switch (stmnt->type) { + case StmntPrint_type: + visitPrintStmnt((Print *)stmnt); + return Object(); + case StmntExpr_type: + return visitExprStmnt((Expression *)stmnt); + } +} + +Object Interpreter::evaluate(Expr *expr) { + switch (expr->type) { + case ExprType_Binary: + return visitBinaryExp((Binary *)expr); + case ExprType_Unary: + return visitUnaryExp((Unary *)expr); + case ExprType_Literal: + return visitLiteral((Literal *)expr); + case ExprType_Grouping: + return visitGroupExp((Grouping *)expr); + case ExprType_Ternary: + return visitTernaryExp((Ternary *)expr); + } + return Object(); +} + +bool Interpreter::isTruthy(Object right) { + if (right.type == nullptr_type) + return false; + if (right.type == bool_type) + return right.bool_literal; + return true; +} + +bool Interpreter::isEqual(Object left, Object right) { + if (left.type == nullptr_type && right.type == nullptr_type) + return true; + if (left.type == nullptr_type) + return false; + + return left.num_literal == right.num_literal; +} + +void Interpreter::checkNumberOperand(Token *op, Object right) { + if (right.type == num_type) + return; + RuntimeError(*op, "Operand must be a number"); +} + +bool Interpreter::checkIfSameTypes(Object left, Object right) { + if (left.type == num_type && right.type == num_type) + return true; + else + return false; +} + +bool Interpreter::checkCompatibility(Token *op, Object left, Object right) { + if ((left.type == string_type && right.type == num_type) || + left.type == num_type && right.type == string_type) { + return true; + } else { + return false; + } +} + +std::string Interpreter::interpret(std::vector &statements) { + try { + for (Statement *stmt : statements) { + Object res = excecute(stmt); + return res.str(); + } + } catch (RuntimeError error) { + crux::runtimeError(error); + return Object().str(); + } + return Object().str(); +} + +void Interpreter::visitPrintStmnt(Print *expr) { + Object value = evaluate(expr->expression); + std::cout << value.str() << "\n"; + return; +} + +Object Interpreter::visitExprStmnt(Expression *expr) { + return evaluate(expr->expression); +} + +Object Interpreter::visitLiteral(Literal *expr) { return *expr->literal; } + +Object Interpreter::visitGroupExp(Grouping *expr) { + return evaluate(expr->expression); +} + +Object Interpreter::visitUnaryExp(Unary *expr) { + Object right = evaluate(expr->right); + switch (expr->op->type) { + case BANG: + return !isTruthy(right); + case MINUS: + checkNumberOperand(expr->op, right); + return -right.num_literal; + default: + RuntimeError(*expr->op, "Invalid operator used"); + } + return Object(); +} + +Object Interpreter::visitBinaryExp(Binary *expr) { + Object left = evaluate(expr->left); + Object right = evaluate(expr->right); + Token *op = expr->op; + switch (op->type) { + case MINUS: + checkIfSameTypes(left, right); + return left.num_literal - right.num_literal; + case SLASH: + checkIfSameTypes(left, right); + return left.num_literal / right.num_literal; + case STAR: + checkIfSameTypes(left, right); + return left.num_literal * right.num_literal; + case PLUS: + if (checkIfSameTypes(left, right)) { + if (left.type == num_type && right.type == num_type) + return left.num_literal + right.num_literal; + if (left.type == string_type && right.type == string_type) + return left.string_literal + right.string_literal; + } else if (checkCompatibility(op, left, right)) { + if (left.type == string_type && right.type == num_type) + return left.string_literal + std::to_string(right.num_literal); + else if (left.type == num_type && right.type == string_type) + return std::to_string(left.num_literal) + right.string_literal; + } + throw new RuntimeError(*op, "Error: Cannot evaluate expression"); + case GREATER: + checkIfSameTypes(left, right); + return left.num_literal > right.num_literal; + case GREATER_EQUAL: + checkIfSameTypes(left, right); + return left.num_literal >= right.num_literal; + case LESS: + checkIfSameTypes(left, right); + return left.num_literal < right.num_literal; + case LESS_EQUAL: + checkIfSameTypes(left, right); + return left.num_literal <= right.num_literal; + case BANG_EQUAL: + return !isEqual(left, right); + case EQUAL_EQUAL: + return isEqual(left, right); + } + RuntimeError(*op, "Operator not found"); +} + +Object Interpreter::visitTernaryExp(Ternary *expr) { + Object condition = evaluate(expr->condition); + if (condition.type == bool_type) { + if (condition.bool_literal) + return evaluate(expr->expression1); + else + return evaluate(expr->expression2); + } else { + RuntimeError(*expr->op1, "Ternary Expression couldn't be evaluated"); + } + return Object(); +} diff --git a/src/Parser.cpp b/src/Parser.cpp index 9357223..a5028b4 100644 --- a/src/Parser.cpp +++ b/src/Parser.cpp @@ -1,4 +1,4 @@ -// +//Object #include "Parser.h" #include "Error.h" #include "Expr.h" diff --git a/src/Statement.cpp b/src/Statement.cpp index 179ba11..0761234 100644 --- a/src/Statement.cpp +++ b/src/Statement.cpp @@ -10,3 +10,6 @@ Print::Print(Expr *expr) : Statement(StmntPrint_type), expression(expr) {} Expression::Expression(Expr *expr) : Statement(StmntExpr_type), expression(expr) {} + +Var::Var(Token *name, Expr *expression) + : Statement(StmntVar_type), name(name), expression(expression) {} diff --git a/test/TestInterpreter.cpp b/test/TestInterpreter.cpp index ea5fce1..747796f 100644 --- a/test/TestInterpreter.cpp +++ b/test/TestInterpreter.cpp @@ -6,52 +6,40 @@ #include "Parser.h" #include "Scanner.h" #include "gtest/gtest.h" +#include #include #include -TEST(InterpreterCheck, TestInterpreterBasic) { - - Expr *left = new Literal(new Object((double)10)); - - Token *op = new Token(PLUS, "+", Object(), 1); - - Expr *right = new Literal(new Object((double)20)); - - Expr *expression = new Binary(left, op, right); - - ASSERT_EQ(Interpreter{}.interpret(expression), "30.000000"); -} - TEST(InterpreterTest, TestInterpreterFlow) { - std::string test = "10 + (40 + (20 - 30) - 10) + 50"; + std::string test = "10 + (40 + (20 - 30) - 10) + 50;"; Scanner scan(test); std::vector tokens = scan.scanTokens(); Parser p(tokens); - Expr *expression = p.parse(); - ASSERT_EQ(Interpreter{}.interpret(expression), "80.000000"); + std::vector statements = p.parse(); + ASSERT_EQ(Interpreter{}.interpret(statements), "80.000000"); } TEST(InterpreterTest, TestInterpreterUnary) { - std::string test = "!true"; + std::string test = "!true;"; Scanner scan(test); std::vector tokens = scan.scanTokens(); Parser p(tokens); - Expr *expression = p.parse(); - ASSERT_EQ(Interpreter{}.interpret(expression), "false"); + std::vector statements = p.parse(); + ASSERT_EQ(Interpreter{}.interpret(statements), "false"); } TEST(InterpreterTest, TestParserTernary) { - std::string test = "3 > 1 ? true : false"; + std::string test = "3 > 1 ? true : false;"; Scanner scan(test); std::vector tokens = scan.scanTokens(); Parser p(tokens); - Expr *expression = p.parse(); - ASSERT_EQ(Interpreter{}.interpret(expression), "true"); + std::vector statements = p.parse(); + ASSERT_EQ(Interpreter{}.interpret(statements), "true"); } TEST(IntrepreterTest, TestStringNumExpressions) { - std::string test1 = "\"test\"+8"; - std::string test2 = "8+\"test\""; + std::string test1 = "\"test\"+8;"; + std::string test2 = "8+\"test\";"; Scanner scan1(test1); Scanner scan2(test2); @@ -61,9 +49,9 @@ TEST(IntrepreterTest, TestStringNumExpressions) { Parser p1(tokens1); Parser p2(tokens2); - Expr *expression1 = p1.parse(); - Expr *expression2 = p2.parse(); + std::vector statements1 = p1.parse(); + std::vector statements2 = p2.parse(); - ASSERT_EQ(Interpreter{}.interpret(expression1), "test8.000000"); - ASSERT_EQ(Interpreter{}.interpret(expression2), "8.000000test"); + ASSERT_EQ(Interpreter{}.interpret(statements1), "test8.000000"); + ASSERT_EQ(Interpreter{}.interpret(statements2), "8.000000test"); } diff --git a/test/TestParser.cpp b/test/TestParser.cpp index c761424..abdaf0d 100644 --- a/test/TestParser.cpp +++ b/test/TestParser.cpp @@ -4,54 +4,60 @@ #include "AstPrinter.h" #include "Parser.h" #include "Scanner.h" +#include "Statement.h" #include "gtest/gtest.h" #include TEST(ParserCheck, TestParserBasic) { - std::string test = "10 + 20"; + std::string test = "10 + 20;"; Scanner scan(test); std::vector tokens = scan.scanTokens(); Parser p(tokens); - Expr *expression = p.parse(); - ASSERT_EQ(PrettyPrint::print(expression), "(+ 10.000000 20.000000)"); + std::vector statements = p.parse(); + Expression *exp = (Expression *)statements[0]; + ASSERT_EQ(PrettyPrint::print(exp->expression), "(+ 10.000000 20.000000)"); } TEST(ParserCheck, TestParserComplex) { - std::string test = "10 + (40 + (20 - 30) - 10) + 50"; + std::string test = "10 + (40 + (20 - 30) - 10) + 50;"; Scanner scan(test); std::vector tokens = scan.scanTokens(); Parser p(tokens); - Expr *expression = p.parse(); - ASSERT_EQ(PrettyPrint::print(expression), + std::vector statements = p.parse(); + Expression *exp = (Expression *)statements[0]; + ASSERT_EQ(PrettyPrint::print(exp->expression), "(+ (+ 10.000000 (Group (- (+ 40.000000 (Group (- 20.000000 " "30.000000))) 10.000000))) 50.000000)"); } TEST(ParserCheck, TestParserUnary) { - std::string test = "!true"; + std::string test = "!true;"; Scanner scan(test); std::vector tokens = scan.scanTokens(); Parser p(tokens); - Expr *expression = p.parse(); - ASSERT_EQ(PrettyPrint::print(expression), "(! true)"); + std::vector statements = p.parse(); + Expression *exp = (Expression *)statements[0]; + ASSERT_EQ(PrettyPrint::print(exp->expression), "(! true)"); } TEST(ParserCheck, TestParserComplexUnary) { - std::string test = "!true + (30 - false) + 40"; + std::string test = "!true + (30 - false) + 40;"; Scanner scan(test); std::vector tokens = scan.scanTokens(); Parser p(tokens); - Expr *expression = p.parse(); - ASSERT_EQ(PrettyPrint::print(expression), + std::vector statements = p.parse(); + Expression *exp = (Expression *)statements[0]; + ASSERT_EQ(PrettyPrint::print(exp->expression), "(+ (+ (! true) (Group (- 30.000000 false))) 40.000000)"); } TEST(ParserCheck, TestParserTernary) { - std::string test = "3 > 1 ? true : false"; + std::string test = "3 > 1 ? true : false;"; Scanner scan(test); std::vector tokens = scan.scanTokens(); Parser p(tokens); - Expr *expression = p.parse(); - ASSERT_EQ(PrettyPrint::print(expression), + std::vector statements = p.parse(); + Expression *exp = (Expression *)statements[0]; + ASSERT_EQ(PrettyPrint::print(exp->expression), "(?: (> 3.000000 1.000000) true false)"); }