From 881ab6c03388abdbd7a5e84021234963793ef94c Mon Sep 17 00:00:00 2001 From: Federico Stagni Date: Fri, 13 Dec 2024 15:44:37 +0100 Subject: [PATCH] sweep: #7943 Some refactoring --- src/DIRAC/Core/Utilities/DictCache.py | 106 ++++++++- .../Client/Limiter.py | 202 +++++------------- 2 files changed, 154 insertions(+), 154 deletions(-) diff --git a/src/DIRAC/Core/Utilities/DictCache.py b/src/DIRAC/Core/Utilities/DictCache.py index cbed339df9f..7ed1028fdf9 100644 --- a/src/DIRAC/Core/Utilities/DictCache.py +++ b/src/DIRAC/Core/Utilities/DictCache.py @@ -1,9 +1,15 @@ """ - DictCache. +DictCache and TwoLevelCache """ import datetime import threading import weakref +from collections import defaultdict +from collections.abc import Callable +from concurrent.futures import Future, ThreadPoolExecutor, wait +from typing import Any + +from cachetools import TTLCache # DIRAC from DIRAC.Core.Utilities.LockRing import LockRing @@ -248,3 +254,101 @@ def _purgeAll(lock, cache, deleteFunction): finally: if lock: lock.release() + + +class TwoLevelCache: + """A two-level caching system with soft and hard time-to-live (TTL) expiration. + + This cache implements a two-tier caching mechanism to allow for background refresh + of cached values. It uses a soft TTL for quick access and a hard TTL as a fallback, + which helps in reducing latency and maintaining data freshness. + + Attributes: + soft_cache (TTLCache): A cache with a shorter TTL for quick access. + hard_cache (TTLCache): A cache with a longer TTL as a fallback. + locks (defaultdict): Thread-safe locks for each cache key. + futures (dict): Stores ongoing asynchronous population tasks. + pool (ThreadPoolExecutor): Thread pool for executing cache population tasks. + + Args: + soft_ttl (int): Time-to-live in seconds for the soft cache. + hard_ttl (int): Time-to-live in seconds for the hard cache. + max_workers (int): Maximum number of workers in the thread pool. + max_items (int): Maximum number of items in the cache. + + Example: + >>> cache = TwoLevelCache(soft_ttl=60, hard_ttl=300) + >>> def populate_func(): + ... return "cached_value" + >>> value = cache.get("key", populate_func) + + Note: + The cache uses a ThreadPoolExecutor with a maximum of 10 workers to + handle concurrent cache population requests. + """ + + def __init__(self, soft_ttl: int, hard_ttl: int, *, max_workers: int = 10, max_items: int = 1_000_000): + """Initialize the TwoLevelCache with specified TTLs.""" + self.soft_cache = TTLCache(max_items, soft_ttl) + self.hard_cache = TTLCache(max_items, hard_ttl) + self.locks = defaultdict(threading.Lock) + self.futures: dict[str, Future] = {} + self.pool = ThreadPoolExecutor(max_workers=max_workers) + + def get(self, key: str, populate_func: Callable[[], Any]) -> dict: + """Retrieve a value from the cache, populating it if necessary. + + This method first checks the soft cache for the key. If not found, + it checks the hard cache while initiating a background refresh. + If the key is not in either cache, it waits for the populate_func + to complete and stores the result in both caches. + + Locks are used to ensure there is never more than one concurrent + population task for a given key. + + Args: + key (str): The cache key to retrieve or populate. + populate_func (Callable[[], Any]): A function to call to populate the cache + if the key is not found. + + Returns: + Any: The cached value associated with the key. + + Note: + This method is thread-safe and handles concurrent requests for the same key. + """ + if result := self.soft_cache.get(key): + return result + with self.locks[key]: + if key not in self.futures: + self.futures[key] = self.pool.submit(self._work, key, populate_func) + if result := self.hard_cache.get(key): + self.soft_cache[key] = result + return result + # It is critical that ``future`` is waited for outside of the lock as + # _work aquires the lock before filling the caches. This also means + # we can guarantee that the future has not yet been removed from the + # futures dict. + future = self.futures[key] + wait([future]) + return self.hard_cache[key] + + def _work(self, key: str, populate_func: Callable[[], Any]) -> None: + """Internal method to execute the populate_func and update caches. + + This method is intended to be run in a separate thread. It calls the + populate_func, stores the result in both caches, and cleans up the + associated future. + + Args: + key (str): The cache key to populate. + populate_func (Callable[[], Any]): The function to call to get the value. + + Note: + This method is not intended to be called directly by users of the class. + """ + result = populate_func() + with self.locks[key]: + self.futures.pop(key) + self.hard_cache[key] = result + self.soft_cache[key] = result diff --git a/src/DIRAC/WorkloadManagementSystem/Client/Limiter.py b/src/DIRAC/WorkloadManagementSystem/Client/Limiter.py index d0c005613b7..226ebeb9693 100644 --- a/src/DIRAC/WorkloadManagementSystem/Client/Limiter.py +++ b/src/DIRAC/WorkloadManagementSystem/Client/Limiter.py @@ -2,121 +2,14 @@ Utilities and classes here are used by the Matcher """ -import threading -from collections import defaultdict -from collections.abc import Callable -from concurrent.futures import ThreadPoolExecutor, wait, Future from functools import partial -from typing import Any -from cachetools import TTLCache - -from DIRAC import S_OK, S_ERROR -from DIRAC import gLogger - -from DIRAC.Core.Utilities.DictCache import DictCache -from DIRAC.Core.Utilities.DErrno import cmpError, ESECTION +from DIRAC import S_ERROR, S_OK, gLogger from DIRAC.ConfigurationSystem.Client.Helpers.Operations import Operations -from DIRAC.WorkloadManagementSystem.DB.JobDB import JobDB +from DIRAC.Core.Utilities.DErrno import ESECTION, cmpError +from DIRAC.Core.Utilities.DictCache import DictCache, TwoLevelCache from DIRAC.WorkloadManagementSystem.Client import JobStatus - - -class TwoLevelCache: - """A two-level caching system with soft and hard time-to-live (TTL) expiration. - - This cache implements a two-tier caching mechanism to allow for background refresh - of cached values. It uses a soft TTL for quick access and a hard TTL as a fallback, - which helps in reducing latency and maintaining data freshness. - - Attributes: - soft_cache (TTLCache): A cache with a shorter TTL for quick access. - hard_cache (TTLCache): A cache with a longer TTL as a fallback. - locks (defaultdict): Thread-safe locks for each cache key. - futures (dict): Stores ongoing asynchronous population tasks. - pool (ThreadPoolExecutor): Thread pool for executing cache population tasks. - - Args: - soft_ttl (int): Time-to-live in seconds for the soft cache. - hard_ttl (int): Time-to-live in seconds for the hard cache. - max_workers (int): Maximum number of workers in the thread pool. - max_items (int): Maximum number of items in the cache. - - Example: - >>> cache = TwoLevelCache(soft_ttl=60, hard_ttl=300) - >>> def populate_func(): - ... return "cached_value" - >>> value = cache.get("key", populate_func) - - Note: - The cache uses a ThreadPoolExecutor with a maximum of 10 workers to - handle concurrent cache population requests. - """ - - def __init__(self, soft_ttl: int, hard_ttl: int, *, max_workers: int = 10, max_items: int = 1_000_000): - """Initialize the TwoLevelCache with specified TTLs.""" - self.soft_cache = TTLCache(max_items, soft_ttl) - self.hard_cache = TTLCache(max_items, hard_ttl) - self.locks = defaultdict(threading.Lock) - self.futures: dict[str, Future] = {} - self.pool = ThreadPoolExecutor(max_workers=max_workers) - - def get(self, key: str, populate_func: Callable[[], Any]): - """Retrieve a value from the cache, populating it if necessary. - - This method first checks the soft cache for the key. If not found, - it checks the hard cache while initiating a background refresh. - If the key is not in either cache, it waits for the populate_func - to complete and stores the result in both caches. - - Locks are used to ensure there is never more than one concurrent - population task for a given key. - - Args: - key (str): The cache key to retrieve or populate. - populate_func (Callable[[], Any]): A function to call to populate the cache - if the key is not found. - - Returns: - Any: The cached value associated with the key. - - Note: - This method is thread-safe and handles concurrent requests for the same key. - """ - if result := self.soft_cache.get(key): - return result - with self.locks[key]: - if key not in self.futures: - self.futures[key] = self.pool.submit(self._work, key, populate_func) - if result := self.hard_cache.get(key): - self.soft_cache[key] = result - return result - # It is critical that ``future`` is waited for outside of the lock as - # _work aquires the lock before filling the caches. This also means - # we can gaurentee that the future has not yet been removed from the - # futures dict. - future = self.futures[key] - wait([future]) - return self.hard_cache[key] - - def _work(self, key: str, populate_func: Callable[[], Any]) -> None: - """Internal method to execute the populate_func and update caches. - - This method is intended to be run in a separate thread. It calls the - populate_func, stores the result in both caches, and cleans up the - associated future. - - Args: - key (str): The cache key to populate. - populate_func (Callable[[], Any]): The function to call to get the value. - - Note: - This method is not intended to be called directly by users of the class. - """ - result = populate_func() - with self.locks[key]: - self.futures.pop(key) - self.hard_cache[key] = result - self.soft_cache[key] = result +from DIRAC.WorkloadManagementSystem.DB.JobDB import JobDB class Limiter: @@ -152,72 +45,76 @@ def getNegativeCond(self): orCond = self.condCache.get("GLOBAL") if orCond: return orCond - negCond = {} + negativeCondition = {} + # Run Limit result = self.__opsHelper.getSections(self.__runningLimitSection) - sites = [] - if result["OK"]: - sites = result["Value"] - for siteName in sites: + if not result["OK"]: + self.log.error("Issue getting running conditions", result["Message"]) + sites_with_running_limits = [] + else: + sites_with_running_limits = result["Value"] + self.log.verbose(f"Found running conditions for {len(sites_with_running_limits)} sites") + + for siteName in sites_with_running_limits: result = self.__getRunningCondition(siteName) if not result["OK"]: - continue - data = result["Value"] - if data: - negCond[siteName] = data + self.log.error("Issue getting running conditions", result["Message"]) + running_condition = {} + else: + running_condition = result["Value"] + if running_condition: + negativeCondition[siteName] = running_condition + # Delay limit - result = self.__opsHelper.getSections(self.__matchingDelaySection) - sites = [] - if result["OK"]: - sites = result["Value"] - for siteName in sites: - result = self.__getDelayCondition(siteName) + if self.__opsHelper.getValue("JobScheduling/CheckMatchingDelay", True): + result = self.__opsHelper.getSections(self.__matchingDelaySection) if not result["OK"]: - continue - data = result["Value"] - if not data: - continue - if siteName in negCond: - negCond[siteName] = self.__mergeCond(negCond[siteName], data) + self.log.error("Issue getting delay conditions", result["Message"]) + sites_with_matching_delay = [] else: - negCond[siteName] = data + sites_with_matching_delay = result["Value"] + self.log.verbose(f"Found delay conditions for {len(sites_with_matching_delay)} sites") + + for siteName in sites_with_matching_delay: + delay_condition = self.__getDelayCondition(siteName) + if siteName in negativeCondition: + negativeCondition[siteName] = self.__mergeCond(negativeCondition[siteName], delay_condition) + else: + negativeCondition[siteName] = delay_condition + orCond = [] - for siteName in negCond: - negCond[siteName]["Site"] = siteName - orCond.append(negCond[siteName]) + for siteName in negativeCondition: + negativeCondition[siteName]["Site"] = siteName + orCond.append(negativeCondition[siteName]) self.condCache.add("GLOBAL", 10, orCond) return orCond def getNegativeCondForSite(self, siteName, gridCE=None): """Generate a negative query based on the limits set on the site""" # Check if Limits are imposed onto the site - negativeCond = {} if self.__opsHelper.getValue("JobScheduling/CheckJobLimits", True): result = self.__getRunningCondition(siteName) if not result["OK"]: self.log.error("Issue getting running conditions", result["Message"]) + negativeCond = {} else: negativeCond = result["Value"] - self.log.verbose( - "Negative conditions for site", f"{siteName} after checking limits are: {str(negativeCond)}" - ) + self.log.verbose( + "Negative conditions for site", f"{siteName} after checking limits are: {str(negativeCond)}" + ) if gridCE: result = self.__getRunningCondition(siteName, gridCE) if not result["OK"]: self.log.error("Issue getting running conditions", result["Message"]) else: - negativeCondCE = result["Value"] - negativeCond = self.__mergeCond(negativeCond, negativeCondCE) + negativeCond = self.__mergeCond(negativeCond, result["Value"]) if self.__opsHelper.getValue("JobScheduling/CheckMatchingDelay", True): - result = self.__getDelayCondition(siteName) - if result["OK"]: - delayCond = result["Value"] - self.log.verbose( - "Negative conditions for site", f"{siteName} after delay checking are: {str(delayCond)}" - ) - negativeCond = self.__mergeCond(negativeCond, delayCond) + delayCond = self.__getDelayCondition(siteName) + self.log.verbose("Negative conditions for site", f"{siteName} after delay checking are: {str(delayCond)}") + negativeCond = self.__mergeCond(negativeCond, delayCond) if negativeCond: self.log.info("Negative conditions for site", f"{siteName} are: {str(negativeCond)}") @@ -337,14 +234,14 @@ def updateDelayCounters(self, siteName, jid): def __getDelayCondition(self, siteName): """Get extra conditions allowing matching delay""" if siteName not in self.delayMem: - return S_OK({}) + return {} lastRun = self.delayMem[siteName].getKeys() negCond = {} for attName, attValue in lastRun: if attName not in negCond: negCond[attName] = [] negCond[attName].append(attValue) - return S_OK(negCond) + return negCond def _countsByJobType(self, siteName, attName): result = self.jobDB.getCounters( @@ -354,6 +251,5 @@ def _countsByJobType(self, siteName, attName): ) if not result["OK"]: return result - data = result["Value"] - data = {k[0][attName]: k[1] for k in data} + data = {k[0][attName]: k[1] for k in result["Value"]} return data