Skip to content

Commit

Permalink
FIxed test
Browse files Browse the repository at this point in the history
  • Loading branch information
jmp1985 committed Oct 19, 2023
1 parent bd34a80 commit f165117
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 32 deletions.
7 changes: 3 additions & 4 deletions src/parakeet/analyse/_correct.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import parakeet.microscope
import parakeet.sample
from functools import singledispatch
from parakeet.config import Device


__all__ = ["correct"]
Expand All @@ -30,7 +29,7 @@ def correct(
image_file: str,
corrected_file: str,
num_defocus: int = 1,
device: Device = Device.gpu,
device: str = None,
):
"""
Correct the images using 3D CTF correction
Expand All @@ -49,7 +48,7 @@ def correct(

# Set the device
if device is not None:
config.device = device
config.multiprocessing.device = device

# Print some options
parakeet.config.show(config)
Expand Down Expand Up @@ -115,5 +114,5 @@ def _correct_Config(
astigmatism=astigmatism,
astigmatism_angle=astigmatism_angle,
phase_shift=phase_shift,
device=config.device,
device=config.multiprocessing.device,
)
9 changes: 3 additions & 6 deletions src/parakeet/analyse/_reconstruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import parakeet.microscope
import parakeet.sample
from functools import singledispatch
from parakeet.config import Device


__all__ = ["reconstruct"]
Expand All @@ -25,9 +24,7 @@


@singledispatch
def reconstruct(
config_file, image_file: str, rec_file: str, device: Device = Device.gpu
):
def reconstruct(config_file, image_file: str, rec_file: str, device: str = None):
"""
Reconstruct the volume
Expand All @@ -44,7 +41,7 @@ def reconstruct(

# Set the device
if device is not None:
config.device = device
config.multiprocessing.device = device

# Print some options
parakeet.config.show(config)
Expand Down Expand Up @@ -110,5 +107,5 @@ def _reconstruct_Config(
astigmatism_angle=astigmatism_angle,
phase_shift=phase_shift,
angular_weights=True,
device=config.device,
device=config.multiprocessing.device,
)
2 changes: 1 addition & 1 deletion src/parakeet/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ class Multiprocessing(BaseModel):

nproc: int = Field(1, description="The number of processes", gt=0)

gpu_id: List[int] = Field(None, description="The GPU id for each thread")
gpu_id: List[int] = Field([0], description="The GPU id for each thread")


class Config(BaseModel):
Expand Down
4 changes: 0 additions & 4 deletions src/parakeet/simulate/_ctf.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,7 @@ def simulation_factory(microscope: Microscope, simulation: dict) -> Simulation:
Args:
microscope (object); The microscope object
exit_wave (object): The exit_wave object
scan (object): The scan object
device (str): The device to use
simulation (object): The simulation parameters
cluster (object): The cluster parameters
Returns:
object: The simulation object
Expand Down
28 changes: 19 additions & 9 deletions src/parakeet/simulate/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,20 @@ class ImageSimulator(object):
"""

def __init__(
self, microscope=None, optics=None, scan=None, simulation=None, device="gpu"
self,
microscope=None,
optics=None,
scan=None,
simulation=None,
device="gpu",
gpu_id=None,
):
self.microscope = microscope
self.optics = optics
self.scan = scan
self.simulation = simulation
self.device = device
self.gpu_id = gpu_id

def __call__(self, index):
"""
Expand Down Expand Up @@ -130,9 +137,8 @@ def simulation_factory(
microscope: Microscope,
optics: object,
scan: Scan,
device: Device = Device.gpu,
simulation: dict = None,
cluster: dict = None,
multiprocessing: dict = None,
) -> Simulation:
"""
Create the simulation
Expand All @@ -141,27 +147,32 @@ def simulation_factory(
microscope (object); The microscope object
optics (object): The optics object
scan (object): The scan object
device (str): The device to use
simulation (object): The simulation parameters
cluster (object): The cluster parameters
multiprocessing (object): The multiprocessing parameters
Returns:
object: The simulation object
"""
# Check multiprocessing settings
if multiprocessing is None:
multiprocessing = {"device": "gpu", "nproc": 1, "gpu_id": 0}
else:
assert multiprocessing["nproc"] in [None, 1]
assert len(multiprocessing["gpu_id"]) == 1

# Create the simulation
return Simulation(
image_size=(microscope.detector.nx, microscope.detector.ny),
pixel_size=microscope.detector.pixel_size,
scan=scan,
cluster=cluster,
simulate_image=ImageSimulator(
microscope=microscope,
optics=optics,
scan=scan,
simulation=simulation,
device=device,
device=multiprocessing["device"],
gpu_id=multiprocessing["gpu_id"][0],
),
)

Expand Down Expand Up @@ -222,9 +233,8 @@ def _image_Config(config: parakeet.config.Config, optics_file: str, image_file:
microscope=microscope,
optics=optics,
scan=scan,
device=config.device,
simulation=config.simulation.dict(),
cluster=config.cluster.dict(),
multiprocessing=config.multiprocessing.dict(),
)

# Create the writer
Expand Down
30 changes: 22 additions & 8 deletions src/parakeet/simulate/_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,20 @@ class SimpleImageSimulator(object):
"""

def __init__(
self, microscope=None, atoms=None, scan=None, simulation=None, device="gpu"
self,
microscope=None,
atoms=None,
scan=None,
simulation=None,
device="gpu",
gpu_id=None,
):
self.microscope = microscope
self.atoms = atoms
self.scan = scan
self.simulation = simulation
self.device = device
self.gpu_id = gpu_id

def __call__(self, index):
"""
Expand Down Expand Up @@ -93,7 +100,8 @@ def __call__(self, index):

# Create the multem system configuration
system_conf = parakeet.simulate.simulation.create_system_configuration(
self.device
self.device,
self.gpu_id,
)

# Create the multem input multislice object
Expand Down Expand Up @@ -134,18 +142,17 @@ def __call__(self, index):
def simulation_factory(
microscope: Microscope,
atoms: str,
device: Device = Device.gpu,
simulation: dict = None,
multiprocessing: dict = None,
):
"""
Create the simulation
Args:
microscope (object); The microscope object
atoms (object): The atom data
device (str): The device to use
simulation (object): The simulation parameters
cluster (object): The cluster parameters
multiprocessing (object): The multiprocessing parameters
Returns:
object: The simulation object
Expand All @@ -157,6 +164,13 @@ def simulation_factory(
# Get the margin
margin = 0 if simulation is None else simulation.get("margin", 0)

# Check multiprocessing settings
if multiprocessing is None:
multiprocessing = {"device": "gpu", "nproc": 1, "gpu_id": 0}
else:
assert multiprocessing["nproc"] in [None, 1]
assert len(multiprocessing["gpu_id"]) == 1

# Create the simulation
return Simulation(
image_size=(
Expand All @@ -165,13 +179,13 @@ def simulation_factory(
),
pixel_size=microscope.detector.pixel_size,
scan=scan,
cluster={"method": None},
simulate_image=SimpleImageSimulator(
microscope=microscope,
scan=scan,
atoms=atoms,
simulation=simulation,
device=device,
device=multiprocessing["device"],
gpu_id=multiprocessing["gpu_id"][0],
),
)

Expand Down Expand Up @@ -220,8 +234,8 @@ def _simple_Config(config: parakeet.config.Config, atoms_file: str, output_file:
simulation = simulation_factory(
microscope=microscope,
atoms=atoms,
device=config.device,
simulation=config.simulation.dict(),
multiprocessing=config.multiprocessing.dict(),
)

# Create the writer
Expand Down

0 comments on commit f165117

Please sign in to comment.