Skip to content

Commit

Permalink
fix for mutlicoordinate selection
Browse files Browse the repository at this point in the history
  • Loading branch information
jreadey committed Sep 20, 2024
1 parent b983832 commit e8d7a02
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 61 deletions.
27 changes: 0 additions & 27 deletions hsds/dset_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,33 +355,6 @@ def getChunkItem(chunkid):
return chunkinfo_map


def get_chunkmap_selections(chunk_map, chunk_ids, slices, dset_json):
""" Update chunk_map with chunk and data selections for the
given set of slices
"""
log.debug(f"get_chunkmap_selections - {len(chunk_ids)} chunk_ids")
if not slices:
log.debug("no slices set, returning")
return # nothing to do
log.debug(f"slices: {slices}")
layout = getChunkLayout(dset_json)
for chunk_id in chunk_ids:
if chunk_id in chunk_map:
item = chunk_map[chunk_id]
else:
item = {}
chunk_map[chunk_id] = item

chunk_sel = getChunkCoverage(chunk_id, slices, layout)
log.debug(
f"get_chunk_selections - chunk_id: {chunk_id}, chunk_sel: {chunk_sel}"
)
item["chunk_sel"] = chunk_sel
data_sel = getDataCoverage(chunk_id, slices, layout)
log.debug(f"get_chunk_selections - data_sel: {data_sel}")
item["data_sel"] = data_sel


def get_chunk_selections(chunk_map, chunk_ids, slices, dset_json):
"""Update chunk_map with chunk and data selections for the
given set of slices
Expand Down
43 changes: 27 additions & 16 deletions hsds/util/chunkUtil.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,8 +445,7 @@ def getChunkIdForPartition(chunk_id, dset_json):

def getChunkIds(dset_id, selection, layout, prefix=None):
"""Get the all the chunk ids for chunks that lie in the
selection of the
given dataset.
selection of the given dataset.
"""

def chunk_index_to_id(indices):
Expand Down Expand Up @@ -598,10 +597,29 @@ def getChunkSelection(chunk_id, slices, layout):
Return the intersection of the chunk with the given slices
selection of the array.
"""
# print("getChunkSelection - chunk_id:", chunk_id, "slices:", slices)
chunk_index = getChunkIndex(chunk_id)
rank = len(layout)
sel = []

coord_mask = None
# compute a boolean mask for the coordinates that apply to the given chunk_id
for dim in range(rank):
s = slices[dim]
c = layout[dim]
n = chunk_index[dim] * c
if isinstance(s, slice):
continue
if coord_mask is None:
coord_mask = [True,] * len(s)
if len(s) != len(coord_mask):
raise ValueError("mismatched number of coordinates for fancy selection")

for i in range(len(s)):
if not coord_mask[i]:
continue
if s[i] < n or s[i] >= n + c:
coord_mask[i] = False

for dim in range(rank):
s = slices[dim]
c = layout[dim]
Expand Down Expand Up @@ -629,9 +647,9 @@ def getChunkSelection(chunk_id, slices, layout):
else:
# coord list
coords = []
for j in s:
if j >= n and j < n + c:
coords.append(j)
for i in range(len(s)):
if coord_mask[i]:
coords.append(s[i])
sel.append(coords)

return sel
Expand All @@ -646,6 +664,7 @@ def getChunkCoverage(chunk_id, slices, layout):
if not chunk_sel:
log.warn(f"slices: {slices} does intersect chunk: {chunk_id}")
return None

rank = len(layout)
if len(slices) != rank:
raise ValueError(f"invalid slices value for dataset of rank: {rank}")
Expand All @@ -670,16 +689,8 @@ def getChunkCoverage(chunk_id, slices, layout):
sel.append(slice(start, stop, step))
else:
coord = []
for j in s:
if j - offset < 0:
msg = "Unexpected chunk selection"
log.error(msg)
raise ValueError(msg)
elif j - offset >= w:
msg = "Unexpected chunk selection"
log.error(msg)
raise ValueError(msg)
coord.append(j - offset)
for i in range(len(s)):
coord.append(s[i] - offset)
sel.append(tuple(coord))

return sel
Expand Down
54 changes: 36 additions & 18 deletions tests/unit/chunk_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,6 @@ def testExpandChunk(self):
num_bytes = getChunkSize(layout, "H5T_VARIABLE")
self.assertTrue(num_bytes < CHUNK_MIN)
expanded = expandChunk(layout, "H5T_VARIABLE", shape, chunk_min=CHUNK_MIN)
print("expanded:", expanded)
num_bytes = getChunkSize(expanded, "H5T_VARIABLE")
self.assertTrue(num_bytes > CHUNK_MIN)
self.assertTrue(num_bytes < CHUNK_MAX)
Expand Down Expand Up @@ -593,27 +592,25 @@ def testGetChunkIds(self):
selection = getHyperslabSelection(datashape, (0, 0), (7639, 6307))
chunk_ids = getChunkIds(dset_id, selection, layout)
self.assertEqual(len(chunk_ids), 7639)
chunk_ids.reverse() # so we can pop off the front
index_set = set()
for i in range(7639):
chunk_id = chunk_ids.pop()
self.assertTrue(chunk_id.startswith("c-"))
fields = chunk_id.split("_")
self.assertEqual(len(fields), 3)
index1 = int(fields[1])
index2 = int(fields[2])
self.assertEqual(index1, i)
index_set.add(index1)
self.assertEqual(index2, 0)
self.assertEqual(len(index_set), 7639)

def testGetChunkIndex(self):
chunk_id = "c-12345678-1234-1234-1234-1234567890ab_6_4"
index = getChunkIndex(chunk_id)
self.assertEqual(index, [6, 4])
chunk_id = "c-12345678-1234-1234-1234-1234567890ab_64"
index = getChunkIndex(chunk_id)
self.assertEqual(
index,
[64,],
)
self.assertEqual(index, [64,])

def testGetChunkSelection(self):
# 1-d test
Expand Down Expand Up @@ -712,7 +709,6 @@ def testGetChunkSelection(self):
chunk_ids = getChunkIds(dset_id, selection, layout)

self.assertEqual(len(chunk_ids), 2)
print("x:", chunk_ids)
chunk_id = f"c-{dset_id[2:]}_1"
self.assertTrue(chunk_id in chunk_ids)

Expand Down Expand Up @@ -804,20 +800,21 @@ def testGetChunkSelection(self):
layout = (10,)
selection = getHyperslabSelection(datashape, 92, 102)
chunk_ids = getChunkIds(dset_id, selection, layout)
chunk_ids.sort()
self.assertEqual(len(chunk_ids), 2)

chunk_id = chunk_ids[0]
sel = getChunkSelection(chunk_id, selection, layout)
sel = sel[0]
self.assertEqual(sel.start, 92)
self.assertEqual(sel.stop, 100)
self.assertEqual(sel.start, 100)
self.assertEqual(sel.stop, 102)
self.assertEqual(sel.step, 1)

chunk_id = chunk_ids[1]
sel = getChunkSelection(chunk_id, selection, layout)
sel = sel[0]
self.assertEqual(sel.start, 100)
self.assertEqual(sel.stop, 102)
self.assertEqual(sel.start, 92)
self.assertEqual(sel.stop, 100)
self.assertEqual(sel.step, 1)

# 3d test
Expand Down Expand Up @@ -886,7 +883,6 @@ def testGetChunkCoverage(self):
chunk_id = f"c-{dset_id[2:]}_3"
self.assertTrue(chunk_id in chunk_ids)
sel = getChunkCoverage(chunk_id, selection, layout)
print("sel:", sel)
self.assertEqual(sel[0], (2, 9))

# 1-d with step
Expand Down Expand Up @@ -1004,25 +1000,48 @@ def testGetChunkCoverage(self):
self.assertEqual(sel[0].step, 1)
self.assertEqual(sel[1], (2, 9))

# 3-d test with coodinates
datashape = (5, 1000, 1000)
layout = (3, 500, 500)
selection = (slice(0, 5, 1), [1, 10, 100], [10, 100, 500])
chunk_ids = getChunkIds(dset_id, selection, layout)
chunk_ids.sort()
self.assertEqual(len(chunk_ids), 4)
chunk_id = chunk_ids[0]
sel = getChunkCoverage(chunk_id, selection, layout)
self.assertEqual(sel[0].start, 0)
self.assertEqual(sel[0].stop, 3)
self.assertEqual(sel[0].step, 1)
self.assertEqual(sel[1], (1, 10))
self.assertEqual(sel[2], (10, 100))
chunk_id = chunk_ids[1]
sel = getChunkCoverage(chunk_id, selection, layout)
self.assertEqual(sel[0].start, 0)
self.assertEqual(sel[0].stop, 3)
self.assertEqual(sel[0].step, 1)
self.assertEqual(sel[1], (100,))
self.assertEqual(sel[2], (0,))

# 1-d test with fractional chunks
datashape = [104,]
layout = (10,)
selection = getHyperslabSelection(datashape, 92, 102)
chunk_ids = getChunkIds(dset_id, selection, layout)
chunk_ids.sort()
self.assertEqual(len(chunk_ids), 2)

chunk_id = chunk_ids[0]
sel = getChunkCoverage(chunk_id, selection, layout)
sel = sel[0]
self.assertEqual(sel.start, 2)
self.assertEqual(sel.stop, 10)
self.assertEqual(sel.start, 0)
self.assertEqual(sel.stop, 2)
self.assertEqual(sel.step, 1)

chunk_id = chunk_ids[1]
sel = getChunkCoverage(chunk_id, selection, layout)
sel = sel[0]
self.assertEqual(sel.start, 0)
self.assertEqual(sel.stop, 2)
self.assertEqual(sel.start, 2)
self.assertEqual(sel.stop, 10)
self.assertEqual(sel.step, 1)

def testGetDataCoverage(self):
Expand Down Expand Up @@ -1236,7 +1255,6 @@ def testGetDataCoverage(self):
datashape = [104,]
layout = (10,)
selection = getHyperslabSelection(datashape, 92, 102)
print("selection:", selection)
chunk_ids = getChunkIds(dset_id, selection, layout)

self.assertEqual(len(chunk_ids), 2)
Expand Down

0 comments on commit e8d7a02

Please sign in to comment.