diff --git a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/language/package.scala b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/language/package.scala index 1c6fcd163a45..842ce65d0d74 100644 --- a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/language/package.scala +++ b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/language/package.scala @@ -1,8 +1,12 @@ package io.joern.dataflowengineoss -import io.shiftleft.codepropertygraph.generated.nodes._ +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.semanticcpg.language.* import io.joern.dataflowengineoss.language.dotextension.DdgNodeDot import io.joern.dataflowengineoss.language.nodemethods.{ExpressionMethods, ExtendedCfgNodeMethods} +import overflowdb.traversal.help.Doc + +import scala.language.implicitConversions package object language { @@ -21,4 +25,21 @@ package object language { implicit def toDdgNodeDotSingle(method: Method): DdgNodeDot = new DdgNodeDot(Iterator.single(method)) + implicit def toExtendedPathsTrav[NodeType <: Path](traversal: IterableOnce[NodeType]): PassesExt = + new PassesExt(traversal.iterator) + + class PassesExt(traversal: Iterator[Path]) { + + @Doc(info = "Filters in paths that pass though the given paths") + def passes(trav: Iterator[AstNode] => Iterator[?]): Iterator[Path] = { + traversal.filter(_.elements.exists(_.start.where(trav).nonEmpty)) + } + + @Doc(info = "Filters out paths that pass though the given paths") + def passesNot(trav: Iterator[AstNode] => Iterator[?]): Iterator[Path] = { + traversal.filter(_.elements.forall(_.start.where(trav).isEmpty)) + } + + } + } diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/CfgNodeMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/CfgNodeMethods.scala index fc61285a889b..a15949550b9c 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/CfgNodeMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/CfgNodeMethods.scala @@ -83,26 +83,6 @@ class CfgNodeMethods(val node: CfgNode) extends AnyVal with NodeExtension { } } - /** Using the post dominator tree, will determine if this node passes through the included set of nodes and filter it - * in. - * @param included - * the nodes this node must pass through. - * @return - * the traversal of this node if it passes through the included set. - */ - def passes(included: Set[CfgNode]): Iterator[CfgNode] = - Iterator.single(node).filter(_.postDominatedBy.exists(included.contains)) - - /** Using the post dominator tree, will determine if this node passes through the excluded set of nodes and filter it - * out. - * @param excluded - * the nodes this node must not pass through. - * @return - * the traversal of this node if it does not pass through the excluded set. - */ - def passesNot(excluded: Set[CfgNode]): Iterator[CfgNode] = - Iterator.single(node).filterNot(_.postDominatedBy.exists(excluded.contains)) - private def expandExhaustively(expand: CfgNode => Iterator[StoredNode]): Iterator[CfgNode] = { var controllingNodes = List.empty[CfgNode] var visited = Set.empty + node diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/CfgNodeTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/CfgNodeTraversal.scala index 3ac6a86b9b33..3a3d22f7840f 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/CfgNodeTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/CfgNodeTraversal.scala @@ -86,16 +86,4 @@ class CfgNodeTraversal[A <: CfgNode](val traversal: Iterator[A]) extends AnyVal def address: Iterator[Option[String]] = traversal.map(_.address) - @Doc(info = "Filters in paths that pass though the given traversal") - def passes(included: Iterator[CfgNode]): Iterator[CfgNode] = { - val in = included.toSet - traversal.flatMap(_.passes(in)) - } - - @Doc(info = "Filters out paths that pass though the given traversal") - def passesNot(excluded: Iterator[CfgNode]): Iterator[CfgNode] = { - val ex = excluded.toSet - traversal.flatMap(_.passesNot(ex)) - } - }