From 955f06c0bfd396b13634492484fcfc7679df9260 Mon Sep 17 00:00:00 2001 From: Jakob Gamper <97gamjak@gmail.com> Date: Sat, 11 Nov 2023 23:16:05 +0100 Subject: [PATCH 1/6] added readvelocity and trajectory format class --- PQAnalysis/io/trajectoryReader.py | 81 ++++++++++++++++++++++++++++--- PQAnalysis/traj/trajectory.py | 19 ++++++++ tests/io/test_trajectoryReader.py | 12 ++--- 3 files changed, 98 insertions(+), 14 deletions(-) diff --git a/PQAnalysis/io/trajectoryReader.py b/PQAnalysis/io/trajectoryReader.py index 9bd32f23..13a475bb 100644 --- a/PQAnalysis/io/trajectoryReader.py +++ b/PQAnalysis/io/trajectoryReader.py @@ -14,10 +14,11 @@ import numpy as np from beartype.typing import Tuple, List +from enum import Enum from .base import BaseReader from ..traj.frame import Frame -from ..traj.trajectory import Trajectory +from ..traj.trajectory import Trajectory, TrajectoryFormat from ..core.cell import Cell from ..core.atom import Atom from ..core.atomicSystem import AtomicSystem @@ -41,7 +42,7 @@ class TrajectoryReader(BaseReader): The list of frames read from the file. """ - def __init__(self, filename: str) -> None: + def __init__(self, filename: str, format: TrajectoryFormat | str = TrajectoryFormat.XYZ) -> None: """ Initializes the TrajectoryReader with the given filename. @@ -52,6 +53,7 @@ def __init__(self, filename: str) -> None: """ super().__init__(filename) self.frames = [] + self.format = format def read(self) -> Trajectory: """ @@ -79,7 +81,8 @@ def read(self) -> Trajectory: frame_string += line elif line.split()[0].isdigit(): if frame_string != '': - self.frames.append(frame_reader.read(frame_string)) + self.frames.append(frame_reader.read( + frame_string, format=self.format)) frame_string = line else: frame_string += line @@ -99,10 +102,38 @@ class FrameReader: FrameReader reads a frame from a string. """ - def read(self, frame_string: str) -> Frame: + def read(self, frame_string: str, format: TrajectoryFormat | str = TrajectoryFormat.XYZ) -> Frame: """ Reads a frame from a string. + Parameters + ---------- + frame_string : str + The string to read the frame from. + format : TrajectoryFormat | str, optional + The format of the trajectory. Default is TrajectoryFormat.XYZ. + + Returns + ------- + Frame + The frame read from the string. + + Raises + ------ + ValueError + If the given format is not valid. + """ + if format == TrajectoryFormat.XYZ: + return self.read_positions(frame_string) + elif format == TrajectoryFormat.VELOCS: + return self.read_velocities(frame_string) + else: + raise ValueError('Invalid format.') + + def read_positions(self, frame_string: str) -> Frame: + """ + Reads the positions of the atoms in a frame from a string. + Parameters ---------- frame_string : str @@ -122,9 +153,9 @@ def read(self, frame_string: str) -> Frame: splitted_frame_string = frame_string.split('\n') header_line = splitted_frame_string[0] - n_atoms, cell = self.__read_header_line__(header_line) + n_atoms, cell = self._read_header_line(header_line) - xyz, atoms = self.__read_xyz__(splitted_frame_string, n_atoms) + xyz, atoms = self._read_xyz(splitted_frame_string, n_atoms) try: atoms = [Atom(atom) for atom in atoms] @@ -133,7 +164,41 @@ def read(self, frame_string: str) -> Frame: return Frame(AtomicSystem(atoms=atoms, pos=xyz, cell=cell)) - def __read_header_line__(self, header_line: str) -> Tuple[int, Cell | None]: + def read_velocities(self, frame_string: str) -> Frame: + """ + Reads the velocities of the atoms in a frame from a string. + + Parameters + ---------- + frame_string : str + The string to read the frame from. + + Returns + ------- + Frame + The frame read from the string. + + Raises + ------ + TypeError + If the given frame_string is not a string. + """ + + splitted_frame_string = frame_string.split('\n') + header_line = splitted_frame_string[0] + + n_atoms, cell = self._read_header_line(header_line) + + velocs, atoms = self._read_xyz(splitted_frame_string, n_atoms) + + try: + atoms = [Atom(atom) for atom in atoms] + except ElementNotFoundError: + atoms = [Atom(atom, use_guess_element=False) for atom in atoms] + + return Frame(AtomicSystem(atoms=atoms, velocs=velocs, cell=cell)) + + def _read_header_line(self, header_line: str) -> Tuple[int, Cell | None]: """ Reads the header line of a frame. @@ -178,7 +243,7 @@ def __read_header_line__(self, header_line: str) -> Tuple[int, Cell | None]: return n_atoms, cell - def __read_xyz__(self, splitted_frame_string: List[str], n_atoms: int) -> Tuple[Numpy2DFloatArray, List[str]]: + def _read_xyz(self, splitted_frame_string: List[str], n_atoms: int) -> Tuple[Numpy2DFloatArray, List[str]]: """ Reads the xyz coordinates and the atom names from the given string. diff --git a/PQAnalysis/traj/trajectory.py b/PQAnalysis/traj/trajectory.py index b3507dc5..3d2b9263 100644 --- a/PQAnalysis/traj/trajectory.py +++ b/PQAnalysis/traj/trajectory.py @@ -14,10 +14,29 @@ import numpy as np from beartype.typing import List, Iterator, Any +from enum import Enum from .frame import Frame +class TrajectoryFormat(Enum): + """ + An enumeration of the supported trajectory formats. + + ... + + Attributes + ---------- + XYZ : str + The XYZ format. + VELOCS : str + The VELOCS format. + """ + + XYZ = "xyz" + VELOCS = "velocs" + + class Trajectory: """ A trajectory object is a sequence of frames. diff --git a/tests/io/test_trajectoryReader.py b/tests/io/test_trajectoryReader.py index 204ad49b..293a5c9d 100644 --- a/tests/io/test_trajectoryReader.py +++ b/tests/io/test_trajectoryReader.py @@ -60,22 +60,22 @@ def test__read_header_line__(self): reader = FrameReader() with pytest.raises(ValueError) as exception: - reader.__read_header_line__("1 2.0 3.0") + reader._read_header_line("1 2.0 3.0") assert str( exception.value) == "Invalid file format in header line of Frame." - n_atoms, cell = reader.__read_header_line__( + n_atoms, cell = reader._read_header_line( "1 2.0 3.0 4.0 5.0 6.0 7.0") assert n_atoms == 1 assert np.allclose(cell.box_lengths, [2.0, 3.0, 4.0]) assert np.allclose(cell.box_angles, [5.0, 6.0, 7.0]) - n_atoms, cell = reader.__read_header_line__("2 2.0 3.0 4.0") + n_atoms, cell = reader._read_header_line("2 2.0 3.0 4.0") assert n_atoms == 2 assert np.allclose(cell.box_lengths, [2.0, 3.0, 4.0]) assert np.allclose(cell.box_angles, [90.0, 90.0, 90.0]) - n_atoms, cell = reader.__read_header_line__("3") + n_atoms, cell = reader._read_header_line("3") assert n_atoms == 3 assert cell is None @@ -83,12 +83,12 @@ def test__read_xyz__(self): reader = FrameReader() with pytest.raises(ValueError) as exception: - reader.__read_xyz__( + reader._read_xyz( ["", "", "h 1.0 2.0 3.0", "o 2.0 2.0"], n_atoms=2) assert str( exception.value) == "Invalid file format in xyz coordinates of Frame." - xyz, atoms = reader.__read_xyz__( + xyz, atoms = reader._read_xyz( ["", "", "h 1.0 2.0 3.0", "o 2.0 2.0 2.0"], n_atoms=2) assert np.allclose(xyz, [[1.0, 2.0, 3.0], [2.0, 2.0, 2.0]]) assert atoms == ["h", "o"] From ebd8356799b18a958241423914524c8bd8f5cd9a Mon Sep 17 00:00:00 2001 From: Jakob Gamper <97gamjak@gmail.com> Date: Sun, 12 Nov 2023 00:00:28 +0100 Subject: [PATCH 2/6] velocity reading fully tested --- PQAnalysis/io/trajectoryReader.py | 13 +++++++------ PQAnalysis/traj/frame.py | 12 ++++++++++++ PQAnalysis/traj/trajectory.py | 29 +++++++++++++++++++++++++---- tests/io/test_trajectoryReader.py | 17 +++++++++++++++++ 4 files changed, 61 insertions(+), 10 deletions(-) diff --git a/PQAnalysis/io/trajectoryReader.py b/PQAnalysis/io/trajectoryReader.py index 13a475bb..f1b62dd6 100644 --- a/PQAnalysis/io/trajectoryReader.py +++ b/PQAnalysis/io/trajectoryReader.py @@ -123,12 +123,13 @@ def read(self, frame_string: str, format: TrajectoryFormat | str = TrajectoryFor ValueError If the given format is not valid. """ - if format == TrajectoryFormat.XYZ: + + # Note: TrajectoryFormat(format) automatically gives an error if format is not a valid TrajectoryFormat + + if TrajectoryFormat(format) is TrajectoryFormat.XYZ: return self.read_positions(frame_string) - elif format == TrajectoryFormat.VELOCS: + elif TrajectoryFormat(format) is TrajectoryFormat.VEL: return self.read_velocities(frame_string) - else: - raise ValueError('Invalid format.') def read_positions(self, frame_string: str) -> Frame: """ @@ -189,14 +190,14 @@ def read_velocities(self, frame_string: str) -> Frame: n_atoms, cell = self._read_header_line(header_line) - velocs, atoms = self._read_xyz(splitted_frame_string, n_atoms) + vel, atoms = self._read_xyz(splitted_frame_string, n_atoms) try: atoms = [Atom(atom) for atom in atoms] except ElementNotFoundError: atoms = [Atom(atom, use_guess_element=False) for atom in atoms] - return Frame(AtomicSystem(atoms=atoms, velocs=velocs, cell=cell)) + return Frame(AtomicSystem(atoms=atoms, vel=vel, cell=cell)) def _read_header_line(self, header_line: str) -> Tuple[int, Cell | None]: """ diff --git a/PQAnalysis/traj/frame.py b/PQAnalysis/traj/frame.py index 7ebd3118..dff1554f 100644 --- a/PQAnalysis/traj/frame.py +++ b/PQAnalysis/traj/frame.py @@ -190,6 +190,18 @@ def pos(self) -> Numpy2DFloatArray: """ return self.system.pos + @property + def vel(self) -> Numpy2DFloatArray: + """ + The positions of the atoms in the system. + + Returns + ------- + Numpy2DFloatArray + The positions of the atoms in the system. + """ + return self.system.vel + @property def atoms(self) -> List[Atom]: """ diff --git a/PQAnalysis/traj/trajectory.py b/PQAnalysis/traj/trajectory.py index 3d2b9263..2c36b4e5 100644 --- a/PQAnalysis/traj/trajectory.py +++ b/PQAnalysis/traj/trajectory.py @@ -29,12 +29,33 @@ class TrajectoryFormat(Enum): ---------- XYZ : str The XYZ format. - VELOCS : str - The VELOCS format. + VEL : str + The VEL format. """ - XYZ = "xyz" - VELOCS = "velocs" + XYZ = "XYZ" + VEL = "VEL" + + @classmethod + def _missing_(cls, value: object) -> Any: + """ + This method allows a trajectory format to be retrieved from a string. + + Parameters + ---------- + value : str + _description_ + + Returns + ------- + Any + _description_ + """ + value = value.lower() + for member in cls: + if member.value.lower() == value: + return member + return None class Trajectory: diff --git a/tests/io/test_trajectoryReader.py b/tests/io/test_trajectoryReader.py index 293a5c9d..67cfe5f9 100644 --- a/tests/io/test_trajectoryReader.py +++ b/tests/io/test_trajectoryReader.py @@ -115,3 +115,20 @@ def test_read(self): assert np.allclose(frame.pos, [ [1.0, 2.0, 3.0], [2.0, 2.0, 2.0]]) assert frame.cell == Cell(2.0, 3.0, 4.0, 5.0, 6.0, 7.0) + + frame = reader.read( + "2 2.0 3.0 4.0 5.0 6.0 7.0\n\nh 1.0 2.0 3.0\no1 2.0 2.0 2.0", format="vel") + assert frame.n_atoms == 2 + assert frame.atoms == [Atom(atom, use_guess_element=False) + for atom in ["h", "o1"]] + assert np.allclose(frame.vel, [ + [1.0, 2.0, 3.0], [2.0, 2.0, 2.0]]) + assert frame.cell == Cell(2.0, 3.0, 4.0, 5.0, 6.0, 7.0) + + def test_read_invalid_format(self): + reader = FrameReader() + + with pytest.raises(ValueError) as exception: + reader.read("", format="invalid") + assert str( + exception.value) == "'invalid' is not a valid TrajectoryFormat" From efec0670e52d325cd1ed91cd3574997b411acd72 Mon Sep 17 00:00:00 2001 From: Jakob Gamper <97gamjak@gmail.com> Date: Sun, 12 Nov 2023 21:22:58 +0100 Subject: [PATCH 3/6] reading of velocities, forces and charges implemented --- PQAnalysis/core/atom.py | 4 +- PQAnalysis/core/atomicSystem.py | 2 +- PQAnalysis/core/cell.py | 2 +- PQAnalysis/exceptions.py | 52 ++++++++++ PQAnalysis/io/trajectoryReader.py | 114 +++++++++++++++++++++- PQAnalysis/io/trajectoryWriter.py | 2 +- PQAnalysis/physicalData/energy.py | 2 +- PQAnalysis/traj/frame.py | 26 ++++- PQAnalysis/traj/trajectory.py | 38 +++++++- PQAnalysis/{utils/mytypes.py => types.py} | 3 - PQAnalysis/utils/exceptions.py | 17 ---- tests/core/test_atom.py | 2 +- tests/io/test_trajectoryReader.py | 45 +++++++-- 13 files changed, 271 insertions(+), 38 deletions(-) create mode 100644 PQAnalysis/exceptions.py rename PQAnalysis/{utils/mytypes.py => types.py} (92%) delete mode 100644 PQAnalysis/utils/exceptions.py diff --git a/PQAnalysis/core/atom.py b/PQAnalysis/core/atom.py index 76e1c7ca..173ba5a4 100644 --- a/PQAnalysis/core/atom.py +++ b/PQAnalysis/core/atom.py @@ -28,13 +28,11 @@ """ -import numpy as np - from multimethod import multimethod from beartype.typing import Any, Tuple from numbers import Real -from PQAnalysis.utils.exceptions import ElementNotFoundError +from PQAnalysis.exceptions import ElementNotFoundError def guess_element(id: int | str) -> Tuple[str, int, Real]: diff --git a/PQAnalysis/core/atomicSystem.py b/PQAnalysis/core/atomicSystem.py index 49b50560..d325bca7 100644 --- a/PQAnalysis/core/atomicSystem.py +++ b/PQAnalysis/core/atomicSystem.py @@ -23,7 +23,7 @@ from .atom import Atom from .cell import Cell -from ..utils.mytypes import Numpy2DFloatArray, Numpy1DFloatArray +from ..types import Numpy2DFloatArray, Numpy1DFloatArray def check_atoms_pos(func): diff --git a/PQAnalysis/core/cell.py b/PQAnalysis/core/cell.py index d67074c2..3c6baa3a 100644 --- a/PQAnalysis/core/cell.py +++ b/PQAnalysis/core/cell.py @@ -16,7 +16,7 @@ from beartype.typing import Any from numbers import Real -from ..utils.mytypes import Numpy3x3FloatArray, Numpy2DFloatArray, Numpy1DFloatArray +from ..types import Numpy3x3FloatArray, Numpy2DFloatArray, Numpy1DFloatArray class Cell: diff --git a/PQAnalysis/exceptions.py b/PQAnalysis/exceptions.py new file mode 100644 index 00000000..af08cf1d --- /dev/null +++ b/PQAnalysis/exceptions.py @@ -0,0 +1,52 @@ +""" +A module containing different exceptions which could be useful. + +... + +Classes +------- +PQException + Base class for exceptions in this module. +ElementNotFoundError + Exception raised if the given element id is not valid +TrajectoryFormatError + Exception raised if the given enum is not valid +""" + +from beartype.typing import Any + + +class PQException(Exception): + """ + Base class for exceptions in this module. + """ + + def __init__(self, message: str) -> None: + self.message = message + super().__init__(self.message) + + +class ElementNotFoundError(PQException): + """ + Exception raised if the given element id is not valid + """ + + def __init__(self, id: Any) -> None: + self.id = id + self.message = f"""Id {self.id} is not a valid element identifier.""" + super().__init__(self.message) + + +class TrajectoryFormatError(PQException): + """ + Exception raised if the given enum is not valid + """ + + def __init__(self, value: object, enum: object) -> None: + self.enum = enum + self.value = value + self.message = f""" +'{self.value}' is not a valid TrajectoryFormat. +Possible values are: {enum.member_repr()} +or their case insensitive string representation: {enum.value_repr()}""" + super().__init__(self.message) diff --git a/PQAnalysis/io/trajectoryReader.py b/PQAnalysis/io/trajectoryReader.py index f1b62dd6..ea2b3d80 100644 --- a/PQAnalysis/io/trajectoryReader.py +++ b/PQAnalysis/io/trajectoryReader.py @@ -22,8 +22,8 @@ from ..core.cell import Cell from ..core.atom import Atom from ..core.atomicSystem import AtomicSystem -from ..utils.exceptions import ElementNotFoundError -from ..utils.mytypes import Numpy2DFloatArray +from ..exceptions import ElementNotFoundError +from ..types import Numpy2DFloatArray, Numpy1DFloatArray class TrajectoryReader(BaseReader): @@ -130,6 +130,10 @@ def read(self, frame_string: str, format: TrajectoryFormat | str = TrajectoryFor return self.read_positions(frame_string) elif TrajectoryFormat(format) is TrajectoryFormat.VEL: return self.read_velocities(frame_string) + elif TrajectoryFormat(format) is TrajectoryFormat.FORCE: + return self.read_forces(frame_string) + elif TrajectoryFormat(format) is TrajectoryFormat.CHARGE: + return self.read_charges(frame_string) def read_positions(self, frame_string: str) -> Frame: """ @@ -199,6 +203,74 @@ def read_velocities(self, frame_string: str) -> Frame: return Frame(AtomicSystem(atoms=atoms, vel=vel, cell=cell)) + def read_forces(self, frame_string: str) -> Frame: + """ + Reads the forces of the atoms in a frame from a string. + + Parameters + ---------- + frame_string : str + The string to read the frame from. + + Returns + ------- + Frame + The frame read from the string. + + Raises + ------ + TypeError + If the given frame_string is not a string. + """ + + splitted_frame_string = frame_string.split('\n') + header_line = splitted_frame_string[0] + + n_atoms, cell = self._read_header_line(header_line) + + forces, atoms = self._read_xyz(splitted_frame_string, n_atoms) + + try: + atoms = [Atom(atom) for atom in atoms] + except ElementNotFoundError: + atoms = [Atom(atom, use_guess_element=False) for atom in atoms] + + return Frame(AtomicSystem(atoms=atoms, forces=forces, cell=cell)) + + def read_charges(self, frame_string: str) -> Frame: + """ + Reads the charge values of the atoms in a frame from a string. + + Parameters + ---------- + frame_string : str + The string to read the frame from. + + Returns + ------- + Frame + The frame read from the string. + + Raises + ------ + TypeError + If the given frame_string is not a string. + """ + + splitted_frame_string = frame_string.split('\n') + header_line = splitted_frame_string[0] + + n_atoms, cell = self._read_header_line(header_line) + + charges, atoms = self._read_scalar(splitted_frame_string, n_atoms) + + try: + atoms = [Atom(atom) for atom in atoms] + except ElementNotFoundError: + atoms = [Atom(atom, use_guess_element=False) for atom in atoms] + + return Frame(AtomicSystem(atoms=atoms, charges=charges, cell=cell)) + def _read_header_line(self, header_line: str) -> Tuple[int, Cell | None]: """ Reads the header line of a frame. @@ -281,3 +353,41 @@ def _read_xyz(self, splitted_frame_string: List[str], n_atoms: int) -> Tuple[Num atoms.append(line.split()[0]) return xyz, atoms + + def _read_scalar(self, splitted_frame_string: List[str], n_atoms: int) -> Tuple[Numpy1DFloatArray, List[str]]: + """ + Reads the scalar values and the atom names from the given string. + + Parameters + ---------- + splitted_frame_string : str + The string to read the scalar values and the atom names from. + n_atoms : int + The number of atoms in the frame. + + Returns + ------- + scalar : np.array + The scalar values of the atoms. + atoms : list of str + The names of the atoms. + + Raises + ------ + ValueError + If the given string does not contain the correct number of lines. + """ + + scalar = np.zeros((n_atoms)) + atoms = [] + for i in range(n_atoms): + line = splitted_frame_string[2+i] + + if len(line.split()) != 2: + raise ValueError( + 'Invalid file format in scalar values of Frame.') + + scalar[i] = float(line.split()[1]) + atoms.append(line.split()[0]) + + return scalar, atoms diff --git a/PQAnalysis/io/trajectoryWriter.py b/PQAnalysis/io/trajectoryWriter.py index c73bcaef..158f8c99 100644 --- a/PQAnalysis/io/trajectoryWriter.py +++ b/PQAnalysis/io/trajectoryWriter.py @@ -17,7 +17,7 @@ from ..traj.trajectory import Trajectory from ..core.cell import Cell from ..core.atom import Atom -from ..utils.mytypes import Numpy2DFloatArray +from ..types import Numpy2DFloatArray def write_trajectory(traj, filename: str | None = None, format: str | None = None) -> None: diff --git a/PQAnalysis/physicalData/energy.py b/PQAnalysis/physicalData/energy.py index 2d9181c8..a3fe8999 100644 --- a/PQAnalysis/physicalData/energy.py +++ b/PQAnalysis/physicalData/energy.py @@ -14,7 +14,7 @@ from beartype.typing import Dict from collections import defaultdict -from ..utils.mytypes import Numpy2DFloatArray, Numpy1DFloatArray +from ..types import Numpy2DFloatArray, Numpy1DFloatArray class Energy(): diff --git a/PQAnalysis/traj/frame.py b/PQAnalysis/traj/frame.py index dff1554f..4a2a9e68 100644 --- a/PQAnalysis/traj/frame.py +++ b/PQAnalysis/traj/frame.py @@ -19,7 +19,7 @@ from ..core.atomicSystem import AtomicSystem from ..core.atom import Atom from ..core.cell import Cell -from ..utils.mytypes import Numpy2DFloatArray, Numpy1DFloatArray +from ..types import Numpy2DFloatArray, Numpy1DFloatArray class Frame: @@ -202,6 +202,30 @@ def vel(self) -> Numpy2DFloatArray: """ return self.system.vel + @property + def forces(self) -> Numpy2DFloatArray: + """ + The forces on the atoms in the system. + + Returns + ------- + Numpy2DFloatArray + The forces on the atoms in the system. + """ + return self.system.forces + + @property + def charges(self) -> Numpy1DFloatArray: + """ + The charges of the atoms in the system. + + Returns + ------- + Numpy1DFloatArray + The charges of the atoms in the system. + """ + return self.system.charges + @property def atoms(self) -> List[Atom]: """ diff --git a/PQAnalysis/traj/trajectory.py b/PQAnalysis/traj/trajectory.py index 2c36b4e5..76321be3 100644 --- a/PQAnalysis/traj/trajectory.py +++ b/PQAnalysis/traj/trajectory.py @@ -5,6 +5,8 @@ Classes ------- +TrajectoryFormat + An enumeration of the supported trajectory formats. Trajectory A trajectory is a sequence of frames. """ @@ -17,6 +19,7 @@ from enum import Enum from .frame import Frame +from ..exceptions import TrajectoryFormatError class TrajectoryFormat(Enum): @@ -31,10 +34,16 @@ class TrajectoryFormat(Enum): The XYZ format. VEL : str The VEL format. + FORCE : str + The FORCE format. + CHARGE : str + The CHARGE format. """ XYZ = "XYZ" VEL = "VEL" + FORCE = "FORCE" + CHARGE = "CHARGE" @classmethod def _missing_(cls, value: object) -> Any: @@ -55,7 +64,34 @@ def _missing_(cls, value: object) -> Any: for member in cls: if member.value.lower() == value: return member - return None + + raise TrajectoryFormatError(value, cls) + + @classmethod + def member_repr(cls) -> str: + """ + This method returns a string representation of the members of the enumeration. + + Returns + ------- + str + A string representation of the members of the enumeration. + """ + + return ', '.join([str(member) for member in cls]) + + @classmethod + def value_repr(cls) -> str: + """ + This method returns a string representation of the values of the members of the enumeration. + + Returns + ------- + str + A string representation of the values of the members of the enumeration. + """ + + return ', '.join([str(member.value) for member in cls]) class Trajectory: diff --git a/PQAnalysis/utils/mytypes.py b/PQAnalysis/types.py similarity index 92% rename from PQAnalysis/utils/mytypes.py rename to PQAnalysis/types.py index 12ad8ab7..3679eb8a 100644 --- a/PQAnalysis/utils/mytypes.py +++ b/PQAnalysis/types.py @@ -3,9 +3,6 @@ from beartype.vale import Is from typing import Annotated -# Import the requisite machinery. -from beartype import beartype, BeartypeConf - Numpy2DFloatArray = Annotated[np.ndarray, Is[lambda array: array.ndim == 2 and diff --git a/PQAnalysis/utils/exceptions.py b/PQAnalysis/utils/exceptions.py deleted file mode 100644 index f8652b15..00000000 --- a/PQAnalysis/utils/exceptions.py +++ /dev/null @@ -1,17 +0,0 @@ -""" -A module containing different exceptions which could be useful. -""" - -from beartype.typing import Any - - -class ElementNotFoundError(Exception): - """ - Exception raised if the given element id is not valid - """ - - def __init__(self, id: Any) -> None: - self.id = id - self.message = f"""Id { - self.id} is not a valid element identifier.""" - super().__init__(self.message) diff --git a/tests/core/test_atom.py b/tests/core/test_atom.py index 14c08d19..b9b1aa1f 100644 --- a/tests/core/test_atom.py +++ b/tests/core/test_atom.py @@ -4,7 +4,7 @@ from multimethod import DispatchError from PQAnalysis.core.atom import Atom, guess_element -from PQAnalysis.utils.exceptions import ElementNotFoundError +from PQAnalysis.exceptions import ElementNotFoundError def test_guess_element(): diff --git a/tests/io/test_trajectoryReader.py b/tests/io/test_trajectoryReader.py index 67cfe5f9..06d36f6c 100644 --- a/tests/io/test_trajectoryReader.py +++ b/tests/io/test_trajectoryReader.py @@ -3,11 +3,12 @@ from beartype.roar import BeartypeException -from PQAnalysis.io.trajectoryReader import TrajectoryReader, FrameReader -from PQAnalysis.core.cell import Cell +from PQAnalysis.io.trajectoryReader import TrajectoryReader, FrameReader, TrajectoryFormat from PQAnalysis.traj.frame import Frame +from PQAnalysis.core.cell import Cell from PQAnalysis.core.atomicSystem import AtomicSystem from PQAnalysis.core.atom import Atom +from PQAnalysis.exceptions import TrajectoryFormatError class TestTrajectoryReader: @@ -56,7 +57,7 @@ def test_read(self): class TestFrameReader: - def test__read_header_line__(self): + def test__read_header_line(self): reader = FrameReader() with pytest.raises(ValueError) as exception: @@ -79,7 +80,7 @@ def test__read_header_line__(self): assert n_atoms == 3 assert cell is None - def test__read_xyz__(self): + def test__read_xyz(self): reader = FrameReader() with pytest.raises(ValueError) as exception: @@ -93,6 +94,18 @@ def test__read_xyz__(self): assert np.allclose(xyz, [[1.0, 2.0, 3.0], [2.0, 2.0, 2.0]]) assert atoms == ["h", "o"] + def test__read_scalar(self): + reader = FrameReader() + + with pytest.raises(ValueError) as exception: + reader._read_scalar(["", "", "h 1.0 2.0 3.0"], n_atoms=1) + assert str( + exception.value) == "Invalid file format in scalar values of Frame." + + scalar, atoms = reader._read_scalar(["", "", "h 1.0"], n_atoms=1) + assert np.allclose(scalar, [1.0]) + assert atoms == ["h"] + def test_read(self): reader = FrameReader() @@ -125,10 +138,30 @@ def test_read(self): [1.0, 2.0, 3.0], [2.0, 2.0, 2.0]]) assert frame.cell == Cell(2.0, 3.0, 4.0, 5.0, 6.0, 7.0) + frame = reader.read( + "2 2.0 3.0 4.0 5.0 6.0 7.0\n\nh 1.0 2.0 3.0\no1 2.0 2.0 2.0", format="force") + assert frame.n_atoms == 2 + assert frame.atoms == [Atom(atom, use_guess_element=False) + for atom in ["h", "o1"]] + assert np.allclose(frame.forces, [ + [1.0, 2.0, 3.0], [2.0, 2.0, 2.0]]) + assert frame.cell == Cell(2.0, 3.0, 4.0, 5.0, 6.0, 7.0) + + frame = reader.read( + "2 2.0 3.0 4.0 5.0 6.0 7.0\n\nh 1.0\no1 2.0", format="charge") + assert frame.n_atoms == 2 + assert frame.atoms == [Atom(atom, use_guess_element=False) + for atom in ["h", "o1"]] + assert np.allclose(frame.charges, [1.0, 2.0]) + assert frame.cell == Cell(2.0, 3.0, 4.0, 5.0, 6.0, 7.0) + def test_read_invalid_format(self): reader = FrameReader() - with pytest.raises(ValueError) as exception: + with pytest.raises(TrajectoryFormatError) as exception: reader.read("", format="invalid") assert str( - exception.value) == "'invalid' is not a valid TrajectoryFormat" + exception.value) == f""" +'invalid' is not a valid TrajectoryFormat. +Possible values are: {TrajectoryFormat.member_repr()} +or their case insensitive string representation: {TrajectoryFormat.value_repr()}""" From 11f82f272f78a69750f08b4ed6d42c6ce4023a67 Mon Sep 17 00:00:00 2001 From: Jakob Gamper <97gamjak@gmail.com> Date: Sun, 12 Nov 2023 21:30:21 +0100 Subject: [PATCH 4/6] some refactoring --- PQAnalysis/io/frameReader.py | 320 ++++++++++++++++++++++++++++++ PQAnalysis/io/trajectoryReader.py | 310 +---------------------------- tests/io/test_frameReader.py | 122 ++++++++++++ tests/io/test_trajectoryReader.py | 117 +---------- 4 files changed, 444 insertions(+), 425 deletions(-) create mode 100644 PQAnalysis/io/frameReader.py create mode 100644 tests/io/test_frameReader.py diff --git a/PQAnalysis/io/frameReader.py b/PQAnalysis/io/frameReader.py new file mode 100644 index 00000000..1846e669 --- /dev/null +++ b/PQAnalysis/io/frameReader.py @@ -0,0 +1,320 @@ +""" +A module containing classes for reading a frame from a string. + +... + +Classes +------- +FrameReader + A class for reading a frame from a string. +""" + +from __future__ import annotations + +import numpy as np + +from beartype.typing import List, Tuple + +from ..core.atomicSystem import AtomicSystem +from ..core.atom import Atom +from ..core.cell import Cell +from ..types import Numpy2DFloatArray, Numpy1DFloatArray +from ..traj.frame import Frame +from ..traj.trajectory import TrajectoryFormat +from ..exceptions import ElementNotFoundError + + +class FrameReader: + """ + FrameReader reads a frame from a string. + """ + + def read(self, frame_string: str, format: TrajectoryFormat | str = TrajectoryFormat.XYZ) -> Frame: + """ + Reads a frame from a string. + + Parameters + ---------- + frame_string : str + The string to read the frame from. + format : TrajectoryFormat | str, optional + The format of the trajectory. Default is TrajectoryFormat.XYZ. + + Returns + ------- + Frame + The frame read from the string. + + Raises + ------ + ValueError + If the given format is not valid. + """ + + # Note: TrajectoryFormat(format) automatically gives an error if format is not a valid TrajectoryFormat + + if TrajectoryFormat(format) is TrajectoryFormat.XYZ: + return self.read_positions(frame_string) + elif TrajectoryFormat(format) is TrajectoryFormat.VEL: + return self.read_velocities(frame_string) + elif TrajectoryFormat(format) is TrajectoryFormat.FORCE: + return self.read_forces(frame_string) + elif TrajectoryFormat(format) is TrajectoryFormat.CHARGE: + return self.read_charges(frame_string) + + def read_positions(self, frame_string: str) -> Frame: + """ + Reads the positions of the atoms in a frame from a string. + + Parameters + ---------- + frame_string : str + The string to read the frame from. + + Returns + ------- + Frame + The frame read from the string. + + Raises + ------ + TypeError + If the given frame_string is not a string. + """ + + splitted_frame_string = frame_string.split('\n') + header_line = splitted_frame_string[0] + + n_atoms, cell = self._read_header_line(header_line) + + xyz, atoms = self._read_xyz(splitted_frame_string, n_atoms) + + try: + atoms = [Atom(atom) for atom in atoms] + except ElementNotFoundError: + atoms = [Atom(atom, use_guess_element=False) for atom in atoms] + + return Frame(AtomicSystem(atoms=atoms, pos=xyz, cell=cell)) + + def read_velocities(self, frame_string: str) -> Frame: + """ + Reads the velocities of the atoms in a frame from a string. + + Parameters + ---------- + frame_string : str + The string to read the frame from. + + Returns + ------- + Frame + The frame read from the string. + + Raises + ------ + TypeError + If the given frame_string is not a string. + """ + + splitted_frame_string = frame_string.split('\n') + header_line = splitted_frame_string[0] + + n_atoms, cell = self._read_header_line(header_line) + + vel, atoms = self._read_xyz(splitted_frame_string, n_atoms) + + try: + atoms = [Atom(atom) for atom in atoms] + except ElementNotFoundError: + atoms = [Atom(atom, use_guess_element=False) for atom in atoms] + + return Frame(AtomicSystem(atoms=atoms, vel=vel, cell=cell)) + + def read_forces(self, frame_string: str) -> Frame: + """ + Reads the forces of the atoms in a frame from a string. + + Parameters + ---------- + frame_string : str + The string to read the frame from. + + Returns + ------- + Frame + The frame read from the string. + + Raises + ------ + TypeError + If the given frame_string is not a string. + """ + + splitted_frame_string = frame_string.split('\n') + header_line = splitted_frame_string[0] + + n_atoms, cell = self._read_header_line(header_line) + + forces, atoms = self._read_xyz(splitted_frame_string, n_atoms) + + try: + atoms = [Atom(atom) for atom in atoms] + except ElementNotFoundError: + atoms = [Atom(atom, use_guess_element=False) for atom in atoms] + + return Frame(AtomicSystem(atoms=atoms, forces=forces, cell=cell)) + + def read_charges(self, frame_string: str) -> Frame: + """ + Reads the charge values of the atoms in a frame from a string. + + Parameters + ---------- + frame_string : str + The string to read the frame from. + + Returns + ------- + Frame + The frame read from the string. + + Raises + ------ + TypeError + If the given frame_string is not a string. + """ + + splitted_frame_string = frame_string.split('\n') + header_line = splitted_frame_string[0] + + n_atoms, cell = self._read_header_line(header_line) + + charges, atoms = self._read_scalar(splitted_frame_string, n_atoms) + + try: + atoms = [Atom(atom) for atom in atoms] + except ElementNotFoundError: + atoms = [Atom(atom, use_guess_element=False) for atom in atoms] + + return Frame(AtomicSystem(atoms=atoms, charges=charges, cell=cell)) + + def _read_header_line(self, header_line: str) -> Tuple[int, Cell | None]: + """ + Reads the header line of a frame. + + It reads the number of atoms and the cell information from the header line. + If the header line contains only the number of atoms, the cell is set to None. + If the header line contains only the number of atoms and the box dimensions, + the cell is set to a Cell object with the given box dimensions and box angles set to 90°. + + Parameters + ---------- + header_line : str + The header line to read. + + Returns + ------- + n_atoms : int + The number of atoms in the frame. + cell : Cell + The cell of the frame. + + Raises + ------ + ValueError + If the header line is not valid. Either it contains too many or too few values. + """ + + header_line = header_line.split() + + if len(header_line) == 4: + n_atoms = int(header_line[0]) + a, b, c = map(float, header_line[1:4]) + cell = Cell(a, b, c) + elif len(header_line) == 7: + n_atoms = int(header_line[0]) + a, b, c, alpha, beta, gamma = map(float, header_line[1:7]) + cell = Cell(a, b, c, alpha, beta, gamma) + elif len(header_line) == 1: + n_atoms = int(header_line[0]) + cell = None + else: + raise ValueError('Invalid file format in header line of Frame.') + + return n_atoms, cell + + def _read_xyz(self, splitted_frame_string: List[str], n_atoms: int) -> Tuple[Numpy2DFloatArray, List[str]]: + """ + Reads the xyz coordinates and the atom names from the given string. + + Parameters + ---------- + splitted_frame_string : str + The string to read the xyz coordinates and the atom names from. + n_atoms : int + The number of atoms in the frame. + + Returns + ------- + xyz : np.array + The xyz coordinates of the atoms. + atoms : list of str + The names of the atoms. + + Raises + ------ + ValueError + If the given string does not contain the correct number of lines. + """ + + xyz = np.zeros((n_atoms, 3)) + atoms = [] + for i in range(n_atoms): + line = splitted_frame_string[2+i] + + if len(line.split()) != 4: + raise ValueError( + 'Invalid file format in xyz coordinates of Frame.') + + xyz[i] = np.array([float(x) for x in line.split()[1:4]]) + atoms.append(line.split()[0]) + + return xyz, atoms + + def _read_scalar(self, splitted_frame_string: List[str], n_atoms: int) -> Tuple[Numpy1DFloatArray, List[str]]: + """ + Reads the scalar values and the atom names from the given string. + + Parameters + ---------- + splitted_frame_string : str + The string to read the scalar values and the atom names from. + n_atoms : int + The number of atoms in the frame. + + Returns + ------- + scalar : np.array + The scalar values of the atoms. + atoms : list of str + The names of the atoms. + + Raises + ------ + ValueError + If the given string does not contain the correct number of lines. + """ + + scalar = np.zeros((n_atoms)) + atoms = [] + for i in range(n_atoms): + line = splitted_frame_string[2+i] + + if len(line.split()) != 2: + raise ValueError( + 'Invalid file format in scalar values of Frame.') + + scalar[i] = float(line.split()[1]) + atoms.append(line.split()[0]) + + return scalar, atoms diff --git a/PQAnalysis/io/trajectoryReader.py b/PQAnalysis/io/trajectoryReader.py index ea2b3d80..837dd2c1 100644 --- a/PQAnalysis/io/trajectoryReader.py +++ b/PQAnalysis/io/trajectoryReader.py @@ -7,23 +7,11 @@ ------- TrajectoryReader A class for reading a trajectory from a file. -FrameReader - A class for reading a frame from a string. """ -import numpy as np - -from beartype.typing import Tuple, List -from enum import Enum - from .base import BaseReader -from ..traj.frame import Frame from ..traj.trajectory import Trajectory, TrajectoryFormat -from ..core.cell import Cell -from ..core.atom import Atom -from ..core.atomicSystem import AtomicSystem -from ..exceptions import ElementNotFoundError -from ..types import Numpy2DFloatArray, Numpy1DFloatArray +from .frameReader import FrameReader class TrajectoryReader(BaseReader): @@ -95,299 +83,3 @@ def read(self) -> Trajectory: self.frames[-1].cell = self.frames[-2].cell return Trajectory(self.frames) - - -class FrameReader: - """ - FrameReader reads a frame from a string. - """ - - def read(self, frame_string: str, format: TrajectoryFormat | str = TrajectoryFormat.XYZ) -> Frame: - """ - Reads a frame from a string. - - Parameters - ---------- - frame_string : str - The string to read the frame from. - format : TrajectoryFormat | str, optional - The format of the trajectory. Default is TrajectoryFormat.XYZ. - - Returns - ------- - Frame - The frame read from the string. - - Raises - ------ - ValueError - If the given format is not valid. - """ - - # Note: TrajectoryFormat(format) automatically gives an error if format is not a valid TrajectoryFormat - - if TrajectoryFormat(format) is TrajectoryFormat.XYZ: - return self.read_positions(frame_string) - elif TrajectoryFormat(format) is TrajectoryFormat.VEL: - return self.read_velocities(frame_string) - elif TrajectoryFormat(format) is TrajectoryFormat.FORCE: - return self.read_forces(frame_string) - elif TrajectoryFormat(format) is TrajectoryFormat.CHARGE: - return self.read_charges(frame_string) - - def read_positions(self, frame_string: str) -> Frame: - """ - Reads the positions of the atoms in a frame from a string. - - Parameters - ---------- - frame_string : str - The string to read the frame from. - - Returns - ------- - Frame - The frame read from the string. - - Raises - ------ - TypeError - If the given frame_string is not a string. - """ - - splitted_frame_string = frame_string.split('\n') - header_line = splitted_frame_string[0] - - n_atoms, cell = self._read_header_line(header_line) - - xyz, atoms = self._read_xyz(splitted_frame_string, n_atoms) - - try: - atoms = [Atom(atom) for atom in atoms] - except ElementNotFoundError: - atoms = [Atom(atom, use_guess_element=False) for atom in atoms] - - return Frame(AtomicSystem(atoms=atoms, pos=xyz, cell=cell)) - - def read_velocities(self, frame_string: str) -> Frame: - """ - Reads the velocities of the atoms in a frame from a string. - - Parameters - ---------- - frame_string : str - The string to read the frame from. - - Returns - ------- - Frame - The frame read from the string. - - Raises - ------ - TypeError - If the given frame_string is not a string. - """ - - splitted_frame_string = frame_string.split('\n') - header_line = splitted_frame_string[0] - - n_atoms, cell = self._read_header_line(header_line) - - vel, atoms = self._read_xyz(splitted_frame_string, n_atoms) - - try: - atoms = [Atom(atom) for atom in atoms] - except ElementNotFoundError: - atoms = [Atom(atom, use_guess_element=False) for atom in atoms] - - return Frame(AtomicSystem(atoms=atoms, vel=vel, cell=cell)) - - def read_forces(self, frame_string: str) -> Frame: - """ - Reads the forces of the atoms in a frame from a string. - - Parameters - ---------- - frame_string : str - The string to read the frame from. - - Returns - ------- - Frame - The frame read from the string. - - Raises - ------ - TypeError - If the given frame_string is not a string. - """ - - splitted_frame_string = frame_string.split('\n') - header_line = splitted_frame_string[0] - - n_atoms, cell = self._read_header_line(header_line) - - forces, atoms = self._read_xyz(splitted_frame_string, n_atoms) - - try: - atoms = [Atom(atom) for atom in atoms] - except ElementNotFoundError: - atoms = [Atom(atom, use_guess_element=False) for atom in atoms] - - return Frame(AtomicSystem(atoms=atoms, forces=forces, cell=cell)) - - def read_charges(self, frame_string: str) -> Frame: - """ - Reads the charge values of the atoms in a frame from a string. - - Parameters - ---------- - frame_string : str - The string to read the frame from. - - Returns - ------- - Frame - The frame read from the string. - - Raises - ------ - TypeError - If the given frame_string is not a string. - """ - - splitted_frame_string = frame_string.split('\n') - header_line = splitted_frame_string[0] - - n_atoms, cell = self._read_header_line(header_line) - - charges, atoms = self._read_scalar(splitted_frame_string, n_atoms) - - try: - atoms = [Atom(atom) for atom in atoms] - except ElementNotFoundError: - atoms = [Atom(atom, use_guess_element=False) for atom in atoms] - - return Frame(AtomicSystem(atoms=atoms, charges=charges, cell=cell)) - - def _read_header_line(self, header_line: str) -> Tuple[int, Cell | None]: - """ - Reads the header line of a frame. - - It reads the number of atoms and the cell information from the header line. - If the header line contains only the number of atoms, the cell is set to None. - If the header line contains only the number of atoms and the box dimensions, - the cell is set to a Cell object with the given box dimensions and box angles set to 90°. - - Parameters - ---------- - header_line : str - The header line to read. - - Returns - ------- - n_atoms : int - The number of atoms in the frame. - cell : Cell - The cell of the frame. - - Raises - ------ - ValueError - If the header line is not valid. Either it contains too many or too few values. - """ - - header_line = header_line.split() - - if len(header_line) == 4: - n_atoms = int(header_line[0]) - a, b, c = map(float, header_line[1:4]) - cell = Cell(a, b, c) - elif len(header_line) == 7: - n_atoms = int(header_line[0]) - a, b, c, alpha, beta, gamma = map(float, header_line[1:7]) - cell = Cell(a, b, c, alpha, beta, gamma) - elif len(header_line) == 1: - n_atoms = int(header_line[0]) - cell = None - else: - raise ValueError('Invalid file format in header line of Frame.') - - return n_atoms, cell - - def _read_xyz(self, splitted_frame_string: List[str], n_atoms: int) -> Tuple[Numpy2DFloatArray, List[str]]: - """ - Reads the xyz coordinates and the atom names from the given string. - - Parameters - ---------- - splitted_frame_string : str - The string to read the xyz coordinates and the atom names from. - n_atoms : int - The number of atoms in the frame. - - Returns - ------- - xyz : np.array - The xyz coordinates of the atoms. - atoms : list of str - The names of the atoms. - - Raises - ------ - ValueError - If the given string does not contain the correct number of lines. - """ - - xyz = np.zeros((n_atoms, 3)) - atoms = [] - for i in range(n_atoms): - line = splitted_frame_string[2+i] - - if len(line.split()) != 4: - raise ValueError( - 'Invalid file format in xyz coordinates of Frame.') - - xyz[i] = np.array([float(x) for x in line.split()[1:4]]) - atoms.append(line.split()[0]) - - return xyz, atoms - - def _read_scalar(self, splitted_frame_string: List[str], n_atoms: int) -> Tuple[Numpy1DFloatArray, List[str]]: - """ - Reads the scalar values and the atom names from the given string. - - Parameters - ---------- - splitted_frame_string : str - The string to read the scalar values and the atom names from. - n_atoms : int - The number of atoms in the frame. - - Returns - ------- - scalar : np.array - The scalar values of the atoms. - atoms : list of str - The names of the atoms. - - Raises - ------ - ValueError - If the given string does not contain the correct number of lines. - """ - - scalar = np.zeros((n_atoms)) - atoms = [] - for i in range(n_atoms): - line = splitted_frame_string[2+i] - - if len(line.split()) != 2: - raise ValueError( - 'Invalid file format in scalar values of Frame.') - - scalar[i] = float(line.split()[1]) - atoms.append(line.split()[0]) - - return scalar, atoms diff --git a/tests/io/test_frameReader.py b/tests/io/test_frameReader.py new file mode 100644 index 00000000..0da7447b --- /dev/null +++ b/tests/io/test_frameReader.py @@ -0,0 +1,122 @@ +import pytest +import numpy as np + +from beartype.roar import BeartypeException + +from PQAnalysis.io.frameReader import FrameReader +from PQAnalysis.core.cell import Cell +from PQAnalysis.core.atom import Atom +from PQAnalysis.exceptions import TrajectoryFormatError +from PQAnalysis.traj.trajectory import TrajectoryFormat + + +class TestFrameReader: + + def test__read_header_line(self): + reader = FrameReader() + + with pytest.raises(ValueError) as exception: + reader._read_header_line("1 2.0 3.0") + assert str( + exception.value) == "Invalid file format in header line of Frame." + + n_atoms, cell = reader._read_header_line( + "1 2.0 3.0 4.0 5.0 6.0 7.0") + assert n_atoms == 1 + assert np.allclose(cell.box_lengths, [2.0, 3.0, 4.0]) + assert np.allclose(cell.box_angles, [5.0, 6.0, 7.0]) + + n_atoms, cell = reader._read_header_line("2 2.0 3.0 4.0") + assert n_atoms == 2 + assert np.allclose(cell.box_lengths, [2.0, 3.0, 4.0]) + assert np.allclose(cell.box_angles, [90.0, 90.0, 90.0]) + + n_atoms, cell = reader._read_header_line("3") + assert n_atoms == 3 + assert cell is None + + def test__read_xyz(self): + reader = FrameReader() + + with pytest.raises(ValueError) as exception: + reader._read_xyz( + ["", "", "h 1.0 2.0 3.0", "o 2.0 2.0"], n_atoms=2) + assert str( + exception.value) == "Invalid file format in xyz coordinates of Frame." + + xyz, atoms = reader._read_xyz( + ["", "", "h 1.0 2.0 3.0", "o 2.0 2.0 2.0"], n_atoms=2) + assert np.allclose(xyz, [[1.0, 2.0, 3.0], [2.0, 2.0, 2.0]]) + assert atoms == ["h", "o"] + + def test__read_scalar(self): + reader = FrameReader() + + with pytest.raises(ValueError) as exception: + reader._read_scalar(["", "", "h 1.0 2.0 3.0"], n_atoms=1) + assert str( + exception.value) == "Invalid file format in scalar values of Frame." + + scalar, atoms = reader._read_scalar(["", "", "h 1.0"], n_atoms=1) + assert np.allclose(scalar, [1.0]) + assert atoms == ["h"] + + def test_read(self): + reader = FrameReader() + + with pytest.raises(BeartypeException): + reader.read(["tmp"]) + + frame = reader.read( + "2 2.0 3.0 4.0 5.0 6.0 7.0\n\nh 1.0 2.0 3.0\no 2.0 2.0 2.0") + assert frame.n_atoms == 2 + assert frame.atoms == [Atom(atom) for atom in ["h", "o"]] + assert np.allclose(frame.pos, [ + [1.0, 2.0, 3.0], [2.0, 2.0, 2.0]]) + assert frame.cell == Cell(2.0, 3.0, 4.0, 5.0, 6.0, 7.0) + + frame = reader.read( + "2 2.0 3.0 4.0 5.0 6.0 7.0\n\nh 1.0 2.0 3.0\no1 2.0 2.0 2.0") + assert frame.n_atoms == 2 + assert frame.atoms == [Atom(atom, use_guess_element=False) + for atom in ["h", "o1"]] + assert np.allclose(frame.pos, [ + [1.0, 2.0, 3.0], [2.0, 2.0, 2.0]]) + assert frame.cell == Cell(2.0, 3.0, 4.0, 5.0, 6.0, 7.0) + + frame = reader.read( + "2 2.0 3.0 4.0 5.0 6.0 7.0\n\nh 1.0 2.0 3.0\no1 2.0 2.0 2.0", format="vel") + assert frame.n_atoms == 2 + assert frame.atoms == [Atom(atom, use_guess_element=False) + for atom in ["h", "o1"]] + assert np.allclose(frame.vel, [ + [1.0, 2.0, 3.0], [2.0, 2.0, 2.0]]) + assert frame.cell == Cell(2.0, 3.0, 4.0, 5.0, 6.0, 7.0) + + frame = reader.read( + "2 2.0 3.0 4.0 5.0 6.0 7.0\n\nh 1.0 2.0 3.0\no1 2.0 2.0 2.0", format="force") + assert frame.n_atoms == 2 + assert frame.atoms == [Atom(atom, use_guess_element=False) + for atom in ["h", "o1"]] + assert np.allclose(frame.forces, [ + [1.0, 2.0, 3.0], [2.0, 2.0, 2.0]]) + assert frame.cell == Cell(2.0, 3.0, 4.0, 5.0, 6.0, 7.0) + + frame = reader.read( + "2 2.0 3.0 4.0 5.0 6.0 7.0\n\nh 1.0\no1 2.0", format="charge") + assert frame.n_atoms == 2 + assert frame.atoms == [Atom(atom, use_guess_element=False) + for atom in ["h", "o1"]] + assert np.allclose(frame.charges, [1.0, 2.0]) + assert frame.cell == Cell(2.0, 3.0, 4.0, 5.0, 6.0, 7.0) + + def test_read_invalid_format(self): + reader = FrameReader() + + with pytest.raises(TrajectoryFormatError) as exception: + reader.read("", format="invalid") + assert str( + exception.value) == f""" +'invalid' is not a valid TrajectoryFormat. +Possible values are: {TrajectoryFormat.member_repr()} +or their case insensitive string representation: {TrajectoryFormat.value_repr()}""" diff --git a/tests/io/test_trajectoryReader.py b/tests/io/test_trajectoryReader.py index 06d36f6c..dccbbb9c 100644 --- a/tests/io/test_trajectoryReader.py +++ b/tests/io/test_trajectoryReader.py @@ -1,14 +1,11 @@ import pytest import numpy as np -from beartype.roar import BeartypeException - -from PQAnalysis.io.trajectoryReader import TrajectoryReader, FrameReader, TrajectoryFormat +from PQAnalysis.io.trajectoryReader import TrajectoryReader from PQAnalysis.traj.frame import Frame from PQAnalysis.core.cell import Cell from PQAnalysis.core.atomicSystem import AtomicSystem from PQAnalysis.core.atom import Atom -from PQAnalysis.exceptions import TrajectoryFormatError class TestTrajectoryReader: @@ -53,115 +50,3 @@ def test_read(self): # NOTE: here cell is not none because of the consecutive reading of frames # Cell will be taken from the previous frame assert traj[1] == frame2 - - -class TestFrameReader: - - def test__read_header_line(self): - reader = FrameReader() - - with pytest.raises(ValueError) as exception: - reader._read_header_line("1 2.0 3.0") - assert str( - exception.value) == "Invalid file format in header line of Frame." - - n_atoms, cell = reader._read_header_line( - "1 2.0 3.0 4.0 5.0 6.0 7.0") - assert n_atoms == 1 - assert np.allclose(cell.box_lengths, [2.0, 3.0, 4.0]) - assert np.allclose(cell.box_angles, [5.0, 6.0, 7.0]) - - n_atoms, cell = reader._read_header_line("2 2.0 3.0 4.0") - assert n_atoms == 2 - assert np.allclose(cell.box_lengths, [2.0, 3.0, 4.0]) - assert np.allclose(cell.box_angles, [90.0, 90.0, 90.0]) - - n_atoms, cell = reader._read_header_line("3") - assert n_atoms == 3 - assert cell is None - - def test__read_xyz(self): - reader = FrameReader() - - with pytest.raises(ValueError) as exception: - reader._read_xyz( - ["", "", "h 1.0 2.0 3.0", "o 2.0 2.0"], n_atoms=2) - assert str( - exception.value) == "Invalid file format in xyz coordinates of Frame." - - xyz, atoms = reader._read_xyz( - ["", "", "h 1.0 2.0 3.0", "o 2.0 2.0 2.0"], n_atoms=2) - assert np.allclose(xyz, [[1.0, 2.0, 3.0], [2.0, 2.0, 2.0]]) - assert atoms == ["h", "o"] - - def test__read_scalar(self): - reader = FrameReader() - - with pytest.raises(ValueError) as exception: - reader._read_scalar(["", "", "h 1.0 2.0 3.0"], n_atoms=1) - assert str( - exception.value) == "Invalid file format in scalar values of Frame." - - scalar, atoms = reader._read_scalar(["", "", "h 1.0"], n_atoms=1) - assert np.allclose(scalar, [1.0]) - assert atoms == ["h"] - - def test_read(self): - reader = FrameReader() - - with pytest.raises(BeartypeException): - reader.read(["tmp"]) - - frame = reader.read( - "2 2.0 3.0 4.0 5.0 6.0 7.0\n\nh 1.0 2.0 3.0\no 2.0 2.0 2.0") - assert frame.n_atoms == 2 - assert frame.atoms == [Atom(atom) for atom in ["h", "o"]] - assert np.allclose(frame.pos, [ - [1.0, 2.0, 3.0], [2.0, 2.0, 2.0]]) - assert frame.cell == Cell(2.0, 3.0, 4.0, 5.0, 6.0, 7.0) - - frame = reader.read( - "2 2.0 3.0 4.0 5.0 6.0 7.0\n\nh 1.0 2.0 3.0\no1 2.0 2.0 2.0") - assert frame.n_atoms == 2 - assert frame.atoms == [Atom(atom, use_guess_element=False) - for atom in ["h", "o1"]] - assert np.allclose(frame.pos, [ - [1.0, 2.0, 3.0], [2.0, 2.0, 2.0]]) - assert frame.cell == Cell(2.0, 3.0, 4.0, 5.0, 6.0, 7.0) - - frame = reader.read( - "2 2.0 3.0 4.0 5.0 6.0 7.0\n\nh 1.0 2.0 3.0\no1 2.0 2.0 2.0", format="vel") - assert frame.n_atoms == 2 - assert frame.atoms == [Atom(atom, use_guess_element=False) - for atom in ["h", "o1"]] - assert np.allclose(frame.vel, [ - [1.0, 2.0, 3.0], [2.0, 2.0, 2.0]]) - assert frame.cell == Cell(2.0, 3.0, 4.0, 5.0, 6.0, 7.0) - - frame = reader.read( - "2 2.0 3.0 4.0 5.0 6.0 7.0\n\nh 1.0 2.0 3.0\no1 2.0 2.0 2.0", format="force") - assert frame.n_atoms == 2 - assert frame.atoms == [Atom(atom, use_guess_element=False) - for atom in ["h", "o1"]] - assert np.allclose(frame.forces, [ - [1.0, 2.0, 3.0], [2.0, 2.0, 2.0]]) - assert frame.cell == Cell(2.0, 3.0, 4.0, 5.0, 6.0, 7.0) - - frame = reader.read( - "2 2.0 3.0 4.0 5.0 6.0 7.0\n\nh 1.0\no1 2.0", format="charge") - assert frame.n_atoms == 2 - assert frame.atoms == [Atom(atom, use_guess_element=False) - for atom in ["h", "o1"]] - assert np.allclose(frame.charges, [1.0, 2.0]) - assert frame.cell == Cell(2.0, 3.0, 4.0, 5.0, 6.0, 7.0) - - def test_read_invalid_format(self): - reader = FrameReader() - - with pytest.raises(TrajectoryFormatError) as exception: - reader.read("", format="invalid") - assert str( - exception.value) == f""" -'invalid' is not a valid TrajectoryFormat. -Possible values are: {TrajectoryFormat.member_repr()} -or their case insensitive string representation: {TrajectoryFormat.value_repr()}""" From f36881e97296b4019f8f01f878699177f22c93fa Mon Sep 17 00:00:00 2001 From: Jakob Gamper <97gamjak@gmail.com> Date: Sun, 12 Nov 2023 22:20:25 +0100 Subject: [PATCH 5/6] included writing of traj formats for vel, forces and charges --- PQAnalysis/io/trajectoryWriter.py | 151 +++++++++++++++++++++++++++--- tests/io/test_trajectoryWriter.py | 88 +++++++++++++++-- 2 files changed, 215 insertions(+), 24 deletions(-) diff --git a/PQAnalysis/io/trajectoryWriter.py b/PQAnalysis/io/trajectoryWriter.py index 158f8c99..cbc3c476 100644 --- a/PQAnalysis/io/trajectoryWriter.py +++ b/PQAnalysis/io/trajectoryWriter.py @@ -14,13 +14,14 @@ from beartype.typing import List from .base import BaseWriter -from ..traj.trajectory import Trajectory +from ..traj.trajectory import Trajectory, TrajectoryFormat +from ..traj.frame import Frame from ..core.cell import Cell from ..core.atom import Atom -from ..types import Numpy2DFloatArray +from ..types import Numpy2DFloatArray, Numpy1DFloatArray -def write_trajectory(traj, filename: str | None = None, format: str | None = None) -> None: +def write_trajectory(traj, filename: str | None = None, format: str | None = None, type: TrajectoryFormat | str = TrajectoryFormat.XYZ) -> None: """ Wrapper for TrajectoryWriter to write a trajectory to a file. @@ -33,10 +34,15 @@ def write_trajectory(traj, filename: str | None = None, format: str | None = Non The trajectory to write. filename : str, optional The name of the file to write to. If None, the output is printed to stdout. + format : str, optional + The format of the file. If None, the default PIMD-QMCF format is used. + type : TrajectoryFormat | str, optional + The type of the data to write to the file. Default is TrajectoryFormat.XYZ. + """ - writer = TrajectoryWriter(filename, format) - writer.write(traj) + writer = TrajectoryWriter(filename, format=format) + writer.write(traj, type=type) class TrajectoryWriter(BaseWriter): @@ -64,6 +70,9 @@ class TrajectoryWriter(BaseWriter): X 0.0 0.0 0.0 coordinates of the atoms in the format 'element x y z' + _type : TrajectoryFormat + The type of the data to write to the file. Default is TrajectoryFormat.XYZ. + Attributes ---------- format : str @@ -71,6 +80,7 @@ class TrajectoryWriter(BaseWriter): """ formats = [None, 'pimd-qmcf', 'qmcfc'] + _type: TrajectoryFormat = TrajectoryFormat.XYZ def __init__(self, filename: str | None = None, @@ -101,7 +111,7 @@ def __init__(self, self.format = format - def write(self, trajectory: Trajectory) -> None: + def write(self, trajectory: Trajectory, type: TrajectoryFormat | str = TrajectoryFormat.XYZ) -> None: """ Writes the trajectory to the file. @@ -110,13 +120,91 @@ def write(self, trajectory: Trajectory) -> None: traj : Trajectory The trajectory to write. """ + self._type = TrajectoryFormat(type) + if self._type == TrajectoryFormat.XYZ: + self.write_positions(trajectory) + elif self._type == TrajectoryFormat.VEL: + self.write_velocities(trajectory) + elif self._type == TrajectoryFormat.FORCE: + self.write_forces(trajectory) + elif self._type == TrajectoryFormat.CHARGE: + self.write_charges(trajectory) + + self.close() + + def write_positions(self, trajectory: Trajectory) -> None: + """ + Writes the positions of the trajectory to the file. + + Parameters + ---------- + traj : Trajectory + The trajectory to write. + """ + self._type = TrajectoryFormat.XYZ self.open() for frame in trajectory: - self.__write_header__(frame.n_atoms, frame.cell) - self.__write_coordinates__(frame.pos, frame.atoms) + self._write_header(frame.n_atoms, frame.cell) + self._write_comment(frame) + self._write_xyz(frame.pos, frame.atoms) + self.close() - def __write_header__(self, n_atoms: int, cell: Cell | None = None) -> None: + def write_velocities(self, trajectory: Trajectory) -> None: + """ + Writes the velocities of the trajectory to the file. + + Parameters + ---------- + traj : Trajectory + The trajectory to write. + """ + self._type = TrajectoryFormat.VEL + self.open() + for frame in trajectory: + self._write_header(frame.n_atoms, frame.cell) + self._write_comment(frame) + self._write_xyz(frame.vel, frame.atoms) + + self.close() + + def write_forces(self, trajectory: Trajectory) -> None: + """ + Writes the forces of the trajectory to the file. + + Parameters + ---------- + traj : Trajectory + The trajectory to write. + """ + self._type = TrajectoryFormat.FORCE + self.open() + for frame in trajectory: + self._write_header(frame.n_atoms, frame.cell) + self._write_comment(frame) + self._write_xyz(frame.forces, frame.atoms) + + self.close() + + def write_charges(self, trajectory: Trajectory) -> None: + """ + Writes the charges of the trajectory to the file. + + Parameters + ---------- + traj : Trajectory + The trajectory to write. + """ + self._type = TrajectoryFormat.CHARGE + self.open() + for frame in trajectory: + self._write_header(frame.n_atoms, frame.cell) + self._write_comment(frame) + self._write_scalar(frame.charges, frame.atoms) + + self.close() + + def _write_header(self, n_atoms: int, cell: Cell | None = None) -> None: """ Writes the header line of the frame to the file. @@ -130,27 +218,60 @@ def __write_header__(self, n_atoms: int, cell: Cell | None = None) -> None: if cell is not None: print( - f"{n_atoms} {cell.x} {cell.y} {cell.z} {cell.alpha} {cell.beta} {cell.gamma}\n", file=self.file) + f"{n_atoms} {cell.x} {cell.y} {cell.z} {cell.alpha} {cell.beta} {cell.gamma}", file=self.file) + else: + print(f"{n_atoms}", file=self.file) + + def _write_comment(self, frame: Frame) -> None: + """ + Writes the comment line of the frame to the file. + + Parameters + ---------- + frame : Frame + The frame to write the comment line of. + """ + + if self._type == TrajectoryFormat.FORCE: + sum_forces = sum(frame.forces) + print( + f"sum of forces: {sum_forces[0]} {sum_forces[1]} {sum_forces[2]}", file=self.file) else: - print(f"{n_atoms}\n", file=self.file) + print("", file=self.file) - def __write_coordinates__(self, xyz: Numpy2DFloatArray, atoms: List[Atom]) -> None: + def _write_xyz(self, xyz: Numpy2DFloatArray, atoms: List[Atom]) -> None: """ - Writes the coordinates of the frame to the file. + Writes the xyz of the frame to the file. If format is 'qmcfc', an additional X 0.0 0.0 0.0 line is written. Parameters ---------- xyz : np.array - The xyz coordinates of the atoms. + The xyz data of the atoms (either positions, velocities or forces). atoms : Elements The elements of the frame. """ - if self.format == "qmcfc": + if self.format == "qmcfc" and self._type == TrajectoryFormat.XYZ: print("X 0.0 0.0 0.0", file=self.file) for i in range(len(atoms)): print( f"{atoms[i].name} {xyz[i][0]} {xyz[i][1]} {xyz[i][2]}", file=self.file) + + def _write_scalar(self, scalar: Numpy1DFloatArray, atoms: List[Atom]) -> None: + """ + Writes the charges of the frame to the file. + + Parameters + ---------- + scalar : np.array + scalar data of the atoms (atm only charges). + atoms : Elements + The elements of the frame. + """ + + for i in range(len(atoms)): + print( + f"{atoms[i].name} {scalar[i]}", file=self.file) diff --git a/tests/io/test_trajectoryWriter.py b/tests/io/test_trajectoryWriter.py index f0c702c0..f8f68d83 100644 --- a/tests/io/test_trajectoryWriter.py +++ b/tests/io/test_trajectoryWriter.py @@ -4,11 +4,13 @@ from PQAnalysis.io.trajectoryWriter import TrajectoryWriter, write_trajectory from PQAnalysis.traj.frame import Frame -from PQAnalysis.traj.trajectory import Trajectory +from PQAnalysis.traj.trajectory import Trajectory, TrajectoryFormat from PQAnalysis.core.cell import Cell from PQAnalysis.core.atomicSystem import AtomicSystem from PQAnalysis.core.atom import Atom +# TODO: here only one option is tested - think of a better way to test all options + def test_write_trajectory(capsys): atoms = [Atom(atom) for atom in ['h', 'o']] @@ -46,34 +48,60 @@ def test__init__(self): writer = TrajectoryWriter(format="pimd-qmcf") assert writer.format == "pimd-qmcf" - def test__write_header__(self, capsys): + def test__write_header(self, capsys): writer = TrajectoryWriter() - writer.__write_header__(1, Cell(10, 10, 10)) + writer._write_header(1, Cell(10, 10, 10)) captured = capsys.readouterr() - assert captured.out == "1 10 10 10 90 90 90\n\n" + assert captured.out == "1 10 10 10 90 90 90\n" - writer.__write_header__(1) + writer._write_header(1) captured = capsys.readouterr() - assert captured.out == "1\n\n" + assert captured.out == "1\n" - def test__write_coordinates__(self, capsys): + def test__write_comment(self, capsys): writer = TrajectoryWriter() - writer.__write_coordinates__( + writer._write_comment(Frame(AtomicSystem( + atoms=[Atom(atom) for atom in ["h", "o"]], cell=Cell(10, 10, 10)))) + + captured = capsys.readouterr() + assert captured.out == "\n" + + forces = np.array([[1, 0, 3], [0, 2, 1]]) + writer._type = TrajectoryFormat.FORCE + writer._write_comment(Frame(AtomicSystem( + atoms=[Atom(atom) for atom in ["h", "o"]], cell=Cell(10, 10, 10), forces=forces))) + + captured = capsys.readouterr() + assert captured.out == "sum of forces: 1 2 4\n" + + def test__write_xyz(self, capsys): + + writer = TrajectoryWriter() + writer._write_xyz( atoms=[Atom(atom) for atom in ["h", "o"]], xyz=np.array([[0, 0, 0], [0, 0, 1]])) captured = capsys.readouterr() assert captured.out == "h 0 0 0\no 0 0 1\n" writer.format = "qmcfc" - writer.__write_coordinates__( + writer._write_xyz( atoms=[Atom(atom) for atom in ["h", "o"]], xyz=np.array([[0, 0, 0], [0, 0, 1]])) captured = capsys.readouterr() assert captured.out == "X 0.0 0.0 0.0\nh 0 0 0\no 0 0 1\n" + def test__write_scalar(self, capsys): + + writer = TrajectoryWriter() + writer._write_scalar( + atoms=[Atom(atom) for atom in ["h", "o"]], scalar=np.array([1, 2])) + + captured = capsys.readouterr() + assert captured.out == "h 1\no 2\n" + def test_write(self, capsys): atoms = [Atom(atom) for atom in ['h', 'o']] @@ -92,3 +120,45 @@ def test_write(self, capsys): captured = capsys.readouterr() assert captured.out == "2 10 10 10 90 90 90\n\nh 0 0 0\no 0 0 1\n2 11 10 10 90 90 90\n\nh 0 0 0\no 0 0 1\n" + + frame1 = Frame(AtomicSystem( + atoms=atoms, vel=coordinates1, cell=Cell(10, 10, 10))) + frame2 = Frame(AtomicSystem( + atoms=atoms, vel=coordinates2, cell=Cell(11, 10, 10))) + + traj = Trajectory([frame1, frame2]) + writer = TrajectoryWriter() + + writer.write(traj, type="vel") + + captured = capsys.readouterr() + assert captured.out == "2 10 10 10 90 90 90\n\nh 0 0 0\no 0 0 1\n2 11 10 10 90 90 90\n\nh 0 0 0\no 0 0 1\n" + + frame1 = Frame(AtomicSystem( + atoms=atoms, forces=coordinates1, cell=Cell(10, 10, 10))) + frame2 = Frame(AtomicSystem( + atoms=atoms, forces=coordinates2, cell=Cell(11, 10, 10))) + + traj = Trajectory([frame1, frame2]) + writer = TrajectoryWriter() + + writer.write(traj, type="force") + + captured = capsys.readouterr() + assert captured.out == "2 10 10 10 90 90 90\nsum of forces: 0 0 1\nh 0 0 0\no 0 0 1\n2 11 10 10 90 90 90\nsum of forces: 0 0 1\nh 0 0 0\no 0 0 1\n" + + charges1 = np.array([1, 2]) + charges2 = np.array([3, 4]) + + frame1 = Frame(AtomicSystem( + atoms=atoms, charges=charges1, cell=Cell(10, 10, 10))) + frame2 = Frame(AtomicSystem( + atoms=atoms, charges=charges2, cell=Cell(11, 10, 10))) + + traj = Trajectory([frame1, frame2]) + writer = TrajectoryWriter() + + writer.write(traj, type="charge") + + captured = capsys.readouterr() + assert captured.out == "2 10 10 10 90 90 90\n\nh 1\no 2\n2 11 10 10 90 90 90\n\nh 3\no 4\n" From 55ad208956f15d4d9dc2ecbd0c9a387c0306ac87 Mon Sep 17 00:00:00 2001 From: Jakob Gamper <97gamjak@gmail.com> Date: Sun, 12 Nov 2023 23:28:23 +0100 Subject: [PATCH 6/6] introduced MDEngineFormat --- PQAnalysis/exceptions.py | 22 ++++- PQAnalysis/io/energyFileReader.py | 22 ++--- PQAnalysis/io/frameReader.py | 2 +- PQAnalysis/io/infoFileReader.py | 28 ++----- PQAnalysis/io/trajectoryReader.py | 3 +- PQAnalysis/io/trajectoryWriter.py | 45 +++++----- PQAnalysis/traj/formats.py | 134 ++++++++++++++++++++++++++++++ PQAnalysis/traj/frame.py | 3 - PQAnalysis/traj/trajectory.py | 76 ----------------- tests/io/test_energyFileReader.py | 19 +++-- tests/io/test_frameReader.py | 2 +- tests/io/test_infoFileReader.py | 13 ++- tests/io/test_trajectoryWriter.py | 17 ++-- 13 files changed, 232 insertions(+), 154 deletions(-) create mode 100644 PQAnalysis/traj/formats.py diff --git a/PQAnalysis/exceptions.py b/PQAnalysis/exceptions.py index af08cf1d..8858675b 100644 --- a/PQAnalysis/exceptions.py +++ b/PQAnalysis/exceptions.py @@ -37,7 +37,7 @@ def __init__(self, id: Any) -> None: super().__init__(self.message) -class TrajectoryFormatError(PQException): +class FormatEnumError(PQException): """ Exception raised if the given enum is not valid """ @@ -46,7 +46,25 @@ def __init__(self, value: object, enum: object) -> None: self.enum = enum self.value = value self.message = f""" -'{self.value}' is not a valid TrajectoryFormat. +'{self.value}' is not a valid {enum.__name__}. Possible values are: {enum.member_repr()} or their case insensitive string representation: {enum.value_repr()}""" super().__init__(self.message) + + +class TrajectoryFormatError(FormatEnumError): + """ + Exception raised if the given enum is not valid + """ + + def __init__(self, value: object, enum: object) -> None: + super().__init__(value, enum) + + +class MDEngineFormatError(FormatEnumError): + """ + Exception raised if the given enum is not valid + """ + + def __init__(self, value: object, enum: object) -> None: + super().__init__(value, enum) diff --git a/PQAnalysis/io/energyFileReader.py b/PQAnalysis/io/energyFileReader.py index b604e1fb..446cc06e 100644 --- a/PQAnalysis/io/energyFileReader.py +++ b/PQAnalysis/io/energyFileReader.py @@ -15,6 +15,7 @@ from .base import BaseReader from .infoFileReader import InfoFileReader from ..physicalData.energy import Energy +from ..traj.formats import MDEngineFormat class EnergyFileReader(BaseReader): @@ -34,15 +35,15 @@ class EnergyFileReader(BaseReader): The name of the info file to read from. withInfoFile : bool If True, the info file was found. + format : MDEngineFormat + The format of the file. Default is MDEngineFormat.PIMD_QMCF. """ - formats = ["pimd-qmcf", "qmcfc"] - def __init__(self, filename: str, info_filename: str | None = None, use_info_file: bool = True, - format: str = "pimd-qmcf" + format: MDEngineFormat | str = MDEngineFormat.PIMD_QMCF ) -> None: """ Initializes the EnergyFileReader with the given filename. @@ -61,13 +62,8 @@ def __init__(self, The name of the info file to read from, by default None use_info_file : bool, optional If True, the info file is searched for, by default True - format : str, optional - The format of the info file, by default "pimd-qmcf" - - Raises - ------ - ValueError - If the format is not supported. + format : MDEngineFormat | str, optional + The format of the file, by default MDEngineFormat.PIMD_QMCF """ super().__init__(filename) self.info_filename = info_filename @@ -77,11 +73,7 @@ def __init__(self, else: self.withInfoFile = False - if format not in self.formats: - raise ValueError( - f"Format {format} is not supported. Supported formats are {self.formats}.") - - self.format = format + self.format = MDEngineFormat(format) def read(self) -> Energy: """ diff --git a/PQAnalysis/io/frameReader.py b/PQAnalysis/io/frameReader.py index 1846e669..54f1ded8 100644 --- a/PQAnalysis/io/frameReader.py +++ b/PQAnalysis/io/frameReader.py @@ -20,7 +20,7 @@ from ..core.cell import Cell from ..types import Numpy2DFloatArray, Numpy1DFloatArray from ..traj.frame import Frame -from ..traj.trajectory import TrajectoryFormat +from ..traj.formats import TrajectoryFormat from ..exceptions import ElementNotFoundError diff --git a/PQAnalysis/io/infoFileReader.py b/PQAnalysis/io/infoFileReader.py index b0cd97e0..b8f2d5e8 100644 --- a/PQAnalysis/io/infoFileReader.py +++ b/PQAnalysis/io/infoFileReader.py @@ -12,6 +12,7 @@ from beartype.typing import Tuple, Dict from .base import BaseReader +from ..traj.formats import MDEngineFormat class InfoFileReader(BaseReader): @@ -27,13 +28,11 @@ class InfoFileReader(BaseReader): ---------- filename : str The name of the file to read from. - format : str - The format of the info file. + format : MDEngineFormat + The format of the info file. Default is MDEngineFormat.PIMD_QMCF. """ - formats = ["pimd-qmcf", "qmcfc"] - - def __init__(self, filename: str, format: str = "pimd-qmcf") -> None: + def __init__(self, filename: str, format: MDEngineFormat | str = MDEngineFormat.PIMD_QMCF) -> None: """ Initializes the InfoFileReader with the given filename. @@ -41,21 +40,12 @@ def __init__(self, filename: str, format: str = "pimd-qmcf") -> None: ---------- filename : str The name of the file to read from. - format : str, optional - The format of the info file, by default "pimd-qmcf" - - Raises - ------ - ValueError - If the format is not supported. + format : MDEngineFormat | str, optional + The format of the info file. Default is MDEngineFormat.PIMD_QMCF. """ super().__init__(filename) - if format not in self.formats: - raise ValueError( - f"Format {format} is not supported. Supported formats are {self.formats}.") - - self.format = format + self.format = MDEngineFormat(format) def read(self) -> Tuple[Dict, Dict | None]: """ @@ -71,9 +61,9 @@ def read(self) -> Tuple[Dict, Dict | None]: The units of the info file as a dictionary. The keys are the names of the information strings. The values are the corresponding units. """ - if self.format == "pimd-qmcf": + if self.format == MDEngineFormat.PIMD_QMCF: return self.read_pimd_qmcf() - elif self.format == "qmcfc": + elif self.format == MDEngineFormat.QMCFC: return self.read_qmcfc() def read_pimd_qmcf(self) -> Tuple[Dict, Dict]: diff --git a/PQAnalysis/io/trajectoryReader.py b/PQAnalysis/io/trajectoryReader.py index 837dd2c1..0be98564 100644 --- a/PQAnalysis/io/trajectoryReader.py +++ b/PQAnalysis/io/trajectoryReader.py @@ -10,7 +10,8 @@ """ from .base import BaseReader -from ..traj.trajectory import Trajectory, TrajectoryFormat +from ..traj.trajectory import Trajectory +from ..traj.formats import TrajectoryFormat from .frameReader import FrameReader diff --git a/PQAnalysis/io/trajectoryWriter.py b/PQAnalysis/io/trajectoryWriter.py index cbc3c476..bde826eb 100644 --- a/PQAnalysis/io/trajectoryWriter.py +++ b/PQAnalysis/io/trajectoryWriter.py @@ -14,14 +14,19 @@ from beartype.typing import List from .base import BaseWriter -from ..traj.trajectory import Trajectory, TrajectoryFormat +from ..traj.trajectory import Trajectory +from ..traj.formats import TrajectoryFormat, MDEngineFormat from ..traj.frame import Frame from ..core.cell import Cell from ..core.atom import Atom from ..types import Numpy2DFloatArray, Numpy1DFloatArray -def write_trajectory(traj, filename: str | None = None, format: str | None = None, type: TrajectoryFormat | str = TrajectoryFormat.XYZ) -> None: +def write_trajectory(traj, + filename: str | None = None, + format: MDEngineFormat | str = MDEngineFormat.PIMD_QMCF, + type: TrajectoryFormat | str = TrajectoryFormat.XYZ + ) -> None: """ Wrapper for TrajectoryWriter to write a trajectory to a file. @@ -34,8 +39,8 @@ def write_trajectory(traj, filename: str | None = None, format: str | None = Non The trajectory to write. filename : str, optional The name of the file to write to. If None, the output is printed to stdout. - format : str, optional - The format of the file. If None, the default PIMD-QMCF format is used. + format : MDEngineFormat | str, optional + The format of the md engine for the output file. The default is MDEngineFormat.PIMD_QMCF. type : TrajectoryFormat | str, optional The type of the data to write to the file. Default is TrajectoryFormat.XYZ. @@ -59,6 +64,7 @@ class TrajectoryWriter(BaseWriter): formats : list of str The available formats for the trajectory file. + #TODO: put this description into formats!!! PIMD-QMCF format for one frame: header line containing the number of atoms and the cell information (if available) arbitrary comment line @@ -75,16 +81,16 @@ class TrajectoryWriter(BaseWriter): Attributes ---------- - format : str - The format of the file. + format : MDEngineFormat + The format of the md engine for the output file. The default is MDEngineFormat.PIMD_QMCF. """ - formats = [None, 'pimd-qmcf', 'qmcfc'] + _format: MDEngineFormat _type: TrajectoryFormat = TrajectoryFormat.XYZ def __init__(self, filename: str | None = None, - format: str | None = None, + format: MDEngineFormat | str = MDEngineFormat.PIMD_QMCF, mode: str = 'w' ) -> None: """ @@ -94,22 +100,15 @@ def __init__(self, ---------- filename : str, optional The name of the file to write to. If None, the output is printed to stdout. - format : str, optional - The format of the file. If None, the default PIMD-QMCF format is used. - (see TrajectoryWriter.formats for available formats) + format : MDEngineFormat | str, optional + The format of the md engine for the output file. The default is MDEngineFormat.PIMD_QMCF. mode : str, optional The mode of the file. Either 'w' for write or 'a' for append. """ super().__init__(filename, mode) - if format not in self.formats: - raise ValueError( - 'Invalid format. Has to be either \'pimd-qmcf\', \'qmcfc\' or \'None\'.') - if format is None: - format = 'pimd-qmcf' - - self.format = format + self.format = MDEngineFormat(format) def write(self, trajectory: Trajectory, type: TrajectoryFormat | str = TrajectoryFormat.XYZ) -> None: """ @@ -253,7 +252,7 @@ def _write_xyz(self, xyz: Numpy2DFloatArray, atoms: List[Atom]) -> None: The elements of the frame. """ - if self.format == "qmcfc" and self._type == TrajectoryFormat.XYZ: + if self.format == MDEngineFormat.QMCFC and self._type == TrajectoryFormat.XYZ: print("X 0.0 0.0 0.0", file=self.file) for i in range(len(atoms)): @@ -275,3 +274,11 @@ def _write_scalar(self, scalar: Numpy1DFloatArray, atoms: List[Atom]) -> None: for i in range(len(atoms)): print( f"{atoms[i].name} {scalar[i]}", file=self.file) + + @property + def format(self) -> MDEngineFormat: + return self._format + + @format.setter + def format(self, format: MDEngineFormat | str) -> None: + self._format = MDEngineFormat(format) diff --git a/PQAnalysis/traj/formats.py b/PQAnalysis/traj/formats.py new file mode 100644 index 00000000..7ac597ab --- /dev/null +++ b/PQAnalysis/traj/formats.py @@ -0,0 +1,134 @@ +""" +A module containing different format types of the trajectory. + +... + +Classes +------- +Format + An enumeration super class of the various supported trajectory formats. +TrajectoryFormat + An enumeration of the supported trajectory formats. +MDEngineFormat + An enumeration of the supported MD engine formats. +""" + +from enum import Enum +from beartype.typing import Any + +from ..exceptions import TrajectoryFormatError, MDEngineFormatError + + +class Format(Enum): + """ + An enumeration super class of the various supported trajectory formats. + """ + + @classmethod + def member_repr(cls) -> str: + """ + This method returns a string representation of the members of the enumeration. + + Returns + ------- + str + A string representation of the members of the enumeration. + """ + + return ', '.join([str(member) for member in cls]) + + @classmethod + def value_repr(cls) -> str: + """ + This method returns a string representation of the values of the members of the enumeration. + + Returns + ------- + str + A string representation of the values of the members of the enumeration. + """ + + return ', '.join([str(member.value) for member in cls]) + + +class TrajectoryFormat(Format): + """ + An enumeration of the supported trajectory formats. + + ... + + Attributes + ---------- + XYZ : str + The XYZ format. + VEL : str + The VEL format. + FORCE : str + The FORCE format. + CHARGE : str + The CHARGE format. + """ + + XYZ = "XYZ" + VEL = "VEL" + FORCE = "FORCE" + CHARGE = "CHARGE" + + @classmethod + def _missing_(cls, value: object) -> Any: + """ + This method allows a TrajectoryFormat to be retrieved from a string. + """ + value = value.lower() + for member in cls: + if member.value.lower() == value: + return member + + raise TrajectoryFormatError(value, cls) + + +class MDEngineFormat(Format): + """ + An enumeration of the supported MD engine formats. + + ... + + Attributes + ---------- + PIMD-QMCF: str + The PIMD-QMCF format. + QMCFC: str + The QMCFC format. + """ + + PIMD_QMCF = "PIMD-QMCF" + QMCFC = "QMCFC" + + @classmethod + def _missing_(cls, value: object) -> Any: + """ + This method allows an MDEngineFormat format to be retrieved from a string. + """ + value = value.lower() + for member in cls: + if member.value.lower() == value: + return member + + raise MDEngineFormatError(value, cls) + + @classmethod + def isQMCFType(cls, format: Any) -> bool: + """ + This method checks if the given format is a QMCF format. + + Parameters + ---------- + format : Any + The format to check. + + Returns + ------- + bool + True if the format is a QMCF format, False otherwise. + """ + return format in [cls.PIMD_QMCF, cls.QMCFC] diff --git a/PQAnalysis/traj/frame.py b/PQAnalysis/traj/frame.py index 4a2a9e68..313229d2 100644 --- a/PQAnalysis/traj/frame.py +++ b/PQAnalysis/traj/frame.py @@ -81,7 +81,6 @@ def compute_com_frame(self, group=None) -> Frame: print(self.n_atoms) - j = 0 for i in range(0, self.n_atoms, group): atomic_system = AtomicSystem( atoms=self.atoms[i:i+group], pos=self.pos[i:i+group], cell=self.cell) @@ -91,8 +90,6 @@ def compute_com_frame(self, group=None) -> Frame: print(pos) names.append(atomic_system.combined_name) - j += 1 - names = [Atom(name, use_guess_element=False) for name in names] print(pos) diff --git a/PQAnalysis/traj/trajectory.py b/PQAnalysis/traj/trajectory.py index 76321be3..b3507dc5 100644 --- a/PQAnalysis/traj/trajectory.py +++ b/PQAnalysis/traj/trajectory.py @@ -5,8 +5,6 @@ Classes ------- -TrajectoryFormat - An enumeration of the supported trajectory formats. Trajectory A trajectory is a sequence of frames. """ @@ -16,82 +14,8 @@ import numpy as np from beartype.typing import List, Iterator, Any -from enum import Enum from .frame import Frame -from ..exceptions import TrajectoryFormatError - - -class TrajectoryFormat(Enum): - """ - An enumeration of the supported trajectory formats. - - ... - - Attributes - ---------- - XYZ : str - The XYZ format. - VEL : str - The VEL format. - FORCE : str - The FORCE format. - CHARGE : str - The CHARGE format. - """ - - XYZ = "XYZ" - VEL = "VEL" - FORCE = "FORCE" - CHARGE = "CHARGE" - - @classmethod - def _missing_(cls, value: object) -> Any: - """ - This method allows a trajectory format to be retrieved from a string. - - Parameters - ---------- - value : str - _description_ - - Returns - ------- - Any - _description_ - """ - value = value.lower() - for member in cls: - if member.value.lower() == value: - return member - - raise TrajectoryFormatError(value, cls) - - @classmethod - def member_repr(cls) -> str: - """ - This method returns a string representation of the members of the enumeration. - - Returns - ------- - str - A string representation of the members of the enumeration. - """ - - return ', '.join([str(member) for member in cls]) - - @classmethod - def value_repr(cls) -> str: - """ - This method returns a string representation of the values of the members of the enumeration. - - Returns - ------- - str - A string representation of the values of the members of the enumeration. - """ - - return ', '.join([str(member.value) for member in cls]) class Trajectory: diff --git a/tests/io/test_energyFileReader.py b/tests/io/test_energyFileReader.py index 73693650..e58c1211 100644 --- a/tests/io/test_energyFileReader.py +++ b/tests/io/test_energyFileReader.py @@ -5,6 +5,8 @@ from PQAnalysis.io.energyFileReader import EnergyFileReader from PQAnalysis.io.infoFileReader import InfoFileReader +from PQAnalysis.traj.formats import MDEngineFormat +from PQAnalysis.exceptions import MDEngineFormatError class TestEnergyReader: @@ -18,19 +20,19 @@ def test__init__(self, test_with_data_dir): assert reader.filename == "md-01.en" assert reader.info_filename == "md-01.info" assert reader.withInfoFile == True - assert reader.format == "pimd-qmcf" + assert reader.format == MDEngineFormat.PIMD_QMCF reader = EnergyFileReader("md-01.en", use_info_file=False) assert reader.filename == "md-01.en" assert reader.info_filename == None assert reader.withInfoFile == False - assert reader.format == "pimd-qmcf" + assert reader.format == MDEngineFormat.PIMD_QMCF reader = EnergyFileReader("md-01_noinfo.en") assert reader.filename == "md-01_noinfo.en" assert reader.info_filename == None assert reader.withInfoFile == False - assert reader.format == "pimd-qmcf" + assert reader.format == MDEngineFormat.PIMD_QMCF with pytest.raises(FileNotFoundError) as exception: EnergyFileReader( @@ -43,18 +45,21 @@ def test__init__(self, test_with_data_dir): assert reader.filename == "md-01_noinfo.en" assert reader.info_filename == "md-01.info" assert reader.withInfoFile == True - assert reader.format == "pimd-qmcf" + assert reader.format == MDEngineFormat.PIMD_QMCF reader = EnergyFileReader("md-01.en", format="qmcfc") assert reader.filename == "md-01.en" assert reader.info_filename == "md-01.info" assert reader.withInfoFile == True - assert reader.format == "qmcfc" + assert reader.format == MDEngineFormat.QMCFC - with pytest.raises(ValueError) as exception: + with pytest.raises(MDEngineFormatError) as exception: EnergyFileReader("md-01.en", format="tmp") assert str( - exception.value) == "Format tmp is not supported. Supported formats are ['pimd-qmcf', 'qmcfc']." + exception.value) == f""" +'tmp' is not a valid MDEngineFormat. +Possible values are: {MDEngineFormat.member_repr()} +or their case insensitive string representation: {MDEngineFormat.value_repr()}""" @pytest.mark.parametrize("example_dir", ["readEnergyFile"], indirect=False) def test__info_file_found__(self, test_with_data_dir, capsys): diff --git a/tests/io/test_frameReader.py b/tests/io/test_frameReader.py index 0da7447b..bf17aff0 100644 --- a/tests/io/test_frameReader.py +++ b/tests/io/test_frameReader.py @@ -7,7 +7,7 @@ from PQAnalysis.core.cell import Cell from PQAnalysis.core.atom import Atom from PQAnalysis.exceptions import TrajectoryFormatError -from PQAnalysis.traj.trajectory import TrajectoryFormat +from PQAnalysis.traj.formats import TrajectoryFormat class TestFrameReader: diff --git a/tests/io/test_infoFileReader.py b/tests/io/test_infoFileReader.py index 29abdca9..942dab62 100644 --- a/tests/io/test_infoFileReader.py +++ b/tests/io/test_infoFileReader.py @@ -3,6 +3,8 @@ from beartype.roar import BeartypeException from PQAnalysis.io.infoFileReader import InfoFileReader +from PQAnalysis.traj.formats import MDEngineFormat +from PQAnalysis.exceptions import MDEngineFormatError @pytest.mark.parametrize("example_dir", ["readInfoFile"], indirect=False) @@ -15,19 +17,22 @@ def test__init__(test_with_data_dir): InfoFileReader( "md-01.info", format=None) - with pytest.raises(ValueError) as exception: + with pytest.raises(MDEngineFormatError) as exception: InfoFileReader( "md-01.info", format="tmp") assert str( - exception.value) == "Format tmp is not supported. Supported formats are ['pimd-qmcf', 'qmcfc']." + exception.value) == f""" +'tmp' is not a valid MDEngineFormat. +Possible values are: {MDEngineFormat.member_repr()} +or their case insensitive string representation: {MDEngineFormat.value_repr()}""" reader = InfoFileReader("md-01.info") assert reader.filename == "md-01.info" - assert reader.format == "pimd-qmcf" + assert reader.format == MDEngineFormat.PIMD_QMCF reader = InfoFileReader("md-01.info", format="qmcfc") assert reader.filename == "md-01.info" - assert reader.format == "qmcfc" + assert reader.format == MDEngineFormat.QMCFC @pytest.mark.parametrize("example_dir", ["readInfoFile"], indirect=False) diff --git a/tests/io/test_trajectoryWriter.py b/tests/io/test_trajectoryWriter.py index f8f68d83..2de968a4 100644 --- a/tests/io/test_trajectoryWriter.py +++ b/tests/io/test_trajectoryWriter.py @@ -4,10 +4,12 @@ from PQAnalysis.io.trajectoryWriter import TrajectoryWriter, write_trajectory from PQAnalysis.traj.frame import Frame -from PQAnalysis.traj.trajectory import Trajectory, TrajectoryFormat +from PQAnalysis.traj.trajectory import Trajectory +from PQAnalysis.traj.formats import TrajectoryFormat, MDEngineFormat from PQAnalysis.core.cell import Cell from PQAnalysis.core.atomicSystem import AtomicSystem from PQAnalysis.core.atom import Atom +from PQAnalysis.exceptions import MDEngineFormatError # TODO: here only one option is tested - think of a better way to test all options @@ -31,22 +33,25 @@ def test_write_trajectory(capsys): class TestTrajectoryWriter: def test__init__(self): - with pytest.raises(ValueError) as exception: + with pytest.raises(MDEngineFormatError) as exception: TrajectoryWriter(format="notAFormat") assert str( - exception.value) == "Invalid format. Has to be either \'pimd-qmcf\', \'qmcfc\' or \'None\'." + exception.value) == f""" +'notaformat' is not a valid MDEngineFormat. +Possible values are: {MDEngineFormat.member_repr()} +or their case insensitive string representation: {MDEngineFormat.value_repr()}""" writer = TrajectoryWriter() assert writer.file == sys.stdout assert writer.filename is None assert writer.mode == "a" - assert writer.format == "pimd-qmcf" + assert writer.format == MDEngineFormat.PIMD_QMCF writer = TrajectoryWriter(format="qmcfc") - assert writer.format == "qmcfc" + assert writer.format == MDEngineFormat.QMCFC writer = TrajectoryWriter(format="pimd-qmcf") - assert writer.format == "pimd-qmcf" + assert writer.format == MDEngineFormat.PIMD_QMCF def test__write_header(self, capsys):