Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rubysrc2cpg] Enhancing TypeDecl implementation #3358

Merged
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class AstCreator(
protected val methodAliases = mutable.HashMap[String, String]()
protected val methodNameToMethod = mutable.HashMap[String, nodes.NewMethod]()

protected val typeDeclNameToTypeDecl = mutable.HashMap[String, nodes.NewTypeDecl]()

protected val methodNamesWithYield = mutable.HashSet[String]()

/*
Expand Down Expand Up @@ -190,12 +192,35 @@ class AstCreator(
methodRefAssignmentAst
}.toList

val typeRefAssignmentAst = typeDeclNameToTypeDecl.values.map { typeDeclNode =>

val typeRefNode = NewTypeRef()
.code("class " + typeDeclNode.name + "(...)")
.typeFullName(typeDeclNode.fullName)
.lineNumber(typeDeclNode.lineNumber)
.columnNumber(typeDeclNode.columnNumber)

val typeDeclNameIdentifier = NewIdentifier()
.code(typeDeclNode.name)
.name(typeDeclNode.name)
.typeFullName(Defines.Any)
.lineNumber(lineColNum)
.columnNumber(lineColNum)

val typeRefAssignmentAst =
astForAssignment(typeDeclNameIdentifier, typeRefNode, typeDeclNode.lineNumber, typeDeclNode.columnNumber)
typeRefAssignmentAst
}

val blockNode = NewBlock().typeFullName(Defines.Any)
val programAst =
methodAst(
programMethod,
Seq(Ast()),
blockAst(blockNode, statementAsts.toList ++ builtInMethodAst ++ methodRefAssignmentAsts),
blockAst(
blockNode,
statementAsts.toList ++ builtInMethodAst ++ methodRefAssignmentAsts ++ typeRefAssignmentAst
),
methodRetNode
)

Expand Down Expand Up @@ -1121,65 +1146,43 @@ class AstCreator(

private def astForParametersContext(ctx: ParametersContext): Seq[Ast] = {
if (ctx == null) return Seq()
val localVarList = ListBuffer[Option[TerminalNode]]()
// NOT differentiating between the productions here since either way we get parameters
val mandatoryParameters = ctx
.parameter()
.asScala
.filter(ctx => Option(ctx.mandatoryParameter()).isDefined)
.map(ctx => Option(ctx.mandatoryParameter().LOCAL_VARIABLE_IDENTIFIER()))
val optionalParameters = ctx
.parameter()
.asScala
.filter(ctx => Option(ctx.optionalParameter()).isDefined)
.map(ctx => Option(ctx.optionalParameter().LOCAL_VARIABLE_IDENTIFIER()))
val arrayParameter = ctx
.parameter()
.asScala
.filter(ctx => Option(ctx.arrayParameter()).isDefined)
.map(ctx => Option(ctx.arrayParameter().LOCAL_VARIABLE_IDENTIFIER()))
val procParameter = ctx
.parameter()
.asScala
.filter(ctx => Option(ctx.procParameter()).isDefined)
.map(ctx => Option(ctx.procParameter().LOCAL_VARIABLE_IDENTIFIER()))

localVarList.addAll(mandatoryParameters)
localVarList.addAll(optionalParameters)
localVarList.addAll(arrayParameter)
localVarList.addAll(procParameter)

localVarList.map {
case localVar @ Some(paramContext) => {
val varSymbol = paramContext.getSymbol
createIdentifierWithScope(ctx, varSymbol.getText, varSymbol.getText, Defines.Any, Seq[String](Defines.Any))
val param = NewMethodParameterIn()
.name(varSymbol.getText)
.code(varSymbol.getText)
.lineNumber(varSymbol.getLine)
.typeFullName(Defines.Any)
.columnNumber(varSymbol.getCharPositionInLine)
if (Option(arrayParameter).isDefined) {
param.isVariadic = true
}
Ast(param)
}
case localVar @ _ => {
val identifierName = getUnusedVariableNames(usedVariableNames, Defines.TempIdentifier)
val parameterName = getUnusedVariableNames(usedVariableNames, Defines.TempParameter)
createIdentifierWithScope(ctx, identifierName, identifierName, Defines.Any, Seq[String](Defines.Any))
val param = NewMethodParameterIn()
.name(parameterName)
.code(parameterName)
.lineNumber(None)
.typeFullName(Defines.Any)
.columnNumber(None)
if (Option(arrayParameter).isDefined) {
param.isVariadic = true
}
Ast(param)
}
}.toSeq

// the parameterTupleList holds the parameter terminal node and is the parameter a variadic parameter
val parameterTupleList = ctx.parameter().asScala.map {
case procCtx if procCtx.procParameter() != null =>
(Option(procCtx.procParameter().LOCAL_VARIABLE_IDENTIFIER()), false)
case optCtx if optCtx.optionalParameter() != null =>
(Option(optCtx.optionalParameter().LOCAL_VARIABLE_IDENTIFIER()), false)
case manCtx if manCtx.mandatoryParameter() != null =>
(Option(manCtx.mandatoryParameter().LOCAL_VARIABLE_IDENTIFIER()), false)
case arrCtx if arrCtx.arrayParameter() != null =>
(Option(arrCtx.arrayParameter().LOCAL_VARIABLE_IDENTIFIER()), arrCtx.arrayParameter().STAR() != null)
case _ => (None, false)
}

parameterTupleList.zipWithIndex.map { case (paraTuple, paraIndex) =>
paraTuple match
case (Some(paraValue), isVariadic) =>
val varSymbol = paraValue.getSymbol
createIdentifierWithScope(ctx, varSymbol.getText, varSymbol.getText, Defines.Any, Seq[String](Defines.Any))
Ast(
createMethodParameterIn(
varSymbol.getText,
lineNumber = Some(varSymbol.getLine),
colNumber = Some(varSymbol.getCharPositionInLine),
order = paraIndex + 1,
index = paraIndex + 1
).isVariadic(isVariadic)
)
case _ =>
Ast(
createMethodParameterIn(
getUnusedVariableNames(usedVariableNames, Defines.TempParameter),
order = paraIndex + 1,
index = paraIndex + 1
)
)
}.toList
}

// TODO: Rewrite for simplicity and take into account more than parameter names.
Expand Down Expand Up @@ -1314,9 +1317,25 @@ class AstCreator(

def astForMethodDefinitionContext(ctx: MethodDefinitionContext): Seq[Ast] = {
scope.pushNewScope(())
val astMethodParamSeq = astForMethodParameterPartContext(ctx.methodParameterPart())
val astMethodName = astForMethodNamePartContext(ctx.methodNamePart())
val callNode = astMethodName.head.nodes.filter(node => node.isInstanceOf[NewCall]).head.asInstanceOf[NewCall]
val astMethodName = astForMethodNamePartContext(ctx.methodNamePart())
val callNode = astMethodName.head.nodes.filter(node => node.isInstanceOf[NewCall]).head.asInstanceOf[NewCall]

// Create thisParameter if this is an instance method
// TODO may need to revisit to make this more robust
val astMethodParamSeq = ctx.methodNamePart() match {
case _: SimpleMethodNamePartContext if !classStack.top.endsWith(":program") =>
val thisParameterNode = createMethodParameterIn(
"this",
typeFullName = callNode.methodFullName,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this call node refer to the :program method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the call node refers to methodDefinition. We use the information from the callNode and discard it

Basically here for any methodDefinition encountered we are categorizing the whether they are class method or instance method

lineNumber = callNode.lineNumber,
colNumber = callNode.columnNumber,
index = 0,
order = 0
)
Seq(Ast(thisParameterNode)) ++ astForMethodParameterPartContext(ctx.methodParameterPart())
case _ => astForMethodParameterPartContext(ctx.methodParameterPart())
}

// there can be only one call node
val astBody = astForBodyStatementContext(ctx.bodyStatement(), true)
scope.popScope()
Expand Down Expand Up @@ -1395,12 +1414,12 @@ class AstCreator(
.filter(_.isInstanceOf[NewMethodParameterIn])
.asInstanceOf[Seq[NewMethodParameterIn]]
)
.foreach(paramNode => {
.foreach { paramNode =>
val linkIdentifiers = identifiers.filter(_.name == paramNode.name)
identifiers.foreach { identifier =>
linkIdentifiers.foreach { identifier =>
diffGraph.addEdge(identifier, paramNode, EdgeTypes.REF)
}
})
}

Seq(
methodAst(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
package io.joern.rubysrc2cpg.astcreation

import io.joern.x2cpg.{Ast, Defines}
import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators}
import io.shiftleft.codepropertygraph.generated.nodes.{AstNodeNew, NewCall, NewNode}
import io.joern.rubysrc2cpg.passes.Defines as RubyDefines
import io.shiftleft.codepropertygraph.generated.{DispatchTypes, EdgeTypes, Operators, nodes}
import io.shiftleft.codepropertygraph.generated.nodes.{
AstNodeNew,
NewCall,
NewFieldIdentifier,
NewMethodParameterIn,
NewNode
}
import scala.collection.mutable
trait AstCreatorHelper { this: AstCreator =>

Expand All @@ -26,6 +33,49 @@ trait AstCreatorHelper { this: AstCreator =>
callAst(callNode, Seq(Ast(lhs), Ast(rhs)))
}

protected def createFieldAccess(
baseNode: NewNode,
fieldName: String,
lineNumber: Option[Integer],
colNumber: Option[Integer]
) = {
val fieldIdNode = NewFieldIdentifier()
.code(fieldName)
.canonicalName(fieldName)
.lineNumber(lineNumber)
.columnNumber(colNumber)

val baseNodeCopy = baseNode.copy
val code = codeOf(baseNode) + "." + codeOf(fieldIdNode)
val callNode = NewCall()
.code(code)
.name(Operators.fieldAccess)
.methodFullName(Operators.fieldAccess)
.dispatchType(DispatchTypes.STATIC_DISPATCH)
.lineNumber(lineNumber)
.columnNumber(colNumber)

callAst(callNode, Seq(Ast(baseNodeCopy), Ast(fieldIdNode)))
}

protected def createMethodParameterIn(
name: String,
lineNumber: Option[Integer] = None,
colNumber: Option[Integer] = None,
typeFullName: String = RubyDefines.Any,
order: Int = -1,
index: Int = -1
) = {
NewMethodParameterIn()
.name(name)
.code(name)
.lineNumber(lineNumber)
.typeFullName(typeFullName)
.columnNumber(colNumber)
.order(order)
.index(index)
}

protected def codeOf(node: NewNode): String = {
node.asInstanceOf[AstNodeNew].code
}
Expand All @@ -36,6 +86,10 @@ trait AstCreatorHelper { this: AstCreator =>
usedVariableNames.put(variableName, counter)
currentVariableName
}

protected def addReceiverEdge(dstNode: nodes.NewNode, srcNode: nodes.NewNode): Unit = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like this exists in the Ast() class already

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will remove this code as anyways we are not using this function anywhere in ruby

diffGraph.addEdge(srcNode, dstNode, EdgeTypes.RECEIVER)
}
}

object GlobalTypes {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ trait AstForTypesCreator { this: AstCreator =>
val typeDeclNode = NewTypeDecl()
.name(className)
.fullName(fullName)

typeDeclNameToTypeDecl.put(className, typeDeclNode)
Seq(Ast(typeDeclNode).withChildren(bodyAst))
} else {
Seq.empty
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,4 +121,11 @@ private class RecoverForRubyFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder,
symbolTable.append(c, callTypes)
}

override protected def visitIdentifierAssignedToTypeRef(i: Identifier, t: TypeRef, rec: Option[String]): Set[String] =
t.typ.referencedTypeDecl
.map(_.fullName.stripSuffix("<meta>"))
.map(td => symbolTable.append(CallAlias(i.name, rec), Set(td)))
.headOption
.getOrElse(super.visitIdentifierAssignedToTypeRef(i, t, rec))

}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.joern.rubysrc2cpg.dataflow

import io.joern.dataflowengineoss.language.*
import io.joern.rubysrc2cpg.RubySrc2Cpg
import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture
import io.shiftleft.semanticcpg.language.*

Expand Down Expand Up @@ -319,8 +320,8 @@ class DataFlowTests extends RubyCode2CpgFixture(withPostProcessing = true, withD
sink.reachableByFlows(src).l.size shouldBe 2
}
}
// TODO:
"Data flow through class member" ignore {

"Data flow through class member" should {
val cpg = code("""
|class MyClass
| @instanceVariable
Expand Down Expand Up @@ -1781,8 +1782,7 @@ class DataFlowTests extends RubyCode2CpgFixture(withPostProcessing = true, withD
}
}

// TODO: Need to be fixed.
"Across the file data flow test" ignore {
"Across the file data flow test" should {
val cpg = code(
"""
|def foo(arg)
Expand Down Expand Up @@ -1815,7 +1815,8 @@ class DataFlowTests extends RubyCode2CpgFixture(withPostProcessing = true, withD
sink.reachableByFlows(src).size shouldBe 1
}

"be found for sink in nested block" in {
// TODO: Need to be fixed.
"be found for sink in nested block" ignore {
val src = cpg.identifier("x").lineNumber(3).l
val sink = cpg.call.name("puts").argument(1).lineNumber(7).l
sink.reachableByFlows(src).size shouldBe 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class FunctionTests extends RubyCode2CpgFixture {
| end
|end
|
|p = Person. new
|p = Person.new
|p.greet
|""".stripMargin)

Expand All @@ -35,7 +35,7 @@ class FunctionTests extends RubyCode2CpgFixture {
cpg.identifier.name("age").size shouldBe 1
cpg.fieldAccess.fieldIdentifier.canonicalName("name").size shouldBe 2
cpg.fieldAccess.fieldIdentifier.canonicalName("age").size shouldBe 4
cpg.identifier.size shouldBe 15 // 4 identifier node is for `puts = typeDef(__builtin.puts)` and methodRef's assignment
cpg.identifier.size shouldBe 16 // 4 identifier node is for `puts = typeDef(__builtin.puts)` and methodRef's assignment, 1 node for class Person = typeDef
}

"recognize all call nodes" in {
Expand Down Expand Up @@ -77,10 +77,12 @@ class FunctionTests extends RubyCode2CpgFixture {
}

"recognize all call nodes" in {
cpg.call.name(Operators.assignment).size shouldBe 6 // 3 identifier node is for methodRef's assignment
cpg.call
.name(Operators.assignment)
.size shouldBe 7 // 3 identifier node is for methodRef's assignment, 1 identifier node for TypeRef's assignment
cpg.call.name("to_s").size shouldBe 2
cpg.call.name("new").size shouldBe 1
cpg.call.size shouldBe 14 // 3 identifier node is for methodRef's assignment
cpg.call.size shouldBe 15 // 3 identifier node is for methodRef's assignment, 1 identifier node for TypeRef's assignment
}

"recognize all identifier nodes" in {
Expand Down