From c76696c68474ccfd853c5f5927d68b6f9fd5d081 Mon Sep 17 00:00:00 2001 From: Richard Guo Date: Thu, 24 Oct 2024 12:04:42 -0700 Subject: [PATCH] [RFR] Heap Sort on Tiles (#162) Working solution with tests. Assumes all the tile sorts can be loaded into memory. If not, we need to bring back two heaps ---- > [!IMPORTANT] > Introduce heap sort for tile sorting in Deepscatter with `TileSorter` and `MinHeap`, and add tests for sorting functionality. > > - **Behavior**: > - Introduces `TileSorter` class in `selection.ts` for heap-based sorting of tiles. > - Adds `iterator()` method in `SortedDataSelection` to iterate over sorted data using `TileSorter`. > - Implements `MinHeap` class in `utilityFunctions.ts` for managing heap operations. > - **Tests**: > - Adds tests in `dataset.spec.js` for sorting and iterated sorting of selections using the new heap sort mechanism. > - Tests cover both normal and reverse iteration over sorted data. > - **Misc**: > - Minor import adjustments in `selection.ts` and `deepscatter.ts`. > > This description was created by [Ellipsis](https://www.ellipsis.dev?ref=nomic-ai%2Fdeepscatter&utm_source=github&utm_medium=referral) for f1ceb61b6e2cae1ef1f6771948e32d795d7aa3e4. It will automatically update as commits are pushed. --- src/deepscatter.ts | 2 +- src/selection.ts | 266 +++++++++++++++++++++++++++++++++++----- src/utilityFunctions.ts | 99 ++++++++++++++- tests/dataset.spec.js | 154 ++++++++++++++++++++++- tests/datasetHelpers.js | 25 ++++ 5 files changed, 510 insertions(+), 36 deletions(-) diff --git a/src/deepscatter.ts b/src/deepscatter.ts index 0498311ba..278ca308e 100644 --- a/src/deepscatter.ts +++ b/src/deepscatter.ts @@ -25,5 +25,5 @@ export type { OpChannel, TileProxy, DeeptableCreateParams, - Transformation, + Transformation } from './types'; diff --git a/src/selection.ts b/src/selection.ts index b73a4821b..8259683f4 100644 --- a/src/selection.ts +++ b/src/selection.ts @@ -6,6 +6,7 @@ import { getTileFromRow } from './tixrixqid'; import type * as DS from './types'; import { Bool, StructRowProxy, Utf8, Vector, makeData } from 'apache-arrow'; import { bisectLeft, bisectRight, range } from 'd3-array'; +import { MinHeap } from './utilityFunctions'; interface SelectParams { name: string; useNameCache?: boolean; // If true and a selection with that name already exists, use it and ignore all passed parameters. Otherwise, throw an error. @@ -350,11 +351,7 @@ class SelectionTile { return this._matchCount; } - addSort( - key: string, - getter: (row: StructRowProxy) => number, - order: 'ascending' | 'descending', - ) { + addSort(key: string, getter: (row: StructRowProxy) => number) { const { bitmask } = this; const indices = Bitmask.from_arrow(bitmask).which(); const pairs: [number, number][] = new Array(indices.length); @@ -362,8 +359,7 @@ class SelectionTile { const v = getter(this.tile.record_batch.get(indices[i])); pairs[i] = [v, indices[i]]; } - // Sort according to the specified order - pairs.sort((a, b) => (order === 'ascending' ? a[0] - b[0] : b[0] - a[0])); + pairs.sort((a, b) => a[0] - b[0]); const values = new Float64Array(indices.length); for (let i = 0; i < indices.length; i++) { indices[i] = pairs[i][1]; @@ -1059,7 +1055,7 @@ export class SortedDataSelection extends DataSelection { const withSort = sel.tiles.map( async (tile: SelectionTile): Promise => { await Promise.all(neededFields.map((f) => tile.tile.get_column(f))); - tile.addSort(key, sortOperation, order); + tile.addSort(key, sortOperation); return tile; }, ); @@ -1081,17 +1077,17 @@ export class SortedDataSelection extends DataSelection { // Store the indices and values in the tile - let ix = this.tiles.findIndex((having) => having.tile === tile); + const ix = this.tiles.findIndex((having) => having.tile === tile); let t: SelectionTile; if (ix !== -1) { t = this.tiles[ix]; - t.addSort(this.key, this.comparisonGetter, this.order); + t.addSort(this.key, this.comparisonGetter); } else { t = new SelectionTile({ arrowBitmask: array, tile, }); - t.addSort(this.key, this.comparisonGetter, this.order); + t.addSort(this.key, this.comparisonGetter); this.selectionSize += t.matchCount; this.evaluationSetSize += tile.manifest.nPoints; this.tiles.push(t); @@ -1107,27 +1103,195 @@ export class SortedDataSelection extends DataSelection { * This implementation uses Quickselect with a pivot selected from actual data. */ get(k: number): StructRowProxy | undefined { - if (k < 0 || k >= this.selectionSize) { - console.error('Index out of bounds'); - return undefined; + const actualK = this.order === 'ascending' ? k : this.selectionSize - k - 1; + // Implement Quickselect over the combined data + const result = quickSelect(actualK, this.tiles, this.key, true); + return result ? result.row : undefined; + } + + // Given a point, returns cursor number that would select it in this selection + // which(row: StructRowProxy) {} + + iterator(k: number = 0, reverse: boolean = false): TileSorter { + if (this.tiles.length == 0) { + throw new Error('No tiles in sorted selection to iterate over'); } + const actualOrder = reverse + ? this.order === 'ascending' + ? 'descending' + : 'ascending' + : this.order; + return new TileSorter(this.tiles, this.key, k, actualOrder); + } +} - // Adjust k based on the order - const targetIndex = - this.order === 'ascending' ? k : this.selectionSize - k - 1; +type SortInfoWithPointer = SelectionSortInfo & { pointer: number; tile: Tile }; - // Implement Quickselect over the combined data - return quickSelect(targetIndex, this.tiles, this.key); +type CompareFunction = (a, b) => number; + +// Separate class for now which is disposable because it maintains its +// own traversal order through the array as a pointer. +// Time complexity per iteration: O(log m) where m is the number of tiles +export class TileSorter + implements Iterator, Iterable +{ + public tiles: SelectionTile[]; + public sortKey: string; + public startingK: number; + public order: 'ascending' | 'descending'; + public compare: CompareFunction; + private valueHeap: MinHeap; + + constructor( + tiles: SelectionTile[], + sortKey: string, + startingK: number = 0, + order: 'ascending' | 'descending' = 'ascending', + ) { + this.startingK = startingK; + this.sortKey = sortKey; + this.order = order; + this.compare = this.getCompareFunction(); + + this.valueHeap = new MinHeap(this.compare); + this.tiles = tiles; + this.init(); } - // Given a point, returns cursor number that would select it in this selection - which(row: StructRowProxy) {} + // Needed to bake the tile id into the comparator. + getCompareFunction(): CompareFunction { + if (this.order === 'descending') { + return (a: SortInfoWithPointer, b: SortInfoWithPointer) => { + const diff = b.values[b.pointer] - a.values[a.pointer]; + return Math.abs(diff) < Number.EPSILON ? b.tile.tix - a.tile.tix : diff; + }; + } else { + return (a: SortInfoWithPointer, b: SortInfoWithPointer) => { + const diff = a.values[a.pointer] - b.values[b.pointer]; + return Math.abs(diff) < Number.EPSILON ? a.tile.tix - b.tile.tix : diff; + }; + } + } + + selectionSizeFromTiles() { + return this.tiles.reduce((acc, t) => acc + t.matchCount, 0); + } + + initFromKHelper() { + const actualK = + this.order === 'ascending' + ? this.startingK + : this.selectionSizeFromTiles() - this.startingK - 1; + const { tix, sortIndex, sortValue } = quickSelect( + actualK, + this.tiles, + this.sortKey, + false, + this.order, + ); + const sortInfos: SortInfoWithPointer[] = []; + // Init all pointers with binary search + for (const tile of this.tiles) { + const rawSortInfo = tile.sorts[this.sortKey]; + // We know exactly pointer spot for the tile with k + if (tile.tile.tix === tix) { + sortInfos.push({ + ...rawSortInfo, + pointer: sortIndex, + tile: tile.tile, + }); + continue; + } + // All values left of pointer are less than targetValue. + let pointer; + if (this.order === 'ascending') { + pointer = bisectLeft(rawSortInfo.values, sortValue); + } else { + pointer = bisectRight(rawSortInfo.values, sortValue); + } + sortInfos.push({ + ...rawSortInfo, + pointer: pointer, + tile: tile.tile, + }); + } + + // Now skip the pointer past values <= kth val until we reach tile with k + for (const sortInfo of sortInfos) { + const { tile, values } = sortInfo; + // Found the target tile + if (tile.tix === tix) { + break; + } + // Before the target tile, skip equal values + while (values[sortInfo.pointer] <= sortValue) { + if (this.order === 'ascending') { + sortInfo.pointer++; + } else { + sortInfo.pointer--; + } + } + } + + // Finally add to heap + for (const sortInfo of sortInfos) { + // Only add if there are indices and we haven't reached the end of the values + if (sortInfo.indices.length > 0 && sortInfo.pointer < sortInfo.values.length) { + this.valueHeap.insert(sortInfo); + } + } + } + + init() { + // First sort the tiles to guarantee consistency + if (this.order === 'ascending') { + this.tiles.sort((a, b) => a.tile.tix - b.tile.tix); + } else { + this.tiles.sort((a, b) => b.tile.tix - a.tile.tix); + } - *yieldSorted(start = undefined, direction = 'up') { - if (start !== undefined) { - this.cursor = start; + if (this.startingK > 0) { + this.initFromKHelper(); + } else { + for (const tile of this.tiles) { + const rawSortInfo = tile.sorts[this.sortKey]; + // Skip tiles with empty selections + if (rawSortInfo.indices.length <= 0) { + continue + } + const sortInfo: SortInfoWithPointer = { + ...rawSortInfo, + pointer: + this.order === 'ascending' ? 0 : rawSortInfo.values.length - 1, + tile: tile.tile, + }; + this.valueHeap.insert(sortInfo); + } } } + + next(): IteratorResult { + if (this.valueHeap.isEmpty()) { + return { done: true, value: undefined }; + } + + const heapItem = this.valueHeap.extractMin(); + const { tile, indices, pointer } = heapItem; + const index = indices[pointer]; + const row = tile.record_batch.get(index); + + const nextPointer: number = + this.order === 'ascending' ? pointer + 1 : pointer - 1; + + if (nextPointer < indices.length && nextPointer >= 0) { + this.valueHeap.insert({ ...heapItem, pointer: nextPointer }); + } + return { value: row, done: false }; + } + + [Symbol.iterator](): Iterator { + return this; + } } interface QuickSortTile { @@ -1135,28 +1299,61 @@ interface QuickSortTile { sorts: Record; } +type QuickSelectResult = { + row: StructRowProxy; + tix: number; + // Needed if you want to set pointers + sortIndex: number; + sortValue: number; +}; + function quickSelect( k: number, tiles: QuickSortTile[], key: string, -): StructRowProxy | undefined { + sortTiles: boolean = true, // Optional because don't need to sort on recursive calls + order: 'ascending' | 'descending' = 'ascending', +): QuickSelectResult | undefined { // Recalculate size based on the current tiles const size = tiles.reduce( (acc, t) => acc + (t.sorts[key].end - t.sorts[key].start), 0, ); + if (k < 0 || k >= size) { + throw new Error('Index out of bounds'); + return undefined; + } + + if (size === 0) { + return undefined; + } + if (size === 1) { for (const t of tiles) { - const { indices, start, end } = t.sorts[key]; + const { indices, values, start, end } = t.sorts[key]; if (end - start > 0) { const recordIndex = indices[start]; - return t.tile.record_batch.get(recordIndex); + return { + row: t.tile.record_batch.get(recordIndex), + sortIndex: start, + sortValue: values[start], + tix: t.tile.tix, + }; } } return undefined; } + // Sort the tiles to guarantee consistency + if (sortTiles) { + tiles.sort((a, b) => + order === 'ascending' + ? a.tile.tix - b.tile.tix + : -1 * (a.tile.tix - b.tile.tix), + ); + } + // Select a random pivot from actual data const pivot = randomPivotFromData(tiles, key); @@ -1215,13 +1412,13 @@ function quickSelect( } if (k < countLess) { - return quickSelect(k, lessTiles, key); + return quickSelect(k, lessTiles, key, false); } else if (k < countLess + countEqual) { const indexInEqual = k - countLess; return selectInEqualTiles(indexInEqual, equalTiles, key); } else { const newK = k - (countLess + countEqual); - return quickSelect(newK, greaterTiles, key); + return quickSelect(newK, greaterTiles, key, false); } } @@ -1229,15 +1426,20 @@ function selectInEqualTiles( indexInEqual: number, tiles: QuickSortTile[], key: string, -): StructRowProxy | undefined { +): QuickSelectResult | undefined { let count = 0; for (const t of tiles) { - const { indices, start, end } = t.sorts[key]; + const { indices, values, start, end } = t.sorts[key]; const numValues = end - start; if (indexInEqual < count + numValues) { const idxInTile = start + (indexInEqual - count); const recordIndex = indices[idxInTile]; - return t.tile.record_batch.get(recordIndex); + return { + row: t.tile.record_batch.get(recordIndex), + tix: t.tile.tix, + sortIndex: idxInTile, + sortValue: values[idxInTile], + }; } count += numValues; } diff --git a/src/utilityFunctions.ts b/src/utilityFunctions.ts index 9267ca47d..e94d4e982 100644 --- a/src/utilityFunctions.ts +++ b/src/utilityFunctions.ts @@ -192,7 +192,7 @@ export class TupleMap { } } -export class TupleSet { +export class TupleSet { private map = new TupleMap(); constructor(v: Some[] = []) { @@ -232,3 +232,100 @@ export class TupleSet { this.map = new TupleMap(); } } + +export class MinHeap { + private heap: T[] = []; + private comparator: (a: T, b: T) => number; + + constructor(comparator: (a: T, b: T) => number) { + this.comparator = comparator; + } + + /** Inserts a new element into the heap */ + public insert(value: T): void { + this.heap.push(value); + this.bubbleUp(); + } + + /** Extracts and returns the minimum element from the heap */ + public extractMin(): T | undefined { + if (this.heap.length === 0) return undefined; + + const min = this.heap[0]; + const end = this.heap.pop(); + + if (this.heap.length > 0 && end !== undefined) { + this.heap[0] = end; + this.bubbleDown(); + } + + return min; + } + + /** Returns true if the heap is empty */ + public isEmpty(): boolean { + return this.heap.length === 0; + } + + /** Returns the size of the heap */ + public size(): number { + return this.heap.length; + } + + /** Returns the minimum element without removing it */ + public peek(): T | undefined { + return this.heap[0]; + } + + private bubbleUp(): void { + let index = this.heap.length - 1; + const element = this.heap[index]; + + while (index > 0) { + const parentIndex = Math.floor((index - 1) / 2); + const parent = this.heap[parentIndex]; + + if (this.comparator(element, parent) >= 0) break; + + this.heap[index] = parent; + this.heap[parentIndex] = element; + index = parentIndex; + } + } + + private bubbleDown(): void { + let index = 0; + const length = this.heap.length; + const element = this.heap[0]; + + while (true) { + let swapIndex: number | null = null; + const leftChildIndex = 2 * index + 1; + const rightChildIndex = 2 * index + 2; + + if (leftChildIndex < length) { + const leftChild = this.heap[leftChildIndex]; + if (this.comparator(leftChild, element) < 0) { + swapIndex = leftChildIndex; + } + } + + if (rightChildIndex < length) { + const rightChild = this.heap[rightChildIndex]; + if ( + (swapIndex === null && this.comparator(rightChild, element) < 0) || + (swapIndex !== null && + this.comparator(rightChild, this.heap[swapIndex]) < 0) + ) { + swapIndex = rightChildIndex; + } + } + + if (swapIndex === null) break; + + this.heap[index] = this.heap[swapIndex]; + this.heap[swapIndex] = element; + index = swapIndex; + } + } +} diff --git a/tests/dataset.spec.js b/tests/dataset.spec.js index 1d20295b9..f0fffaba3 100644 --- a/tests/dataset.spec.js +++ b/tests/dataset.spec.js @@ -1,15 +1,14 @@ import { - Deeptable, DataSelection, SortedDataSelection, Bitmask, } from '../dist/deepscatter.js'; -import { Table, vectorFromArray, Utf8 } from 'apache-arrow'; import { test } from 'uvu'; import * as assert from 'uvu/assert'; import { createIntegerDataset, selectFunctionForFactorsOf, + selectRandomRows, } from './datasetHelpers.js'; test('Dataset can be created', async () => { @@ -124,4 +123,155 @@ test('Test sorting of selections', async () => { assert.ok(mid.random < 0.55); }); +test('Test iterated sorting of selections', async () => { + const dataset = createIntegerDataset(); + await dataset.root_tile.preprocessRootTileInfo(); + const selectEvens = new DataSelection(dataset, { + name: 'twos2', + tileFunction: selectFunctionForFactorsOf(2), + }); + const sortKey = 'random'; + const sorted = await SortedDataSelection.fromSelection( + selectEvens, + [sortKey], + ({ random }) => random, + ); + // Apply only to root tile + await dataset.root_tile.get_column(sorted.name); + + let size = 0; + let prevValue = Number.NEGATIVE_INFINITY; + for (const row of sorted.iterator()) { + size++; + const currValue = row[sortKey]; + assert.ok(currValue >= prevValue); + prevValue = currValue; + } + assert.ok(size, 2048); + assert.ok(size, sorted.selectionSize); + + // Now load all the tiles + await sorted.applyToAllTiles(); + + size = 0; + prevValue = Number.NEGATIVE_INFINITY; + for (const row of sorted.iterator()) { + size++; + const currValue = row[sortKey]; + assert.ok(currValue >= prevValue); + prevValue = currValue; + } + assert.is(size, 8192); + assert.is(size, sorted.selectionSize); + + // Multiple iterators + const first = sorted.iterator(0); + const second = sorted.iterator(5); + + const numbers = [0, 1, 2, 3, 5, 6, 7, 8, 9, 10]; + const firstVals = numbers.map(d => first.next().value[sortKey]) + + for (let i = 0; i < 5; i++) { + assert.ok(firstVals[5 + i] === second.next().value[sortKey]); + } +}); + +test ('Iterated sorting of empty selection', async() => { + const dataset = createIntegerDataset(); + await dataset.root_tile.preprocessRootTileInfo(); + const emptySelection = new DataSelection(dataset, { + name: 'empty', + tileFunction: async (t) => new Bitmask(t.record_batch.numRows).to_arrow(), + }); + + const sorted = await SortedDataSelection.fromSelection( + emptySelection, + ['random'], + ({ random }) => random, + ); + await sorted.applyToAllTiles(); + + let size = 0; + for (const row of sorted.iterator()) { + console.log(row); + size++; + } + assert.is(size, 0); + assert.is(size, sorted.selectionSize); + + let thrown = false; + // throw index out of bounds + try { + sorted.iterator(5); + } catch (e) { + thrown = true; + } + assert.ok(thrown); +}) + +test('Edge cases for iterated sorting of selections', async () => { + const dataset = createIntegerDataset(); + await dataset.root_tile.preprocessRootTileInfo(); + const selectEvens = new DataSelection(dataset, { + name: 'twos2', + tileFunction: selectFunctionForFactorsOf(2), + }); + let sortKey = 'random'; + await selectEvens.applyToAllTiles(); + + // Go reverse direction + const reverseSorted = await SortedDataSelection.fromSelection( + selectEvens, + [sortKey], + ({ random }) => random, + 'descending', + ); + + await reverseSorted.applyToAllTiles(); + + let size = 0; + let prevValue = Number.POSITIVE_INFINITY; + for (const row of reverseSorted.iterator()) { + size++; + const currValue = row[sortKey]; + assert.ok(currValue <= prevValue); + prevValue = currValue; + } + assert.ok(size, reverseSorted.selectionSize); + + // TODO: sandwich 01111112 edge case + const selectRandom = new DataSelection(dataset, { + name: 'randomSelect', + tileFunction: selectRandomRows(), + }); + sortKey = 'sandwich'; + await selectRandom.applyToAllTiles(); + + const randomSorted = await SortedDataSelection.fromSelection( + selectRandom, + [sortKey, 'random'], + ({ sandwich }) => sandwich, + ); + await randomSorted.applyToAllTiles(); + + const randomVals = []; + prevValue = 0; + for (const row of randomSorted.iterator()) { + randomVals.push(row['random']); + assert.ok(prevValue <= row['sandwich']); + } + + let index = 0; + for (const row of randomSorted.iterator(10)) { + assert.ok( + Math.abs(row['random'] - randomVals[index + 10]) < Number.EPSILON, + ); + index++; + } + + for (const row of randomSorted.iterator(0, true)) { + assert.ok(Math.abs(row['random'] - randomVals.pop()) < Number.EPSILON); + } +}); + test.run(); diff --git a/tests/datasetHelpers.js b/tests/datasetHelpers.js index 62dfcfca6..7141d60a2 100644 --- a/tests/datasetHelpers.js +++ b/tests/datasetHelpers.js @@ -15,6 +15,19 @@ export function selectFunctionForFactorsOf(n) { }; } +// Creates a tile transformation that just takes random rows +export function selectRandomRows(p = 0.5) { + return async(tile) => { + const mask = new Bitmask(tile.record_batch.numRows); + for (let i = 0; i < tile.record_batch.numRows; i++) { + if (Math.random() < p) { + mask.set(i); + } + } + return mask.to_arrow(); + } +} + function make_batch(start = 0, length = 65536, batch_number_here = 0) { let x = new Float32Array(length); let y = new Float32Array(length); @@ -44,6 +57,17 @@ function make_batch(start = 0, length = 65536, batch_number_here = 0) { randoms[i - start] = Math.random(); } + // Create an array that looks like 0, 1, 1, ..., 1, 2 + // and shuffle it + function sandwich(n) { + const arr = [0, 2, ...Array(Math.max(n - 2, 1)).fill(1)]; + for (let i = arr.length - 1; i > 0; i--) { + const j = Math.floor(Math.random() * (i + 1)); + [arr[i], arr[j]] = [arr[j], arr[i]]; + } + return arr; + } + function num_to_string(num) { return num.toString(); } @@ -55,6 +79,7 @@ function make_batch(start = 0, length = 65536, batch_number_here = 0) { integers: vectorFromArray(integers), batch_id: vectorFromArray(batch_id), random: vectorFromArray(randoms), + sandwich: vectorFromArray(sandwich(length)), }); }