From 955f06c0bfd396b13634492484fcfc7679df9260 Mon Sep 17 00:00:00 2001 From: Jakob Gamper <97gamjak@gmail.com> Date: Sat, 11 Nov 2023 23:16:05 +0100 Subject: [PATCH 1/2] added readvelocity and trajectory format class --- PQAnalysis/io/trajectoryReader.py | 81 ++++++++++++++++++++++++++++--- PQAnalysis/traj/trajectory.py | 19 ++++++++ tests/io/test_trajectoryReader.py | 12 ++--- 3 files changed, 98 insertions(+), 14 deletions(-) diff --git a/PQAnalysis/io/trajectoryReader.py b/PQAnalysis/io/trajectoryReader.py index 9bd32f23..13a475bb 100644 --- a/PQAnalysis/io/trajectoryReader.py +++ b/PQAnalysis/io/trajectoryReader.py @@ -14,10 +14,11 @@ import numpy as np from beartype.typing import Tuple, List +from enum import Enum from .base import BaseReader from ..traj.frame import Frame -from ..traj.trajectory import Trajectory +from ..traj.trajectory import Trajectory, TrajectoryFormat from ..core.cell import Cell from ..core.atom import Atom from ..core.atomicSystem import AtomicSystem @@ -41,7 +42,7 @@ class TrajectoryReader(BaseReader): The list of frames read from the file. """ - def __init__(self, filename: str) -> None: + def __init__(self, filename: str, format: TrajectoryFormat | str = TrajectoryFormat.XYZ) -> None: """ Initializes the TrajectoryReader with the given filename. @@ -52,6 +53,7 @@ def __init__(self, filename: str) -> None: """ super().__init__(filename) self.frames = [] + self.format = format def read(self) -> Trajectory: """ @@ -79,7 +81,8 @@ def read(self) -> Trajectory: frame_string += line elif line.split()[0].isdigit(): if frame_string != '': - self.frames.append(frame_reader.read(frame_string)) + self.frames.append(frame_reader.read( + frame_string, format=self.format)) frame_string = line else: frame_string += line @@ -99,10 +102,38 @@ class FrameReader: FrameReader reads a frame from a string. """ - def read(self, frame_string: str) -> Frame: + def read(self, frame_string: str, format: TrajectoryFormat | str = TrajectoryFormat.XYZ) -> Frame: """ Reads a frame from a string. + Parameters + ---------- + frame_string : str + The string to read the frame from. + format : TrajectoryFormat | str, optional + The format of the trajectory. Default is TrajectoryFormat.XYZ. + + Returns + ------- + Frame + The frame read from the string. + + Raises + ------ + ValueError + If the given format is not valid. + """ + if format == TrajectoryFormat.XYZ: + return self.read_positions(frame_string) + elif format == TrajectoryFormat.VELOCS: + return self.read_velocities(frame_string) + else: + raise ValueError('Invalid format.') + + def read_positions(self, frame_string: str) -> Frame: + """ + Reads the positions of the atoms in a frame from a string. + Parameters ---------- frame_string : str @@ -122,9 +153,9 @@ def read(self, frame_string: str) -> Frame: splitted_frame_string = frame_string.split('\n') header_line = splitted_frame_string[0] - n_atoms, cell = self.__read_header_line__(header_line) + n_atoms, cell = self._read_header_line(header_line) - xyz, atoms = self.__read_xyz__(splitted_frame_string, n_atoms) + xyz, atoms = self._read_xyz(splitted_frame_string, n_atoms) try: atoms = [Atom(atom) for atom in atoms] @@ -133,7 +164,41 @@ def read(self, frame_string: str) -> Frame: return Frame(AtomicSystem(atoms=atoms, pos=xyz, cell=cell)) - def __read_header_line__(self, header_line: str) -> Tuple[int, Cell | None]: + def read_velocities(self, frame_string: str) -> Frame: + """ + Reads the velocities of the atoms in a frame from a string. + + Parameters + ---------- + frame_string : str + The string to read the frame from. + + Returns + ------- + Frame + The frame read from the string. + + Raises + ------ + TypeError + If the given frame_string is not a string. + """ + + splitted_frame_string = frame_string.split('\n') + header_line = splitted_frame_string[0] + + n_atoms, cell = self._read_header_line(header_line) + + velocs, atoms = self._read_xyz(splitted_frame_string, n_atoms) + + try: + atoms = [Atom(atom) for atom in atoms] + except ElementNotFoundError: + atoms = [Atom(atom, use_guess_element=False) for atom in atoms] + + return Frame(AtomicSystem(atoms=atoms, velocs=velocs, cell=cell)) + + def _read_header_line(self, header_line: str) -> Tuple[int, Cell | None]: """ Reads the header line of a frame. @@ -178,7 +243,7 @@ def __read_header_line__(self, header_line: str) -> Tuple[int, Cell | None]: return n_atoms, cell - def __read_xyz__(self, splitted_frame_string: List[str], n_atoms: int) -> Tuple[Numpy2DFloatArray, List[str]]: + def _read_xyz(self, splitted_frame_string: List[str], n_atoms: int) -> Tuple[Numpy2DFloatArray, List[str]]: """ Reads the xyz coordinates and the atom names from the given string. diff --git a/PQAnalysis/traj/trajectory.py b/PQAnalysis/traj/trajectory.py index b3507dc5..3d2b9263 100644 --- a/PQAnalysis/traj/trajectory.py +++ b/PQAnalysis/traj/trajectory.py @@ -14,10 +14,29 @@ import numpy as np from beartype.typing import List, Iterator, Any +from enum import Enum from .frame import Frame +class TrajectoryFormat(Enum): + """ + An enumeration of the supported trajectory formats. + + ... + + Attributes + ---------- + XYZ : str + The XYZ format. + VELOCS : str + The VELOCS format. + """ + + XYZ = "xyz" + VELOCS = "velocs" + + class Trajectory: """ A trajectory object is a sequence of frames. diff --git a/tests/io/test_trajectoryReader.py b/tests/io/test_trajectoryReader.py index 204ad49b..293a5c9d 100644 --- a/tests/io/test_trajectoryReader.py +++ b/tests/io/test_trajectoryReader.py @@ -60,22 +60,22 @@ def test__read_header_line__(self): reader = FrameReader() with pytest.raises(ValueError) as exception: - reader.__read_header_line__("1 2.0 3.0") + reader._read_header_line("1 2.0 3.0") assert str( exception.value) == "Invalid file format in header line of Frame." - n_atoms, cell = reader.__read_header_line__( + n_atoms, cell = reader._read_header_line( "1 2.0 3.0 4.0 5.0 6.0 7.0") assert n_atoms == 1 assert np.allclose(cell.box_lengths, [2.0, 3.0, 4.0]) assert np.allclose(cell.box_angles, [5.0, 6.0, 7.0]) - n_atoms, cell = reader.__read_header_line__("2 2.0 3.0 4.0") + n_atoms, cell = reader._read_header_line("2 2.0 3.0 4.0") assert n_atoms == 2 assert np.allclose(cell.box_lengths, [2.0, 3.0, 4.0]) assert np.allclose(cell.box_angles, [90.0, 90.0, 90.0]) - n_atoms, cell = reader.__read_header_line__("3") + n_atoms, cell = reader._read_header_line("3") assert n_atoms == 3 assert cell is None @@ -83,12 +83,12 @@ def test__read_xyz__(self): reader = FrameReader() with pytest.raises(ValueError) as exception: - reader.__read_xyz__( + reader._read_xyz( ["", "", "h 1.0 2.0 3.0", "o 2.0 2.0"], n_atoms=2) assert str( exception.value) == "Invalid file format in xyz coordinates of Frame." - xyz, atoms = reader.__read_xyz__( + xyz, atoms = reader._read_xyz( ["", "", "h 1.0 2.0 3.0", "o 2.0 2.0 2.0"], n_atoms=2) assert np.allclose(xyz, [[1.0, 2.0, 3.0], [2.0, 2.0, 2.0]]) assert atoms == ["h", "o"] From ebd8356799b18a958241423914524c8bd8f5cd9a Mon Sep 17 00:00:00 2001 From: Jakob Gamper <97gamjak@gmail.com> Date: Sun, 12 Nov 2023 00:00:28 +0100 Subject: [PATCH 2/2] velocity reading fully tested --- PQAnalysis/io/trajectoryReader.py | 13 +++++++------ PQAnalysis/traj/frame.py | 12 ++++++++++++ PQAnalysis/traj/trajectory.py | 29 +++++++++++++++++++++++++---- tests/io/test_trajectoryReader.py | 17 +++++++++++++++++ 4 files changed, 61 insertions(+), 10 deletions(-) diff --git a/PQAnalysis/io/trajectoryReader.py b/PQAnalysis/io/trajectoryReader.py index 13a475bb..f1b62dd6 100644 --- a/PQAnalysis/io/trajectoryReader.py +++ b/PQAnalysis/io/trajectoryReader.py @@ -123,12 +123,13 @@ def read(self, frame_string: str, format: TrajectoryFormat | str = TrajectoryFor ValueError If the given format is not valid. """ - if format == TrajectoryFormat.XYZ: + + # Note: TrajectoryFormat(format) automatically gives an error if format is not a valid TrajectoryFormat + + if TrajectoryFormat(format) is TrajectoryFormat.XYZ: return self.read_positions(frame_string) - elif format == TrajectoryFormat.VELOCS: + elif TrajectoryFormat(format) is TrajectoryFormat.VEL: return self.read_velocities(frame_string) - else: - raise ValueError('Invalid format.') def read_positions(self, frame_string: str) -> Frame: """ @@ -189,14 +190,14 @@ def read_velocities(self, frame_string: str) -> Frame: n_atoms, cell = self._read_header_line(header_line) - velocs, atoms = self._read_xyz(splitted_frame_string, n_atoms) + vel, atoms = self._read_xyz(splitted_frame_string, n_atoms) try: atoms = [Atom(atom) for atom in atoms] except ElementNotFoundError: atoms = [Atom(atom, use_guess_element=False) for atom in atoms] - return Frame(AtomicSystem(atoms=atoms, velocs=velocs, cell=cell)) + return Frame(AtomicSystem(atoms=atoms, vel=vel, cell=cell)) def _read_header_line(self, header_line: str) -> Tuple[int, Cell | None]: """ diff --git a/PQAnalysis/traj/frame.py b/PQAnalysis/traj/frame.py index 7ebd3118..dff1554f 100644 --- a/PQAnalysis/traj/frame.py +++ b/PQAnalysis/traj/frame.py @@ -190,6 +190,18 @@ def pos(self) -> Numpy2DFloatArray: """ return self.system.pos + @property + def vel(self) -> Numpy2DFloatArray: + """ + The positions of the atoms in the system. + + Returns + ------- + Numpy2DFloatArray + The positions of the atoms in the system. + """ + return self.system.vel + @property def atoms(self) -> List[Atom]: """ diff --git a/PQAnalysis/traj/trajectory.py b/PQAnalysis/traj/trajectory.py index 3d2b9263..2c36b4e5 100644 --- a/PQAnalysis/traj/trajectory.py +++ b/PQAnalysis/traj/trajectory.py @@ -29,12 +29,33 @@ class TrajectoryFormat(Enum): ---------- XYZ : str The XYZ format. - VELOCS : str - The VELOCS format. + VEL : str + The VEL format. """ - XYZ = "xyz" - VELOCS = "velocs" + XYZ = "XYZ" + VEL = "VEL" + + @classmethod + def _missing_(cls, value: object) -> Any: + """ + This method allows a trajectory format to be retrieved from a string. + + Parameters + ---------- + value : str + _description_ + + Returns + ------- + Any + _description_ + """ + value = value.lower() + for member in cls: + if member.value.lower() == value: + return member + return None class Trajectory: diff --git a/tests/io/test_trajectoryReader.py b/tests/io/test_trajectoryReader.py index 293a5c9d..67cfe5f9 100644 --- a/tests/io/test_trajectoryReader.py +++ b/tests/io/test_trajectoryReader.py @@ -115,3 +115,20 @@ def test_read(self): assert np.allclose(frame.pos, [ [1.0, 2.0, 3.0], [2.0, 2.0, 2.0]]) assert frame.cell == Cell(2.0, 3.0, 4.0, 5.0, 6.0, 7.0) + + frame = reader.read( + "2 2.0 3.0 4.0 5.0 6.0 7.0\n\nh 1.0 2.0 3.0\no1 2.0 2.0 2.0", format="vel") + assert frame.n_atoms == 2 + assert frame.atoms == [Atom(atom, use_guess_element=False) + for atom in ["h", "o1"]] + assert np.allclose(frame.vel, [ + [1.0, 2.0, 3.0], [2.0, 2.0, 2.0]]) + assert frame.cell == Cell(2.0, 3.0, 4.0, 5.0, 6.0, 7.0) + + def test_read_invalid_format(self): + reader = FrameReader() + + with pytest.raises(ValueError) as exception: + reader.read("", format="invalid") + assert str( + exception.value) == "'invalid' is not a valid TrajectoryFormat"