Skip to content

Abstraction for Fold and Scan

Jacob Hinkle edited this page Apr 18, 2024 · 13 revisions

Magmas and monoids

Let $(M, \cdot, e)$ be a unital magma; that is, $M$ is a set over which the binary operation $\cdot: M\times M\to M$ is closed and with respect to which $e\in M$ is a left and right identity. Note that a monoid is a unital magma whose operation is also associative. We will discuss associativity of $\cdot$ and computation of folds and scans in this setting. However, unless stated associativity is not required for most of this document.

Binary trees

First, consider a rooted binary tree $T$ where each node has at most two children which are ordered (left and right). We refer to a node without children as a "leaf".

Flattening

Any rooted binary tree $T$ as we've defined it so far has a natural ordering of its leaves. The relation $\forall n, \textrm{left}(n) < \textrm{right}(n)$ is a partial order on nodes of $T$ that is total when restricted to the children of a single node. Any such partial order extends to a total order over all nodes of $T$ using lexicographic order: each node $n$ has a unique path from the root to that node which is a word of directions $p(n)\in \lbrace L, R \rbrace^m(n)$ for some $m(n)$ which is the distance between $n$ and the root of the tree. We can extend the order to arbitrary nodes by $m < n \iff p(m) < p(n)$ where paths are compared in lexicographic order using $L < R$ (note that this will order all ancestors of a node to be less than that node). In particular, the tree's structure implies a total order of its leaves, meaning there is a well-defined mapping of the tree to its leaves preserving that order: we call that "flattening" the tree.

$$ \begin{align} \textrm{flatten}(T) = \left(n_0, \dots, n_{|\textrm{leaves}(T)|-1}\right) &\in \textrm{leaves}(T)^{|\textrm{leaves}(T)|} \\ \textrm{flatten}(T)_i < \textrm{flatten}(T)_j &\iff i < j \end{align} $$

Folds

Further assume that each leaf node $l \in T$ contains an element of $M$ determined by some given function $d:\textrm{leaves}(T)\to M$. Then we can define the fold of the data $d$ with respect to the tree $T$ by introducing a recursively-defined function $f_d: \textrm{nodes}(T)\to M$:

$$ \begin{align} f_d(n) &= \begin{cases} d(n) & \mbox{$n$ is a leaf} \\ f_d(\textrm{left}(n)) \cdot e & \mbox{$n$ has left but not right child} \\ e \cdot f_d(\textrm{right}(n)) & \mbox{$n$ has right but not left child} \\ f_d(\textrm{left}(n)) \cdot f_d(\textrm{right}(n)) & {otherwise} \end{cases} \\ \textrm{fold}(T, d) &= f_d(\textrm{root}(T)) \end{align} $$

Notice that this fold operation essentially takes a tree with data attached to the leaves and propagates it such that data is now defined at every node (as the value $f_d(n)$. Then it extracts the data corresponding to the root node. This is a very familiar way to represent computation (cf. abstract syntax trees).

Consider a case where we are given an ordered array of data $x\in M^n$; i.e. $x$ consists of $n$ elements from $M$, $x[0]$ through $x[n-1]$. If we would like to reduce this to a single number, we must decide on a computational strategy for the order in which we'll apply the binary operation $\cdot$. Given that the only thing we can compute in this abstract case is that binary operation, our entire computational strategy is equivalent to a choice of binary tree $T$ and a mapping from its leaves to $x$. Any binary tree $T$ with $n$ leaves corresponds to an operation we call $T$-unflattening, which takes a word $w\in M^n$ and maps it to a function $d:\textrm{nodes}(T)\to M$ such that the following holds:

$$ w_i = \left(\textrm{flatten}_T(\cdots)\right)_i $$

Sequential sum

A non-parallel sum of three elements would look like this

graph TD;
  s0["sum 0"] --> 0;
  s0 --> x0["x[0]"];
  s1["sum 1"] --> s0;
  s1 --> x1["x[1]"];
  s2["sum 2 (final sum)"] --> s1;
  s2 --> x2["x[2]"];
Loading

Hierarchical parallel sum

A parallel sum of three elements might look like this

graph TD;
  s0["sum X"] --> x0["x[0]"];
  s0 --> x1["x[1]"];
  s1["sum Y"] --> x2["x[2]"];
  s1 --> x3["x[3]"];
  s2["sum Z (total sum)"] --> s0;
  s2 --> s1;
Loading