diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreator.scala index 36e3890836d5..ead95f983084 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreator.scala @@ -241,9 +241,21 @@ class AstCreator( } def astForSingleAssignmentExpressionContext(ctx: SingleAssignmentExpressionContext): Seq[Ast] = { - val rightAst = astForMultipleRightHandSideContext(ctx.multipleRightHandSide()) - val leftAst = astForSingleLeftHandSideContext(ctx.singleLeftHandSide()) - val operatorName = getOperatorName(ctx.op) + val rightAst = astForMultipleRightHandSideContext(ctx.multipleRightHandSide()) + val leftAst = astForSingleLeftHandSideContext(ctx.singleLeftHandSide()) + + val operatorName = getOperatorName(ctx.op) + val isSelfFieldAccess = ctx.singleLeftHandSide().getText.startsWith("@") + + // Very basic field detection + // TODO: Create a + if (isSelfFieldAccess) { + fieldReferences.updateWith(classStack.top) { + case Some(xs) => Option(xs ++ Set(ctx.singleLeftHandSide())) + case None => Option(Set(ctx.singleLeftHandSide())) + } + } + if (leftAst.size == 1 && rightAst.size > 1) { /* * This is multiple RHS packed into a single LHS. That is, packing left hand side. @@ -1053,20 +1065,39 @@ class AstCreator( def astForClassBody(ctx: BodyStatementContext): Seq[Ast] = { val rootStatements = Option(ctx).map(_.compoundStatement()).map(_.statements()).map(astForStatements).getOrElse(Seq()) + retrieveAndGenerateClassChildren(ctx, rootStatements) + } - val (methods, blockStmts) = - rootStatements - .flatMap { ast => - ast.root match - case Some(x: NewMethod) => Seq(ast) - case Some(x: NewCall) if x.name == Operators.assignment => Seq(ast) :+ astsForClassMembers(ast) - case _ => Seq(ast) - } - .partition(_.root match - case Some(_: NewMethod) => true - case _ => false - ) - Seq(blockAst(blockNode(ctx), blockStmts.toList)) ++ methods + /** As class bodies are not treated much differently to other procedure bodies, we need to retrieve certain components + * that would result in the creation of interprocedural constructs. + * + * TODO: This is pretty hacky and the parser could benefit from more specific tokens + */ + private def retrieveAndGenerateClassChildren(classCtx: BodyStatementContext, rootStatements: Seq[Ast]): Seq[Ast] = { + val (memberLikeStmts, blockStmts) = rootStatements + .flatMap { ast => + ast.root match + case Some(x: NewMethod) => Seq(ast) + case Some(x: NewCall) if x.name == Operators.assignment => Seq(ast) ++ membersFromStatementAsts(ast) + case _ => Seq(ast) + } + .partition(_.root match + case Some(_: NewMethod) => true + case Some(_: NewMember) => true + case _ => false + ) + + val methodStmts = memberLikeStmts.filter(_.root.exists(_.isInstanceOf[NewMethod])) + val memberNodes = memberLikeStmts.flatMap(_.root).collect { case m: NewMember => m } + + val uniqueMemberReferences = + (memberNodes ++ fieldReferences.getOrElse(classStack.top, Set.empty).groupBy(_.getText).map { case (code, ctxs) => + NewMember() + .name(code.replaceAll("@", "")) + .code(code) + .typeFullName(Defines.Any) + }).toList.distinctBy(_.name).map(Ast.apply) + Seq(blockAst(blockNode(classCtx), blockStmts.toList)) ++ uniqueMemberReferences ++ methodStmts } private def convertLastStmtToReturn(compoundStatementAsts: Seq[Ast], ctxStmt: StatementsContext): Seq[Ast] = { diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForTypesCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForTypesCreator.scala index ee53a9e8246b..15b8d9480b07 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForTypesCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForTypesCreator.scala @@ -7,10 +7,16 @@ import io.joern.rubysrc2cpg.parser.RubyParser.{ } import io.joern.rubysrc2cpg.passes.Defines import io.joern.x2cpg.Ast -import io.shiftleft.codepropertygraph.generated.nodes.{NewBlock, NewIdentifier, NewTypeDecl} +import io.shiftleft.codepropertygraph.generated.nodes.* +import org.antlr.v4.runtime.ParserRuleContext + +import scala.collection.mutable trait AstForTypesCreator { this: AstCreator => + // Maps field references of known types + protected val fieldReferences = mutable.HashMap.empty[String, Set[ParserRuleContext]] + def astForClassDeclaration(ctx: ClassDefinitionPrimaryContext): Seq[Ast] = { val baseClassName = if (ctx.classDefinition().expressionOrCommand() != null) { val parentClassNameAst = astForExpressionOrCommand(ctx.classDefinition().expressionOrCommand()) @@ -81,10 +87,21 @@ trait AstForTypesCreator { this: AstCreator => } } - def astsForClassMembers(ast: Ast): Ast = { - // TODO: Handle members - Ast() - } + def membersFromStatementAsts(ast: Ast): Seq[Ast] = + ast.nodes + .collect { case i: NewIdentifier if i.name.startsWith("@") => i } + .map { i => + val code = ast.root.collect { case c: NewCall => c.code }.getOrElse(i.name) + Ast( + NewMember() + .code(code) + .name(i.name.replaceAll("@", "")) + .typeFullName(i.typeFullName) + .lineNumber(i.lineNumber) + .columnNumber(i.columnNumber) + ) + } + .toSeq implicit class ClassDefinitionPrimaryContextExt(val ctx: ClassDefinitionPrimaryContext) { diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/TypeDeclAstCreationPassTest.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/TypeDeclAstCreationPassTest.scala index 67b92ef9a72a..7ee446b5b1c5 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/TypeDeclAstCreationPassTest.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/TypeDeclAstCreationPassTest.scala @@ -2,7 +2,7 @@ package io.joern.rubysrc2cpg.passes.ast import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture import io.shiftleft.codepropertygraph.generated.ModifierTypes -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class TypeDeclAstCreationPassTest extends RubyCode2CpgFixture { @@ -50,7 +50,7 @@ class TypeDeclAstCreationPassTest extends RubyCode2CpgFixture { driving.fullName shouldBe "Test0.rb::program:Vehicle:driving" } - "generate members for various class members under the respective type declaration" ignore { + "generate members for various class members under the respective type declaration" in { val cpg = code(""" |class Song | @@plays = 0 @@ -65,7 +65,7 @@ class TypeDeclAstCreationPassTest extends RubyCode2CpgFixture { song.name shouldBe "Song" song.fullName shouldBe "Test0.rb::program:Song" - val List(plays, name, artist, duration) = song.member.l + val List(plays, artist, duration, name) = song.member.l plays.name shouldBe "plays" name.name shouldBe "name" @@ -192,12 +192,12 @@ class TypeDeclAstCreationPassTest extends RubyCode2CpgFixture { | def initialize | puts "This is Superclass" | end - | + | | def super_method | puts "Method of superclass" | end |end - | + | |class Sudo_Placement < GeeksforGeeks | def initialize | puts "This is Subclass"