Skip to content

Commit

Permalink
Refactor clib to avoid checking GMT version repeatedly and only check…
Browse files Browse the repository at this point in the history
… once when loading the GMT library (#3254)
  • Loading branch information
seisman authored Jul 29, 2024
1 parent 1df8f19 commit 3a589fa
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 54 deletions.
2 changes: 1 addition & 1 deletion .github/ISSUE_TEMPLATE/5-bump_gmt_checklist.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ using the following command:
**To-Do for bumping the minimum required GMT version**:

- [ ] Bump the minimum required GMT version (1 PR)
- [ ] Update `required_version` in `pygmt/clib/session.py`
- [ ] Update `required_gmt_version` in `pygmt/clib/__init__.py`
- [ ] Update `test_get_default` in `pygmt/tests/test_clib.py`
- [ ] Update minimum required versions in `doc/minversions.md`
- [ ] Remove unsupported GMT version from `.github/workflows/ci_tests_legacy.yaml`
Expand Down
5 changes: 2 additions & 3 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@

# ruff: isort: off
from sphinx_gallery.sorting import ExplicitOrder, ExampleTitleSortKey
import pygmt
from pygmt.clib import required_gmt_version
from pygmt import __commit__, __version__
from pygmt.sphinx_gallery import PyGMTScraper

# ruff: isort: on

requires_python = metadata("pygmt")["Requires-Python"]
with pygmt.clib.Session() as lib:
requires_gmt = f">={lib.required_version}"
requires_gmt = f">={required_gmt_version}"

extensions = [
"myst_parser",
Expand Down
15 changes: 12 additions & 3 deletions pygmt/clib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,16 @@
interface. Access to the C library is done through ctypes.
"""

from pygmt.clib.session import Session
from packaging.version import Version
from pygmt.clib.session import Session, __gmt_version__
from pygmt.exceptions import GMTVersionError

with Session() as lib:
__gmt_version__ = lib.info["version"]
required_gmt_version = "6.3.0"

# Check if the GMT version is older than the required version.
if Version(__gmt_version__) < Version(required_gmt_version):
msg = (
f"Using an incompatible GMT version {__gmt_version__}. "
f"Must be equal or newer than {required_gmt_version}."
)
raise GMTVersionError(msg)
27 changes: 27 additions & 0 deletions pygmt/clib/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,33 @@ def load_libgmt(lib_fullnames: Iterator[str] | None = None) -> ctypes.CDLL:
return libgmt


def get_gmt_version(libgmt: ctypes.CDLL) -> str:
"""
Get the GMT version string of the GMT shared library.
Parameters
----------
libgmt
The GMT shared library.
Returns
-------
The GMT version string in *major.minor.patch* format.
"""
func = libgmt.GMT_Get_Version
func.argtypes = (
ctypes.c_void_p, # Unused parameter, so it can be None.
ctypes.POINTER(ctypes.c_uint), # major
ctypes.POINTER(ctypes.c_uint), # minor
ctypes.POINTER(ctypes.c_uint), # patch
)
# The function return value is the current library version as a float, e.g., 6.5.
func.restype = ctypes.c_float
major, minor, patch = ctypes.c_uint(0), ctypes.c_uint(0), ctypes.c_uint(0)
func(None, major, minor, patch)
return f"{major.value}.{minor.value}.{patch.value}"


def clib_names(os_name: str) -> list[str]:
"""
Return the name(s) of GMT's shared library for the current operating system.
Expand Down
31 changes: 4 additions & 27 deletions pygmt/clib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,9 @@
strings_to_ctypes_array,
vectors_to_arrays,
)
from pygmt.clib.loading import load_libgmt
from pygmt.clib.loading import get_gmt_version, load_libgmt
from pygmt.datatypes import _GMT_DATASET, _GMT_GRID, _GMT_IMAGE
from pygmt.exceptions import (
GMTCLibError,
GMTCLibNoSessionError,
GMTInvalidInput,
GMTVersionError,
)
from pygmt.exceptions import GMTCLibError, GMTCLibNoSessionError, GMTInvalidInput
from pygmt.helpers import (
_validate_data_input,
data_kind,
Expand Down Expand Up @@ -98,6 +93,7 @@

# Load the GMT library outside the Session class to avoid repeated loading.
_libgmt = load_libgmt()
__gmt_version__ = get_gmt_version(_libgmt)


class Session:
Expand Down Expand Up @@ -155,9 +151,6 @@ class Session:
-55 -47 -24 -10 190 981 1 1 8 14 1 1
"""

# The minimum supported GMT version.
required_version = "6.3.0"

@property
def session_pointer(self):
"""
Expand Down Expand Up @@ -212,27 +205,11 @@ def info(self):

def __enter__(self):
"""
Create a GMT API session and check the libgmt version.
Create a GMT API session.
Calls :meth:`pygmt.clib.Session.create`.
Raises
------
GMTVersionError
If the version reported by libgmt is less than
``Session.required_version``. Will destroy the session before
raising the exception.
"""
self.create("pygmt-session")
# Need to store the version info because 'get_default' won't work after
# the session is destroyed.
version = self.info["version"]
if Version(version) < Version(self.required_version):
self.destroy()
raise GMTVersionError(
f"Using an incompatible GMT version {version}. "
f"Must be equal or newer than {self.required_version}."
)
return self

def __exit__(self, exc_type, exc_value, traceback):
Expand Down
35 changes: 16 additions & 19 deletions pygmt/tests/test_clib.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,27 +577,24 @@ def mock_defaults(api, name, value): # noqa: ARG001
ses.destroy()


def test_fails_for_wrong_version():
def test_fails_for_wrong_version(monkeypatch):
"""
Make sure the clib.Session raises an exception if GMT is too old.
Make sure that importing clib raise an exception if GMT is too old.
"""
import importlib

# Mock GMT_Get_Default to return an old version
def mock_defaults(api, name, value): # noqa: ARG001
"""
Return an old version.
"""
if name == b"API_VERSION":
value.value = b"5.4.3"
else:
value.value = b"bla"
return 0
with monkeypatch.context() as mpatch:
# Make sure the current GMT major version is 6.
assert clib.__gmt_version__.split(".")[0] == "6"

lib = clib.Session()
with mock(lib, "GMT_Get_Default", mock_func=mock_defaults):
# Monkeypatch the version string returned by pygmt.clib.loading.get_gmt_version.
mpatch.setattr(clib.loading, "get_gmt_version", lambda libgmt: "5.4.3") # noqa: ARG005

# Reload clib.session and check the __gmt_version__ string.
importlib.reload(clib.session)
assert clib.session.__gmt_version__ == "5.4.3"

# Should raise an exception when pygmt.clib is loaded/reloaded.
with pytest.raises(GMTVersionError):
with lib:
assert lib.info["version"] != "5.4.3"
# Make sure the session is closed when the exception is raised.
with pytest.raises(GMTCLibNoSessionError):
assert lib.session_pointer
importlib.reload(clib)
assert clib.__gmt_version__ == "5.4.3" # Make sure it's still the old version
20 changes: 19 additions & 1 deletion pygmt/tests/test_clib_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@
from pathlib import PurePath

import pytest
from pygmt.clib.loading import check_libgmt, clib_full_names, clib_names, load_libgmt
from pygmt.clib.loading import (
check_libgmt,
clib_full_names,
clib_names,
get_gmt_version,
load_libgmt,
)
from pygmt.clib.session import Session
from pygmt.exceptions import GMTCLibError, GMTCLibNotFoundError, GMTOSError

Expand Down Expand Up @@ -360,3 +366,15 @@ def test_clib_full_names_gmt_library_path_incorrect_path_included(
# Windows: find_library() searches the library in PATH, so one more
npath = 2 if sys.platform == "win32" else 1
assert list(lib_fullpaths) == [gmt_lib_realpath] * npath + gmt_lib_names


###############################################################################
# Test get_gmt_version
def test_get_gmt_version():
"""
Test if get_gmt_version returns a version string in major.minor.patch format.
"""
version = get_gmt_version(load_libgmt())
assert isinstance(version, str)
assert len(version.split(".")) == 3 # In major.minor.patch format
assert version.split(".")[0] == "6" # Is GMT 6.x.x

0 comments on commit 3a589fa

Please sign in to comment.