Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JDFTx jobs.py to run JDFTx using atomate2/jobflow #349

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/custodian/jdftx/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""
This package implements various JDFTx Jobs and Error Handlers.
Used Cp2kJob developed by Nick Winner as a template.
"""
110 changes: 110 additions & 0 deletions src/custodian/jdftx/jobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""This module implements basic kinds of jobs for JDFTx runs."""

import logging
import os
import shlex
import subprocess

import psutil

from custodian.custodian import Job

logger = logging.getLogger(__name__)


class JDFTxJob(Job):
"""A basic JDFTx job. Runs whatever is in the working directory."""

def __init__(
self,
jdftx_cmd,
input_file="init.in",
output_file="jdftx.out",
stderr_file="std_err.txt",
) -> None:
"""
Args:
jdftx_cmd (str): Command to run JDFTx as a string.
input_file (str): Name of the file to use as input to JDFTx
executable. Defaults to "init.in"
output_file (str): Name of file to direct standard out to.
Defaults to "jdftx.out".
stderr_file (str): Name of file to direct standard error to.
Defaults to "std_err.txt".
"""
self.jdftx_cmd = jdftx_cmd
self.input_file = input_file
self.output_file = output_file
self.stderr_file = stderr_file

def setup(self, directory="./") -> None:
"""No setup required."""

def run(self, directory="./"):
"""
Perform the actual JDFTx run.

Returns:
-------
(subprocess.Popen) Used for monitoring.
"""
cmd = self.jdftx_cmd + " -i " + self.input_file + " -o " + self.output_file
logger.info(f"Running {cmd}")
with (
open(os.path.join(directory, self.output_file), "w") as f_std,
open(os.path.join(directory, self.stderr_file), "w", buffering=1) as f_err,
):
# use line buffering for stderr
return subprocess.run(
shlex.split(cmd),
cwd=directory,
stdout=f_std,
stderr=f_err,
shell=False,
check=False,
)

def postprocess(self, directory="./") -> None:
"""No post-processing required."""

def terminate(self, directory="./") -> None:
"""Terminate JDFTx."""
work_dir = directory
logger.info(f"Killing JDFTx processes in {work_dir=}.")
for proc in psutil.process_iter():
try:
if "jdftx" in proc.name():
print("name:", proc.name())
open_paths = [file.path for file in proc.open_files()]
run_path = os.path.join(work_dir, self.output_file)
if (run_path in open_paths) and psutil.pid_exists(proc.pid):
self.terminate_process(proc)
return
except (psutil.NoSuchProcess, psutil.AccessDenied) as e:
logger.warning(f"Exception {e} encountered while killing JDFTx.")
continue

logger.warning(
f"Killing JDFTx processes in {work_dir=} failed with subprocess.Popen.terminate(). Resorting to 'killall'."
)
cmd = self.jdftx_cmd
print("cmd:", cmd)
if "jdftx" in cmd:
subprocess.run(["killall", f"{cmd}"], check=False)

@staticmethod
def terminate_process(proc, timeout=5):
"""Terminate a process gracefully, then forcefully if necessary."""
try:
proc.terminate()
try:
proc.wait(timeout=timeout)
except psutil.TimeoutExpired:
# If process is still running after the timeout, kill it
logger.warning(f"Process {proc.pid} did not terminate gracefully, killing it.")
proc.kill()
# proc.wait()
else:
logger.info(f"Process {proc.pid} terminated gracefully.")
except (psutil.NoSuchProcess, psutil.AccessDenied) as e:
logger.warning(f"Error while terminating process {proc.pid}: {e}")
3 changes: 2 additions & 1 deletion src/custodian/vasp/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,8 @@ def run(self, directory="./"):
open(os.path.join(directory, self.stderr_file), "w", buffering=1) as f_err,
):
# use line buffering for stderr
return subprocess.Popen(cmd, cwd=directory, stdout=f_std, stderr=f_err, start_new_session=True) # pylint: disable=R1732
return subprocess.Popen(cmd, cwd=directory, stdout=f_std, stderr=f_err, start_new_session=True)
# pylint: disable=R1732

def postprocess(self, directory="./") -> None:
"""
Expand Down
8 changes: 8 additions & 0 deletions tests/jdftx/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import pytest

from custodian.jdftx.jobs import JDFTxJob


@pytest.fixture
def jdftx_job():
return JDFTxJob(jdftx_cmd="jdftx", output_file="jdftx.out")
174 changes: 174 additions & 0 deletions tests/jdftx/test_jobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import os
from pathlib import Path
from unittest import mock
from unittest.mock import ANY, MagicMock, patch

import psutil

from custodian.jdftx.jobs import JDFTxJob

TEST_DIR = Path(__file__).resolve().parent.parent
TEST_FILES = f"{TEST_DIR}/files/jdftx"


def create_mock_process(
name="jdftx", open_files=None, pid=12345, name_side_effect=None, wait_side_effect=None, terminate_side_effect=None
):
if open_files is None: # Set a default value if not provided
open_files = [MagicMock(path=os.path.join("default_path", "jdftx.out"))]

mock_process = mock.Mock(spec=psutil.Process)
mock_process.name.return_value = name
mock_process.open_files.return_value = open_files
mock_process.pid = pid
mock_process.name.side_effect = name_side_effect
mock_process.wait.side_effect = wait_side_effect
mock_process.terminate.side_effect = terminate_side_effect
return mock_process


def test_jdftx_job_init(jdftx_job):
assert jdftx_job.jdftx_cmd == "jdftx"
assert jdftx_job.input_file == "init.in"
assert jdftx_job.output_file == "jdftx.out"
assert jdftx_job.stderr_file == "std_err.txt"


def test_jdftx_job_setup(jdftx_job, tmp_path):
jdftx_job.setup(str(tmp_path))
# Setup method doesn't do anything, so just checking that it doesn't raise an exception


def test_jdftx_job_run(jdftx_job, tmp_path):
with patch("subprocess.run") as mock_run:
mock_process = MagicMock()
mock_run.return_value = mock_process

result = jdftx_job.run(str(tmp_path))

assert result == mock_process
mock_run.assert_called_once_with(
["jdftx", "-i", "init.in", "-o", "jdftx.out"],
cwd=str(tmp_path),
stdout=ANY,
stderr=ANY,
shell=False,
check=False,
)


def test_jdftx_job_run_creates_output_files(jdftx_job, tmp_path):
with patch("subprocess.run"):
jdftx_job.run(str(tmp_path))

assert os.path.exists(os.path.join(str(tmp_path), "jdftx.out"))
assert os.path.exists(os.path.join(str(tmp_path), "std_err.txt"))


def test_jdftx_job_postprocess(jdftx_job, tmp_path):
jdftx_job.postprocess(str(tmp_path))
# Postprocess method doesn't do anything, so we just check that it doesn't raise an exception


@mock.patch("psutil.pid_exists")
@mock.patch("subprocess.run")
@mock.patch.object(JDFTxJob, "terminate_process", autospec=True)
def test_jdftx_job_terminate(mock_terminate_process, mock_subprocess_run, mock_pid_exists, jdftx_job, tmp_path, caplog):
open_files = [MagicMock(path=os.path.join(str(tmp_path), jdftx_job.output_file))]
# Test when JDFTx process exists
mock_process = create_mock_process(name="jdftx", open_files=open_files, pid=12345)

with patch("psutil.process_iter", return_value=[mock_process]):
mock_pid_exists.return_value = True
jdftx_job.terminate(str(tmp_path))
mock_terminate_process.assert_called_once_with(mock_process)
mock_subprocess_run.assert_not_called()

mock_terminate_process.reset_mock()
mock_subprocess_run.reset_mock()

# Test when no JDFTx process exists
mock_process = create_mock_process(name="vasp", open_files=open_files, pid=12345)

with patch("psutil.process_iter", return_value=[mock_process]):
jdftx_job.terminate(str(tmp_path))
mock_terminate_process.assert_not_called()
mock_subprocess_run.assert_called_once_with(["killall", "jdftx"], check=False)

mock_terminate_process.reset_mock()
mock_subprocess_run.reset_mock()

# Test when psutil.process_iter raises NoSuchProcess
mock_process = create_mock_process(
name="jdftx", open_files=open_files, pid=12345, name_side_effect=psutil.NoSuchProcess(pid=12345)
)

with caplog.at_level("WARNING"):
with patch("psutil.process_iter", return_value=[mock_process]):
jdftx_job.terminate(str(tmp_path))
mock_terminate_process.assert_not_called()
mock_subprocess_run.assert_called_with(["killall", "jdftx"], check=False)

assert "Exception" in caplog.text
assert "encountered while killing JDFTx" in caplog.text

mock_terminate_process.reset_mock()
mock_subprocess_run.reset_mock()

# Test when psutil.process_iter raises AccessDenied
with caplog.at_level("WARNING"):
mock_process = create_mock_process(
name="jdftx", open_files=open_files, pid=12345, name_side_effect=psutil.AccessDenied(pid=12345)
)
with patch("psutil.process_iter", return_value=[mock_process]):
jdftx_job.terminate(str(tmp_path))
mock_terminate_process.assert_not_called()
mock_subprocess_run.assert_called_with(["killall", "jdftx"], check=False)

assert "Exception" in caplog.text
assert "encountered while killing JDFTx" in caplog.text


def test_terminate_process(jdftx_job, caplog):
# Test successful termination
mock_process = create_mock_process()
mock_process.terminate.return_value = None # Simulate successful termination
mock_process.wait.return_value = None # Simulate process finished immediately
with caplog.at_level("INFO"):
jdftx_job.terminate_process(mock_process)

mock_process.terminate.assert_called_once()
mock_process.wait.assert_called_once()

assert "Process" in caplog.text
assert "terminated gracefully" in caplog.text

mock_process.reset_mock()

# Test when process doesn't terminate gracefully
mock_process = create_mock_process(pid=12345, wait_side_effect=psutil.TimeoutExpired(seconds=5))
mock_process.terminate.return_value = None

jdftx_job.terminate_process(mock_process)
mock_process.terminate.assert_called_once()
mock_process.kill.assert_called_once()
mock_process.wait.assert_called()

mock_process.reset_mock()

# Test when process raises NoSuchProcess
mock_process = create_mock_process(pid=12345, terminate_side_effect=psutil.NoSuchProcess(pid=12345))
with caplog.at_level("WARNING"):
jdftx_job.terminate_process(mock_process)

assert "Error while terminating process" in caplog.text

mock_process.reset_mock()

# Test when process raises AccessDenied
mock_process = create_mock_process(pid=12345, terminate_side_effect=psutil.AccessDenied(pid=12345))

with caplog.at_level("WARNING"):
jdftx_job.terminate_process(mock_process)

assert "Error while terminating process" in caplog.text