Skip to content

Commit

Permalink
Merge pull request #7950 from DIRACGridBot/cherry-pick-2-6e4c9ec01-in…
Browse files Browse the repository at this point in the history
…tegration

[sweep:integration] Some refactoring
  • Loading branch information
fstagni authored Dec 13, 2024
2 parents 7120c95 + 881ab6c commit ec1855e
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 154 deletions.
106 changes: 105 additions & 1 deletion src/DIRAC/Core/Utilities/DictCache.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
"""
DictCache.
DictCache and TwoLevelCache
"""
import datetime
import threading
import weakref
from collections import defaultdict
from collections.abc import Callable
from concurrent.futures import Future, ThreadPoolExecutor, wait
from typing import Any

from cachetools import TTLCache

# DIRAC
from DIRAC.Core.Utilities.LockRing import LockRing
Expand Down Expand Up @@ -248,3 +254,101 @@ def _purgeAll(lock, cache, deleteFunction):
finally:
if lock:
lock.release()


class TwoLevelCache:
"""A two-level caching system with soft and hard time-to-live (TTL) expiration.
This cache implements a two-tier caching mechanism to allow for background refresh
of cached values. It uses a soft TTL for quick access and a hard TTL as a fallback,
which helps in reducing latency and maintaining data freshness.
Attributes:
soft_cache (TTLCache): A cache with a shorter TTL for quick access.
hard_cache (TTLCache): A cache with a longer TTL as a fallback.
locks (defaultdict): Thread-safe locks for each cache key.
futures (dict): Stores ongoing asynchronous population tasks.
pool (ThreadPoolExecutor): Thread pool for executing cache population tasks.
Args:
soft_ttl (int): Time-to-live in seconds for the soft cache.
hard_ttl (int): Time-to-live in seconds for the hard cache.
max_workers (int): Maximum number of workers in the thread pool.
max_items (int): Maximum number of items in the cache.
Example:
>>> cache = TwoLevelCache(soft_ttl=60, hard_ttl=300)
>>> def populate_func():
... return "cached_value"
>>> value = cache.get("key", populate_func)
Note:
The cache uses a ThreadPoolExecutor with a maximum of 10 workers to
handle concurrent cache population requests.
"""

def __init__(self, soft_ttl: int, hard_ttl: int, *, max_workers: int = 10, max_items: int = 1_000_000):
"""Initialize the TwoLevelCache with specified TTLs."""
self.soft_cache = TTLCache(max_items, soft_ttl)
self.hard_cache = TTLCache(max_items, hard_ttl)
self.locks = defaultdict(threading.Lock)
self.futures: dict[str, Future] = {}
self.pool = ThreadPoolExecutor(max_workers=max_workers)

def get(self, key: str, populate_func: Callable[[], Any]) -> dict:
"""Retrieve a value from the cache, populating it if necessary.
This method first checks the soft cache for the key. If not found,
it checks the hard cache while initiating a background refresh.
If the key is not in either cache, it waits for the populate_func
to complete and stores the result in both caches.
Locks are used to ensure there is never more than one concurrent
population task for a given key.
Args:
key (str): The cache key to retrieve or populate.
populate_func (Callable[[], Any]): A function to call to populate the cache
if the key is not found.
Returns:
Any: The cached value associated with the key.
Note:
This method is thread-safe and handles concurrent requests for the same key.
"""
if result := self.soft_cache.get(key):
return result
with self.locks[key]:
if key not in self.futures:
self.futures[key] = self.pool.submit(self._work, key, populate_func)
if result := self.hard_cache.get(key):
self.soft_cache[key] = result
return result
# It is critical that ``future`` is waited for outside of the lock as
# _work aquires the lock before filling the caches. This also means
# we can guarantee that the future has not yet been removed from the
# futures dict.
future = self.futures[key]
wait([future])
return self.hard_cache[key]

def _work(self, key: str, populate_func: Callable[[], Any]) -> None:
"""Internal method to execute the populate_func and update caches.
This method is intended to be run in a separate thread. It calls the
populate_func, stores the result in both caches, and cleans up the
associated future.
Args:
key (str): The cache key to populate.
populate_func (Callable[[], Any]): The function to call to get the value.
Note:
This method is not intended to be called directly by users of the class.
"""
result = populate_func()
with self.locks[key]:
self.futures.pop(key)
self.hard_cache[key] = result
self.soft_cache[key] = result
202 changes: 49 additions & 153 deletions src/DIRAC/WorkloadManagementSystem/Client/Limiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,121 +2,14 @@
Utilities and classes here are used by the Matcher
"""
import threading
from collections import defaultdict
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor, wait, Future
from functools import partial
from typing import Any

from cachetools import TTLCache

from DIRAC import S_OK, S_ERROR
from DIRAC import gLogger

from DIRAC.Core.Utilities.DictCache import DictCache
from DIRAC.Core.Utilities.DErrno import cmpError, ESECTION
from DIRAC import S_ERROR, S_OK, gLogger
from DIRAC.ConfigurationSystem.Client.Helpers.Operations import Operations
from DIRAC.WorkloadManagementSystem.DB.JobDB import JobDB
from DIRAC.Core.Utilities.DErrno import ESECTION, cmpError
from DIRAC.Core.Utilities.DictCache import DictCache, TwoLevelCache
from DIRAC.WorkloadManagementSystem.Client import JobStatus


class TwoLevelCache:
"""A two-level caching system with soft and hard time-to-live (TTL) expiration.
This cache implements a two-tier caching mechanism to allow for background refresh
of cached values. It uses a soft TTL for quick access and a hard TTL as a fallback,
which helps in reducing latency and maintaining data freshness.
Attributes:
soft_cache (TTLCache): A cache with a shorter TTL for quick access.
hard_cache (TTLCache): A cache with a longer TTL as a fallback.
locks (defaultdict): Thread-safe locks for each cache key.
futures (dict): Stores ongoing asynchronous population tasks.
pool (ThreadPoolExecutor): Thread pool for executing cache population tasks.
Args:
soft_ttl (int): Time-to-live in seconds for the soft cache.
hard_ttl (int): Time-to-live in seconds for the hard cache.
max_workers (int): Maximum number of workers in the thread pool.
max_items (int): Maximum number of items in the cache.
Example:
>>> cache = TwoLevelCache(soft_ttl=60, hard_ttl=300)
>>> def populate_func():
... return "cached_value"
>>> value = cache.get("key", populate_func)
Note:
The cache uses a ThreadPoolExecutor with a maximum of 10 workers to
handle concurrent cache population requests.
"""

def __init__(self, soft_ttl: int, hard_ttl: int, *, max_workers: int = 10, max_items: int = 1_000_000):
"""Initialize the TwoLevelCache with specified TTLs."""
self.soft_cache = TTLCache(max_items, soft_ttl)
self.hard_cache = TTLCache(max_items, hard_ttl)
self.locks = defaultdict(threading.Lock)
self.futures: dict[str, Future] = {}
self.pool = ThreadPoolExecutor(max_workers=max_workers)

def get(self, key: str, populate_func: Callable[[], Any]):
"""Retrieve a value from the cache, populating it if necessary.
This method first checks the soft cache for the key. If not found,
it checks the hard cache while initiating a background refresh.
If the key is not in either cache, it waits for the populate_func
to complete and stores the result in both caches.
Locks are used to ensure there is never more than one concurrent
population task for a given key.
Args:
key (str): The cache key to retrieve or populate.
populate_func (Callable[[], Any]): A function to call to populate the cache
if the key is not found.
Returns:
Any: The cached value associated with the key.
Note:
This method is thread-safe and handles concurrent requests for the same key.
"""
if result := self.soft_cache.get(key):
return result
with self.locks[key]:
if key not in self.futures:
self.futures[key] = self.pool.submit(self._work, key, populate_func)
if result := self.hard_cache.get(key):
self.soft_cache[key] = result
return result
# It is critical that ``future`` is waited for outside of the lock as
# _work aquires the lock before filling the caches. This also means
# we can gaurentee that the future has not yet been removed from the
# futures dict.
future = self.futures[key]
wait([future])
return self.hard_cache[key]

def _work(self, key: str, populate_func: Callable[[], Any]) -> None:
"""Internal method to execute the populate_func and update caches.
This method is intended to be run in a separate thread. It calls the
populate_func, stores the result in both caches, and cleans up the
associated future.
Args:
key (str): The cache key to populate.
populate_func (Callable[[], Any]): The function to call to get the value.
Note:
This method is not intended to be called directly by users of the class.
"""
result = populate_func()
with self.locks[key]:
self.futures.pop(key)
self.hard_cache[key] = result
self.soft_cache[key] = result
from DIRAC.WorkloadManagementSystem.DB.JobDB import JobDB


class Limiter:
Expand Down Expand Up @@ -152,72 +45,76 @@ def getNegativeCond(self):
orCond = self.condCache.get("GLOBAL")
if orCond:
return orCond
negCond = {}
negativeCondition = {}

# Run Limit
result = self.__opsHelper.getSections(self.__runningLimitSection)
sites = []
if result["OK"]:
sites = result["Value"]
for siteName in sites:
if not result["OK"]:
self.log.error("Issue getting running conditions", result["Message"])
sites_with_running_limits = []
else:
sites_with_running_limits = result["Value"]
self.log.verbose(f"Found running conditions for {len(sites_with_running_limits)} sites")

for siteName in sites_with_running_limits:
result = self.__getRunningCondition(siteName)
if not result["OK"]:
continue
data = result["Value"]
if data:
negCond[siteName] = data
self.log.error("Issue getting running conditions", result["Message"])
running_condition = {}
else:
running_condition = result["Value"]
if running_condition:
negativeCondition[siteName] = running_condition

# Delay limit
result = self.__opsHelper.getSections(self.__matchingDelaySection)
sites = []
if result["OK"]:
sites = result["Value"]
for siteName in sites:
result = self.__getDelayCondition(siteName)
if self.__opsHelper.getValue("JobScheduling/CheckMatchingDelay", True):
result = self.__opsHelper.getSections(self.__matchingDelaySection)
if not result["OK"]:
continue
data = result["Value"]
if not data:
continue
if siteName in negCond:
negCond[siteName] = self.__mergeCond(negCond[siteName], data)
self.log.error("Issue getting delay conditions", result["Message"])
sites_with_matching_delay = []
else:
negCond[siteName] = data
sites_with_matching_delay = result["Value"]
self.log.verbose(f"Found delay conditions for {len(sites_with_matching_delay)} sites")

for siteName in sites_with_matching_delay:
delay_condition = self.__getDelayCondition(siteName)
if siteName in negativeCondition:
negativeCondition[siteName] = self.__mergeCond(negativeCondition[siteName], delay_condition)
else:
negativeCondition[siteName] = delay_condition

orCond = []
for siteName in negCond:
negCond[siteName]["Site"] = siteName
orCond.append(negCond[siteName])
for siteName in negativeCondition:
negativeCondition[siteName]["Site"] = siteName
orCond.append(negativeCondition[siteName])
self.condCache.add("GLOBAL", 10, orCond)
return orCond

def getNegativeCondForSite(self, siteName, gridCE=None):
"""Generate a negative query based on the limits set on the site"""
# Check if Limits are imposed onto the site
negativeCond = {}
if self.__opsHelper.getValue("JobScheduling/CheckJobLimits", True):
result = self.__getRunningCondition(siteName)
if not result["OK"]:
self.log.error("Issue getting running conditions", result["Message"])
negativeCond = {}
else:
negativeCond = result["Value"]
self.log.verbose(
"Negative conditions for site", f"{siteName} after checking limits are: {str(negativeCond)}"
)
self.log.verbose(
"Negative conditions for site", f"{siteName} after checking limits are: {str(negativeCond)}"
)

if gridCE:
result = self.__getRunningCondition(siteName, gridCE)
if not result["OK"]:
self.log.error("Issue getting running conditions", result["Message"])
else:
negativeCondCE = result["Value"]
negativeCond = self.__mergeCond(negativeCond, negativeCondCE)
negativeCond = self.__mergeCond(negativeCond, result["Value"])

if self.__opsHelper.getValue("JobScheduling/CheckMatchingDelay", True):
result = self.__getDelayCondition(siteName)
if result["OK"]:
delayCond = result["Value"]
self.log.verbose(
"Negative conditions for site", f"{siteName} after delay checking are: {str(delayCond)}"
)
negativeCond = self.__mergeCond(negativeCond, delayCond)
delayCond = self.__getDelayCondition(siteName)
self.log.verbose("Negative conditions for site", f"{siteName} after delay checking are: {str(delayCond)}")
negativeCond = self.__mergeCond(negativeCond, delayCond)

if negativeCond:
self.log.info("Negative conditions for site", f"{siteName} are: {str(negativeCond)}")
Expand Down Expand Up @@ -337,14 +234,14 @@ def updateDelayCounters(self, siteName, jid):
def __getDelayCondition(self, siteName):
"""Get extra conditions allowing matching delay"""
if siteName not in self.delayMem:
return S_OK({})
return {}
lastRun = self.delayMem[siteName].getKeys()
negCond = {}
for attName, attValue in lastRun:
if attName not in negCond:
negCond[attName] = []
negCond[attName].append(attValue)
return S_OK(negCond)
return negCond

def _countsByJobType(self, siteName, attName):
result = self.jobDB.getCounters(
Expand All @@ -354,6 +251,5 @@ def _countsByJobType(self, siteName, attName):
)
if not result["OK"]:
return result
data = result["Value"]
data = {k[0][attName]: k[1] for k in data}
data = {k[0][attName]: k[1] for k in result["Value"]}
return data

0 comments on commit ec1855e

Please sign in to comment.