Skip to content

Commit

Permalink
feat: faster splitAt (#919)
Browse files Browse the repository at this point in the history
  • Loading branch information
kim-em authored Aug 16, 2024
1 parent b0587b2 commit a36e34f
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions Batteries/Data/List/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,13 @@ Split a list at an index.
splitAt 2 [a, b, c] = ([a, b], [c])
```
-/
def splitAt (n : Nat) (l : List α) : List α × List α := go l n #[] where
/-- Auxiliary for `splitAt`: `splitAt.go l n xs acc = (acc.toList ++ take n xs, drop n xs)`
def splitAt (n : Nat) (l : List α) : List α × List α := go l n [] where
/-- Auxiliary for `splitAt`: `splitAt.go l n xs acc = (acc.reverse ++ take n xs, drop n xs)`
if `n < length xs`, else `(l, [])`. -/
go : List α → Nat → Array α → List α × List α
go : List α → Nat → List α → List α × List α
| [], _, _ => (l, [])
| x :: xs, n+1, acc => go xs n (acc.push x)
| xs, _, acc => (acc.toList, xs)
| x :: xs, n+1, acc => go xs n (x :: acc)
| xs, _, acc => (acc.reverse, xs)

/--
Split a list at an index. Ensures the left list always has the specified length
Expand All @@ -132,13 +132,13 @@ splitAtD 2 [a, b, c] x = ([a, b], [c])
splitAtD 4 [a, b, c] x = ([a, b, c, x], [])
```
-/
def splitAtD (n : Nat) (l : List α) (dflt : α) : List α × List α := go n l #[] where
/-- Auxiliary for `splitAtD`: `splitAtD.go dflt n l acc = (acc.toList ++ left, right)`
def splitAtD (n : Nat) (l : List α) (dflt : α) : List α × List α := go n l [] where
/-- Auxiliary for `splitAtD`: `splitAtD.go dflt n l acc = (acc.reverse ++ left, right)`
if `splitAtD n l dflt = (left, right)`. -/
go : Nat → List α → Array α → List α × List α
| n+1, x :: xs, acc => go n xs (acc.push x)
| 0, xs, acc => (acc.toList, xs)
| n, [], acc => (acc.toListAppend (replicate n dflt), [])
go : Nat → List α → List α → List α × List α
| n+1, x :: xs, acc => go n xs (x :: acc)
| 0, xs, acc => (acc, xs)
| n, [], acc => (acc.reverseAux (replicate n dflt), [])

/--
Split a list at every element satisfying a predicate. The separators are not in the result.
Expand Down

0 comments on commit a36e34f

Please sign in to comment.