From efec0670e52d325cd1ed91cd3574997b411acd72 Mon Sep 17 00:00:00 2001 From: Jakob Gamper <97gamjak@gmail.com> Date: Sun, 12 Nov 2023 21:22:58 +0100 Subject: [PATCH] reading of velocities, forces and charges implemented --- PQAnalysis/core/atom.py | 4 +- PQAnalysis/core/atomicSystem.py | 2 +- PQAnalysis/core/cell.py | 2 +- PQAnalysis/exceptions.py | 52 ++++++++++ PQAnalysis/io/trajectoryReader.py | 114 +++++++++++++++++++++- PQAnalysis/io/trajectoryWriter.py | 2 +- PQAnalysis/physicalData/energy.py | 2 +- PQAnalysis/traj/frame.py | 26 ++++- PQAnalysis/traj/trajectory.py | 38 +++++++- PQAnalysis/{utils/mytypes.py => types.py} | 3 - PQAnalysis/utils/exceptions.py | 17 ---- tests/core/test_atom.py | 2 +- tests/io/test_trajectoryReader.py | 45 +++++++-- 13 files changed, 271 insertions(+), 38 deletions(-) create mode 100644 PQAnalysis/exceptions.py rename PQAnalysis/{utils/mytypes.py => types.py} (92%) delete mode 100644 PQAnalysis/utils/exceptions.py diff --git a/PQAnalysis/core/atom.py b/PQAnalysis/core/atom.py index 76e1c7ca..173ba5a4 100644 --- a/PQAnalysis/core/atom.py +++ b/PQAnalysis/core/atom.py @@ -28,13 +28,11 @@ """ -import numpy as np - from multimethod import multimethod from beartype.typing import Any, Tuple from numbers import Real -from PQAnalysis.utils.exceptions import ElementNotFoundError +from PQAnalysis.exceptions import ElementNotFoundError def guess_element(id: int | str) -> Tuple[str, int, Real]: diff --git a/PQAnalysis/core/atomicSystem.py b/PQAnalysis/core/atomicSystem.py index 49b50560..d325bca7 100644 --- a/PQAnalysis/core/atomicSystem.py +++ b/PQAnalysis/core/atomicSystem.py @@ -23,7 +23,7 @@ from .atom import Atom from .cell import Cell -from ..utils.mytypes import Numpy2DFloatArray, Numpy1DFloatArray +from ..types import Numpy2DFloatArray, Numpy1DFloatArray def check_atoms_pos(func): diff --git a/PQAnalysis/core/cell.py b/PQAnalysis/core/cell.py index d67074c2..3c6baa3a 100644 --- a/PQAnalysis/core/cell.py +++ b/PQAnalysis/core/cell.py @@ -16,7 +16,7 @@ from beartype.typing import Any from numbers import Real -from ..utils.mytypes import Numpy3x3FloatArray, Numpy2DFloatArray, Numpy1DFloatArray +from ..types import Numpy3x3FloatArray, Numpy2DFloatArray, Numpy1DFloatArray class Cell: diff --git a/PQAnalysis/exceptions.py b/PQAnalysis/exceptions.py new file mode 100644 index 00000000..af08cf1d --- /dev/null +++ b/PQAnalysis/exceptions.py @@ -0,0 +1,52 @@ +""" +A module containing different exceptions which could be useful. + +... + +Classes +------- +PQException + Base class for exceptions in this module. +ElementNotFoundError + Exception raised if the given element id is not valid +TrajectoryFormatError + Exception raised if the given enum is not valid +""" + +from beartype.typing import Any + + +class PQException(Exception): + """ + Base class for exceptions in this module. + """ + + def __init__(self, message: str) -> None: + self.message = message + super().__init__(self.message) + + +class ElementNotFoundError(PQException): + """ + Exception raised if the given element id is not valid + """ + + def __init__(self, id: Any) -> None: + self.id = id + self.message = f"""Id {self.id} is not a valid element identifier.""" + super().__init__(self.message) + + +class TrajectoryFormatError(PQException): + """ + Exception raised if the given enum is not valid + """ + + def __init__(self, value: object, enum: object) -> None: + self.enum = enum + self.value = value + self.message = f""" +'{self.value}' is not a valid TrajectoryFormat. +Possible values are: {enum.member_repr()} +or their case insensitive string representation: {enum.value_repr()}""" + super().__init__(self.message) diff --git a/PQAnalysis/io/trajectoryReader.py b/PQAnalysis/io/trajectoryReader.py index f1b62dd6..ea2b3d80 100644 --- a/PQAnalysis/io/trajectoryReader.py +++ b/PQAnalysis/io/trajectoryReader.py @@ -22,8 +22,8 @@ from ..core.cell import Cell from ..core.atom import Atom from ..core.atomicSystem import AtomicSystem -from ..utils.exceptions import ElementNotFoundError -from ..utils.mytypes import Numpy2DFloatArray +from ..exceptions import ElementNotFoundError +from ..types import Numpy2DFloatArray, Numpy1DFloatArray class TrajectoryReader(BaseReader): @@ -130,6 +130,10 @@ def read(self, frame_string: str, format: TrajectoryFormat | str = TrajectoryFor return self.read_positions(frame_string) elif TrajectoryFormat(format) is TrajectoryFormat.VEL: return self.read_velocities(frame_string) + elif TrajectoryFormat(format) is TrajectoryFormat.FORCE: + return self.read_forces(frame_string) + elif TrajectoryFormat(format) is TrajectoryFormat.CHARGE: + return self.read_charges(frame_string) def read_positions(self, frame_string: str) -> Frame: """ @@ -199,6 +203,74 @@ def read_velocities(self, frame_string: str) -> Frame: return Frame(AtomicSystem(atoms=atoms, vel=vel, cell=cell)) + def read_forces(self, frame_string: str) -> Frame: + """ + Reads the forces of the atoms in a frame from a string. + + Parameters + ---------- + frame_string : str + The string to read the frame from. + + Returns + ------- + Frame + The frame read from the string. + + Raises + ------ + TypeError + If the given frame_string is not a string. + """ + + splitted_frame_string = frame_string.split('\n') + header_line = splitted_frame_string[0] + + n_atoms, cell = self._read_header_line(header_line) + + forces, atoms = self._read_xyz(splitted_frame_string, n_atoms) + + try: + atoms = [Atom(atom) for atom in atoms] + except ElementNotFoundError: + atoms = [Atom(atom, use_guess_element=False) for atom in atoms] + + return Frame(AtomicSystem(atoms=atoms, forces=forces, cell=cell)) + + def read_charges(self, frame_string: str) -> Frame: + """ + Reads the charge values of the atoms in a frame from a string. + + Parameters + ---------- + frame_string : str + The string to read the frame from. + + Returns + ------- + Frame + The frame read from the string. + + Raises + ------ + TypeError + If the given frame_string is not a string. + """ + + splitted_frame_string = frame_string.split('\n') + header_line = splitted_frame_string[0] + + n_atoms, cell = self._read_header_line(header_line) + + charges, atoms = self._read_scalar(splitted_frame_string, n_atoms) + + try: + atoms = [Atom(atom) for atom in atoms] + except ElementNotFoundError: + atoms = [Atom(atom, use_guess_element=False) for atom in atoms] + + return Frame(AtomicSystem(atoms=atoms, charges=charges, cell=cell)) + def _read_header_line(self, header_line: str) -> Tuple[int, Cell | None]: """ Reads the header line of a frame. @@ -281,3 +353,41 @@ def _read_xyz(self, splitted_frame_string: List[str], n_atoms: int) -> Tuple[Num atoms.append(line.split()[0]) return xyz, atoms + + def _read_scalar(self, splitted_frame_string: List[str], n_atoms: int) -> Tuple[Numpy1DFloatArray, List[str]]: + """ + Reads the scalar values and the atom names from the given string. + + Parameters + ---------- + splitted_frame_string : str + The string to read the scalar values and the atom names from. + n_atoms : int + The number of atoms in the frame. + + Returns + ------- + scalar : np.array + The scalar values of the atoms. + atoms : list of str + The names of the atoms. + + Raises + ------ + ValueError + If the given string does not contain the correct number of lines. + """ + + scalar = np.zeros((n_atoms)) + atoms = [] + for i in range(n_atoms): + line = splitted_frame_string[2+i] + + if len(line.split()) != 2: + raise ValueError( + 'Invalid file format in scalar values of Frame.') + + scalar[i] = float(line.split()[1]) + atoms.append(line.split()[0]) + + return scalar, atoms diff --git a/PQAnalysis/io/trajectoryWriter.py b/PQAnalysis/io/trajectoryWriter.py index c73bcaef..158f8c99 100644 --- a/PQAnalysis/io/trajectoryWriter.py +++ b/PQAnalysis/io/trajectoryWriter.py @@ -17,7 +17,7 @@ from ..traj.trajectory import Trajectory from ..core.cell import Cell from ..core.atom import Atom -from ..utils.mytypes import Numpy2DFloatArray +from ..types import Numpy2DFloatArray def write_trajectory(traj, filename: str | None = None, format: str | None = None) -> None: diff --git a/PQAnalysis/physicalData/energy.py b/PQAnalysis/physicalData/energy.py index 2d9181c8..a3fe8999 100644 --- a/PQAnalysis/physicalData/energy.py +++ b/PQAnalysis/physicalData/energy.py @@ -14,7 +14,7 @@ from beartype.typing import Dict from collections import defaultdict -from ..utils.mytypes import Numpy2DFloatArray, Numpy1DFloatArray +from ..types import Numpy2DFloatArray, Numpy1DFloatArray class Energy(): diff --git a/PQAnalysis/traj/frame.py b/PQAnalysis/traj/frame.py index dff1554f..4a2a9e68 100644 --- a/PQAnalysis/traj/frame.py +++ b/PQAnalysis/traj/frame.py @@ -19,7 +19,7 @@ from ..core.atomicSystem import AtomicSystem from ..core.atom import Atom from ..core.cell import Cell -from ..utils.mytypes import Numpy2DFloatArray, Numpy1DFloatArray +from ..types import Numpy2DFloatArray, Numpy1DFloatArray class Frame: @@ -202,6 +202,30 @@ def vel(self) -> Numpy2DFloatArray: """ return self.system.vel + @property + def forces(self) -> Numpy2DFloatArray: + """ + The forces on the atoms in the system. + + Returns + ------- + Numpy2DFloatArray + The forces on the atoms in the system. + """ + return self.system.forces + + @property + def charges(self) -> Numpy1DFloatArray: + """ + The charges of the atoms in the system. + + Returns + ------- + Numpy1DFloatArray + The charges of the atoms in the system. + """ + return self.system.charges + @property def atoms(self) -> List[Atom]: """ diff --git a/PQAnalysis/traj/trajectory.py b/PQAnalysis/traj/trajectory.py index 2c36b4e5..76321be3 100644 --- a/PQAnalysis/traj/trajectory.py +++ b/PQAnalysis/traj/trajectory.py @@ -5,6 +5,8 @@ Classes ------- +TrajectoryFormat + An enumeration of the supported trajectory formats. Trajectory A trajectory is a sequence of frames. """ @@ -17,6 +19,7 @@ from enum import Enum from .frame import Frame +from ..exceptions import TrajectoryFormatError class TrajectoryFormat(Enum): @@ -31,10 +34,16 @@ class TrajectoryFormat(Enum): The XYZ format. VEL : str The VEL format. + FORCE : str + The FORCE format. + CHARGE : str + The CHARGE format. """ XYZ = "XYZ" VEL = "VEL" + FORCE = "FORCE" + CHARGE = "CHARGE" @classmethod def _missing_(cls, value: object) -> Any: @@ -55,7 +64,34 @@ def _missing_(cls, value: object) -> Any: for member in cls: if member.value.lower() == value: return member - return None + + raise TrajectoryFormatError(value, cls) + + @classmethod + def member_repr(cls) -> str: + """ + This method returns a string representation of the members of the enumeration. + + Returns + ------- + str + A string representation of the members of the enumeration. + """ + + return ', '.join([str(member) for member in cls]) + + @classmethod + def value_repr(cls) -> str: + """ + This method returns a string representation of the values of the members of the enumeration. + + Returns + ------- + str + A string representation of the values of the members of the enumeration. + """ + + return ', '.join([str(member.value) for member in cls]) class Trajectory: diff --git a/PQAnalysis/utils/mytypes.py b/PQAnalysis/types.py similarity index 92% rename from PQAnalysis/utils/mytypes.py rename to PQAnalysis/types.py index 12ad8ab7..3679eb8a 100644 --- a/PQAnalysis/utils/mytypes.py +++ b/PQAnalysis/types.py @@ -3,9 +3,6 @@ from beartype.vale import Is from typing import Annotated -# Import the requisite machinery. -from beartype import beartype, BeartypeConf - Numpy2DFloatArray = Annotated[np.ndarray, Is[lambda array: array.ndim == 2 and diff --git a/PQAnalysis/utils/exceptions.py b/PQAnalysis/utils/exceptions.py deleted file mode 100644 index f8652b15..00000000 --- a/PQAnalysis/utils/exceptions.py +++ /dev/null @@ -1,17 +0,0 @@ -""" -A module containing different exceptions which could be useful. -""" - -from beartype.typing import Any - - -class ElementNotFoundError(Exception): - """ - Exception raised if the given element id is not valid - """ - - def __init__(self, id: Any) -> None: - self.id = id - self.message = f"""Id { - self.id} is not a valid element identifier.""" - super().__init__(self.message) diff --git a/tests/core/test_atom.py b/tests/core/test_atom.py index 14c08d19..b9b1aa1f 100644 --- a/tests/core/test_atom.py +++ b/tests/core/test_atom.py @@ -4,7 +4,7 @@ from multimethod import DispatchError from PQAnalysis.core.atom import Atom, guess_element -from PQAnalysis.utils.exceptions import ElementNotFoundError +from PQAnalysis.exceptions import ElementNotFoundError def test_guess_element(): diff --git a/tests/io/test_trajectoryReader.py b/tests/io/test_trajectoryReader.py index 67cfe5f9..06d36f6c 100644 --- a/tests/io/test_trajectoryReader.py +++ b/tests/io/test_trajectoryReader.py @@ -3,11 +3,12 @@ from beartype.roar import BeartypeException -from PQAnalysis.io.trajectoryReader import TrajectoryReader, FrameReader -from PQAnalysis.core.cell import Cell +from PQAnalysis.io.trajectoryReader import TrajectoryReader, FrameReader, TrajectoryFormat from PQAnalysis.traj.frame import Frame +from PQAnalysis.core.cell import Cell from PQAnalysis.core.atomicSystem import AtomicSystem from PQAnalysis.core.atom import Atom +from PQAnalysis.exceptions import TrajectoryFormatError class TestTrajectoryReader: @@ -56,7 +57,7 @@ def test_read(self): class TestFrameReader: - def test__read_header_line__(self): + def test__read_header_line(self): reader = FrameReader() with pytest.raises(ValueError) as exception: @@ -79,7 +80,7 @@ def test__read_header_line__(self): assert n_atoms == 3 assert cell is None - def test__read_xyz__(self): + def test__read_xyz(self): reader = FrameReader() with pytest.raises(ValueError) as exception: @@ -93,6 +94,18 @@ def test__read_xyz__(self): assert np.allclose(xyz, [[1.0, 2.0, 3.0], [2.0, 2.0, 2.0]]) assert atoms == ["h", "o"] + def test__read_scalar(self): + reader = FrameReader() + + with pytest.raises(ValueError) as exception: + reader._read_scalar(["", "", "h 1.0 2.0 3.0"], n_atoms=1) + assert str( + exception.value) == "Invalid file format in scalar values of Frame." + + scalar, atoms = reader._read_scalar(["", "", "h 1.0"], n_atoms=1) + assert np.allclose(scalar, [1.0]) + assert atoms == ["h"] + def test_read(self): reader = FrameReader() @@ -125,10 +138,30 @@ def test_read(self): [1.0, 2.0, 3.0], [2.0, 2.0, 2.0]]) assert frame.cell == Cell(2.0, 3.0, 4.0, 5.0, 6.0, 7.0) + frame = reader.read( + "2 2.0 3.0 4.0 5.0 6.0 7.0\n\nh 1.0 2.0 3.0\no1 2.0 2.0 2.0", format="force") + assert frame.n_atoms == 2 + assert frame.atoms == [Atom(atom, use_guess_element=False) + for atom in ["h", "o1"]] + assert np.allclose(frame.forces, [ + [1.0, 2.0, 3.0], [2.0, 2.0, 2.0]]) + assert frame.cell == Cell(2.0, 3.0, 4.0, 5.0, 6.0, 7.0) + + frame = reader.read( + "2 2.0 3.0 4.0 5.0 6.0 7.0\n\nh 1.0\no1 2.0", format="charge") + assert frame.n_atoms == 2 + assert frame.atoms == [Atom(atom, use_guess_element=False) + for atom in ["h", "o1"]] + assert np.allclose(frame.charges, [1.0, 2.0]) + assert frame.cell == Cell(2.0, 3.0, 4.0, 5.0, 6.0, 7.0) + def test_read_invalid_format(self): reader = FrameReader() - with pytest.raises(ValueError) as exception: + with pytest.raises(TrajectoryFormatError) as exception: reader.read("", format="invalid") assert str( - exception.value) == "'invalid' is not a valid TrajectoryFormat" + exception.value) == f""" +'invalid' is not a valid TrajectoryFormat. +Possible values are: {TrajectoryFormat.member_repr()} +or their case insensitive string representation: {TrajectoryFormat.value_repr()}"""