Skip to content

Commit

Permalink
Enhance VTK export
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexey Pechnikov committed Mar 17, 2024
1 parent 3a18a5f commit 89499fc
Showing 1 changed file with 37 additions and 25 deletions.
62 changes: 37 additions & 25 deletions pygmtsar/pygmtsar/Stack_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def as_vtk(dataset):
#print ('bands', bands)
if bands in [3,4]:
# RGB or RGBA, select 3 bands only
array = vn.numpy_to_vtk(da[:3].values.reshape(3,-1).T, deep=True, array_type=VTK_UNSIGNED_CHAR)
array = vn.numpy_to_vtk(da[:3].round().astype(np.uint8).values.reshape(3,-1).T, deep=True, array_type=VTK_UNSIGNED_CHAR)
array.SetName(da.name)
sgrid.GetPointData().AddArray(array)
elif bands == 1:
Expand Down Expand Up @@ -358,11 +358,12 @@ def export_vtk(self, data, name, caption='Exporting WGS84 VTK(s)', topo='auto',
from tqdm.auto import tqdm
import os

assert isinstance(data, xr.DataArray), 'Argument data is not an xr.DataArray object'

assert data is None or isinstance(data, xr.DataArray), 'Argument data is not an xr.DataArray object or None'
assert data is not None or image is not None, 'One of arguments "data" or "image" needs to be specified'

# determine if data has a stack dimension and what it is
stackvar = data.dims[0] if len(data.dims) == 3 else None
#print ('stackvar', stackvar)
stackvar = data.dims[0] if data is not None and len(data.dims) == 3 else None
print ('stackvar', stackvar)
if stackvar is not None and np.issubdtype(data[stackvar].dtype, np.datetime64):
stackvals = data[stackvar].dt.date.astype(str).values
elif stackvar is not None:
Expand All @@ -375,54 +376,65 @@ def export_vtk(self, data, name, caption='Exporting WGS84 VTK(s)', topo='auto',
if os.path.exists(filename):
os.remove(filename)

# convert the data to geographic coordinates if necessary
if not self.is_geo(data):
data_ll = self.ra2ll(data)
else:
data_ll = data
# define 2D grid for interpolation
grid2d = data_ll.min(stackvar) if stackvar is not None else data_ll
dss = []
grid2d = None

if data is not None:
# convert the data to geographic coordinates if necessary
data_ll = self.ra2ll(data) if not self.is_geo(data) else data
# define 2D grid for interpolation
grid2d = data_ll.min(stackvar) if stackvar is not None else data_ll
dss.append(data_ll)

if image is not None:
image_ll = self.ra2ll(image) if not self.is_geo(image) else image
if not 'band' in image.dims:
image_ll = image_ll.expand_dims('band')
if grid2d is not None:
image_ll = image_ll.interp_like(grid2d, method='linear')
#.round().astype(np.uint8)
else:
grid2d = image_ll.isel(band=0)
dss.append(image_ll.rename('colors'))

#print ('dss', dss)
#print ('grid2d', grid2d)

dss = []
if isinstance(topo, str) and topo == 'auto':
dem = self.get_dem()
elif topo is not None:
# convert topography to geographic coordinates if necessary
dem = self.ra2ll(topo) if not self.is_geo(topo) else topo
if topo is not None:
dem = dem.reindex_like(grid2d, method='nearest')
dem = dem.interp_like(grid2d, method='linear')
if isinstance(mask, str) and mask == 'auto':
dem = dem.where(np.isfinite(grid2d))
elif mask is not None:
dem = dem.where(mask.reindex_like(grid2d, method='nearest'))
dss.append(dem.rename('z'))

if image is not None:
dss.append(image.reindex_like(grid2d, method='nearest').round().astype(np.uint8).rename('colors'))
#print ('dss', dss)
ds = xr.merge(dss, compat='override').rename({'lat': 'y', 'lon': 'x'})
#print ('ds', ds)

# prepare the progress bar
with tqdm(desc=caption, total=len(stackvals)) as pbar:
for stackidx, stackval in enumerate(stackvals):
#print (stackidx, stackval)
if stackval is not None:
ds = xr.merge([*dss, data_ll.isel({stackvar: stackidx})], compat='override')\
.rename({'lat': 'y', 'lon': 'x'})\
.drop_vars(stackvar)
out = ds.isel({stackvar: stackidx}).drop_vars(stackvar)
filename = f'{name}.{stackidx}.vtk'
else:
ds = xr.merge([*dss, data_ll], compat='override').rename({'lat': 'y', 'lon': 'x'})
out = ds
filename = f'{name}.vtk'
#print ('ds', ds)
vtk_grid = self.as_vtk(ds)
#print ('out', out)

vtk_grid = self.as_vtk(out)
if stackval is not None:
metadata = vtkStringArray()
metadata.SetName(stackvar)
metadata.InsertNextValue(stackval)
vtk_grid.GetFieldData().AddArray(metadata)

# convert to VTK structure and save to file
writer = vtkStructuredGridWriter()
writer.SetFileName(filename)
Expand Down

0 comments on commit 89499fc

Please sign in to comment.