From b3410673784184b494097b7395adf7f9697cbf7b Mon Sep 17 00:00:00 2001 From: David Baker Effendi Date: Wed, 21 Aug 2024 16:09:25 +0200 Subject: [PATCH] [ruby] Apply Cached Side-Effect Variables (#4868) For chained calls such as `a().b()`, where the base/receiver of `b` involves a call, to avoid invoking the call in both the base and receiver, we now assign a temporary variable to the first invocation and refer to it on the second invocation. --- .../astcreation/AstCreatorHelper.scala | 13 +- .../AstForExpressionsCreator.scala | 146 ++++++++++++------ .../astcreation/AstForStatementsCreator.scala | 47 +++--- .../rubysrc2cpg/dataflow/CallTests.scala | 23 +-- .../rubysrc2cpg/querying/CallTests.scala | 64 ++++++-- .../rubysrc2cpg/querying/ClassTests.scala | 16 +- .../querying/ControlStructureTests.scala | 2 +- .../rubysrc2cpg/querying/DoBlockTests.scala | 7 +- .../querying/FieldAccessTests.scala | 60 ++++--- .../rubysrc2cpg/querying/MethodTests.scala | 57 +++---- .../querying/SingleAssignmentTests.scala | 2 +- 11 files changed, 256 insertions(+), 181 deletions(-) diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreatorHelper.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreatorHelper.scala index 32f5f29104f0..5dd06d9ff9b5 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreatorHelper.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreatorHelper.scala @@ -115,12 +115,19 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As astForAssignment(Ast(lhs), Ast(rhs), lineNumber, columnNumber) } - protected def astForAssignment(lhs: Ast, rhs: Ast, lineNumber: Option[Int], columnNumber: Option[Int]): Ast = { - val code = Seq(lhs, rhs).flatMap(_.root).collect { case x: ExpressionNew => x.code }.mkString(" = ") + protected def astForAssignment( + lhs: Ast, + rhs: Ast, + lineNumber: Option[Int], + columnNumber: Option[Int], + code: Option[String] = None + ): Ast = { + val _code = + code.getOrElse(Seq(lhs, rhs).flatMap(_.root).collect { case x: ExpressionNew => x.code }.mkString(" = ")) val assignment = NewCall() .name(Operators.assignment) .methodFullName(Operators.assignment) - .code(code) + .code(_code) .dispatchType(DispatchTypes.STATIC_DISPATCH) .lineNumber(lineNumber) .columnNumber(columnNumber) diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala index f0b43a56c2c0..1940049f9e90 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala @@ -17,11 +17,17 @@ import io.shiftleft.codepropertygraph.generated.{ PropertyNames } +import scala.collection.mutable + trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator => val tmpGen: FreshNameGenerator[String] = FreshNameGenerator(i => s"") + /** For tracking aliased calls that occur on the LHS of a member access or call. + */ + protected val baseAstCache = mutable.Map.empty[RubyNode, String] + protected def astForExpression(node: RubyNode): Ast = node match case node: StaticLiteral => astForStaticLiteral(node) case node: HereDocNode => astForHereDoc(node) @@ -180,33 +186,36 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { protected def astForMemberCall(node: MemberCall, isStatic: Boolean = false): Ast = { def createMemberCall(n: MemberCall): Ast = { - val baseAst = n.target match { - case target: MemberAccess => astForFieldAccess(target, stripLeadingAt = true) - case _ => astForExpression(n.target) - } val receiverAst = astForFieldAccess(MemberAccess(n.target, ".", n.methodName)(n.span), stripLeadingAt = true) + val (baseAst, baseCode) = astForMemberAccessTarget(n.target) val builtinType = n.target match { case MemberAccess(_: SelfIdentifier, _, memberName) if isBundledClass(memberName) => Option(prefixAsBundledType(memberName)) case x: TypeIdentifier if x.isBuiltin => Option(x.typeFullName) case _ => None } - val (receiverFullName, methodFullName) = receiverAst.nodes + val methodFullName = receiverAst.nodes .collectFirst { - case _ if builtinType.isDefined => builtinType.get -> s"${builtinType.get}.${n.methodName}" - case x: NewMethodRef => x.methodFullName -> x.methodFullName + case _ if builtinType.isDefined => s"${builtinType.get}.${n.methodName}" + case x: NewMethodRef => x.methodFullName case _ => (n.target match { case ma: MemberAccess => scope.tryResolveTypeReference(ma.memberName).map(_.name) case _ => typeFromCallTarget(n.target) - }).map(x => x -> s"$x.${n.methodName}") - .getOrElse(XDefines.Any -> XDefines.DynamicCallUnknownFullName) + }).map(x => s"$x.${n.methodName}") + .getOrElse(XDefines.DynamicCallUnknownFullName) } - .getOrElse(XDefines.Any -> XDefines.DynamicCallUnknownFullName) + .getOrElse(XDefines.DynamicCallUnknownFullName) val argumentAsts = n.arguments.map(astForMethodCallArgument) val dispatchType = if (isStatic) DispatchTypes.STATIC_DISPATCH else DispatchTypes.DYNAMIC_DISPATCH - val call = callNode(n, code(n), n.methodName, XDefines.DynamicCallUnknownFullName, dispatchType) + val callCode = if (baseCode.contains(" x case None => MemberAccess(SelfIdentifier()(x.span.spanStart(Defines.Self)), ".", x.text)(x.span) } - case x @ MemberAccess(ma, op, memberName) => x.copy(target = determineMemberAccessBase(ma))(x.span) - case _ => target + case x @ MemberAccess(ma, _, _) => x.copy(target = determineMemberAccessBase(ma))(x.span) + case _ => target } node.target match { + case _: LiteralExpr => + createMemberCall(node) case x: SimpleIdentifier if isBundledClass(x.text) => createMemberCall(node.copy(target = TypeIdentifier(prefixAsBundledType(x.text))(x.span))(node.span)) case x: SimpleIdentifier => createMemberCall(node.copy(target = determineMemberAccessBase(x))(node.span)) case memAccess: MemberAccess => createMemberCall(node.copy(target = determineMemberAccessBase(memAccess))(node.span)) - case x => createMemberCall(node) + case _ => createMemberCall(node) + } + } + + protected def astForFieldAccess(node: MemberAccess, stripLeadingAt: Boolean = false): Ast = { + val (memberName, memberCode) = node.target match { + case _ if node.memberName == Defines.Initialize => Defines.Initialize -> Defines.Initialize + case _ if stripLeadingAt => node.memberName -> node.memberName.stripPrefix("@") + case _: TypeIdentifier => node.memberName -> node.memberName + case _ if !node.memberName.startsWith("@") && node.memberName.headOption.exists(_.isLower) => + s"@${node.memberName}" -> node.memberName + case _ => node.memberName -> node.memberName + } + + val fieldIdentifierAst = Ast(fieldIdentifierNode(node, memberName, memberCode)) + val (targetAst, _code) = astForMemberAccessTarget(node.target) + val code = s"$_code${node.op}$memberCode" + val memberType = typeFromCallTarget(node.target) + .flatMap(scope.tryResolveTypeReference) + .map(_.fields) + .getOrElse(List.empty) + .collectFirst { + case x if x.name == memberName => + scope.tryResolveTypeReference(x.typeName).map(_.name).getOrElse(Defines.Any) + } + .orElse(Option(Defines.Any)) + val fieldAccess = callNode( + node, + code, + Operators.fieldAccess, + Operators.fieldAccess, + DispatchTypes.STATIC_DISPATCH, + signature = None, + typeFullName = Option(Defines.Any) + ).possibleTypes(IndexedSeq(memberType.get)) + callAst(fieldAccess, Seq(targetAst, fieldIdentifierAst)) + } + + private def astForMemberAccessTarget(target: RubyNode): (Ast, String) = { + target match { + case simpleLhs: (LiteralExpr | SimpleIdentifier | SelfIdentifier | TypeIdentifier) => + astForExpression(simpleLhs) -> code(target) + case target: MemberAccess => handleTmpGen(target, astForFieldAccess(target, stripLeadingAt = true)) + case target => handleTmpGen(target, astForExpression(target)) + } + } + + private def handleTmpGen(target: RubyNode, rhs: Ast): (Ast, String) = { + // Check cache + val createAssignmentToTmp = !baseAstCache.contains(target) + val tmpName = baseAstCache + .updateWith(target) { + case Some(tmpName) => Option(tmpName) + case None => + val tmpName = tmpGen.fresh + val tmpGenLocal = NewLocal().name(tmpName).code(tmpName).typeFullName(Defines.Any) + scope.addToScope(tmpName, tmpGenLocal) match { + case BlockScope(block) => diffGraph.addEdge(block, tmpGenLocal, EdgeTypes.AST) + case _ => + } + Option(tmpName) + } + .get + val tmpIden = NewIdentifier().name(tmpName).code(tmpName).typeFullName(Defines.Any) + val tmpIdenAst = + scope.lookupVariable(tmpName).map(x => Ast(tmpIden).withRefEdge(tmpIden, x)).getOrElse(Ast(tmpIden)) + val code = s"$tmpName = ${target.text}" + if (createAssignmentToTmp) { + astForAssignment(tmpIdenAst, rhs, target.line, target.column, Option(code)) -> s"($code)" + } else { + tmpIdenAst -> s"($code)" } } @@ -882,43 +963,6 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { astForExpression(assoc) } - protected def astForFieldAccess(node: MemberAccess, stripLeadingAt: Boolean = false): Ast = { - val (memberName, memberCode) = node.target match { - case _ if node.memberName == Defines.Initialize => Defines.Initialize -> Defines.Initialize - case _ if stripLeadingAt => node.memberName -> node.memberName.stripPrefix("@") - case _: TypeIdentifier => node.memberName -> node.memberName - case _ if !node.memberName.startsWith("@") && node.memberName.headOption.exists(_.isLower) => - s"@${node.memberName}" -> node.memberName - case _ => node.memberName -> node.memberName - } - - val fieldIdentifierAst = Ast(fieldIdentifierNode(node, memberName, memberCode)) - val targetAst = node.target match { - case target: MemberAccess => astForFieldAccess(target, stripLeadingAt = true) - case _ => astForExpression(node.target) - } - val code = s"${node.target.text}${node.op}$memberCode" - val memberType = typeFromCallTarget(node.target) - .flatMap(scope.tryResolveTypeReference) - .map(_.fields) - .getOrElse(List.empty) - .collectFirst { - case x if x.name == memberName => - scope.tryResolveTypeReference(x.typeName).map(_.name).getOrElse(Defines.Any) - } - .orElse(Option(Defines.Any)) - val fieldAccess = callNode( - node, - code, - Operators.fieldAccess, - Operators.fieldAccess, - DispatchTypes.STATIC_DISPATCH, - signature = None, - typeFullName = Option(Defines.Any) - ).possibleTypes(IndexedSeq(memberType.get)) - callAst(fieldAccess, Seq(targetAst, fieldIdentifierAst)) - } - protected def astForSplattingRubyNode(node: SplattingRubyNode): Ast = { val splattingCall = callNode(node, code(node), RubyOperators.splat, RubyOperators.splat, DispatchTypes.STATIC_DISPATCH) diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala index fbe43a742078..97c165e13ae3 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala @@ -10,28 +10,31 @@ import io.shiftleft.codepropertygraph.generated.nodes.{NewControlStructure, NewM trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator => - protected def astsForStatement(node: RubyNode): Seq[Ast] = node match - case node: WhileExpression => astForWhileStatement(node) :: Nil - case node: DoWhileExpression => astForDoWhileStatement(node) :: Nil - case node: UntilExpression => astForUntilStatement(node) :: Nil - case node: IfExpression => astForIfStatement(node) :: Nil - case node: UnlessExpression => astForUnlessStatement(node) :: Nil - case node: ForExpression => astForForExpression(node) :: Nil - case node: CaseExpression => astsForCaseExpression(node) - case node: StatementList => astForStatementList(node) :: Nil - case node: SimpleCallWithBlock => astForCallWithBlock(node) :: Nil - case node: MemberCallWithBlock => astForCallWithBlock(node) :: Nil - case node: ReturnExpression => astForReturnStatement(node) :: Nil - case node: AnonymousTypeDeclaration => astForAnonymousTypeDeclaration(node) :: Nil - case node: TypeDeclaration => astForClassDeclaration(node) - case node: FieldsDeclaration => astsForFieldDeclarations(node) - case node: AccessModifier => registerAccessModifier(node) - case node: MethodDeclaration => astForMethodDeclaration(node) - case node: SingletonMethodDeclaration => astForSingletonMethodDeclaration(node) - case node: MultipleAssignment => node.assignments.map(astForExpression) - case node: BreakStatement => astForBreakStatement(node) :: Nil - case node: SingletonStatementList => astForSingletonStatementList(node) - case _ => astForExpression(node) :: Nil + protected def astsForStatement(node: RubyNode): Seq[Ast] = { + baseAstCache.clear() // A safe approximation on where to reset the cache + node match + case node: WhileExpression => astForWhileStatement(node) :: Nil + case node: DoWhileExpression => astForDoWhileStatement(node) :: Nil + case node: UntilExpression => astForUntilStatement(node) :: Nil + case node: IfExpression => astForIfStatement(node) :: Nil + case node: UnlessExpression => astForUnlessStatement(node) :: Nil + case node: ForExpression => astForForExpression(node) :: Nil + case node: CaseExpression => astsForCaseExpression(node) + case node: StatementList => astForStatementList(node) :: Nil + case node: SimpleCallWithBlock => astForCallWithBlock(node) :: Nil + case node: MemberCallWithBlock => astForCallWithBlock(node) :: Nil + case node: ReturnExpression => astForReturnStatement(node) :: Nil + case node: AnonymousTypeDeclaration => astForAnonymousTypeDeclaration(node) :: Nil + case node: TypeDeclaration => astForClassDeclaration(node) + case node: FieldsDeclaration => astsForFieldDeclarations(node) + case node: AccessModifier => registerAccessModifier(node) + case node: MethodDeclaration => astForMethodDeclaration(node) + case node: SingletonMethodDeclaration => astForSingletonMethodDeclaration(node) + case node: MultipleAssignment => node.assignments.map(astForExpression) + case node: BreakStatement => astForBreakStatement(node) :: Nil + case node: SingletonStatementList => astForSingletonStatementList(node) + case _ => astForExpression(node) :: Nil + } private def astForWhileStatement(node: WhileExpression): Ast = { val conditionAst = astForExpression(node.condition) diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/dataflow/CallTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/dataflow/CallTests.scala index d54cff93bd70..df3264db6379 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/dataflow/CallTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/dataflow/CallTests.scala @@ -287,8 +287,7 @@ class CallTests extends RubyCode2CpgFixture(withPostProcessing = true, withDataF | |x = 1 |foo = Foo.new - |y = foo - | .bar(1) + |y = foo.bar(1) |puts y |""".stripMargin) @@ -296,25 +295,13 @@ class CallTests extends RubyCode2CpgFixture(withPostProcessing = true, withDataF val sink = cpg.call.name("puts").argument(1).l val List(flow) = sink.reachableByFlows(src).map(flowToResultPairs).distinct.sortBy(_.length).l flow shouldBe List( - ( - """|foo - | .bar(1)""".stripMargin, - 11 - ), + ("foo.bar(1)", 10), ("bar(self, x)", 3), ("return x", 4), ("RET", 3), - ( - """|foo - | .bar(1)""".stripMargin, - 10 - ), - ( - """|y = foo - | .bar(1)""".stripMargin, - 10 - ), - ("puts y", 12) + ("foo.bar(1)", 10), + ("y = foo.bar(1)", 10), + ("puts y", 11) ) } diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/CallTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/CallTests.scala index 0e9b9b5fcf6a..980906b7b592 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/CallTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/CallTests.scala @@ -177,11 +177,11 @@ class CallTests extends RubyCode2CpgFixture(withPostProcessing = true) { } "create an assignment from a temp variable to the alloc call" in { - inside(cpg.method.isModule.assignment.where(_.target.isIdentifier.name("")).l) { + inside(cpg.method.isModule.assignment.where(_.target.isIdentifier.name("")).l) { case assignment :: Nil => inside(assignment.argument.l) { case (a: Identifier) :: (alloc: Call) :: Nil => - a.name shouldBe "" + a.name shouldBe "" alloc.name shouldBe Operators.alloc alloc.methodFullName shouldBe Operators.alloc @@ -198,7 +198,7 @@ class CallTests extends RubyCode2CpgFixture(withPostProcessing = true) { case constructor :: Nil => inside(constructor.argument.l) { case (a: Identifier) :: (one: Literal) :: (two: Literal) :: Nil => - a.name shouldBe "" + a.name shouldBe "" a.typeFullName shouldBe s"Test0.rb:$Main.A" a.argumentIndex shouldBe 0 @@ -243,18 +243,21 @@ class CallTests extends RubyCode2CpgFixture(withPostProcessing = true) { val recv = constructor.receiver.head.asInstanceOf[Call] recv.methodFullName shouldBe Operators.fieldAccess recv.name shouldBe Operators.fieldAccess - recv.code shouldBe s"params[:type].constantize.${RubyDefines.Initialize}" + recv.code shouldBe s"( = params[:type].constantize).${RubyDefines.Initialize}" - inside(recv.argument.l) { case (constantize: Call) :: (initialize: FieldIdentifier) :: Nil => - constantize.code shouldBe "params[:type].constantize" - inside(constantize.argument.l) { case (indexAccess: Call) :: (const: FieldIdentifier) :: Nil => - indexAccess.name shouldBe Operators.indexAccess - indexAccess.code shouldBe "params[:type]" + recv.argument(2).asInstanceOf[FieldIdentifier].canonicalName shouldBe RubyDefines.Initialize - const.canonicalName shouldBe "constantize" - } + inside(recv.argument(1).start.isCall.argument(2).isCall.argument.l) { + case (paramsAssign: Call) :: (constantize: FieldIdentifier) :: Nil => + paramsAssign.code shouldBe " = params[:type]" + inside(paramsAssign.argument.l) { case (tmpIdent: Identifier) :: (indexAccess: Call) :: Nil => + tmpIdent.name shouldBe "" - initialize.canonicalName shouldBe RubyDefines.Initialize + indexAccess.name shouldBe Operators.indexAccess + indexAccess.code shouldBe "params[:type]" + } + + constantize.canonicalName shouldBe "constantize" } case xs => fail(s"Expected a single alloc, got [${xs.code.mkString(",")}]") } @@ -336,12 +339,12 @@ class CallTests extends RubyCode2CpgFixture(withPostProcessing = true) { val cpg = code("::Augeas.open { |aug| aug.get('/augeas/version') }") val augeasReceiv = cpg.call.nameExact("open").receiver.head.asInstanceOf[Call] augeasReceiv.methodFullName shouldBe Operators.fieldAccess - augeasReceiv.code shouldBe "::Augeas.open" + augeasReceiv.code shouldBe "( = ::Augeas).open" val selfAugeas = augeasReceiv.argument(1).asInstanceOf[Call] - selfAugeas.argument(1).asInstanceOf[Identifier].name shouldBe RubyDefines.Self - selfAugeas.argument(2).asInstanceOf[FieldIdentifier].canonicalName shouldBe "Augeas" + selfAugeas.argument(1).asInstanceOf[Identifier].name shouldBe "" + selfAugeas.argument(2).asInstanceOf[Call].code shouldBe "self::Augeas" augeasReceiv.argument(2).asInstanceOf[FieldIdentifier].canonicalName shouldBe "open" } @@ -364,4 +367,35 @@ class CallTests extends RubyCode2CpgFixture(withPostProcessing = true) { case xs => fail(s"Expected one call to initialize, got ${xs.code.mkString}") } } + + "Member calls where the LHS is a call" should { + + "assign the first call to a temp variable to avoid a second invocation at arg 0" in { + val cpg = code("a().b()") + + val bCall = cpg.call("b").head + bCall.code shouldBe "( = a()).b()" + + // Check receiver + val bAccess = bCall.receiver.isCall.head + bAccess.name shouldBe Operators.fieldAccess + bAccess.methodFullName shouldBe Operators.fieldAccess + bAccess.code shouldBe "( = a()).b" + + bAccess.argument(2).asInstanceOf[FieldIdentifier].canonicalName shouldBe "b" + + val aAssign = bAccess.argument(1).asInstanceOf[Call] + aAssign.name shouldBe Operators.assignment + aAssign.methodFullName shouldBe Operators.assignment + aAssign.code shouldBe " = a()" + + aAssign.argument(1).asInstanceOf[Identifier].name shouldBe "" + aAssign.argument(2).asInstanceOf[Call].name shouldBe "a" + + // Check (cached) base + val base = bCall.argument(0).asInstanceOf[Identifier] + base.name shouldBe "" + } + + } } diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ClassTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ClassTests.scala index a2959d7b49cd..0748a30d7f21 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ClassTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ClassTests.scala @@ -361,9 +361,12 @@ class ClassTests extends RubyCode2CpgFixture { "generate an assignment to the variable `a` with the source being a constructor invocation of the class" in { inside(cpg.method.isModule.assignment.l) { - case aAssignment :: Nil => + case aAssignment :: tmpAssign :: Nil => aAssignment.target.code shouldBe "a" - aAssignment.source.code shouldBe "Class.new (...)" + aAssignment.source.code shouldBe "( = Class.new (...)).new" + + tmpAssign.target.code shouldBe "" + tmpAssign.source.code shouldBe "self.Class.new (...)" case xs => fail(s"Expected a single assignment, but got [${xs.map(x => x.label -> x.code).mkString(",")}]") } } @@ -613,14 +616,9 @@ class ClassTests extends RubyCode2CpgFixture { case Some(bodyCall) => bodyCall.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH bodyCall.methodFullName shouldBe s"Test0.rb:$Main.Foo.${RubyDefines.TypeDeclBody}" - + bodyCall.code shouldBe "( = self::Foo)." bodyCall.receiver.isEmpty shouldBe true - inside(bodyCall.argumentOption(0)) { - case Some(selfArg: Call) => - selfArg.name shouldBe Operators.fieldAccess - selfArg.code shouldBe "self::Foo" - case None => fail("Expected `self` argument") - } + bodyCall.argument(0).code shouldBe "" case None => fail("Expected call") } } diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ControlStructureTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ControlStructureTests.scala index 2d5c43140b89..acd956e619c7 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ControlStructureTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ControlStructureTests.scala @@ -595,7 +595,7 @@ class ControlStructureTests extends RubyCode2CpgFixture { inside(ternary.argument.l) { case condition :: (leftOpt: Block) :: (rightOpt: Block) :: Nil => - condition.code shouldBe "@user.admin" + condition.code shouldBe "( = @user).admin" condition.ast.isFieldIdentifier.code.l shouldBe List("@user", "admin") leftOpt.ast.fieldAccess.code.head shouldBe "User.all" diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/DoBlockTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/DoBlockTests.scala index 4ee394b8eda3..b4a7f04eb0a1 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/DoBlockTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/DoBlockTests.scala @@ -400,11 +400,14 @@ class DoBlockTests extends RubyCode2CpgFixture { |""".stripMargin) inside(cpg.local.l) { - case jfsOutsideLocal :: hashInsideLocal :: jfsCapturedLocal :: _ :: Nil => + case jfsOutsideLocal :: hashInsideLocal :: jfsCapturedLocal :: tmp0 :: tmp1 :: Nil => jfsOutsideLocal.closureBindingId shouldBe None hashInsideLocal.closureBindingId shouldBe None jfsCapturedLocal.closureBindingId shouldBe Some("Test0.rb:
.get_pto_schedule.jfs") - case xs => fail(s"Expected 4 locals, got ${xs.code.mkString(",")}") + + tmp0.name shouldBe "" + tmp1.name shouldBe "" + case xs => fail(s"Expected 5 locals, got ${xs.code.mkString(",")}") } inside(cpg.method.isLambda.local.l) { diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/FieldAccessTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/FieldAccessTests.scala index 620984dbad1a..3212b9310e70 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/FieldAccessTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/FieldAccessTests.scala @@ -109,20 +109,19 @@ class FieldAccessTests extends RubyCode2CpgFixture { } "give external type accesses on script-level the `self.` base" in { - val call = cpg.method.isModule.call.codeExact("Base64::decode64()").head + val call = cpg.method.isModule.call.nameExact("decode64").head call.name shouldBe "decode64" - val base = call.argument(0).asInstanceOf[Call] - base.name shouldBe Operators.fieldAccess - base.code shouldBe "self.Base64" + val base = call.argument(0).asInstanceOf[Identifier] + base.code shouldBe "" val receiver = call.receiver.isCall.head receiver.name shouldBe Operators.fieldAccess - receiver.code shouldBe "Base64.decode64" + receiver.code shouldBe "( = Base64).decode64" val selfArg1 = receiver.argument(1).asInstanceOf[Call] - selfArg1.name shouldBe Operators.fieldAccess - selfArg1.code shouldBe "self.Base64" + selfArg1.name shouldBe Operators.assignment + selfArg1.code shouldBe " = Base64" val selfArg2 = receiver.argument(2).asInstanceOf[FieldIdentifier] selfArg2.canonicalName shouldBe "decode64" @@ -130,20 +129,19 @@ class FieldAccessTests extends RubyCode2CpgFixture { } "give internal type accesses on script-level the `self.` base" in { - val call = cpg.method.isModule.call.codeExact("Baz::func1()").head + val call = cpg.method.isModule.call.nameExact("func1").head call.name shouldBe "func1" - val base = call.argument(0).asInstanceOf[Call] - base.name shouldBe Operators.fieldAccess - base.code shouldBe "self.Baz" + val base = call.argument(0).asInstanceOf[Identifier] + base.code shouldBe "" val receiver = call.receiver.isCall.head receiver.name shouldBe Operators.fieldAccess - receiver.code shouldBe "Baz.func1" + receiver.code shouldBe "( = Baz).func1" val selfArg1 = receiver.argument(1).asInstanceOf[Call] - selfArg1.name shouldBe Operators.fieldAccess - selfArg1.code shouldBe "self.Baz" + selfArg1.name shouldBe Operators.assignment + selfArg1.code shouldBe " = Baz" val selfArg2 = receiver.argument(2).asInstanceOf[FieldIdentifier] selfArg2.canonicalName shouldBe "func1" @@ -175,17 +173,16 @@ class FieldAccessTests extends RubyCode2CpgFixture { val call = cpg.method.nameExact("func").call.nameExact("func1").head call.name shouldBe "func1" - val base = call.argument(0).asInstanceOf[Call] - base.name shouldBe Operators.fieldAccess - base.code shouldBe "self.Baz" + val base = call.argument(0).asInstanceOf[Identifier] + base.code shouldBe "" val receiver = call.receiver.isCall.head receiver.name shouldBe Operators.fieldAccess - receiver.code shouldBe "Baz.func1" + receiver.code shouldBe "( = Baz).func1" val selfArg1 = receiver.argument(1).asInstanceOf[Call] - selfArg1.name shouldBe Operators.fieldAccess - selfArg1.code shouldBe "self.Baz" + selfArg1.name shouldBe Operators.assignment + selfArg1.code shouldBe " = Baz" val selfArg2 = receiver.argument(2).asInstanceOf[FieldIdentifier] selfArg2.canonicalName shouldBe "func1" @@ -211,23 +208,24 @@ class FieldAccessTests extends RubyCode2CpgFixture { "create `TYPE_REF` targets for the field accesses" in { val call = cpg.call.nameExact("func").head - val base = call.argument(0).asInstanceOf[Call] - base.name shouldBe Operators.fieldAccess - base.code shouldBe "A::B" - - base.argument(1).asInstanceOf[TypeRef].typeFullName shouldBe s"Test0.rb:$Main.A" - base.argument(2).asInstanceOf[FieldIdentifier].canonicalName shouldBe "B" + val base = call.argument(0).asInstanceOf[Identifier] + base.code shouldBe "" val receiver = call.receiver.isCall.head receiver.name shouldBe Operators.fieldAccess - receiver.code shouldBe "A::B.func" + receiver.code shouldBe "( = A::B).func" val selfArg1 = receiver.argument(1).asInstanceOf[Call] - selfArg1.name shouldBe Operators.fieldAccess - selfArg1.code shouldBe "A::B" + selfArg1.name shouldBe Operators.assignment + selfArg1.code shouldBe " = A::B" + + selfArg1.argument(1).asInstanceOf[Identifier].code shouldBe s"" + + val abRhs = selfArg1.argument(2).asInstanceOf[Call] + abRhs.code shouldBe "A::B" - selfArg1.argument(1).asInstanceOf[TypeRef].typeFullName shouldBe s"Test0.rb:$Main.A" - selfArg1.argument(2).asInstanceOf[FieldIdentifier].canonicalName shouldBe "B" + abRhs.argument(1).asInstanceOf[TypeRef].typeFullName shouldBe s"Test0.rb:$Main.A" + abRhs.argument(2).asInstanceOf[FieldIdentifier].canonicalName shouldBe "B" val selfArg2 = receiver.argument(2).asInstanceOf[FieldIdentifier] selfArg2.canonicalName shouldBe "func" diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/MethodTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/MethodTests.scala index dc1f38f881a9..ba713567a236 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/MethodTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/MethodTests.scala @@ -7,6 +7,7 @@ import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, NodeTypes, Operators} import io.shiftleft.semanticcpg.language.* +import io.shiftleft.semanticcpg.language.operatorextension.OpNodes.{Assignment, FieldAccess} class MethodTests extends RubyCode2CpgFixture { @@ -558,23 +559,23 @@ class MethodTests extends RubyCode2CpgFixture { leftArg.name shouldBe "a" rightArg.name shouldBe "hexdigest" - rightArg.code shouldBe "Digest::MD5.hexdigest(password)" + rightArg.code shouldBe "( = Digest::MD5).hexdigest(password)" - inside(rightArg.argument.l) { - case (md5: Call) :: (passwordArg: Identifier) :: Nil => - md5.name shouldBe Operators.fieldAccess - md5.code shouldBe "Digest::MD5" + val hexDigestFa = rightArg.receiver.head.asInstanceOf[FieldAccess] + hexDigestFa.code shouldBe "( = Digest::MD5).hexdigest" - val md5Base = md5.argument(1).asInstanceOf[Call] - md5.argument(2).asInstanceOf[FieldIdentifier].canonicalName shouldBe "MD5" + val tmp1Assign = hexDigestFa.argument(1).asInstanceOf[Assignment] + tmp1Assign.code shouldBe " = Digest::MD5" - md5Base.name shouldBe Operators.fieldAccess - md5Base.code shouldBe "self.Digest" + val md5Fa = tmp1Assign.source.asInstanceOf[FieldAccess] + md5Fa.code shouldBe "( = Digest)::MD5" - md5Base.argument(1).asInstanceOf[Identifier].name shouldBe RDefines.Self - md5Base.argument(2).asInstanceOf[FieldIdentifier].canonicalName shouldBe "Digest" - case xs => fail(s"Expected identifier and call, got ${xs.code.mkString(", ")} instead") - } + val tmp0Assign = md5Fa.argument(1).asInstanceOf[Assignment] + tmp0Assign.code shouldBe " = Digest" + + val digestFa = tmp0Assign.source.asInstanceOf[FieldAccess] + digestFa.argument(1).asInstanceOf[Identifier].name shouldBe RDefines.Self + digestFa.argument(2).asInstanceOf[FieldIdentifier].canonicalName shouldBe "Digest" case xs => fail(s"Expected 2 arguments, got ${xs.code.mkString(", ")} instead") } case None => fail("Expected if-condition") @@ -674,12 +675,12 @@ class MethodTests extends RubyCode2CpgFixture { } "be placed in order of definition" in { - inside(cpg.method.name(RDefines.Main).filename("t1.rb").block.astChildren.l) { + inside(cpg.method.name(RDefines.Main).filename("t1.rb").block.astChildren.isCall.l) { case (a1: Call) :: (a2: Call) :: (a3: Call) :: (a4: Call) :: (a5: Call) :: Nil => a1.code shouldBe "self.A = module A (...)" - a2.code shouldBe "self::A::" + a2.code shouldBe "( = self::A)." a3.code shouldBe "self.B = class B (...)" - a4.code shouldBe "self::B::" + a4.code shouldBe "( = self::B)." a5.code shouldBe "self.c = def c (...)" case xs => fail(s"Expected assignments to appear before definitions, instead got [${xs.mkString("\n")}]") } @@ -766,15 +767,15 @@ class MethodTests extends RubyCode2CpgFixture { inside(cpg.call.name(".*retry!").l) { case batchCall :: Nil => batchCall.name shouldBe "retry!" - batchCall.code shouldBe "batch.retry!()" + batchCall.code shouldBe "( = batch).retry!()" inside(batchCall.receiver.l) { case (receiverCall: Call) :: Nil => receiverCall.name shouldBe Operators.fieldAccess - receiverCall.code shouldBe "batch.retry!" + receiverCall.code shouldBe "( = batch).retry!" val selfBatch = receiverCall.argument(1).asInstanceOf[Call] - selfBatch.code shouldBe "self.batch" + selfBatch.code shouldBe " = batch" val retry = receiverCall.argument(2).asInstanceOf[FieldIdentifier] retry.code shouldBe "retry!" @@ -794,15 +795,15 @@ class MethodTests extends RubyCode2CpgFixture { inside(cpg.call.name(".*retry!").l) { case batchCall :: Nil => batchCall.name shouldBe "retry!" - batchCall.code shouldBe "batch::retry!()" + batchCall.code shouldBe "( = batch).retry!()" inside(batchCall.receiver.l) { case (receiverCall: Call) :: Nil => receiverCall.name shouldBe Operators.fieldAccess - receiverCall.code shouldBe "batch.retry!" + receiverCall.code shouldBe "( = batch).retry!" val selfBatch = receiverCall.argument(1).asInstanceOf[Call] - selfBatch.code shouldBe "self.batch" + selfBatch.code shouldBe " = batch" val retry = receiverCall.argument(2).asInstanceOf[FieldIdentifier] retry.code shouldBe "retry!" @@ -822,15 +823,15 @@ class MethodTests extends RubyCode2CpgFixture { inside(cpg.call.name(".*retry!").l) { case batchCall :: Nil => batchCall.name shouldBe "retry!" - batchCall.code shouldBe "retry.retry!()" + batchCall.code shouldBe "( = retry).retry!()" inside(batchCall.receiver.l) { case (receiverCall: Call) :: Nil => receiverCall.name shouldBe Operators.fieldAccess - receiverCall.code shouldBe "retry.retry!" + receiverCall.code shouldBe "( = retry).retry!" val selfBatch = receiverCall.argument(1).asInstanceOf[Call] - selfBatch.code shouldBe "self.retry" + selfBatch.code shouldBe " = retry" val retry = receiverCall.argument(2).asInstanceOf[FieldIdentifier] retry.code shouldBe "retry!" @@ -850,15 +851,15 @@ class MethodTests extends RubyCode2CpgFixture { inside(cpg.call.name(".*retry!").l) { case batchCall :: Nil => batchCall.name shouldBe "retry!" - batchCall.code shouldBe "retry::retry!()" + batchCall.code shouldBe "( = retry).retry!()" inside(batchCall.receiver.l) { case (receiverCall: Call) :: Nil => receiverCall.name shouldBe Operators.fieldAccess - receiverCall.code shouldBe "retry.retry!" + receiverCall.code shouldBe "( = retry).retry!" val selfBatch = receiverCall.argument(1).asInstanceOf[Call] - selfBatch.code shouldBe "self.retry" + selfBatch.code shouldBe " = retry" val retry = receiverCall.argument(2).asInstanceOf[FieldIdentifier] retry.code shouldBe "retry!" diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/SingleAssignmentTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/SingleAssignmentTests.scala index 02622afb5676..4aacdf38f397 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/SingleAssignmentTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/SingleAssignmentTests.scala @@ -319,7 +319,7 @@ class SingleAssignmentTests extends RubyCode2CpgFixture { inside(cpg.method.name("foo").controlStructure.l) { case ifStruct :: Nil => ifStruct.controlStructureType shouldBe ControlStructureTypes.IF - ifStruct.condition.code.l shouldBe List("hash[:id].nil?") + ifStruct.condition.code.l shouldBe List("( = hash[:id]).nil?") inside(ifStruct.whenTrue.ast.isCall.name(Operators.assignment).l) { case assignmentCall :: Nil =>