diff --git a/core/src/main/scala/flatgraph/traversal/RepeatBehaviour.scala b/core/src/main/scala/flatgraph/traversal/RepeatBehaviour.scala index 26cea413..267a29a9 100644 --- a/core/src/main/scala/flatgraph/traversal/RepeatBehaviour.scala +++ b/core/src/main/scala/flatgraph/traversal/RepeatBehaviour.scala @@ -67,9 +67,11 @@ object RepeatBehaviour { this } - /** Emit intermediate elements (along the way), if they meet the given condition. Note that this does not apply a filter on the final - * elements of the traversal. - */ + /** + * Emit intermediate elements (along the way), if they meet the given condition. + * Note: this does not apply a filter on the final elements of the traversal! Quite likely that you want to reuse + * the given condition as a filter step at the end of your traversal... See `RepeatTraversalTests` for an example. + */ def emit(condition: Traversal[A] => Traversal[?]): Builder[A] = { _shouldEmit = (element, _) => condition(Iterator.single(element)).hasNext this diff --git a/tests/src/test/scala/flatgraph/traversal/RepeatTraversalTests.scala b/tests/src/test/scala/flatgraph/traversal/RepeatTraversalTests.scala index 25c7f40c..b8af84f1 100644 --- a/tests/src/test/scala/flatgraph/traversal/RepeatTraversalTests.scala +++ b/tests/src/test/scala/flatgraph/traversal/RepeatTraversalTests.scala @@ -51,24 +51,31 @@ class RepeatTraversalTests extends AnyWordSpec with FlatlineGraphFixture { centerTrav.repeat(_.out)(_.emit.breadthFirstSearch).toSet shouldBe expectedResults } - "emit everything but the first element (starting point)" in { + "emit everything along the way but the first element (starting point)" in { val expectedResults = Set(l3, l2, l1, r1, r2, r3, r4, r5) centerTrav.repeat(_.out)(_.emitAllButFirst).toSet shouldBe expectedResults centerTrav.repeat(_.out)(_.emitAllButFirst.breadthFirstSearch).toSet shouldBe expectedResults } - "emit nodes that meet given condition" in { - val expectedResults = Set("L1", "L2", "L3") + "emit nodes along the way that meet given condition" in { centerTrav - .repeat(_.out)(_.emit(_.where(_.property(StringMandatory).filter(_.startsWith("L")))).breadthFirstSearch) + .repeat(_.out)(_.emit(_.where(_.property(StringMandatory).filter(_.startsWith("L"))))) .property(StringMandatory) - .toSet shouldBe expectedResults + .toSet shouldBe Set("L1", "L2", "L3") // with domain specific language centerTrav - .repeat(_.connectedTo)(_.emit(_.where(_.stringMandatory("L.*"))).breadthFirstSearch) + .repeat(_.connectedTo)(_.emit(_.stringMandatory("L.*"))) .stringMandatory - .toSet shouldBe expectedResults + .toSet shouldBe Set("L1", "L2", "L3") + + // note: the emit condition only applies as a filter to what's emitted _along the way_, i.e. if the repeat + // traversal ends somewhere with results (e.g. because of `maxDepth` or `until`), you'll get those results also + // example: this traversal ends at `L2/R2` due to `maxDepth=2` and it emitted `L1` along the way + centerTrav + .repeat(_.connectedTo)(_.maxDepth(2).emit(_.stringMandatory("L.*"))) + .stringMandatory + .toSet shouldBe Set("L1", "L2", "R2") } "going through multiple steps in repeat traversal" in {