diff --git a/src/parakeet/_run.py b/src/parakeet/_run.py index 38d2ccc..447b2b4 100644 --- a/src/parakeet/_run.py +++ b/src/parakeet/_run.py @@ -14,10 +14,9 @@ import parakeet.metadata import parakeet.sample import parakeet.simulate +from parakeet.config import Device from functools import singledispatch -Device = parakeet.config.Device - __all__ = ["run"] @@ -33,7 +32,7 @@ def run( exit_wave_file: str, optics_file: str, image_file: str, - device: str = None, + device: Device = None, nproc: int = None, gpu_id: list = None, steps: list = None, diff --git a/src/parakeet/analyse/_correct.py b/src/parakeet/analyse/_correct.py index e75e788..3b12142 100644 --- a/src/parakeet/analyse/_correct.py +++ b/src/parakeet/analyse/_correct.py @@ -13,6 +13,7 @@ import random import parakeet.microscope import parakeet.sample +from parakeet.config import Device from functools import singledispatch @@ -29,7 +30,7 @@ def correct( image_file: str, corrected_file: str, num_defocus: int = 1, - device: str = None, + device: Device = None, ): """ Correct the images using 3D CTF correction diff --git a/src/parakeet/analyse/_reconstruct.py b/src/parakeet/analyse/_reconstruct.py index 8154a6a..eec9612 100644 --- a/src/parakeet/analyse/_reconstruct.py +++ b/src/parakeet/analyse/_reconstruct.py @@ -13,6 +13,7 @@ import random import parakeet.microscope import parakeet.sample +from parakeet.config import Device from functools import singledispatch @@ -24,7 +25,7 @@ @singledispatch -def reconstruct(config_file, image_file: str, rec_file: str, device: str = None): +def reconstruct(config_file, image_file: str, rec_file: str, device: Device = None): """ Reconstruct the volume diff --git a/src/parakeet/simulate/_cbed.py b/src/parakeet/simulate/_cbed.py index 9c3f3e4..9dfcf0d 100644 --- a/src/parakeet/simulate/_cbed.py +++ b/src/parakeet/simulate/_cbed.py @@ -20,6 +20,7 @@ import parakeet.io import parakeet.sample import parakeet.simulate +from parakeet.config import Device from parakeet.simulate.simulation import Simulation from parakeet.simulate.engine import SimulationEngine from parakeet.microscope import Microscope @@ -32,9 +33,6 @@ __all__ = ["cbed"] -Device = parakeet.config.Device -Sample = parakeet.sample.Sample - # Get the logger logger = logging.getLogger(__name__) @@ -323,7 +321,7 @@ def cbed( config_file, sample_file: str, image_file: str, - device: str = None, + device: Device = None, nproc: int = None, gpu_id: list = None, ): diff --git a/src/parakeet/simulate/_exit_wave.py b/src/parakeet/simulate/_exit_wave.py index 7eb2e4f..bc455d7 100644 --- a/src/parakeet/simulate/_exit_wave.py +++ b/src/parakeet/simulate/_exit_wave.py @@ -20,6 +20,7 @@ import parakeet.io import parakeet.sample import parakeet.simulate +from parakeet.config import Device from parakeet.simulate.simulation import Simulation from parakeet.simulate.engine import SimulationEngine from parakeet.microscope import Microscope @@ -312,7 +313,7 @@ def exit_wave( config_file, sample_file: str, exit_wave_file: str, - device: str = None, + device: Device = None, nproc: int = None, gpu_id: list = None, ): diff --git a/src/parakeet/simulate/_optics.py b/src/parakeet/simulate/_optics.py index a28477b..6bc2741 100644 --- a/src/parakeet/simulate/_optics.py +++ b/src/parakeet/simulate/_optics.py @@ -19,6 +19,9 @@ import parakeet.inelastic import parakeet.io import parakeet.sample +from parakeet.config import Device +from parakeet.microscope import Microscope +from parakeet.scan import Scan from functools import singledispatch from parakeet.simulate.simulation import Simulation from parakeet.simulate.engine import SimulationEngine @@ -456,7 +459,7 @@ def optics( config_file, exit_wave_file: str, optics_file: str, - device: str = None, + device: Device = None, nproc: int = None, gpu_id: list = None, ): diff --git a/src/parakeet/simulate/_potential.py b/src/parakeet/simulate/_potential.py index b835968..78e4405 100644 --- a/src/parakeet/simulate/_potential.py +++ b/src/parakeet/simulate/_potential.py @@ -18,6 +18,7 @@ import parakeet.futures import parakeet.inelastic import parakeet.sample +from parakeet.config import Device from parakeet.microscope import Microscope from parakeet.scan import Scan from parakeet.simulate.simulation import Simulation @@ -221,7 +222,7 @@ def potential( config_file, sample_file: str, potential_prefix: str, - device: str = None, + device: Device = None, nproc: int = None, gpu_id: list = None, ):