diff --git a/CIME/XML/env_batch.py b/CIME/XML/env_batch.py
index 3417b58c3af..285f0e4c062 100644
--- a/CIME/XML/env_batch.py
+++ b/CIME/XML/env_batch.py
@@ -20,6 +20,7 @@
from collections import OrderedDict
import stat, re, math
import pathlib
+from itertools import zip_longest
logger = logging.getLogger(__name__)
@@ -1412,34 +1413,81 @@ def cancel_job(self, jobid):
else:
return True
+ def zip(self, other, name):
+ for self_pnode in self.get_children(name):
+ try:
+ other_pnode = other.get_children(name, attributes=self_pnode.attrib)[0]
+ except (TypeError, IndexError):
+ other_pnode = None
+
+ for node1 in self.get_children(root=self_pnode):
+ for node2 in other.scan_children(
+ node1.name, attributes=node1.attrib, root=other_pnode
+ ):
+ yield node1, node2
+
+ def _compare_arg(self, index, arg1, arg2):
+ try:
+ flag1 = arg1.attrib["flag"]
+ name1 = arg1.attrib.get("name", "")
+ except AttributeError:
+ flag2, name2 = arg2.attrib["flag"], arg2.attrib["name"]
+
+ return {f"arg{index}": ["", f"{flag2} {name2}"]}
+
+ try:
+ flag2 = arg2.attrib["flag"]
+ name2 = arg2.attrib.get("name", "")
+ except AttributeError:
+ return {f"arg{index}": [f"{flag1} {name1}", ""]}
+
+ if flag1 != flag2 or name1 != name2:
+ return {f"arg{index}": [f"{flag1} {name1}", f"{flag2} {name2}"]}
+
+ return {}
+
+ def _compare_argument(self, index, arg1, arg2):
+ if arg1.text != arg2.text:
+ return {f"argument{index}": [arg1.text, arg2.text]}
+
+ return {}
+
def compare_xml(self, other):
xmldiffs = {}
- f1batchnodes = self.get_children("batch_system")
- for bnode in f1batchnodes:
- f2bnodes = other.get_children("batch_system", attributes=self.attrib(bnode))
- f2bnode = None
- if len(f2bnodes):
- f2bnode = f2bnodes[0]
- f1batchnodes = self.get_children(root=bnode)
- for node in f1batchnodes:
- name = self.name(node)
- text1 = self.text(node)
- text2 = ""
- attribs = self.attrib(node)
- f2matches = other.scan_children(name, attributes=attribs, root=f2bnode)
- foundmatch = False
- for chkmatch in f2matches:
- name2 = other.name(chkmatch)
- attribs2 = other.attrib(chkmatch)
- text2 = other.text(chkmatch)
- if name == name2 and attribs == attribs2 and text1 == text2:
- foundmatch = True
- break
- if not foundmatch:
- xmldiffs[name] = [text1, text2]
-
- f1groups = self.get_children("group")
- for node in f1groups:
+
+ for node1, node2 in self.zip(other, "batch_system"):
+ if node1.name == "submit_args":
+ self_nodes = self.get_children(root=node1)
+ other_nodes = other.get_children(root=node2)
+ for i, (x, y) in enumerate(
+ zip_longest(self_nodes, other_nodes, fillvalue=None)
+ ):
+ if (x is not None and x.name == "arg") or (
+ y is not None and y.name == "arg"
+ ):
+ xmldiffs.update(self._compare_arg(i, x, y))
+ elif (x is not None and x.name == "argument") or (
+ y is not None and y.name == "argument"
+ ):
+ xmldiffs.update(self._compare_node(x, y, i))
+ elif node1.name == "directives":
+ self_nodes = self.get_children(root=node1)
+ other_nodes = other.get_children(root=node2)
+ for i, (x, y) in enumerate(
+ zip_longest(self_nodes, other_nodes, fillvalue=None)
+ ):
+ xmldiffs.update(self._compare_node(x, y, i))
+ elif node1.name == "queues":
+ self_nodes = self.get_children(root=node1)
+ other_nodes = other.get_children(root=node2)
+ for i, (x, y) in enumerate(
+ zip_longest(self_nodes, other_nodes, fillvalue=None)
+ ):
+ xmldiffs.update(self._compare_node(x, y, i))
+ else:
+ xmldiffs.update(self._compare_node(node1, node2))
+
+ for node in self.get_children("group"):
group = self.get(node, "id")
f2group = other.get_child("group", attributes={"id": group})
xmldiffs.update(
@@ -1447,6 +1495,36 @@ def compare_xml(self, other):
)
return xmldiffs
+ def _compare_node(self, x, y, index=None):
+ """Compares two XML nodes and returns diff.
+
+ Compares the attributes and text of two XML nodes. Handles the case when either node is `None`.
+
+ The `index` argument can be used to append the nodes tag. This can be useful when comparing a list
+ of XML nodes that all have the same tag to differentiate which nodes are different.
+
+ Args:
+ x (:obj:`CIME.XML.generic_xml._Element`): First node.
+ y (:obj:`CIME.XML.generic_xml._Element`): Second node.
+ index (int, optional): Index of the nodes.
+
+ Returns:
+ dict: Key is the tag and value is the difference.
+ """
+ diff = {}
+
+ if index is None:
+ index = ""
+
+ if x is None:
+ diff[f"{y.name}{index}"] = ["", y.text]
+ elif y is None:
+ diff[f"{x.name}{index}"] = [x.text, ""]
+ elif x.text != y.text or x.attrib != y.attrib:
+ diff[f"{x.name}{index}"] = [x.text, y.text]
+
+ return diff
+
def make_all_batch_files(self, case):
machdir = case.get_value("MACHDIR")
logger.info("Creating batch scripts")
diff --git a/CIME/XML/generic_xml.py b/CIME/XML/generic_xml.py
index 083743695b3..d35815d7829 100644
--- a/CIME/XML/generic_xml.py
+++ b/CIME/XML/generic_xml.py
@@ -36,6 +36,24 @@ def __hash__(self):
def __deepcopy__(self, _):
return _Element(deepcopy(self.xml_element))
+ def __str__(self):
+ return str(self.xml_element)
+
+ def __repr__(self):
+ return repr(self.xml_element)
+
+ @property
+ def name(self):
+ return self.xml_element.tag
+
+ @property
+ def text(self):
+ return self.xml_element.text
+
+ @property
+ def attrib(self):
+ return dict(self.xml_element.attrib)
+
class GenericXML(object):
diff --git a/CIME/tests/test_unit_xml_env_batch.py b/CIME/tests/test_unit_xml_env_batch.py
index d59c4b080c9..281abc7603d 100755
--- a/CIME/tests/test_unit_xml_env_batch.py
+++ b/CIME/tests/test_unit_xml_env_batch.py
@@ -3,15 +3,178 @@
import os
import unittest
import tempfile
+from contextlib import ExitStack
from unittest import mock
-from CIME.utils import CIMEError
+from CIME.utils import CIMEError, expect
from CIME.XML.env_batch import EnvBatch, get_job_deps
# pylint: disable=unused-argument
+XML_BASE = b"""
+
+
+ These variables may be changed anytime during a run, they
+ control arguments to the batch submit command.
+
+
+
+ char
+ miller_slurm,nersc_slurm,lc_slurm,moab,pbs,lsf,slurm,cobalt,cobalt_theta,none
+ The batch system type to use for this machine.
+
+
+
+
+ logical
+ TRUE,FALSE
+ whether the PROJECT value is required on this machine
+
+
+
+ squeue
+ sbatch
+ scancel
+ #SBATCH
+ (\\d+)$
+ --dependency=afterok:jobid
+ --dependency=afterany:jobid
+ :
+ %H:%M:%S
+ --mail-user
+ --mail-type
+ none, all, begin, end, fail
+
+
+
+
+
+
+ --job-name={{ job_id }}
+ --nodes={{ num_nodes }}
+ --output={{ job_id }}.%j
+ --exclusive
+
+
+
+
+ -w docker
+
+
+ debug
+ big
+ smallfast
+
+
+"""
+
+XML_DIFF = b"""
+
+
+ These variables may be changed anytime during a run, they
+ control arguments to the batch submit command.
+
+
+
+ char
+ miller_slurm,nersc_slurm,lc_slurm,moab,pbs,lsf,slurm,cobalt,cobalt_theta,none
+ The batch system type to use for this machine.
+
+
+
+
+ logical
+ TRUE,FALSE
+ whether the PROJECT value is required on this machine
+
+
+
+ squeue
+ batch
+ scancel
+ #SBATCH
+ (\\d+)$
+ --dependency=afterok:jobid
+ --dependency=afterany:jobid
+ :
+ %H:%M:%S
+ --mail-user
+ --mail-type
+ none, all, begin, end, fail
+
+
+
+
+
+
+
+ --job-name={{ job_id }}
+ --nodes=10
+ --output={{ job_id }}.%j
+ --exclusive
+ --qos=high
+
+
+
+
+ -w docker
+
+
+ debug
+ big
+
+
+"""
+
+
+def _open_temp_file(stack, data):
+ tfile = stack.enter_context(tempfile.NamedTemporaryFile())
+
+ tfile.write(data)
+
+ tfile.seek(0)
+
+ return tfile
+
class TestXMLEnvBatch(unittest.TestCase):
+ def test_compare_xml(self):
+ with ExitStack() as stack:
+ file1 = _open_temp_file(stack, XML_DIFF)
+ batch1 = EnvBatch(infile=file1.name)
+
+ file2 = _open_temp_file(stack, XML_BASE)
+ batch2 = EnvBatch(infile=file2.name)
+
+ diff = batch1.compare_xml(batch2)
+ diff2 = batch2.compare_xml(batch1)
+
+ expected_diff = {
+ "BATCH_SYSTEM": ["pbs", "slurm"],
+ "arg1": ["-p pbatch", "-p $JOB_QUEUE"],
+ "arg3": ["-m plane", ""],
+ "batch_submit": ["batch", "sbatch"],
+ "directive1": [" --nodes=10", " --nodes={{ num_nodes }}"],
+ "directive4": [" --qos=high ", ""],
+ "queue1": ["big", "big"],
+ "queue2": ["", "smallfast"],
+ }
+
+ assert diff == expected_diff
+
+ expected_diff2 = {
+ "BATCH_SYSTEM": ["slurm", "pbs"],
+ "arg1": ["-p $JOB_QUEUE", "-p pbatch"],
+ "arg3": ["", "-m plane"],
+ "batch_submit": ["sbatch", "batch"],
+ "directive1": [" --nodes={{ num_nodes }}", " --nodes=10"],
+ "directive4": ["", " --qos=high "],
+ "queue1": ["big", "big"],
+ "queue2": ["smallfast", ""],
+ }
+
+ assert diff2 == expected_diff2
+
@mock.patch("CIME.XML.env_batch.EnvBatch._submit_single_job")
def test_submit_jobs(self, _submit_single_job):
case = mock.MagicMock()
@@ -273,7 +436,7 @@ def test_get_submit_args(self):
sbatch
scancel
#SBATCH
- (\d+)$
+ (\\d+)$
--dependency=afterok:jobid
--dependency=afterany:jobid
: