Skip to content

Commit

Permalink
Inline color field into map constructors
Browse files Browse the repository at this point in the history
  • Loading branch information
s-and-witch committed Oct 14, 2024
1 parent 3b1471e commit 12daced
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 105 deletions.
194 changes: 108 additions & 86 deletions src/PersistentOrderedMap.mo
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,12 @@ import O "Order";

module {

/// Node color: Either red (`#R`) or black (`#B`).
public type Color = { #R; #B };

/// Red-black tree of nodes with key-value entries, ordered by the keys.
/// The keys have the generic type `K` and the values the generic type `V`.
/// Leaves are considered implicitly black.
public type Map<K, V> = {
#node : (Color, Map<K, V>, (K, V), Map<K, V>);
#red : (Map<K, V>, (K, V), Map<K, V>);
#black : (Map<K, V>, (K, V), Map<K, V>);
#leaf
};

Expand Down Expand Up @@ -351,7 +349,11 @@ module {
trees := ts;
?xy
};
case (?(#tr(#node(_, l, xy, r)), ts)) {
case (?(#tr(#red(l, xy, r)), ts)) {
trees := mapTraverser(l, xy, r, ts);
next()
};
case (?(#tr(#black(l, xy, r)), ts)) {
trees := mapTraverser(l, xy, r, ts);
next()
}
Expand Down Expand Up @@ -467,8 +469,11 @@ module {
func mapRec(m : Map<K, V1>) : Map<K, V2> {
switch m {
case (#leaf) { #leaf };
case (#node(c, l, xy, r)) {
#node(c, mapRec l, (xy.0, f xy), mapRec r) // TODO: try destination-passing style to avoid non tail-call recursion
case (#red(l, xy, r)) {
#red(mapRec l, (xy.0, f xy), mapRec r)
};
case (#black(l, xy, r)) {
#black(mapRec l, (xy.0, f xy), mapRec r)
};
}
};
Expand Down Expand Up @@ -497,7 +502,10 @@ module {
public func size<K, V>(t : Map<K, V>) : Nat {
switch t {
case (#leaf) { 0 };
case (#node(_, l, _, r)) {
case (#red(l, _, r)) {
size(l) + size(r) + 1
};
case (#black(l, _, r)) {
size(l) + size(r) + 1
}
}
Expand Down Expand Up @@ -538,7 +546,12 @@ module {
{
switch (rbMap) {
case (#leaf) { base };
case (#node(_, l, (k, v), r)) {
case (#red(l, (k, v), r)) {
let left = foldLeft(l, base, combine);
let middle = combine(k, v, left);
foldLeft(r, middle, combine)
};
case (#black(l, (k, v), r)) {
let left = foldLeft(l, base, combine);
let middle = combine(k, v, left);
foldLeft(r, middle, combine)
Expand Down Expand Up @@ -581,7 +594,12 @@ module {
{
switch (rbMap) {
case (#leaf) { base };
case (#node(_, l, (k, v), r)) {
case (#red(l, (k, v), r)) {
let right = foldRight(r, base, combine);
let middle = combine(k, v, right);
foldRight(l, middle, combine)
};
case (#black(l, (k, v), r)) {
let right = foldRight(r, base, combine);
let middle = combine(k, v, right);
foldRight(l, middle, combine)
Expand Down Expand Up @@ -616,7 +634,14 @@ module {
public func get<K, V>(t : Map<K, V>, compare : (K, K) -> O.Order, x : K) : ?V {
switch t {
case (#leaf) { null };
case (#node(_c, l, xy, r)) {
case (#red(l, xy, r)) {
switch (compare(x, xy.0)) {
case (#less) { get(l, compare, x) };
case (#equal) { ?xy.1 };
case (#greater) { get(r, compare, x) }
}
};
case (#black(l, xy, r)) {
switch (compare(x, xy.0)) {
case (#less) { get(l, compare, x) };
case (#equal) { ?xy.1 };
Expand All @@ -628,8 +653,8 @@ module {

func redden<K, V>(t : Map<K, V>) : Map<K, V> {
switch t {
case (#node (#B, l, xy, r)) {
(#node (#R, l, xy, r))
case (#black (l, xy, r)) {
(#red (l, xy, r))
};
case _ {
Debug.trap "RBTree.red"
Expand All @@ -639,44 +664,40 @@ module {

func lbalance<K,V>(left : Map<K, V>, xy : (K,V), right : Map<K, V>) : Map<K,V> {
switch (left, right) {
case (#node(#R, #node(#R, l1, xy1, r1), xy2, r2), r) {
#node(
#R,
#node(#B, l1, xy1, r1),
case (#red(#red(l1, xy1, r1), xy2, r2), r) {
#red(
#black(l1, xy1, r1),
xy2,
#node(#B, r2, xy, r))
#black(r2, xy, r))
};
case (#node(#R, l1, xy1, #node(#R, l2, xy2, r2)), r) {
#node(
#R,
#node(#B, l1, xy1, l2),
case (#red(l1, xy1, #red(l2, xy2, r2)), r) {
#red(
#black(l1, xy1, l2),
xy2,
#node(#B, r2, xy, r))
#black(r2, xy, r))
};
case _ {
#node(#B, left, xy, right)
#black(left, xy, right)
}
}
};

func rbalance<K,V>(left : Map<K, V>, xy : (K,V), right : Map<K, V>) : Map<K,V> {
switch (left, right) {
case (l, #node(#R, l1, xy1, #node(#R, l2, xy2, r2))) {
#node(
#R,
#node(#B, l, xy, l1),
case (l, #red(l1, xy1, #red(l2, xy2, r2))) {
#red(
#black(l, xy, l1),
xy1,
#node(#B, l2, xy2, r2))
#black(l2, xy2, r2))
};
case (l, #node(#R, #node(#R, l1, xy1, r1), xy2, r2)) {
#node(
#R,
#node(#B, l, xy, l1),
case (l, #red(#red(l1, xy1, r1), xy2, r2)) {
#red(
#black(l, xy, l1),
xy1,
#node(#B, r1, xy2, r2))
#black(r1, xy2, r2))
};
case _ {
#node(#B, left, xy, right)
#black(left, xy, right)
};
}
};
Expand All @@ -694,9 +715,9 @@ module {
func ins(tree : Map<K,V>) : Map<K,V> {
switch tree {
case (#leaf) {
#node(#R, #leaf, (key,val), #leaf)
#red(#leaf, (key,val), #leaf)
};
case (#node(#B, left, xy, right)) {
case (#black(left, xy, right)) {
switch (compare (key, xy.0)) {
case (#less) {
lbalance(ins left, xy, right)
Expand All @@ -706,29 +727,29 @@ module {
};
case (#equal) {
let newVal = onClash({ new = val; old = xy.1 });
#node(#B, left, (key,newVal), right)
#black(left, (key,newVal), right)
}
}
};
case (#node(#R, left, xy, right)) {
case (#red(left, xy, right)) {
switch (compare (key, xy.0)) {
case (#less) {
#node(#R, ins left, xy, right)
#red(ins left, xy, right)
};
case (#greater) {
#node(#R, left, xy, ins right)
#red(left, xy, ins right)
};
case (#equal) {
let newVal = onClash { new = val; old = xy.1 };
#node(#R, left, (key,newVal), right)
#red(left, (key,newVal), right)
}
}
}
};
};
switch (ins m) {
case (#node(#R, left, xy, right)) {
#node(#B, left, xy, right);
case (#red(left, xy, right)) {
#black(left, xy, right);
};
case other { other };
};
Expand Down Expand Up @@ -761,19 +782,18 @@ module {

func balLeft<K,V>(left : Map<K, V>, xy : (K,V), right : Map<K, V>) : Map<K,V> {
switch (left, right) {
case (#node(#R, l1, xy1, r1), r) {
#node(
#R,
#node(#B, l1, xy1, r1),
case (#red(l1, xy1, r1), r) {
#red(
#black(l1, xy1, r1),
xy,
r)
};
case (_, #node(#B, l2, xy2, r2)) {
rbalance(left, xy, #node(#R, l2, xy2, r2))
case (_, #black(l2, xy2, r2)) {
rbalance(left, xy, #red(l2, xy2, r2))
};
case (_, #node(#R, #node(#B, l2, xy2, r2), xy3, r3)) {
#node(#R,
#node(#B, left, xy, l2),
case (_, #red(#black(l2, xy2, r2), xy3, r3)) {
#red(
#black(left, xy, l2),
xy2,
rbalance(r2, xy3, redden r3))
};
Expand All @@ -783,20 +803,20 @@ module {

func balRight<K,V>(left : Map<K, V>, xy : (K,V), right : Map<K, V>) : Map<K,V> {
switch (left, right) {
case (l, #node(#R, l1, xy1, r1)) {
#node(#R,
case (l, #red(l1, xy1, r1)) {
#red(
l,
xy,
#node(#B, l1, xy1, r1))
#black(l1, xy1, r1))
};
case (#node(#B, l1, xy1, r1), r) {
lbalance(#node(#R, l1, xy1, r1), xy, r);
case (#black(l1, xy1, r1), r) {
lbalance(#red(l1, xy1, r1), xy, r);
};
case (#node(#R, l1, xy1, #node(#B, l2, xy2, r2)), r3) {
#node(#R,
case (#red(l1, xy1, #black(l2, xy2, r2)), r3) {
#red(
lbalance(redden l1, xy1, l2),
xy2,
#node(#B, r2, xy, r3))
#black(r2, xy, r3))
};
case _ { Debug.trap "balRight" };
}
Expand All @@ -806,40 +826,39 @@ module {
switch (left, right) {
case (#leaf, _) { right };
case (_, #leaf) { left };
case (#node (#R, l1, xy1, r1),
#node (#R, l2, xy2, r2)) {
case (#red (l1, xy1, r1),
#red (l2, xy2, r2)) {
switch (append (r1, l2)) {
case (#node (#R, l3, xy3, r3)) {
#node(
#R,
#node(#R, l1, xy1, l3),
case (#red (l3, xy3, r3)) {
#red(
#red(l1, xy1, l3),
xy3,
#node(#R, r3, xy2, r2))
#red(r3, xy2, r2))
};
case r1l2 {
#node(#R, l1, xy1, #node(#R, r1l2, xy2, r2))
#red(l1, xy1, #red(r1l2, xy2, r2))
}
}
};
case (t1, #node(#R, l2, xy2, r2)) {
#node(#R, append(t1, l2), xy2, r2)
case (t1, #red(l2, xy2, r2)) {
#red(append(t1, l2), xy2, r2)
};
case (#node(#R, l1, xy1, r1), t2) {
#node(#R, l1, xy1, append(r1, t2))
case (#red(l1, xy1, r1), t2) {
#red(l1, xy1, append(r1, t2))
};
case (#node(#B, l1, xy1, r1), #node (#B, l2, xy2, r2)) {
case (#black(l1, xy1, r1), #black (l2, xy2, r2)) {
switch (append (r1, l2)) {
case (#node (#R, l3, xy3, r3)) {
#node(#R,
#node(#B, l1, xy1, l3),
case (#red (l3, xy3, r3)) {
#red(
#black(l1, xy1, l3),
xy3,
#node(#B, r3, xy2, r2))
#black(r3, xy2, r2))
};
case r1l2 {
balLeft (
l1,
xy1,
#node(#B, r1l2, xy2, r2)
#black(r1l2, xy2, r2)
)
}
}
Expand All @@ -857,22 +876,22 @@ module {
case (#less) {
let newLeft = del left;
switch left {
case (#node(#B, _, _, _)) {
case (#black(_, _, _)) {
balLeft(newLeft, xy, right)
};
case _ {
#node(#R, newLeft, xy, right)
#red(newLeft, xy, right)
}
}
};
case (#greater) {
let newRight = del right;
switch right {
case (#node(#B, _, _, _)) {
case (#black(_, _, _)) {
balRight(left, xy, newRight)
};
case _ {
#node(#R, left, xy, newRight)
#red(left, xy, newRight)
}
}
};
Expand All @@ -887,14 +906,17 @@ module {
case (#leaf) {
tree
};
case (#node(_, left, xy, right)) {
case (#red(left, xy, right)) {
delNode(left, xy, right)
};
case (#black(left, xy, right)) {
delNode(left, xy, right)
}
};
};
switch (del(tree)) {
case (#node(#R, left, xy, right)) {
(#node(#B, left, xy, right), y0);
case (#red(left, xy, right)) {
(#black(left, xy, right), y0);
};
case other { (other, y0) };
};
Expand Down
Loading

0 comments on commit 12daced

Please sign in to comment.