From e8d7a026433c7ab574b5fe85a8337799e5a90ac4 Mon Sep 17 00:00:00 2001 From: John Readey Date: Fri, 20 Sep 2024 11:15:05 -0500 Subject: [PATCH] fix for mutlicoordinate selection --- hsds/dset_lib.py | 27 ------------------ hsds/util/chunkUtil.py | 43 +++++++++++++++++----------- tests/unit/chunk_util_test.py | 54 +++++++++++++++++++++++------------ 3 files changed, 63 insertions(+), 61 deletions(-) diff --git a/hsds/dset_lib.py b/hsds/dset_lib.py index ece44e72..060bb703 100755 --- a/hsds/dset_lib.py +++ b/hsds/dset_lib.py @@ -355,33 +355,6 @@ def getChunkItem(chunkid): return chunkinfo_map -def get_chunkmap_selections(chunk_map, chunk_ids, slices, dset_json): - """ Update chunk_map with chunk and data selections for the - given set of slices - """ - log.debug(f"get_chunkmap_selections - {len(chunk_ids)} chunk_ids") - if not slices: - log.debug("no slices set, returning") - return # nothing to do - log.debug(f"slices: {slices}") - layout = getChunkLayout(dset_json) - for chunk_id in chunk_ids: - if chunk_id in chunk_map: - item = chunk_map[chunk_id] - else: - item = {} - chunk_map[chunk_id] = item - - chunk_sel = getChunkCoverage(chunk_id, slices, layout) - log.debug( - f"get_chunk_selections - chunk_id: {chunk_id}, chunk_sel: {chunk_sel}" - ) - item["chunk_sel"] = chunk_sel - data_sel = getDataCoverage(chunk_id, slices, layout) - log.debug(f"get_chunk_selections - data_sel: {data_sel}") - item["data_sel"] = data_sel - - def get_chunk_selections(chunk_map, chunk_ids, slices, dset_json): """Update chunk_map with chunk and data selections for the given set of slices diff --git a/hsds/util/chunkUtil.py b/hsds/util/chunkUtil.py index 581a9eda..170ea007 100644 --- a/hsds/util/chunkUtil.py +++ b/hsds/util/chunkUtil.py @@ -445,8 +445,7 @@ def getChunkIdForPartition(chunk_id, dset_json): def getChunkIds(dset_id, selection, layout, prefix=None): """Get the all the chunk ids for chunks that lie in the - selection of the - given dataset. + selection of the given dataset. """ def chunk_index_to_id(indices): @@ -598,10 +597,29 @@ def getChunkSelection(chunk_id, slices, layout): Return the intersection of the chunk with the given slices selection of the array. """ - # print("getChunkSelection - chunk_id:", chunk_id, "slices:", slices) chunk_index = getChunkIndex(chunk_id) rank = len(layout) sel = [] + + coord_mask = None + # compute a boolean mask for the coordinates that apply to the given chunk_id + for dim in range(rank): + s = slices[dim] + c = layout[dim] + n = chunk_index[dim] * c + if isinstance(s, slice): + continue + if coord_mask is None: + coord_mask = [True,] * len(s) + if len(s) != len(coord_mask): + raise ValueError("mismatched number of coordinates for fancy selection") + + for i in range(len(s)): + if not coord_mask[i]: + continue + if s[i] < n or s[i] >= n + c: + coord_mask[i] = False + for dim in range(rank): s = slices[dim] c = layout[dim] @@ -629,9 +647,9 @@ def getChunkSelection(chunk_id, slices, layout): else: # coord list coords = [] - for j in s: - if j >= n and j < n + c: - coords.append(j) + for i in range(len(s)): + if coord_mask[i]: + coords.append(s[i]) sel.append(coords) return sel @@ -646,6 +664,7 @@ def getChunkCoverage(chunk_id, slices, layout): if not chunk_sel: log.warn(f"slices: {slices} does intersect chunk: {chunk_id}") return None + rank = len(layout) if len(slices) != rank: raise ValueError(f"invalid slices value for dataset of rank: {rank}") @@ -670,16 +689,8 @@ def getChunkCoverage(chunk_id, slices, layout): sel.append(slice(start, stop, step)) else: coord = [] - for j in s: - if j - offset < 0: - msg = "Unexpected chunk selection" - log.error(msg) - raise ValueError(msg) - elif j - offset >= w: - msg = "Unexpected chunk selection" - log.error(msg) - raise ValueError(msg) - coord.append(j - offset) + for i in range(len(s)): + coord.append(s[i] - offset) sel.append(tuple(coord)) return sel diff --git a/tests/unit/chunk_util_test.py b/tests/unit/chunk_util_test.py index 75ebfb5a..e7450d4b 100755 --- a/tests/unit/chunk_util_test.py +++ b/tests/unit/chunk_util_test.py @@ -187,7 +187,6 @@ def testExpandChunk(self): num_bytes = getChunkSize(layout, "H5T_VARIABLE") self.assertTrue(num_bytes < CHUNK_MIN) expanded = expandChunk(layout, "H5T_VARIABLE", shape, chunk_min=CHUNK_MIN) - print("expanded:", expanded) num_bytes = getChunkSize(expanded, "H5T_VARIABLE") self.assertTrue(num_bytes > CHUNK_MIN) self.assertTrue(num_bytes < CHUNK_MAX) @@ -593,7 +592,7 @@ def testGetChunkIds(self): selection = getHyperslabSelection(datashape, (0, 0), (7639, 6307)) chunk_ids = getChunkIds(dset_id, selection, layout) self.assertEqual(len(chunk_ids), 7639) - chunk_ids.reverse() # so we can pop off the front + index_set = set() for i in range(7639): chunk_id = chunk_ids.pop() self.assertTrue(chunk_id.startswith("c-")) @@ -601,8 +600,9 @@ def testGetChunkIds(self): self.assertEqual(len(fields), 3) index1 = int(fields[1]) index2 = int(fields[2]) - self.assertEqual(index1, i) + index_set.add(index1) self.assertEqual(index2, 0) + self.assertEqual(len(index_set), 7639) def testGetChunkIndex(self): chunk_id = "c-12345678-1234-1234-1234-1234567890ab_6_4" @@ -610,10 +610,7 @@ def testGetChunkIndex(self): self.assertEqual(index, [6, 4]) chunk_id = "c-12345678-1234-1234-1234-1234567890ab_64" index = getChunkIndex(chunk_id) - self.assertEqual( - index, - [64,], - ) + self.assertEqual(index, [64,]) def testGetChunkSelection(self): # 1-d test @@ -712,7 +709,6 @@ def testGetChunkSelection(self): chunk_ids = getChunkIds(dset_id, selection, layout) self.assertEqual(len(chunk_ids), 2) - print("x:", chunk_ids) chunk_id = f"c-{dset_id[2:]}_1" self.assertTrue(chunk_id in chunk_ids) @@ -804,20 +800,21 @@ def testGetChunkSelection(self): layout = (10,) selection = getHyperslabSelection(datashape, 92, 102) chunk_ids = getChunkIds(dset_id, selection, layout) + chunk_ids.sort() self.assertEqual(len(chunk_ids), 2) chunk_id = chunk_ids[0] sel = getChunkSelection(chunk_id, selection, layout) sel = sel[0] - self.assertEqual(sel.start, 92) - self.assertEqual(sel.stop, 100) + self.assertEqual(sel.start, 100) + self.assertEqual(sel.stop, 102) self.assertEqual(sel.step, 1) chunk_id = chunk_ids[1] sel = getChunkSelection(chunk_id, selection, layout) sel = sel[0] - self.assertEqual(sel.start, 100) - self.assertEqual(sel.stop, 102) + self.assertEqual(sel.start, 92) + self.assertEqual(sel.stop, 100) self.assertEqual(sel.step, 1) # 3d test @@ -886,7 +883,6 @@ def testGetChunkCoverage(self): chunk_id = f"c-{dset_id[2:]}_3" self.assertTrue(chunk_id in chunk_ids) sel = getChunkCoverage(chunk_id, selection, layout) - print("sel:", sel) self.assertEqual(sel[0], (2, 9)) # 1-d with step @@ -1004,25 +1000,48 @@ def testGetChunkCoverage(self): self.assertEqual(sel[0].step, 1) self.assertEqual(sel[1], (2, 9)) + # 3-d test with coodinates + datashape = (5, 1000, 1000) + layout = (3, 500, 500) + selection = (slice(0, 5, 1), [1, 10, 100], [10, 100, 500]) + chunk_ids = getChunkIds(dset_id, selection, layout) + chunk_ids.sort() + self.assertEqual(len(chunk_ids), 4) + chunk_id = chunk_ids[0] + sel = getChunkCoverage(chunk_id, selection, layout) + self.assertEqual(sel[0].start, 0) + self.assertEqual(sel[0].stop, 3) + self.assertEqual(sel[0].step, 1) + self.assertEqual(sel[1], (1, 10)) + self.assertEqual(sel[2], (10, 100)) + chunk_id = chunk_ids[1] + sel = getChunkCoverage(chunk_id, selection, layout) + self.assertEqual(sel[0].start, 0) + self.assertEqual(sel[0].stop, 3) + self.assertEqual(sel[0].step, 1) + self.assertEqual(sel[1], (100,)) + self.assertEqual(sel[2], (0,)) + # 1-d test with fractional chunks datashape = [104,] layout = (10,) selection = getHyperslabSelection(datashape, 92, 102) chunk_ids = getChunkIds(dset_id, selection, layout) + chunk_ids.sort() self.assertEqual(len(chunk_ids), 2) chunk_id = chunk_ids[0] sel = getChunkCoverage(chunk_id, selection, layout) sel = sel[0] - self.assertEqual(sel.start, 2) - self.assertEqual(sel.stop, 10) + self.assertEqual(sel.start, 0) + self.assertEqual(sel.stop, 2) self.assertEqual(sel.step, 1) chunk_id = chunk_ids[1] sel = getChunkCoverage(chunk_id, selection, layout) sel = sel[0] - self.assertEqual(sel.start, 0) - self.assertEqual(sel.stop, 2) + self.assertEqual(sel.start, 2) + self.assertEqual(sel.stop, 10) self.assertEqual(sel.step, 1) def testGetDataCoverage(self): @@ -1236,7 +1255,6 @@ def testGetDataCoverage(self): datashape = [104,] layout = (10,) selection = getHyperslabSelection(datashape, 92, 102) - print("selection:", selection) chunk_ids = getChunkIds(dset_id, selection, layout) self.assertEqual(len(chunk_ids), 2)