Skip to content

Commit

Permalink
Workshop fixes (#114)
Browse files Browse the repository at this point in the history
* Expose MLFFErrorAnalysis in `__init__.py`
* Fix calling plot with keyword arguments
* Fix string fix-length issue for MD energy plotting
* Replace calculations import with py4vasp import
* Fix access to HDF5 files for trajectories

---------
Co-authored-by: Martin Schlipf <martin.schlipf@gmail.com>
  • Loading branch information
sudarshanv01 authored Nov 6, 2023
1 parent ec10e3b commit 9eddc36
Show file tree
Hide file tree
Showing 9 changed files with 49 additions and 28 deletions.
1 change: 1 addition & 0 deletions src/py4vasp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright © VASP Software GmbH,
# Licensed under the Apache License 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
from py4vasp._analysis.mlff import MLFFErrorAnalysis
from py4vasp._calculation import Calculation
from py4vasp._calculations import Calculations
from py4vasp._third_party.graph import plot
Expand Down
11 changes: 8 additions & 3 deletions src/py4vasp/_analysis/mlff.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

import numpy as np

from py4vasp import Calculations, exception
import py4vasp
from py4vasp import exception


class MLFFErrorAnalysis:
Expand Down Expand Up @@ -71,7 +72,9 @@ def from_paths(cls, dft_data, mlff_data):
Path to the MLFF data. Accepts wildcards.
"""
mlff_error_analysis = cls(_internal=True)
calculations = Calculations.from_paths(dft_data=dft_data, mlff_data=mlff_data)
calculations = py4vasp.Calculations.from_paths(
dft_data=dft_data, mlff_data=mlff_data
)
mlff_error_analysis._calculations = calculations
set_appropriate_attrs(mlff_error_analysis)
return mlff_error_analysis
Expand All @@ -92,7 +95,9 @@ def from_files(cls, dft_data, mlff_data):
Path to the MLFF data. Accepts wildcards.
"""
mlff_error_analysis = cls(_internal=True)
calculations = Calculations.from_files(dft_data=dft_data, mlff_data=mlff_data)
calculations = py4vasp.Calculations.from_files(
dft_data=dft_data, mlff_data=mlff_data
)
mlff_error_analysis._calculations = calculations
set_appropriate_attrs(mlff_error_analysis)
return mlff_error_analysis
Expand Down
24 changes: 12 additions & 12 deletions src/py4vasp/_data/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,16 @@ def _selection_string(default):


_SELECTIONS = {
"ion-electron TOTEN ": ["ion_electron", "TOTEN"],
"kinetic energy EKIN ": ["kinetic_energy", "EKIN"],
"ion-electron TOTEN": ["ion_electron", "TOTEN"],
"kinetic energy EKIN": ["kinetic_energy", "EKIN"],
"kin. lattice EKIN_LAT": ["kinetic_lattice", "EKIN_LAT"],
"temperature TEIN ": ["temperature", "TEIN"],
"nose potential ES ": ["nose_potential", "ES"],
"nose kinetic EPS ": ["nose_kinetic", "EPS"],
"total energy ETOTAL ": ["total_energy", "ETOTAL"],
"free energy TOTEN ": ["free_energy", "TOTEN"],
"energy without entropy ": ["without_entropy", "ENOENT"],
"energy(sigma->0) ": ["sigma_0", "ESIG0"],
"temperature TEIN": ["temperature", "TEIN"],
"nose potential ES": ["nose_potential", "ES"],
"nose kinetic EPS": ["nose_kinetic", "EPS"],
"total energy ETOTAL": ["total_energy", "ETOTAL"],
"free energy TOTEN": ["free_energy", "TOTEN"],
"energy without entropy": ["without_entropy", "ENOENT"],
"energy(sigma->0)": ["sigma_0", "ESIG0"],
}


Expand Down Expand Up @@ -168,7 +168,7 @@ def _init_selection_dict(self):
return {
selection: index
for index, label in enumerate(self._raw_data.labels)
for selection in _SELECTIONS.get(convert.text_to_string(label), ())
for selection in _SELECTIONS.get(convert.text_to_string(label).strip(), ())
}

def _make_series(self, yaxes, tree):
Expand All @@ -188,9 +188,9 @@ def __init__(self, tree):
self.y2label = "Temperature (K)" if self.use_both else None

def _is_temperature(self, selection):
choices = _SELECTIONS["temperature TEIN "]
choices = _SELECTIONS["temperature TEIN"]
return any(select.contains(selection, choice) for choice in choices)

def use_y2(self, label):
choices = _SELECTIONS["temperature TEIN "]
choices = _SELECTIONS["temperature TEIN"]
return self.use_both and label in choices
3 changes: 3 additions & 0 deletions src/py4vasp/_data/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,8 @@ def _lattice_vectors(self):
return self._scale() * lattice_vectors[self._get_steps()]

def _scale(self):
if isinstance(self._raw_data.cell.scale, np.float_):
return self._raw_data.cell.scale
if not self._raw_data.cell.scale.is_none():
return self._raw_data.cell.scale[()]
else:
Expand Down Expand Up @@ -324,6 +326,7 @@ def _step_string(self):
else:
return f" (step {self._steps + 1})"

@base.data_access
def __getitem__(self, steps):
if not self._is_trajectory:
message = "The structure is not a Trajectory so accessing individual elements is not allowed."
Expand Down
15 changes: 11 additions & 4 deletions src/py4vasp/_third_party/graph/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,19 @@ def plot(*args, **kwargs):


def _parse_series(*args, **kwargs):
if series := _parse_multiple_series(*args, **kwargs):
return series
else:
return _parse_single_series(*args, **kwargs)


def _parse_multiple_series(*args, **kwargs):
try:
return [Series(*arg) for arg in args]
except TypeError:
# A TypeError is raised, if plot(x, y) is called instead of plot((x, y)).
# Because we creating the Series may raise another error, we leave the
# exception handling first to avoid reraising the TypeError.
pass
return []


def _parse_single_series(*args, **kwargs):
for_series = {key: val for key, val in kwargs.items() if key in Series._fields}
return Series(*args, **for_series)
12 changes: 6 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,13 +438,13 @@ def _polarization():

def _MD_energy(randomize: bool = False):
labels = (
"ion-electron TOTEN ",
"kinetic energy EKIN ",
"ion-electron TOTEN",
"kinetic energy EKIN",
"kin. lattice EKIN_LAT",
"temperature TEIN ",
"nose potential ES ",
"nose kinetic EPS ",
"total energy ETOTAL ",
"temperature TEIN",
"nose potential ES",
"nose kinetic EPS",
"total energy ETOTAL",
)
return _create_energy(labels, randomize=randomize)

Expand Down
6 changes: 4 additions & 2 deletions tests/data/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def inner(cls, data, parameters={}):
def check_instance_accesses_data(instance, data, parameters, file=None):
failed = []
for name, method in inspect.getmembers(instance, inspect.ismethod):
if should_test_method(name):
if should_test_method(name, parameters):
kwargs = parameters.get(name, {})
try:
check_method_accesses_data(data, method, file, **kwargs)
Expand All @@ -50,7 +50,9 @@ def check_instance_accesses_data(instance, data, parameters, file=None):
raise AssertionError(message)


def should_test_method(name):
def should_test_method(name, parameters):
if name in parameters:
return True
if name in ("__str__", "_repr_html_"):
return True
if name.startswith("from") or name.startswith("_"):
Expand Down
3 changes: 2 additions & 1 deletion tests/data/test_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,4 +360,5 @@ def test_print_Ca3AsBr3(Ca3AsBr3, format_):

def test_factory_methods(raw_data, check_factory_methods):
data = raw_data.structure("Sr2TiO4")
check_factory_methods(Structure, data)
parameters = {"__getitem__": {"steps": slice(None)}}
check_factory_methods(Structure, data, parameters)
2 changes: 2 additions & 0 deletions tests/third_party/graph/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ def test_plot():
assert plot((x1, y1)) == Graph([series0])
assert plot((x1, y1), (x2, y2, "label2")) == Graph([series0, series2])
assert plot((x1, y1), xlabel="xaxis") == Graph([series0], xlabel="xaxis")
assert plot(x1, y=y1) == Graph(series0)
assert plot(x=x1, y=y1) == Graph(series0)


def test_plot_small_dataset():
Expand Down

0 comments on commit 9eddc36

Please sign in to comment.