Skip to content

Commit

Permalink
New Gaussian plugin (#325)
Browse files Browse the repository at this point in the history
* Change better_guess func name in tests

* Add type annotations to GaussianErrorHandler

* `float` instead of `int | float` for mem
  • Loading branch information
rashatwi authored Apr 3, 2024
1 parent 180990b commit c3dcf12
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 76 deletions.
149 changes: 74 additions & 75 deletions custodian/gaussian/handlers.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
"""This module implements error handlers for Gaussian runs."""

from __future__ import annotations

import datetime
import glob
import logging
import math
import os
import re
import shutil
from typing import TYPE_CHECKING, Any

import numpy as np
from monty.io import zopen
Expand All @@ -15,6 +18,9 @@
from custodian.custodian import ErrorHandler
from custodian.utils import backup

if TYPE_CHECKING:
from collections.abc import Iterable

__author__ = "Rasha Atwi"
__version__ = "0.1"
__maintainer__ = "Rasha Atwi"
Expand Down Expand Up @@ -91,17 +97,17 @@ class GaussianErrorHandler(ErrorHandler):

def __init__(
self,
input_file,
output_file,
stderr_file="stderr.txt",
cart_coords=True,
scf_max_cycles=100,
opt_max_cycles=100,
job_type="normal",
lower_functional=None,
lower_basis_set=None,
prefix="error",
check_convergence=True,
input_file: str,
output_file: str,
stderr_file: str = "stderr.txt",
cart_coords: bool = True,
scf_max_cycles: int = 100,
opt_max_cycles: int = 100,
job_type: str = "normal",
lower_functional: str | None = None,
lower_basis_set: str | None = None,
prefix: str = "error",
check_convergence: bool = True,
):
"""
Initialize the GaussianErrorHandler class.
Expand Down Expand Up @@ -134,23 +140,23 @@ def __init__(
self.output_file = output_file
self.stderr_file = stderr_file
self.cart_coords = cart_coords
self.errors = set()
self.gout = None
self.gin = None
self.errors: set[str] = set()
self.gout: GaussianOutput = None
self.gin: GaussianInput = None
self.scf_max_cycles = scf_max_cycles
self.opt_max_cycles = opt_max_cycles
self.job_type = job_type
self.lower_functional = lower_functional
self.lower_basis_set = lower_basis_set
self.prefix = prefix
self.check_convergence = check_convergence
self.conv_data = None
self.recom_mem = None
self.logger = logging.getLogger(self.__class__.__name__)
self.conv_data: dict[str, dict[str, Any]] = {}
self.recom_mem: float | None = None
self.logger: logging.Logger = logging.getLogger(self.__class__.__name__)
logging.basicConfig(level=logging.INFO)

@staticmethod
def _recursive_lowercase(obj):
def _recursive_lowercase(obj: dict[str, Any] | str | Iterable[Any]) -> dict[str, Any] | str | Iterable[Any]:
"""
Recursively convert all string elements in a given object to lowercase.
Expand All @@ -174,21 +180,15 @@ def _recursive_lowercase(obj):
input `obj`.
"""
if isinstance(obj, dict):
updated_obj = {}
for k, v in obj.items():
updated_obj[k.lower()] = GaussianErrorHandler._recursive_lowercase(v)
return updated_obj
return {k.lower(): GaussianErrorHandler._recursive_lowercase(v) for k, v in obj.items()}
if isinstance(obj, str):
return obj.lower()
if hasattr(obj, "__iter__"):
updated_obj = []
for i in obj:
updated_obj.append(GaussianErrorHandler._recursive_lowercase(i))
return updated_obj
return [GaussianErrorHandler._recursive_lowercase(i) for i in obj]
return obj

@staticmethod
def _recursive_remove_space(obj):
def _recursive_remove_space(obj: dict[str, Any]) -> dict[str, Any]:
"""
Recursively remove leading and trailing whitespace from keys and string values
in a dictionary.
Expand All @@ -209,18 +209,17 @@ def _recursive_remove_space(obj):
A new dictionary with all keys and string values stripped of leading
and trailing whitespace. The structure of the dictionary is preserved.
"""
updated_obj = {}
for key, value in obj.items():
if isinstance(value, dict):
updated_obj[key.strip()] = GaussianErrorHandler._recursive_remove_space(value)
elif isinstance(value, str):
updated_obj[key.strip()] = value.strip()
else:
updated_obj[key.strip()] = value
return updated_obj
return {
key.strip(): GaussianErrorHandler._recursive_remove_space(value)
if isinstance(value, dict)
else value.strip()
if isinstance(value, str)
else value
for key, value in obj.items()
}

@staticmethod
def _update_route_params(route_params, key, value):
def _update_route_params(route_params: dict, key: str, value: str | dict) -> dict:
"""
Update Gaussian route parameters with new key-value pairs, handling nested
structures.
Expand All @@ -243,12 +242,11 @@ def _update_route_params(route_params, key, value):
update = {key: {obj: None, **value}} if isinstance(value, dict) else {key: {obj: None, value: None}}
route_params.update(update)
elif isinstance(obj, dict):
update = value if isinstance(value, dict) else {value: None}
route_params[key].update(update)
route_params[key].update(value if isinstance(value, dict) else {value: None})
return route_params

@staticmethod
def _int_keyword(route_params):
def _int_keyword(route_params: dict[str, str | dict]) -> tuple[str, str | dict]:
"""
Determine the keyword used for 'Integral' in the Gaussian route parameters of
the input file. Possible keywords are 'int' and 'integral'. If neither keyword
Expand All @@ -273,7 +271,7 @@ def _int_keyword(route_params):
return int_key, route_params.get(int_key, "")

@staticmethod
def _int_grid(route_params):
def _int_grid(route_params: dict[str, str | dict]) -> bool:
"""
Check if the integration grid used for numerical integrations matches specific
options.
Expand All @@ -299,12 +297,12 @@ def _int_grid(route_params):
return False

@staticmethod
def convert_mem(mem, unit):
def convert_mem(mem: float, unit: str) -> float:
"""
Convert memory size between different units to megabytes (MB).
Args:
mem (float | int): The memory size to convert.
mem (float): The memory size to convert.
unit (str): The unit of the input memory size. Supported units include
'kb', 'mb', 'gb', 'tb', and word units ('kw', 'mw', 'gw', 'tw'), or an
empty string for default conversion (from words).
Expand All @@ -327,7 +325,7 @@ def convert_mem(mem, unit):
return mem * conversion[unit]

@staticmethod
def _find_dynamic_memory_allocated(link0_params):
def _find_dynamic_memory_allocated(link0_params: dict[str, str]) -> tuple[str | None, float | None]:
"""
Find and convert the memory allocation from Gaussian link0 parameters. This
method searches for the '%mem' key in the link0 parameters of a Gaussian job
Expand All @@ -346,24 +344,25 @@ def _find_dynamic_memory_allocated(link0_params):
allocation in MB. If '%mem' is not found, the second element will be None.
"""
mem_key = None
dynamic_mem = None
for k in link0_params:
if k.lower() == "%mem":
mem_key = k
break
dynamic_mem = link0_params.get(mem_key)
if dynamic_mem:
if mem_key:
dynamic_mem_str = link0_params[mem_key]
# default memory unit in Gaussian is words
dynamic_mem = dynamic_mem.lower()
dynamic_mem_str = dynamic_mem_str.lower()
mem_unit = ""
for unit in GaussianErrorHandler.MEM_UNITS:
if unit in dynamic_mem:
if unit in dynamic_mem_str:
mem_unit = unit
break
dynamic_mem = float(dynamic_mem.strip(mem_unit))
dynamic_mem = float(dynamic_mem_str.strip(mem_unit))
dynamic_mem = GaussianErrorHandler.convert_mem(dynamic_mem, mem_unit)
return mem_key, dynamic_mem

def _add_int(self):
def _add_int(self) -> bool:
"""
Check and update the integration grid setting ('int') in the Gaussian input
file's route parameters to 'ultrafine', if necessary.
Expand Down Expand Up @@ -407,7 +406,7 @@ def _add_int(self):
self.logger.warning(warning_msg)
self.gin.route_parameters[int_key]["grid"] = "ultrafine"
return True
if int_value in self.GRID_NAMES or self.grid_patt.match(int_value):
if isinstance(int_value, str) and (int_value in self.GRID_NAMES or self.grid_patt.match(int_value)):
# if int grid is set and is different from ultrafine,
# set it to ultrafine (works when no other int options
# are specified)
Expand All @@ -424,7 +423,7 @@ def _add_int(self):
return False

@staticmethod
def _not_g16(gout):
def _not_g16(gout: GaussianOutput) -> bool:
"""
Determine if the Gaussian version is not 16.
Expand All @@ -437,7 +436,7 @@ def _not_g16(gout):
return "16" not in gout.version

@staticmethod
def _monitor_convergence(data, directory="./"):
def _monitor_convergence(data: dict[str, dict[str, Any]], directory: str = "./") -> None:
"""
Plot and save a convergence graph for an optimization job as a function of the
number of iterations.
Expand Down Expand Up @@ -470,7 +469,7 @@ def _monitor_convergence(data, directory="./"):
plt.tight_layout()
plt.savefig(os.path.join(directory, "convergence.png"))

def check(self, directory="./"):
def check(self, directory: str = "./") -> bool:
"""Check for errors in the Gaussian output file."""
# TODO: this backups the original file instead of the actual one
if "linear_bend" in self.errors:
Expand All @@ -481,6 +480,7 @@ def check(self, directory="./"):

self.gin = GaussianInput.from_file(os.path.join(directory, self.input_file))
self.gin.route_parameters = GaussianErrorHandler._recursive_lowercase(self.gin.route_parameters)
assert isinstance(self.gin.route_parameters, dict)
self.gin.route_parameters = GaussianErrorHandler._recursive_remove_space(self.gin.route_parameters)
self.gout = GaussianOutput(os.path.join(directory, self.output_file))
self.errors = set()
Expand All @@ -505,8 +505,8 @@ def check(self, directory="./"):

if self.check_convergence and "opt" in self.gin.route_parameters:
for k, v in GaussianErrorHandler.conv_critera.items():
if v.search(line):
m = v.search(line)
m = v.search(line)
if m:
if k not in self.conv_data["values"]:
self.conv_data["values"][k] = [m.group(2)]
self.conv_data["thresh"][k] = float(m.group(3))
Expand All @@ -524,9 +524,9 @@ def check(self, directory="./"):
self.logger.error(patt)
return len(self.errors) > 0

def correct(self, directory="./"):
def correct(self, directory: str = "./"):
"""Perform necessary actions to correct the errors in the Gaussian output."""
actions = []
actions: list[Any] = []
# to avoid situations like 'linear_bend', where if we backup input_file,
# it will not be the actual input used in the current calc
# shutil.copy(self.input_file, f'{self.input_file}.backup')
Expand All @@ -544,13 +544,13 @@ def correct(self, directory="./"):
if "scf_convergence" in self.errors:
self.gin.route_parameters = GaussianErrorHandler._update_route_params(self.gin.route_parameters, "scf", {})
# if the SCF procedure has failed to converge
if self.gin.route_parameters.get("scf").get("maxcycle") != str(self.scf_max_cycles):
if self.gin.route_parameters.get("scf", {}).get("maxcycle") != str(self.scf_max_cycles):
# increase number of cycles if not already set or is different
# from scf_max_cycles
self.gin.route_parameters["scf"]["maxcycle"] = self.scf_max_cycles
actions.append({"scf_max_cycles": self.scf_max_cycles})

elif not {"xqc", "yqc", "qc"}.intersection(self.gin.route_parameters.get("scf")):
elif not {"xqc", "yqc", "qc"}.intersection(self.gin.route_parameters.get("scf", set())):
# use an alternate SCF converger
self.gin.route_parameters["scf"]["xqc"] = None
actions.append({"scf_algorithm": "xqc"})
Expand All @@ -577,7 +577,7 @@ def correct(self, directory="./"):
)
else:
self.logger.info("SCF calculation failed. Exiting...")
return {"errors": list[self.errors], "actions": None}
return {"errors": list(self.errors), "actions": None}

elif "opt_steps" in self.errors:
# int_actions = self._add_int()
Expand Down Expand Up @@ -825,16 +825,16 @@ class WallTimeErrorHandler(ErrorHandler):
time is less than or equal to the buffer time.
"""

is_monitor = True
is_monitor: bool = True

def __init__(
self,
wall_time,
buffer_time,
input_file,
output_file,
stderr_file="stderr.txt",
prefix="error",
wall_time: int,
buffer_time: int,
input_file: str,
output_file: str,
stderr_file: str = "stderr.txt",
prefix: str = "error",
):
"""
Initialize the WalTimeErrorHandler class.
Expand All @@ -858,17 +858,16 @@ def __init__(
self.output_file = output_file
self.stderr_file = stderr_file
self.prefix = prefix
self.logger = logging.getLogger(self.__class__.__name__)
self.logger: logging.Logger = logging.getLogger(self.__class__.__name__)
logging.basicConfig(level=logging.INFO)

now_ = datetime.datetime.now()
now_str = datetime.datetime.strftime(now_, "%a %b %d %H:%M:%S UTC %Y")
init_time_str = os.environ.get("JOB_START_TIME", now_str)
os.environ["JOB_START_TIME"] = init_time_str
self.init_time = datetime.datetime.strptime(init_time_str, "%a %b %d %H:%M:%S %Z %Y")

self.init_time = os.environ.get("JOB_START_TIME", now_str)
os.environ["JOB_START_TIME"] = self.init_time
self.init_time = datetime.datetime.strptime(self.init_time, "%a %b %d %H:%M:%S %Z %Y")

def check(self, directory="./"):
def check(self, directory: str = "./") -> bool:
"""Check if the job is nearing the walltime. If so, return True, else False."""
if self.wall_time:
run_time = datetime.datetime.now() - self.init_time
Expand All @@ -877,7 +876,7 @@ def check(self, directory="./"):
return True
return False

def correct(self, directory="./"):
def correct(self, directory: str = "./") -> dict:
"""Perform the corrections."""
# TODO: when using restart, the rwf file might be in a different dir
backup_files = [self.input_file, self.output_file, self.stderr_file, *BACKUP_FILES.values()]
Expand Down
2 changes: 1 addition & 1 deletion tests/gaussian/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_normal(self):
assert os.path.exists(f"{SCR_DIR}/test.out{self.suffix}")

def test_better_guess(self):
g = GaussianJob.better_guess(
g = GaussianJob.generate_better_guess(
self.gaussian_cmd,
self.input_file,
self.output_file,
Expand Down

0 comments on commit c3dcf12

Please sign in to comment.