Skip to content

Commit

Permalink
Merge branch 'feature/readVelocsFile' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
97gamjak committed Nov 11, 2023
2 parents 15ae9c1 + ebd8356 commit 95dd3fd
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 14 deletions.
82 changes: 74 additions & 8 deletions PQAnalysis/io/trajectoryReader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
import numpy as np

from beartype.typing import Tuple, List
from enum import Enum

from .base import BaseReader
from ..traj.frame import Frame
from ..traj.trajectory import Trajectory
from ..traj.trajectory import Trajectory, TrajectoryFormat
from ..core.cell import Cell
from ..core.atom import Atom
from ..core.atomicSystem import AtomicSystem
Expand All @@ -41,7 +42,7 @@ class TrajectoryReader(BaseReader):
The list of frames read from the file.
"""

def __init__(self, filename: str) -> None:
def __init__(self, filename: str, format: TrajectoryFormat | str = TrajectoryFormat.XYZ) -> None:
"""
Initializes the TrajectoryReader with the given filename.
Expand All @@ -52,6 +53,7 @@ def __init__(self, filename: str) -> None:
"""
super().__init__(filename)
self.frames = []
self.format = format

def read(self) -> Trajectory:
"""
Expand Down Expand Up @@ -79,7 +81,8 @@ def read(self) -> Trajectory:
frame_string += line
elif line.split()[0].isdigit():
if frame_string != '':
self.frames.append(frame_reader.read(frame_string))
self.frames.append(frame_reader.read(
frame_string, format=self.format))
frame_string = line
else:
frame_string += line
Expand All @@ -99,10 +102,39 @@ class FrameReader:
FrameReader reads a frame from a string.
"""

def read(self, frame_string: str) -> Frame:
def read(self, frame_string: str, format: TrajectoryFormat | str = TrajectoryFormat.XYZ) -> Frame:
"""
Reads a frame from a string.
Parameters
----------
frame_string : str
The string to read the frame from.
format : TrajectoryFormat | str, optional
The format of the trajectory. Default is TrajectoryFormat.XYZ.
Returns
-------
Frame
The frame read from the string.
Raises
------
ValueError
If the given format is not valid.
"""

# Note: TrajectoryFormat(format) automatically gives an error if format is not a valid TrajectoryFormat

if TrajectoryFormat(format) is TrajectoryFormat.XYZ:
return self.read_positions(frame_string)
elif TrajectoryFormat(format) is TrajectoryFormat.VEL:
return self.read_velocities(frame_string)

def read_positions(self, frame_string: str) -> Frame:
"""
Reads the positions of the atoms in a frame from a string.
Parameters
----------
frame_string : str
Expand All @@ -122,9 +154,9 @@ def read(self, frame_string: str) -> Frame:
splitted_frame_string = frame_string.split('\n')
header_line = splitted_frame_string[0]

n_atoms, cell = self.__read_header_line__(header_line)
n_atoms, cell = self._read_header_line(header_line)

xyz, atoms = self.__read_xyz__(splitted_frame_string, n_atoms)
xyz, atoms = self._read_xyz(splitted_frame_string, n_atoms)

try:
atoms = [Atom(atom) for atom in atoms]
Expand All @@ -133,7 +165,41 @@ def read(self, frame_string: str) -> Frame:

return Frame(AtomicSystem(atoms=atoms, pos=xyz, cell=cell))

def __read_header_line__(self, header_line: str) -> Tuple[int, Cell | None]:
def read_velocities(self, frame_string: str) -> Frame:
"""
Reads the velocities of the atoms in a frame from a string.
Parameters
----------
frame_string : str
The string to read the frame from.
Returns
-------
Frame
The frame read from the string.
Raises
------
TypeError
If the given frame_string is not a string.
"""

splitted_frame_string = frame_string.split('\n')
header_line = splitted_frame_string[0]

n_atoms, cell = self._read_header_line(header_line)

vel, atoms = self._read_xyz(splitted_frame_string, n_atoms)

try:
atoms = [Atom(atom) for atom in atoms]
except ElementNotFoundError:
atoms = [Atom(atom, use_guess_element=False) for atom in atoms]

return Frame(AtomicSystem(atoms=atoms, vel=vel, cell=cell))

def _read_header_line(self, header_line: str) -> Tuple[int, Cell | None]:
"""
Reads the header line of a frame.
Expand Down Expand Up @@ -178,7 +244,7 @@ def __read_header_line__(self, header_line: str) -> Tuple[int, Cell | None]:

return n_atoms, cell

def __read_xyz__(self, splitted_frame_string: List[str], n_atoms: int) -> Tuple[Numpy2DFloatArray, List[str]]:
def _read_xyz(self, splitted_frame_string: List[str], n_atoms: int) -> Tuple[Numpy2DFloatArray, List[str]]:
"""
Reads the xyz coordinates and the atom names from the given string.
Expand Down
12 changes: 12 additions & 0 deletions PQAnalysis/traj/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,18 @@ def pos(self) -> Numpy2DFloatArray:
"""
return self.system.pos

@property
def vel(self) -> Numpy2DFloatArray:
"""
The positions of the atoms in the system.
Returns
-------
Numpy2DFloatArray
The positions of the atoms in the system.
"""
return self.system.vel

@property
def atoms(self) -> List[Atom]:
"""
Expand Down
40 changes: 40 additions & 0 deletions PQAnalysis/traj/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,50 @@
import numpy as np

from beartype.typing import List, Iterator, Any
from enum import Enum

from .frame import Frame


class TrajectoryFormat(Enum):
"""
An enumeration of the supported trajectory formats.
...
Attributes
----------
XYZ : str
The XYZ format.
VEL : str
The VEL format.
"""

XYZ = "XYZ"
VEL = "VEL"

@classmethod
def _missing_(cls, value: object) -> Any:
"""
This method allows a trajectory format to be retrieved from a string.
Parameters
----------
value : str
_description_
Returns
-------
Any
_description_
"""
value = value.lower()
for member in cls:
if member.value.lower() == value:
return member
return None


class Trajectory:
"""
A trajectory object is a sequence of frames.
Expand Down
29 changes: 23 additions & 6 deletions tests/io/test_trajectoryReader.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,35 +60,35 @@ def test__read_header_line__(self):
reader = FrameReader()

with pytest.raises(ValueError) as exception:
reader.__read_header_line__("1 2.0 3.0")
reader._read_header_line("1 2.0 3.0")
assert str(
exception.value) == "Invalid file format in header line of Frame."

n_atoms, cell = reader.__read_header_line__(
n_atoms, cell = reader._read_header_line(
"1 2.0 3.0 4.0 5.0 6.0 7.0")
assert n_atoms == 1
assert np.allclose(cell.box_lengths, [2.0, 3.0, 4.0])
assert np.allclose(cell.box_angles, [5.0, 6.0, 7.0])

n_atoms, cell = reader.__read_header_line__("2 2.0 3.0 4.0")
n_atoms, cell = reader._read_header_line("2 2.0 3.0 4.0")
assert n_atoms == 2
assert np.allclose(cell.box_lengths, [2.0, 3.0, 4.0])
assert np.allclose(cell.box_angles, [90.0, 90.0, 90.0])

n_atoms, cell = reader.__read_header_line__("3")
n_atoms, cell = reader._read_header_line("3")
assert n_atoms == 3
assert cell is None

def test__read_xyz__(self):
reader = FrameReader()

with pytest.raises(ValueError) as exception:
reader.__read_xyz__(
reader._read_xyz(
["", "", "h 1.0 2.0 3.0", "o 2.0 2.0"], n_atoms=2)
assert str(
exception.value) == "Invalid file format in xyz coordinates of Frame."

xyz, atoms = reader.__read_xyz__(
xyz, atoms = reader._read_xyz(
["", "", "h 1.0 2.0 3.0", "o 2.0 2.0 2.0"], n_atoms=2)
assert np.allclose(xyz, [[1.0, 2.0, 3.0], [2.0, 2.0, 2.0]])
assert atoms == ["h", "o"]
Expand All @@ -115,3 +115,20 @@ def test_read(self):
assert np.allclose(frame.pos, [
[1.0, 2.0, 3.0], [2.0, 2.0, 2.0]])
assert frame.cell == Cell(2.0, 3.0, 4.0, 5.0, 6.0, 7.0)

frame = reader.read(
"2 2.0 3.0 4.0 5.0 6.0 7.0\n\nh 1.0 2.0 3.0\no1 2.0 2.0 2.0", format="vel")
assert frame.n_atoms == 2
assert frame.atoms == [Atom(atom, use_guess_element=False)
for atom in ["h", "o1"]]
assert np.allclose(frame.vel, [
[1.0, 2.0, 3.0], [2.0, 2.0, 2.0]])
assert frame.cell == Cell(2.0, 3.0, 4.0, 5.0, 6.0, 7.0)

def test_read_invalid_format(self):
reader = FrameReader()

with pytest.raises(ValueError) as exception:
reader.read("", format="invalid")
assert str(
exception.value) == "'invalid' is not a valid TrajectoryFormat"

0 comments on commit 95dd3fd

Please sign in to comment.