Skip to content

Commit

Permalink
Figure.plot & Figure.plot3d: Move common codes into _common.py (#3461)
Browse files Browse the repository at this point in the history
  • Loading branch information
seisman committed Sep 29, 2024
1 parent d2fbb38 commit f97c3a4
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 32 deletions.
41 changes: 41 additions & 0 deletions pygmt/src/_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""
Common functions used in multiple PyGMT functions/methods.
"""

from pathlib import Path
from typing import Any

from pygmt.src.which import which


def _data_geometry_is_point(data: Any, kind: str) -> bool:
"""
Check if the geometry of the input data is Point or MultiPoint.
The inptu data can be a GeoJSON object or a OGR_GMT file.
This function is used in ``Figure.plot`` and ``Figure.plot3d``.
Parameters
----------
data
The data being plotted.
kind
The data kind.
Returns
-------
bool
``True`` if the geometry is Point/MultiPoint, ``False`` otherwise.
"""
if kind == "geojson" and data.geom_type.isin(["Point", "MultiPoint"]).all():
return True
if kind == "file" and str(data).endswith(".gmt"): # OGR_GMT file
try:
with Path(which(data)).open(encoding="utf-8") as file:
line = file.readline()
if "@GMULTIPOINT" in line or "@GPOINT" in line:
return True
except FileNotFoundError:
pass
return False
21 changes: 4 additions & 17 deletions pygmt/src/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
plot - Plot in two dimensions.
"""

from pathlib import Path

from pygmt.clib import Session
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers import (
Expand All @@ -14,7 +12,7 @@
kwargs_to_strings,
use_alias,
)
from pygmt.src.which import which
from pygmt.src._common import _data_geometry_is_point


@fmt_docstring
Expand Down Expand Up @@ -50,9 +48,7 @@
w="wrap",
)
@kwargs_to_strings(R="sequence", c="sequence_comma", i="sequence_comma", p="sequence")
def plot( # noqa: PLR0912
self, data=None, x=None, y=None, size=None, direction=None, **kwargs
):
def plot(self, data=None, x=None, y=None, size=None, direction=None, **kwargs):
r"""
Plot lines, polygons, and symbols in 2-D.
Expand Down Expand Up @@ -242,17 +238,8 @@ def plot( # noqa: PLR0912
raise GMTInvalidInput(f"'{name}' can't be 1-D array if 'data' is used.")

# Set the default style if data has a geometry of Point or MultiPoint
if kwargs.get("S") is None:
if kind == "geojson" and data.geom_type.isin(["Point", "MultiPoint"]).all():
kwargs["S"] = "s0.2c"
elif kind == "file" and str(data).endswith(".gmt"): # OGR_GMT file
try:
with Path(which(data)).open(encoding="utf-8") as file:
line = file.readline()
if "@GMULTIPOINT" in line or "@GPOINT" in line:
kwargs["S"] = "s0.2c"
except FileNotFoundError:
pass
if kwargs.get("S") is None and _data_geometry_is_point(data, kind):
kwargs["S"] = "s0.2c"

with Session() as lib:
with lib.virtualfile_in(
Expand Down
19 changes: 4 additions & 15 deletions pygmt/src/plot3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
plot3d - Plot in three dimensions.
"""

from pathlib import Path

from pygmt.clib import Session
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers import (
Expand All @@ -14,7 +12,7 @@
kwargs_to_strings,
use_alias,
)
from pygmt.src.which import which
from pygmt.src._common import _data_geometry_is_point


@fmt_docstring
Expand Down Expand Up @@ -51,7 +49,7 @@
w="wrap",
)
@kwargs_to_strings(R="sequence", c="sequence_comma", i="sequence_comma", p="sequence")
def plot3d( # noqa: PLR0912
def plot3d(
self, data=None, x=None, y=None, z=None, size=None, direction=None, **kwargs
):
r"""
Expand Down Expand Up @@ -218,17 +216,8 @@ def plot3d( # noqa: PLR0912
raise GMTInvalidInput(f"'{name}' can't be 1-D array if 'data' is used.")

# Set the default style if data has a geometry of Point or MultiPoint
if kwargs.get("S") is None:
if kind == "geojson" and data.geom_type.isin(["Point", "MultiPoint"]).all():
kwargs["S"] = "u0.2c"
elif kind == "file" and str(data).endswith(".gmt"): # OGR_GMT file
try:
with Path(which(data)).open(encoding="utf-8") as file:
line = file.readline()
if "@GMULTIPOINT" in line or "@GPOINT" in line:
kwargs["S"] = "u0.2c"
except FileNotFoundError:
pass
if kwargs.get("S") is None and _data_geometry_is_point(data, kind):
kwargs["S"] = "u0.2c"

with Session() as lib:
with lib.virtualfile_in(
Expand Down

0 comments on commit f97c3a4

Please sign in to comment.