Skip to content

Commit

Permalink
Tweak logic for better constant expression evaluation (respect file-l…
Browse files Browse the repository at this point in the history
…evel constants, support more node types and cast operations)
  • Loading branch information
blitz-1306 committed Oct 6, 2023
1 parent f467859 commit ac2dfc8
Show file tree
Hide file tree
Showing 6 changed files with 319 additions and 34 deletions.
162 changes: 130 additions & 32 deletions src/types/eval_const.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,17 @@ import {
FunctionCall,
FunctionCallKind,
Identifier,
IndexAccess,
Literal,
LiteralKind,
MemberAccess,
TimeUnit,
TupleExpression,
UnaryOperation,
VariableDeclaration
} from "../ast";
import { assert, pp } from "../misc";
import { IntType, NumericLiteralType } from "./ast";
import { pp } from "../misc";
import { BytesType, FixedBytesType, IntType, NumericLiteralType, StringType } from "./ast";
import { InferType } from "./infer";
import { BINARY_OPERATOR_GROUPS, SUBDENOMINATION_MULTIPLIERS, clampIntToType } from "./utils";
/**
Expand All @@ -27,7 +29,7 @@ import { BINARY_OPERATOR_GROUPS, SUBDENOMINATION_MULTIPLIERS, clampIntToType } f
*/
Decimal.set({ precision: 100 });

export type Value = Decimal | boolean | string | bigint;
export type Value = Decimal | boolean | string | bigint | Buffer;

export class EvalError extends Error {
expr?: Expression;
Expand Down Expand Up @@ -62,14 +64,18 @@ function promoteToDec(v: Value): Decimal {
return new Decimal(v === "" ? 0 : "0x" + Buffer.from(v, "utf-8").toString("hex"));
}

if (v instanceof Buffer) {
return new Decimal(v.length === 0 ? 0 : "0x" + v.toString("hex"));
}

throw new Error(`Expected number not ${v}`);
}

function demoteFromDec(d: Decimal): Decimal | bigint {
return d.isInt() ? BigInt(d.toFixed()) : d;
}

export function isConstant(expr: Expression): boolean {
export function isConstant(expr: Expression | VariableDeclaration): boolean {
if (expr instanceof Literal) {
return true;
}
Expand All @@ -78,6 +84,15 @@ export function isConstant(expr: Expression): boolean {
return true;
}

if (
expr instanceof VariableDeclaration &&
expr.constant &&
expr.vValue &&
isConstant(expr.vValue)
) {
return true;
}

if (
expr instanceof BinaryOperation &&
isConstant(expr.vLeftExpression) &&
Expand Down Expand Up @@ -108,17 +123,19 @@ export function isConstant(expr: Expression): boolean {
return true;
}

if (expr instanceof Identifier) {
const decl = expr.vReferencedDeclaration;
if (expr instanceof Identifier || expr instanceof MemberAccess) {
return (
expr.vReferencedDeclaration instanceof VariableDeclaration &&
isConstant(expr.vReferencedDeclaration)
);
}

if (
decl instanceof VariableDeclaration &&
decl.constant &&
decl.vValue &&
isConstant(decl.vValue)
) {
return true;
}
if (expr instanceof IndexAccess) {
return (
isConstant(expr.vBaseExpression) &&
expr.vIndexExpression !== undefined &&
isConstant(expr.vIndexExpression)
);
}

if (
Expand All @@ -142,7 +159,7 @@ export function evalLiteralImpl(
}

if (kind === LiteralKind.HexString) {
return value === "" ? 0n : BigInt("0x" + value);
return Buffer.from(value, "hex");
}

if (kind === LiteralKind.String || kind === LiteralKind.UnicodeString) {
Expand Down Expand Up @@ -408,22 +425,93 @@ export function evalBinary(node: BinaryOperation, inference: InferType): Value {
}
}

export function evalIndexAccess(node: IndexAccess, inference: InferType): Value {
const base = evalConstantExpr(node.vBaseExpression, inference);
const index = evalConstantExpr(node.vIndexExpression as Expression, inference);

if (!(typeof index === "bigint" || index instanceof Decimal)) {
throw new EvalError(
`Unexpected non-numeric index into base in expression ${pp(node)}`,
node
);
}

const plainIndex = index instanceof Decimal ? index.toNumber() : Number(index);

if (typeof base === "bigint" || base instanceof Decimal) {
let baseHex = base instanceof Decimal ? base.toHex().slice(2) : base.toString(16);

if (baseHex.length % 2 !== 0) {
baseHex = "0" + baseHex;
}

const indexInHex = plainIndex * 2;

return BigInt("0x" + baseHex.slice(indexInHex, indexInHex + 2));
}

if (base instanceof Buffer) {
const res = base.at(plainIndex);

if (res === undefined) {
throw new EvalError(
`Out-of-bounds index access ${plainIndex} to ${base.toString("hex")}`
);
}

return BigInt(res);
}

throw new EvalError(`Unable to process ${pp(node)}`, node);
}

export function evalFunctionCall(node: FunctionCall, inference: InferType): Value {
assert(
node.kind === FunctionCallKind.TypeConversion,
'Expected constant call to be a "{0}", but got "{1}" instead',
FunctionCallKind.TypeConversion,
node.kind
);
if (node.kind !== FunctionCallKind.TypeConversion) {
throw new EvalError(
`Expected function call to have kind "${FunctionCallKind.TypeConversion}", but got "${node.kind}" instead`,
node
);
}

if (!(node.vExpression instanceof ElementaryTypeNameExpression)) {
throw new EvalError(
`Expected function call expression to be an ${ElementaryTypeNameExpression.name}, but got "${node.type}" instead`,
node
);
}

const val = evalConstantExpr(node.vArguments[0], inference);
const castT = inference.typeOfElementaryTypeNameExpression(node.vExpression).type;

if (typeof val === "bigint" && node.vExpression instanceof ElementaryTypeNameExpression) {
const castT = inference.typeOfElementaryTypeNameExpression(node.vExpression);
const toT = castT.type;
if (typeof val === "bigint") {
if (castT instanceof IntType) {
return clampIntToType(val, castT);
}

if (toT instanceof IntType) {
return clampIntToType(val, toT);
if (castT instanceof FixedBytesType) {
return val;
}
}

if (typeof val === "string") {
if (castT instanceof BytesType) {
return Buffer.from(val, "utf-8");
}

if (castT instanceof FixedBytesType) {
const buf = Buffer.from(val, "utf-8");

return BigInt("0x" + buf.slice(0, castT.size).toString("hex"));
}
}

if (val instanceof Buffer) {
if (castT instanceof StringType) {
return val.toString("utf-8");
}

if (castT instanceof FixedBytesType) {
return BigInt("0x" + val.slice(0, castT.size).toString("hex"));
}
}

Expand All @@ -437,7 +525,10 @@ export function evalFunctionCall(node: FunctionCall, inference: InferType): Valu
* @todo The order of some operations changed in some version.
* Current implementation does not yet take it into an account.
*/
export function evalConstantExpr(node: Expression, inference: InferType): Value {
export function evalConstantExpr(
node: Expression | VariableDeclaration,
inference: InferType
): Value {
if (!isConstant(node)) {
throw new NonConstantExpressionError(node);
}
Expand All @@ -464,12 +555,19 @@ export function evalConstantExpr(node: Expression, inference: InferType): Value
: evalConstantExpr(node.vFalseExpression, inference);
}

if (node instanceof Identifier) {
const decl = node.vReferencedDeclaration;
if (node instanceof VariableDeclaration) {
return evalConstantExpr(node.vValue as Expression, inference);
}

if (decl instanceof VariableDeclaration) {
return evalConstantExpr(decl.vValue as Expression, inference);
}
if (node instanceof Identifier || node instanceof MemberAccess) {
return evalConstantExpr(
node.vReferencedDeclaration as Expression | VariableDeclaration,
inference
);
}

if (node instanceof IndexAccess) {
return evalIndexAccess(node, inference);
}

if (node instanceof FunctionCall) {
Expand Down
101 changes: 101 additions & 0 deletions test/integration/eval_const.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import expect from "expect";
import {
assert,
ASTReader,
compileSol,
detectCompileErrors,
evalConstantExpr,
Expression,
InferType,
LatestCompilerVersion,
SourceUnit,
Value,
VariableDeclaration,
XPath
} from "../../src";

const cases: Array<[string, Array<[string, Value]>]> = [
[
"test/samples/solidity/consts/consts.sol",
[
["//VariableDeclaration[@id=5]", 100n],
["//VariableDeclaration[@id=8]", 15n],
["//VariableDeclaration[@id=13]", 115n],
["//VariableDeclaration[@id=18]", 158n],
["//VariableDeclaration[@id=24]", 158n],
["//VariableDeclaration[@id=31]", false],
["//VariableDeclaration[@id=37]", 158n],
["//VariableDeclaration[@id=44]", 85n],
["//VariableDeclaration[@id=47]", "abcd"],
["//VariableDeclaration[@id=53]", Buffer.from("abcd", "utf-8")],
["//VariableDeclaration[@id=58]", 97n],
["//VariableDeclaration[@id=64]", "abcd"],
["//VariableDeclaration[@id=73]", 30841n],
["//VariableDeclaration[@id=82]", 30841n],
["//VariableDeclaration[@id=88]", 258n]
]
]
];

describe("Constant expression evaluator integration test", () => {
for (const [sample, mapping] of cases) {
describe(sample, () => {
let units: SourceUnit[];
let inference: InferType;

before(async () => {
const result = await compileSol(sample, "auto");

const data = result.data;
const compilerVersion = result.compilerVersion || LatestCompilerVersion;

const errors = detectCompileErrors(data);

expect(errors).toHaveLength(0);

const reader = new ASTReader();

units = reader.read(data);

expect(units.length).toBeGreaterThanOrEqual(1);

inference = new InferType(compilerVersion);
});

for (const [selector, expectation] of mapping) {
let found = false;

it(`${selector} -> ${expectation}`, () => {
for (const unit of units) {
const results = new XPath(unit).query(selector);

if (results.length > 0) {
const [expr] = results;

assert(
expr instanceof Expression || expr instanceof VariableDeclaration,
`Expected selector result to be an {0} or {1} descendant, got {2} instead`,
Expression.name,
VariableDeclaration.name,
expr
);

found = true;

expect(evalConstantExpr(expr, inference)).toEqual(expectation);

break;
}
}

assert(
found,
`Selector "{0}" not found in source units of sample "{1}"`,
selector,
sample
);
});
}
});
}
});
22 changes: 22 additions & 0 deletions test/samples/solidity/consts/consts.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import "./lib_a.sol";
import "./lib_a.sol" as LibA;

uint constant SOME_CONST = 100;
uint constant SOME_OTHER = 15;
uint constant SOME_ELSE = SOME_CONST + SOME_OTHER;
uint constant C2 = SOME_ELSE + ANOTHER_CONST;
uint constant C3 = SOME_ELSE + LibA.ANOTHER_CONST;
uint constant C4 = -SOME_CONST;
bool constant C5 = false;
uint constant C6 = C5 ? SOME_ELSE : C3;
uint constant C7 = LibA.ANOTHER_CONST + LibB.AND_ANOTHER_CONST;
// uint constant C8 = LibA.ANOTHER_CONST + LibA.LibB.AND_ANOTHER_CONST;

string constant FOO = "abcd";
bytes constant BOO = bytes("abcd");
bytes1 constant MOO = BOO[0];
string constant WOO = string(BOO);

uint16 constant U16S = uint16(bytes2("xy"));
uint16 constant U16B = uint16(bytes2(hex"7879"));
bytes2 constant B2U = bytes2(0x0102);
5 changes: 5 additions & 0 deletions test/samples/solidity/consts/lib_a.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pragma solidity ^0.7.5;

import "./lib_b.sol" as LibB;

uint constant ANOTHER_CONST = LibB.AND_ANOTHER_CONST + 1;
3 changes: 3 additions & 0 deletions test/samples/solidity/consts/lib_b.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
pragma solidity ^0.7.5;

uint constant AND_ANOTHER_CONST = 42;
Loading

0 comments on commit ac2dfc8

Please sign in to comment.