From c93a366b9159b8d6215de6b57aac011262aa3c12 Mon Sep 17 00:00:00 2001 From: Michael Pollmeier Date: Thu, 28 Mar 2024 19:22:56 +0100 Subject: [PATCH] bring back path tracking - used in codescience (#167) --- .../scala/flatgraph/traversal/Language.scala | 75 ++++++++------- .../traversal/PathAwareRepeatStep.scala | 91 +++++++++++++++++++ .../traversal/PathAwareTraversal.scala | 80 ++++++++++++++++ .../traversal/RepeatTraversalTests.scala | 21 ++--- .../flatgraph/traversal/TraversalTests.scala | 8 +- .../simple/ExampleGraphSetup.scala | 39 +++++++- 6 files changed, 255 insertions(+), 59 deletions(-) create mode 100644 core/src/main/scala/flatgraph/traversal/PathAwareRepeatStep.scala create mode 100644 core/src/main/scala/flatgraph/traversal/PathAwareTraversal.scala diff --git a/core/src/main/scala/flatgraph/traversal/Language.scala b/core/src/main/scala/flatgraph/traversal/Language.scala index 09d86708..15ec06f2 100644 --- a/core/src/main/scala/flatgraph/traversal/Language.scala +++ b/core/src/main/scala/flatgraph/traversal/Language.scala @@ -167,8 +167,8 @@ class GenericSteps[A](iterator: Iterator[A]) extends AnyVal { @Doc(info = "perform side effect without changing the contents of the traversal") def sideEffect(fun: A => _): Iterator[A] = iterator match { - // TODO bring back PathAwareTraversal? -// case pathAwareTraversal: PathAwareTraversal[A] => pathAwareTraversal._sideEffect(fun) + case pathAwareTraversal: PathAwareTraversal[A] => + pathAwareTraversal._sideEffect(fun) case _ => iterator.map { a => fun(a); a @@ -240,8 +240,8 @@ class GenericSteps[A](iterator: Iterator[A]) extends AnyVal { */ @Doc(info = "union/sum/aggregate/join given traversals from the current point") def union[B](traversals: (Iterator[A] => Iterator[B])*): Iterator[B] = iterator match { - // TODO bring back PathAwareTraversal? -// case pathAwareTraversal: PathAwareTraversal[A] => pathAwareTraversal._union(traversals: _*) + case pathAwareTraversal: PathAwareTraversal[A] => + pathAwareTraversal._union(traversals: _*) case _ => iterator.flatMap { (a: A) => traversals.flatMap(_.apply(Iterator.single(a))) @@ -275,8 +275,8 @@ class GenericSteps[A](iterator: Iterator[A]) extends AnyVal { def choose[BranchOn >: Null, NewEnd]( on: Iterator[A] => Iterator[BranchOn] )(options: PartialFunction[BranchOn, Iterator[A] => Iterator[NewEnd]]): Iterator[NewEnd] = iterator match { - // TODO bring back PathAwareTraversal? -// case pathAwareTraversal: PathAwareTraversal[A] => pathAwareTraversal._choose[BranchOn, NewEnd](on)(options) + case pathAwareTraversal: PathAwareTraversal[A] => + pathAwareTraversal._choose[BranchOn, NewEnd](on)(options) case _ => iterator.flatMap { (a: A) => val branchOnValue: BranchOn = on(Iterator.single(a)).nextOption().getOrElse(null) @@ -288,8 +288,8 @@ class GenericSteps[A](iterator: Iterator[A]) extends AnyVal { @Doc(info = "evaluates the provided traversals in order and returns the first traversal that emits at least one element") def coalesce[NewEnd](options: (Iterator[A] => Iterator[NewEnd])*): Iterator[NewEnd] = iterator match { - // TODO bring back PathAwareTraversal? -// case pathAwareTraversal: PathAwareTraversal[A] => pathAwareTraversal._coalesce(options: _*) + case pathAwareTraversal: PathAwareTraversal[A] => + pathAwareTraversal._coalesce(options: _*) case _ => iterator.flatMap { (a: A) => options.iterator @@ -301,22 +301,22 @@ class GenericSteps[A](iterator: Iterator[A]) extends AnyVal { } } - // @Doc(info = "enable path tracking - prerequisite for path/simplePath steps") - // TODO bring back PathAwareTraversal? -// def enablePathTracking: PathAwareTraversal[A] = -// iterator match { -// case pathAwareTraversal: PathAwareTraversal[_] => throw new RuntimeException("path tracking is already enabled") -// case _ => new PathAwareTraversal[A](iterator.map { a => (a, Vector.empty) }) -// } + @Doc(info = "enable path tracking - prerequisite for path/simplePath steps") + def enablePathTracking: PathAwareTraversal[A] = + iterator match { + case pathAwareTraversal: PathAwareTraversal[_] => throw new RuntimeException("path tracking is already enabled") + case _ => new PathAwareTraversal[A](iterator.map { a => (a, Vector.empty) }) + } - // @Doc(info = "enable path tracking - prerequisite for path/simplePath steps") -// def discardPathTracking: Iterator[A] = -// iterator match { -// case pathAwareTraversal: PathAwareTraversal[A] => pathAwareTraversal.wrapped.map { _._1 } -// case _ => iterator -// } + @Doc(info = "enable path tracking - prerequisite for path/simplePath steps") + def discardPathTracking: Iterator[A] = + iterator match { + case pathAwareTraversal: PathAwareTraversal[A] => pathAwareTraversal.wrapped.map { _._1 } + case _ => iterator + } -// def isPathTracking: Boolean = iterator.isInstanceOf[PathAwareTraversal[_]] + def isPathTracking: Boolean = + iterator.isInstanceOf[PathAwareTraversal[_]] /** retrieve entire path that has been traversed thus far prerequisite: enablePathTracking has been called previously * @@ -328,11 +328,10 @@ class GenericSteps[A](iterator: Iterator[A]) extends AnyVal { */ @Doc(info = "retrieve entire path that has been traversed thus far") def path: Iterator[Vector[Any]] = iterator match { - // TODO bring back PathAwareTraversal? -// case tracked: PathAwareTraversal[A] => -// tracked.wrapped.map { case (a, p) => -// p.appended(a) -// } + case tracked: PathAwareTraversal[A] => + tracked.wrapped.map { case (a, p) => + p.appended(a) + } case _ => throw new AssertionError( "path tracking not enabled, please make sure you have a `PathAwareTraversal`, e.g. via `Traversal.enablePathTracking`" @@ -340,11 +339,10 @@ class GenericSteps[A](iterator: Iterator[A]) extends AnyVal { } // fixme: I think ClassCastException is the correct result when the user forgot to enable path tracking. But a better error message to go along with it would be nice. def simplePath: Iterator[A] = iterator match { - // TODO bring back PathAwareTraversal? -// case tracked: PathAwareTraversal[A] => -// new PathAwareTraversal(tracked.wrapped.filter { case (a, p) => -// mutable.Set.from(p).addOne(a).size == 1 + p.size -// }) + case tracked: PathAwareTraversal[A] => + new PathAwareTraversal(tracked.wrapped.filter { case (a, p) => + mutable.Set.from(p).addOne(a).size == 1 + p.size + }) case _ => throw new AssertionError( "path tracking not enabled, please make sure you have a `PathAwareTraversal`, e.g. via `Traversal.enablePathTracking`" @@ -387,12 +385,13 @@ class GenericSteps[A](iterator: Iterator[A]) extends AnyVal { repeatTraversal .asInstanceOf[Iterator[B] => Iterator[B]] // this cast usually :tm: safe, because `B` is a supertype of `A` iterator match { -// case tracked: PathAwareTraversal[A] => -// val step = PathAwareRepeatStep(_repeatTraversal, behaviour) -// new PathAwareTraversal(tracked.wrapped.flatMap { case (a, p) => -// step.apply(a).wrapped.map { case (aa, pp) => (aa, p ++ pp) } -// }) - case _ => iterator.flatMap(RepeatStep(_repeatTraversal, behaviour)) + case tracked: PathAwareTraversal[A] => + val step = PathAwareRepeatStep(_repeatTraversal, behaviour) + new PathAwareTraversal(tracked.wrapped.flatMap { case (a, p) => + step.apply(a).wrapped.map { case (aa, pp) => (aa, p ++ pp) } + }) + case _ => + iterator.flatMap(RepeatStep(_repeatTraversal, behaviour)) } } } diff --git a/core/src/main/scala/flatgraph/traversal/PathAwareRepeatStep.scala b/core/src/main/scala/flatgraph/traversal/PathAwareRepeatStep.scala new file mode 100644 index 00000000..eec0fc3b --- /dev/null +++ b/core/src/main/scala/flatgraph/traversal/PathAwareRepeatStep.scala @@ -0,0 +1,91 @@ +package flatgraph.traversal + +import flatgraph.traversal.RepeatBehaviour.SearchAlgorithm + +import scala.collection.{mutable, Iterator} + +object PathAwareRepeatStep { + import RepeatStep._ + + case class WorklistItem[A](traversal: Iterator[A], depth: Int) + + /** @see + * [[Traversal.repeat]] for a detailed overview + * + * Implementation note: using recursion results in nicer code, but uses the JVM stack, which only has enough space for ~10k steps. So + * instead, this uses a programmatic Stack which is semantically identical. The RepeatTraversalTests cover this case. + */ + def apply[A](repeatTraversal: Iterator[A] => Iterator[A], behaviour: RepeatBehaviour[A]): A => PathAwareTraversal[A] = { (element: A) => + new PathAwareTraversal[A](new Iterator[(A, Vector[Any])] { + val visited = mutable.Set.empty[A] + val emitSack: mutable.Queue[(A, Vector[Any])] = mutable.Queue.empty + val worklist: Worklist[WorklistItem[A]] = behaviour.searchAlgorithm match { + case SearchAlgorithm.DepthFirst => new LifoWorklist() + case SearchAlgorithm.BreadthFirst => new FifoWorklist() + } + + worklist.addItem(WorklistItem(new PathAwareTraversal(Iterator.single((element, Vector.empty))), 0)) + + def hasNext: Boolean = { + if (emitSack.isEmpty) { + // this may add elements to the emit sack and/or modify the worklist + traverseOnWorklist + } + emitSack.nonEmpty || worklistTopHasNext + } + + private def traverseOnWorklist: Unit = { + var stop = false + while (worklist.nonEmpty && !stop) { + val WorklistItem(trav0, depth) = worklist.head + val trav = trav0.asInstanceOf[PathAwareTraversal[A]].wrapped + if (trav.isEmpty) worklist.removeHead() + else if (behaviour.maxDepthReached(depth)) stop = true + else { + val (element, path1) = trav.next() + if (behaviour.dedupEnabled) visited.addOne(element) + + if ( // `while/repeat` behaviour, i.e. check every time + behaviour.whileConditionIsDefinedAndEmpty(element) || + // `repeat/until` behaviour, i.e. only checking the `until` condition from depth 1 + (depth > 0 && behaviour.untilConditionReached(element)) + ) { + // we just consumed an element from the traversal, so in lieu adding to the emit sack + emitSack.enqueue((element, path1)) + stop = true + } else { + val nextLevelTraversal = { + val repeat = + repeatTraversal(new PathAwareTraversal(Iterator.single((element, path1)))) + if (behaviour.dedupEnabled) repeat.filterNot(visited.contains) + else repeat + } + worklist.addItem(WorklistItem(nextLevelTraversal, depth + 1)) + + if (behaviour.shouldEmit(element, depth)) + emitSack.enqueue((element, path1)) + + if (emitSack.nonEmpty) + stop = true + } + } + } + } + + private def worklistTopHasNext: Boolean = + worklist.nonEmpty && worklist.head.traversal.hasNext + + override def next(): (A, Vector[Any]) = { + val result = { + if (emitSack.nonEmpty) emitSack.dequeue() + else if (worklistTopHasNext) { + worklist.head.traversal.asInstanceOf[PathAwareTraversal[A]].wrapped.next() + } else throw new NoSuchElementException("next on empty iterator") + } + if (behaviour.dedupEnabled) visited.addOne(result._1) + result + } + }) + } + +} diff --git a/core/src/main/scala/flatgraph/traversal/PathAwareTraversal.scala b/core/src/main/scala/flatgraph/traversal/PathAwareTraversal.scala new file mode 100644 index 00000000..0fc7d9e6 --- /dev/null +++ b/core/src/main/scala/flatgraph/traversal/PathAwareTraversal.scala @@ -0,0 +1,80 @@ +package flatgraph.traversal + +import scala.collection.{IterableOnce, Iterator} + +class PathAwareTraversal[A](val wrapped: Iterator[(A, Vector[Any])]) extends Iterator[A] { + override def hasNext: Boolean = wrapped.hasNext + + override def next(): A = wrapped.next()._1 + + override def map[B](f: A => B): PathAwareTraversal[B] = new PathAwareTraversal[B](wrapped.map { case (a, p) => + (f(a), p.appended(a)) + }) + + override def flatMap[B](f: A => IterableOnce[B]): PathAwareTraversal[B] = + new PathAwareTraversal[B](wrapped.flatMap { case (a, p) => + val ap = p.appended(a) + f(a).iterator.map { + (_, ap) + } + }) + + override def distinctBy[B](f: A => B): PathAwareTraversal[A] = new PathAwareTraversal[A](wrapped.distinctBy { case (a, p) => + f(a) + }) + + override def collect[B](pf: PartialFunction[A, B]): PathAwareTraversal[B] = flatMap(pf.lift) + + override def filter(p: A => Boolean): PathAwareTraversal[A] = new PathAwareTraversal(wrapped.filter(ap => p(ap._1))) + + override def filterNot(p: A => Boolean): PathAwareTraversal[A] = new PathAwareTraversal(wrapped.filterNot(ap => p(ap._1))) + + override def duplicate: (Iterator[A], Iterator[A]) = { + val tmp = wrapped.duplicate + (new PathAwareTraversal(tmp._1), new PathAwareTraversal(tmp._2)) + } + + private[traversal] def _union[B](traversals: (Iterator[A] => Iterator[B])*): Iterator[B] = + new PathAwareTraversal(wrapped.flatMap { case (a, p) => + traversals.iterator.flatMap { inner => + inner(new PathAwareTraversal(Iterator.single((a, p)))) match { + case stillPathAware: PathAwareTraversal[B] => stillPathAware.wrapped + // do we really want to allow the following, or is it an error? + case notPathAware => notPathAware.iterator.map { (b: B) => (b, p.appended(a)) } + } + } + }) + + private[traversal] def _choose[BranchOn >: Null, NewEnd](on: Iterator[A] => Iterator[BranchOn])( + options: PartialFunction[BranchOn, Iterator[A] => Iterator[NewEnd]] + ): Iterator[NewEnd] = + new PathAwareTraversal(wrapped.flatMap { case (a, p) => + val branchOnValue: BranchOn = on(Iterator.single(a)).nextOption().getOrElse(null) + options + .applyOrElse(branchOnValue, (failState: BranchOn) => (unused: Iterator[A]) => Iterator.empty[NewEnd]) + .apply(new PathAwareTraversal(Iterator.single((a, p)))) match { + case stillPathAware: PathAwareTraversal[NewEnd] => stillPathAware.wrapped + // do we really want to allow the following, or is it an error? + case notPathAware => notPathAware.iterator.map { (b: NewEnd) => (b, p.appended(a)) } + } + }) + + private[traversal] def _coalesce[NewEnd](options: (Iterator[A] => Iterator[NewEnd])*): Iterator[NewEnd] = + new PathAwareTraversal(wrapped.flatMap { case (a, p) => + options.iterator + .map { inner => + inner(new PathAwareTraversal(Iterator.single((a, p)))) match { + case stillPathAware: PathAwareTraversal[NewEnd] => stillPathAware.wrapped + // do we really want to allow the following, or is it an error? + case notPathAware => notPathAware.iterator.map { (b: NewEnd) => (b, p.appended(a)) } + } + } + .find(_.nonEmpty) + .getOrElse(Iterator.empty) + }) + + private[traversal] def _sideEffect(f: A => _): PathAwareTraversal[A] = new PathAwareTraversal(wrapped.map { case (a, p) => + f(a); (a, p) + }) + +} diff --git a/core/src/test/scala/flatgraph/traversal/RepeatTraversalTests.scala b/core/src/test/scala/flatgraph/traversal/RepeatTraversalTests.scala index d29265d7..97cc437a 100644 --- a/core/src/test/scala/flatgraph/traversal/RepeatTraversalTests.scala +++ b/core/src/test/scala/flatgraph/traversal/RepeatTraversalTests.scala @@ -3,6 +3,7 @@ package flatgraph.traversal import org.scalatest.matchers.should.Matchers._ import org.scalatest.wordspec.AnyWordSpec import flatgraph.traversal.testdomains.simple.ExampleGraphSetup +import flatgraph.traversal.testdomains.simple.ExampleGraphSetup.Properties import flatgraph.traversal.Language.* import scala.collection.mutable @@ -48,19 +49,15 @@ class RepeatTraversalTests extends AnyWordSpec with ExampleGraphSetup { centerTrav.repeat(_.out)(_.emitAllButFirst.breadthFirstSearch).toSetMutable shouldBe expectedResults } - // TODO continue here -// "emit nodes that meet given condition" in { -// val expectedResults = Set("L1", "L2", "L3") -// centerTrav -// .repeat(_.out)(_.emit(_.has(Name.where(_.startsWith("L"))))) -// .property(Name) -// .toSetMutable shouldBe expectedResults -// centerTrav -// .repeat(_.out)(_.emit(_.has(Name.where(_.startsWith("L")))).breadthFirstSearch) -// .property(Name) -// .toSetMutable shouldBe expectedResults -// } + "emit nodes that meet given condition" in { + val expectedResults = Set("L1", "L2", "L3") + centerTrav + .repeat(_.out)(_.emit(_.where(_.property(Properties.Name).filter(_.startsWith("L")))).breadthFirstSearch) + .property(Properties.Name) + .toSetMutable shouldBe expectedResults + } + // TODO continue here // "going through multiple steps in repeat traversal" in { // r1.start.repeat(_.out.out)(_.emit).l shouldBe Seq(r1, r3, r5) // r1.start.enablePathTracking.repeat(_.out.out)(_.emit).path.l shouldBe Seq( diff --git a/core/src/test/scala/flatgraph/traversal/TraversalTests.scala b/core/src/test/scala/flatgraph/traversal/TraversalTests.scala index 4f98e425..c46a3421 100644 --- a/core/src/test/scala/flatgraph/traversal/TraversalTests.scala +++ b/core/src/test/scala/flatgraph/traversal/TraversalTests.scala @@ -5,7 +5,7 @@ import flatgraph.GNode import flatgraph.help.{DocSearchPackages, Table} import flatgraph.help.Table.AvailableWidthProvider import flatgraph.traversal.Language.* -import flatgraph.traversal.testdomains.simple.SimpleDomain.Thing +import flatgraph.traversal.testdomains.simple.SimpleDomain.{Connection, Thing} import flatgraph.traversal.testdomains.simple.{ExampleGraphSetup, SimpleDomain} import org.scalatest.matchers.should.Matchers.* import org.scalatest.wordspec.AnyWordSpec @@ -19,11 +19,11 @@ class TraversalTests extends AnyWordSpec with ExampleGraphSetup { def centerTrav = Iterator.single(center) "GNode traversals" in { - centerTrav.label.l shouldBe Seq("V0") + centerTrav.label.l shouldBe Seq(Thing.Label) centerTrav.outE.size shouldBe 2 centerTrav.inE.size shouldBe 0 - centerTrav.outE("0").size shouldBe 2 - centerTrav.inE("0").size shouldBe 0 + centerTrav.outE(Connection.Label).size shouldBe 2 + centerTrav.inE(Connection.Label).size shouldBe 0 } "can only be iterated once" in { diff --git a/core/src/test/scala/flatgraph/traversal/testdomains/simple/ExampleGraphSetup.scala b/core/src/test/scala/flatgraph/traversal/testdomains/simple/ExampleGraphSetup.scala index 5649b27f..b3b35bef 100644 --- a/core/src/test/scala/flatgraph/traversal/testdomains/simple/ExampleGraphSetup.scala +++ b/core/src/test/scala/flatgraph/traversal/testdomains/simple/ExampleGraphSetup.scala @@ -1,7 +1,7 @@ package flatgraph.traversal.testdomains.simple import flatgraph.help.Table.AvailableWidthProvider -import flatgraph.{DiffGraphApplier, DiffGraphBuilder, GNode, GenericDNode, Graph, TestSchema} +import flatgraph.{DiffGraphApplier, DiffGraphBuilder, FreeSchema, GNode, GenericDNode, Graph, SinglePropertyKey, TestSchema} import flatgraph.help.{Doc, DocSearchPackages, Traversal, TraversalHelp, TraversalSource} import flatgraph.traversal.testdomains.simple.SimpleDomain.Thing import flatgraph.traversal.Language.* @@ -10,8 +10,8 @@ import flatgraph.traversal.Language.* * L3 <- L2 <- L1 <- Center -> R1 -> R2 -> R3 -> R4 -> R5 * */ trait ExampleGraphSetup { - // val nonExistingLabel = "this label does not exist" - // val nonExistingPropertyKey = new PropertyKey[String]("this property key does not exist") + import ExampleGraphSetup.Properties.* + val nonExistingLabel = "this label does not exist" val graph = SimpleDomain.newGraph val l3 = addNode() @@ -26,7 +26,6 @@ trait ExampleGraphSetup { val diff = new DiffGraphBuilder(graph.schema) // TODO reimplement arrow synax from odb - // TODO bring back properties as well // center --- Connection.Label --> l1 // l1 --- Connection.Label --> l2 // l2 --- Connection.Label --> l3 @@ -44,6 +43,15 @@ trait ExampleGraphSetup { ._addEdge(r2, r3, 0) ._addEdge(r3, r4, 0) ._addEdge(r4, r5, 0) + .setNodeProperty(l3, Name.name, "L3") + .setNodeProperty(l2, Name.name, "L2") + .setNodeProperty(l1, Name.name, "L1") + .setNodeProperty(center, Name.name, "Center") + .setNodeProperty(r1, Name.name, "R1") + .setNodeProperty(r2, Name.name, "R2") + .setNodeProperty(r3, Name.name, "R3") + .setNodeProperty(r4, Name.name, "R4") + .setNodeProperty(r5, Name.name, "R5") DiffGraphApplier.applyDiff(graph, diff) def addNode(): GNode = { @@ -52,11 +60,25 @@ trait ExampleGraphSetup { newNode.storedRef.get // that reference is set by DiffGraphApplier } } +object ExampleGraphSetup { + // property keys etc are normally generated by DomainClassesGenerator for a given schema + object Properties { + val Name = SinglePropertyKey(kind = 0, name = "name", default = "") + val NonExisting = SinglePropertyKey(kind = 10, name = "this property key does not exist", default = "default value 0") + } +} object SimpleDomain { class Thing(graph: Graph, nodeKind: Short, seqId: Int) extends GNode(graph, nodeKind, seqId) { def name: String = ??? } + object Thing { + val Label = "Thing" + } + + object Connection { + val Label = "Connection" + } val defaultDocSearchPackage: DocSearchPackages = DocSearchPackages(getClass.getPackage.getName) def help(using AvailableWidthProvider) = @@ -65,7 +87,14 @@ object SimpleDomain { TraversalHelp(defaultDocSearchPackage).forTraversalSources(verbose = true) def newGraph: Graph = { - val schema = TestSchema.make(1, 1) + val edgeLabels = Array(Connection.Label) + val schema = new FreeSchema( + nodeLabels = Array(Thing.Label), + edgeLabels = edgeLabels, + propertyLabels = Array(ExampleGraphSetup.Properties.Name.name), + edgePropertyPrototypes = new Array(edgeLabels.length), + nodePropertyPrototypes = Array(Array.empty[String]) + ) Graph(schema) }