Skip to content

Commit

Permalink
[ruby] Apply Cached Side-Effect Variables (#4868)
Browse files Browse the repository at this point in the history
For chained calls such as `a().b()`, where the base/receiver of `b` involves a call, to avoid invoking the call in both the base and receiver, we now assign a temporary variable to the first invocation and refer to it on the second invocation.
  • Loading branch information
DavidBakerEffendi authored Aug 21, 2024
1 parent caf75bc commit b341067
Show file tree
Hide file tree
Showing 11 changed files with 256 additions and 181 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,19 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As
astForAssignment(Ast(lhs), Ast(rhs), lineNumber, columnNumber)
}

protected def astForAssignment(lhs: Ast, rhs: Ast, lineNumber: Option[Int], columnNumber: Option[Int]): Ast = {
val code = Seq(lhs, rhs).flatMap(_.root).collect { case x: ExpressionNew => x.code }.mkString(" = ")
protected def astForAssignment(
lhs: Ast,
rhs: Ast,
lineNumber: Option[Int],
columnNumber: Option[Int],
code: Option[String] = None
): Ast = {
val _code =
code.getOrElse(Seq(lhs, rhs).flatMap(_.root).collect { case x: ExpressionNew => x.code }.mkString(" = "))
val assignment = NewCall()
.name(Operators.assignment)
.methodFullName(Operators.assignment)
.code(code)
.code(_code)
.dispatchType(DispatchTypes.STATIC_DISPATCH)
.lineNumber(lineNumber)
.columnNumber(columnNumber)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,17 @@ import io.shiftleft.codepropertygraph.generated.{
PropertyNames
}

import scala.collection.mutable

trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
this: AstCreator =>

val tmpGen: FreshNameGenerator[String] = FreshNameGenerator(i => s"<tmp-$i>")

/** For tracking aliased calls that occur on the LHS of a member access or call.
*/
protected val baseAstCache = mutable.Map.empty[RubyNode, String]

protected def astForExpression(node: RubyNode): Ast = node match
case node: StaticLiteral => astForStaticLiteral(node)
case node: HereDocNode => astForHereDoc(node)
Expand Down Expand Up @@ -180,33 +186,36 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
protected def astForMemberCall(node: MemberCall, isStatic: Boolean = false): Ast = {

def createMemberCall(n: MemberCall): Ast = {
val baseAst = n.target match {
case target: MemberAccess => astForFieldAccess(target, stripLeadingAt = true)
case _ => astForExpression(n.target)
}
val receiverAst = astForFieldAccess(MemberAccess(n.target, ".", n.methodName)(n.span), stripLeadingAt = true)
val (baseAst, baseCode) = astForMemberAccessTarget(n.target)
val builtinType = n.target match {
case MemberAccess(_: SelfIdentifier, _, memberName) if isBundledClass(memberName) =>
Option(prefixAsBundledType(memberName))
case x: TypeIdentifier if x.isBuiltin => Option(x.typeFullName)
case _ => None
}
val (receiverFullName, methodFullName) = receiverAst.nodes
val methodFullName = receiverAst.nodes
.collectFirst {
case _ if builtinType.isDefined => builtinType.get -> s"${builtinType.get}.${n.methodName}"
case x: NewMethodRef => x.methodFullName -> x.methodFullName
case _ if builtinType.isDefined => s"${builtinType.get}.${n.methodName}"
case x: NewMethodRef => x.methodFullName
case _ =>
(n.target match {
case ma: MemberAccess => scope.tryResolveTypeReference(ma.memberName).map(_.name)
case _ => typeFromCallTarget(n.target)
}).map(x => x -> s"$x.${n.methodName}")
.getOrElse(XDefines.Any -> XDefines.DynamicCallUnknownFullName)
}).map(x => s"$x.${n.methodName}")
.getOrElse(XDefines.DynamicCallUnknownFullName)
}
.getOrElse(XDefines.Any -> XDefines.DynamicCallUnknownFullName)
.getOrElse(XDefines.DynamicCallUnknownFullName)
val argumentAsts = n.arguments.map(astForMethodCallArgument)
val dispatchType = if (isStatic) DispatchTypes.STATIC_DISPATCH else DispatchTypes.DYNAMIC_DISPATCH

val call = callNode(n, code(n), n.methodName, XDefines.DynamicCallUnknownFullName, dispatchType)
val callCode = if (baseCode.contains("<tmp-")) {
val rhsCode = if (n.methodName == "new") n.methodName else code(n).replace("::", ".").split('.').last
s"$baseCode.$rhsCode"
} else {
code(n)
}
val call = callNode(n, callCode, n.methodName, XDefines.DynamicCallUnknownFullName, dispatchType)
if methodFullName != XDefines.DynamicCallUnknownFullName then call.possibleTypes(Seq(methodFullName))
if (isStatic) {
callAst(call, argumentAsts, base = Option(baseAst)).copy(receiverEdges = Nil)
Expand All @@ -225,18 +234,90 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
case None if scope.lookupVariable(x.text).isDefined => x
case None => MemberAccess(SelfIdentifier()(x.span.spanStart(Defines.Self)), ".", x.text)(x.span)
}
case x @ MemberAccess(ma, op, memberName) => x.copy(target = determineMemberAccessBase(ma))(x.span)
case _ => target
case x @ MemberAccess(ma, _, _) => x.copy(target = determineMemberAccessBase(ma))(x.span)
case _ => target
}

node.target match {
case _: LiteralExpr =>
createMemberCall(node)
case x: SimpleIdentifier if isBundledClass(x.text) =>
createMemberCall(node.copy(target = TypeIdentifier(prefixAsBundledType(x.text))(x.span))(node.span))
case x: SimpleIdentifier =>
createMemberCall(node.copy(target = determineMemberAccessBase(x))(node.span))
case memAccess: MemberAccess =>
createMemberCall(node.copy(target = determineMemberAccessBase(memAccess))(node.span))
case x => createMemberCall(node)
case _ => createMemberCall(node)
}
}

protected def astForFieldAccess(node: MemberAccess, stripLeadingAt: Boolean = false): Ast = {
val (memberName, memberCode) = node.target match {
case _ if node.memberName == Defines.Initialize => Defines.Initialize -> Defines.Initialize
case _ if stripLeadingAt => node.memberName -> node.memberName.stripPrefix("@")
case _: TypeIdentifier => node.memberName -> node.memberName
case _ if !node.memberName.startsWith("@") && node.memberName.headOption.exists(_.isLower) =>
s"@${node.memberName}" -> node.memberName
case _ => node.memberName -> node.memberName
}

val fieldIdentifierAst = Ast(fieldIdentifierNode(node, memberName, memberCode))
val (targetAst, _code) = astForMemberAccessTarget(node.target)
val code = s"$_code${node.op}$memberCode"
val memberType = typeFromCallTarget(node.target)
.flatMap(scope.tryResolveTypeReference)
.map(_.fields)
.getOrElse(List.empty)
.collectFirst {
case x if x.name == memberName =>
scope.tryResolveTypeReference(x.typeName).map(_.name).getOrElse(Defines.Any)
}
.orElse(Option(Defines.Any))
val fieldAccess = callNode(
node,
code,
Operators.fieldAccess,
Operators.fieldAccess,
DispatchTypes.STATIC_DISPATCH,
signature = None,
typeFullName = Option(Defines.Any)
).possibleTypes(IndexedSeq(memberType.get))
callAst(fieldAccess, Seq(targetAst, fieldIdentifierAst))
}

private def astForMemberAccessTarget(target: RubyNode): (Ast, String) = {
target match {
case simpleLhs: (LiteralExpr | SimpleIdentifier | SelfIdentifier | TypeIdentifier) =>
astForExpression(simpleLhs) -> code(target)
case target: MemberAccess => handleTmpGen(target, astForFieldAccess(target, stripLeadingAt = true))
case target => handleTmpGen(target, astForExpression(target))
}
}

private def handleTmpGen(target: RubyNode, rhs: Ast): (Ast, String) = {
// Check cache
val createAssignmentToTmp = !baseAstCache.contains(target)
val tmpName = baseAstCache
.updateWith(target) {
case Some(tmpName) => Option(tmpName)
case None =>
val tmpName = tmpGen.fresh
val tmpGenLocal = NewLocal().name(tmpName).code(tmpName).typeFullName(Defines.Any)
scope.addToScope(tmpName, tmpGenLocal) match {
case BlockScope(block) => diffGraph.addEdge(block, tmpGenLocal, EdgeTypes.AST)
case _ =>
}
Option(tmpName)
}
.get
val tmpIden = NewIdentifier().name(tmpName).code(tmpName).typeFullName(Defines.Any)
val tmpIdenAst =
scope.lookupVariable(tmpName).map(x => Ast(tmpIden).withRefEdge(tmpIden, x)).getOrElse(Ast(tmpIden))
val code = s"$tmpName = ${target.text}"
if (createAssignmentToTmp) {
astForAssignment(tmpIdenAst, rhs, target.line, target.column, Option(code)) -> s"($code)"
} else {
tmpIdenAst -> s"($code)"
}
}

Expand Down Expand Up @@ -882,43 +963,6 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
astForExpression(assoc)
}

protected def astForFieldAccess(node: MemberAccess, stripLeadingAt: Boolean = false): Ast = {
val (memberName, memberCode) = node.target match {
case _ if node.memberName == Defines.Initialize => Defines.Initialize -> Defines.Initialize
case _ if stripLeadingAt => node.memberName -> node.memberName.stripPrefix("@")
case _: TypeIdentifier => node.memberName -> node.memberName
case _ if !node.memberName.startsWith("@") && node.memberName.headOption.exists(_.isLower) =>
s"@${node.memberName}" -> node.memberName
case _ => node.memberName -> node.memberName
}

val fieldIdentifierAst = Ast(fieldIdentifierNode(node, memberName, memberCode))
val targetAst = node.target match {
case target: MemberAccess => astForFieldAccess(target, stripLeadingAt = true)
case _ => astForExpression(node.target)
}
val code = s"${node.target.text}${node.op}$memberCode"
val memberType = typeFromCallTarget(node.target)
.flatMap(scope.tryResolveTypeReference)
.map(_.fields)
.getOrElse(List.empty)
.collectFirst {
case x if x.name == memberName =>
scope.tryResolveTypeReference(x.typeName).map(_.name).getOrElse(Defines.Any)
}
.orElse(Option(Defines.Any))
val fieldAccess = callNode(
node,
code,
Operators.fieldAccess,
Operators.fieldAccess,
DispatchTypes.STATIC_DISPATCH,
signature = None,
typeFullName = Option(Defines.Any)
).possibleTypes(IndexedSeq(memberType.get))
callAst(fieldAccess, Seq(targetAst, fieldIdentifierAst))
}

protected def astForSplattingRubyNode(node: SplattingRubyNode): Ast = {
val splattingCall =
callNode(node, code(node), RubyOperators.splat, RubyOperators.splat, DispatchTypes.STATIC_DISPATCH)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,31 @@ import io.shiftleft.codepropertygraph.generated.nodes.{NewControlStructure, NewM

trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator =>

protected def astsForStatement(node: RubyNode): Seq[Ast] = node match
case node: WhileExpression => astForWhileStatement(node) :: Nil
case node: DoWhileExpression => astForDoWhileStatement(node) :: Nil
case node: UntilExpression => astForUntilStatement(node) :: Nil
case node: IfExpression => astForIfStatement(node) :: Nil
case node: UnlessExpression => astForUnlessStatement(node) :: Nil
case node: ForExpression => astForForExpression(node) :: Nil
case node: CaseExpression => astsForCaseExpression(node)
case node: StatementList => astForStatementList(node) :: Nil
case node: SimpleCallWithBlock => astForCallWithBlock(node) :: Nil
case node: MemberCallWithBlock => astForCallWithBlock(node) :: Nil
case node: ReturnExpression => astForReturnStatement(node) :: Nil
case node: AnonymousTypeDeclaration => astForAnonymousTypeDeclaration(node) :: Nil
case node: TypeDeclaration => astForClassDeclaration(node)
case node: FieldsDeclaration => astsForFieldDeclarations(node)
case node: AccessModifier => registerAccessModifier(node)
case node: MethodDeclaration => astForMethodDeclaration(node)
case node: SingletonMethodDeclaration => astForSingletonMethodDeclaration(node)
case node: MultipleAssignment => node.assignments.map(astForExpression)
case node: BreakStatement => astForBreakStatement(node) :: Nil
case node: SingletonStatementList => astForSingletonStatementList(node)
case _ => astForExpression(node) :: Nil
protected def astsForStatement(node: RubyNode): Seq[Ast] = {
baseAstCache.clear() // A safe approximation on where to reset the cache
node match
case node: WhileExpression => astForWhileStatement(node) :: Nil
case node: DoWhileExpression => astForDoWhileStatement(node) :: Nil
case node: UntilExpression => astForUntilStatement(node) :: Nil
case node: IfExpression => astForIfStatement(node) :: Nil
case node: UnlessExpression => astForUnlessStatement(node) :: Nil
case node: ForExpression => astForForExpression(node) :: Nil
case node: CaseExpression => astsForCaseExpression(node)
case node: StatementList => astForStatementList(node) :: Nil
case node: SimpleCallWithBlock => astForCallWithBlock(node) :: Nil
case node: MemberCallWithBlock => astForCallWithBlock(node) :: Nil
case node: ReturnExpression => astForReturnStatement(node) :: Nil
case node: AnonymousTypeDeclaration => astForAnonymousTypeDeclaration(node) :: Nil
case node: TypeDeclaration => astForClassDeclaration(node)
case node: FieldsDeclaration => astsForFieldDeclarations(node)
case node: AccessModifier => registerAccessModifier(node)
case node: MethodDeclaration => astForMethodDeclaration(node)
case node: SingletonMethodDeclaration => astForSingletonMethodDeclaration(node)
case node: MultipleAssignment => node.assignments.map(astForExpression)
case node: BreakStatement => astForBreakStatement(node) :: Nil
case node: SingletonStatementList => astForSingletonStatementList(node)
case _ => astForExpression(node) :: Nil
}

private def astForWhileStatement(node: WhileExpression): Ast = {
val conditionAst = astForExpression(node.condition)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,34 +287,21 @@ class CallTests extends RubyCode2CpgFixture(withPostProcessing = true, withDataF
|
|x = 1
|foo = Foo.new
|y = foo
| .bar(1)
|y = foo.bar(1)
|puts y
|""".stripMargin)

val src = cpg.literal.code("1").l
val sink = cpg.call.name("puts").argument(1).l
val List(flow) = sink.reachableByFlows(src).map(flowToResultPairs).distinct.sortBy(_.length).l
flow shouldBe List(
(
"""|foo
| .bar(1)""".stripMargin,
11
),
("foo.bar(1)", 10),
("bar(self, x)", 3),
("return x", 4),
("RET", 3),
(
"""|foo
| .bar(1)""".stripMargin,
10
),
(
"""|y = foo
| .bar(1)""".stripMargin,
10
),
("puts y", 12)
("foo.bar(1)", 10),
("y = foo.bar(1)", 10),
("puts y", 11)
)
}

Expand Down
Loading

0 comments on commit b341067

Please sign in to comment.