Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ruby_ast_gen] Handling for Singleton & Anon Classes #5006

Merged
merged 4 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,8 @@ object RubyIntermediateAst {
extends RubyExpression(span)
with RubyStatement {

def toStatementList: StatementList = StatementList(body :: Nil)(span)

def toMethodDeclaration(name: String, parameters: Option[List[RubyExpression]]): MethodDeclaration =
parameters match {
case Some(givenParameters) => MethodDeclaration(name, givenParameters, body)(span)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,11 @@ import better.files.File
import io.joern.rubysrc2cpg.Config
import io.joern.x2cpg.astgen.AstGenRunner.{AstGenProgramMetaData, executableDir}
import io.joern.x2cpg.astgen.AstGenRunnerBase
import io.joern.x2cpg.utils.ExternalCommand
import org.jruby.{Ruby, RubyHash, RubyInstanceConfig, RubyRuntimeAdapter}
import org.jruby.javasupport.JavaEmbedUtils
import org.jruby.RubyInstanceConfig
import org.slf4j.LoggerFactory

import java.io.{ByteArrayOutputStream, PrintStream}
import java.io.File.separator
import java.nio.file.{Files, Paths}
import java.io.{ByteArrayOutputStream, PrintStream}
import scala.collection.mutable
import scala.jdk.CollectionConverters.*
import scala.util.{Failure, Success, Try}
Expand Down Expand Up @@ -78,6 +75,7 @@ class RubyAstGenRunner(config: Config) extends AstGenRunnerBase(config) {
config.setEnvironment(Map("GEM_PATH" -> gemPath, "GEM_FILE" -> gemPath).asJava)
config.setHasShebangLine(true)
config.setScriptFileName(mainScript)
config.setHardExit(false)

try {
org.jruby.Main(config).run(Array.empty)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ object ParserKeys {
val Condition = "condition"
val ElseClause = "else_clause"
val ElseBranch = "else_branch"
val End = "end"
val ExecList = "exec_list"
val ExecVar = "exec_var"
val FilePath = "file_path"
Expand All @@ -37,6 +38,7 @@ object ParserKeys {
val Right = "right"
val Rhs = "rhs"
val Statement = "statement"
val Start = "start"
val SuperClass = "superclass"
val ThenBranch = "then_branch"
val Type = "type"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
package io.joern.rubysrc2cpg.parser

import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.{
AllowedTypeDeclarationChild,
ClassFieldIdentifier,
MemberAccess,
MethodDeclaration,
RubyExpression,
RubyFieldIdentifier,
SelfIdentifier,
SimpleIdentifier,
SingleAssignment,
StatementList,
StaticLiteral,
TextSpan,
TypeDeclBodyCall
}
import io.joern.rubysrc2cpg.passes.Defines
import io.joern.rubysrc2cpg.passes.Defines.getBuiltInType
import upickle.core.*
import upickle.default.*

Expand Down Expand Up @@ -44,26 +50,76 @@ object RubyJsonHelpers {

}

protected def nilLiteral(span: TextSpan): StaticLiteral = StaticLiteral(getBuiltInType(Defines.NilClass))(span)

def createClassBodyAndFields(
obj: ujson.Obj
)(implicit visit: ujson.Value => RubyExpression): (StatementList, List[RubyExpression & RubyFieldIdentifier]) = {

def createBodyMethod(fieldStatements: List[ujson.Obj]): MethodDeclaration = {
MethodDeclaration(
Defines.TypeDeclBody,
Nil,
StatementList(fieldStatements.map(visit))(obj.toTextSpan.spanStart(s"(...)"))
)(obj.toTextSpan.spanStart(s"def <body>; (...); end"))
def bodyMethod(fieldStatements: List[RubyExpression]): MethodDeclaration = {

val body = fieldStatements.map {
case field: SimpleIdentifier =>
val assignmentSpan = field.span.spanStart(s"${field.span} = nil")
SingleAssignment(ClassFieldIdentifier()(field.span), "=", nilLiteral(field.span))(assignmentSpan)
case field: RubyFieldIdentifier =>
val assignmentSpan = field.span.spanStart(s"${field.span} = nil")
SingleAssignment(field, "=", nilLiteral(field.span))(assignmentSpan)
case assignment @ SingleAssignment(_: RubyFieldIdentifier, _, _) => assignment
case assignment @ SingleAssignment(lhs: SimpleIdentifier, op, _) =>
assignment.copy(lhs = ClassFieldIdentifier()(lhs.span))(assignment.span)
case otherExpr => otherExpr
}

MethodDeclaration(Defines.TypeDeclBody, Nil, StatementList(body)(obj.toTextSpan.spanStart(s"(...)")))(
obj.toTextSpan.spanStart(s"def <body>; (...); end")
)
}

val bodyMethod = createBodyMethod(Nil)
/** @param expr
* An expression that is a direct child to a class or module.
* @return
* true if the expression constitutes field-related behaviour, false if otherwise.
*/
def isFieldStmt(expr: RubyExpression): Boolean = {
expr match {
case _: SingleAssignment => true
case _: SimpleIdentifier => true
case _: RubyFieldIdentifier => true
case _ => false
}
}

/** Extracts a field from the expression.
* @param expr
* An expression that is a direct child to a class or module.
*/
def getField(expr: RubyExpression): Option[RubyExpression & RubyFieldIdentifier] = {
expr match {
case field: SimpleIdentifier => Option(ClassFieldIdentifier()(field.span))
case field: RubyFieldIdentifier => Option(field)
case _ @SingleAssignment(lhs: RubyFieldIdentifier, _, _) => Option(lhs)
case _ @SingleAssignment(lhs: SimpleIdentifier, _, _) => Option(ClassFieldIdentifier()(lhs.span))
case _ => None
}
}

obj.visitOption(ParserKeys.Body) match {
case Some(stmtList @ StatementList(expression :: Nil)) if expression.isInstanceOf[AllowedTypeDeclarationChild] =>
(stmtList, Nil)
case Some(stmtList @ StatementList(expression :: Nil)) if isFieldStmt(expression) =>
(StatementList(bodyMethod(expression :: Nil) :: Nil)(stmtList.span), getField(expression).toList)
case Some(stmtList: StatementList) =>
val body = stmtList.copy(statements = bodyMethod +: stmtList.statements)(stmtList.span)
(body, Nil)
case Some(expression) => (StatementList(bodyMethod :: expression :: Nil)(obj.toTextSpan), Nil)
case None => (StatementList(bodyMethod :: Nil)(obj.toTextSpan.spanStart("<empty>")), Nil)
val (fieldStmts, otherStmts) = stmtList.statements.partition(isFieldStmt)
val (typeDeclStmts, bodyStmts) = otherStmts.partition(_.isInstanceOf[AllowedTypeDeclarationChild])
val body = stmtList.copy(statements = bodyMethod(fieldStmts ++ bodyStmts) +: typeDeclStmts)(stmtList.span)
val fields = fieldStmts.flatMap(getField)
(body, fields)
case Some(expression) if isFieldStmt(expression) || !expression.isInstanceOf[AllowedTypeDeclarationChild] =>
(StatementList(bodyMethod(expression :: Nil) :: Nil)(obj.toTextSpan), getField(expression).toList)
case Some(expression) =>
(StatementList(bodyMethod(Nil) :: expression :: Nil)(obj.toTextSpan), Nil)
case None => (StatementList(bodyMethod(Nil) :: Nil)(obj.toTextSpan.spanStart("<empty>")), Nil)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,10 @@ class RubyJsonToNodeCreator(
val body = visit(obj(ParserKeys.Body))
val block = Block(parameters, body)(obj.toTextSpan)
visit(obj(ParserKeys.CallName)) match {
case simpleCall: RubyCall => simpleCall.withBlock(block)
case classNew: ObjectInstantiation if classNew.target.text == "Class.new" =>
AnonymousClassDeclaration(freshClassName(obj.toTextSpan), None, block.toStatementList)(obj.toTextSpan)
case simpleCall: RubyCall =>
simpleCall.withBlock(block)
case x =>
logger.warn(s"Unexpected call type used for block ${x.getClass}, ignoring block")
x
Expand Down Expand Up @@ -263,7 +266,12 @@ class RubyJsonToNodeCreator(

private def visitExclusiveFlipFlop(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan))

private def visitExclusiveRange(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan))
private def visitExclusiveRange(obj: Obj): RubyExpression = {
val start = visit(obj(ParserKeys.Start))
val end = visit(obj(ParserKeys.End))
val op = RangeOperator(true)(obj.toTextSpan.spanStart("..."))
RangeExpression(start, end, op)(obj.toTextSpan)
}

private def visitExecutableString(obj: Obj): RubyExpression = {
val callName =
Expand All @@ -274,6 +282,11 @@ class RubyJsonToNodeCreator(

private def visitFalse(obj: Obj): RubyExpression = StaticLiteral(getBuiltInType(Defines.FalseClass))(obj.toTextSpan)

private def visitFieldDeclaration(obj: Obj): RubyExpression = {
val arguments = obj.visitArray(ParserKeys.Arguments)
FieldsDeclaration(arguments)(obj.toTextSpan)
}

private def visitFindPattern(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan))

private def visitFloat(obj: Obj): RubyExpression = StaticLiteral(getBuiltInType(Defines.Float))(obj.toTextSpan)
Expand Down Expand Up @@ -326,7 +339,12 @@ class RubyJsonToNodeCreator(

private def visitInclusiveFlipFlop(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan))

private def visitInclusiveRange(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan))
private def visitInclusiveRange(obj: Obj): RubyExpression = {
val start = visit(obj(ParserKeys.Start))
val end = visit(obj(ParserKeys.End))
val op = RangeOperator(false)(obj.toTextSpan.spanStart(".."))
RangeExpression(start, end, op)(obj.toTextSpan)
}

private def visitInPattern(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan))

Expand Down Expand Up @@ -425,7 +443,11 @@ class RubyJsonToNodeCreator(

private def visitOrAssign(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan))

private def visitPair(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan))
private def visitPair(obj: Obj): RubyExpression = {
val key = visit(obj(ParserKeys.Key))
val value = visit(obj(ParserKeys.Value))
Association(key, value)(obj.toTextSpan)
}

private def visitPostExpression(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan))

Expand Down Expand Up @@ -484,14 +506,20 @@ class RubyJsonToNodeCreator(
case "new" => visitObjectInstantiation(obj)
case "raise" => visitRaise(obj)
case "include" => visitInclude(obj)
case "attr_reader" | "attr_writer" | "attr_accessor" => visitFieldDeclaration(obj)
case requireLike if ImportCallNames.contains(requireLike) => visitRequireLike(obj)
case _ if BinaryOperators.isBinaryOperatorName(callName) =>
val lhs = visit(obj(ParserKeys.Receiver))
val rhs = obj.visitArray(ParserKeys.Arguments).head
BinaryExpression(lhs, callName, rhs)(obj.toTextSpan)
case _ =>
val target = SimpleIdentifier()(obj.toTextSpan.spanStart(callName))
val arguments = obj.visitArray(ParserKeys.Arguments)
val target = SimpleIdentifier()(obj.toTextSpan.spanStart(callName))
val argumentArr = obj.visitArray(ParserKeys.Arguments)
val arguments = argumentArr.zipWithIndex.flatMap {
case (hashLiteral: HashLiteral, idx) =>
hashLiteral.elements // a hash is likely named arguments
case (x, _) => x :: Nil
}
if (obj.contains(ParserKeys.Receiver)) {
val base = visit(obj(ParserKeys.Receiver))
MemberCall(base, ".", callName, arguments)(obj.toTextSpan)
Expand All @@ -503,12 +531,26 @@ class RubyJsonToNodeCreator(

private def visitShadowArg(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan))

private def visitSingletonMethodDefinition(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan))
private def visitSingletonMethodDefinition(obj: Obj): RubyExpression = {
val base = visit(obj(ParserKeys.Base))
val name = obj(ParserKeys.Name).str
val parameters = obj(ParserKeys.Arguments).asInstanceOf[ujson.Obj].visitArray(ParserKeys.Children)
val body = obj.visitOption(ParserKeys.Body).getOrElse(StatementList(Nil)(obj.toTextSpan.spanStart("<empty>")))
SingletonMethodDeclaration(base, name, parameters, body)(obj.toTextSpan)
}

private def visitSingletonClassDefinition(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan))
private def visitSingletonClassDefinition(obj: Obj): RubyExpression = {
val name = visit(obj(ParserKeys.Name))
val baseClass = obj.visitOption(ParserKeys.SuperClass)
val body = obj.visitOption(ParserKeys.Body).getOrElse(StatementList(Nil)(obj.toTextSpan.spanStart("<empty>")))
val bodyMemberCall = createBodyMemberCall(name.text, obj.toTextSpan)
SingletonClassDeclaration(name = name, baseClass = baseClass, body = body, bodyMemberCall = Option(bodyMemberCall))(
obj.toTextSpan
)
}

private def visitSingleAssignment(obj: Obj): RubyExpression = {
val lhs = visit(obj(ParserKeys.Lhs))
val lhs = SimpleIdentifier()(obj.toTextSpan.spanStart(obj(ParserKeys.Lhs).str))
val rhs = visit(obj(ParserKeys.Rhs))
SingleAssignment(lhs, "=", rhs)(obj.toTextSpan)
}
Expand Down
Loading
Loading