Skip to content

Commit

Permalink
Add ability to subsample extracted sub tomos
Browse files Browse the repository at this point in the history
  • Loading branch information
jmp1985 committed Sep 30, 2024
1 parent dde244d commit f3a6472
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 7 deletions.
42 changes: 36 additions & 6 deletions src/parakeet/analyse/_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def extract(
rec_file: str,
particles_file: str,
particle_size: int,
particle_sampling: int,
):
"""
Perform sub tomogram extraction
Expand All @@ -45,6 +46,7 @@ def extract(
rec_file: The reconstruction filename
particles_file: The file to extract the particles to
particle_size: The particle size (px)
particle_sampling: The particle sampling (factor of 2)
"""

Expand All @@ -58,7 +60,9 @@ def extract(
sample = parakeet.sample.load(sample_file)

# Do the sub tomogram averaging
_extract_Config(config, sample, rec_file, particles_file, particle_size)
_extract_Config(
config, sample, rec_file, particles_file, particle_size, particle_sampling
)


@extract.register(parakeet.config.Config)
Expand All @@ -68,12 +72,31 @@ def _extract_Config(
rec_filename: str,
extract_file: str,
particle_size: int = 0,
particle_sampling: int = 1,
):
"""
Extract particles for post-processing
"""

def is_power_of_2(n):
return (n & (n - 1) == 0) and n != 0

def rebin(data, shape):
shape = (
shape[0],
data.shape[0] // shape[0],
shape[1],
data.shape[1] // shape[1],
shape[2],
data.shape[2] // shape[2],
)
output = data.reshape(shape).sum(-1).sum(-2).sum(-3)
return output

# Check particle sampling is power of 2
assert is_power_of_2(particle_sampling)

# Get the scan config
# scan = config.model_dump()

Expand Down Expand Up @@ -128,11 +151,17 @@ def _extract_Config(
else:
half_length = particle_size // 2
length = 2 * half_length

# Check length is compatible with particle sampling
assert (length % particle_sampling) == 0
sampled_length = length // particle_sampling

# Check number positions
assert len(positions) == len(orientations)
num_particles = len(positions)
print(
"Extracting %d %s particles with box size %d"
% (num_particles, name, length)
"Extracting %d %s particles with box size %d and sampling %d"
% (num_particles, name, length, particle_sampling)
)

# Create the average array
Expand All @@ -148,10 +177,11 @@ def _extract_Config(
indices = [list(range(len(positions)))]

# Create a file to store particles
sampled_shape = (sampled_length, sampled_length, sampled_length)
handle = h5py.File(extract_file, "w")
handle["voxel_size"] = voxel_size
handle["voxel_size"] = voxel_size * particle_sampling
data_handle = handle.create_dataset(
"data", (0,) + shape, maxshape=(None,) + shape
"data", (0,) + sampled_shape, maxshape=(None,) + shape
)

# Loop through all the particles
Expand All @@ -173,7 +203,7 @@ def _extract_Config(
):
# Add the particle to the file
data_handle.resize(num + 1, axis=0)
data_handle[num, :, :, :] = data
data_handle[num, :, :, :] = rebin(data, sampled_shape)
num += 1
print("Count: ", num)

Expand Down
15 changes: 14 additions & 1 deletion src/parakeet/command_line/analyse/_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,14 @@ def get_parser(parser: ArgumentParser = None) -> ArgumentParser:
dest="particle_size",
help="The size of the particles extracted (px)",
)
parser.add_argument(
"-psm",
"--particle_sampling",
type=int,
default=1,
dest="particle_sampling",
help="The sampling of the particle volume (factor of 2)",
)

return parser

Expand All @@ -107,7 +115,12 @@ def extract_impl(args):

# Do the work
parakeet.analyse.extract(
args.config, args.sample, args.rec, args.particles, args.particle_size
args.config,
args.sample,
args.rec,
args.particles,
args.particle_size,
args.particle_sampling,
)

# Write some timing stats
Expand Down

0 comments on commit f3a6472

Please sign in to comment.