From a9c2b6219079252c36c08e9ccd4a2ded49b4ddc5 Mon Sep 17 00:00:00 2001 From: Sergey Gulin Date: Fri, 4 Oct 2024 18:51:52 +0300 Subject: [PATCH] [DMS-74] Implement missing functionality for set Add following functions: union(rbSet1 : Set, rbSet2 : Set) : Set intersect(rbSet1 : Set, rbSet2 : Set) : Set diff(rbSet1 : Set, rbSet2 : Set) : Set map(rbSet : Set, f : T1 -> T) : Set mapFilter(rbSet: Set, f : T1 -> ?T) : Set isSubset(rbSet1 : Set, rbSet2 : Set) : Bool equals (rbSet1 : Set, rbSet2 : Set) : Bool isEmpty (rbSet : Set) : Bool Add tests for the functions above. --- src/PersistentOrderedMap.mo | 104 +++++++++++- src/PersistentOrderedSet.mo | 263 +++++++++++++++++++++++++++++- test/PersistentOrderedSet.test.mo | 246 ++++++++++++++++++++++++++++ 3 files changed, 609 insertions(+), 4 deletions(-) diff --git a/src/PersistentOrderedMap.mo b/src/PersistentOrderedMap.mo index 222b93ea..46512b17 100644 --- a/src/PersistentOrderedMap.mo +++ b/src/PersistentOrderedMap.mo @@ -574,8 +574,7 @@ module { acc }; - - module Internal { + public module Internal { public func fromIter(i : I.Iter<(K,V)>, compare : (K, K) -> O.Order) : Map { @@ -887,6 +886,105 @@ module { }; case other { (other, y0) }; }; - } + }; + + // TODO: Instead, consider storing the black height in the node constructor + public func blackHeight (rbMap : Map) : Nat { + func f (node : Map, acc : Nat) : Nat { + switch node { + case (#leaf) { acc }; + case (#node (#R, l1, _, _)) { f(l1, acc) }; + case (#node (#B, l1, _, _)) { f(l1, acc + 1) } + } + }; + f (rbMap, 0) + }; + + public func joinL(l : Map, x : (K, V), r : Map) : Map { + if (blackHeight r <= blackHeight l) { (#node (#R, l, x, r)) } + else { + switch r { + case (#node (#R, rl, rx, rr)) { (#node (#R, joinL(l, x, rl) , rx, rr)) }; + case (#node (#B, rl, rx, rr)) { balLeft (joinL(l, x, rl), rx, rr) }; + case _ { Debug.trap "joinL" }; + } + } + }; + + public func joinR(l : Map, x : (K, V), r : Map) : Map { + if (blackHeight l <= blackHeight r) { (#node (#R, l, x, r)) } + else { + switch l { + case (#node (#R, ll, lx, lr)) { (#node (#R, ll , lx, joinR (lr, x, r))) }; + case (#node (#B, ll, lx, lr)) { balRight (ll, lx, joinR (lr, x, r)) }; + case _ { Debug.trap "joinR" }; + } + } + }; + + public func paint(color : Color, rbMap : Map) : Map { + switch rbMap { + case (#leaf) { #leaf }; + case (#node (_, l, x, r)) { (#node (color, l, x, r)) }; + } + }; + + public func splitMin (rbMap : Map) : ((K,V), Map) { + switch rbMap { + case (#leaf) { Debug.trap "splitMin" }; + case (#node(_, #leaf, x, r)) { (x, r) }; + case (#node(_, l, x, r)) { + let (m, l2) = splitMin l; + (m, join(l2, x, r)) + }; + } + }; + + // Joins an element and two trees. + // See Tobias Nipkow's "Functional Data Structures and Algorithms", 117 + public func join(l : Map, x : (K, V), r : Map) : Map { + if (Internal.blackHeight r < Internal.blackHeight l) { + return Internal.paint(#B, Internal.joinR(l, x, r)) + }; + if (Internal.blackHeight l < Internal.blackHeight r) { + return Internal.paint(#B, Internal.joinL(l, x, r)) + }; + return (#node (#B, l, x, r)) + }; + + // Joins two trees. + // See Tobias Nipkow's "Functional Data Structures and Algorithms", 117 + public func join2(l : Map, r : Map) : Map { + switch r { + case (#leaf) { l }; + case _ { + let (m, r2) = Internal.splitMin r; + join(l, m, r2) + }; + } + }; + + // Splits `rbMap` with respect to a given element `x`, into tuple `(l, b, r)` + // such that `l` contains the elements less than `x`, `r` contains the elements greater than `x` + // and `b` is `true` if `x` was in the `rbMap`. + // See Tobias Nipkow's "Functional Data Structures and Algorithms", 117 + public func split(x : K, rbMap : Map, compare : (K, K) -> O.Order) : (Map, Bool, Map) { + switch rbMap { + case (#leaf) { (#leaf, false, #leaf)}; + case (#node (_, l, (k, v), r)) { + switch (compare(x, k)) { + case (#less) { + let (l1, b, l2) = split(x, l, compare); + (l1, b, join(l2, (k, v), r)) + }; + case (#equal) { (l, true, r) }; + case (#greater) { + let (r1, b, r2) = split(x, r, compare); + (join(l, (k, v), r1), b, r2) + }; + }; + }; + }; + }; } } diff --git a/src/PersistentOrderedSet.mo b/src/PersistentOrderedSet.mo index 7453ef0f..0ed84fa0 100644 --- a/src/PersistentOrderedSet.mo +++ b/src/PersistentOrderedSet.mo @@ -4,6 +4,9 @@ /// * Runtime: `O(log(n))` worst case cost per insertion, removal, and retrieval operation. /// * Space: `O(n)` for storing the entire tree. /// `n` denotes the number of elements (i.e. nodes) stored in the tree. +/// +/// The set operations implementation is derived from: +/// Tobias Nipkow's "Functional Data Structures and Algorithms", 10: 117-125 (2024). import Map "PersistentOrderedMap"; @@ -135,6 +138,242 @@ module { /// /// Note: Creates `O(log(n))` temporary objects that will be collected as garbage. public func contains(rbSet : Set, value : T) : Bool = Option.isSome(mapOps.get(rbSet, value)); + + /// [Set union](https://en.wikipedia.org/wiki/Union_(set_theory)) operation. + /// + /// Example: + /// ```motoko + /// import Set "mo:base/PersistentOrderedSet"; + /// import Nat "mo:base/Nat"; + /// import Iter "mo:base/Iter" + /// + /// let setOps = Set.SetOps(Nat.compare); + /// let rbSet1 = setOps.fromIter(Iter.fromArray([0, 1, 2])); + /// let rbSet2 = setOps.fromIter(Iter.fromArray([2, 3, 4])); + /// + /// Debug.print(debug_show Iter.toArray(Set.elements(setOps.union(rbSet1, rbSet2)))); + /// + /// // [0, 1, 2, 3, 4] + /// ``` + /// + /// Runtime: `O(m * log(n/m + 1))`. + /// Space: `O(m * log(n/m + 1))`, where `m` and `n` denote the number of elements + /// in the sets, and `m <= n`. + public func union(rbSet1 : Set, rbSet2 : Set) : Set { + switch (rbSet1, rbSet2) { + case (#leaf, rbSet) { rbSet }; + case (rbSet, #leaf) { rbSet }; + case (#node (_,l1, (k, v), r1), _) { + let (l2, _, r2) = Map.Internal.split(k, rbSet2, compare); + Map.Internal.join(union(l1, l2), (k, v), union(r1, r2)) + }; + }; + }; + + /// [Set intersection](https://en.wikipedia.org/wiki/Intersection_(set_theory)) operation. + /// + /// Example: + /// ```motoko + /// import Set "mo:base/PersistentOrderedSet"; + /// import Nat "mo:base/Nat"; + /// import Iter "mo:base/Iter"; + /// + /// let setOps = Set.SetOps(Nat.compare); + /// let rbSet1 = setOps.fromIter(Iter.fromArray([0, 1, 2])); + /// let rbSet2 = setOps.fromIter(Iter.fromArray([1, 2, 3])); + /// + /// Debug.print(debug_show Iter.toArray(Set.elements(setOps.intersect(rbSet1, rbSet2)))); + /// + /// // [1, 2] + /// ``` + /// + /// Runtime: `O(m * log(n/m + 1))`. + /// Space: `O(m * log(n/m + 1))`, where `m` and `n` denote the number of elements + /// in the sets, and `m <= n`. + public func intersect(rbSet1 : Set, rbSet2 : Set) : Set { + switch (rbSet1, rbSet2) { + case (#leaf, _) { #leaf }; + case (_, #leaf) { #leaf }; + case (#node (_, l1, (k, v), r1), _) { + let (l2, b2, r2) = Map.Internal.split(k, rbSet2, compare); + let l = intersect(l1, l2); + let r = intersect(r1, r2); + if b2 { Map.Internal.join (l, (k, v), r) } + else { Map.Internal.join2(l, r) }; + }; + }; + }; + + /// [Set difference](https://en.wikipedia.org/wiki/Difference_(set_theory)). + /// + /// Example: + /// ```motoko + /// import Set "mo:base/PersistentOrderedSet"; + /// import Nat "mo:base/Nat"; + /// import Iter "mo:base/Iter" + /// + /// let setOps = Set.SetOps(Nat.compare); + /// let rbSet1 = setOps.fromIter(Iter.fromArray([0, 1, 2])); + /// let rbSet2 = setOps.fromIter(Iter.fromArray([1, 2, 3])); + /// + /// Debug.print(debug_show Iter.toArray(Set.elements(setOps.diff(rbSet1, rbSet2)))); + /// + /// // [0] + /// ``` + /// + /// Runtime: `O(m * log(n/m + 1))`. + /// Space: `O(m * log(n/m + 1))`, where `m` and `n` denote the number of elements + /// in the sets, and `m <= n`. + public func diff(rbSet1 : Set, rbSet2 : Set) : Set { + switch (rbSet1, rbSet2) { + case (#leaf, _) { #leaf }; + case (rbSet, #leaf) { rbSet }; + case (_, (#node(_, l2, (k, _), r2))) { + let (l1, _, r1) = Map.Internal.split(k, rbSet1, compare); + Map.Internal.join2(diff(l1, l2), diff(r1, r2)); + } + } + }; + + /// Creates a new Set by applying `f` to each entry in `rbSet`. Each element + /// `x` in the old set is transformed into a new entry `x2`, where + /// the new value `x2` is created by applying `f` to `x`. + /// The result set may be smaller than the original set due to duplicate elements. + /// + /// Example: + /// ```motoko + /// import Map "mo:base/PersistentOrderedMap"; + /// import Nat "mo:base/Nat" + /// import Iter "mo:base/Iter" + /// + /// let setOps = Set.SetOps(Nat.compare); + /// let rbSet = setOps.fromIter(Iter.fromArray([0, 1, 2, 3])); + /// + /// func f(x : Nat) : Nat = if (x < 2) { x } else { 0 }; + /// + /// let resSet = setOps.map(rbSet, f); + /// + /// Debug.print(debug_show(Iter.toArray(Set.elements(resSet)))); + /// // [0, 1] + /// ``` + /// + /// Cost of mapping all the elements: + /// Runtime: `O(n)`. + /// Space: `O(n)` retained memory + /// where `n` denotes the number of elements stored in the set. + public func map(rbSet : Set, f : T1 -> T) : Set = fromIter(I.map(elements(rbSet), f)); + + /// Creates a new map by applying `f` to each element in `rbSet`. For each element + /// `x` in the old set, if `f` evaluates to `null`, the element is discarded. + /// Otherwise, the entry is transformed into a new entry `x2`, where + /// the new value `x2` is the result of applying `f` to `x`. + /// + /// Example: + /// ```motoko + /// import Map "mo:base/PersistentOrderedMap"; + /// import Nat "mo:base/Nat" + /// import Iter "mo:base/Iter"; + /// + /// let setOps = Set.SetOps(Nat.compare); + /// let rbSet = setOps.fromIter(Iter.fromArray([0, 1, 2, 3])); + /// + /// func f(x : Nat) : ?Nat { + /// if(x == 0) {null} + /// else { ?( x * 2 )} + /// }; + /// + /// let newRbSet = setOps.mapFilter(rbSet, f); + /// + /// Debug.print(debug_show(Iter.toArray(Set.elements(newRbSet)))); + /// + /// // [2, 4, 6] + /// ``` + /// + /// Runtime: `O(n)`. + /// Space: `O(n)` retained memory plus garbage, see the note below. + /// where `n` denotes the number of elements stored in the set and + /// assuming that the `compare` function implements an `O(1)` comparison. + /// + /// Note: Creates `O(log(n))` temporary objects that will be collected as garbage. + public func mapFilter(rbSet: Set, f : T1 -> ?T) : Set { + var set = #leaf : Set; + for(x in elements(rbSet)) { + switch(f x){ + case null {}; + case (?x2) { + set := put(set, x2); + } + } + }; + set + }; + + /// Test if `rbSet1` is subset of `rbSet2`. + /// + /// Example: + /// ```motoko + /// import Set "mo:base/PersistentOrderedSet"; + /// import Nat "mo:base/Nat"; + /// import Iter "mo:base/Iter" + /// + /// let setOps = Set.SetOps(Nat.compare); + /// let rbSet1 = setOps.fromIter(Iter.fromArray([1, 2])); + /// let rbSet2 = setOps.fromIter(Iter.fromArray([0, 2, 1])); + /// + /// Debug.print(debug_show setOps.isSubset(rbSet1, rbSet2)); + /// + /// // true + /// ``` + /// + /// Runtime: `O(m * log(n))`. + /// Space: `O(1)` retained memory plus garbage, see the note below. + /// where `m` and `n` denote the number of elements stored in the sets rbSet1 and rbSet2, respectively, + /// and assuming that the `compare` function implements an `O(1)` comparison. + /// + /// Note: Creates `O(m * log(n))` temporary objects that will be collected as garbage. + public func isSubset(rbSet1 : Set, rbSet2 : Set) : Bool { + if (size(rbSet1) > size(rbSet2)) { return false; }; + isSubsetHelper(rbSet1, rbSet2) + }; + + /// Test if two sets are equal. + /// + /// Example: + /// ```motoko + /// import Set "mo:base/PersistentOrderedSet"; + /// import Nat "mo:base/Nat"; + /// import Iter "mo:base/Iter" + /// + /// let setOps = Set.SetOps(Nat.compare); + /// let rbSet1 = setOps.fromIter(Iter.fromArray([0, 2, 1])); + /// let rbSet2 = setOps.fromIter(Iter.fromArray([1, 2])); + /// + /// Debug.print(debug_show setOps.equals(rbSet1, rbSet1)); + /// Debug.print(debug_show setOps.equals(rbSet1, rbSet2)); + /// + /// // true + /// // false + /// ``` + /// + /// Runtime: `O(m * log(n))`. + /// Space: `O(1)` retained memory plus garbage, see the note below. + /// where `m` and `n` denote the number of elements stored in the sets rbSet1 and rbSet2, respectively, + /// and assuming that the `compare` function implements an `O(1)` comparison. + /// + /// Note: Creates `O(m * log(n))` temporary objects that will be collected as garbage. + public func equals (rbSet1 : Set, rbSet2 : Set) : Bool { + if (size(rbSet1) != size(rbSet2)) { return false; }; + isSubsetHelper(rbSet1, rbSet2) + }; + + func isSubsetHelper(rbSet1 : Set, rbSet2 : Set) : Bool { + for (x in elements(rbSet1)) { + if (not (contains(rbSet2, x))) { + return false; + } + }; + return true; + }; }; /// Returns an Iterator (`Iter`) over the elements of the set. @@ -175,7 +414,7 @@ module { /// // 0 /// ``` /// - /// Cost of empty map creation + /// Cost of empty set creation /// Runtime: `O(1)`. /// Space: `O(1)` public func empty() : Set = Map.empty(); @@ -278,4 +517,26 @@ module { func (x : T , _ : (), acc : Accum) : Accum { combine(x, acc) } ) }; + + /// Test if set is empty. + /// + /// Example: + /// ```motoko + /// import Set "mo:base/PersistentOrderedSet"; + /// + /// let rbSet = Set.empty(); + /// + /// Debug.print(debug_show(Set.isEmpty(rbSet))); + /// + /// // true + /// ``` + /// + /// Runtime: `O(1)`. + /// Space: `O(1)` + public func isEmpty (rbSet : Set) : Bool { + switch rbSet { + case (#leaf) { true }; + case _ { false }; + }; + }; } diff --git a/test/PersistentOrderedSet.test.mo b/test/PersistentOrderedSet.test.mo index e2408032..3b37b219 100644 --- a/test/PersistentOrderedSet.test.mo +++ b/test/PersistentOrderedSet.test.mo @@ -95,6 +95,16 @@ func clear(initialRbSet : Set.Set) : Set.Set { rbSet }; +func add1(x : Nat) : Nat { x + 1 }; + +func ifElemLessThan(threshold : Nat, f : Nat -> Nat) : Nat -> ?Nat + = func (x) { + if(x < threshold) + ?f(x) + else null + }; + + /* --------------------------------------- */ var buildTestSet = func() : Set.Set { @@ -135,6 +145,21 @@ run( Set.foldLeft(buildTestSet(), "", concatenateKeys), M.equals(T.text("")) ), + test( + "traverse empty set", + natSetOps.map(buildTestSet(), add1), + SetMatcher([]) + ), + test( + "empty map filter", + natSetOps.mapFilter(buildTestSet(), ifElemLessThan(0, add1)), + SetMatcher([]) + ), + test( + "is empty", + Set.isEmpty(buildTestSet()), + M.equals(T.bool(true)) + ), ] ) ); @@ -186,6 +211,26 @@ run( Set.foldLeft(buildTestSet(), "", concatenateKeys), M.equals(T.text("0")) ), + test( + "traverse set", + natSetOps.map(buildTestSet(), add1), + SetMatcher([1]) + ), + test( + "map filter/filter all", + natSetOps.mapFilter(buildTestSet(), ifElemLessThan(0, add1)), + SetMatcher([]) + ), + test( + "map filter/no filer", + natSetOps.mapFilter(buildTestSet(), ifElemLessThan(1, add1)), + SetMatcher([1]) + ), + test( + "is empty", + Set.isEmpty(buildTestSet()), + M.equals(T.bool(false)) + ), ] ) ); @@ -240,6 +285,36 @@ func rebalanceTests(buildTestSet : () -> Set.Set) : [Suite.Suite] = Set.foldLeft(buildTestSet(), "", concatenateKeys), M.equals(T.text("012")) ), + test( + "traverse set", + natSetOps.map(buildTestSet(), add1), + SetMatcher([1, 2, 3]) + ), + test( + "traverse set/reshape", + natSetOps.map(buildTestSet(), func (x : Nat) : Nat {5}), + SetMatcher([5]) + ), + test( + "map filter/filter all", + natSetOps.mapFilter(buildTestSet(), ifElemLessThan(0, add1)), + SetMatcher([]) + ), + test( + "map filter/filter one", + natSetOps.mapFilter(buildTestSet(), ifElemLessThan(1, add1)), + SetMatcher([1]) + ), + test( + "map filter/no filer", + natSetOps.mapFilter(buildTestSet(), ifElemLessThan(3, add1)), + SetMatcher([1, 2, 3]) + ), + test( + "is empty", + Set.isEmpty(buildTestSet()), + M.equals(T.bool(false)) + ), ]; buildTestSet := func() : Set.Set { @@ -316,3 +391,174 @@ run( ] ) ); + +/* --------------------------------------- */ + +let buildTestSet012 = func() : Set.Set { + var rbSet = Set.empty(); + rbSet := insert(rbSet, 0); + rbSet := insert(rbSet, 1); + rbSet := insert(rbSet, 2); + rbSet +}; + +let buildTestSet01 = func() : Set.Set { + var rbSet = Set.empty(); + rbSet := insert(rbSet, 0); + rbSet := insert(rbSet, 1); + rbSet +}; + +let buildTestSet234 = func() : Set.Set { + var rbSet = Set.empty(); + rbSet := insert(rbSet, 2); + rbSet := insert(rbSet, 3); + rbSet := insert(rbSet, 4); + rbSet +}; + +let buildTestSet345 = func() : Set.Set { + var rbSet = Set.empty(); + rbSet := insert(rbSet, 5); + rbSet := insert(rbSet, 3); + rbSet := insert(rbSet, 4); + rbSet +}; + +run( + suite( + "set operations", + [ + test( + "subset/subset of itself", + natSetOps.isSubset(buildTestSet012(), buildTestSet012()), + M.equals(T.bool(true)) + ), + test( + "subset/empty set is subset of itself", + natSetOps.isSubset(Set.empty(), Set.empty()), + M.equals(T.bool(true)) + ), + test( + "subset/empty set is subset of another set", + natSetOps.isSubset(Set.empty(), buildTestSet012()), + M.equals(T.bool(true)) + ), + test( + "subset/subset", + natSetOps.isSubset(buildTestSet01(), buildTestSet012()), + M.equals(T.bool(true)) + ), + test( + "subset/not subset", + natSetOps.isSubset(buildTestSet012(), buildTestSet01()), + M.equals(T.bool(false)) + ), + test( + "equals/empty set", + natSetOps.equals(Set.empty(), Set.empty()), + M.equals(T.bool(true)) + ), + test( + "equals/equals", + natSetOps.equals(buildTestSet012(), buildTestSet012()), + M.equals(T.bool(true)) + ), + test( + "equals/not equals", + natSetOps.equals(buildTestSet012(), buildTestSet01()), + M.equals(T.bool(false)) + ), + test( + "union/empty set", + natSetOps.union(Set.empty(), Set.empty()), + SetMatcher([]) + ), + test( + "union/union with empty set", + natSetOps.union(buildTestSet012(), Set.empty()), + SetMatcher([0, 1, 2]) + ), + test( + "union/union with itself", + natSetOps.union(buildTestSet012(), buildTestSet012()), + SetMatcher([0, 1, 2]) + ), + test( + "union/union with subset", + natSetOps.union(buildTestSet012(), buildTestSet01()), + SetMatcher([0, 1, 2]) + ), + test( + "union/union expand", + natSetOps.union(buildTestSet012(), buildTestSet234()), + SetMatcher([0, 1, 2, 3, 4]) + ), + test( + "intersect/empty set", + natSetOps.intersect(Set.empty(), Set.empty()), + SetMatcher([]) + ), + test( + "intersect/intersect with empty set", + natSetOps.intersect(buildTestSet012(), Set.empty()), + SetMatcher([]) + ), + test( + "intersect/intersect with itself", + natSetOps.intersect(buildTestSet012(), buildTestSet012()), + SetMatcher([0, 1, 2]) + ), + test( + "intersect/intersect with subset", + natSetOps.intersect(buildTestSet012(), buildTestSet01()), + SetMatcher([0, 1]) + ), + test( + "intersect/intersect", + natSetOps.intersect(buildTestSet012(), buildTestSet234()), + SetMatcher([2]) + ), + test( + "intersect/no intersection", + natSetOps.intersect(buildTestSet012(), buildTestSet345()), + SetMatcher([]) + ), + test( + "diff/empty set", + natSetOps.diff(Set.empty(), Set.empty()), + SetMatcher([]) + ), + test( + "diff/diff with empty set", + natSetOps.diff(buildTestSet012(), Set.empty()), + SetMatcher([0, 1, 2]) + ), + test( + "diff/diff with empty set 2", + natSetOps.diff(Set.empty(), buildTestSet012()), + SetMatcher([]) + ), + test( + "diff/diff with subset", + natSetOps.diff(buildTestSet012(), buildTestSet01()), + SetMatcher([2]) + ), + test( + "diff/diff with subset 2", + natSetOps.diff(buildTestSet01(), buildTestSet012()), + SetMatcher([]) + ), + test( + "diff/diff", + natSetOps.diff(buildTestSet012(), buildTestSet234()), + SetMatcher([0, 1]) + ), + test( + "diff/diff no intersection", + natSetOps.diff(buildTestSet012(), buildTestSet345()), + SetMatcher([0, 1, 2]) + ), + ] + ) +);