Skip to content

Commit

Permalink
[RFR] Heap Sort on Tiles (#162)
Browse files Browse the repository at this point in the history
Working solution with tests. Assumes all the tile sorts can be loaded into memory. If not, we need to bring back two heaps
<!-- ELLIPSIS_HIDDEN -->


----

> [!IMPORTANT]
> Introduce heap sort for tile sorting in Deepscatter with `TileSorter` and `MinHeap`, and add tests for sorting functionality.
> 
>   - **Behavior**:
>     - Introduces `TileSorter` class in `selection.ts` for heap-based sorting of tiles.
>     - Adds `iterator()` method in `SortedDataSelection` to iterate over sorted data using `TileSorter`.
>     - Implements `MinHeap` class in `utilityFunctions.ts` for managing heap operations.
>   - **Tests**:
>     - Adds tests in `dataset.spec.js` for sorting and iterated sorting of selections using the new heap sort mechanism.
>     - Tests cover both normal and reverse iteration over sorted data.
>   - **Misc**:
>     - Minor import adjustments in `selection.ts` and `deepscatter.ts`.
> 
> <sup>This description was created by </sup>[<img alt="Ellipsis" src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=nomic-ai%2Fdeepscatter&utm_source=github&utm_medium=referral)<sup> for f1ceb61. It will automatically update as commits are pushed.</sup>


<!-- ELLIPSIS_HIDDEN -->
  • Loading branch information
rguo123 authored Oct 24, 2024
1 parent e6c2a61 commit c76696c
Show file tree
Hide file tree
Showing 5 changed files with 510 additions and 36 deletions.
2 changes: 1 addition & 1 deletion src/deepscatter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@ export type {
OpChannel,
TileProxy,
DeeptableCreateParams,
Transformation,
Transformation
} from './types';
266 changes: 234 additions & 32 deletions src/selection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { getTileFromRow } from './tixrixqid';
import type * as DS from './types';
import { Bool, StructRowProxy, Utf8, Vector, makeData } from 'apache-arrow';
import { bisectLeft, bisectRight, range } from 'd3-array';
import { MinHeap } from './utilityFunctions';
interface SelectParams {
name: string;
useNameCache?: boolean; // If true and a selection with that name already exists, use it and ignore all passed parameters. Otherwise, throw an error.
Expand Down Expand Up @@ -350,20 +351,15 @@ class SelectionTile {
return this._matchCount;
}

addSort(
key: string,
getter: (row: StructRowProxy) => number,
order: 'ascending' | 'descending',
) {
addSort(key: string, getter: (row: StructRowProxy) => number) {
const { bitmask } = this;
const indices = Bitmask.from_arrow(bitmask).which();
const pairs: [number, number][] = new Array(indices.length);
for (let i = 0; i < indices.length; i++) {
const v = getter(this.tile.record_batch.get(indices[i]));
pairs[i] = [v, indices[i]];
}
// Sort according to the specified order
pairs.sort((a, b) => (order === 'ascending' ? a[0] - b[0] : b[0] - a[0]));
pairs.sort((a, b) => a[0] - b[0]);
const values = new Float64Array(indices.length);
for (let i = 0; i < indices.length; i++) {
indices[i] = pairs[i][1];
Expand Down Expand Up @@ -1059,7 +1055,7 @@ export class SortedDataSelection extends DataSelection {
const withSort = sel.tiles.map(
async (tile: SelectionTile): Promise<SelectionTile> => {
await Promise.all(neededFields.map((f) => tile.tile.get_column(f)));
tile.addSort(key, sortOperation, order);
tile.addSort(key, sortOperation);
return tile;
},
);
Expand All @@ -1081,17 +1077,17 @@ export class SortedDataSelection extends DataSelection {

// Store the indices and values in the tile

let ix = this.tiles.findIndex((having) => having.tile === tile);
const ix = this.tiles.findIndex((having) => having.tile === tile);
let t: SelectionTile;
if (ix !== -1) {
t = this.tiles[ix];
t.addSort(this.key, this.comparisonGetter, this.order);
t.addSort(this.key, this.comparisonGetter);
} else {
t = new SelectionTile({
arrowBitmask: array,
tile,
});
t.addSort(this.key, this.comparisonGetter, this.order);
t.addSort(this.key, this.comparisonGetter);
this.selectionSize += t.matchCount;
this.evaluationSetSize += tile.manifest.nPoints;
this.tiles.push(t);
Expand All @@ -1107,56 +1103,257 @@ export class SortedDataSelection extends DataSelection {
* This implementation uses Quickselect with a pivot selected from actual data.
*/
get(k: number): StructRowProxy | undefined {
if (k < 0 || k >= this.selectionSize) {
console.error('Index out of bounds');
return undefined;
const actualK = this.order === 'ascending' ? k : this.selectionSize - k - 1;
// Implement Quickselect over the combined data
const result = quickSelect(actualK, this.tiles, this.key, true);
return result ? result.row : undefined;
}

// Given a point, returns cursor number that would select it in this selection
// which(row: StructRowProxy) {}

iterator(k: number = 0, reverse: boolean = false): TileSorter {
if (this.tiles.length == 0) {
throw new Error('No tiles in sorted selection to iterate over');
}
const actualOrder = reverse
? this.order === 'ascending'
? 'descending'
: 'ascending'
: this.order;
return new TileSorter(this.tiles, this.key, k, actualOrder);
}
}

// Adjust k based on the order
const targetIndex =
this.order === 'ascending' ? k : this.selectionSize - k - 1;
type SortInfoWithPointer = SelectionSortInfo & { pointer: number; tile: Tile };

// Implement Quickselect over the combined data
return quickSelect(targetIndex, this.tiles, this.key);
type CompareFunction = (a, b) => number;

// Separate class for now which is disposable because it maintains its
// own traversal order through the array as a pointer.
// Time complexity per iteration: O(log m) where m is the number of tiles
export class TileSorter
implements Iterator<StructRowProxy>, Iterable<StructRowProxy>
{
public tiles: SelectionTile[];
public sortKey: string;
public startingK: number;
public order: 'ascending' | 'descending';
public compare: CompareFunction;
private valueHeap: MinHeap<SortInfoWithPointer>;

constructor(
tiles: SelectionTile[],
sortKey: string,
startingK: number = 0,
order: 'ascending' | 'descending' = 'ascending',
) {
this.startingK = startingK;
this.sortKey = sortKey;
this.order = order;
this.compare = this.getCompareFunction();

this.valueHeap = new MinHeap<SortInfoWithPointer>(this.compare);
this.tiles = tiles;
this.init();
}

// Given a point, returns cursor number that would select it in this selection
which(row: StructRowProxy) {}
// Needed to bake the tile id into the comparator.
getCompareFunction(): CompareFunction {
if (this.order === 'descending') {
return (a: SortInfoWithPointer, b: SortInfoWithPointer) => {
const diff = b.values[b.pointer] - a.values[a.pointer];
return Math.abs(diff) < Number.EPSILON ? b.tile.tix - a.tile.tix : diff;
};
} else {
return (a: SortInfoWithPointer, b: SortInfoWithPointer) => {
const diff = a.values[a.pointer] - b.values[b.pointer];
return Math.abs(diff) < Number.EPSILON ? a.tile.tix - b.tile.tix : diff;
};
}
}

selectionSizeFromTiles() {
return this.tiles.reduce((acc, t) => acc + t.matchCount, 0);
}

initFromKHelper() {
const actualK =
this.order === 'ascending'
? this.startingK
: this.selectionSizeFromTiles() - this.startingK - 1;
const { tix, sortIndex, sortValue } = quickSelect(
actualK,
this.tiles,
this.sortKey,
false,
this.order,
);
const sortInfos: SortInfoWithPointer[] = [];
// Init all pointers with binary search
for (const tile of this.tiles) {
const rawSortInfo = tile.sorts[this.sortKey];
// We know exactly pointer spot for the tile with k
if (tile.tile.tix === tix) {
sortInfos.push({
...rawSortInfo,
pointer: sortIndex,
tile: tile.tile,
});
continue;
}
// All values left of pointer are less than targetValue.
let pointer;
if (this.order === 'ascending') {
pointer = bisectLeft(rawSortInfo.values, sortValue);
} else {
pointer = bisectRight(rawSortInfo.values, sortValue);
}
sortInfos.push({
...rawSortInfo,
pointer: pointer,
tile: tile.tile,
});
}

// Now skip the pointer past values <= kth val until we reach tile with k
for (const sortInfo of sortInfos) {
const { tile, values } = sortInfo;
// Found the target tile
if (tile.tix === tix) {
break;
}
// Before the target tile, skip equal values
while (values[sortInfo.pointer] <= sortValue) {
if (this.order === 'ascending') {
sortInfo.pointer++;
} else {
sortInfo.pointer--;
}
}
}

// Finally add to heap
for (const sortInfo of sortInfos) {
// Only add if there are indices and we haven't reached the end of the values
if (sortInfo.indices.length > 0 && sortInfo.pointer < sortInfo.values.length) {
this.valueHeap.insert(sortInfo);
}
}
}

init() {
// First sort the tiles to guarantee consistency
if (this.order === 'ascending') {
this.tiles.sort((a, b) => a.tile.tix - b.tile.tix);
} else {
this.tiles.sort((a, b) => b.tile.tix - a.tile.tix);
}

*yieldSorted(start = undefined, direction = 'up') {
if (start !== undefined) {
this.cursor = start;
if (this.startingK > 0) {
this.initFromKHelper();
} else {
for (const tile of this.tiles) {
const rawSortInfo = tile.sorts[this.sortKey];
// Skip tiles with empty selections
if (rawSortInfo.indices.length <= 0) {
continue
}
const sortInfo: SortInfoWithPointer = {
...rawSortInfo,
pointer:
this.order === 'ascending' ? 0 : rawSortInfo.values.length - 1,
tile: tile.tile,
};
this.valueHeap.insert(sortInfo);
}
}
}

next(): IteratorResult<StructRowProxy> {
if (this.valueHeap.isEmpty()) {
return { done: true, value: undefined };
}

const heapItem = this.valueHeap.extractMin();
const { tile, indices, pointer } = heapItem;
const index = indices[pointer];
const row = tile.record_batch.get(index);

const nextPointer: number =
this.order === 'ascending' ? pointer + 1 : pointer - 1;

if (nextPointer < indices.length && nextPointer >= 0) {
this.valueHeap.insert({ ...heapItem, pointer: nextPointer });
}
return { value: row, done: false };
}

[Symbol.iterator](): Iterator<StructRowProxy> {
return this;
}
}

interface QuickSortTile {
tile: Tile;
sorts: Record<string, SelectionSortInfo>;
}

type QuickSelectResult = {
row: StructRowProxy;
tix: number;
// Needed if you want to set pointers
sortIndex: number;
sortValue: number;
};

function quickSelect(
k: number,
tiles: QuickSortTile[],
key: string,
): StructRowProxy | undefined {
sortTiles: boolean = true, // Optional because don't need to sort on recursive calls
order: 'ascending' | 'descending' = 'ascending',
): QuickSelectResult | undefined {
// Recalculate size based on the current tiles
const size = tiles.reduce(
(acc, t) => acc + (t.sorts[key].end - t.sorts[key].start),
0,
);

if (k < 0 || k >= size) {
throw new Error('Index out of bounds');
return undefined;
}

if (size === 0) {
return undefined;
}

if (size === 1) {
for (const t of tiles) {
const { indices, start, end } = t.sorts[key];
const { indices, values, start, end } = t.sorts[key];
if (end - start > 0) {
const recordIndex = indices[start];
return t.tile.record_batch.get(recordIndex);
return {
row: t.tile.record_batch.get(recordIndex),
sortIndex: start,
sortValue: values[start],
tix: t.tile.tix,
};
}
}
return undefined;
}

// Sort the tiles to guarantee consistency
if (sortTiles) {
tiles.sort((a, b) =>
order === 'ascending'
? a.tile.tix - b.tile.tix
: -1 * (a.tile.tix - b.tile.tix),
);
}

// Select a random pivot from actual data
const pivot = randomPivotFromData(tiles, key);

Expand Down Expand Up @@ -1215,29 +1412,34 @@ function quickSelect(
}

if (k < countLess) {
return quickSelect(k, lessTiles, key);
return quickSelect(k, lessTiles, key, false);
} else if (k < countLess + countEqual) {
const indexInEqual = k - countLess;
return selectInEqualTiles(indexInEqual, equalTiles, key);
} else {
const newK = k - (countLess + countEqual);
return quickSelect(newK, greaterTiles, key);
return quickSelect(newK, greaterTiles, key, false);
}
}

function selectInEqualTiles(
indexInEqual: number,
tiles: QuickSortTile[],
key: string,
): StructRowProxy | undefined {
): QuickSelectResult | undefined {
let count = 0;
for (const t of tiles) {
const { indices, start, end } = t.sorts[key];
const { indices, values, start, end } = t.sorts[key];
const numValues = end - start;
if (indexInEqual < count + numValues) {
const idxInTile = start + (indexInEqual - count);
const recordIndex = indices[idxInTile];
return t.tile.record_batch.get(recordIndex);
return {
row: t.tile.record_batch.get(recordIndex),
tix: t.tile.tix,
sortIndex: idxInTile,
sortValue: values[idxInTile],
};
}
count += numValues;
}
Expand Down
Loading

0 comments on commit c76696c

Please sign in to comment.