From eb0ef3b2102b31b5488b2f735ebcbecf8ea7383f Mon Sep 17 00:00:00 2001 From: AntoineRichard Date: Tue, 24 Sep 2024 20:14:01 +0200 Subject: [PATCH] Improved thread management. Now the code catches Ctrl + C. Still closing the simulation window will lead to a mess. --- src/environments/large_scale_lunar.py | 3 + .../ros2/base_wrapper_ros2.py | 7 + .../ros2/largescale_ros2.py | 8 +- .../ros2/simulation_manager_ros2.py | 41 ++- .../high_resolution_DEM_generator.py | 45 ++- .../high_resolution_DEM_workers.py | 333 ++++++++++++++++-- 6 files changed, 370 insertions(+), 67 deletions(-) diff --git a/src/environments/large_scale_lunar.py b/src/environments/large_scale_lunar.py index c93aa4a..00378ff 100644 --- a/src/environments/large_scale_lunar.py +++ b/src/environments/large_scale_lunar.py @@ -128,6 +128,9 @@ def update(self) -> None: p, _ = self.pose_tracker() self.LSTM.update_visual_mesh((p[0], p[1])) + def monitor_thread_is_alive(self) -> None: + return self.LSTM.map_manager.hr_dem_gen.monitor_thread.thread.is_alive() + def load(self) -> None: """ Loads the lab interactive elements in the stage. diff --git a/src/environments_wrappers/ros2/base_wrapper_ros2.py b/src/environments_wrappers/ros2/base_wrapper_ros2.py index 8230a3f..fbe04e3 100644 --- a/src/environments_wrappers/ros2/base_wrapper_ros2.py +++ b/src/environments_wrappers/ros2/base_wrapper_ros2.py @@ -210,3 +210,10 @@ def clean_scene(self): """ self.destroy_node() + + def monitor_thread_is_alive(self): + """ + Checks if the monitor thread is alive. + """ + + return True diff --git a/src/environments_wrappers/ros2/largescale_ros2.py b/src/environments_wrappers/ros2/largescale_ros2.py index d7d5915..a49d7b8 100644 --- a/src/environments_wrappers/ros2/largescale_ros2.py +++ b/src/environments_wrappers/ros2/largescale_ros2.py @@ -25,6 +25,7 @@ def __init__( self, environment_cfg: dict = None, is_simulation_alive: callable = lambda: True, + close_simulation: callable = lambda: None, **kwargs, ) -> None: """ @@ -37,7 +38,9 @@ def __init__( """ super().__init__(environment_cfg=environment_cfg, **kwargs) - self.LC = LargeScaleController(**environment_cfg, is_simulation_alive=is_simulation_alive) + self.LC = LargeScaleController( + **environment_cfg, is_simulation_alive=is_simulation_alive, close_simulation=close_simulation + ) self.LC.load() self.create_subscription(Float32, "/OmniLRS/Sun/Intensity", self.set_sun_intensity, 1) @@ -118,3 +121,6 @@ def set_sun_pose(self, data: Pose) -> None: position = [data.position.x, data.position.y, data.position.z] orientation = [data.orientation.w, data.orientation.y, data.orientation.z, data.orientation.x] self.modifications.append([self.LC.set_sun_pose, {"position": position, "orientation": orientation}]) + + def monitor_thread_is_alive(self): + return self.LC.monitor_thread_is_alive() diff --git a/src/environments_wrappers/ros2/simulation_manager_ros2.py b/src/environments_wrappers/ros2/simulation_manager_ros2.py index 25ba5c7..9885fce 100644 --- a/src/environments_wrappers/ros2/simulation_manager_ros2.py +++ b/src/environments_wrappers/ros2/simulation_manager_ros2.py @@ -11,6 +11,7 @@ from omni.isaac.kit import SimulationApp from omni.isaac.core import World from typing import Union +import logging import omni import time @@ -21,6 +22,10 @@ from src.configurations.procedural_terrain_confs import TerrainManagerConf from rclpy.executors import SingleThreadedExecutor as Executor from src.physics.physics_scene import PhysicsSceneManager +import rclpy + +logger = logging.getLogger(__name__) +logging.basicConfig(format="%(asctime)s %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p") class Rate: @@ -93,7 +98,7 @@ def register( def __call__( self, cfg: dict, - is_simulation_alive: callable, + **kwargs, ) -> Union[ROS_LunalabManager, ROS_LunaryardManager]: """ Returns an instance of the lab manager corresponding to the environment name. @@ -107,7 +112,7 @@ def __call__( return self._lab_managers[cfg["environment"]["name"]]( environment_cfg=cfg["environment"], - is_simulation_alive=is_simulation_alive, + **kwargs, ) @@ -160,16 +165,18 @@ def __init__( self.rate = Rate(is_disabled=True) # Lab manager thread - self.ROSLabManager = ROS2_LMF(cfg, self.simulation_app.is_running) - exec1 = Executor() - exec1.add_node(self.ROSLabManager) - self.exec1_thread = Thread(target=exec1.spin, daemon=True, args=()) + self.ROSLabManager = ROS2_LMF( + cfg, is_simulation_alive=self.simulation_app.is_running, close_simulation=self.simulation_app.close + ) + self.exec1 = Executor() + self.exec1.add_node(self.ROSLabManager) + self.exec1_thread = Thread(target=self.exec1.spin, daemon=True, args=()) self.exec1_thread.start() # Robot manager thread self.ROSRobotManager = ROS_RobotManager(cfg["environment"]["robots_settings"]) - exec2 = Executor() - exec2.add_node(self.ROSRobotManager) - self.exec2_thread = Thread(target=exec2.spin, daemon=True, args=()) + self.exec2 = Executor() + self.exec2.add_node(self.ROSRobotManager) + self.exec2_thread = Thread(target=self.exec2.spin, daemon=True, args=()) self.exec2_thread.start() # Have you ever asked your self: "Is there a limit of topics one can subscribe to in ROS2?" @@ -222,6 +229,20 @@ def run_simulation(self) -> None: if self.world.current_time_step_index >= (self.deform_delay * self.world.get_physics_dt()): self.ROSLabManager.LC.deform_terrain() # self.ROSLabManager.LC.applyTerramechanics() - self.rate.sleep() + if not self.ROSLabManager.monitor_thread_is_alive(): + logger.debug("Destroying the ROS nodes") + self.ROSLabManager.destroy_node() + self.ROSRobotManager.destroy_node() + logger.debug("Shutting down the ROS executors") + self.exec1.shutdown() + self.exec2.shutdown() + logger.debug("Joining the ROS threads") + self.exec1_thread.join() + self.exec2_thread.join() + logger.debug("Shutting down ROS2") + rclpy.shutdown() + break + self.rate.sleep() + self.world.stop() self.timeline.stop() diff --git a/src/terrain_management/large_scale_terrain/high_resolution_DEM_generator.py b/src/terrain_management/large_scale_terrain/high_resolution_DEM_generator.py index 134ccff..f7264b4 100644 --- a/src/terrain_management/large_scale_terrain/high_resolution_DEM_generator.py +++ b/src/terrain_management/large_scale_terrain/high_resolution_DEM_generator.py @@ -10,7 +10,7 @@ import numpy as np import dataclasses import threading -import warnings +import logging import time import copy @@ -34,6 +34,9 @@ ThreadMonitor, ) +logger = logging.getLogger(__name__) +logging.basicConfig(format="%(asctime)s %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p") + @dataclasses.dataclass class HighResDEMConf: @@ -111,15 +114,21 @@ def __init__( low_res_dem: np.ndarray, settings: HighResDEMGenConf, profiling: bool = True, + is_simulation_alive: callable = lambda: True, + close_simulation: callable = lambda: None, ) -> None: """ Args: low_res_dem (np.ndarray): The low resolution DEM. settings (HighResDEMGenConf): The settings for the high resolution DEM generation. profiling (bool): True if the profiling is enabled, False otherwise. + is_simulation_alive (callable): A callable that returns True if the simulation is alive, False otherwise. + close_simulation (callable): A callable that closes the simulation. """ self.low_res_dem = low_res_dem self.settings = settings + self.is_simulation_alive = is_simulation_alive + self.close_simulation = close_simulation self.current_block_coord = (0, 0) self.sim_is_warm = False @@ -149,7 +158,9 @@ def build(self) -> None: # Creates the worker managers that will distribute the work to the workers. # This enables the generation of craters and the interpolation of the terrain # data to be done in parallel. - self.monitor_thread = ThreadMonitor() + self.monitor_thread = ThreadMonitor( + is_simulation_alive=self.is_simulation_alive, close_simulation=self.close_simulation + ) self.crater_builder_manager = CraterBuilderManager( settings=self.settings.crater_worker_manager_cfg, @@ -161,6 +172,7 @@ def build(self) -> None: interp=self.interpolator, parent_thread=self.monitor_thread.thread, ) + self.monitor_thread.add_shutdowns(self.crater_builder_manager.shutdown, self.interpolator_manager.shutdown) # Instantiates the high resolution DEM with the given settings. self.settings = self.settings.high_res_dem_cfg self.build_block_grid() @@ -429,7 +441,7 @@ def shift(self, coordinates: Tuple[float, float]) -> None: # even if the terrain reconstruction is not complete. This would prevent a full simulation lock # that waits for one reconstruction to finish before starting the next one. - print("Called shift") + logger.debug("Called shift") with ScopedTimer("Shift", active=self.profiling): # Compute initial coordinates in block space with ScopedTimer("Cast coordinates to block space", active=self.profiling): @@ -459,7 +471,7 @@ def shift(self, coordinates: Tuple[float, float]) -> None: with ScopedTimer("Add terrain blocks to queues", active=self.profiling): # Asynchronous terrain block generation self.generate_terrain_blocks() # <-- This is the bit that needs to be edited to support multiple calls - print("Shift done") + logger.debug("Shift done") def get_height(self, coordinates: Tuple[float, float]) -> float: """ @@ -576,30 +588,29 @@ def update_high_res_dem(self, coords: Tuple[float, float]) -> bool: # Initial map generation if not self.sim_is_warm: - print("Warming up simulation") + logger.debug("Warming up simulation") self.shift(block_coordinates) # Threaded update, the function will return before the update is done - threading.Thread(target=self.threaded_high_res_dem_update).start() + thread = threading.Thread(target=self.threaded_high_res_dem_update) + thread.start() self.sim_is_warm = True updated = True # Map update if the block has changed if self.current_block_coord != block_coordinates: - print("Triggering high res DEM update") + logger.debug("Triggering high res DEM update") if not self.terrain_is_primed: - print("Map is not done, waiting for the terrain data") + logger.debug("Map is not done, waiting for the terrain data") while not self.terrain_is_primed: time.sleep(0.2) - print([self.block_grid_tracker[coord] for coord in self.list_missing_blocks()]) - # print(self.crater_builder_manager.get_load_per_worker()) - # print(self.interpolator_manager.get_load_per_worker()) - print("Map is done carrying on") + logger.debug([self.block_grid_tracker[coord] for coord in self.list_missing_blocks()]) + logger.debug("Map is done carrying on") # Threaded update, the function will return before the update is done if self.thread is None: self.shift(coords) self.thread = threading.Thread(target=self.threaded_high_res_dem_update).start() elif self.thread.is_alive(): - print("Thread is alive waiting for it to finish") + logger.debug("Thread is alive waiting for it to finish") while self.thread.is_alive(): time.sleep(0.1) self.shift(coords) @@ -639,12 +650,12 @@ def threaded_high_res_dem_update(self) -> None: # even if the terrain reconstruction is not complete. This would prevent a full simulation lock # that waits for one reconstruction to finish before starting the next one. - print("Opening thread") + logger.debug("Opening thread") self.terrain_is_primed = False while (not self.is_map_done()) and (self.monitor_thread.thread.is_alive()): self.collect_terrain_data() time.sleep(0.1) - print("Thread closing map is done") + logger.debug("Thread closing map is done") self.terrain_is_primed = True def generate_craters_metadata(self, new_block_coord) -> None: @@ -672,7 +683,7 @@ def generate_craters_metadata(self, new_block_coord) -> None: self.block_grid_tracker[local_coords]["has_crater_metadata"] = True else: self.block_grid_tracker[local_coords]["has_crater_metadata"] = False - print(f"Block {coords} does not have crater metadata") + logger.warn(f"Block {coords} does not have crater metadata") def querry_low_res_dem(self, coordinates: Tuple[float, float]) -> np.ndarray: """ @@ -766,7 +777,7 @@ def collect_terrain_data(self) -> None: # Offset to account for the padding blocks offset = int((self.settings.num_blocks + 1) * self.settings.block_size / self.settings.resolution) - print("collecting...") + logger.debug("collecting...") # Collect the results from the workers responsible for adding craters crater_results = self.crater_builder_manager.collect_results() for coords, data in crater_results: diff --git a/src/terrain_management/large_scale_terrain/high_resolution_DEM_workers.py b/src/terrain_management/large_scale_terrain/high_resolution_DEM_workers.py index a7d6c05..1edfa17 100644 --- a/src/terrain_management/large_scale_terrain/high_resolution_DEM_workers.py +++ b/src/terrain_management/large_scale_terrain/high_resolution_DEM_workers.py @@ -12,6 +12,8 @@ import numpy as np import dataclasses import threading +import logging +import signal import time import copy import PIL @@ -21,6 +23,9 @@ CraterBuilder, ) +logger = logging.getLogger(__name__) +logging.basicConfig(format="%(asctime)s %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p") + @dataclasses.dataclass class InterpolatorConf: @@ -43,13 +48,13 @@ def __post_init__(self): self.method = self.method.lower() self.update() if self.source_padding < 2: - print("Warning: Padding may be too small for interpolation.") + logger.warn("Padding may be too small for interpolation.") self.source_padding = 2 if self.method == "bicubic": self.method = cv2.INTER_CUBIC if self.fx < 1.0: - print("Warning: Bicubic interpolation with downscaling. Consider using a different method.") + logger.warn("Bicubic interpolation with downscaling. Consider using a different method.") elif self.method == "nearest": self.method = cv2.INTER_NEAREST elif self.method == "linear": @@ -57,7 +62,7 @@ def __post_init__(self): elif self.method == "area": self.method = cv2.INTER_AREA if self.fx > 1.0: - print("Warning: Area interpolation with upscaling. Consider using a different method.") + logger.warn("Area interpolation with upscaling. Consider using a different method.") else: raise ValueError(f"Invalid interpolation method: {self.method}") @@ -191,6 +196,94 @@ def interpolate(self, data: np.ndarray) -> np.ndarray: ] +class BaseWorker(multiprocessing.Process): + """ + BaseWorker class. This class is used to create worker processes that can be used + to distribute the work across multiple processes. + """ + + def __init__( + self, + queue_size: int, + output_queue: multiprocessing.JoinableQueue, + parent_thread: threading.Thread, + thread_timeout: float = 1.0, + **kwargs, + ): + """ + Args: + queue_size (int): The size of the input queue. + output_queue (multiprocessing.JoinableQueue): The output queue. + parent_thread (threading.Thread): The parent thread. + thread_timeout (float): The timeout for the worker thread. (seconds) + **kwargs: Additional arguments. + """ + + super().__init__() + self.stop_event = multiprocessing.Event() + self.input_queue = multiprocessing.JoinableQueue(maxsize=queue_size) + self.output_queue = output_queue + self.daemon = True + self.parent_thread = parent_thread + self.thread_timeout = thread_timeout + + def get_input_queue_length(self) -> int: + """ + Returns the length of the input queue. + + Returns: + int: The length of the input queue. + """ + + return self.input_queue.qsize() + + def is_input_queue_empty(self) -> bool: + """ + Checks if the input queue is empty. + + Returns: + bool: True if the input queue is empty, False otherwise. + """ + + return self.input_queue.empty() + + def is_input_queue_full(self): + """ + Checks if the input queue is full. + + Returns: + bool: True if the input queue is full, False otherwise. + """ + + return self.input_queue.full() + + def run(self) -> None: + """ + The main function of the worker process. + This function is called when the worker process is started. + + TO BE IMPLEMENTED BY SUBCLASSES. + """ + + raise NotImplementedError + + def shutdown(self) -> None: + """ + Shuts down the worker process. + """ + + logger.debug("Shutting down worker.") + self.input_queue.put(((0, 0), None)) + self.stop_event.set() + self.join() + logger.debug("Worker joined.") + while not self.input_queue.empty(): + self.input_queue.get() + logger.debug("Worker input queue emptied.") + self.input_queue.join() + logger.debug("Worker input queue joined.") + + @dataclasses.dataclass class WorkerManagerConf: """ @@ -198,11 +291,13 @@ class WorkerManagerConf: Args: num_workers (int): The number of workers. + worker_queue_size (int): The size of the worker queue. input_queue_size (int): The size of the input queue. output_queue_size (int): The size of the output queue. """ num_workers: int = dataclasses.field(default_factory=int) + worker_queue_size: int = dataclasses.field(default_factory=int) input_queue_size: int = dataclasses.field(default_factory=int) output_queue_size: int = dataclasses.field(default_factory=int) @@ -234,11 +329,12 @@ def __init__( self.manager = multiprocessing.Manager() # Create the input and output queues - self.input_queue = self.manager.Queue(maxsize=settings.input_queue_size) + self.input_queue = multiprocessing.JoinableQueue(maxsize=settings.input_queue_size) self.output_queue = self.manager.Queue(maxsize=settings.output_queue_size) # Create the workers self.num_workers = settings.num_workers + self.worker_queue_size = settings.worker_queue_size self.worker_class = worker_class self.parent_thread = parent_thread self.thread_timeout = thread_timeout @@ -246,6 +342,11 @@ def __init__( self.instantiate_workers(**kwargs) + # Start the manager thread + self.stop_event = threading.Event() + self.manager_thread = threading.Thread(target=self.dispatch_jobs, daemon=True) + self.manager_thread.start() + def instantiate_workers( self, **kwargs, @@ -257,15 +358,34 @@ def instantiate_workers( **kwargs: Additional arguments. Used to pass the objects the workers need. """ - worker_instance = self.worker_class(self.thread_timeout, **kwargs) - self.workers = [ - multiprocessing.Process(target=worker_instance.run, args=(self.input_queue, self.output_queue)) + self.worker_class( + self.worker_queue_size, self.output_queue, self.parent_thread, self.thread_timeout, **kwargs + ) for _ in range(self.num_workers) ] for worker in self.workers: worker.start() + def get_load_per_worker(self) -> Dict[str, int]: + """ + Returns the load of each worker. + The dictionary contains the load of each worker in the form of: + { + "worker_0": 10, + "worker_1": 20, + ... + } + + Where the key is the worker name and the value is the number of jobs + in the worker's input queue. + + Returns: + dict: A dictionary with the load of each worker. + """ + + return {"worker_{}".format(i): worker.get_input_queue_length() for i, worker in enumerate(self.workers)} + def get_input_queue_length(self) -> int: """ Returns the length of the input queue. @@ -326,6 +446,20 @@ def is_output_queue_full(self) -> bool: return self.output_queue.full() + def get_shortest_queue_index(self) -> int: + """ + Returns the index of the worker with the shortest input queue. + This is used to distribute the work in a balanced way. + + Returns: + int: The index of the worker with the shortest input queue. + """ + + return min( + range(len(self.workers)), + key=lambda i: self.workers[i].get_input_queue_length(), + ) + def are_workers_done(self) -> bool: """ Checks if all the workers are done. @@ -335,7 +469,16 @@ def are_workers_done(self) -> bool: bool: True if all the workers are done, False otherwise. """ - return self.is_input_queue_empty() and self.is_output_queue_empty() + return all([worker.is_input_queue_empty() for worker in self.workers]) and self.is_output_queue_empty() + + def dispatch_jobs(self, parent_thread: threading.Thread) -> None: + """ + Dispatches the jobs to the workers. + + TO BE IMPLEMENTED BY SUBCLASSES. + """ + + raise NotImplementedError def process_data(self, coords: Tuple[float, float], data) -> None: """ @@ -372,10 +515,32 @@ def shutdown(self) -> None: Shuts down the worker manager and its workers. """ + logger.debug("Shutting down manager...") + self.input_queue.put(((0, 0), None)) + self.manager_thread.join() + logger.debug("Joined manager thread.") + while not self.input_queue.empty(): + self.input_queue.get() + self.input_queue.task_done() + logger.debug("Emptied input queue.") + self.input_queue.join() + logger.debug("Joined input queue.") + while not self.output_queue.empty(): + self.output_queue.get() + logger.debug("Emptied output queue.") + self.manager.shutdown() + logger.debug("Multiprocessing.Manager shutdown complete.") + logger.debug("Manager shutdown complete.") + + def shutdown_workers(self) -> None: + """ + Shuts down the workers. + """ + + logger.debug("Shutting down workers...") for worker in self.workers: - self.input_queue.put(((0, 0), None)) - for worker in self.workers: - worker.join() + worker.shutdown() + logger.debug("Workers shutdown complete.") def __del__(self) -> None: """ @@ -389,13 +554,16 @@ def __del__(self) -> None: pass -class CraterBuilderWorker: +class CraterBuilderWorker(BaseWorker): """ CraterBuilderWorker class. This class is responsible for generating the craters. """ def __init__( self, + queue_size: int = 10, + output_queue: multiprocessing.Queue = multiprocessing.JoinableQueue(), + parent_thread: threading.Thread = None, thread_timeout: float = 1.0, builder: CraterBuilder = None, ) -> None: @@ -405,37 +573,34 @@ def __init__( builder (CraterBuilder): The crater builder. """ - self.thread_timeout = thread_timeout + super().__init__(queue_size, output_queue, parent_thread, thread_timeout) self.builder = copy.copy(builder) - def run(self, input_queue: multiprocessing.Queue, output_queue: multiprocessing.Queue) -> None: + def run(self) -> None: """ The main function of the worker process. This function is called when the worker process is started. It takes craters metadata and coordinates from the input queue and generates images with inprinted craters. - - Args: - input_queue (multiprocessing.JoinableQueue): The input queue. - output_queue (multiprocessing.JoinableQueue): The output queue. """ - while True: + while not self.stop_event.is_set(): try: - coords, crater_meta_data = input_queue.get(timeout=self.thread_timeout) + coords, crater_meta_data = self.input_queue.get(timeout=self.thread_timeout) + self.input_queue.task_done() if crater_meta_data is None: break data_not_in_queue = True out = (coords, self.builder.generate_craters(crater_meta_data, coords)) while data_not_in_queue: try: - output_queue.put(out, timeout=0.1) + self.output_queue.put(out, timeout=0.1) data_not_in_queue = False except Exception as e: pass except Exception as e: pass - print("crater worker dead.") + logger.debug("Crater Builder Worker exited processing loop.") class CraterBuilderManager(BaseWorkerManager): @@ -468,8 +633,28 @@ def __init__( thread_timeout=thread_timeout, ) + def dispatch_jobs(self) -> None: + """ + Dispatches the jobs to the workers. + The jobs are dispatched in a balanced way such that the workers have a similar load. + """ + + # Assigns the jobs such that the workers have a balanced load + while (not self.stop_event.is_set()) and (self.parent_thread.is_alive()): + try: + coords, crater_metadata = self.input_queue.get(timeout=self.thread_timeout) + if crater_metadata is None: # Check for shutdown signal + self.input_queue.task_done() + break + self.workers[self.get_shortest_queue_index()].input_queue.put((coords, crater_metadata)) + self.input_queue.task_done() + except: + pass + self.shutdown_workers() + logger.debug("Crater Builder Manager exited processing loop.") -class BicubicInterpolatorWorker: + +class BicubicInterpolatorWorker(BaseWorker): """ BicubicInterpolatorWorker class. This class is responsible for interpolating the terrain data. @@ -477,6 +662,9 @@ class BicubicInterpolatorWorker: def __init__( self, + queue_size: int = 10, + output_queue: multiprocessing.JoinableQueue = multiprocessing.JoinableQueue(), + parent_thread: threading.Thread = None, thread_timeout: float = 1.0, interp: Interpolator = None, ): @@ -486,40 +674,33 @@ def __init__( interp (Interpolator): The interpolator. """ + super().__init__(queue_size, output_queue, parent_thread, thread_timeout) self.interpolator = copy.copy(interp) - self.thread_timeout = thread_timeout - def run( - self, - input_queue: multiprocessing.Queue, - output_queue: multiprocessing.Queue, - ) -> None: + def run(self) -> None: """ The main function of the worker process. This function is called when the worker process is started. It takes terrain data from the input queue and interpolates it. - - Args: - input_queue (multiprocessing.JoinableQueue): The input queue. - output_queue (multiprocessing.JoinableQueue): The output queue. """ - while True: + while not self.stop_event.is_set(): try: - coords, data = input_queue.get(timeout=self.thread_timeout) + coords, data = self.input_queue.get(timeout=self.thread_timeout) + self.input_queue.task_done() if data is None: break out = (coords, self.interpolator.interpolate(data)) data_not_in_queue = True while data_not_in_queue: try: - output_queue.put(out, timeout=0.1) + self.output_queue.put(out, timeout=0.1) data_not_in_queue = False except Exception as e: pass except Exception as e: pass - print("bicubic worker dead.") + logger.debug("Bicubic Interpolator Worker exited processing loop.") class BicubicInterpolatorManager(BaseWorkerManager): @@ -555,21 +736,95 @@ def __init__( ) cv2.setNumThreads(num_cv2_threads) + def dispatch_jobs(self) -> None: + """ + Dispatches the jobs to the workers. + The jobs are dispatched in a balanced way such that the workers have a similar load. + """ + + while (not self.stop_event.is_set()) and (self.parent_thread.is_alive()): + try: + coords, data = self.input_queue.get(timeout=self.thread_timeout) + self.input_queue.task_done() + if data is None: + break + self.workers[self.get_shortest_queue_index()].input_queue.put((coords, data)) + except: + pass + self.shutdown_workers() + logger.debug("Bicubic Interpolator manager exited processing loop.") + class ThreadMonitor: - def __init__(self): + def __init__(self, is_simulation_alive: callable = lambda: True, close_simulation: callable = lambda: None): + # Catch SIGINT + self.ctrl_c = False + signal.signal(signal.SIGINT, self.catch_sigint) + # Catch Simulation shutdown + self.is_simulation_alive = is_simulation_alive + self.close_simulation = close_simulation + # Start the monitor thread self.event = threading.Event() self.thread = threading.Thread(target=self.monitor_main_thread) self.thread.start() + self.shutdowns = [] + + def catch_sigint(self, sig, frame) -> None: + """ + Catch the ctrl-c signal. + """ + + self.ctrl_c = True + + def add_shutdowns(self, *args): + self.shutdowns += list(args) + + def apply_shutdowns(self): + logger.debug("Applying registered shutdowns.") + for shutdown in self.shutdowns: + shutdown() + + def apply_shutdowns_in_different_process(self): + logger.debug("Applying registered shutdowns in different process.") + for shutdown in self.shutdowns: + p = multiprocessing.Process(target=shutdown) + p.start() + p.join() def monitor_main_thread(self) -> None: """ Monitors the main thread. This function is used to monitor the main thread and check if it is still alive. + + It monitors three things: + - The main python thread. + - The simulation thread. + - The ctrl-c signal. This was added as it seems that the simulation was completely ignoring the ctrl-c signal. """ + use_multiprocessing = False while not self.event.is_set(): time.sleep(1) if not threading.main_thread().is_alive(): - print("Main thread is dead, shutting down workers.") + logger.debug("Main thread is dead, shutting down workers.") + break + if not self.is_simulation_alive(): + logger.debug("Simulation is dead, trying to shut down workers.") + logger.warn( + "When the simulation is exited by clicking the close button, Isaac is aggressively killing everyting including the threads in charge of cleaning up..." + ) + logger.warn("If the simulation window does not close it means that not all threads exited.") + logger.warn("Use ps aux | grep run.py to find the remaining threads.") + logger.warn( + "Use kill -9 PID to kill the remaining threads. PID being the PIDs returned by the previous command." + ) + use_multiprocessing = True break + if self.ctrl_c: + logger.debug("Ctrl-C caught, shutting down workers.") + break + if use_multiprocessing: + self.apply_shutdowns_in_different_process() + else: + self.apply_shutdowns() + logger.debug("Thread monitor exiting.")