Skip to content

Commit

Permalink
no tests, but json and GUI work
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishavlin committed Jan 30, 2024
1 parent 43b2ff5 commit 4e5484b
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 16 deletions.
24 changes: 24 additions & 0 deletions src/yt_napari/_data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,27 @@ class Region(BaseModel):
)


class CoveringGrid(BaseModel):
fields: List[ytField] = Field(
None, description="list of fields to load for this selection"
)
left_edge: Optional[Left_Edge] = Field(
None,
description="the left edge (min x, min y, min z)",
)
right_edge: Optional[Right_Edge] = Field(
None,
description="the right edge (max x, max y, max z)",
)
level: Optional[int] = (Field(0, description="Grid level to sample at"),)
num_ghost_zones: Optional[int] = (
Field(None, description="Number of ghost zones to include"),
)
rescale: Optional[bool] = Field(
False, description="rescale the final image between 0,1"
)


class Slice(BaseModel):
fields: List[ytField] = Field(
None, description="list of fields to load for this selection"
Expand Down Expand Up @@ -93,6 +114,9 @@ class SelectionObject(BaseModel):
regions: Optional[List[Region]] = Field(
None, description="a list of regions to load"
)
covering_grids: Optional[List[CoveringGrid]] = Field(
None, description="a list of covering grids to load"
)
slices: Optional[List[Slice]] = Field(None, description="a list of slices to load")


Expand Down
30 changes: 29 additions & 1 deletion src/yt_napari/_gui_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ def get_pydantic_attr(self, pydantic_model, field: str, widget_instance):
if self.is_registered(pydantic_model, field, required=True):
func, args, kwargs = self.registry[pydantic_model][field]["pydantic"]
return func(widget_instance, *args, **kwargs)
else:
raise RuntimeError("unexpected")

def add_pydantic_to_container(
self,
Expand Down Expand Up @@ -196,6 +198,15 @@ def get_filename(file_widget: widgets.FileEdit):
return str(file_widget.value)


def get_int_box_widget(*args, **kwargs):
# could remove the need for this if the model uses pathlib.Path for typing
return widgets.IntText(*args, **kwargs)


def get_int_val(int_box: widgets.IntText):
return int(int_box.value)


def get_magicguidefault(field_def: pydantic.fields.ModelField):
# returns an instance of the default widget selected by magicgui
ftype = field_def.type_
Expand Down Expand Up @@ -233,8 +244,10 @@ def _get_pydantic_model_field(py_model, field: str) -> pydantic.fields.ModelFiel
_models_to_embed_in_list = (
(_data_model.Slice, "fields"),
(_data_model.Region, "fields"),
(_data_model.CoveringGrid, "fields"),
(_data_model.DataContainer, "selections"),
(_data_model.SelectionObject, "regions"),
(_data_model.SelectionObject, "covering_grids"),
(_data_model.SelectionObject, "slices"),
)

Expand Down Expand Up @@ -273,6 +286,21 @@ def _register_yt_data_model(translator: MagicPydanticRegistry):
pydantic_attr_factory=split_comma_sep_string,
)

translator.register(
_data_model.CoveringGrid,
"level",
magicgui_factory=get_int_box_widget,
magicgui_kwargs={"name": "level"},
pydantic_attr_factory=get_int_val,
)
translator.register(
_data_model.CoveringGrid,
"num_ghost_zones",
magicgui_factory=get_int_box_widget,
magicgui_kwargs={"name": "num_ghost_zones"},
pydantic_attr_factory=get_int_val,
)


translator = MagicPydanticRegistry()
_register_yt_data_model(translator)
Expand All @@ -296,7 +324,7 @@ def get_yt_data_container(
return data_container


_valid_selections = ("Region", "Slice")
_valid_selections = ("Region", "Slice", "CoveringGrid")


def get_yt_selection_container(selection_type: str, return_native: bool = False):
Expand Down
38 changes: 28 additions & 10 deletions src/yt_napari/_model_ingestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from yt_napari import _special_loaders
from yt_napari._data_model import (
CoveringGrid,
DataContainer,
InputModel,
MetadataModel,
Expand Down Expand Up @@ -434,7 +435,13 @@ def _load_3D_regions(
layer_list: list,
timeseries_container: Optional[TimeseriesContainer] = None,
) -> list:
for sel in selections.regions:

sels = []
for seltype in ("regions", "covering_grids"):
if getattr(selections, seltype) is not None:
sels += [sel for sel in getattr(selections, seltype)]

for sel in sels:
# get the left, right edge as a unitful array, initialize the layer
# domain tracking for this layer and update the global domain extent
if sel.left_edge is None:
Expand All @@ -446,16 +453,27 @@ def _load_3D_regions(
RE = ds.domain_right_edge
else:
RE = ds.arr(sel.right_edge.value, sel.right_edge.unit)
res = sel.resolution
layer_domain = LayerDomain(left_edge=LE, right_edge=RE, resolution=res)

# create the fixed resolution buffer
frb = ds.r[
LE[0] : RE[0] : complex(0, res[0]), # noqa: E203
LE[1] : RE[1] : complex(0, res[1]), # noqa: E203
LE[2] : RE[2] : complex(0, res[2]), # noqa: E203
]
if isinstance(sel, Region):
res = sel.resolution
frb = ds.r[
LE[0] : RE[0] : complex(0, res[0]), # noqa: E203
LE[1] : RE[1] : complex(0, res[1]), # noqa: E203
LE[2] : RE[2] : complex(0, res[2]), # noqa: E203
]
elif isinstance(sel, CoveringGrid):
# get a temp covering grid with specified ghost zones then
# recalcuate dims at correct dds
dims = (4, 4, 4)
nghostzones = sel.num_ghost_zones
temp_cg = ds.covering_grid(sel.level, LE, dims, num_ghost_zones=nghostzones)
effective_dds = temp_cg.dds
dims = (RE - LE) / effective_dds
# get the actual covering grid
frb = ds.covering_grid(sel.level, LE, dims, num_ghost_zones=nghostzones)
res = dims

layer_domain = LayerDomain(left_edge=LE, right_edge=RE, resolution=res)
for field_container in sel.fields:
field = (field_container.field_type, field_container.field_name)

Expand Down Expand Up @@ -600,7 +618,7 @@ def _load_selections_from_ds(
layer_list: List[SpatialLayer],
timeseries_container: Optional[TimeseriesContainer] = None,
) -> List[SpatialLayer]:
if selections.regions is not None:
if selections.regions is not None or selections.covering_grids is not None:
layer_list = _load_3D_regions(
ds, selections, layer_list, timeseries_container=timeseries_container
)
Expand Down
8 changes: 3 additions & 5 deletions src/yt_napari/_widget_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ def add_load_group_widgets(self):
load_group.addWidget(ss.native)

def save_selection(self):
py_kwargs = {}
py_kwargs = self._validate_data_model()

file_dialog = QFileDialog()
Expand All @@ -145,13 +144,11 @@ def load_data(self):
# instantiate pydantic objects, which are then handed off to the
# same data ingestion function as the json loader.

py_kwargs = {}
py_kwargs = self._validate_data_model()
model = _data_model.InputModel.parse_obj(py_kwargs)

# process each layer
layer_list, _ = _model_ingestor._process_validated_model(model)

# align all layers after checking for or setting the reference layer
ref_layer = _check_for_reference_layer(self.viewer.layers)
if ref_layer is None:
Expand All @@ -167,11 +164,13 @@ def load_data(self):
self.viewer.add_image(im_arr, **im_kwargs)

def _validate_data_model(self):
# this function save json data

selections_by_type = defaultdict(list)
for selection in self.active_selections.values():
py_kwargs = selection.get_current_pydantic_kwargs()
sel_key = selection.selection_type.lower() + "s"
if "covering" in sel_key:
sel_key = "covering_grids"
selections_by_type[sel_key].append(py_kwargs)

# next, process remaining arguments (skipping selections):
Expand Down Expand Up @@ -281,7 +280,6 @@ def save_selection(self):
json.dump(py_kwargs, json_file, indent=4)

def load_data(self):
py_kwargs = {}
py_kwargs = self._validate_data_model()
model = _data_model.InputModel.parse_obj(py_kwargs)

Expand Down

0 comments on commit 4e5484b

Please sign in to comment.