From f165117716a4aafec94ebc5111804341f261f2f3 Mon Sep 17 00:00:00 2001 From: James Parkhurst Date: Thu, 19 Oct 2023 18:04:54 +0100 Subject: [PATCH] FIxed test --- src/parakeet/analyse/_correct.py | 7 +++---- src/parakeet/analyse/_reconstruct.py | 9 +++------ src/parakeet/config.py | 2 +- src/parakeet/simulate/_ctf.py | 4 ---- src/parakeet/simulate/_image.py | 28 +++++++++++++++++--------- src/parakeet/simulate/_simple.py | 30 ++++++++++++++++++++-------- 6 files changed, 48 insertions(+), 32 deletions(-) diff --git a/src/parakeet/analyse/_correct.py b/src/parakeet/analyse/_correct.py index 6c3b502..e75e788 100644 --- a/src/parakeet/analyse/_correct.py +++ b/src/parakeet/analyse/_correct.py @@ -14,7 +14,6 @@ import parakeet.microscope import parakeet.sample from functools import singledispatch -from parakeet.config import Device __all__ = ["correct"] @@ -30,7 +29,7 @@ def correct( image_file: str, corrected_file: str, num_defocus: int = 1, - device: Device = Device.gpu, + device: str = None, ): """ Correct the images using 3D CTF correction @@ -49,7 +48,7 @@ def correct( # Set the device if device is not None: - config.device = device + config.multiprocessing.device = device # Print some options parakeet.config.show(config) @@ -115,5 +114,5 @@ def _correct_Config( astigmatism=astigmatism, astigmatism_angle=astigmatism_angle, phase_shift=phase_shift, - device=config.device, + device=config.multiprocessing.device, ) diff --git a/src/parakeet/analyse/_reconstruct.py b/src/parakeet/analyse/_reconstruct.py index 3e21c06..8154a6a 100644 --- a/src/parakeet/analyse/_reconstruct.py +++ b/src/parakeet/analyse/_reconstruct.py @@ -14,7 +14,6 @@ import parakeet.microscope import parakeet.sample from functools import singledispatch -from parakeet.config import Device __all__ = ["reconstruct"] @@ -25,9 +24,7 @@ @singledispatch -def reconstruct( - config_file, image_file: str, rec_file: str, device: Device = Device.gpu -): +def reconstruct(config_file, image_file: str, rec_file: str, device: str = None): """ Reconstruct the volume @@ -44,7 +41,7 @@ def reconstruct( # Set the device if device is not None: - config.device = device + config.multiprocessing.device = device # Print some options parakeet.config.show(config) @@ -110,5 +107,5 @@ def _reconstruct_Config( astigmatism_angle=astigmatism_angle, phase_shift=phase_shift, angular_weights=True, - device=config.device, + device=config.multiprocessing.device, ) diff --git a/src/parakeet/config.py b/src/parakeet/config.py index c06b978..cd71c22 100644 --- a/src/parakeet/config.py +++ b/src/parakeet/config.py @@ -741,7 +741,7 @@ class Multiprocessing(BaseModel): nproc: int = Field(1, description="The number of processes", gt=0) - gpu_id: List[int] = Field(None, description="The GPU id for each thread") + gpu_id: List[int] = Field([0], description="The GPU id for each thread") class Config(BaseModel): diff --git a/src/parakeet/simulate/_ctf.py b/src/parakeet/simulate/_ctf.py index 4a787a6..880828b 100644 --- a/src/parakeet/simulate/_ctf.py +++ b/src/parakeet/simulate/_ctf.py @@ -104,11 +104,7 @@ def simulation_factory(microscope: Microscope, simulation: dict) -> Simulation: Args: microscope (object); The microscope object - exit_wave (object): The exit_wave object - scan (object): The scan object - device (str): The device to use simulation (object): The simulation parameters - cluster (object): The cluster parameters Returns: object: The simulation object diff --git a/src/parakeet/simulate/_image.py b/src/parakeet/simulate/_image.py index 8808c3c..d65d0ef 100644 --- a/src/parakeet/simulate/_image.py +++ b/src/parakeet/simulate/_image.py @@ -45,13 +45,20 @@ class ImageSimulator(object): """ def __init__( - self, microscope=None, optics=None, scan=None, simulation=None, device="gpu" + self, + microscope=None, + optics=None, + scan=None, + simulation=None, + device="gpu", + gpu_id=None, ): self.microscope = microscope self.optics = optics self.scan = scan self.simulation = simulation self.device = device + self.gpu_id = gpu_id def __call__(self, index): """ @@ -130,9 +137,8 @@ def simulation_factory( microscope: Microscope, optics: object, scan: Scan, - device: Device = Device.gpu, simulation: dict = None, - cluster: dict = None, + multiprocessing: dict = None, ) -> Simulation: """ Create the simulation @@ -141,27 +147,32 @@ def simulation_factory( microscope (object); The microscope object optics (object): The optics object scan (object): The scan object - device (str): The device to use simulation (object): The simulation parameters - cluster (object): The cluster parameters + multiprocessing (object): The multiprocessing parameters Returns: object: The simulation object """ + # Check multiprocessing settings + if multiprocessing is None: + multiprocessing = {"device": "gpu", "nproc": 1, "gpu_id": 0} + else: + assert multiprocessing["nproc"] in [None, 1] + assert len(multiprocessing["gpu_id"]) == 1 # Create the simulation return Simulation( image_size=(microscope.detector.nx, microscope.detector.ny), pixel_size=microscope.detector.pixel_size, scan=scan, - cluster=cluster, simulate_image=ImageSimulator( microscope=microscope, optics=optics, scan=scan, simulation=simulation, - device=device, + device=multiprocessing["device"], + gpu_id=multiprocessing["gpu_id"][0], ), ) @@ -222,9 +233,8 @@ def _image_Config(config: parakeet.config.Config, optics_file: str, image_file: microscope=microscope, optics=optics, scan=scan, - device=config.device, simulation=config.simulation.dict(), - cluster=config.cluster.dict(), + multiprocessing=config.multiprocessing.dict(), ) # Create the writer diff --git a/src/parakeet/simulate/_simple.py b/src/parakeet/simulate/_simple.py index 47e91c9..d6dba46 100644 --- a/src/parakeet/simulate/_simple.py +++ b/src/parakeet/simulate/_simple.py @@ -47,13 +47,20 @@ class SimpleImageSimulator(object): """ def __init__( - self, microscope=None, atoms=None, scan=None, simulation=None, device="gpu" + self, + microscope=None, + atoms=None, + scan=None, + simulation=None, + device="gpu", + gpu_id=None, ): self.microscope = microscope self.atoms = atoms self.scan = scan self.simulation = simulation self.device = device + self.gpu_id = gpu_id def __call__(self, index): """ @@ -93,7 +100,8 @@ def __call__(self, index): # Create the multem system configuration system_conf = parakeet.simulate.simulation.create_system_configuration( - self.device + self.device, + self.gpu_id, ) # Create the multem input multislice object @@ -134,8 +142,8 @@ def __call__(self, index): def simulation_factory( microscope: Microscope, atoms: str, - device: Device = Device.gpu, simulation: dict = None, + multiprocessing: dict = None, ): """ Create the simulation @@ -143,9 +151,8 @@ def simulation_factory( Args: microscope (object); The microscope object atoms (object): The atom data - device (str): The device to use simulation (object): The simulation parameters - cluster (object): The cluster parameters + multiprocessing (object): The multiprocessing parameters Returns: object: The simulation object @@ -157,6 +164,13 @@ def simulation_factory( # Get the margin margin = 0 if simulation is None else simulation.get("margin", 0) + # Check multiprocessing settings + if multiprocessing is None: + multiprocessing = {"device": "gpu", "nproc": 1, "gpu_id": 0} + else: + assert multiprocessing["nproc"] in [None, 1] + assert len(multiprocessing["gpu_id"]) == 1 + # Create the simulation return Simulation( image_size=( @@ -165,13 +179,13 @@ def simulation_factory( ), pixel_size=microscope.detector.pixel_size, scan=scan, - cluster={"method": None}, simulate_image=SimpleImageSimulator( microscope=microscope, scan=scan, atoms=atoms, simulation=simulation, - device=device, + device=multiprocessing["device"], + gpu_id=multiprocessing["gpu_id"][0], ), ) @@ -220,8 +234,8 @@ def _simple_Config(config: parakeet.config.Config, atoms_file: str, output_file: simulation = simulation_factory( microscope=microscope, atoms=atoms, - device=config.device, simulation=config.simulation.dict(), + multiprocessing=config.multiprocessing.dict(), ) # Create the writer