Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GoSrc2Cpg : AST generation for switch case #3018

Merged
merged 5 commits into from
Jul 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package io.joern.gosrc2cpg.astcreation

import io.joern.gosrc2cpg.parser.ParserAst.{ParserNode, fromString}
import io.joern.gosrc2cpg.parser.ParserAst.{Ident, ParserNode, fromString}
import ujson.Value
import io.joern.gosrc2cpg.parser.{ParserKeys, ParserNodeInfo}
import org.apache.commons.lang.StringUtils
Expand Down Expand Up @@ -66,5 +66,11 @@ trait AstCreatorHelper { this: AstCreator =>
}
.toMap
}

protected def getTypeForJsonNode(jsonNode: Value): String = {
val nodeInfo = createParserNodeInfo(jsonNode)
nodeInfo.node match {
case Ident => jsonNode.obj(ParserKeys.Name).str
case _ => Defines.anyTypeName
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import io.joern.gosrc2cpg.parser.{ParserKeys, ParserNodeInfo}
import io.joern.x2cpg.Ast
import io.joern.gosrc2cpg.parser.ParserAst._
import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators}
import ujson.Value

import scala.util.Try

Expand Down Expand Up @@ -50,7 +51,7 @@ trait AstForGenDeclarationCreator { this: AstCreator =>
val localParserNode = createParserNodeInfo(parserNode)

val name = parserNode(ParserKeys.Name).str
val typ = valueSpec.json(ParserKeys.Type).obj(ParserKeys.Name).str
val typ = getTypeForJsonNode(valueSpec.json(ParserKeys.Type))
val node = localNode(localParserNode, name, localParserNode.code, typ)
scope.addToScope(name, (node, typ))
Ast(node)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import io.joern.gosrc2cpg.utils.Operator
import io.joern.x2cpg.Ast
import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, DispatchTypes, Operators}

import scala.annotation.tailrec
import scala.util.Try

trait AstForStatementsCreator { this: AstCreator =>
Expand All @@ -23,14 +24,20 @@ trait AstForStatementsCreator { this: AstCreator =>
blockAst(newBlockNode, childAsts.toList)
}

@tailrec
private def astsForStatement(statement: ParserNodeInfo, argIndex: Int = -1): Seq[Ast] = {
statement.node match {
case AssignStmt => astForAssignStatement(statement)
case BlockStmt => Seq(astForBlockStatement(statement, argIndex))
case DeclStmt => astForDeclStatement(statement)
case IfStmt => Seq(astForIfStatement(statement))
case IncDecStmt => Seq(astForIncDecStatement(statement))
case _ => Seq()
case AssignStmt => astForAssignStatement(statement)
case BlockStmt => Seq(astForBlockStatement(statement, argIndex))
case CaseClause => astForCaseClause(statement)
case DeclStmt => astForDeclStatement(statement)
case ExprStmt => astsForStatement(createParserNodeInfo(statement.json(ParserKeys.X)))
case IfStmt => Seq(astForIfStatement(statement))
case IncDecStmt => Seq(astForIncDecStatement(statement))
case SwitchStmt => Seq(astForSwitchStatement(statement))
case TypeAssertExpr => astForNode(statement.json(ParserKeys.X))
case TypeSwitchStmt => Seq(astForTypeSwitchStatement(statement))
case _ => astForNode(statement.json)
}
}

Expand Down Expand Up @@ -100,7 +107,7 @@ trait AstForStatementsCreator { this: AstCreator =>
private def astForConditionExpression(condStmt: ParserNodeInfo): Ast = {
condStmt.node match {
case ParenExpr => astForNode(condStmt.json(ParserKeys.X)).head
case _ => Ast()
case _ => astsForStatement(condStmt).headOption.getOrElse(Ast())
}
}

Expand Down Expand Up @@ -133,4 +140,45 @@ trait AstForStatementsCreator { this: AstCreator =>
controlStructureAst(ifNode, Some(conditionAst), Seq(thenAst, elseAst))
}

private def astForSwitchStatement(switchStmt: ParserNodeInfo): Ast = {

val conditionParserNode = Try(createParserNodeInfo(switchStmt.json(ParserKeys.Tag)))
val (code, conditionAst) = conditionParserNode.toOption match {
case Some(node) => (node.code, Some(astForConditionExpression(node)))
case _ => ("", None)
}
val switchNode = controlStructureNode(switchStmt, ControlStructureTypes.SWITCH, s"switch $code")
val stmtAsts = astsForStatement(createParserNodeInfo(switchStmt.json(ParserKeys.Body)))
controlStructureAst(switchNode, conditionAst, stmtAsts)
}

private def astForTypeSwitchStatement(typeSwitchStmt: ParserNodeInfo): Ast = {

val conditionParserNode = Try(createParserNodeInfo(typeSwitchStmt.json(ParserKeys.Assign)))
val (code, conditionAst) = conditionParserNode.toOption match {
case Some(node) => (node.code, Some(astForConditionExpression(node)))
case _ => ("", None)
}
val switchNode = controlStructureNode(typeSwitchStmt, ControlStructureTypes.SWITCH, s"switch $code")
val stmtAsts = astsForStatement(createParserNodeInfo(typeSwitchStmt.json(ParserKeys.Body)))
controlStructureAst(switchNode, conditionAst, stmtAsts)
}

private def astForCaseClause(caseStmt: ParserNodeInfo): Seq[Ast] = {
val caseClauseAst = caseStmt.json(ParserKeys.List).arrOpt match {
case Some(caseConditionList) =>
caseConditionList.flatMap { caseConditionNode =>
val caseConditionParserNode = createParserNodeInfo(caseConditionNode)
val jumpTarget = jumpTargetNode(caseStmt, "case", s"case ${caseConditionParserNode.code}")
val labelAsts = astForNode(caseConditionNode).toList
Ast(jumpTarget) :: labelAsts
}
case _ =>
val target = jumpTargetNode(caseStmt, "default", "default")
Seq(Ast(target))
}

val caseBodyAst = caseStmt.json(ParserKeys.Body).arr.map(createParserNodeInfo).flatMap(astsForStatement(_)).toList
caseClauseAst ++: caseBodyAst
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,31 +14,38 @@ object ParserAst {
}
sealed trait BaseExprStmt extends ParserNode

object File extends ParserNode
object GenDecl extends ParserNode
object ImportSpec extends ParserNode
object BasicLit extends ParserNode
object FuncDecl extends ParserNode
object BlockStmt extends ParserNode
object DeclStmt extends ParserNode
object ValueSpec extends ParserNode
object Ident extends ParserNode
object AssignStmt extends ParserNode
object ExprStmt extends BaseExprStmt
object BinaryExpr extends BaseExprStmt
object UnaryExpr extends BaseExprStmt
object StarExpr extends BaseExprStmt
object IncDecStmt extends ParserNode
object IfStmt extends ParserNode
object ParenExpr extends BaseExprStmt
object ReturnStmt extends ParserNode
object FuncType extends ParserNode
object Ellipsis extends ParserNode
object SelectorExpr extends ParserNode
object File extends ParserNode
object GenDecl extends ParserNode
object ImportSpec extends ParserNode
object BasicLit extends ParserNode
object FuncDecl extends ParserNode
object BlockStmt extends ParserNode
object DeclStmt extends ParserNode
object ValueSpec extends ParserNode
object Ident extends ParserNode
object AssignStmt extends ParserNode
object ExprStmt extends BaseExprStmt
object BinaryExpr extends BaseExprStmt
object UnaryExpr extends BaseExprStmt
object StarExpr extends BaseExprStmt

object IncDecStmt extends ParserNode
object IfStmt extends ParserNode
object ParenExpr extends BaseExprStmt
object SwitchStmt extends ParserNode
object CaseClause extends ParserNode
object TypeSwitchStmt extends ParserNode
object TypeAssertExpr extends BaseExprStmt
object InterfaceType extends ParserNode
object ReturnStmt extends ParserNode
object FuncType extends ParserNode
object Ellipsis extends ParserNode
object SelectorExpr extends ParserNode
}

object ParserKeys {

val Assign = "Assign"
val Body = "Body"
val Cond = "Cond"
val Decl = "Decl"
Expand All @@ -59,6 +66,7 @@ object ParserKeys {
val Path = "Path"
val Rhs = "Rhs"
val Specs = "Specs"
val Tag = "Tag"
val Tok = "Tok"
val Type = "Type"
val Value = "Value"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,4 +129,45 @@ class DataflowTests extends GoCodeToCpgSuite(withOssDataflow = true) {
}
}

"Source to sink dataflow through switch case" should {
"be reachable for expression condition" in {
val cpg = code("""
|package main
|func method() {
| var marks int = 90
| var grade string = "B"
| switch marks {
| case 90: myGrade := grade
| case 50,60,70: grade = "C"
| default: grade = "D"
| }
|}
""".stripMargin)
val source = cpg.identifier("grade").lineNumber(5)
val sink = cpg.identifier("myGrade").lineNumber(7)
sink.reachableByFlows(source).size shouldBe 1

}

"be reachable for empty condition" ignore {
// TODO (BUG)dataflow doesn't work for empty condition in switch case
val cpg = code("""
|package main
|func method() {
| var marks int = 90
| var grade string = "B"
| switch {
| case grade == "A" :
| mymarks := grade
| case grade == "B":
| marks = 80
| }
|}
""".stripMargin)
val source = cpg.identifier("grade").lineNumber(5).l
val sink = cpg.identifier("mymarks").lineNumber(8).l
sink.reachableByFlows(source).size shouldBe 1
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -169,4 +169,149 @@ class AstCreationPassTests extends GoCodeToCpgSuite {
.headOption shouldBe Some(("y", "1"))
}
}

"be correct for switch case 1" in {

val cpg = code("""
|package main
|func method() {
| var marks int = 90
| var grade string = "B"
| switch marks {
| case 90: grade = "A"
| case 50,60,70: grade = "C"
| default: grade = "D"
| }
|}
""".stripMargin)
inside(cpg.method.name("method").controlStructure.l) { case List(controlStruct: ControlStructure) =>
controlStruct.code shouldBe "switch marks"
controlStruct.controlStructureType shouldBe ControlStructureTypes.SWITCH
inside(controlStruct.astChildren.l) { case List(cond: Identifier, switchBlock: Block) =>
cond.code shouldBe "marks"
switchBlock.astChildren.size shouldBe 12
switchBlock.astChildren.code.l shouldBe List(
"case 90",
"90",
"grade = \"A\"",
"case 50",
"50",
"case 60",
"60",
"case 70",
"70",
"grade = \"C\"",
"default",
"grade = \"D\""
)
}
}
}

"be correct for switch case 2" in {

val cpg = code("""
|package main
|func method() {
| var marks int = 90
| var grade string = "B"
| switch {
| case grade == "A" :
| marks = 95
| case grade == "B":
| marks = 80
| }
|}
""".stripMargin)
inside(cpg.method.name("method").controlStructure.l) { case List(controlStruct: ControlStructure) =>
controlStruct.code shouldBe "switch "
controlStruct.controlStructureType shouldBe ControlStructureTypes.SWITCH
inside(controlStruct.astChildren.l) { case List(switchBlock: Block) =>
switchBlock.astChildren.size shouldBe 6
switchBlock.astChildren.code.l shouldBe List(
"case grade == \"A\"",
"grade == \"A\"",
"marks = 95",
"case grade == \"B\"",
"grade == \"B\"",
"marks = 80"
)
}
}
}

"be correct for switch case 3" ignore {

val cpg = code("""
|package main
|func method() {
| var x interface{}
| var y int = 6
| switch i := x.(type) {
| case nil:
| y = 5
| case int:
| y = 8
| case float64:
| y= 12
| }
|}
""".stripMargin)
inside(cpg.method.name("method").controlStructure.l) { case List(controlStruct: ControlStructure) =>
controlStruct.code shouldBe "switch i := x.(type)"
controlStruct.controlStructureType shouldBe ControlStructureTypes.SWITCH
inside(controlStruct.astChildren.l) { case List(assignment: Call, switchBlock: Block) =>
switchBlock.astChildren.size shouldBe 9
switchBlock.astChildren.code.l shouldBe List(
"case nil",
"nil",
"y = 5",
"case int",
"int",
"y = 8",
"case float64",
"float64",
"y = 12"
)
}
}
}

"be correct for switch case 4" in {

val cpg = code("""
|package main
|func method() {
| var x interface{}
| var y int = 6
| switch x.(type) {
| case nil:
| y = 5
| case int:
| y = 8
| case float64:
| y = 12
| }
|}
""".stripMargin)
inside(cpg.method.name("method").controlStructure.l) { case List(controlStruct: ControlStructure) =>
controlStruct.code shouldBe "switch x.(type)"
controlStruct.controlStructureType shouldBe ControlStructureTypes.SWITCH
inside(controlStruct.astChildren.l) { case List(identifier: Identifier, switchBlock: Block) =>
identifier.code shouldBe "x"
switchBlock.astChildren.size shouldBe 9
switchBlock.astChildren.code.l shouldBe List(
"case nil",
"nil",
"y = 5",
"case int",
"int",
"y = 8",
"case float64",
"float64",
"y = 12"
)
}
}
}
}
Loading