From 3a589fa58bdf5a179ee46b61a1b2de3f3d9092a4 Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Mon, 29 Jul 2024 10:25:02 +0800 Subject: [PATCH] Refactor clib to avoid checking GMT version repeatedly and only check once when loading the GMT library (#3254) --- .../ISSUE_TEMPLATE/5-bump_gmt_checklist.md | 2 +- doc/conf.py | 5 ++- pygmt/clib/__init__.py | 15 ++++++-- pygmt/clib/loading.py | 27 ++++++++++++++ pygmt/clib/session.py | 31 +++------------- pygmt/tests/test_clib.py | 35 +++++++++---------- pygmt/tests/test_clib_loading.py | 20 ++++++++++- 7 files changed, 81 insertions(+), 54 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/5-bump_gmt_checklist.md b/.github/ISSUE_TEMPLATE/5-bump_gmt_checklist.md index a4591f12847..9652f2150ee 100644 --- a/.github/ISSUE_TEMPLATE/5-bump_gmt_checklist.md +++ b/.github/ISSUE_TEMPLATE/5-bump_gmt_checklist.md @@ -35,7 +35,7 @@ using the following command: **To-Do for bumping the minimum required GMT version**: - [ ] Bump the minimum required GMT version (1 PR) - - [ ] Update `required_version` in `pygmt/clib/session.py` + - [ ] Update `required_gmt_version` in `pygmt/clib/__init__.py` - [ ] Update `test_get_default` in `pygmt/tests/test_clib.py` - [ ] Update minimum required versions in `doc/minversions.md` - [ ] Remove unsupported GMT version from `.github/workflows/ci_tests_legacy.yaml` diff --git a/doc/conf.py b/doc/conf.py index b8a6c6ce4fc..9348960f325 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -7,15 +7,14 @@ # ruff: isort: off from sphinx_gallery.sorting import ExplicitOrder, ExampleTitleSortKey -import pygmt +from pygmt.clib import required_gmt_version from pygmt import __commit__, __version__ from pygmt.sphinx_gallery import PyGMTScraper # ruff: isort: on requires_python = metadata("pygmt")["Requires-Python"] -with pygmt.clib.Session() as lib: - requires_gmt = f">={lib.required_version}" +requires_gmt = f">={required_gmt_version}" extensions = [ "myst_parser", diff --git a/pygmt/clib/__init__.py b/pygmt/clib/__init__.py index 868616f2345..9f145716f0e 100644 --- a/pygmt/clib/__init__.py +++ b/pygmt/clib/__init__.py @@ -5,7 +5,16 @@ interface. Access to the C library is done through ctypes. """ -from pygmt.clib.session import Session +from packaging.version import Version +from pygmt.clib.session import Session, __gmt_version__ +from pygmt.exceptions import GMTVersionError -with Session() as lib: - __gmt_version__ = lib.info["version"] +required_gmt_version = "6.3.0" + +# Check if the GMT version is older than the required version. +if Version(__gmt_version__) < Version(required_gmt_version): + msg = ( + f"Using an incompatible GMT version {__gmt_version__}. " + f"Must be equal or newer than {required_gmt_version}." + ) + raise GMTVersionError(msg) diff --git a/pygmt/clib/loading.py b/pygmt/clib/loading.py index 7bcf576b9b6..9b785fe826e 100644 --- a/pygmt/clib/loading.py +++ b/pygmt/clib/loading.py @@ -64,6 +64,33 @@ def load_libgmt(lib_fullnames: Iterator[str] | None = None) -> ctypes.CDLL: return libgmt +def get_gmt_version(libgmt: ctypes.CDLL) -> str: + """ + Get the GMT version string of the GMT shared library. + + Parameters + ---------- + libgmt + The GMT shared library. + + Returns + ------- + The GMT version string in *major.minor.patch* format. + """ + func = libgmt.GMT_Get_Version + func.argtypes = ( + ctypes.c_void_p, # Unused parameter, so it can be None. + ctypes.POINTER(ctypes.c_uint), # major + ctypes.POINTER(ctypes.c_uint), # minor + ctypes.POINTER(ctypes.c_uint), # patch + ) + # The function return value is the current library version as a float, e.g., 6.5. + func.restype = ctypes.c_float + major, minor, patch = ctypes.c_uint(0), ctypes.c_uint(0), ctypes.c_uint(0) + func(None, major, minor, patch) + return f"{major.value}.{minor.value}.{patch.value}" + + def clib_names(os_name: str) -> list[str]: """ Return the name(s) of GMT's shared library for the current operating system. diff --git a/pygmt/clib/session.py b/pygmt/clib/session.py index ec5bdcf3be2..829381d339d 100644 --- a/pygmt/clib/session.py +++ b/pygmt/clib/session.py @@ -25,14 +25,9 @@ strings_to_ctypes_array, vectors_to_arrays, ) -from pygmt.clib.loading import load_libgmt +from pygmt.clib.loading import get_gmt_version, load_libgmt from pygmt.datatypes import _GMT_DATASET, _GMT_GRID, _GMT_IMAGE -from pygmt.exceptions import ( - GMTCLibError, - GMTCLibNoSessionError, - GMTInvalidInput, - GMTVersionError, -) +from pygmt.exceptions import GMTCLibError, GMTCLibNoSessionError, GMTInvalidInput from pygmt.helpers import ( _validate_data_input, data_kind, @@ -98,6 +93,7 @@ # Load the GMT library outside the Session class to avoid repeated loading. _libgmt = load_libgmt() +__gmt_version__ = get_gmt_version(_libgmt) class Session: @@ -155,9 +151,6 @@ class Session: -55 -47 -24 -10 190 981 1 1 8 14 1 1 """ - # The minimum supported GMT version. - required_version = "6.3.0" - @property def session_pointer(self): """ @@ -212,27 +205,11 @@ def info(self): def __enter__(self): """ - Create a GMT API session and check the libgmt version. + Create a GMT API session. Calls :meth:`pygmt.clib.Session.create`. - - Raises - ------ - GMTVersionError - If the version reported by libgmt is less than - ``Session.required_version``. Will destroy the session before - raising the exception. """ self.create("pygmt-session") - # Need to store the version info because 'get_default' won't work after - # the session is destroyed. - version = self.info["version"] - if Version(version) < Version(self.required_version): - self.destroy() - raise GMTVersionError( - f"Using an incompatible GMT version {version}. " - f"Must be equal or newer than {self.required_version}." - ) return self def __exit__(self, exc_type, exc_value, traceback): diff --git a/pygmt/tests/test_clib.py b/pygmt/tests/test_clib.py index f833e01a37b..7591aff0729 100644 --- a/pygmt/tests/test_clib.py +++ b/pygmt/tests/test_clib.py @@ -577,27 +577,24 @@ def mock_defaults(api, name, value): # noqa: ARG001 ses.destroy() -def test_fails_for_wrong_version(): +def test_fails_for_wrong_version(monkeypatch): """ - Make sure the clib.Session raises an exception if GMT is too old. + Make sure that importing clib raise an exception if GMT is too old. """ + import importlib - # Mock GMT_Get_Default to return an old version - def mock_defaults(api, name, value): # noqa: ARG001 - """ - Return an old version. - """ - if name == b"API_VERSION": - value.value = b"5.4.3" - else: - value.value = b"bla" - return 0 + with monkeypatch.context() as mpatch: + # Make sure the current GMT major version is 6. + assert clib.__gmt_version__.split(".")[0] == "6" - lib = clib.Session() - with mock(lib, "GMT_Get_Default", mock_func=mock_defaults): + # Monkeypatch the version string returned by pygmt.clib.loading.get_gmt_version. + mpatch.setattr(clib.loading, "get_gmt_version", lambda libgmt: "5.4.3") # noqa: ARG005 + + # Reload clib.session and check the __gmt_version__ string. + importlib.reload(clib.session) + assert clib.session.__gmt_version__ == "5.4.3" + + # Should raise an exception when pygmt.clib is loaded/reloaded. with pytest.raises(GMTVersionError): - with lib: - assert lib.info["version"] != "5.4.3" - # Make sure the session is closed when the exception is raised. - with pytest.raises(GMTCLibNoSessionError): - assert lib.session_pointer + importlib.reload(clib) + assert clib.__gmt_version__ == "5.4.3" # Make sure it's still the old version diff --git a/pygmt/tests/test_clib_loading.py b/pygmt/tests/test_clib_loading.py index df92a73fbd0..3c65c5caf46 100644 --- a/pygmt/tests/test_clib_loading.py +++ b/pygmt/tests/test_clib_loading.py @@ -11,7 +11,13 @@ from pathlib import PurePath import pytest -from pygmt.clib.loading import check_libgmt, clib_full_names, clib_names, load_libgmt +from pygmt.clib.loading import ( + check_libgmt, + clib_full_names, + clib_names, + get_gmt_version, + load_libgmt, +) from pygmt.clib.session import Session from pygmt.exceptions import GMTCLibError, GMTCLibNotFoundError, GMTOSError @@ -360,3 +366,15 @@ def test_clib_full_names_gmt_library_path_incorrect_path_included( # Windows: find_library() searches the library in PATH, so one more npath = 2 if sys.platform == "win32" else 1 assert list(lib_fullpaths) == [gmt_lib_realpath] * npath + gmt_lib_names + + +############################################################################### +# Test get_gmt_version +def test_get_gmt_version(): + """ + Test if get_gmt_version returns a version string in major.minor.patch format. + """ + version = get_gmt_version(load_libgmt()) + assert isinstance(version, str) + assert len(version.split(".")) == 3 # In major.minor.patch format + assert version.split(".")[0] == "6" # Is GMT 6.x.x