diff --git a/src/yt_napari/_data_model.py b/src/yt_napari/_data_model.py index 4330a23..1b4b90f 100644 --- a/src/yt_napari/_data_model.py +++ b/src/yt_napari/_data_model.py @@ -63,6 +63,27 @@ class Region(BaseModel): ) +class CoveringGrid(BaseModel): + fields: List[ytField] = Field( + None, description="list of fields to load for this selection" + ) + left_edge: Optional[Left_Edge] = Field( + None, + description="the left edge (min x, min y, min z)", + ) + right_edge: Optional[Right_Edge] = Field( + None, + description="the right edge (max x, max y, max z)", + ) + level: Optional[int] = (Field(0, description="Grid level to sample at"),) + num_ghost_zones: Optional[int] = ( + Field(None, description="Number of ghost zones to include"), + ) + rescale: Optional[bool] = Field( + False, description="rescale the final image between 0,1" + ) + + class Slice(BaseModel): fields: List[ytField] = Field( None, description="list of fields to load for this selection" @@ -93,6 +114,9 @@ class SelectionObject(BaseModel): regions: Optional[List[Region]] = Field( None, description="a list of regions to load" ) + covering_grids: Optional[List[CoveringGrid]] = Field( + None, description="a list of covering grids to load" + ) slices: Optional[List[Slice]] = Field(None, description="a list of slices to load") diff --git a/src/yt_napari/_gui_utilities.py b/src/yt_napari/_gui_utilities.py index d52cc76..648de8b 100644 --- a/src/yt_napari/_gui_utilities.py +++ b/src/yt_napari/_gui_utilities.py @@ -101,6 +101,8 @@ def get_pydantic_attr(self, pydantic_model, field: str, widget_instance): if self.is_registered(pydantic_model, field, required=True): func, args, kwargs = self.registry[pydantic_model][field]["pydantic"] return func(widget_instance, *args, **kwargs) + else: + raise RuntimeError("unexpected") def add_pydantic_to_container( self, @@ -196,6 +198,15 @@ def get_filename(file_widget: widgets.FileEdit): return str(file_widget.value) +def get_int_box_widget(*args, **kwargs): + # could remove the need for this if the model uses pathlib.Path for typing + return widgets.IntText(*args, **kwargs) + + +def get_int_val(int_box: widgets.IntText): + return int(int_box.value) + + def get_magicguidefault(field_def: pydantic.fields.ModelField): # returns an instance of the default widget selected by magicgui ftype = field_def.type_ @@ -233,8 +244,10 @@ def _get_pydantic_model_field(py_model, field: str) -> pydantic.fields.ModelFiel _models_to_embed_in_list = ( (_data_model.Slice, "fields"), (_data_model.Region, "fields"), + (_data_model.CoveringGrid, "fields"), (_data_model.DataContainer, "selections"), (_data_model.SelectionObject, "regions"), + (_data_model.SelectionObject, "covering_grids"), (_data_model.SelectionObject, "slices"), ) @@ -273,6 +286,21 @@ def _register_yt_data_model(translator: MagicPydanticRegistry): pydantic_attr_factory=split_comma_sep_string, ) + translator.register( + _data_model.CoveringGrid, + "level", + magicgui_factory=get_int_box_widget, + magicgui_kwargs={"name": "level"}, + pydantic_attr_factory=get_int_val, + ) + translator.register( + _data_model.CoveringGrid, + "num_ghost_zones", + magicgui_factory=get_int_box_widget, + magicgui_kwargs={"name": "num_ghost_zones"}, + pydantic_attr_factory=get_int_val, + ) + translator = MagicPydanticRegistry() _register_yt_data_model(translator) @@ -296,7 +324,7 @@ def get_yt_data_container( return data_container -_valid_selections = ("Region", "Slice") +_valid_selections = ("Region", "Slice", "CoveringGrid") def get_yt_selection_container(selection_type: str, return_native: bool = False): diff --git a/src/yt_napari/_model_ingestor.py b/src/yt_napari/_model_ingestor.py index 6779482..a9e4c08 100644 --- a/src/yt_napari/_model_ingestor.py +++ b/src/yt_napari/_model_ingestor.py @@ -8,6 +8,7 @@ from yt_napari import _special_loaders from yt_napari._data_model import ( + CoveringGrid, DataContainer, InputModel, MetadataModel, @@ -434,7 +435,13 @@ def _load_3D_regions( layer_list: list, timeseries_container: Optional[TimeseriesContainer] = None, ) -> list: - for sel in selections.regions: + + sels = [] + for seltype in ("regions", "covering_grids"): + if getattr(selections, seltype) is not None: + sels += [sel for sel in getattr(selections, seltype)] + + for sel in sels: # get the left, right edge as a unitful array, initialize the layer # domain tracking for this layer and update the global domain extent if sel.left_edge is None: @@ -446,16 +453,27 @@ def _load_3D_regions( RE = ds.domain_right_edge else: RE = ds.arr(sel.right_edge.value, sel.right_edge.unit) - res = sel.resolution - layer_domain = LayerDomain(left_edge=LE, right_edge=RE, resolution=res) - # create the fixed resolution buffer - frb = ds.r[ - LE[0] : RE[0] : complex(0, res[0]), # noqa: E203 - LE[1] : RE[1] : complex(0, res[1]), # noqa: E203 - LE[2] : RE[2] : complex(0, res[2]), # noqa: E203 - ] + if isinstance(sel, Region): + res = sel.resolution + frb = ds.r[ + LE[0] : RE[0] : complex(0, res[0]), # noqa: E203 + LE[1] : RE[1] : complex(0, res[1]), # noqa: E203 + LE[2] : RE[2] : complex(0, res[2]), # noqa: E203 + ] + elif isinstance(sel, CoveringGrid): + # get a temp covering grid with specified ghost zones then + # recalcuate dims at correct dds + dims = (4, 4, 4) + nghostzones = sel.num_ghost_zones + temp_cg = ds.covering_grid(sel.level, LE, dims, num_ghost_zones=nghostzones) + effective_dds = temp_cg.dds + dims = (RE - LE) / effective_dds + # get the actual covering grid + frb = ds.covering_grid(sel.level, LE, dims, num_ghost_zones=nghostzones) + res = dims + layer_domain = LayerDomain(left_edge=LE, right_edge=RE, resolution=res) for field_container in sel.fields: field = (field_container.field_type, field_container.field_name) @@ -600,7 +618,7 @@ def _load_selections_from_ds( layer_list: List[SpatialLayer], timeseries_container: Optional[TimeseriesContainer] = None, ) -> List[SpatialLayer]: - if selections.regions is not None: + if selections.regions is not None or selections.covering_grids is not None: layer_list = _load_3D_regions( ds, selections, layer_list, timeseries_container=timeseries_container ) diff --git a/src/yt_napari/_widget_reader.py b/src/yt_napari/_widget_reader.py index 5d77a3f..7a9d59b 100644 --- a/src/yt_napari/_widget_reader.py +++ b/src/yt_napari/_widget_reader.py @@ -122,7 +122,6 @@ def add_load_group_widgets(self): load_group.addWidget(ss.native) def save_selection(self): - py_kwargs = {} py_kwargs = self._validate_data_model() file_dialog = QFileDialog() @@ -145,13 +144,11 @@ def load_data(self): # instantiate pydantic objects, which are then handed off to the # same data ingestion function as the json loader. - py_kwargs = {} py_kwargs = self._validate_data_model() model = _data_model.InputModel.parse_obj(py_kwargs) # process each layer layer_list, _ = _model_ingestor._process_validated_model(model) - # align all layers after checking for or setting the reference layer ref_layer = _check_for_reference_layer(self.viewer.layers) if ref_layer is None: @@ -167,11 +164,13 @@ def load_data(self): self.viewer.add_image(im_arr, **im_kwargs) def _validate_data_model(self): - # this function save json data + selections_by_type = defaultdict(list) for selection in self.active_selections.values(): py_kwargs = selection.get_current_pydantic_kwargs() sel_key = selection.selection_type.lower() + "s" + if "covering" in sel_key: + sel_key = "covering_grids" selections_by_type[sel_key].append(py_kwargs) # next, process remaining arguments (skipping selections): @@ -281,7 +280,6 @@ def save_selection(self): json.dump(py_kwargs, json_file, indent=4) def load_data(self): - py_kwargs = {} py_kwargs = self._validate_data_model() model = _data_model.InputModel.parse_obj(py_kwargs)