Skip to content

Commit

Permalink
[rubysrc2cpg] this style field accesses handled (#3071)
Browse files Browse the repository at this point in the history
* Some code cleanups
* Detected when `@`-style field accesses were called and generated the necessary `Operators.fieldAccess` structure
* For members: Added static modifier when `@@` notation detected and the virtual modifier otherwise

Partially resolves #3068, more to be done for instance object field accesses.
  • Loading branch information
DavidBakerEffendi authored Jul 8, 2023
1 parent 9b20b6c commit f253a0e
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ class AstCreator(
null
}
}
val varSymbol = localVar.getSymbol()
val varSymbol = localVar.getSymbol
val node =
createIdentifierWithScope(ctx, varSymbol.getText, varSymbol.getText, Defines.Any, List(Defines.Any))
val yAst = Ast(node)
Expand All @@ -207,11 +207,11 @@ class AstCreator(
.dispatchType(DispatchTypes.STATIC_DISPATCH)
.typeFullName(Defines.Any)
.lineNumber(localVar.getSymbol.getLine)
.columnNumber(localVar.getSymbol.getCharPositionInLine())
.columnNumber(localVar.getSymbol.getCharPositionInLine)
Seq(callAst(callNode, xAsts ++ Seq(yAst)))
case ctx: ScopedConstantAccessSingleLeftHandSideContext =>
val localVar = ctx.CONSTANT_IDENTIFIER()
val varSymbol = localVar.getSymbol()
val varSymbol = localVar.getSymbol
val node = createIdentifierWithScope(ctx, varSymbol.getText, varSymbol.getText, Defines.Any, List(Defines.Any))
Seq(Ast(node))
case _ =>
Expand Down Expand Up @@ -244,17 +244,7 @@ class AstCreator(
val rightAst = astForMultipleRightHandSideContext(ctx.multipleRightHandSide())
val leftAst = astForSingleLeftHandSideContext(ctx.singleLeftHandSide())

val operatorName = getOperatorName(ctx.op)
val isSelfFieldAccess = ctx.singleLeftHandSide().getText.startsWith("@")

// Very basic field detection
// TODO: Create a <operator.fieldAccess>
if (isSelfFieldAccess) {
fieldReferences.updateWith(classStack.top) {
case Some(xs) => Option(xs ++ Set(ctx.singleLeftHandSide()))
case None => Option(Set(ctx.singleLeftHandSide()))
}
}
val operatorName = getOperatorName(ctx.op)

if (leftAst.size == 1 && rightAst.size > 1) {
/*
Expand All @@ -267,8 +257,8 @@ class AstCreator(
.methodFullName(operatorName)
.dispatchType(DispatchTypes.STATIC_DISPATCH)
.typeFullName(Defines.Any)
.lineNumber(ctx.op.getLine())
.columnNumber(ctx.op.getCharPositionInLine())
.lineNumber(ctx.op.getLine)
.columnNumber(ctx.op.getCharPositionInLine)

val packedRHS = getPackedRHS(rightAst)
Seq(callAst(callNode, leftAst ++ packedRHS))
Expand All @@ -280,8 +270,8 @@ class AstCreator(
.signature("")
.dispatchType(DispatchTypes.STATIC_DISPATCH)
.typeFullName(Defines.Any)
.lineNumber(ctx.op.getLine())
.columnNumber(ctx.op.getCharPositionInLine())
.lineNumber(ctx.op.getLine)
.columnNumber(ctx.op.getCharPositionInLine)
Seq(callAst(callNode, leftAst ++ rightAst))
}
}
Expand Down Expand Up @@ -1096,7 +1086,13 @@ class AstCreator(
.name(code.replaceAll("@", ""))
.code(code)
.typeFullName(Defines.Any)
}).toList.distinctBy(_.name).map(Ast.apply)
}).toList.distinctBy(_.name).map { m =>
val modifierType = m.name match
case x if x.startsWith("@@") => ModifierTypes.STATIC
case _ => ModifierTypes.VIRTUAL
val modifierAst = Ast(NewModifier().modifierType(modifierType))
Ast(m).withChild(modifierAst)
}
Seq(blockAst(blockNode(classCtx), blockStmts.toList)) ++ uniqueMemberReferences ++ methodStmts
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ package io.joern.rubysrc2cpg.astcreation
import io.joern.rubysrc2cpg.parser.RubyParser.*
import io.joern.rubysrc2cpg.passes.Defines
import io.joern.x2cpg.Ast
import io.shiftleft.codepropertygraph.generated.nodes.NewJumpTarget
import io.shiftleft.codepropertygraph.generated.nodes.{NewFieldIdentifier, NewJumpTarget, NewNode}
import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, DispatchTypes, ModifierTypes, Operators}
import org.antlr.v4.runtime.ParserRuleContext

import scala.collection.immutable.Set
import scala.jdk.CollectionConverters.CollectionHasAsScala

trait AstForExpressionsCreator { this: AstCreator =>
Expand Down Expand Up @@ -206,8 +207,17 @@ trait AstForExpressionsCreator { this: AstCreator =>
* 4. Otherwise default to identifier node creation since there is no reason (point 2) to create a call node
*/

val variableName = ctx.getText
if (definitelyIdentifier || scope.lookupVariable(variableName).isDefined) {
val variableName = ctx.getText
val isSelfFieldAccess = variableName.startsWith("@")
if (isSelfFieldAccess) {
// Very basic field detection
fieldReferences.updateWith(classStack.top) {
case Some(xs) => Option(xs ++ Set(ctx))
case None => Option(Set(ctx))
}
val thisNode = createIdentifierWithScope(ctx, "this", "this", Defines.Any, List.empty)
astForFieldAccess(ctx, thisNode)
} else if (definitelyIdentifier || scope.lookupVariable(variableName).isDefined) {
val node = createIdentifierWithScope(ctx, variableName, variableName, Defines.Any, List())
Ast(node)
} else if (methodNames.contains(variableName)) {
Expand All @@ -230,4 +240,22 @@ trait AstForExpressionsCreator { this: AstCreator =>
controlStructureAst(ifNode, testAst.headOption, thenAst ++ elseAst)
}

protected def astForFieldAccess(ctx: ParserRuleContext, baseNode: NewNode): Ast = {
val fieldAccess =
callNode(ctx, ctx.getText, Operators.fieldAccess, Operators.fieldAccess, DispatchTypes.STATIC_DISPATCH)
val fieldIdentifier = newFieldIdentifier(ctx)
val astChildren = Seq(baseNode, fieldIdentifier)
callAst(fieldAccess, astChildren.map(Ast.apply))
}

protected def newFieldIdentifier(ctx: ParserRuleContext): NewFieldIdentifier = {
val code = ctx.getText
val name = code.replaceAll("@", "")
NewFieldIdentifier()
.code(code)
.canonicalName(name)
.lineNumber(ctx.start.getLine)
.columnNumber(ctx.start.getCharPositionInLine)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import io.joern.rubysrc2cpg.parser.RubyParser.{
ClassOrModuleReferenceContext,
ScopedConstantReferenceContext
}
import io.shiftleft.codepropertygraph.generated.ModifierTypes
import io.joern.rubysrc2cpg.passes.Defines
import io.joern.x2cpg.Ast
import io.shiftleft.codepropertygraph.generated.nodes.*
Expand Down Expand Up @@ -92,14 +93,18 @@ trait AstForTypesCreator { this: AstCreator =>
.collect { case i: NewIdentifier if i.name.startsWith("@") => i }
.map { i =>
val code = ast.root.collect { case c: NewCall => c.code }.getOrElse(i.name)
val modifierType = i.name match
case x if x.startsWith("@@") => ModifierTypes.STATIC
case _ => ModifierTypes.VIRTUAL
val modifierAst = Ast(NewModifier().modifierType(modifierType))
Ast(
NewMember()
.code(code)
.name(i.name.replaceAll("@", ""))
.typeFullName(i.typeFullName)
.lineNumber(i.lineNumber)
.columnNumber(i.columnNumber)
)
).withChild(modifierAst)
}
.toSeq

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,14 @@ class TypeDeclAstCreationPassTest extends RubyCode2CpgFixture {
song.name shouldBe "Song"
song.fullName shouldBe "Test0.rb::program:Song"

val List(plays, artist, duration, name) = song.member.l
val List(artist, duration, name, plays) = song.member.l

plays.name shouldBe "plays"
name.name shouldBe "name"
artist.name shouldBe "artist"
duration.name shouldBe "duration"

cpg.fieldAccess.fieldIdentifier.canonicalName.l shouldBe List("plays", "name", "artist", "duration")
}

"generate members for various class members when using the `attr_reader` and `attr_writer` idioms" ignore {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ class FunctionTests extends RubyCode2CpgFixture {
|""".stripMargin)

"recognise all identifier nodes" in {
cpg.identifier.name("name").l.size shouldBe 1
cpg.identifier.name("age").l.size shouldBe 1
cpg.identifier.name("@name").l.size shouldBe 2
cpg.identifier.name("@age").l.size shouldBe 4
cpg.identifier.name("name").size shouldBe 1
cpg.identifier.name("age").size shouldBe 1
cpg.fieldAccess.fieldIdentifier.canonicalName("name").size shouldBe 2
cpg.fieldAccess.fieldIdentifier.canonicalName("age").size shouldBe 4
cpg.identifier.size shouldBe 11
}

Expand Down Expand Up @@ -80,11 +80,11 @@ class FunctionTests extends RubyCode2CpgFixture {
cpg.call.name(Operators.assignment).size shouldBe 3
cpg.call.name("to_s").size shouldBe 2
cpg.call.name("new").size shouldBe 1
cpg.call.size shouldBe 8
cpg.call.size shouldBe 11
}

"recognize all identifier nodes" in {
cpg.identifier.name("@my_hash").size shouldBe 3
cpg.fieldAccess.fieldIdentifier.canonicalName("my_hash").size shouldBe 3
cpg.identifier.name("key").size shouldBe 2
cpg.identifier.name("value").size shouldBe 1
cpg.identifier.name("my_object").size shouldBe 1
Expand Down

0 comments on commit f253a0e

Please sign in to comment.