From c8bf2dacd7700a1036c2956a35107323f265206d Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Sat, 8 Jul 2023 08:23:17 +0800 Subject: [PATCH] Figure.meca: Fix beachball offsetting for ndarray input (requires GMT>=6.5.0) (#2576) --- pygmt/src/meca.py | 244 +++++++++++++++++++++++++++++---------- pygmt/tests/test_meca.py | 59 +++++++++- 2 files changed, 236 insertions(+), 67 deletions(-) diff --git a/pygmt/src/meca.py b/pygmt/src/meca.py index d22c24552b0..f304bfe920e 100644 --- a/pygmt/src/meca.py +++ b/pygmt/src/meca.py @@ -8,44 +8,71 @@ from pygmt.helpers import build_arg_string, fmt_docstring, kwargs_to_strings, use_alias -def data_format_code(convention, component="full"): +def convention_code(convention, component="full"): """ - Determine the data format code for meca's -S option. + Determine the convention code for focal mechanisms. - See the meca() method for explanations of the parameters. + The convention code can be used in meca's -S option. + + Parameters + ---------- + convention : str + The focal mechanism convention. Can be one of the following: + + - ``"aki"``: Aki and Richards + - ``"gcmt"``: Global Centroid Moment Tensor + - ``"partial"``: Partial focal mechanism + - ``"mt"``: Moment tensor + - ``"principal_axis"``: Principal axis + + Single letter convention codes like ``"a"`` and ``"c"`` are also + supported but undocumented. + + component : str + The component of the focal mechanism. Only used when ``convention`` is + ``"mt"`` or ``"principal_axis"``. Can be one of the following: + + - ``"full"``: Full moment tensor + - ``"deviatoric"``: Deviatoric moment tensor + - ``"dc"``: Double couple + + Returns + ------- + str + The single-letter convention code used in meca's -S option. Examples -------- - >>> data_format_code("aki") + >>> convention_code("aki") 'a' - >>> data_format_code("gcmt") + >>> convention_code("gcmt") 'c' - >>> data_format_code("partial") + >>> convention_code("partial") 'p' - >>> data_format_code("mt", component="full") + >>> convention_code("mt", component="full") 'm' - >>> data_format_code("mt", component="deviatoric") + >>> convention_code("mt", component="deviatoric") 'z' - >>> data_format_code("mt", component="dc") + >>> convention_code("mt", component="dc") 'd' - >>> data_format_code("principal_axis", component="full") + >>> convention_code("principal_axis", component="full") 'x' - >>> data_format_code("principal_axis", component="deviatoric") + >>> convention_code("principal_axis", component="deviatoric") 't' - >>> data_format_code("principal_axis", component="dc") + >>> convention_code("principal_axis", component="dc") 'y' >>> for code in ["a", "c", "m", "d", "z", "p", "x", "y", "t"]: - ... assert data_format_code(code) == code + ... assert convention_code(code) == code ... - >>> data_format_code("invalid") + >>> convention_code("invalid") Traceback (most recent call last): ... pygmt.exceptions.GMTInvalidInput: Invalid convention 'invalid'. - >>> data_format_code("mt", "invalid") # doctest: +NORMALIZE_WHITESPACE + >>> convention_code("mt", "invalid") # doctest: +NORMALIZE_WHITESPACE Traceback (most recent call last): ... pygmt.exceptions.GMTInvalidInput: @@ -73,6 +100,90 @@ def data_format_code(convention, component="full"): raise GMTInvalidInput(f"Invalid convention '{convention}'.") +def convention_name(code): + """ + Determine the name of a focal mechanism convention from its code. + + Parameters + ---------- + code : str + The single-letter convention code. + + Returns + ------- + str + The name of the focal mechanism convention. + + Examples + -------- + >>> convention_name("a") + 'aki' + >>> convention_name("aki") + 'aki' + """ + name = { + "a": "aki", + "c": "gcmt", + "p": "partial", + "z": "mt", + "d": "mt", + "m": "mt", + "x": "principal_axis", + "y": "principal_axis", + "t": "principal_axis", + }.get(code) + return name if name is not None else code + + +def convention_params(convention): + """ + Return the list of focal mechanism parameters for a given convention. + + Parameters + ---------- + convention : str + The focal mechanism convention. Can be one of the following: + + - ``"aki"``: Aki and Richards + - ``"gcmt"``: Global Centroid Moment Tensor + - ``"partial"``: Partial focal mechanism + - ``"mt"``: Moment tensor + - ``"principal_axis"``: Principal axis + + Returns + ------- + list + The list of focal mechanism parameters. + """ + return { + "aki": ["strike", "dip", "rake", "magnitude"], + "gcmt": [ + "strike1", + "dip1", + "rake1", + "strike2", + "dip2", + "rake2", + "mantissa", + "exponent", + ], + "mt": ["mrr", "mtt", "mff", "mrt", "mrf", "mtf", "exponent"], + "partial": ["strike1", "dip1", "strike2", "fault_type", "magnitude"], + "pricipal_axis": [ + "t_value", + "t_azimuth", + "t_plunge", + "n_value", + "n_azimuth", + "n_plunge", + "p_value", + "p_azimuth", + "p_plunge", + "exponent", + ], + }[convention] + + @fmt_docstring @use_alias( A="offset", @@ -287,38 +398,14 @@ def meca( {transparency} """ # pylint: disable=too-many-arguments,too-many-locals,too-many-branches + # pylint: disable=too-many-statements kwargs = self._preprocess(**kwargs) # pylint: disable=protected-access + + # Convert spec to pandas.DataFrame unless it's a file if isinstance(spec, (dict, pd.DataFrame)): # spec is a dict or pd.DataFrame - param_conventions = { - "aki": ["strike", "dip", "rake", "magnitude"], - "gcmt": [ - "strike1", - "dip1", - "rake1", - "strike2", - "dip2", - "rake2", - "mantissa", - "exponent", - ], - "mt": ["mrr", "mtt", "mff", "mrt", "mrf", "mtf", "exponent"], - "partial": ["strike1", "dip1", "strike2", "fault_type", "magnitude"], - "pricipal_axis": [ - "t_value", - "t_azimuth", - "t_plunge", - "n_value", - "n_azimuth", - "n_plunge", - "p_value", - "p_azimuth", - "p_plunge", - "exponent", - ], - } # determine convention from dict keys or pd.DataFrame column names - for conv, paras in param_conventions.items(): - if set(paras).issubset(set(spec.keys())): + for conv in ["aki", "gcmt", "mt", "partial", "pricipal_axis"]: + if set(convention_params(conv)).issubset(set(spec.keys())): convention = conv break else: @@ -328,8 +415,42 @@ def meca( msg = "Column names in pd.DataFrame 'spec' do not match known conventions." raise GMTError(msg) - # override the values in dict/pd.DataFrame if parameters are explicity - # specified + # convert dict to pd.DataFrame so columns can be reordered + if isinstance(spec, dict): + # convert values to ndarray so pandas doesn't complain about "all + # scalar values". See + # https://github.com/GenericMappingTools/pygmt/pull/2174 + spec = pd.DataFrame( + {key: np.atleast_1d(value) for key, value in spec.items()} + ) + elif isinstance(spec, np.ndarray): # spec is a numpy array + if convention is None: + raise GMTInvalidInput("'convention' must be specified for an array input.") + # make sure convention is a name, not a code + convention = convention_name(convention) + + # Convert array to pd.DataFrame and assign column names + spec = pd.DataFrame(np.atleast_2d(spec)) + colnames = ["longitude", "latitude", "depth"] + convention_params(convention) + # check if spec has the expected number of columns + ncolsdiff = len(spec.columns) - len(colnames) + if ncolsdiff == 0: + pass + elif ncolsdiff == 1: + colnames += ["event_name"] + elif ncolsdiff == 2: + colnames += ["plot_longitude", "plot_latitude"] + elif ncolsdiff == 3: + colnames += ["plot_longitude", "plot_latitude", "event_name"] + else: + raise GMTInvalidInput( + f"Input array must have {len(colnames)} to {len(colnames) + 3} columns." + ) + spec.columns = colnames + + # Now spec is a pd.DataFrame or a file + if isinstance(spec, pd.DataFrame): + # override the values in pd.DataFrame if parameters are given if longitude is not None: spec["longitude"] = np.atleast_1d(longitude) if latitude is not None: @@ -341,38 +462,33 @@ def meca( if plot_latitude is not None: spec["plot_latitude"] = np.atleast_1d(plot_latitude) if event_name is not None: - spec["event_name"] = np.atleast_1d(event_name).astype(str) + spec["event_name"] = np.atleast_1d(event_name) - # convert dict to pd.DataFrame so columns can be reordered - if isinstance(spec, dict): - # convert values to ndarray so pandas doesn't complain about "all - # scalar values". See - # https://github.com/GenericMappingTools/pygmt/pull/2174 - spec = {key: np.atleast_1d(value) for key, value in spec.items()} - spec = pd.DataFrame(spec) + # Due to the internal implementation of the meca module, we need to + # convert the following columns to strings if they exist + if "plot_longitude" in spec.columns and "plot_latitude" in spec.columns: + spec["plot_longitude"] = spec["plot_longitude"].astype(str) + spec["plot_latitude"] = spec["plot_latitude"].astype(str) + if "event_name" in spec.columns: + spec["event_name"] = spec["event_name"].astype(str) + # Reorder columns in DataFrame to match convention if necessary # expected columns are: # longitude, latitude, depth, focal_parameters, # [plot_longitude, plot_latitude] [event_name] - newcols = ["longitude", "latitude", "depth"] + param_conventions[convention] + newcols = ["longitude", "latitude", "depth"] + convention_params(convention) if "plot_longitude" in spec.columns and "plot_latitude" in spec.columns: newcols += ["plot_longitude", "plot_latitude"] - spec[["plot_longitude", "plot_latitude"]] = spec[ - ["plot_longitude", "plot_latitude"] - ].astype(str) if kwargs.get("A") is None: kwargs["A"] = True if "event_name" in spec.columns: newcols += ["event_name"] - spec["event_name"] = spec["event_name"].astype(str) # reorder columns in DataFrame - spec = spec.reindex(newcols, axis=1) - elif isinstance(spec, np.ndarray) and spec.ndim == 1: - # Convert 1-D array into 2-D array - spec = np.atleast_2d(spec) + if spec.columns.tolist() != newcols: + spec = spec.reindex(newcols, axis=1) # determine data_format from convention and component - data_format = data_format_code(convention=convention, component=component) + data_format = convention_code(convention=convention, component=component) # Assemble -S flag kwargs["S"] = f"{data_format}{scale}" diff --git a/pygmt/tests/test_meca.py b/pygmt/tests/test_meca.py index f54f6b7d16a..e63f346587c 100644 --- a/pygmt/tests/test_meca.py +++ b/pygmt/tests/test_meca.py @@ -6,6 +6,7 @@ import pytest from packaging.version import Version from pygmt import Figure, __gmt_version__ +from pygmt.exceptions import GMTInvalidInput from pygmt.helpers import GMTTempFile @@ -143,12 +144,25 @@ def test_meca_spec_multiple_focalmecha(inputtype): @pytest.mark.mpl_image_compare(filename="test_meca_offset.png") -@pytest.mark.parametrize("inputtype", ["offset_args", "offset_dict"]) +@pytest.mark.parametrize( + "inputtype", + [ + "args", + "dict", + pytest.param( + "ndarray", + marks=pytest.mark.skipif( + condition=Version(__gmt_version__) < Version("6.5.0"), + reason="Upstream bug fixed in https://github.com/GenericMappingTools/gmt/pull/7557", + ), + ), + ], +) def test_meca_offset(inputtype): """ Test offsetting beachballs. """ - if inputtype == "offset_args": + if inputtype == "args": args = { "spec": {"strike": 330, "dip": 30, "rake": 90, "magnitude": 3}, "longitude": -124, @@ -157,7 +171,7 @@ def test_meca_offset(inputtype): "plot_longitude": -124.5, "plot_latitude": 47.5, } - elif inputtype == "offset_dict": + elif inputtype == "dict": # Test https://github.com/GenericMappingTools/pygmt/issues/2016 # offset parameters are in the dict. args = { @@ -173,6 +187,13 @@ def test_meca_offset(inputtype): "latitude": 48, "depth": 12.0, } + elif inputtype == "ndarray": + # Test ndarray input reported in + # https://github.com/GenericMappingTools/pygmt/issues/2016 + args = { + "spec": np.array([[-124, 48, 12.0, 330, 30, 90, 3, -124.5, 47.5]]), + "convention": "aki", + } fig = Figure() fig.basemap(region=[-125, -122, 47, 49], projection="M6c", frame=True) @@ -277,3 +298,35 @@ def test_meca_spec_dict_all_scalars(): scale=1.0, # make sure a non-str scale works ) return fig + + +def test_meca_spec_ndarray_no_convention(): + """ + Raise an exception if convention is not given for an ndarray input. + """ + with pytest.raises(GMTInvalidInput): + fig = Figure() + fig.basemap(region=[-125, -122, 47, 49], projection="M6c", frame=True) + fig.meca(spec=np.array([[-124, 48, 12.0, 330, 30, 90, 3]]), scale="1c") + + +def test_meca_spec_ndarray_mismatched_columns(): + """ + Raise an exception if the ndarray input doesn't have the expected number of + columns. + """ + with pytest.raises(GMTInvalidInput): + fig = Figure() + fig.basemap(region=[-125, -122, 47, 49], projection="M6c", frame=True) + fig.meca( + spec=np.array([[-124, 48, 12.0, 330, 30, 90]]), convention="aki", scale="1c" + ) + + with pytest.raises(GMTInvalidInput): + fig = Figure() + fig.basemap(region=[-125, -122, 47, 49], projection="M6c", frame=True) + fig.meca( + spec=np.array([[-124, 48, 12.0, 330, 30, 90, 3, -124.5, 47.5, 30.0, 50.0]]), + convention="aki", + scale="1c", + )