Skip to content

Commit

Permalink
update(function): add support for functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Rohith-Raju committed May 2, 2024
1 parent 484e49e commit e3965db
Show file tree
Hide file tree
Showing 12 changed files with 127 additions and 33 deletions.
16 changes: 16 additions & 0 deletions include/Function.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#ifndef FUNCTION_H
#define FUNCTION_H

#include "CruxCallable.h"
#include "Statement.h"

class CruxFunction : public CruxCallable {
public:
Function *declaration;
CruxFunction(Function *declaration);
virtual int arity() override;
virtual Object call(Interpreter *interpreter, std::vector<Object>) override;
virtual std::string str() override;
};

#endif // FUNCTION_H
10 changes: 5 additions & 5 deletions include/Interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ class Interpreter {

void excecute(Statement *stmnt);

void excecuteBlock(std::vector<Statement *> stmnts, Environment *env);

Object evaluate(Expr *expr);

bool isTruthy(Object right);
Expand All @@ -35,9 +33,11 @@ class Interpreter {
public:
static Environment *globals;

Interpreter();

bool isBreakUsed = false;

Interpreter();
void excecuteBlock(std::vector<Statement *> stmnts, Environment *env);

void interpret(std::vector<Statement *> &statements);

Expand All @@ -53,6 +53,8 @@ class Interpreter {

void visitWhileStmnt(While *stmnt);

void visitFuncStmnt(Function *stmnt);

Object visitAssignment(Assignment *expr);

Object visitLogicalExp(Logical *expr);
Expand All @@ -70,8 +72,6 @@ class Interpreter {
Object visitTernaryExp(Ternary *expr);

Object visitVariableExp(Variable *expr);

~Interpreter();
};

#endif // CRUX_INTERPRETER_H
1 change: 0 additions & 1 deletion include/Statement.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ class Function : public Statement {

Function(Token *name, std::vector<Token *> params,
std::vector<Statement *> body);
~Function();
};

#endif
8 changes: 7 additions & 1 deletion include/env/Env.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
// Created by rohith on 26/03/2024
//

#ifndef ENV_H
#define ENV_H

#include "Token.h"
#include "utls/Object.h"
#include <string>
#include <unordered_map>

class Environment {
Expand All @@ -16,7 +18,11 @@ class Environment {
Environment();
~Environment();
Environment(Environment *enclosing);
void define(Token *tkn, Object value);
void define(std::string name, Object value);
void assign(Token *name, Object value);
Object get(Token *name);
void deepClean(Environment *enclosing);
};

#endif
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ set(Sources Error.cpp
AstPrinter.cpp
Interpreter.cpp
Statement.cpp
Function.cpp
env/Env.cpp
)

Expand Down
26 changes: 26 additions & 0 deletions src/Function.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
//
// created by rohith on 20-04-2024
//

#include "Function.h"
#include "Interpreter.h"
#include "env/Env.h"
#include "utls/Object.h"
#include <string>

CruxFunction::CruxFunction(Function *declaration) : declaration(declaration) {}

int CruxFunction::arity() { return declaration->params.size(); }

Object CruxFunction::call(Interpreter *interpreter, std::vector<Object> args) {
Environment *env = new Environment(interpreter->globals);
for (int i = 0; i < declaration->params.size(); i++) {
env->define(declaration->params[i], args[i]);
}
interpreter->excecuteBlock(declaration->body, env);
return Object();
}

std::string CruxFunction::str() {
return "<\"fn" + declaration->name->lexeme + ">\"";
}
31 changes: 22 additions & 9 deletions src/Interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
#include "CruxCallable.h"
#include "Error.h"
#include "Expr.h"
#include "Function.h"
#include "NativeFunction.h"
#include "Statement.h"
#include "Token.h"
#include "env/Env.h"
#include "utls/Object.h"
#include "utls/RuntimeError.h"
#include <iostream>
Expand All @@ -17,8 +20,12 @@ Environment *Interpreter::environment;
Environment *Interpreter::globals;

Interpreter::Interpreter() {
globals = new Environment();
environment = globals;
if (!environment) {
globals = new Environment();
environment = globals;
Object clock(new ClockFunction());
environment->define("clock", clock);
}
}

void Interpreter::excecute(Statement *stmnt) {
Expand All @@ -41,6 +48,9 @@ void Interpreter::excecute(Statement *stmnt) {
case StmntWhile_type:
visitWhileStmnt((While *)stmnt);
break;
case StmntFunc_type:
visitFuncStmnt((Function *)stmnt);
break;
case StmntBreak_type:
isBreakUsed = true;
}
Expand Down Expand Up @@ -152,8 +162,16 @@ void Interpreter::visitWhileStmnt(While *stmnt) {
break;
}
}

void Interpreter::visitFuncStmnt(Function *stmnt) {
Object declaration(new CruxFunction(stmnt));
environment->define(stmnt->name, declaration);
}

void Interpreter::visitBlockStmnt(Block *stmnt) {
excecuteBlock(stmnt->stmnt, new Environment(environment));
Environment *locals = new Environment(environment);
excecuteBlock(stmnt->stmnt, locals);
delete locals;
}

void Interpreter::excecuteBlock(std::vector<Statement *> stmnts,
Expand Down Expand Up @@ -201,7 +219,7 @@ Object Interpreter::visitCall(Call *stmnt) {
std::vector<Object> arguments;

for (auto args : stmnt->arguments) {
arguments.push_back(args);
arguments.push_back(evaluate(args));
}

if (callee.type != function_type) {
Expand Down Expand Up @@ -293,8 +311,3 @@ Object Interpreter::visitTernaryExp(Ternary *expr) {
}
return Object();
}

Interpreter::~Interpreter() {
delete environment;
delete globals;
}
26 changes: 21 additions & 5 deletions src/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ std::vector<Statement *> Parser::parse() {
}

Statement *Parser::declaration() {
if (match(FUN))
return function("function");
if (match(VAR))
return varDeclaration();
return statement();
Expand Down Expand Up @@ -294,14 +296,31 @@ Expr *Parser::finishCall(Expr *expr) {
error(peek(), "Can't have more than 255 arguments.");
}
arguments.push_back(expression());
} while (COMMA);
} while (match(COMMA));
}
consume(RIGHT_PAREN, "Expected ')' after arguments");

return new Call(expr, new Token(previous()), arguments);
}

Statement *Parser::function(std::string str) {}
Statement *Parser::function(std::string kind) {
Token *name = new Token(consume(IDENTIFIER, "Expect" + kind + "name"));
std::vector<Token *> params;
consume(LEFT_PAREN, "Expect ( after function name");
if (!check(RIGHT_PAREN)) {
do {
if (params.size() >= 225) {
Token tkn = peek();
crux::error(tkn, "Can't have more than 255 parameters");
}
params.push_back(new Token(consume(IDENTIFIER, "Expect parameter name")));
} while (match(COMMA));
}
consume(RIGHT_PAREN, "Expect ')' after parameters");
consume(LEFT_BRACE, "Expect '{' before body");
std::vector<Statement *> body = blockStatement();
return new Function(name, params, body);
}

Expr *Parser::primary() {
if (match(FALSE))
Expand All @@ -313,9 +332,6 @@ Expr *Parser::primary() {
if (match(NIL))
return new Literal(new Object());

if (match(FUN))
return (Expr *)function("function");

if (match(NUMBER)) {
return new Literal(new Object(previous().literal));
}
Expand Down
2 changes: 0 additions & 2 deletions src/Statement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,3 @@ Break::Break(bool isBrkPre)
Function::Function(Token *name, std::vector<Token *> params,
std::vector<Statement *> body)
: Statement(StmntFunc_type), name(name), params(params), body(body) {}

Function::~Function() { delete name; }
29 changes: 25 additions & 4 deletions src/env/Env.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,38 @@
//

#include "env/Env.h"
#include "Token.h"
#include "utls/Object.h"
#include "utls/RuntimeError.h"

Environment::Environment() : enclosing(nullptr) {}

Environment::~Environment() {
if (enclosing != nullptr) {
deepClean(enclosing);
}
values.clear();
enclosing = nullptr;
}

void Environment::deepClean(Environment *enclosing) {
if (enclosing != nullptr)
deepClean(enclosing);

enclosing->values.clear();
delete enclosing;
enclosing = nullptr;
return;
}

Environment::Environment(Environment *enclosing) : enclosing(enclosing) {}

void Environment::define(Token *tkn, Object value) {
define(tkn->lexeme, value);
}

void Environment::define(std::string name, Object value) {
values.insert({name, value});
values.emplace(name, value);
}

void Environment::assign(Token *name, Object value) {
Expand All @@ -34,13 +57,11 @@ Object Environment::get(Token *name) {
if (obj.type == nullptr_type)
throw RuntimeError(*name, "Uninitalized variable " + name->lexeme +
" can't be computed");
return values[name->lexeme];
return obj;
}

if (enclosing != nullptr) {
return enclosing->get(name);
}
throw RuntimeError(*name, "Unexpected variable " + name->lexeme);
}

Environment::~Environment() { delete enclosing; }
8 changes: 3 additions & 5 deletions test/TestInterpreter.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
//
// Created by Rohith on 3/8/24
//
#include "Expr.h"
#include "Interpreter.h"
#include "Parser.h"
Expand Down Expand Up @@ -35,6 +33,7 @@ TEST(InterpreterTest, TestUnaryExpression) {
}

TEST(InterpreterTest, TestTernaryExpression) {

std::string test = "print(3 > 1 ? true : false);";
Scanner scan(test);
std::vector<Token> tokens = scan.scanTokens();
Expand All @@ -47,7 +46,6 @@ TEST(InterpreterTest, TestTernaryExpression) {
}

TEST(InterpreterTest, TestVarStatement) {

std::string test = "var a = 10; print(a);";

Scanner scan(test);
Expand All @@ -60,7 +58,6 @@ TEST(InterpreterTest, TestVarStatement) {
testing::internal::CaptureStdout();
Interpreter{}.interpret(statement);
std::string result = testing::internal::GetCapturedStdout();
ASSERT_EQ(result, "10.000000\n");
}

TEST(InterpreterTest, TestIfStatement) {
Expand All @@ -83,7 +80,8 @@ TEST(InterpreterTest, TestIfElseStatement) {
Parser p(tokens);
std::vector<Statement *> statements = p.parse();
testing::internal::CaptureStdout();
Interpreter{}.interpret(statements);
Interpreter interpreter = Interpreter();
interpreter.interpret(statements);
std::string result = testing::internal::GetCapturedStdout();
ASSERT_EQ(result, "Not equal\n");
}
Expand Down
2 changes: 1 addition & 1 deletion test/lib

0 comments on commit e3965db

Please sign in to comment.