diff --git a/pygmt/src/_common.py b/pygmt/src/_common.py new file mode 100644 index 00000000000..96ae9791b71 --- /dev/null +++ b/pygmt/src/_common.py @@ -0,0 +1,41 @@ +""" +Common functions used in multiple PyGMT functions/methods. +""" + +from pathlib import Path +from typing import Any + +from pygmt.src.which import which + + +def _data_geometry_is_point(data: Any, kind: str) -> bool: + """ + Check if the geometry of the input data is Point or MultiPoint. + + The inptu data can be a GeoJSON object or a OGR_GMT file. + + This function is used in ``Figure.plot`` and ``Figure.plot3d``. + + Parameters + ---------- + data + The data being plotted. + kind + The data kind. + + Returns + ------- + bool + ``True`` if the geometry is Point/MultiPoint, ``False`` otherwise. + """ + if kind == "geojson" and data.geom_type.isin(["Point", "MultiPoint"]).all(): + return True + if kind == "file" and str(data).endswith(".gmt"): # OGR_GMT file + try: + with Path(which(data)).open(encoding="utf-8") as file: + line = file.readline() + if "@GMULTIPOINT" in line or "@GPOINT" in line: + return True + except FileNotFoundError: + pass + return False diff --git a/pygmt/src/plot.py b/pygmt/src/plot.py index e66f08438e5..61db357bfde 100644 --- a/pygmt/src/plot.py +++ b/pygmt/src/plot.py @@ -2,8 +2,6 @@ plot - Plot in two dimensions. """ -from pathlib import Path - from pygmt.clib import Session from pygmt.exceptions import GMTInvalidInput from pygmt.helpers import ( @@ -14,7 +12,7 @@ kwargs_to_strings, use_alias, ) -from pygmt.src.which import which +from pygmt.src._common import _data_geometry_is_point @fmt_docstring @@ -50,9 +48,7 @@ w="wrap", ) @kwargs_to_strings(R="sequence", c="sequence_comma", i="sequence_comma", p="sequence") -def plot( # noqa: PLR0912 - self, data=None, x=None, y=None, size=None, direction=None, **kwargs -): +def plot(self, data=None, x=None, y=None, size=None, direction=None, **kwargs): r""" Plot lines, polygons, and symbols in 2-D. @@ -242,17 +238,8 @@ def plot( # noqa: PLR0912 raise GMTInvalidInput(f"'{name}' can't be 1-D array if 'data' is used.") # Set the default style if data has a geometry of Point or MultiPoint - if kwargs.get("S") is None: - if kind == "geojson" and data.geom_type.isin(["Point", "MultiPoint"]).all(): - kwargs["S"] = "s0.2c" - elif kind == "file" and str(data).endswith(".gmt"): # OGR_GMT file - try: - with Path(which(data)).open(encoding="utf-8") as file: - line = file.readline() - if "@GMULTIPOINT" in line or "@GPOINT" in line: - kwargs["S"] = "s0.2c" - except FileNotFoundError: - pass + if kwargs.get("S") is None and _data_geometry_is_point(data, kind): + kwargs["S"] = "s0.2c" with Session() as lib: with lib.virtualfile_in( diff --git a/pygmt/src/plot3d.py b/pygmt/src/plot3d.py index c86e5e259f1..7c7a78f2ab3 100644 --- a/pygmt/src/plot3d.py +++ b/pygmt/src/plot3d.py @@ -2,8 +2,6 @@ plot3d - Plot in three dimensions. """ -from pathlib import Path - from pygmt.clib import Session from pygmt.exceptions import GMTInvalidInput from pygmt.helpers import ( @@ -14,7 +12,7 @@ kwargs_to_strings, use_alias, ) -from pygmt.src.which import which +from pygmt.src._common import _data_geometry_is_point @fmt_docstring @@ -51,7 +49,7 @@ w="wrap", ) @kwargs_to_strings(R="sequence", c="sequence_comma", i="sequence_comma", p="sequence") -def plot3d( # noqa: PLR0912 +def plot3d( self, data=None, x=None, y=None, z=None, size=None, direction=None, **kwargs ): r""" @@ -218,17 +216,8 @@ def plot3d( # noqa: PLR0912 raise GMTInvalidInput(f"'{name}' can't be 1-D array if 'data' is used.") # Set the default style if data has a geometry of Point or MultiPoint - if kwargs.get("S") is None: - if kind == "geojson" and data.geom_type.isin(["Point", "MultiPoint"]).all(): - kwargs["S"] = "u0.2c" - elif kind == "file" and str(data).endswith(".gmt"): # OGR_GMT file - try: - with Path(which(data)).open(encoding="utf-8") as file: - line = file.readline() - if "@GMULTIPOINT" in line or "@GPOINT" in line: - kwargs["S"] = "u0.2c" - except FileNotFoundError: - pass + if kwargs.get("S") is None and _data_geometry_is_point(data, kind): + kwargs["S"] = "u0.2c" with Session() as lib: with lib.virtualfile_in(