Skip to content

Commit

Permalink
Rework lambda implicit parameter handling.
Browse files Browse the repository at this point in the history
- Removed implicitParameterName and hasApplyOrAlsoScopeFunctionParent
  APIs from TypeInfoProvider.
- Removed wrong parameter deconstruction in lambda. New implemention is
  missing.
- Some lambda to builtin `apply` and `also` functions did not get return
  statements generated. That is now fixed.
  • Loading branch information
ml86 committed Oct 18, 2024
1 parent 325b742 commit b271eaa
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 134 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,23 @@ import io.shiftleft.codepropertygraph.generated.ModifierTypes
import io.shiftleft.codepropertygraph.generated.nodes.*
import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal
import org.jetbrains.kotlin.com.intellij.psi.PsiElement
import org.jetbrains.kotlin.descriptors.{ClassDescriptor, DescriptorVisibilities, FunctionDescriptor, Modality}
import org.jetbrains.kotlin.descriptors.{
ClassDescriptor,
DescriptorVisibilities,
FunctionDescriptor,
Modality,
ParameterDescriptor,
ReceiverParameterDescriptor
}
import org.jetbrains.kotlin.psi.*
import org.jetbrains.kotlin.resolve.calls.model.ResolvedCallArgument
import org.jetbrains.kotlin.resolve.calls.tower.{NewAbstractResolvedCall, PSIFunctionKotlinCallArgument}
import org.jetbrains.kotlin.resolve.sam.{SamConstructorDescriptor, SamConversionResolverImplKt}
import org.jetbrains.kotlin.resolve.DescriptorUtils
import org.jetbrains.kotlin.resolve.source.KotlinSourceElement

import java.util.UUID.nameUUIDFromBytes
import scala.collection.mutable
import scala.jdk.CollectionConverters.*

trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) {
Expand Down Expand Up @@ -309,6 +318,8 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) {
.withChildren(annotations.map(astForAnnotationEntry))
}

// TODO Handling for destructuring of lambda parameters is missing.
// More specifically the creation and initialisation of the thereby introduced variables.
def astForLambda(
expr: KtLambdaExpression,
argIdxMaybe: Option[Int],
Expand Down Expand Up @@ -353,50 +364,33 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) {
scope.addToScope(capturedNodeContext.name, node)
node
}
val parametersAsts = typeInfoProvider.implicitParameterName(expr) match {
case Some(implicitParamName) =>
val node = parameterInNode(
expr,
implicitParamName,
implicitParamName,
1,
false,
EvaluationStrategies.BY_REFERENCE,
TypeConstants.any
)
scope.addToScope(implicitParamName, node)
Seq(Ast(node))
case None =>
withIndex(expr.getValueParameters.asScala.toSeq) { (p, idx) =>
val destructuringEntries =
Option(p.getDestructuringDeclaration)
.map(_.getEntries.asScala)
.getOrElse(Seq())
if (destructuringEntries.nonEmpty)
destructuringEntries.filterNot(_.getText == Constants.unusedDestructuringEntryText).zipWithIndex.map {
case (entry, innerIdx) =>
val name = entry.getName
val typeFullName =
bindingUtils
.getVariableDesc(entry)
.flatMap(desc => nameRenderer.typeFullName(desc.getType))
.getOrElse {
val explicitTypeName = Option(entry.getTypeReference).map(_.getText).getOrElse(TypeConstants.any)
explicitTypeName
}
registerType(typeFullName)
val node =
parameterInNode(entry, name, name, innerIdx + idx, false, EvaluationStrategies.BY_VALUE, typeFullName)
scope.addToScope(name, node)
Ast(node)
}
else Seq(astForParameter(p, idx))
}.flatten

val paramAsts = mutable.ArrayBuffer.empty[Ast]
val valueParamStartIndex =
if (funcDesc.getExtensionReceiverParameter != null) {
// Lambdas which are arguments to function parameters defined
// like `func: extendedType.(argTypes) -> returnType` have an implicit extension receiver parameter
// which can be accessed as `this`
paramAsts.append(createImplicitParamNode(expr, funcDesc.getExtensionReceiverParameter, "this", 1))
2
} else {
1
}

funcDesc.getValueParameters.asScala match {
case parameters if parameters.size == 1 && !parameters.head.getSource.isInstanceOf[KotlinSourceElement] =>
// Here we handle the implicit `it` parameter.
paramAsts.append(createImplicitParamNode(expr, parameters.head, "it", valueParamStartIndex))
case parameters =>
parameters.zipWithIndex.foreach { (paramDesc, idx) =>
val param = paramDesc.getSource.asInstanceOf[KotlinSourceElement].getPsi.asInstanceOf[KtParameter]
paramAsts.append(astForParameter(param, valueParamStartIndex + idx))
}
}

val lastChildNotReturnExpression = !expr.getBodyExpression.getLastChild.isInstanceOf[KtReturnExpression]
val needsReturnExpression =
lastChildNotReturnExpression && !typeInfoProvider.hasApplyOrAlsoScopeFunctionParent(expr)
lastChildNotReturnExpression
val bodyAsts = Option(expr.getBodyExpression)
.map(
astsForBlock(
Expand Down Expand Up @@ -424,7 +418,7 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) {
}
val lambdaMethodAst = methodAst(
lambdaMethodNode,
parametersAsts,
paramAsts.toSeq,
bodyAst,
newMethodReturnNode(returnTypeFullName, None, line(expr), column(expr)),
newModifierNode(ModifierTypes.VIRTUAL) :: newModifierNode(ModifierTypes.LAMBDA) :: Nil
Expand Down Expand Up @@ -460,6 +454,25 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) {
.withChildren(annotations.map(astForAnnotationEntry))
}

private def createImplicitParamNode(
expr: KtLambdaExpression,
paramDesc: ParameterDescriptor,
paramName: String,
index: Int
): Ast = {
val node = parameterInNode(
expr,
paramName,
paramName,
index,
false,
EvaluationStrategies.BY_REFERENCE,
nameRenderer.typeFullName(paramDesc.getType).getOrElse(TypeConstants.any)
)
scope.addToScope(paramName, node)
Ast(node)
}

// SAM stands for: single abstraction method
private def getSamInterface(expr: KtLambdaExpression | KtNamedFunction): Option[ClassDescriptor] = {
getSurroundingCallTarget(expr) match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,39 +216,6 @@ class DefaultTypeInfoProvider(val bindingContext: BindingContext, typeRenderer:
.getOrElse(CallKind.Unknown)
}

def hasApplyOrAlsoScopeFunctionParent(expr: KtLambdaExpression): Boolean = {
expr.getParent.getParent match {
case callExpr: KtCallExpression =>
resolvedCallDescriptor(callExpr) match {
case Some(desc) =>
val rendered = typeRenderer.renderFqNameForDesc(desc.getOriginal)
rendered.startsWith(TypeConstants.kotlinApplyPrefix) || rendered.startsWith(TypeConstants.kotlinAlsoPrefix)
case _ => false
}
case _ => false
}
}

private def renderedReturnType(fnDesc: FunctionDescriptor): String = {
val returnT = fnDesc.getReturnType.getConstructor.getDeclarationDescriptor.getDefaultType
val typeParams = fnDesc.getTypeParameters.asScala.toList

val typesInTypeParams = typeParams.map(_.getDefaultType.getConstructor.getDeclarationDescriptor.getDefaultType)
val hasReturnTypeFromTypeParams = typesInTypeParams.contains(returnT)
if (hasReturnTypeFromTypeParams) {
if (returnT.getConstructor.getSupertypes.asScala.nonEmpty) {
val firstSuperType = returnT.getConstructor.getSupertypes.asScala.toList.head
typeRenderer.render(firstSuperType)
} else {
val renderedReturnT = typeRenderer.render(returnT)
if (renderedReturnT == TypeConstants.tType) TypeConstants.javaLangObject
else renderedReturnT
}
} else {
typeRenderer.render(fnDesc.getReturnType)
}
}

def isReferenceToClass(expr: KtNameReferenceExpression): Boolean = {
descriptorForNameReference(expr).exists {
case _: LazyJavaClassDescriptor => true
Expand Down Expand Up @@ -292,45 +259,6 @@ class DefaultTypeInfoProvider(val bindingContext: BindingContext, typeRenderer:
else None
}.headOption
}

def implicitParameterName(expr: KtLambdaExpression): Option[String] = {
if (!expr.getValueParameters.isEmpty) {
None
} else {
val hasSingleImplicitParameter =
Option(bindingContext.get(BindingContext.EXPECTED_EXPRESSION_TYPE, expr)).exists { desc =>
// 1 for the parameter + 1 for the return type == 2
desc.getConstructor.getParameters.size() == 2
}
val containingQualifiedExpression = Option(expr.getParent)
.map(_.getParent)
.flatMap(_.getParent match {
case q: KtQualifiedExpression => Some(q)
case _ => None
})
containingQualifiedExpression match {
case Some(qualifiedExpression) =>
resolvedCallDescriptor(qualifiedExpression) match {
case Some(fnDescriptor) =>
val originalDesc = fnDescriptor.getOriginal
val vps = originalDesc.getValueParameters
val renderedFqName = typeRenderer.renderFqNameForDesc(originalDesc)
if (
hasSingleImplicitParameter &&
(renderedFqName.startsWith(TypeConstants.kotlinRunPrefix) ||
renderedFqName.startsWith(TypeConstants.kotlinApplyPrefix))
) {
Some(TypeConstants.scopeFunctionThisParameterName)
// https://kotlinlang.org/docs/lambdas.html#it-implicit-name-of-a-single-parameter
} else if (hasSingleImplicitParameter) {
Some(TypeConstants.lambdaImplicitParameterName)
} else None
case None => None
}
case None => None
}
}
}
}

object DefaultTypeInfoProvider {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@ object TypeConstants {
val kotlinSuspendFunctionXPrefix = "kotlin.coroutines.SuspendFunction"
val kotlinAlsoPrefix = "kotlin.also"
val kotlinApplyPrefix = "kotlin.apply"
val kotlinRunPrefix = "kotlin.run"
val lambdaImplicitParameterName = "it"
val scopeFunctionThisParameterName = "this"
val kotlinUnit = "kotlin.Unit"
val javaLangBoolean = "boolean"
val javaLangClass = "java.lang.Class"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,12 @@ trait TypeInfoProvider(val typeRenderer: TypeRenderer = new TypeRenderer()) {

def anySignature(args: Seq[Any]): String

def hasApplyOrAlsoScopeFunctionParent(expr: KtLambdaExpression): Boolean

def isConstructorCall(expr: KtExpression): Option[Boolean]

def typeFullName(expr: KtTypeReference, defaultValue: String): String

def hasStaticDesc(expr: KtQualifiedExpression): Boolean

def implicitParameterName(expr: KtLambdaExpression): Option[String]

def isCompanionObject(expr: KtClassOrObject): Boolean

def isRefToCompanionObject(expr: KtNameReferenceExpression): Boolean
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.joern.kotlin2cpg.querying

import io.joern.kotlin2cpg.Constants
import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture
import io.joern.x2cpg.Defines
import io.shiftleft.codepropertygraph.generated.DispatchTypes
Expand Down Expand Up @@ -194,11 +195,29 @@ class LambdaTests extends KotlinCode2CpgFixture(withOssDataflow = false, withDef
"should contain a METHOD_PARAMETER_IN for the lambda with the correct properties set" in {
val List(p) = cpg.method.fullName(".*lambda.*").parameter.l
p.code shouldBe "this"
p.typeFullName shouldBe "ANY"
p.typeFullName shouldBe "java.lang.String"
p.index shouldBe 1
}
}

"lambda should contain METHOD_PARAMETER_IN for both implicit lambda parameters" in {
val cpg = code("""
|package mypkg
|public fun myFunc(block: String.(Int) -> Unit): Unit {}
|fun outer(param: String): Unit {
| myFunc { println(it); println(this)}
|}
||""".stripMargin)

val List(thisParam, itParam) = cpg.method.fullName(".*lambda.*").parameter.l
thisParam.code shouldBe "this"
thisParam.typeFullName shouldBe "java.lang.String"
thisParam.index shouldBe 1
itParam.code shouldBe "it"
itParam.typeFullName shouldBe "int"
itParam.index shouldBe 2
}

"CPG for code containing a lambda with parameter destructuring" should {
val cpg = code("""|package mypkg
|
Expand All @@ -217,14 +236,13 @@ class LambdaTests extends KotlinCode2CpgFixture(withOssDataflow = false, withDef
}

"should contain METHOD_PARAMETER_IN nodes for the lambda with the correct properties set" in {
val List(p1, p2) = cpg.method.fullName(".*lambda.*").parameter.l
p1.code shouldBe "k"
val List(p1) = cpg.method.fullName(".*lambda.*").parameter.l
p1.code shouldBe Constants.paramNameLambdaDestructureDecl
p1.index shouldBe 1
p1.typeFullName shouldBe "java.lang.String"
p2.code shouldBe "v"
p2.index shouldBe 2
p2.typeFullName shouldBe "int"
p1.typeFullName shouldBe "java.util.Map$Entry"
}

// TODO add tests for initialisation of destructured parameter
}

"CPG for code containing a lambda with parameter destructuring and an `_` entry" should {
Expand All @@ -247,9 +265,9 @@ class LambdaTests extends KotlinCode2CpgFixture(withOssDataflow = false, withDef

"should contain one METHOD_PARAMETER_IN node for the lambda with the correct properties set" in {
val List(p1) = cpg.method.fullName(".*lambda.*").parameter.l
p1.code shouldBe "k"
p1.code shouldBe Constants.paramNameLambdaDestructureDecl
p1.index shouldBe 1
p1.typeFullName shouldBe "java.lang.String"
p1.typeFullName shouldBe "java.util.Map$Entry"
}
}

Expand Down Expand Up @@ -552,6 +570,7 @@ class LambdaTests extends KotlinCode2CpgFixture(withOssDataflow = false, withDef
m.signature shouldBe "void(java.lang.String)"
val List(p) = m.parameter.l
p.name shouldBe "it"
p.index shouldBe 1
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ class ScopeFunctionTests extends KotlinCode2CpgFixture(withOssDataflow = false)
p.name shouldBe "it"
}

"should NOT contain a RETURN node around as the last child of the lambda's BLOCK" in {
"should contain a RETURN node around as the last child of the lambda's BLOCK" in {
val List(b: Block) = cpg.method.fullName(".*lambda.*").block.l
val hasReturnAsLastChild = b.astChildren.last match {
case _: Return => true
case _ => false
}
hasReturnAsLastChild shouldBe false
hasReturnAsLastChild shouldBe true
}
}

Expand All @@ -31,13 +31,13 @@ class ScopeFunctionTests extends KotlinCode2CpgFixture(withOssDataflow = false)
p.name shouldBe "this"
}

"should NOT contain a RETURN node around as the last child of the lambda's BLOCK" in {
"should contain a RETURN node around as the last child of the lambda's BLOCK" in {
val List(b: Block) = cpg.method.fullName(".*lambda.*").block.l
val hasReturnAsLastChild = b.astChildren.last match {
case _: Return => true
case _ => false
}
hasReturnAsLastChild shouldBe false
hasReturnAsLastChild shouldBe true
}
}

Expand Down

0 comments on commit b271eaa

Please sign in to comment.