From 01103d809488f4ecc637277b72542b3f56689f82 Mon Sep 17 00:00:00 2001 From: Vurv <56230599+Vurv78@users.noreply.github.com> Date: Wed, 8 Nov 2023 23:21:41 -0800 Subject: [PATCH 01/11] Initial lambda implementation --- .../compiler/restrictions/lambda_return.txt | 6 + .../compiler/restrictions/variadic_lambda.txt | 4 + .../tests/runtime/base/lambdas.txt | 56 ++++ .../gmod_wire_expression2/base/compiler.lua | 290 ++++++++++++------ .../gmod_wire_expression2/base/parser.lua | 26 +- .../gmod_wire_expression2/core/e2lib.lua | 2 +- .../gmod_wire_expression2/core/functions.lua | 41 +-- 7 files changed, 300 insertions(+), 125 deletions(-) create mode 100644 data/expression2/tests/compiler/compiler/restrictions/lambda_return.txt create mode 100644 data/expression2/tests/compiler/compiler/restrictions/variadic_lambda.txt create mode 100644 data/expression2/tests/runtime/base/lambdas.txt diff --git a/data/expression2/tests/compiler/compiler/restrictions/lambda_return.txt b/data/expression2/tests/compiler/compiler/restrictions/lambda_return.txt new file mode 100644 index 0000000000..9ab6b91f68 --- /dev/null +++ b/data/expression2/tests/compiler/compiler/restrictions/lambda_return.txt @@ -0,0 +1,6 @@ +## SHOULD_FAIL:COMPILE + +function() { + if (1) { return "str" } + return 22 +} \ No newline at end of file diff --git a/data/expression2/tests/compiler/compiler/restrictions/variadic_lambda.txt b/data/expression2/tests/compiler/compiler/restrictions/variadic_lambda.txt new file mode 100644 index 0000000000..adc2c0cbc6 --- /dev/null +++ b/data/expression2/tests/compiler/compiler/restrictions/variadic_lambda.txt @@ -0,0 +1,4 @@ +## SHOULD_FAIL:COMPILE + +function(...A:array) {} +function(...A:table) {} \ No newline at end of file diff --git a/data/expression2/tests/runtime/base/lambdas.txt b/data/expression2/tests/runtime/base/lambdas.txt new file mode 100644 index 0000000000..8e8b319fe5 --- /dev/null +++ b/data/expression2/tests/runtime/base/lambdas.txt @@ -0,0 +1,56 @@ +## SHOULD_PASS:EXECUTE + +# Returns + +assert( (function() { return 55 })()[number] == 55 ) +assert( (function() { return "str" })()[string] == "str" ) + +# Upvalues + +const Wrapper = function(V:number) { + return function() { + return V + } +} + +const F1 = Wrapper(55)[function] + +if (1) { + if (2) { + local V = 22 + assert(F1()[number] == 55) + } +} + +assert(F1()[number] == 55) +assert(F1()[number] == 55) + +const F2 = Wrapper(1238)[function] + +assert(F2()[number] == 1238) +#local V = 21 +assert(F2()[number] == 1238) + +const IsEven = function(N:number) { + return N % 2 == 0 +} + +const Not = function(N:number) { + return !N +} + +const IsOdd = function(N:number) { + return Not(IsEven(N)[number])[number] +} + +assert(IsOdd(1)[number] == 1) +assert(IsOdd(2)[number] == 0) + +assert( ((function() { return function() { return 55 } })()[function])()[number] == 55 ) + +const Identity = function(N:number) { + return N +} + +assert(Identity(2)[number] == 2) +assert(Identity(2193921)[number] == 2193921) \ No newline at end of file diff --git a/lua/entities/gmod_wire_expression2/base/compiler.lua b/lua/entities/gmod_wire_expression2/base/compiler.lua index 5ad1ebe51f..a6a37a6c54 100644 --- a/lua/entities/gmod_wire_expression2/base/compiler.lua +++ b/lua/entities/gmod_wire_expression2/base/compiler.lua @@ -224,7 +224,7 @@ local CompileVisitors = { local i = #stmts + 1 stmts[i], traces[i] = stmt, trace - if node:isExpr() and node.variant ~= NodeVariant.ExprStringCall and node.variant ~= NodeVariant.ExprCall and node.variant ~= NodeVariant.ExprMethodCall then + if node:isExpr() and node.variant ~= NodeVariant.ExprDynCall and node.variant ~= NodeVariant.ExprCall and node.variant ~= NodeVariant.ExprMethodCall then self:Warning("This expression has no effect", node.trace) end end @@ -649,7 +649,7 @@ local CompileVisitors = { local existing = {} for i, param in ipairs(data[4]) do if param.type then - local t = self:CheckType(param.type) + local t = self:Assert(self:CheckType(param.type), "Cannot use void as parameter type", param.name.trace) if param.variadic then self:Assert(t == "r" or t == "t", "Variadic parameter must be of type array or table", param.type.trace) variadic_ind, variadic_ty = i, t @@ -679,9 +679,9 @@ local CompileVisitors = { end else if return_type then - self:Assert(fn_data.returns and fn_data.returns[1] == return_type, "Cannot override with differing return type", trace) + self:Assert(fn_data.ret == return_type, "Cannot override with differing return type", trace) else - self:Assert(fn_data.returns == nil, "Cannot override function returning void with differing return type", trace) + self:Assert(fn_data.ret == nil, "Cannot override function returning void with differing return type", trace) end -- Tag function if it is ever re-declared. Used as an optimization @@ -689,7 +689,7 @@ local CompileVisitors = { end end - local fn = { args = param_types, returns = return_type and { return_type }, meta = meta_type, cost = variadic_ty and 25 or 10, attrs = {} } + local fn = { args = param_types, ret = return_type, meta = meta_type, cost = variadic_ty and 25 or 10, attrs = {} } local sig = table.concat(param_types, "", 1, #param_types - 1) .. ((variadic_ty and ".." or "") .. (param_types[#param_types] or "")) if meta_type then @@ -816,7 +816,7 @@ local CompileVisitors = { end end) - self:Assert((fn.returns and fn.returns[1]) == return_type, "Function " .. name.value .. " expects to return type (" .. (return_type or "void") .. ") but got type (" .. ((fn.returns and fn.returns[1]) or "void") .. ")", trace) + self:Assert(fn.ret == return_type, "Function " .. name.value .. " expects to return type (" .. (return_type or "void") .. ") but got type (" .. (fn.ret or "void") .. ")", trace) local sig = name.value .. "(" .. (meta_type and (meta_type .. ":") or "") .. sig .. ")" local fn = fn.op @@ -910,10 +910,10 @@ local CompileVisitors = { local name, fn = fn[1], fn[2] - if fn.returns then - self:Assert(fn.returns[1] == ret_ty, "Function " .. name .. " expects return type (" .. (fn.returns[1] or "void") .. ") but was given (" .. (ret_ty or "void") .. ")", trace) + if fn.ret then + self:Assert(fn.ret == ret_ty, "Function " .. name .. " expects return type (" .. (fn.ret or "void") .. ") but was given (" .. (ret_ty or "void") .. ")", trace) else - fn.returns = { ret_ty } + fn.ret = ret_ty end if ret_ty then @@ -1273,6 +1273,88 @@ local CompileVisitors = { end end, + ---@param data { [1]: Parameter[], [2]: Node } + [NodeVariant.ExprFunction] = function(self, trace, data) + ---@type EnvFunction + local fn, param_names, param_types, nargs = { attrs = {} }, {}, {}, #data[1] + + local block = self:Scope(function(scope) + scope.data["function"] = { "", fn } + + for i, param in ipairs(data[1]) do + param_names[i], param_types[i] = param.name.value, self:Assert(self:CheckType(param.type), "Cannot use void as parameter", param.name.trace) + self:Assert(not param.variadic, "Variadic lambdas are not supported, use an array instead", param.name.trace) + scope:DeclVar(param.name.value, { type = param_types[i], initialized = true, trace_if_unused = param.name.trace }) + end + + return self:CompileStmt(data[2]) + end) + + local ret = fn.ret + local expected_sig = table.concat(param_types) + + return function(state) + local inherited_scopes, after, before = {}, state.ScopeID + 1, state.ScopeID + for i = 0, state.ScopeID do + inherited_scopes[i] = state.Scopes[i] + end + + if ret then + return function(args, given_sig) + if given_sig ~= expected_sig then + state:forceThrow("Incorrect arguments passed to lambda, expected (" .. expected_sig .. "), got (" .. given_sig .. ")") + end + + local s_scopes, s_scope, s_scopeid = state.Scopes, state.Scope, state.ScopeID + + local scope = { vclk = {} } + state.Scopes = inherited_scopes + state.ScopeID = after -- state.ScopeID + 1 + state.Scopes[after] = scope + state.Scope = scope + + for i = 1, nargs do + local arg = args[i] + if arg then + scope[param_names[i]] = arg + else + state:forceThrow("Missing argument [" .. param_names[i] .. "] to be passed into lambda") + end + end + + block(state) + + state.ScopeID, state.Scope, state.Scopes = s_scopeid, s_scope, s_scopes + + state.__return__ = false + return ret, state.__returnval__ + end + else -- function without return value, don't handle case. + return function(args, given_sig) + if given_sig ~= expected_sig then + state:forceThrow("Incorrect arguments passed to lambda, expected (" .. expected_sig .. "), got (" .. given_sig .. ")") + end + + local s_scopes, s_scopeid, s_scope = state.Scopes, state.ScopeID, state.Scope + + local scope = { vclk = {} } + state.Scopes = inherited_scopes + state.ScopeID = after -- state.ScopeID + 1 + state.Scopes[after] = scope + state.Scope = scope + + for i = 1, nargs do + scope[param_names[i]] = args[i] + end + + block(state) + + state.Scopes[before], state.ScopeID, state.Scope, state.Scopes = nil, s_scopeid, s_scope, s_scopes + end + end + end, "f" + end, + [NodeVariant.ExprArithmetic] = handleInfixOperation, ---@param data { [1]: Node, [2]: Operator, [3]: self } @@ -1466,7 +1548,7 @@ local CompileVisitors = { rargs[k] = args[k](state) end return fn(state, rargs, types) - end, fn_data.returns and (fn_data.returns[1] ~= "" and fn_data.returns[1] or nil) + end, fn_data.ret and (fn_data.ret ~= "" and fn_data.ret or nil) else local full_sig = name.value .. "(" .. arg_sig .. ")" return function(state) ---@param state RuntimeContext @@ -1481,7 +1563,7 @@ local CompileVisitors = { else state:forceThrow("No such function defined at runtime: " .. full_sig) end - end, fn_data.returns and (fn_data.returns[1] ~= "" and fn_data.returns[1] or nil) + end, fn_data.ret and (fn_data.ret ~= "" and fn_data.ret or nil) end elseif fn_data.attrs["legacy"] then -- Not a user function. Can get function to call at compile time. local fn, largs = fn_data.op, { [1] = {}, [nargs + 2] = types } @@ -1490,7 +1572,7 @@ local CompileVisitors = { end return function(state) ---@param state RuntimeContext return fn(state, largs) - end, fn_data.returns and (fn_data.returns[1] ~= "" and fn_data.returns[1] or nil) + end, fn_data.ret and (fn_data.ret ~= "" and fn_data.ret or nil) else local fn = fn_data.op return function(state) ---@param state RuntimeContext @@ -1500,7 +1582,7 @@ local CompileVisitors = { end return fn(state, rargs, types) - end, fn_data.returns and (fn_data.returns[1] ~= "" and fn_data.returns[1] or nil) + end, fn_data.ret and (fn_data.ret ~= "" and fn_data.ret or nil) end end, @@ -1550,7 +1632,7 @@ local CompileVisitors = { else state:forceThrow("No such method defined at runtime: " .. full_sig) end - end, fn_data.returns and (fn_data.returns[1] ~= "" and fn_data.returns[1] or nil) + end, fn_data.ret and (fn_data.ret ~= "" and fn_data.ret or nil) end elseif fn_data.attrs["legacy"] then local fn, largs = fn_data.op, { [nargs + 3] = types, [2] = { [1] = meta } } @@ -1560,7 +1642,7 @@ local CompileVisitors = { return function(state) ---@param state RuntimeContext return fn(state, largs) - end, fn_data.returns and fn_data.returns[1] + end, fn_data.ret else local fn = fn_data.op return function(state) ---@param state RuntimeContext @@ -1570,117 +1652,143 @@ local CompileVisitors = { end return fn(state, rargs, types) - end, fn_data.returns and fn_data.returns[1] + end, fn_data.ret end end, ---@param data { [1]: Node, [2]: Node[], [3]: Token? } - [NodeVariant.ExprStringCall] = function (self, trace, data) - local expr = self:CompileExpr(data[1]) + [NodeVariant.ExprDynCall] = function (self, trace, data) + local expr, expr_ty = self:CompileExpr(data[1]) local args, arg_types = {}, {} for i, arg in ipairs(data[2]) do args[i], arg_types[i] = self:CompileExpr(arg) end - local type_sig = table.concat(arg_types) - local arg_sig = "(" .. type_sig .. ")" - local meta_arg_sig = #arg_types >= 1 and ("(" .. arg_types[1] .. ":" .. table.concat(arg_types, "", 2) .. ")") or "()" - local ret_type = data[3] and self:CheckType(data[3]) - local nargs = #args - return function(state) ---@param state RuntimeContext - local rargs = {} - for k = 1, nargs do - rargs[k] = args[k](state) - end + if expr_ty == "s" then + self:Warning("String calls are deprecated. Use lambdas instead. This will be an error on @strict in the future.", trace) + self.scope.data.ops = self.scope.data.ops + 25 - local fn_name = expr(state) - local sig, meta_sig = fn_name .. arg_sig, fn_name .. meta_arg_sig + local type_sig = table.concat(arg_types) + local arg_sig = "(" .. type_sig .. ")" + local meta_arg_sig = #arg_types >= 1 and ("(" .. arg_types[1] .. ":" .. table.concat(arg_types, "", 2) .. ")") or "()" - local fn = state.funcs[sig] or state.funcs[meta_sig] - if fn then -- first check if user defined any functions that match signature - local r = state.funcs_ret[sig] or state.funcs_ret[meta_sig] - if r ~= ret_type then - state:forceThrow( "Mismatching return types. Got " .. (r or "void") .. ", expected " .. (ret_type or "void")) + local nargs = #args + return function(state) ---@param state RuntimeContext + local rargs = {} + for k = 1, nargs do + rargs[k] = args[k](state) end - return fn(state, rargs, arg_types) - else -- no user defined functions, check builtins - fn = wire_expression2_funcs[sig] or wire_expression2_funcs[meta_sig] - if fn then - local r = fn[2] - if r ~= ret_type and not (ret_type == nil and r == "") then + local fn_name = expr(state) + local sig, meta_sig = fn_name .. arg_sig, fn_name .. meta_arg_sig + + local fn = state.funcs[sig] or state.funcs[meta_sig] + if fn then -- first check if user defined any functions that match signature + local r = state.funcs_ret[sig] or state.funcs_ret[meta_sig] + if r ~= ret_type then state:forceThrow( "Mismatching return types. Got " .. (r or "void") .. ", expected " .. (ret_type or "void")) end - if fn.attributes.legacy then - local largs = { [1] = {}, [nargs + 2] = arg_types } - for i = 1, nargs do - largs[i + 1] = { [1] = function() return rargs[i] end } + return fn(state, rargs, arg_types) + else -- no user defined functions, check builtins + fn = wire_expression2_funcs[sig] or wire_expression2_funcs[meta_sig] + if fn then + local r = fn[2] + if r ~= ret_type and not (ret_type == nil and r == "") then + state:forceThrow( "Mismatching return types. Got " .. (r or "void") .. ", expected " .. (ret_type or "void")) end - return fn[3](state, largs, arg_types) - else - return fn[3](state, rargs, arg_types) - end - else -- none found, check variadic builtins - for i = nargs, 0, -1 do - local varsig = fn_name .. "(" .. type_sig:sub(1, i) .. "...)" - local fn = wire_expression2_funcs[varsig] - if fn then - local r = fn[2] - if r ~= ret_type and not (ret_type == nil and r == "") then - state:forceThrow("Mismatching return types. Got " .. (r or "void") .. ", expected " .. (ret_type or "void")) - end - if fn.attributes.legacy then - local largs = { [1] = {}, [nargs + 2] = arg_types } - for i = 1, nargs do - largs[i + 1] = { [1] = function() return rargs[i] end } - end - return fn[3](state, largs, arg_types) - elseif varsig == "array(...)" then -- Need this since can't enforce compile time argument type restrictions on string calls. Woop. Array creation should not be a function.. - local i = 1 - while i <= #arg_types do - local ty = arg_types[i] - if BLOCKED_ARRAY_TYPES[ty] then - table.remove(rargs, i) - table.remove(arg_types, i) - state:forceThrow("Cannot use type " .. ty .. " for argument #" .. i .. " in stringcall array creation") - else - i = i + 1 - end - end + if fn.attributes.legacy then + local largs = { [1] = {}, [nargs + 2] = arg_types } + for i = 1, nargs do + largs[i + 1] = { [1] = function() return rargs[i] end } end - - return fn[3](state, rargs, arg_types) + return fn[3](state, largs, arg_types) else - local varsig = fn_name .. "(" .. type_sig:sub(1, i) .. "..r)" - local fn = state.funcs[varsig] - + return fn[3](state, rargs, arg_types) + end + else -- none found, check variadic builtins + for i = nargs, 0, -1 do + local varsig = fn_name .. "(" .. type_sig:sub(1, i) .. "...)" + local fn = wire_expression2_funcs[varsig] if fn then - for _, ty in ipairs(arg_types) do -- Just block them entirely. Current method of finding variadics wouldn't allow a proper solution that works with x types. Would need to rewrite all of this which I don't think is worth it when already nobody is going to use this functionality. - if BLOCKED_ARRAY_TYPES[ty] then - state:forceThrow("Cannot pass array into variadic array function") + local r = fn[2] + if r ~= ret_type and not (ret_type == nil and r == "") then + state:forceThrow("Mismatching return types. Got " .. (r or "void") .. ", expected " .. (ret_type or "void")) + end + + if fn.attributes.legacy then + local largs = { [1] = {}, [nargs + 2] = arg_types } + for i = 1, nargs do + largs[i + 1] = { [1] = function() return rargs[i] end } + end + return fn[3](state, largs, arg_types) + elseif varsig == "array(...)" then -- Need this since can't enforce compile time argument type restrictions on string calls. Woop. Array creation should not be a function.. + local i = 1 + while i <= #arg_types do + local ty = arg_types[i] + if BLOCKED_ARRAY_TYPES[ty] then + table.remove(rargs, i) + table.remove(arg_types, i) + state:forceThrow("Cannot use type " .. ty .. " for argument #" .. i .. " in stringcall array creation") + else + i = i + 1 + end end end - return fn(state, rargs, arg_types) + return fn[3](state, rargs, arg_types) else - local varsig = fn_name .. "(" .. type_sig:sub(1, i) .. "..t)" + local varsig = fn_name .. "(" .. type_sig:sub(1, i) .. "..r)" local fn = state.funcs[varsig] + if fn then + for _, ty in ipairs(arg_types) do -- Just block them entirely. Current method of finding variadics wouldn't allow a proper solution that works with x types. Would need to rewrite all of this which I don't think is worth it when already nobody is going to use this functionality. + if BLOCKED_ARRAY_TYPES[ty] then + state:forceThrow("Cannot pass array into variadic array function") + end + end + return fn(state, rargs, arg_types) + else + local varsig = fn_name .. "(" .. type_sig:sub(1, i) .. "..t)" + local fn = state.funcs[varsig] + if fn then + return fn(state, rargs, arg_types) + end end end end + + state:forceThrow("No such function: " .. fn_name .. arg_sig) end + end + end, ret_type + elseif expr_ty == "f" then + self.scope.data.ops = self.scope.data.ops + 15 -- Since functions are 10 ops, this is pretty lenient. I will decrease this slightly when functions are made static and cheaper. - state:forceThrow("No such function: " .. fn_name .. arg_sig) + local nargs = #args + local sig = table.concat(arg_types) + + return function(state) + local f, rargs = expr(state), {} + for k = 1, nargs do + rargs[k] = args[k](state) end - end - end, ret_type + + local ty, r = f(rargs, sig) + if ty ~= ret_type then -- if void, don't care if it returns something. + state:forceThrow("Expected type " .. (ret_type or "void") .. " from lambda, got " .. (ty or "void")) + end + + return r + end, ret_type + else + self:Error("Cannot call type of " .. expr_ty, trace) + end end, ---@param data { [1]: Token, [2]: Parameter[], [3]: Node } @@ -1829,14 +1937,14 @@ function Compiler:GetFunction(name, types, method) local sig, method_prefix = table.concat(types), method and (method .. ":") or "" local fn = wire_expression2_funcs[name .. "(" .. method_prefix .. sig .. ")"] - if fn then return { op = fn[3], returns = { fn[2] }, args = types, cost = fn[4], attrs = fn.attributes }, false, false end + if fn then return { op = fn[3], ret = fn[2], args = types, cost = fn[4], attrs = fn.attributes }, false, false end local fn, variadic = self:GetUserFunction(name, types, method) if fn then return fn, variadic, true end for i = #sig, 0, -1 do fn = wire_expression2_funcs[name .. "(" .. method_prefix .. sig:sub(1, i) .. "...)"] - if fn then return { op = fn[3], returns = { fn[2] }, args = types, cost = fn[4], attrs = fn.attributes }, true, false end + if fn then return { op = fn[3], ret = fn[2], args = types, cost = fn[4], attrs = fn.attributes }, true, false end end end diff --git a/lua/entities/gmod_wire_expression2/base/parser.lua b/lua/entities/gmod_wire_expression2/base/parser.lua index 8d15fb017e..e8788d0afa 100644 --- a/lua/entities/gmod_wire_expression2/base/parser.lua +++ b/lua/entities/gmod_wire_expression2/base/parser.lua @@ -94,12 +94,13 @@ local NodeVariant = { ExprIndex = 29, -- `[, ?]` ExprGrouped = 30, -- () ExprCall = 31, -- `call()` - ExprStringCall = 32, -- `""()` (Temporary until lambdas are made) + ExprDynCall = 32, -- `Var()` ExprUnaryWire = 33, -- `~Var` `$Var` `->Var` ExprArray = 34, -- `array(1, 2, 3)` or `array(1 = 2, 2 = 3)` ExprTable = 35, -- `table(1, 2, 3)` or `table(1 = 2, "test" = 3)` - ExprLiteral = 36, -- `"test"` `5e2` `4.023` `4j` - ExprIdent = 37 -- `Variable` + ExprFunction = 36, -- `function() {}` + ExprLiteral = 37, -- `"test"` `5e2` `4.023` `4j` + ExprIdent = 38 -- `Variable` } Parser.Variant = NodeVariant @@ -533,8 +534,16 @@ end ---@return Token? function Parser:Type() local type = self:Consume(TokenVariant.LowerIdent) - if type and type.value == "normal" then - type.value = "number" + if type then + if type.value == "normal" then + type.value = "number" + end + else -- workaround to allow "function" as type while also being a keyword + local fn = self:Consume(TokenVariant.Keyword, Keyword.Function) + if fn then + fn.value, fn.variant = "function", TokenVariant.LowerIdent + return fn + end end return type end @@ -885,7 +894,7 @@ function Parser:Expr14() end end - return Node.new(NodeVariant.ExprStringCall, { expr, args, typ }, expr.trace:stitch(self:Prev().trace)) + return Node.new(NodeVariant.ExprDynCall, { expr, args, typ }, expr.trace:stitch(self:Prev().trace)) else break end @@ -915,6 +924,11 @@ function Parser:Expr15() return Node.new(NodeVariant.ExprCall, { fn, self:Arguments() }, fn.trace:stitch(self:Prev().trace)) end + local fn = self:Consume(TokenVariant.Keyword, Keyword.Function) + if fn then + return Node.new(NodeVariant.ExprFunction, { self:Parameters(), self:Assert(self:Block(), "Expected block to follow function") }, fn.trace:stitch(self:Prev().trace)) + end + -- Decimal / Hexadecimal / Binary numbers local num = self:Consume(TokenVariant.Decimal) or self:Consume(TokenVariant.Hexadecimal) or self:Consume(TokenVariant.Binary) if num then diff --git a/lua/entities/gmod_wire_expression2/core/e2lib.lua b/lua/entities/gmod_wire_expression2/core/e2lib.lua index 64864fbc50..dbef930251 100644 --- a/lua/entities/gmod_wire_expression2/core/e2lib.lua +++ b/lua/entities/gmod_wire_expression2/core/e2lib.lua @@ -15,7 +15,7 @@ AddCSLuaFile() ---@class EnvOperator ---@field args TypeSignature[] ----@field returns TypeSignature[] +---@field ret TypeSignature? ---@field op RuntimeOperator ---@field cost integer diff --git a/lua/entities/gmod_wire_expression2/core/functions.lua b/lua/entities/gmod_wire_expression2/core/functions.lua index 217e40dfa4..b5c9cc56f8 100644 --- a/lua/entities/gmod_wire_expression2/core/functions.lua +++ b/lua/entities/gmod_wire_expression2/core/functions.lua @@ -1,30 +1,17 @@ ---[[============================================================ - E2 Function System - By Rusketh - General Operators -============================================================]]-- +--[[ + Lambdas for Expression 2 + Format: fun(args: any[], sig: string): ret_ty string?, ret any +]] -__e2setcost(1) +local function DEFAULT_FUNCTION(state) + state:forceThrow("Invalid function!") +end -registerOperator("function", "", "", function(self, args) - local sig, body = args[2], args[3] - self.funcs[sig] = body - - local cached = self.strfunc_cache[1][sig] - if cached then - self.strfunc_cache[2][ cached[3] ] = nil - self.strfunc_cache[1][sig] = nil +registerType("function", "f", DEFAULT_FUNCTION, + nil, + nil, + nil, + function(v) + return not isfunction(v) end -end) - -__e2setcost(2) - -registerOperator("return", "", "", function(self, args) - if args[2] then - local op = args[2] - local rv = op[1](self, op) - self.func_rv = rv - end - - error("return",0) -end) +) \ No newline at end of file From 12cfb264f54b4928f00010fe0fa0524d186fdde6 Mon Sep 17 00:00:00 2001 From: Vurv <56230599+Vurv78@users.noreply.github.com> Date: Wed, 8 Nov 2023 23:24:19 -0800 Subject: [PATCH 02/11] Add ops on creation --- lua/entities/gmod_wire_expression2/base/compiler.lua | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lua/entities/gmod_wire_expression2/base/compiler.lua b/lua/entities/gmod_wire_expression2/base/compiler.lua index a6a37a6c54..97455ac34f 100644 --- a/lua/entities/gmod_wire_expression2/base/compiler.lua +++ b/lua/entities/gmod_wire_expression2/base/compiler.lua @@ -1293,6 +1293,8 @@ local CompileVisitors = { local ret = fn.ret local expected_sig = table.concat(param_types) + self.scope.data.ops = self.scope.data.ops + 25 + return function(state) local inherited_scopes, after, before = {}, state.ScopeID + 1, state.ScopeID for i = 0, state.ScopeID do From 7e9913dd8b925302676c28c6b1a158fea097f0e5 Mon Sep 17 00:00:00 2001 From: Vurv <56230599+Vurv78@users.noreply.github.com> Date: Thu, 9 Nov 2023 09:33:11 -0800 Subject: [PATCH 03/11] New implementation * Functions are now tables, containing parameter signature, return type and inner function * function:getParameterTypes() * function:getReturnType() * Don't reset global variables on `@strict`, fixing issue with top level locals not working as upvalues since they'd be reset by the runtime. I don't think this would actually affect anyone since you shouldn't be able to use variables before they're assigned, but it's `@strict` only behavior anyway. --- .../gmod_wire_expression2/base/compiler.lua | 108 +++++++++--------- .../gmod_wire_expression2/core/e2lib.lua | 47 ++++++++ .../gmod_wire_expression2/core/functions.lua | 48 ++++++-- lua/entities/gmod_wire_expression2/init.lua | 12 +- 4 files changed, 147 insertions(+), 68 deletions(-) diff --git a/lua/entities/gmod_wire_expression2/base/compiler.lua b/lua/entities/gmod_wire_expression2/base/compiler.lua index 97455ac34f..5a81adfb22 100644 --- a/lua/entities/gmod_wire_expression2/base/compiler.lua +++ b/lua/entities/gmod_wire_expression2/base/compiler.lua @@ -1302,57 +1302,52 @@ local CompileVisitors = { end if ret then - return function(args, given_sig) - if given_sig ~= expected_sig then - state:forceThrow("Incorrect arguments passed to lambda, expected (" .. expected_sig .. "), got (" .. given_sig .. ")") - end - - local s_scopes, s_scope, s_scopeid = state.Scopes, state.Scope, state.ScopeID - - local scope = { vclk = {} } - state.Scopes = inherited_scopes - state.ScopeID = after -- state.ScopeID + 1 - state.Scopes[after] = scope - state.Scope = scope - - for i = 1, nargs do - local arg = args[i] - if arg then - scope[param_names[i]] = arg - else - state:forceThrow("Missing argument [" .. param_names[i] .. "] to be passed into lambda") + return E2Lib.Lambda.new( + expected_sig, + ret, + function(args) + local s_scopes, s_scope, s_scopeid = state.Scopes, state.Scope, state.ScopeID + + local scope = { vclk = {} } + state.Scopes = inherited_scopes + state.ScopeID = after + state.Scopes[after] = scope + state.Scope = scope + + for i = 1, nargs do + scope[param_names[i]] = args[i] end - end - block(state) + block(state) - state.ScopeID, state.Scope, state.Scopes = s_scopeid, s_scope, s_scopes + state.ScopeID, state.Scope, state.Scopes = s_scopeid, s_scope, s_scopes - state.__return__ = false - return ret, state.__returnval__ - end - else -- function without return value, don't handle case. - return function(args, given_sig) - if given_sig ~= expected_sig then - state:forceThrow("Incorrect arguments passed to lambda, expected (" .. expected_sig .. "), got (" .. given_sig .. ")") + state.__return__ = false + return state.__returnval__ end + ) + else -- function without return value, don't handle case. + return E2Lib.Lambda.new( + expected_sig, + ret, + function(args) + local s_scopes, s_scopeid, s_scope = state.Scopes, state.ScopeID, state.Scope + + local scope = { vclk = {} } + state.Scopes = inherited_scopes + state.ScopeID = after + state.Scopes[after] = scope + state.Scope = scope + + for i = 1, nargs do + scope[param_names[i]] = args[i] + end - local s_scopes, s_scopeid, s_scope = state.Scopes, state.ScopeID, state.Scope - - local scope = { vclk = {} } - state.Scopes = inherited_scopes - state.ScopeID = after -- state.ScopeID + 1 - state.Scopes[after] = scope - state.Scope = scope + block(state) - for i = 1, nargs do - scope[param_names[i]] = args[i] + state.Scopes[before], state.ScopeID, state.Scope, state.Scopes = nil, s_scopeid, s_scope, s_scopes end - - block(state) - - state.Scopes[before], state.ScopeID, state.Scope, state.Scopes = nil, s_scopeid, s_scope, s_scopes - end + ) end end, "f" end, @@ -1776,17 +1771,20 @@ local CompileVisitors = { local sig = table.concat(arg_types) return function(state) - local f, rargs = expr(state), {} - for k = 1, nargs do - rargs[k] = args[k](state) - end + ---@type E2Lambda + local f = expr(state) - local ty, r = f(rargs, sig) - if ty ~= ret_type then -- if void, don't care if it returns something. - state:forceThrow("Expected type " .. (ret_type or "void") .. " from lambda, got " .. (ty or "void")) + if f.arg_sig ~= sig then + state:forceThrow("Incorrect arguments passed to lambda, expected (" .. f.arg_sig .. ") got (" .. sig .. ")") + elseif f.ret ~= ret_type then + state:forceThrow("Expected type " .. (ret_type or "void") .. " from lambda, got " .. (f.ret or "void")) + else + local rargs = {} + for k = 1, nargs do + rargs[k] = args[k](state) + end + return f.fn(rargs) end - - return r end, ret_type else self:Error("Cannot call type of " .. expr_ty, trace) @@ -1970,15 +1968,15 @@ end ---@return RuntimeOperator function Compiler:Process(ast) for var, type in pairs(self.persist[3]) do - self.scope:DeclVar(var, { initialized = false, trace_if_unused = self.persist[5][var], type = type }) + self.global_scope:DeclVar(var, { initialized = false, trace_if_unused = self.persist[5][var], type = type }) end for var, type in pairs(self.inputs[3]) do - self.scope:DeclVar(var, { initialized = true, trace_if_unused = self.inputs[5][var], type = type }) + self.global_scope:DeclVar(var, { initialized = true, trace_if_unused = self.inputs[5][var], type = type }) end for var, type in pairs(self.outputs[3]) do - self.scope:DeclVar(var, { initialized = false, type = type }) + self.global_scope:DeclVar(var, { initialized = false, type = type }) end return self:CompileStmt(ast) diff --git a/lua/entities/gmod_wire_expression2/core/e2lib.lua b/lua/entities/gmod_wire_expression2/core/e2lib.lua index dbef930251..0d43522d5f 100644 --- a/lua/entities/gmod_wire_expression2/core/e2lib.lua +++ b/lua/entities/gmod_wire_expression2/core/e2lib.lua @@ -70,6 +70,53 @@ function E2Lib.newE2Table() return { n = {}, ntypes = {}, s = {}, stypes = {}, size = 0 } end +---@class E2Lambda +---@field fn fun(args: any[]): any +---@field arg_sig string +---@field ret string +local Function = {} +Function.__index = Function + +function Function.new(args, ret, fn) + return setmetatable({ arg_sig = args, ret = ret, fn = fn }, Function) +end + +E2Lib.Lambda = Function + +--- Call the function without doing any type checking. +--- Only use this when you check self:Args() yourself to ensure you have the correct signature function. +function Function:UnsafeCall(args) + return self.fn(args) +end + +function Function:Call(args, types) + if self.arg_sig == types then + return self.fn(args) + else + error("Incorrect arguments passed to lambda") + end +end + +function Function:Args() + return self.arg_sig +end + +function Function:Ret() + return self.ret +end + +--- If given the correct arguments, returns the inner untyped function you can call. +--- Otherwise, throws an error to the given E2 Context. +---@param arg_sig string +---@param ctx RuntimeContext +function Function:Unwrap(arg_sig, ctx) + if self.arg_sig == arg_sig then + return self.fn + else + ctx:forceThrow("Incorrect function signature passed, expected (" .. arg_sig .. ") got (" .. self.arg_sig .. ")") + end +end + -- Returns a cloned table of the variable given if it is a table. -- TODO: Ditch this system for instead having users provide a function that returns the default value. -- Would be much more efficient and avoid type checks. diff --git a/lua/entities/gmod_wire_expression2/core/functions.lua b/lua/entities/gmod_wire_expression2/core/functions.lua index b5c9cc56f8..6b9b2de8c5 100644 --- a/lua/entities/gmod_wire_expression2/core/functions.lua +++ b/lua/entities/gmod_wire_expression2/core/functions.lua @@ -1,17 +1,47 @@ --[[ Lambdas for Expression 2 Format: fun(args: any[], sig: string): ret_ty string?, ret any + Format: { arg_sig: string, ret: string, fn: fun(args: any[]): any } ]] -local function DEFAULT_FUNCTION(state) - state:forceThrow("Invalid function!") -end - -registerType("function", "f", DEFAULT_FUNCTION, - nil, - nil, +registerType("function", "f", nil, + function(self) self.entity:Error("You may not input a function") end, + function(self) self.entity:Error("You may not output a function") end, nil, function(v) - return not isfunction(v) + return not istable(v) or getmetatable(v) ~= E2Lib.Lambda end -) \ No newline at end of file +) + +__e2setcost(1) + +e2function number operator_is(function f) + return f and 1 or 0 +end + +local function splitTypeFast(sig) + local i, r, count, len = 1, {}, 0, #sig + while i <= len do + count = count + 1 + if string.sub(sig, i, i) == "x" then + r[count] = string.sub(sig, i, i + 2) + i = i + 3 + else + r[count] = string.sub(sig, i, i) + i = i + 1 + end + end + return r +end + +__e2setcost(5) + +e2function array function:getParameterTypes() + return splitTypeFast(this.arg_sig) +end + +__e2setcost(1) + +e2function string function:getReturnType() + return this.ret or "" +end \ No newline at end of file diff --git a/lua/entities/gmod_wire_expression2/init.lua b/lua/entities/gmod_wire_expression2/init.lua index 266aca895b..b51dedda50 100644 --- a/lua/entities/gmod_wire_expression2/init.lua +++ b/lua/entities/gmod_wire_expression2/init.lua @@ -166,8 +166,10 @@ function ENT:Execute() end self.GlobalScope.vclk = {} - for k, var in pairs(self.globvars_mut) do - self.GlobalScope[k] = fixDefault(wire_expression_types2[var.type][2]) + if not self.directives.strict then + for k, var in pairs(self.globvars_mut) do + self.GlobalScope[k] = fixDefault(wire_expression_types2[var.type][2]) + end end if self.context.prfcount + self.context.prf - e2_softquota > e2_hardquota then @@ -457,8 +459,10 @@ function ENT:ResetContext() self.globvars_mut[k] = nil end - for k, var in pairs(self.globvars_mut) do - self.GlobalScope[k] = fixDefault(wire_expression_types2[var.type][2]) + if not self.directives.strict then -- Need to disable this so local variables at top scope don't get reset + for k, var in pairs(self.globvars_mut) do + self.GlobalScope[k] = fixDefault(wire_expression_types2[var.type][2]) + end end for k, v in pairs(self.Inputs) do From 641a8f28d710c9a7b9b21b8241c1f1302e983f22 Mon Sep 17 00:00:00 2001 From: Vurv <56230599+Vurv78@users.noreply.github.com> Date: Thu, 9 Nov 2023 09:59:15 -0800 Subject: [PATCH 04/11] Fix missing parity change I'd already fixed this bug with functions that have return values but this was not fixed for functions without return values. Also added a test case for this. --- .../tests/runtime/base/lambdas.txt | 35 ++++++++++++++----- .../gmod_wire_expression2/base/compiler.lua | 4 +-- 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/data/expression2/tests/runtime/base/lambdas.txt b/data/expression2/tests/runtime/base/lambdas.txt index 8e8b319fe5..8177d93b87 100644 --- a/data/expression2/tests/runtime/base/lambdas.txt +++ b/data/expression2/tests/runtime/base/lambdas.txt @@ -16,10 +16,10 @@ const Wrapper = function(V:number) { const F1 = Wrapper(55)[function] if (1) { - if (2) { - local V = 22 - assert(F1()[number] == 55) - } + if (2) { + local V = 22 + assert(F1()[number] == 55) + } } assert(F1()[number] == 55) @@ -32,15 +32,15 @@ assert(F2()[number] == 1238) assert(F2()[number] == 1238) const IsEven = function(N:number) { - return N % 2 == 0 + return N % 2 == 0 } const Not = function(N:number) { - return !N + return !N } const IsOdd = function(N:number) { - return Not(IsEven(N)[number])[number] + return Not(IsEven(N)[number])[number] } assert(IsOdd(1)[number] == 1) @@ -49,8 +49,25 @@ assert(IsOdd(2)[number] == 0) assert( ((function() { return function() { return 55 } })()[function])()[number] == 55 ) const Identity = function(N:number) { - return N + return N } assert(Identity(2)[number] == 2) -assert(Identity(2193921)[number] == 2193921) \ No newline at end of file +assert(Identity(2193921)[number] == 2193921) + +local SayMessage = function() {} + +const SetMessage = function(Message:string) { + SayMessage = function() { + return Message + } +} + +SetMessage("There's a snake in my boot!") + +assert( SayMessage()[string] == "There's a snake in my boot!" ) +assert( SayMessage()[string] == "There's a snake in my boot!" ) + +SetMessage("Reach for the sky!") + +assert( SayMessage()[string] == "Reach for the sky!" ) \ No newline at end of file diff --git a/lua/entities/gmod_wire_expression2/base/compiler.lua b/lua/entities/gmod_wire_expression2/base/compiler.lua index 5a81adfb22..3584889c3f 100644 --- a/lua/entities/gmod_wire_expression2/base/compiler.lua +++ b/lua/entities/gmod_wire_expression2/base/compiler.lua @@ -1296,7 +1296,7 @@ local CompileVisitors = { self.scope.data.ops = self.scope.data.ops + 25 return function(state) - local inherited_scopes, after, before = {}, state.ScopeID + 1, state.ScopeID + local inherited_scopes, after = {}, state.ScopeID + 1 for i = 0, state.ScopeID do inherited_scopes[i] = state.Scopes[i] end @@ -1345,7 +1345,7 @@ local CompileVisitors = { block(state) - state.Scopes[before], state.ScopeID, state.Scope, state.Scopes = nil, s_scopeid, s_scope, s_scopes + state.ScopeID, state.Scope, state.Scopes = s_scopeid, s_scope, s_scopes end ) end From 6aebc04342304a75250a2560e47da717e6ebcdd3 Mon Sep 17 00:00:00 2001 From: Vurv <56230599+Vurv78@users.noreply.github.com> Date: Thu, 9 Nov 2023 12:22:58 -0800 Subject: [PATCH 05/11] More tests, enforce returns at compile time * Add tests to ensure variadic parameters, void parameters and implicit parameters aren't allowed * Fix lambdas potentially not returning at all codepaths like functions now are expected to do. Added a test for that. --- .../compiler/compiler/restrictions/fn_void_param.txt | 3 +++ .../compiler/restrictions/lambda/implicit_param.txt | 4 ++++ .../compiler/restrictions/lambda/return_codepaths.txt | 6 ++++++ .../{lambda_return.txt => lambda/return_type_mix.txt} | 2 +- .../compiler/restrictions/lambda/variadic_param.txt | 4 ++++ .../compiler/compiler/restrictions/lambda/void_param.txt | 3 +++ .../compiler/compiler/restrictions/variadic_lambda.txt | 4 ---- lua/entities/gmod_wire_expression2/base/compiler.lua | 9 ++++++++- 8 files changed, 29 insertions(+), 6 deletions(-) create mode 100644 data/expression2/tests/compiler/compiler/restrictions/fn_void_param.txt create mode 100644 data/expression2/tests/compiler/compiler/restrictions/lambda/implicit_param.txt create mode 100644 data/expression2/tests/compiler/compiler/restrictions/lambda/return_codepaths.txt rename data/expression2/tests/compiler/compiler/restrictions/{lambda_return.txt => lambda/return_type_mix.txt} (72%) create mode 100644 data/expression2/tests/compiler/compiler/restrictions/lambda/variadic_param.txt create mode 100644 data/expression2/tests/compiler/compiler/restrictions/lambda/void_param.txt delete mode 100644 data/expression2/tests/compiler/compiler/restrictions/variadic_lambda.txt diff --git a/data/expression2/tests/compiler/compiler/restrictions/fn_void_param.txt b/data/expression2/tests/compiler/compiler/restrictions/fn_void_param.txt new file mode 100644 index 0000000000..5128b062e6 --- /dev/null +++ b/data/expression2/tests/compiler/compiler/restrictions/fn_void_param.txt @@ -0,0 +1,3 @@ +## SHOULD_FAIL:COMPILE + +function test(X:void) {} \ No newline at end of file diff --git a/data/expression2/tests/compiler/compiler/restrictions/lambda/implicit_param.txt b/data/expression2/tests/compiler/compiler/restrictions/lambda/implicit_param.txt new file mode 100644 index 0000000000..7cc9e62e44 --- /dev/null +++ b/data/expression2/tests/compiler/compiler/restrictions/lambda/implicit_param.txt @@ -0,0 +1,4 @@ +## SHOULD_FAIL:COMPILE + +# Implicit number fallback is not going to be allowed. +const J = function(X) {} \ No newline at end of file diff --git a/data/expression2/tests/compiler/compiler/restrictions/lambda/return_codepaths.txt b/data/expression2/tests/compiler/compiler/restrictions/lambda/return_codepaths.txt new file mode 100644 index 0000000000..0c50cbf810 --- /dev/null +++ b/data/expression2/tests/compiler/compiler/restrictions/lambda/return_codepaths.txt @@ -0,0 +1,6 @@ +## SHOULD_FAIL:COMPILE + +const X = function() { + if (1) { return "str" } + # doesn't return string +} \ No newline at end of file diff --git a/data/expression2/tests/compiler/compiler/restrictions/lambda_return.txt b/data/expression2/tests/compiler/compiler/restrictions/lambda/return_type_mix.txt similarity index 72% rename from data/expression2/tests/compiler/compiler/restrictions/lambda_return.txt rename to data/expression2/tests/compiler/compiler/restrictions/lambda/return_type_mix.txt index 9ab6b91f68..6a63dee2bb 100644 --- a/data/expression2/tests/compiler/compiler/restrictions/lambda_return.txt +++ b/data/expression2/tests/compiler/compiler/restrictions/lambda/return_type_mix.txt @@ -1,6 +1,6 @@ ## SHOULD_FAIL:COMPILE -function() { +const X = function() { if (1) { return "str" } return 22 } \ No newline at end of file diff --git a/data/expression2/tests/compiler/compiler/restrictions/lambda/variadic_param.txt b/data/expression2/tests/compiler/compiler/restrictions/lambda/variadic_param.txt new file mode 100644 index 0000000000..4fb88eca1d --- /dev/null +++ b/data/expression2/tests/compiler/compiler/restrictions/lambda/variadic_param.txt @@ -0,0 +1,4 @@ +## SHOULD_FAIL:COMPILE + +const X = function(...A:array) {} +const Y = function(...A:table) {} \ No newline at end of file diff --git a/data/expression2/tests/compiler/compiler/restrictions/lambda/void_param.txt b/data/expression2/tests/compiler/compiler/restrictions/lambda/void_param.txt new file mode 100644 index 0000000000..bf686d24c5 --- /dev/null +++ b/data/expression2/tests/compiler/compiler/restrictions/lambda/void_param.txt @@ -0,0 +1,3 @@ +## SHOULD_FAIL:COMPILE + +const X = function(X:void) {} \ No newline at end of file diff --git a/data/expression2/tests/compiler/compiler/restrictions/variadic_lambda.txt b/data/expression2/tests/compiler/compiler/restrictions/variadic_lambda.txt deleted file mode 100644 index adc2c0cbc6..0000000000 --- a/data/expression2/tests/compiler/compiler/restrictions/variadic_lambda.txt +++ /dev/null @@ -1,4 +0,0 @@ -## SHOULD_FAIL:COMPILE - -function(...A:array) {} -function(...A:table) {} \ No newline at end of file diff --git a/lua/entities/gmod_wire_expression2/base/compiler.lua b/lua/entities/gmod_wire_expression2/base/compiler.lua index 3584889c3f..2036be5500 100644 --- a/lua/entities/gmod_wire_expression2/base/compiler.lua +++ b/lua/entities/gmod_wire_expression2/base/compiler.lua @@ -1282,12 +1282,19 @@ local CompileVisitors = { scope.data["function"] = { "", fn } for i, param in ipairs(data[1]) do + self:Assert(param.type, "Cannot omit parameter type for lambda, annotate with :", param.name.trace) param_names[i], param_types[i] = param.name.value, self:Assert(self:CheckType(param.type), "Cannot use void as parameter", param.name.trace) self:Assert(not param.variadic, "Variadic lambdas are not supported, use an array instead", param.name.trace) scope:DeclVar(param.name.value, { type = param_types[i], initialized = true, trace_if_unused = param.name.trace }) end - return self:CompileStmt(data[2]) + local block = self:CompileStmt(data[2]) + + if fn.ret then -- Ensure function either returns or errors + self:Assert(scope.data.dead, "Not all codepaths return a value of type '" .. fn.ret .. "'", trace) + end + + return block end) local ret = fn.ret From db32af27f7d2e1a4fbd4f4af1ee3fc8f7cc78bae Mon Sep 17 00:00:00 2001 From: Vurv <56230599+Vurv78@users.noreply.github.com> Date: Sun, 12 Nov 2023 16:41:29 -0800 Subject: [PATCH 06/11] Add timers via callbacks --- .../gmod_wire_expression2/core/timer.lua | 157 ++++++++++-------- lua/wire/client/e2descriptions.lua | 6 + 2 files changed, 91 insertions(+), 72 deletions(-) diff --git a/lua/entities/gmod_wire_expression2/core/timer.lua b/lua/entities/gmod_wire_expression2/core/timer.lua index a2c1b97550..f94db23f6e 100644 --- a/lua/entities/gmod_wire_expression2/core/timer.lua +++ b/lua/entities/gmod_wire_expression2/core/timer.lua @@ -1,121 +1,131 @@ -/******************************************************************************\ - Timer support -\******************************************************************************/ +--[[ + Timers +]] -local timerid = 0 +---@type table> +local Timers = {} -local function Execute(self, name) - self.data.timer.runner = name - - self.data['timer'].timers[name] = nil - - if(self.entity and self.entity.Execute) then - self.entity:Execute() - end - - if !self.data['timer'].timers[name] then - timer.Remove("e2_" .. self.data['timer'].timerid .. "_" .. name) - end - - self.data.timer.runner = nil -end - -local function AddTimer(self, name, delay) - if delay < 10 then delay = 10 end - - local timerName = "e2_" .. self.data.timer.timerid .. "_" .. name - - if self.data.timer.runner == name and timer.Exists(timerName) then - timer.Adjust(timerName, delay / 1000, 2, function() - Execute(self, name) - end) - timer.Start(timerName) - elseif !self.data['timer'].timers[name] then - timer.Create(timerName, delay / 1000, 2, function() - Execute(self, name) - end) - end - - self.data['timer'].timers[name] = true +local function addTimer(self, name, delay, reps, callback) + Timers[self][name] = true + timer.Create(("e2timer_%p_%s"):format(self, name), math.max(delay, 1e-2), reps, callback) end -local function RemoveTimer(self, name) - if self.data['timer'].timers[name] then - timer.Remove("e2_" .. self.data['timer'].timerid .. "_" .. name) - self.data['timer'].timers[name] = nil - end +local function removeTimer(self, name) + Timers[self][name] = nil + timer.Remove(("e2timer_%p_%s"):format(self, name)) end -/******************************************************************************/ - registerCallback("construct", function(self) - self.data['timer'] = {} - self.data['timer'].timerid = timerid - self.data['timer'].timers = {} - - timerid = timerid + 1 + Timers[self] = {} end) registerCallback("destruct", function(self) - for name,_ in pairs(self.data['timer'].timers) do - RemoveTimer(self, name) + for name in pairs(Timers[self]) do + removeTimer(self, name) end -end) -/******************************************************************************/ + Timers[self] = nil +end) __e2setcost(20) +---@param self RuntimeContext +local function MAKE_TRIGGER(id, self) + return function() + self.data.timer = id + removeTimer(self, id) + + if self.entity and self.entity.Execute then + self.entity:Execute() + end + + self.data.timer = nil + end +end + +[deprecated = "Use the timer function with callbacks instead"] e2function void interval(rv1) - AddTimer(self, "interval", rv1) + addTimer(self, "interval", rv1 / 1000, 1, MAKE_TRIGGER("interval", self)) end +[deprecated = "Use the timer function with callbacks instead"] e2function void timer(string rv1, rv2) - AddTimer(self, rv1, rv2) + addTimer(self, rv1, rv2 / 1000, 1, MAKE_TRIGGER(rv1, self)) end __e2setcost(5) e2function void stoptimer(string rv1) - RemoveTimer(self, rv1) + removeTimer(self, rv1) end __e2setcost(1) -[nodiscard] +[nodiscard, deprecated = "Use the timer function with callbacks instead"] e2function number clk() - return self.data.timer.runner == "interval" and 1 or 0 + return self.data.timer == "interval" and 1 or 0 end -[nodiscard] +[nodiscard, deprecated = "Use the timer function with callbacks instead"] e2function number clk(string rv1) - return self.data.timer.runner == rv1 and 1 or 0 + return self.data.timer == rv1 and 1 or 0 end -[nodiscard] +[nodiscard, deprecated = "Use the timer function with callbacks instead"] e2function string clkName() - return self.data.timer.runner or "" + return self.data.timer or "" end +__e2setcost(5) + +[deprecated = "Use the timer function with callbacks instead"] e2function array getTimers() - local ret = {} + local ret, timers = {}, Timers[self] + if not timers then return ret end + + self.prf = self.prf + #timers * 2 + local i = 0 - for name in pairs( self.data.timer.timers ) do + for name in pairs(timers) do i = i + 1 ret[i] = name end - self.prf = self.prf + i * 5 + return ret end e2function void stopAllTimers() - for name in pairs(self.data.timer.timers) do - self.prf = self.prf + 5 - RemoveTimer(self,name) + local timers = Timers[self] + if not timers then return end + + self.prf = self.prf + #timers * 2 + + for name in pairs(timers) do + removeTimer(self, name) end end -/******************************************************************************/ +--[[ + Timers 2.0 +]] + +__e2setcost(10) + +e2function void timer(string name, number delay, number reps, function callback) + local fn = callback:Unwrap("", self) + addTimer(self, name, delay, reps, fn) +end + +-- Create "anonymous" timer using address of arguments, which should be different for each function call. +-- Definitely hacky, but should work properly. I think this is better than just incrementing a number infinitely. +e2function void timer(number delay, function callback) + local fn = callback:Unwrap("", self) + addTimer(self,("%p"):format(args), delay, 1, fn) +end + +--[[ + Time Monitoring +]] [nodiscard] e2function number curtime() @@ -132,7 +142,9 @@ e2function number systime() return SysTime() end ------------------------------------------------------------------------------------ +--[[ + Datetime +]] local function luaDateToE2Table( time, utc ) local ret = E2Lib.newE2Table() @@ -170,13 +182,14 @@ e2function table dateUTC() return luaDateToE2Table(nil,true) end --- Returns the specified time formatted neatly in a table using UTC +[nodiscard] e2function table dateUTC( time ) return luaDateToE2Table(time,true) end -- This function has a strange and slightly misleading name, but changing it might break older E2s, so I'm leaving it -- It's essentially the same as the date function above +[nodiscard] e2function number time(string component) local ostime = os.date("!*t") local ret = ostime[component] @@ -188,7 +201,7 @@ end ----------------------------------------------------------------------------------- __e2setcost(2) --- Returns the time in seconds + [nodiscard] e2function number time() return os.time() diff --git a/lua/wire/client/e2descriptions.lua b/lua/wire/client/e2descriptions.lua index 268d4856f7..56251bc1a2 100644 --- a/lua/wire/client/e2descriptions.lua +++ b/lua/wire/client/e2descriptions.lua @@ -942,9 +942,15 @@ E2Helper.Descriptions["getTimers()"] = "Returns an array of all timers used in t E2Helper.Descriptions["interval(n)"] = "Sets a one-time timer with name \"interval\" and delay in milliseconds (minimum delay for timers is 10ms)" E2Helper.Descriptions["runOnTick(n)"] = "If set to 1, the expression will execute once every game tick" E2Helper.Descriptions["timer(sn)"] = "Sets a one-time timer with entered name and delay in milliseconds" + E2Helper.Descriptions["stoptimer(s)"] = "Stops a timer, can stop interval with stoptimer(\"interval\")" E2Helper.Descriptions["stopAllTimers()"] = "Stops all timers" +-- Timers 2.0 +E2Helper.Descriptions["timer(nf)"] = "Sets a callback to run after n seconds" +E2Helper.Descriptions["timer(snf)"] = "Sets a named timer to run a callback after n seconds" +E2Helper.Descriptions["timer(snnf)"] = "Sets a named timer to run a callback after n seconds, repeating n2 times" + -- Unit conversion E2Helper.Descriptions["toUnit(sn)"] = "Converts default garrysmod units to specified units" E2Helper.Descriptions["fromUnit(sn)"] = "Converts specified units to default garrysmod units" From 342f15639824cda4065dc503377b64306987f7e2 Mon Sep 17 00:00:00 2001 From: Vurv <56230599+Vurv78@users.noreply.github.com> Date: Sun, 12 Nov 2023 17:14:39 -0800 Subject: [PATCH 07/11] Execute timer with chip entity, move ops * Ops will increase inside of the actual lambda's body instead of in ExprDynCall, which means it won't cost 0 ops to call them outside of an E2 chip. * Lambdas outside of E2 will now need to be called with ENT:Execute, which now optionally takes a "script" and "args" corresponding to what you'd pass to the lambda. It will handle perf and everything for you. * Remove Lambda:Call, UnsafeCall still exists in case someone wants to handle perf on their own in some extension. --- .../gmod_wire_expression2/base/compiler.lua | 4 ++-- lua/entities/gmod_wire_expression2/core/e2lib.lua | 15 ++++----------- lua/entities/gmod_wire_expression2/core/timer.lua | 12 ++++++++---- lua/entities/gmod_wire_expression2/init.lua | 9 +++++++-- 4 files changed, 21 insertions(+), 19 deletions(-) diff --git a/lua/entities/gmod_wire_expression2/base/compiler.lua b/lua/entities/gmod_wire_expression2/base/compiler.lua index 3bbf959e4a..353f9d090b 100644 --- a/lua/entities/gmod_wire_expression2/base/compiler.lua +++ b/lua/entities/gmod_wire_expression2/base/compiler.lua @@ -1360,6 +1360,8 @@ local CompileVisitors = { function(args) local s_scopes, s_scope, s_scopeid = state.Scopes, state.Scope, state.ScopeID + state.prf = state.prf + 15 + local scope = { vclk = {} } state.Scopes = inherited_scopes state.ScopeID = after @@ -1792,8 +1794,6 @@ local CompileVisitors = { end end, ret_type elseif expr_ty == "f" then - self.scope.data.ops = self.scope.data.ops + 15 -- Since functions are 10 ops, this is pretty lenient. I will decrease this slightly when functions are made static and cheaper. - local nargs = #args local sig = table.concat(arg_types) diff --git a/lua/entities/gmod_wire_expression2/core/e2lib.lua b/lua/entities/gmod_wire_expression2/core/e2lib.lua index 25053fe1b0..6b842beaf7 100644 --- a/lua/entities/gmod_wire_expression2/core/e2lib.lua +++ b/lua/entities/gmod_wire_expression2/core/e2lib.lua @@ -83,20 +83,13 @@ end E2Lib.Lambda = Function ---- Call the function without doing any type checking. ---- Only use this when you check self:Args() yourself to ensure you have the correct signature function. +--- Call the function without doing any type checking, and outside of the proper entity context. +--- **ONLY USE THIS IF YOU KNOW WHAT YOU'RE DOING.** This could hit perf without actually erroring the chip. +--- Also make sure to check self:Args() yourself to ensure you have the correct signature function. function Function:UnsafeCall(args) return self.fn(args) end -function Function:Call(args, types) - if self.arg_sig == types then - return self.fn(args) - else - error("Incorrect arguments passed to lambda") - end -end - function Function:Args() return self.arg_sig end @@ -105,7 +98,7 @@ function Function:Ret() return self.ret end ---- If given the correct arguments, returns the inner untyped function you can call. +--- If given the correct arguments, returns the inner untyped function you can then call with ENT:Execute(f). --- Otherwise, throws an error to the given E2 Context. ---@param arg_sig string ---@param ctx RuntimeContext diff --git a/lua/entities/gmod_wire_expression2/core/timer.lua b/lua/entities/gmod_wire_expression2/core/timer.lua index f94db23f6e..2ef84e7121 100644 --- a/lua/entities/gmod_wire_expression2/core/timer.lua +++ b/lua/entities/gmod_wire_expression2/core/timer.lua @@ -112,15 +112,19 @@ end __e2setcost(10) e2function void timer(string name, number delay, number reps, function callback) - local fn = callback:Unwrap("", self) - addTimer(self, name, delay, reps, fn) + local fn, ent = callback:Unwrap("", self), self.entity + addTimer(self, name, delay, reps, function() + ent:Execute(fn) + end) end -- Create "anonymous" timer using address of arguments, which should be different for each function call. -- Definitely hacky, but should work properly. I think this is better than just incrementing a number infinitely. e2function void timer(number delay, function callback) - local fn = callback:Unwrap("", self) - addTimer(self,("%p"):format(args), delay, 1, fn) + local fn, ent = callback:Unwrap("", self), self.entity + addTimer(self,("%p"):format(args), delay, 1, function() + ent:Execute(fn) + end) end --[[ diff --git a/lua/entities/gmod_wire_expression2/init.lua b/lua/entities/gmod_wire_expression2/init.lua index b51dedda50..d2837a639a 100644 --- a/lua/entities/gmod_wire_expression2/init.lua +++ b/lua/entities/gmod_wire_expression2/init.lua @@ -111,9 +111,12 @@ function ENT:UpdatePerf() end -function ENT:Execute() +function ENT:Execute(script, args) if self.error or not self.context or self.context.resetting then return end + script = script or self.script + args = args or self.context + self:PCallHook('preexecute') self.context.stackdepth = self.context.stackdepth + 1 @@ -124,7 +127,7 @@ function ENT:Execute() local bench = SysTime() - local ok, msg = pcall(self.script, self.context) + local ok, msg = pcall(script, args) if not ok then local _catchable, msg, trace = E2Lib.unpackException(msg) @@ -180,6 +183,8 @@ function ENT:Execute() if self.error then self:Destruct() end + + return msg end ---@param evt string From 1a78d0d6dcea3cf708f4d98a27640ac4eee8a33c Mon Sep 17 00:00:00 2001 From: Vurv <56230599+Vurv78@users.noreply.github.com> Date: Sun, 12 Nov 2023 17:15:51 -0800 Subject: [PATCH 08/11] Add timer(snf) --- .../gmod_wire_expression2/core/timer.lua | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/lua/entities/gmod_wire_expression2/core/timer.lua b/lua/entities/gmod_wire_expression2/core/timer.lua index 2ef84e7121..e79ab5d082 100644 --- a/lua/entities/gmod_wire_expression2/core/timer.lua +++ b/lua/entities/gmod_wire_expression2/core/timer.lua @@ -111,18 +111,25 @@ end __e2setcost(10) -e2function void timer(string name, number delay, number reps, function callback) +-- Create "anonymous" timer using address of arguments, which should be different for each function call. +-- Definitely hacky, but should work properly. I think this is better than just incrementing a number infinitely. +e2function void timer(number delay, function callback) local fn, ent = callback:Unwrap("", self), self.entity - addTimer(self, name, delay, reps, function() + addTimer(self,("%p"):format(args), delay, 1, function() ent:Execute(fn) end) end --- Create "anonymous" timer using address of arguments, which should be different for each function call. --- Definitely hacky, but should work properly. I think this is better than just incrementing a number infinitely. -e2function void timer(number delay, function callback) +e2function void timer(string name, number delay, function callback) local fn, ent = callback:Unwrap("", self), self.entity - addTimer(self,("%p"):format(args), delay, 1, function() + addTimer(self, name, delay, 1, function() + ent:Execute(fn) + end) +end + +e2function void timer(string name, number delay, number reps, function callback) + local fn, ent = callback:Unwrap("", self), self.entity + addTimer(self, name, delay, reps, function() ent:Execute(fn) end) end From 274325de9cec99614f5c826ba4b1fba30d611efd Mon Sep 17 00:00:00 2001 From: Vurv <56230599+Vurv78@users.noreply.github.com> Date: Mon, 1 Jan 2024 20:09:38 -0800 Subject: [PATCH 09/11] Refactor, add per-chip limit * Fix error when chip quotas with timers active * Make getTimers nodiscard, change deprecation message * Optimize code to not create / remove timers needlessly * Auto-remove timer with reps properly --- .../gmod_wire_expression2/core/timer.lua | 74 ++++++++++++++----- 1 file changed, 54 insertions(+), 20 deletions(-) diff --git a/lua/entities/gmod_wire_expression2/core/timer.lua b/lua/entities/gmod_wire_expression2/core/timer.lua index e79ab5d082..0b37f9527c 100644 --- a/lua/entities/gmod_wire_expression2/core/timer.lua +++ b/lua/entities/gmod_wire_expression2/core/timer.lua @@ -2,43 +2,68 @@ Timers ]] ----@type table> +---@type table, count: integer }> local Timers = {} +--- Max timers that can exist at one time per chip. +local MAX_TIMERS = CreateConVar("wire_expression2_timer_max", 100) + local function addTimer(self, name, delay, reps, callback) - Timers[self][name] = true + local timers = Timers[self] + if not timers.lookup[name] then + timers.lookup[name] = true + timers.count = timers.count + 1 + + if timers.count > MAX_TIMERS:GetInt() then + return self:throw("Hit per-chip timer limit of " .. MAX_TIMERS:GetInt() .. "!", nil) + end + end + timer.Create(("e2timer_%p_%s"):format(self, name), math.max(delay, 1e-2), reps, callback) end local function removeTimer(self, name) - Timers[self][name] = nil - timer.Remove(("e2timer_%p_%s"):format(self, name)) + local timers = Timers[self] + if timers.lookup[name] then + timers.lookup[name] = nil + timers.count = timers.count - 1 + + timer.Remove(("e2timer_%p_%s"):format(self, name)) + end end registerCallback("construct", function(self) - Timers[self] = {} + Timers[self] = { lookup = {}, count = 0 } end) registerCallback("destruct", function(self) - for name in pairs(Timers[self]) do + for name in pairs(Timers[self].lookup) do removeTimer(self, name) end Timers[self] = nil end) -__e2setcost(20) +__e2setcost(25) ---@param self RuntimeContext local function MAKE_TRIGGER(id, self) return function() self.data.timer = id - removeTimer(self, id) + + Timers[self].lookup[id] = nil if self.entity and self.entity.Execute then self.entity:Execute() end + if + Timers[self] -- This case is needed if chip tick quotas, which would call destruct hook on :Execute(). + and not Timers[self].lookup[id] + then + removeTimer(self, id) -- only remove if not immediately re-created + end + self.data.timer = nil end end @@ -78,15 +103,13 @@ end __e2setcost(5) -[deprecated = "Use the timer function with callbacks instead"] +[nodiscard, deprecated = "You should keep track of timers with callbacks instead"] e2function array getTimers() local ret, timers = {}, Timers[self] - if not timers then return ret end - - self.prf = self.prf + #timers * 2 + self.prf = self.prf + timers.count * 2 local i = 0 - for name in pairs(timers) do + for name in pairs(timers.lookup) do i = i + 1 ret[i] = name end @@ -96,11 +119,9 @@ end e2function void stopAllTimers() local timers = Timers[self] - if not timers then return end - - self.prf = self.prf + #timers * 2 + self.prf = self.prf + timers.count * 2 - for name in pairs(timers) do + for name in pairs(timers.lookup) do removeTimer(self, name) end end @@ -109,13 +130,20 @@ end Timers 2.0 ]] -__e2setcost(10) +__e2setcost(15) + +local simpletimer = 1 -- Create "anonymous" timer using address of arguments, which should be different for each function call. -- Definitely hacky, but should work properly. I think this is better than just incrementing a number infinitely. e2function void timer(number delay, function callback) local fn, ent = callback:Unwrap("", self), self.entity - addTimer(self,("%p"):format(args), delay, 1, function() + + simpletimer = (simpletimer + 1) % (MAX_TIMERS:GetInt() * 100000000) -- if this ends up overwriting other timers you have a much bigger problem. wrap to avoid inf. + local name = tostring(simpletimer) + + addTimer(self, name, delay, 1, function() + removeTimer(self, name) ent:Execute(fn) end) end @@ -123,13 +151,19 @@ end e2function void timer(string name, number delay, function callback) local fn, ent = callback:Unwrap("", self), self.entity addTimer(self, name, delay, 1, function() + removeTimer(self, name) ent:Execute(fn) end) end e2function void timer(string name, number delay, number reps, function callback) - local fn, ent = callback:Unwrap("", self), self.entity + local fn, ent, rep = callback:Unwrap("", self), self.entity, 0 addTimer(self, name, delay, reps, function() + rep = rep + 1 + if rep == reps then + removeTimer(self, name) + end + ent:Execute(fn) end) end From 8d317649e10c3abd6122f13f2bd00e4aa05388e7 Mon Sep 17 00:00:00 2001 From: Vurv <56230599+Vurv78@users.noreply.github.com> Date: Mon, 1 Jan 2024 20:18:20 -0800 Subject: [PATCH 10/11] Decrease prf cost, clarify change 15 ops to call lambda -> 10 ops --- lua/entities/gmod_wire_expression2/base/compiler.lua | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lua/entities/gmod_wire_expression2/base/compiler.lua b/lua/entities/gmod_wire_expression2/base/compiler.lua index d76288a9fe..3db0047990 100644 --- a/lua/entities/gmod_wire_expression2/base/compiler.lua +++ b/lua/entities/gmod_wire_expression2/base/compiler.lua @@ -1440,7 +1440,9 @@ local CompileVisitors = { function(args) local s_scopes, s_scope, s_scopeid = state.Scopes, state.Scope, state.ScopeID - state.prf = state.prf + 15 + -- Not using `self.scope.data.ops` in order to add prf when builtin functions call lambdas. + -- This behavior may change. + state.prf = state.prf + 10 local scope = { vclk = {} } state.Scopes = inherited_scopes From 0b2c88eaf43add2ca24eb2c51bc667f7b5979cdc Mon Sep 17 00:00:00 2001 From: Vurv <56230599+Vurv78@users.noreply.github.com> Date: Tue, 2 Jan 2024 00:51:31 -0800 Subject: [PATCH 11/11] Remove outdated comment --- lua/entities/gmod_wire_expression2/core/timer.lua | 2 -- 1 file changed, 2 deletions(-) diff --git a/lua/entities/gmod_wire_expression2/core/timer.lua b/lua/entities/gmod_wire_expression2/core/timer.lua index 0b37f9527c..d2d82cf9bf 100644 --- a/lua/entities/gmod_wire_expression2/core/timer.lua +++ b/lua/entities/gmod_wire_expression2/core/timer.lua @@ -134,8 +134,6 @@ __e2setcost(15) local simpletimer = 1 --- Create "anonymous" timer using address of arguments, which should be different for each function call. --- Definitely hacky, but should work properly. I think this is better than just incrementing a number infinitely. e2function void timer(number delay, function callback) local fn, ent = callback:Unwrap("", self), self.entity