Skip to content

Commit

Permalink
[javasrc2cpg] - add ability to cache JdkTypeSolver (#3965)
Browse files Browse the repository at this point in the history
  • Loading branch information
xavierpinho authored Dec 15, 2023
1 parent 9d79faa commit 162abf5
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ final case class Config(
jdkPath: Option[String] = None,
showEnv: Boolean = false,
skipTypeInfPass: Boolean = false,
dumpJavaparserAsts: Boolean = false
dumpJavaparserAsts: Boolean = false,
cacheJdkTypeSolver: Boolean = false
) extends X2CpgConfig[Config]
with TypeRecoveryParserConfig[Config] {
def withInferenceJarPaths(paths: Set[String]): Config = {
Expand Down Expand Up @@ -62,6 +63,10 @@ final case class Config(
def withDumpJavaparserAsts(value: Boolean): Config = {
copy(dumpJavaparserAsts = value).withInheritedFields(this)
}

def withCacheJdkTypeSolver(value: Boolean): Config = {
copy(cacheJdkTypeSolver = value).withInheritedFields(this)
}
}

private object Frontend {
Expand Down Expand Up @@ -111,7 +116,11 @@ private object Frontend {
opt[Unit]("dump-javaparser-asts")
.hidden()
.action((_, c) => c.withDumpJavaparserAsts(true))
.text("Dump the javaparser asts for the given input files and terminate (for debugging).")
.text("Dump the javaparser asts for the given input files and terminate (for debugging)."),
opt[Unit]("cache-jdk-type-solver")
.hidden()
.action((_, c) => c.withCacheJdkTypeSolver(true))
.text("Re-use JDK type solver between scans.")
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@ class AstCreationPass(config: Config, cpg: Cpg, sourcesOverride: Option[List[Str
jdkPath
}

combinedTypeSolver.addNonCachingTypeSolver(JdkJarTypeSolver.fromJdkPath(jdkPath))
combinedTypeSolver.addNonCachingTypeSolver(
JdkJarTypeSolver.fromJdkPath(jdkPath, useCache = config.cacheJdkTypeSolver)
)

val relativeSourceFilenames =
sourceFilenames.map(filename => Path.of(config.inputPath).relativize(Path.of(filename)).toString)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,9 @@ import scala.collection.mutable
import scala.jdk.CollectionConverters.*
import scala.util.{Failure, Success, Try, Using}

class JdkJarTypeSolver extends TypeSolver {

private val logger = LoggerFactory.getLogger(this.getClass())
class JdkJarTypeSolver(classPool: NonCachingClassPool, knownPackagePrefixes: Set[String]) extends TypeSolver {

private var parent: Option[TypeSolver] = None
private val classPool = new NonCachingClassPool()

private val knownPackagePrefixes: mutable.Set[String] = mutable.Set.empty

private type RefType = ResolvedReferenceTypeDeclaration

Expand Down Expand Up @@ -76,6 +71,17 @@ class JdkJarTypeSolver extends TypeSolver {
private def refTypeToSymbolReference(refType: RefType): SymbolReference[RefType] = {
SymbolReference.solved[RefType, RefType](refType)
}
}

class JdkJarTypeSolverBuilder {

private val logger = LoggerFactory.getLogger(this.getClass)
private val classPool = new NonCachingClassPool()
private val knownPackagePrefixes: mutable.Set[String] = mutable.Set.empty

def build: JdkJarTypeSolver = {
new JdkJarTypeSolver(classPool, knownPackagePrefixes.toSet)
}

private def addPathToClassPool(archivePath: String): Try[ClassPath] = {
if (archivePath.isJarPath) {
Expand All @@ -88,12 +94,12 @@ class JdkJarTypeSolver extends TypeSolver {
}
}

def withJars(archivePaths: Seq[String]): JdkJarTypeSolver = {
def withJars(archivePaths: Seq[String]): JdkJarTypeSolverBuilder = {
addArchives(archivePaths)
this
}

def addArchives(archivePaths: Seq[String]): Unit = {
private def addArchives(archivePaths: Seq[String]): Unit = {
archivePaths.foreach { archivePath =>
addPathToClassPool(archivePath) match {
case Success(_) => registerPackagesForJar(archivePath)
Expand Down Expand Up @@ -124,24 +130,34 @@ class JdkJarTypeSolver extends TypeSolver {
}

object JdkJarTypeSolver {
val ClassExtension: String = ".class"
val JmodClassPrefix: String = "classes/"
val JarExtension: String = ".jar"
val JmodExtension: String = ".jmod"
val ClassExtension: String = ".class"
val JmodClassPrefix: String = "classes/"
val JarExtension: String = ".jar"
val JmodExtension: String = ".jmod"
private val cache: mutable.Map[String, JdkJarTypeSolverBuilder] = mutable.Map.empty

extension (path: String) {
def isJarPath: Boolean = path.endsWith(JarExtension)
def isJmodPath: Boolean = path.endsWith(JmodExtension)
}

def fromJdkPath(jdkPath: String): JdkJarTypeSolver = {
private def determineJarPaths(jdkPath: String): List[String] = {
// not following symlinks, because some setups might have a loop, e.g. AWS's Corretto
// see https://github.com/joernio/joern/pull/3871
val jarPaths = SourceFiles.determine(jdkPath, Set(JarExtension, JmodExtension))(VisitOptions.default)
if (jarPaths.isEmpty) {
throw new IllegalArgumentException(s"No .jar or .jmod files found at JDK path ${jdkPath}")
}
new JdkJarTypeSolver().withJars(jarPaths)
jarPaths
}

def fromJdkPath(jdkPath: String, useCache: Boolean = false): JdkJarTypeSolver = {
def createBuilder = new JdkJarTypeSolverBuilder().withJars(determineJarPaths(jdkPath))
if (useCache) {
cache.getOrElseUpdate(jdkPath, createBuilder).build
} else {
createBuilder.build
}
}

/** Convert JavaParser class name foo.bar.qux.Baz to package prefix foo.bar Only use first 2 parts since this is
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class JavaSrc2CpgTestContext {
val config = Config(inferenceJarPaths = inferenceJarPaths)
.withInputPath(writeCodeToFile(code, "javasrc2cpgTest", ".java").getAbsolutePath)
.withOutputPath("")
.withCacheJdkTypeSolver(true)
val cpg = javaSrc2Cpg.createCpgWithOverlays(config)
if (runDataflow) {
val context = new LayerCreatorContext(cpg.get)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ trait JavaSrcFrontend extends LanguageFrontend {
override val fileSuffix: String = ".java"

override def execute(sourceCodeFile: File): Cpg = {
val config = getConfig().map(_.asInstanceOf[Config]).getOrElse(JavaSrc2Cpg.DefaultConfig)
val config =
getConfig().map(_.asInstanceOf[Config]).getOrElse(JavaSrc2Cpg.DefaultConfig).withCacheJdkTypeSolver(true)
new JavaSrc2Cpg().createCpg(sourceCodeFile.getAbsolutePath)(config).get
}
}
Expand Down

0 comments on commit 162abf5

Please sign in to comment.