Skip to content

Commit

Permalink
Wrap GMT's standard data type GMT_IMAGE for images (#3338)
Browse files Browse the repository at this point in the history
Co-authored-by: Wei Ji <23487320+weiji14@users.noreply.github.com>
  • Loading branch information
seisman and weiji14 authored Jul 27, 2024
1 parent ff246c6 commit 537d684
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 11 deletions.
28 changes: 17 additions & 11 deletions pygmt/clib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
vectors_to_arrays,
)
from pygmt.clib.loading import load_libgmt
from pygmt.datatypes import _GMT_DATASET, _GMT_GRID
from pygmt.datatypes import _GMT_DATASET, _GMT_GRID, _GMT_IMAGE
from pygmt.exceptions import (
GMTCLibError,
GMTCLibNoSessionError,
Expand Down Expand Up @@ -1071,7 +1071,7 @@ def put_matrix(self, dataset, matrix, pad=0):
def read_data(
self,
infile: str,
kind: Literal["dataset", "grid"],
kind: Literal["dataset", "grid", "image"],
family: str | None = None,
geometry: str | None = None,
mode: str = "GMT_READ_NORMAL",
Expand All @@ -1089,8 +1089,8 @@ def read_data(
infile
The input file name.
kind
The data kind of the input file. Valid values are ``"dataset"`` and
``"grid"``.
The data kind of the input file. Valid values are ``"dataset"``, ``"grid"``
and ``"image"``.
family
A valid GMT data family name (e.g., ``"GMT_IS_DATASET"``). See the
``FAMILIES`` attribute for valid names. If ``None``, will determine the data
Expand Down Expand Up @@ -1141,6 +1141,7 @@ def read_data(
_family, _geometry, dtype = {
"dataset": ("GMT_IS_DATASET", "GMT_IS_PLP", _GMT_DATASET),
"grid": ("GMT_IS_GRID", "GMT_IS_SURFACE", _GMT_GRID),
"image": ("GMT_IS_IMAGE", "GMT_IS_SURFACE", _GMT_IMAGE),
}[kind]
if family is None:
family = _family
Expand Down Expand Up @@ -1797,7 +1798,9 @@ def virtualfile_from_data(

@contextlib.contextmanager
def virtualfile_out(
self, kind: Literal["dataset", "grid"] = "dataset", fname: str | None = None
self,
kind: Literal["dataset", "grid", "image"] = "dataset",
fname: str | None = None,
) -> Generator[str, None, None]:
r"""
Create a virtual file or an actual file for storing output data.
Expand All @@ -1810,8 +1813,8 @@ def virtualfile_out(
Parameters
----------
kind
The data kind of the virtual file to create. Valid values are ``"dataset"``
and ``"grid"``. Ignored if ``fname`` is specified.
The data kind of the virtual file to create. Valid values are ``"dataset"``,
``"grid"``, and ``"image"``. Ignored if ``fname`` is specified.
fname
The name of the actual file to write the output data. No virtual file will
be created.
Expand Down Expand Up @@ -1854,8 +1857,10 @@ def virtualfile_out(
family, geometry = {
"dataset": ("GMT_IS_DATASET", "GMT_IS_PLP"),
"grid": ("GMT_IS_GRID", "GMT_IS_SURFACE"),
"image": ("GMT_IS_IMAGE", "GMT_IS_SURFACE"),
}[kind]
with self.open_virtualfile(family, geometry, "GMT_OUT", None) as vfile:
direction = "GMT_OUT|GMT_IS_REFERENCE" if kind == "image" else "GMT_OUT"
with self.open_virtualfile(family, geometry, direction, None) as vfile:
yield vfile

def inquire_virtualfile(self, vfname: str) -> int:
Expand Down Expand Up @@ -1901,7 +1906,8 @@ def read_virtualfile(
Name of the virtual file to read.
kind
Cast the data into a GMT data container. Valid values are ``"dataset"``,
``"grid"`` and ``None``. If ``None``, will return a ctypes void pointer.
``"grid"``, ``"image"`` and ``None``. If ``None``, will return a ctypes void
pointer.
Returns
-------
Expand Down Expand Up @@ -1951,9 +1957,9 @@ def read_virtualfile(
# _GMT_DATASET).
if kind is None: # Return the ctypes void pointer
return pointer
if kind in {"image", "cube"}:
if kind == "cube":
raise NotImplementedError(f"kind={kind} is not supported yet.")
dtype = {"dataset": _GMT_DATASET, "grid": _GMT_GRID}[kind]
dtype = {"dataset": _GMT_DATASET, "grid": _GMT_GRID, "image": _GMT_IMAGE}[kind]
return ctp.cast(pointer, ctp.POINTER(dtype))

def virtualfile_to_dataset(
Expand Down
1 change: 1 addition & 0 deletions pygmt/datatypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@

from pygmt.datatypes.dataset import _GMT_DATASET
from pygmt.datatypes.grid import _GMT_GRID
from pygmt.datatypes.image import _GMT_IMAGE
94 changes: 94 additions & 0 deletions pygmt/datatypes/image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""
Wrapper for the GMT_IMAGE data type.
"""

import ctypes as ctp
from typing import ClassVar

from pygmt.datatypes.grid import _GMT_GRID_HEADER


class _GMT_IMAGE(ctp.Structure): # noqa: N801
"""
GMT image data structure.
Examples
--------
>>> import numpy as np
>>> from pygmt.clib import Session
>>> with Session() as lib:
... with lib.virtualfile_out(kind="image") as voutimg:
... lib.call_module("read", ["@earth_day_01d", voutimg, "-Ti"])
... # Read the image from the virtual file
... image = lib.read_virtualfile(vfname=voutimg, kind="image").contents
... # The image header
... header = image.header.contents
... # Access the header properties
... print(header.n_rows, header.n_columns, header.registration)
... print(header.wesn[:], header.inc[:])
... print(header.z_scale_factor, header.z_add_offset)
... print(header.x_units, header.y_units, header.z_units)
... print(header.title)
... print(header.command)
... print(header.remark)
... print(header.nm, header.size, header.complex_mode)
... print(header.type, header.n_bands, header.mx, header.my)
... print(header.pad[:])
... print(header.mem_layout, header.nan_value, header.xy_off)
... # Image-specific attributes.
... print(image.type, image.n_indexed_colors)
... # The x and y coordinates
... x = image.x[: header.n_columns]
... y = image.y[: header.n_rows]
... # The data array (with paddings)
... data = np.reshape(
... image.data[: header.n_bands * header.mx * header.my],
... (header.my, header.mx, header.n_bands),
... )
... # The data array (without paddings)
... pad = header.pad[:]
... data = data[pad[2] : header.my - pad[3], pad[0] : header.mx - pad[1], :]
180 360 1
[-180.0, 180.0, -90.0, 90.0] [1.0, 1.0]
1.0 0.0
b'x' b'y' b'z'
b''
b''
b''
64800 66976 0
0 3 364 184
[2, 2, 2, 2]
b'BRPa' 0.0 0.5
1 0
>>> x
[-179.5, -178.5, ..., 178.5, 179.5]
>>> y
[89.5, 88.5, ..., -88.5, -89.5]
>>> data.shape
(180, 360, 3)
>>> data.min(), data.max()
(10, 255)
"""

_fields_: ClassVar = [
# Data type, e.g. GMT_FLOAT
("type", ctp.c_int),
# Array with color lookup values
("colormap", ctp.POINTER(ctp.c_int)),
# Number of colors in a paletted image
("n_indexed_colors", ctp.c_int),
# Pointer to full GMT header for the image
("header", ctp.POINTER(_GMT_GRID_HEADER)),
# Pointer to actual image
("data", ctp.POINTER(ctp.c_ubyte)),
# Pointer to an optional transparency layer stored in a separate variable
("alpha", ctp.POINTER(ctp.c_ubyte)),
# Color interpolation
("color_interp", ctp.c_char_p),
# Pointer to the x-coordinate vector
("x", ctp.POINTER(ctp.c_double)),
# Pointer to the y-coordinate vector
("y", ctp.POINTER(ctp.c_double)),
# Book-keeping variables "hidden" from the API
("hidden", ctp.c_void_p),
]
37 changes: 37 additions & 0 deletions pygmt/tests/test_clib_read_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,43 @@ def test_clib_read_data_grid_actual_image():
)


# Note: Simplify the tests for images after GMT_IMAGE.to_dataarray() is implemented.
def test_clib_read_data_image():
"""
Test the Session.read_data method for images.
"""
with Session() as lib:
image = lib.read_data("@earth_day_01d_p", kind="image").contents
header = image.header.contents
assert header.n_rows == 180
assert header.n_columns == 360
assert header.n_bands == 3
assert header.wesn[:] == [-180.0, 180.0, -90.0, 90.0]
assert image.data


def test_clib_read_data_image_two_steps():
"""
Test the Session.read_data method for images in two steps, first reading the header
and then the data.
"""
infile = "@earth_day_01d_p"
with Session() as lib:
# Read the header first
data_ptr = lib.read_data(infile, kind="image", mode="GMT_CONTAINER_ONLY")
image = data_ptr.contents
header = image.header.contents
assert header.n_rows == 180
assert header.n_columns == 360
assert header.wesn[:] == [-180.0, 180.0, -90.0, 90.0]
assert header.n_bands == 3 # Explicitly check n_bands
assert not image.data # The data is not read yet

# Read the data
lib.read_data(infile, kind="image", mode="GMT_DATA_ONLY", data=data_ptr)
assert image.data


def test_clib_read_data_fails():
"""
Test that the Session.read_data method raises an exception if there are errors.
Expand Down

0 comments on commit 537d684

Please sign in to comment.