From 3781333a2d1b3a6d04b13aa6190b8de7ac466c64 Mon Sep 17 00:00:00 2001 From: aldbr Date: Tue, 3 Dec 2024 17:43:16 +0100 Subject: [PATCH] fix(wms): correctly log the pilot job reference during the matching process --- .../Client/Matcher.py | 173 +++++++++--------- .../WorkloadManagementSystem/DB/JobDB.py | 11 ++ .../DB/JobLoggingDB.py | 10 + .../DB/PilotAgentsDB.py | 10 + .../DB/TaskQueueDB.py | 10 + .../Service/MatcherHandler.py | 2 +- .../Utilities/ContextVars.py | 16 ++ 7 files changed, 142 insertions(+), 90 deletions(-) create mode 100644 src/DIRAC/WorkloadManagementSystem/Utilities/ContextVars.py diff --git a/src/DIRAC/WorkloadManagementSystem/Client/Matcher.py b/src/DIRAC/WorkloadManagementSystem/Client/Matcher.py index 33fd28b0e95..087847e1644 100644 --- a/src/DIRAC/WorkloadManagementSystem/Client/Matcher.py +++ b/src/DIRAC/WorkloadManagementSystem/Client/Matcher.py @@ -4,20 +4,19 @@ """ import time -from DIRAC import gLogger, convertToPy3VersionNumber - -from DIRAC.Core.Utilities.PrettyPrint import printDict -from DIRAC.Core.Security import Properties +from DIRAC import convertToPy3VersionNumber, gLogger from DIRAC.ConfigurationSystem.Client.Helpers import Registry from DIRAC.ConfigurationSystem.Client.Helpers.Operations import Operations -from DIRAC.WorkloadManagementSystem.Client import JobStatus +from DIRAC.Core.Security import Properties +from DIRAC.Core.Utilities.PrettyPrint import printDict +from DIRAC.ResourceStatusSystem.Client.SiteStatus import SiteStatus +from DIRAC.WorkloadManagementSystem.Client import JobStatus, PilotStatus from DIRAC.WorkloadManagementSystem.Client.Limiter import Limiter -from DIRAC.WorkloadManagementSystem.Client import PilotStatus -from DIRAC.WorkloadManagementSystem.DB.TaskQueueDB import TaskQueueDB, singleValueDefFields, multiValueMatchFields -from DIRAC.WorkloadManagementSystem.DB.PilotAgentsDB import PilotAgentsDB from DIRAC.WorkloadManagementSystem.DB.JobDB import JobDB from DIRAC.WorkloadManagementSystem.DB.JobLoggingDB import JobLoggingDB -from DIRAC.ResourceStatusSystem.Client.SiteStatus import SiteStatus +from DIRAC.WorkloadManagementSystem.DB.PilotAgentsDB import PilotAgentsDB +from DIRAC.WorkloadManagementSystem.DB.TaskQueueDB import TaskQueueDB, multiValueMatchFields, singleValueDefFields +from DIRAC.WorkloadManagementSystem.Utilities.ContextVars import setPilotRefLogger class PilotVersionError(Exception): @@ -52,11 +51,7 @@ def __init__(self, pilotAgentsDB=None, jobDB=None, tqDB=None, jlDB=None, opsHelp self.opsHelper = Operations() if pilotRef: - self.log = gLogger.getSubLogger(f"[{pilotRef}]Matcher") - self.pilotAgentsDB.log = gLogger.getSubLogger(f"[{pilotRef}]Matcher") - self.jobDB.log = gLogger.getSubLogger(f"[{pilotRef}]Matcher") - self.tqDB.log = gLogger.getSubLogger(f"[{pilotRef}]Matcher") - self.jlDB.log = gLogger.getSubLogger(f"[{pilotRef}]Matcher") + self.log = gLogger.getLocalSubLogger(f"[{pilotRef}]Matcher") else: self.log = gLogger.getSubLogger("Matcher") @@ -66,86 +61,86 @@ def __init__(self, pilotAgentsDB=None, jobDB=None, tqDB=None, jlDB=None, opsHelp def selectJob(self, resourceDescription, credDict): """Main job selection function to find the highest priority job matching the resource capacity""" + with setPilotRefLogger(self.log): + startTime = time.time() + + resourceDict = self._getResourceDict(resourceDescription, credDict) + + # Make a nice print of the resource matching parameters + toPrintDict = dict(resourceDict) + if "MaxRAM" in resourceDescription: + toPrintDict["MaxRAM"] = resourceDescription["MaxRAM"] + if "NumberOfProcessors" in resourceDescription: + toPrintDict["NumberOfProcessors"] = resourceDescription["NumberOfProcessors"] + toPrintDict["Tag"] = [] + if "Tag" in resourceDict: + for tag in resourceDict["Tag"]: + if not tag.endswith("GB") and not tag.endswith("Processors"): + toPrintDict["Tag"].append(tag) + if not toPrintDict["Tag"]: + toPrintDict.pop("Tag") + self.log.info("Resource description for matching", printDict(toPrintDict)) + + negativeCond = self.limiter.getNegativeCondForSite(resourceDict["Site"], resourceDict.get("GridCE")) + result = self.tqDB.matchAndGetJob(resourceDict, negativeCond=negativeCond) - startTime = time.time() - - resourceDict = self._getResourceDict(resourceDescription, credDict) - - # Make a nice print of the resource matching parameters - toPrintDict = dict(resourceDict) - if "MaxRAM" in resourceDescription: - toPrintDict["MaxRAM"] = resourceDescription["MaxRAM"] - if "NumberOfProcessors" in resourceDescription: - toPrintDict["NumberOfProcessors"] = resourceDescription["NumberOfProcessors"] - toPrintDict["Tag"] = [] - if "Tag" in resourceDict: - for tag in resourceDict["Tag"]: - if not tag.endswith("GB") and not tag.endswith("Processors"): - toPrintDict["Tag"].append(tag) - if not toPrintDict["Tag"]: - toPrintDict.pop("Tag") - self.log.info("Resource description for matching", printDict(toPrintDict)) - - negativeCond = self.limiter.getNegativeCondForSite(resourceDict["Site"], resourceDict.get("GridCE")) - result = self.tqDB.matchAndGetJob(resourceDict, negativeCond=negativeCond) - - if not result["OK"]: - raise RuntimeError(result["Message"]) - result = result["Value"] - if not result["matchFound"]: - self.log.info("No match found") - return {} - - jobID = result["jobId"] - resAtt = self.jobDB.getJobAttributes(jobID, ["OwnerDN", "OwnerGroup", "Status"]) - if not resAtt["OK"]: - raise RuntimeError("Could not retrieve job attributes") - if not resAtt["Value"]: - raise RuntimeError("No attributes returned for job") - if not resAtt["Value"]["Status"] == "Waiting": - self.log.error("Job matched by the TQ is not in Waiting state", str(jobID)) - result = self.tqDB.deleteJob(jobID) if not result["OK"]: raise RuntimeError(result["Message"]) - raise RuntimeError(f"Job {str(jobID)} is not in Waiting state") + result = result["Value"] + if not result["matchFound"]: + self.log.info("No match found") + return {} + + jobID = result["jobId"] + resAtt = self.jobDB.getJobAttributes(jobID, ["OwnerDN", "OwnerGroup", "Status"]) + if not resAtt["OK"]: + raise RuntimeError("Could not retrieve job attributes") + if not resAtt["Value"]: + raise RuntimeError("No attributes returned for job") + if not resAtt["Value"]["Status"] == "Waiting": + self.log.error("Job matched by the TQ is not in Waiting state", str(jobID)) + result = self.tqDB.deleteJob(jobID) + if not result["OK"]: + raise RuntimeError(result["Message"]) + raise RuntimeError(f"Job {str(jobID)} is not in Waiting state") - self._reportStatus(resourceDict, jobID) + self._reportStatus(resourceDict, jobID) - result = self.jobDB.getJobJDL(jobID) - if not result["OK"]: - raise RuntimeError("Failed to get the job JDL") - - resultDict = {} - resultDict["JDL"] = result["Value"] - resultDict["JobID"] = jobID - - matchTime = time.time() - startTime - self.log.verbose("Match time", f"[{str(matchTime)}]") - - # Get some extra stuff into the response returned - resOpt = self.jobDB.getJobOptParameters(jobID) - if resOpt["OK"]: - for key, value in resOpt["Value"].items(): - resultDict[key] = value - resAtt = self.jobDB.getJobAttributes(jobID, ["OwnerDN", "OwnerGroup"]) - if not resAtt["OK"]: - raise RuntimeError("Could not retrieve job attributes") - if not resAtt["Value"]: - raise RuntimeError("No attributes returned for job") - - if self.opsHelper.getValue("JobScheduling/CheckMatchingDelay", True): - self.limiter.updateDelayCounters(resourceDict["Site"], jobID) - - pilotInfoReportedFlag = resourceDict.get("PilotInfoReportedFlag", False) - if not pilotInfoReportedFlag: - self._updatePilotInfo(resourceDict) - self._updatePilotJobMapping(resourceDict, jobID) - - resultDict["DN"] = resAtt["Value"]["OwnerDN"] - resultDict["Group"] = resAtt["Value"]["OwnerGroup"] - resultDict["PilotInfoReportedFlag"] = True - - return resultDict + result = self.jobDB.getJobJDL(jobID) + if not result["OK"]: + raise RuntimeError("Failed to get the job JDL") + + resultDict = {} + resultDict["JDL"] = result["Value"] + resultDict["JobID"] = jobID + + matchTime = time.time() - startTime + self.log.verbose("Match time", f"[{str(matchTime)}]") + + # Get some extra stuff into the response returned + resOpt = self.jobDB.getJobOptParameters(jobID) + if resOpt["OK"]: + for key, value in resOpt["Value"].items(): + resultDict[key] = value + resAtt = self.jobDB.getJobAttributes(jobID, ["OwnerDN", "OwnerGroup"]) + if not resAtt["OK"]: + raise RuntimeError("Could not retrieve job attributes") + if not resAtt["Value"]: + raise RuntimeError("No attributes returned for job") + + if self.opsHelper.getValue("JobScheduling/CheckMatchingDelay", True): + self.limiter.updateDelayCounters(resourceDict["Site"], jobID) + + pilotInfoReportedFlag = resourceDict.get("PilotInfoReportedFlag", False) + if not pilotInfoReportedFlag: + self._updatePilotInfo(resourceDict) + self._updatePilotJobMapping(resourceDict, jobID) + + resultDict["DN"] = resAtt["Value"]["OwnerDN"] + resultDict["Group"] = resAtt["Value"]["OwnerGroup"] + resultDict["PilotInfoReportedFlag"] = True + + return resultDict def _getResourceDict(self, resourceDescription, credDict): """from resourceDescription to resourceDict (just various mods)""" diff --git a/src/DIRAC/WorkloadManagementSystem/DB/JobDB.py b/src/DIRAC/WorkloadManagementSystem/DB/JobDB.py index 532c5dace93..08091fe031f 100755 --- a/src/DIRAC/WorkloadManagementSystem/DB/JobDB.py +++ b/src/DIRAC/WorkloadManagementSystem/DB/JobDB.py @@ -32,6 +32,7 @@ extractJDL, fixJDL, ) +from DIRAC.WorkloadManagementSystem.Utilities.ContextVars import pilotRefLogger class JobDB(DB): @@ -42,6 +43,8 @@ def __init__(self, parentLogger=None): DB.__init__(self, "JobDB", "WorkloadManagement/JobDB", parentLogger=parentLogger) + self._defaultLogger = self.log + # data member to check if __init__ went through without error self.__initialized = False self.maxRescheduling = self.getCSOption("MaxRescheduling", 3) @@ -64,6 +67,14 @@ def __init__(self, parentLogger=None): self.log.info("==================================================") self.__initialized = True + @property + def log(self): + return pilotRefLogger.get() or self._defaultLogger + + @log.setter + def log(self, value): + self._defaultLogger = value + def isValid(self): """Check if correctly initialised""" return self.__initialized diff --git a/src/DIRAC/WorkloadManagementSystem/DB/JobLoggingDB.py b/src/DIRAC/WorkloadManagementSystem/DB/JobLoggingDB.py index ee4b16df55e..62f5f8e3a58 100755 --- a/src/DIRAC/WorkloadManagementSystem/DB/JobLoggingDB.py +++ b/src/DIRAC/WorkloadManagementSystem/DB/JobLoggingDB.py @@ -11,6 +11,7 @@ from DIRAC import S_ERROR, S_OK from DIRAC.Core.Base.DB import DB from DIRAC.Core.Utilities import TimeUtilities +from DIRAC.WorkloadManagementSystem.Utilities.ContextVars import pilotRefLogger MAGIC_EPOC_NUMBER = 1270000000 @@ -24,6 +25,15 @@ def __init__(self, parentLogger=None): """Standard Constructor""" DB.__init__(self, "JobLoggingDB", "WorkloadManagement/JobLoggingDB", parentLogger=parentLogger) + self._defaultLogger = self.log + + @property + def log(self): + return pilotRefLogger.get() or self._defaultLogger + + @log.setter + def log(self, value): + self._defaultLogger = value ############################################################################# def addLoggingRecord( diff --git a/src/DIRAC/WorkloadManagementSystem/DB/PilotAgentsDB.py b/src/DIRAC/WorkloadManagementSystem/DB/PilotAgentsDB.py index 7d7e3cb873c..6a64c07f5ea 100755 --- a/src/DIRAC/WorkloadManagementSystem/DB/PilotAgentsDB.py +++ b/src/DIRAC/WorkloadManagementSystem/DB/PilotAgentsDB.py @@ -32,13 +32,23 @@ from DIRAC.Core.Utilities.MySQL import _quotedList from DIRAC.ResourceStatusSystem.Client.SiteStatus import SiteStatus from DIRAC.WorkloadManagementSystem.Client import PilotStatus +from DIRAC.WorkloadManagementSystem.Utilities.ContextVars import pilotRefLogger class PilotAgentsDB(DB): def __init__(self, parentLogger=None): super().__init__("PilotAgentsDB", "WorkloadManagement/PilotAgentsDB", parentLogger=parentLogger) + self._defaultLogger = self.log self.lock = threading.Lock() + @property + def log(self): + return pilotRefLogger.get() or self._defaultLogger + + @log.setter + def log(self, value): + self._defaultLogger = value + ########################################################################################## def addPilotReferences(self, pilotRef, ownerGroup, gridType="DIRAC", pilotStampDict={}): """Add a new pilot job reference""" diff --git a/src/DIRAC/WorkloadManagementSystem/DB/TaskQueueDB.py b/src/DIRAC/WorkloadManagementSystem/DB/TaskQueueDB.py index e94fc617e78..d2d426c92be 100755 --- a/src/DIRAC/WorkloadManagementSystem/DB/TaskQueueDB.py +++ b/src/DIRAC/WorkloadManagementSystem/DB/TaskQueueDB.py @@ -12,6 +12,7 @@ from DIRAC.ConfigurationSystem.Client.Helpers.Operations import Operations from DIRAC.ConfigurationSystem.Client.Helpers import Registry from DIRAC.WorkloadManagementSystem.private.SharesCorrector import SharesCorrector +from DIRAC.WorkloadManagementSystem.Utilities.ContextVars import pilotRefLogger DEFAULT_GROUP_SHARE = 1000 TQ_MIN_SHARE = 0.001 @@ -37,6 +38,7 @@ class TaskQueueDB(DB): def __init__(self, parentLogger=None): DB.__init__(self, "TaskQueueDB", "WorkloadManagement/TaskQueueDB", parentLogger=parentLogger) + self._defaultLogger = self.log self.__maxJobsInTQ = 5000 self.__defaultCPUSegments = [ 6 * 60, @@ -64,6 +66,14 @@ def __init__(self, parentLogger=None): if not result["OK"]: raise Exception(f"Can't create tables: {result['Message']}") + @property + def log(self): + return pilotRefLogger.get() or self._defaultLogger + + @log.setter + def log(self, value): + self._defaultLogger = value + def enableAllTaskQueues(self): """Enable all Task queues""" return self.updateFields("tq_TaskQueues", updateDict={"Enabled": "1"}) diff --git a/src/DIRAC/WorkloadManagementSystem/Service/MatcherHandler.py b/src/DIRAC/WorkloadManagementSystem/Service/MatcherHandler.py index 7d7bf4428b6..cd3dc992c48 100755 --- a/src/DIRAC/WorkloadManagementSystem/Service/MatcherHandler.py +++ b/src/DIRAC/WorkloadManagementSystem/Service/MatcherHandler.py @@ -56,7 +56,7 @@ def export_requestJob(self, resourceDescription): resourceDescription["Setup"] = self.serviceInfoDict["clientSetup"] credDict = self.getRemoteCredentials() - pilotRef = resourceDescription.get("PilotReference", "Unknown") + pilotRef = resourceDescription.get("PilotReference") try: opsHelper = Operations(group=credDict["group"]) diff --git a/src/DIRAC/WorkloadManagementSystem/Utilities/ContextVars.py b/src/DIRAC/WorkloadManagementSystem/Utilities/ContextVars.py new file mode 100644 index 00000000000..1044c453764 --- /dev/null +++ b/src/DIRAC/WorkloadManagementSystem/Utilities/ContextVars.py @@ -0,0 +1,16 @@ +""" Context variables for the Workload Management System """ + +# Context variable for the logger (adapted to the request of the pilot reference) +import contextvars +from contextlib import contextmanager + +pilotRefLogger = contextvars.ContextVar("PilotRefLogger", default=None) + + +@contextmanager +def setPilotRefLogger(logger_name): + token = pilotRefLogger.set(logger_name) + try: + yield + finally: + pilotRefLogger.reset(token)