Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chunking and Padding #829

Merged
merged 14 commits into from
Oct 29, 2024
16 changes: 13 additions & 3 deletions databroker/mongo_normalized.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import copy
from datetime import datetime, timedelta
import functools
import inspect
import itertools
import logging
import os
Expand Down Expand Up @@ -160,7 +161,8 @@ def structure_from_descriptor(descriptor, sub_dict, max_seq_num, unicode_columns
numpy_dtype = dtype.to_numpy_dtype()
if "chunks" in field_metadata:
# If the Event Descriptor tells us a preferred chunking, use that.
suggested_chunks = tuple(tuple(chunks) for chunks in field_metadata["chunks"])
suggested_chunks = [tuple(chunk) if isinstance(chunk, list)
else chunk for chunk in field_metadata['chunks']]
elif (0 in shape) or (numpy_dtype.itemsize == 0):
# special case to avoid warning from dask
suggested_chunks = shape
Expand Down Expand Up @@ -931,6 +933,9 @@ def populate_columns(keys, min_seq_num, max_seq_num):
map(
lambda item: self.validate_shape(
key, numpy.asarray(item), expected_shape
) if 'uid' in inspect.signature(self.validate_shape).parameters
else self.validate_shape(
key, numpy.asarray(item), expected_shape, uid=self._run.metadata()['start']['uid']
),
result[key],
)
Expand Down Expand Up @@ -2204,14 +2209,14 @@ class BadShapeMetadata(Exception):
pass


def default_validate_shape(key, data, expected_shape):
def default_validate_shape(key, data, expected_shape, uid=None):
"""
Check that data.shape == expected.shape.

* If number of dimensions differ, raise BadShapeMetadata
* If any dimension differs by more than MAX_SIZE_DIFF, raise BadShapeMetadata.
* If some dimensions are smaller than expected,, pad "right" edge of each
dimension that falls short with NaN.
dimension that falls short with zeros.
"""
MAX_SIZE_DIFF = 2
if data.shape == expected_shape:
Expand Down Expand Up @@ -2241,6 +2246,11 @@ def default_validate_shape(key, data, expected_shape):
else: # margin == 0
padding.append((0, 0))
padded = numpy.pad(data, padding, "edge")

logger.warning(f"The data.shape: {data.shape} did not match the expected_shape: "
danielballan marked this conversation as resolved.
Show resolved Hide resolved
f"{expected_shape} for key: '{key}'. This data has been zero-padded "
"to match the expected_shape! RunStart UID: {uid}")

return padded


Expand Down
21 changes: 0 additions & 21 deletions databroker/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import contextlib
import getpass
import os
import pytest
import sys
Expand Down Expand Up @@ -114,22 +112,3 @@ def delete_dm():
@pytest.fixture(params=['scalar', 'image', 'external_image'])
def detector(request, hw):
return getattr(hw, SIM_DETECTORS[request.param])


@pytest.fixture
def enter_password(monkeypatch):
"""
Return a context manager that overrides getpass, used like:

>>> with enter_password(...):
... # Run code that calls getpass.getpass().
"""

@contextlib.contextmanager
def f(password):
original = getpass.getpass
monkeypatch.setattr("getpass.getpass", lambda: password)
yield
monkeypatch.setattr("getpass.getpass", original)

return f
8 changes: 4 additions & 4 deletions databroker/tests/test_access_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from bluesky.plans import count
from tiled.client import Context, from_context
from tiled.server.app import build_app_from_config

from tiled._tests.utils import enter_username_password
from ..mongo_normalized import MongoAdapter, SimpleAccessPolicy


Expand All @@ -22,7 +22,7 @@ def __init__(self, *args, **kwargs):
InstrumentedMongoAdapter.from_mongomock(access_policy=access_policy)


def test_access_policy_example(tmpdir, enter_password):
def test_access_policy_example(tmpdir):

config = {
"authentication": {
Expand Down Expand Up @@ -52,8 +52,8 @@ def test_access_policy_example(tmpdir, enter_password):
],
}
with Context.from_app(build_app_from_config(config), token_cache=tmpdir) as context:
with enter_password("secret"):
client = from_context(context, username="alice", prompt_for_reauthentication=True)
with enter_username_password("alice", "secret"):
client = from_context(context, prompt_for_reauthentication=True)

def post_document(name, doc):
client.post_document(name, doc)
Expand Down
99 changes: 71 additions & 28 deletions databroker/tests/test_validate_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,42 +10,61 @@
import pytest


def test_validate_shape(tmpdir):
# custom_validate_shape will mutate this to show it has been called
shapes = []

def custom_validate_shape(key, data, expected_shape):
shapes.append(expected_shape)
return data
@pytest.mark.parametrize(
"shape,expected_shape",
[
((10,), (11,)), # Short by 1, 1d.
((10, 20), (10, 21)), # Short by 1, 2d.
((10, 20, 30), (10, 21, 30)), # Short by 1, 3d.
((10, 20, 30), (10, 20, 31)), # Short by 1, 3d.
((10, 20), (10, 19)), # Too-big by 1, 2d.
((20, 20, 20, 20), (20, 21, 20, 22)), # 4d example.
],
)
def test_padding(tmpdir, shape, expected_shape):
adapter = MongoAdapter.from_mongomock()

adapter = MongoAdapter.from_mongomock(validate_shape=custom_validate_shape)
direct_img = DirectImage(
func=lambda: np.array(np.ones(shape)), name="direct", labels={"detectors"}
)
direct_img.img.name = "img"

with Context.from_app(build_app(adapter), token_cache=tmpdir) as context:
client = from_context(context)

def post_document(name, doc):
if name == "descriptor":
doc["data_keys"]["img"]["shape"] = expected_shape

client.post_document(name, doc)

RE = RunEngine()
RE.subscribe(post_document)
(uid,) = RE(count([img]))
assert not shapes
client[uid]["primary"]["data"]["img"][:]
assert shapes
(uid,) = RE(count([direct_img]))
assert client[uid]["primary"]["data"]["img"][0].shape == expected_shape


@pytest.mark.parametrize(
"shape,expected_shape",
"chunks,shape,expected_chunks",
[
((10,), (11,)),
((10, 20), (10, 21)),
((10, 20), (10, 19)),
((10, 20, 30), (10, 21, 30)),
((10, 20, 30), (10, 20, 31)),
((20, 20, 20, 20), (20, 21, 20, 22)),
([1, 2], (10,), ((1,), (2, 2, 2, 2, 2))), # 1D image
([1, 3], (10,), ((1,), (3, 3, 3, 1))), # not evenly divisible.
([1, 2, 2], (10, 10), ((1,), (2, 2, 2, 2, 2), (2, 2, 2, 2, 2))), # 2D
([1, 2, -1], (10, 10), ((1,), (2, 2, 2, 2, 2), (10,))), # -1 for max size.
([1, 2, "auto"], (10, 10), ((1,), (2, 2, 2, 2, 2), (10,))), # auto
(
((1,), (2, 2, 2, 2, 2), (2, 2, 2, 2, 2)),
(10, 10),
((1,), (2, 2, 2, 2, 2), (2, 2, 2, 2, 2)),
), # normalized chunks
(
[1, 5, "auto", -1, 5],
(10, 10, 10, 10),
((1,), (5, 5), (10,), (10,), (5, 5))
), # mixture of things.
],
)
def test_padding(tmpdir, shape, expected_shape):
def test_custom_chunking(tmpdir, chunks, shape, expected_chunks):
adapter = MongoAdapter.from_mongomock()

direct_img = DirectImage(
Expand All @@ -54,30 +73,30 @@ def test_padding(tmpdir, shape, expected_shape):
direct_img.img.name = "img"

with Context.from_app(build_app(adapter), token_cache=tmpdir) as context:
client = from_context(context)
client = from_context(context, "dask")

def post_document(name, doc):
if name == "descriptor":
doc["data_keys"]["img"]["shape"] = expected_shape
doc["data_keys"]["img"]["chunks"] = chunks

client.post_document(name, doc)

RE = RunEngine()
RE.subscribe(post_document)
(uid,) = RE(count([direct_img]))
assert client[uid]["primary"]["data"]["img"][0].shape == expected_shape
assert client[uid]["primary"]["data"]["img"].chunks == expected_chunks


@pytest.mark.parametrize(
"shape,expected_shape",
[
((10,), (11, 12)),
((10, 20), (10, 200)),
((20, 20, 20, 20), (20, 21, 20, 200)),
((10, 20), (5, 20)),
((10,), (11, 12)), # Different number of dimensions.
((10, 20), (10, 200)), # Dimension sizes differ by more than 2.
((20, 20, 20, 20), (20, 21, 20, 200)), # Dimension sizes differ by more than 2.
((10, 20), (5, 20)), # Data is bigger than expected.
],
)
def test_default_validate_shape(tmpdir, shape, expected_shape):
def test_validate_shape_exceptions(tmpdir, shape, expected_shape):
adapter = MongoAdapter.from_mongomock()

direct_img = DirectImage(
Expand All @@ -99,3 +118,27 @@ def post_document(name, doc):
(uid,) = RE(count([direct_img]))
with pytest.raises(BadShapeMetadata):
client[uid]["primary"]["data"]["img"][:]


def test_custom_validate_shape(tmpdir):
# custom_validate_shape will mutate this to show it has been called
shapes = []

def custom_validate_shape(key, data, expected_shape):
shapes.append(expected_shape)
return data

adapter = MongoAdapter.from_mongomock(validate_shape=custom_validate_shape)

with Context.from_app(build_app(adapter), token_cache=tmpdir) as context:
client = from_context(context)

def post_document(name, doc):
client.post_document(name, doc)

RE = RunEngine()
RE.subscribe(post_document)
(uid,) = RE(count([img]))
assert not shapes
client[uid]["primary"]["data"]["img"][:]
assert shapes
Loading