diff --git a/src/py4vasp/__init__.py b/src/py4vasp/__init__.py index 812a9621..60643461 100644 --- a/src/py4vasp/__init__.py +++ b/src/py4vasp/__init__.py @@ -1,5 +1,6 @@ # Copyright © VASP Software GmbH, # Licensed under the Apache License 2.0 (http://www.apache.org/licenses/LICENSE-2.0) +from py4vasp._analysis.mlff import MLFFErrorAnalysis from py4vasp._calculation import Calculation from py4vasp._calculations import Calculations from py4vasp._third_party.graph import plot diff --git a/src/py4vasp/_analysis/mlff.py b/src/py4vasp/_analysis/mlff.py index 4b78b221..a6bdfd29 100644 --- a/src/py4vasp/_analysis/mlff.py +++ b/src/py4vasp/_analysis/mlff.py @@ -5,7 +5,8 @@ import numpy as np -from py4vasp import Calculations, exception +import py4vasp +from py4vasp import exception class MLFFErrorAnalysis: @@ -71,7 +72,9 @@ def from_paths(cls, dft_data, mlff_data): Path to the MLFF data. Accepts wildcards. """ mlff_error_analysis = cls(_internal=True) - calculations = Calculations.from_paths(dft_data=dft_data, mlff_data=mlff_data) + calculations = py4vasp.Calculations.from_paths( + dft_data=dft_data, mlff_data=mlff_data + ) mlff_error_analysis._calculations = calculations set_appropriate_attrs(mlff_error_analysis) return mlff_error_analysis @@ -92,7 +95,9 @@ def from_files(cls, dft_data, mlff_data): Path to the MLFF data. Accepts wildcards. """ mlff_error_analysis = cls(_internal=True) - calculations = Calculations.from_files(dft_data=dft_data, mlff_data=mlff_data) + calculations = py4vasp.Calculations.from_files( + dft_data=dft_data, mlff_data=mlff_data + ) mlff_error_analysis._calculations = calculations set_appropriate_attrs(mlff_error_analysis) return mlff_error_analysis diff --git a/src/py4vasp/_data/energy.py b/src/py4vasp/_data/energy.py index f84dd8d6..30aa691b 100644 --- a/src/py4vasp/_data/energy.py +++ b/src/py4vasp/_data/energy.py @@ -18,16 +18,16 @@ def _selection_string(default): _SELECTIONS = { - "ion-electron TOTEN ": ["ion_electron", "TOTEN"], - "kinetic energy EKIN ": ["kinetic_energy", "EKIN"], + "ion-electron TOTEN": ["ion_electron", "TOTEN"], + "kinetic energy EKIN": ["kinetic_energy", "EKIN"], "kin. lattice EKIN_LAT": ["kinetic_lattice", "EKIN_LAT"], - "temperature TEIN ": ["temperature", "TEIN"], - "nose potential ES ": ["nose_potential", "ES"], - "nose kinetic EPS ": ["nose_kinetic", "EPS"], - "total energy ETOTAL ": ["total_energy", "ETOTAL"], - "free energy TOTEN ": ["free_energy", "TOTEN"], - "energy without entropy ": ["without_entropy", "ENOENT"], - "energy(sigma->0) ": ["sigma_0", "ESIG0"], + "temperature TEIN": ["temperature", "TEIN"], + "nose potential ES": ["nose_potential", "ES"], + "nose kinetic EPS": ["nose_kinetic", "EPS"], + "total energy ETOTAL": ["total_energy", "ETOTAL"], + "free energy TOTEN": ["free_energy", "TOTEN"], + "energy without entropy": ["without_entropy", "ENOENT"], + "energy(sigma->0)": ["sigma_0", "ESIG0"], } @@ -168,7 +168,7 @@ def _init_selection_dict(self): return { selection: index for index, label in enumerate(self._raw_data.labels) - for selection in _SELECTIONS.get(convert.text_to_string(label), ()) + for selection in _SELECTIONS.get(convert.text_to_string(label).strip(), ()) } def _make_series(self, yaxes, tree): @@ -188,9 +188,9 @@ def __init__(self, tree): self.y2label = "Temperature (K)" if self.use_both else None def _is_temperature(self, selection): - choices = _SELECTIONS["temperature TEIN "] + choices = _SELECTIONS["temperature TEIN"] return any(select.contains(selection, choice) for choice in choices) def use_y2(self, label): - choices = _SELECTIONS["temperature TEIN "] + choices = _SELECTIONS["temperature TEIN"] return self.use_both and label in choices diff --git a/src/py4vasp/_data/structure.py b/src/py4vasp/_data/structure.py index a90fe7c7..04d5fad6 100644 --- a/src/py4vasp/_data/structure.py +++ b/src/py4vasp/_data/structure.py @@ -291,6 +291,8 @@ def _lattice_vectors(self): return self._scale() * lattice_vectors[self._get_steps()] def _scale(self): + if isinstance(self._raw_data.cell.scale, np.float_): + return self._raw_data.cell.scale if not self._raw_data.cell.scale.is_none(): return self._raw_data.cell.scale[()] else: @@ -324,6 +326,7 @@ def _step_string(self): else: return f" (step {self._steps + 1})" + @base.data_access def __getitem__(self, steps): if not self._is_trajectory: message = "The structure is not a Trajectory so accessing individual elements is not allowed." diff --git a/src/py4vasp/_third_party/graph/plot.py b/src/py4vasp/_third_party/graph/plot.py index 94ef0522..fa89a10e 100644 --- a/src/py4vasp/_third_party/graph/plot.py +++ b/src/py4vasp/_third_party/graph/plot.py @@ -37,12 +37,19 @@ def plot(*args, **kwargs): def _parse_series(*args, **kwargs): + if series := _parse_multiple_series(*args, **kwargs): + return series + else: + return _parse_single_series(*args, **kwargs) + + +def _parse_multiple_series(*args, **kwargs): try: return [Series(*arg) for arg in args] except TypeError: - # A TypeError is raised, if plot(x, y) is called instead of plot((x, y)). - # Because we creating the Series may raise another error, we leave the - # exception handling first to avoid reraising the TypeError. - pass + return [] + + +def _parse_single_series(*args, **kwargs): for_series = {key: val for key, val in kwargs.items() if key in Series._fields} return Series(*args, **for_series) diff --git a/tests/conftest.py b/tests/conftest.py index 30991aff..f3acc139 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -438,13 +438,13 @@ def _polarization(): def _MD_energy(randomize: bool = False): labels = ( - "ion-electron TOTEN ", - "kinetic energy EKIN ", + "ion-electron TOTEN", + "kinetic energy EKIN", "kin. lattice EKIN_LAT", - "temperature TEIN ", - "nose potential ES ", - "nose kinetic EPS ", - "total energy ETOTAL ", + "temperature TEIN", + "nose potential ES", + "nose kinetic EPS", + "total energy ETOTAL", ) return _create_energy(labels, randomize=randomize) diff --git a/tests/data/conftest.py b/tests/data/conftest.py index e968a3f2..6e9e27c3 100644 --- a/tests/data/conftest.py +++ b/tests/data/conftest.py @@ -36,7 +36,7 @@ def inner(cls, data, parameters={}): def check_instance_accesses_data(instance, data, parameters, file=None): failed = [] for name, method in inspect.getmembers(instance, inspect.ismethod): - if should_test_method(name): + if should_test_method(name, parameters): kwargs = parameters.get(name, {}) try: check_method_accesses_data(data, method, file, **kwargs) @@ -50,7 +50,9 @@ def check_instance_accesses_data(instance, data, parameters, file=None): raise AssertionError(message) -def should_test_method(name): +def should_test_method(name, parameters): + if name in parameters: + return True if name in ("__str__", "_repr_html_"): return True if name.startswith("from") or name.startswith("_"): diff --git a/tests/data/test_structure.py b/tests/data/test_structure.py index aac6c07d..c16598cc 100644 --- a/tests/data/test_structure.py +++ b/tests/data/test_structure.py @@ -360,4 +360,5 @@ def test_print_Ca3AsBr3(Ca3AsBr3, format_): def test_factory_methods(raw_data, check_factory_methods): data = raw_data.structure("Sr2TiO4") - check_factory_methods(Structure, data) + parameters = {"__getitem__": {"steps": slice(None)}} + check_factory_methods(Structure, data, parameters) diff --git a/tests/third_party/graph/test_plot.py b/tests/third_party/graph/test_plot.py index 11d289ab..3a424d3d 100644 --- a/tests/third_party/graph/test_plot.py +++ b/tests/third_party/graph/test_plot.py @@ -19,6 +19,8 @@ def test_plot(): assert plot((x1, y1)) == Graph([series0]) assert plot((x1, y1), (x2, y2, "label2")) == Graph([series0, series2]) assert plot((x1, y1), xlabel="xaxis") == Graph([series0], xlabel="xaxis") + assert plot(x1, y=y1) == Graph(series0) + assert plot(x=x1, y=y1) == Graph(series0) def test_plot_small_dataset():