Skip to content

Commit

Permalink
[ruby] require* and load import nodes (joernio#4565)
Browse files Browse the repository at this point in the history
Creating import nodes for `require_all`, `require_relative`, and `load` calls. `require_all` is interpreted as a wildcard import as it imports a whole directory.

Additionally, handling the `require_relative` and `require_all` correctly as per their semantics with accompanying tests.
  • Loading branch information
DavidBakerEffendi authored May 17, 2024
1 parent 6a42e46 commit 592e17a
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
case arg: StaticLiteral if arg.isString => Option(arg.innerText)
case _ => None
}
pathOpt.foreach(path => scope.addRequire(path, node.isRelative))
pathOpt.foreach(path => scope.addRequire(projectRoot.get, fileName, path, node.isRelative, node.isWildCard))
astForSimpleCall(node.asSimpleCall)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,9 @@ object RubyIntermediateAst {
final case class SimpleIdentifier(typeFullName: Option[String] = None)(span: TextSpan)
extends RubyNode(span)
with RubyIdentifier
with SingletonMethodIdentifier
with SingletonMethodIdentifier {
override def toString: String = s"SimpleIdentifier(${span.text}, $typeFullName)"
}

/** Represents a InstanceFieldIdentifier e.g `@x` */
final case class InstanceFieldIdentifier()(span: TextSpan) extends RubyNode(span) with RubyFieldIdentifier
Expand All @@ -248,7 +250,7 @@ object RubyIntermediateAst {
def isString: Boolean = text.startsWith("\"") || text.startsWith("'")

def innerText: String = {
val strRegex = "[./:]?['\"]([\\w\\d_-]+)(?:\\.rb)?['\"]".r
val strRegex = "['\"]([./:]{0,3}[\\w\\d_-]+)(?:\\.rb)?['\"]".r
text match {
case s":'$content'" => content
case s":$symbol" => symbol
Expand Down Expand Up @@ -298,7 +300,12 @@ object RubyIntermediateAst {
extends RubyNode(span)
with RubyCall

final case class RequireCall(target: RubyNode, argument: RubyNode, isRelative: Boolean)(span: TextSpan)
final case class RequireCall(
target: RubyNode,
argument: RubyNode,
isRelative: Boolean = false,
isWildCard: Boolean = false
)(span: TextSpan)
extends RubyNode(span)
with RubyCall {
def arguments: List[RubyNode] = List(argument)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,21 +83,37 @@ class RubyScope(summary: RubyProgramSummary, projectRoot: Option[String])
super.pushNewScope(mappedScopeNode)
}

def addRequire(rawPath: String, isRelative: Boolean): Unit = {
val path = rawPath.stripSuffix(":<global>") // Sometimes the require call provides a processed path
// We assume the project root is the sole LOAD_PATH of the project sources for now
val relativizedPath =
def addRequire(
projectRoot: String,
currentFilePath: String,
requiredPath: String,
isRelative: Boolean,
isWildCard: Boolean = false
): Unit = {
val path = requiredPath.stripSuffix(":<global>") // Sometimes the require call provides a processed path
// We assume the project root is the sole LOAD_PATH of the project sources
// NB: Tracking whatever has been added to $LOADER is dynamic and requires post-processing step!
val resolvedPath =
if (isRelative) {
Try {
val parentDir = File(surrounding[ProgramScope].get.fileName).parentOption.get
val absPath = (parentDir / path).path.toAbsolutePath
projectRoot.map(File(_).path.toAbsolutePath.relativize(absPath).toString)
}.getOrElse(Option(path))
Try((File(currentFilePath).parent / path).pathAsString).toOption
.map(_.stripPrefix(s"$projectRoot/"))
.getOrElse(path)
} else {
Option(path)
path
}

relativizedPath.iterator.flatMap(summary.pathToType.getOrElse(_, Set())) match {
val pathsToImport =
if (isWildCard) {
val dir = File(projectRoot) / resolvedPath
if (dir.isDirectory)
dir.list
.map(_.pathAsString.stripPrefix(s"$projectRoot/").stripSuffix(".rb"))
.toList
else Nil
} else {
resolvedPath :: Nil
}
pathsToImport.flatMap(summary.pathToType.getOrElse(_, Set())) match {
case x if x.nonEmpty =>
x.foreach { ty => addImportedTypeOrModule(ty.name) }
case _ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -522,9 +522,11 @@ class RubyNodeCreator extends RubyParserBaseVisitor[RubyNode] {
val arguments = ctx.commandArgument().arguments.map(visit)
(identifierCtx.getText, arguments) match {
case ("require", List(argument)) =>
RequireCall(visit(identifierCtx), argument, false)(ctx.toTextSpan)
RequireCall(visit(identifierCtx), argument)(ctx.toTextSpan)
case ("require_relative", List(argument)) =>
RequireCall(visit(identifierCtx), argument, true)(ctx.toTextSpan)
case ("require_all", List(argument)) =>
RequireCall(visit(identifierCtx), argument, true, true)(ctx.toTextSpan)
case ("include", List(argument)) =>
IncludeCall(visit(identifierCtx), argument)(ctx.toTextSpan)
case (idAssign, arguments) if idAssign.endsWith("=") =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,16 @@ import io.shiftleft.semanticcpg.language.*

class ImportsPass(cpg: Cpg) extends ForkJoinParallelCpgPass[Call](cpg) {

private val importCallName: String = "require"
private val importCallName: Seq[String] = Seq("require", "load", "require_relative", "require_all")

override def generateParts(): Array[Call] = cpg.call.nameExact(importCallName).toArray
override def generateParts(): Array[Call] = cpg.call.nameExact(importCallName*).toArray

override def runOnPart(diffGraph: DiffGraphBuilder, call: Call): Unit = {
val importedEntity = stripQuotes(call.argument.isLiteral.code.l match {
case s :: _ => s
case _ => ""
})
createImportNodeAndLink(importedEntity, importedEntity, Some(call), diffGraph)
val importNode = createImportNodeAndLink(importedEntity, importedEntity, Some(call), diffGraph)
if (call.name == "require_all") importNode.isWildcard(true)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -178,4 +178,113 @@ class ImportTests extends RubyCode2CpgFixture with Inspectors {
}
}
}

"`require_all` on a directory" should {
val cpg = code("""
|require_all './dir'
|Module1.foo
|Module2.foo
|""".stripMargin)
.moreCode(
"""
|module Module1
| def foo
| end
|end
|""".stripMargin,
"dir/module1.rb"
)
.moreCode(
"""
|module Module2
| def foo
| end
|end
|""".stripMargin,
"dir/module2.rb"
)

"allow the resolution for all modules in that directory" in {
cpg.call("foo").methodFullName.l shouldBe List(
"dir/module1.rb:<global>::program.Module1:foo",
"dir/module2.rb:<global>::program.Module2:foo"
)
}
}

"`require_all`, `require_relative`, and `load`" should {
val cpg = code("""
|require_all './dir'
|require_relative '../foo'
|load 'pp'
|""".stripMargin)

"also create import nodes" in {
inside(cpg.imports.l) {
case requireAll :: requireRelative :: load :: Nil =>
requireAll.importedAs shouldBe Option("./dir")
requireAll.isWildcard shouldBe Option(true)
requireRelative.importedAs shouldBe Option("../foo")
load.importedAs shouldBe Option("pp")
case xs => fail(s"Expected two imports, got [${xs.code.mkString(",")}] instead")
}
}
}

"Modifying `$LOADER` with an additional entry" should {
val cpg = code(
"""
|lib_dir = File.expand_path('lib', __dir__)
|src_dir = File.expand_path('src', File.dirname(__FILE__))
|
|$LOADER << lib_dir unless $LOADER.include?(lib_dir)
|$LOAD_PATH.unshift(src_dir) unless $LOAD_PATH.include?(src_dir)
|
|require 'file1'
|require 'file2'
|require 'file3'
|
|File1::foo # lib/file1.rb::program:foo
|File2::foo # lib/file2.rb::program:foo
|File3::foo # src/file3.rb::program:foo
|""".stripMargin,
"main.rb"
).moreCode(
"""
|module File1
| def self.foo
| end
|end
|""".stripMargin,
"lib/file1.rb"
).moreCode(
"""
|module File2
| def self.foo
| end
|end
|""".stripMargin,
"lib/file2.rb"
).moreCode(
"""
|module File3
| def self.foo
| end
|end
|""".stripMargin,
"src/file3.rb"
)

// TODO: This works because of an over-approximation of the type resolver assuming that classes may have been
// implicitly loaded elsewhere
"resolve the calls directly" in {
inside(cpg.call.name("foo.*").l) {
case foo1 :: foo2 :: foo3 :: Nil =>
foo1.methodFullName shouldBe "lib/file1.rb:<global>::program.File1:foo"
foo2.methodFullName shouldBe "lib/file2.rb:<global>::program.File2:foo"
foo3.methodFullName shouldBe "src/file3.rb:<global>::program.File3:foo"
case xs => fail(s"Expected 3 calls, got [${xs.code.mkString(",")}] instead")
}
}
}
}

0 comments on commit 592e17a

Please sign in to comment.