Skip to content

Commit

Permalink
introduced MDEngineFormat
Browse files Browse the repository at this point in the history
  • Loading branch information
97gamjak committed Nov 12, 2023
1 parent f36881e commit 55ad208
Show file tree
Hide file tree
Showing 13 changed files with 232 additions and 154 deletions.
22 changes: 20 additions & 2 deletions PQAnalysis/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, id: Any) -> None:
super().__init__(self.message)


class TrajectoryFormatError(PQException):
class FormatEnumError(PQException):
"""
Exception raised if the given enum is not valid
"""
Expand All @@ -46,7 +46,25 @@ def __init__(self, value: object, enum: object) -> None:
self.enum = enum
self.value = value
self.message = f"""
'{self.value}' is not a valid TrajectoryFormat.
'{self.value}' is not a valid {enum.__name__}.
Possible values are: {enum.member_repr()}
or their case insensitive string representation: {enum.value_repr()}"""
super().__init__(self.message)


class TrajectoryFormatError(FormatEnumError):
"""
Exception raised if the given enum is not valid
"""

def __init__(self, value: object, enum: object) -> None:
super().__init__(value, enum)


class MDEngineFormatError(FormatEnumError):
"""
Exception raised if the given enum is not valid
"""

def __init__(self, value: object, enum: object) -> None:
super().__init__(value, enum)
22 changes: 7 additions & 15 deletions PQAnalysis/io/energyFileReader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .base import BaseReader
from .infoFileReader import InfoFileReader
from ..physicalData.energy import Energy
from ..traj.formats import MDEngineFormat


class EnergyFileReader(BaseReader):
Expand All @@ -34,15 +35,15 @@ class EnergyFileReader(BaseReader):
The name of the info file to read from.
withInfoFile : bool
If True, the info file was found.
format : MDEngineFormat
The format of the file. Default is MDEngineFormat.PIMD_QMCF.
"""

formats = ["pimd-qmcf", "qmcfc"]

def __init__(self,
filename: str,
info_filename: str | None = None,
use_info_file: bool = True,
format: str = "pimd-qmcf"
format: MDEngineFormat | str = MDEngineFormat.PIMD_QMCF
) -> None:
"""
Initializes the EnergyFileReader with the given filename.
Expand All @@ -61,13 +62,8 @@ def __init__(self,
The name of the info file to read from, by default None
use_info_file : bool, optional
If True, the info file is searched for, by default True
format : str, optional
The format of the info file, by default "pimd-qmcf"
Raises
------
ValueError
If the format is not supported.
format : MDEngineFormat | str, optional
The format of the file, by default MDEngineFormat.PIMD_QMCF
"""
super().__init__(filename)
self.info_filename = info_filename
Expand All @@ -77,11 +73,7 @@ def __init__(self,
else:
self.withInfoFile = False

if format not in self.formats:
raise ValueError(
f"Format {format} is not supported. Supported formats are {self.formats}.")

self.format = format
self.format = MDEngineFormat(format)

def read(self) -> Energy:
"""
Expand Down
2 changes: 1 addition & 1 deletion PQAnalysis/io/frameReader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ..core.cell import Cell
from ..types import Numpy2DFloatArray, Numpy1DFloatArray
from ..traj.frame import Frame
from ..traj.trajectory import TrajectoryFormat
from ..traj.formats import TrajectoryFormat
from ..exceptions import ElementNotFoundError


Expand Down
28 changes: 9 additions & 19 deletions PQAnalysis/io/infoFileReader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from beartype.typing import Tuple, Dict

from .base import BaseReader
from ..traj.formats import MDEngineFormat


class InfoFileReader(BaseReader):
Expand All @@ -27,35 +28,24 @@ class InfoFileReader(BaseReader):
----------
filename : str
The name of the file to read from.
format : str
The format of the info file.
format : MDEngineFormat
The format of the info file. Default is MDEngineFormat.PIMD_QMCF.
"""

formats = ["pimd-qmcf", "qmcfc"]

def __init__(self, filename: str, format: str = "pimd-qmcf") -> None:
def __init__(self, filename: str, format: MDEngineFormat | str = MDEngineFormat.PIMD_QMCF) -> None:
"""
Initializes the InfoFileReader with the given filename.
Parameters
----------
filename : str
The name of the file to read from.
format : str, optional
The format of the info file, by default "pimd-qmcf"
Raises
------
ValueError
If the format is not supported.
format : MDEngineFormat | str, optional
The format of the info file. Default is MDEngineFormat.PIMD_QMCF.
"""
super().__init__(filename)

if format not in self.formats:
raise ValueError(
f"Format {format} is not supported. Supported formats are {self.formats}.")

self.format = format
self.format = MDEngineFormat(format)

def read(self) -> Tuple[Dict, Dict | None]:
"""
Expand All @@ -71,9 +61,9 @@ def read(self) -> Tuple[Dict, Dict | None]:
The units of the info file as a dictionary. The keys are the names of the
information strings. The values are the corresponding units.
"""
if self.format == "pimd-qmcf":
if self.format == MDEngineFormat.PIMD_QMCF:
return self.read_pimd_qmcf()
elif self.format == "qmcfc":
elif self.format == MDEngineFormat.QMCFC:
return self.read_qmcfc()

def read_pimd_qmcf(self) -> Tuple[Dict, Dict]:
Expand Down
3 changes: 2 additions & 1 deletion PQAnalysis/io/trajectoryReader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
"""

from .base import BaseReader
from ..traj.trajectory import Trajectory, TrajectoryFormat
from ..traj.trajectory import Trajectory
from ..traj.formats import TrajectoryFormat
from .frameReader import FrameReader


Expand Down
45 changes: 26 additions & 19 deletions PQAnalysis/io/trajectoryWriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,19 @@
from beartype.typing import List

from .base import BaseWriter
from ..traj.trajectory import Trajectory, TrajectoryFormat
from ..traj.trajectory import Trajectory
from ..traj.formats import TrajectoryFormat, MDEngineFormat
from ..traj.frame import Frame
from ..core.cell import Cell
from ..core.atom import Atom
from ..types import Numpy2DFloatArray, Numpy1DFloatArray


def write_trajectory(traj, filename: str | None = None, format: str | None = None, type: TrajectoryFormat | str = TrajectoryFormat.XYZ) -> None:
def write_trajectory(traj,
filename: str | None = None,
format: MDEngineFormat | str = MDEngineFormat.PIMD_QMCF,
type: TrajectoryFormat | str = TrajectoryFormat.XYZ
) -> None:
"""
Wrapper for TrajectoryWriter to write a trajectory to a file.
Expand All @@ -34,8 +39,8 @@ def write_trajectory(traj, filename: str | None = None, format: str | None = Non
The trajectory to write.
filename : str, optional
The name of the file to write to. If None, the output is printed to stdout.
format : str, optional
The format of the file. If None, the default PIMD-QMCF format is used.
format : MDEngineFormat | str, optional
The format of the md engine for the output file. The default is MDEngineFormat.PIMD_QMCF.
type : TrajectoryFormat | str, optional
The type of the data to write to the file. Default is TrajectoryFormat.XYZ.
Expand All @@ -59,6 +64,7 @@ class TrajectoryWriter(BaseWriter):
formats : list of str
The available formats for the trajectory file.
#TODO: put this description into formats!!!
PIMD-QMCF format for one frame:
header line containing the number of atoms and the cell information (if available)
arbitrary comment line
Expand All @@ -75,16 +81,16 @@ class TrajectoryWriter(BaseWriter):
Attributes
----------
format : str
The format of the file.
format : MDEngineFormat
The format of the md engine for the output file. The default is MDEngineFormat.PIMD_QMCF.
"""

formats = [None, 'pimd-qmcf', 'qmcfc']
_format: MDEngineFormat
_type: TrajectoryFormat = TrajectoryFormat.XYZ

def __init__(self,
filename: str | None = None,
format: str | None = None,
format: MDEngineFormat | str = MDEngineFormat.PIMD_QMCF,
mode: str = 'w'
) -> None:
"""
Expand All @@ -94,22 +100,15 @@ def __init__(self,
----------
filename : str, optional
The name of the file to write to. If None, the output is printed to stdout.
format : str, optional
The format of the file. If None, the default PIMD-QMCF format is used.
(see TrajectoryWriter.formats for available formats)
format : MDEngineFormat | str, optional
The format of the md engine for the output file. The default is MDEngineFormat.PIMD_QMCF.
mode : str, optional
The mode of the file. Either 'w' for write or 'a' for append.
"""

super().__init__(filename, mode)
if format not in self.formats:
raise ValueError(
'Invalid format. Has to be either \'pimd-qmcf\', \'qmcfc\' or \'None\'.')

if format is None:
format = 'pimd-qmcf'

self.format = format
self.format = MDEngineFormat(format)

def write(self, trajectory: Trajectory, type: TrajectoryFormat | str = TrajectoryFormat.XYZ) -> None:
"""
Expand Down Expand Up @@ -253,7 +252,7 @@ def _write_xyz(self, xyz: Numpy2DFloatArray, atoms: List[Atom]) -> None:
The elements of the frame.
"""

if self.format == "qmcfc" and self._type == TrajectoryFormat.XYZ:
if self.format == MDEngineFormat.QMCFC and self._type == TrajectoryFormat.XYZ:
print("X 0.0 0.0 0.0", file=self.file)

for i in range(len(atoms)):
Expand All @@ -275,3 +274,11 @@ def _write_scalar(self, scalar: Numpy1DFloatArray, atoms: List[Atom]) -> None:
for i in range(len(atoms)):
print(
f"{atoms[i].name} {scalar[i]}", file=self.file)

@property
def format(self) -> MDEngineFormat:
return self._format

@format.setter
def format(self, format: MDEngineFormat | str) -> None:
self._format = MDEngineFormat(format)
Loading

0 comments on commit 55ad208

Please sign in to comment.