Skip to content

Commit

Permalink
Allow multiple particles in extract
Browse files Browse the repository at this point in the history
  • Loading branch information
jmp1985 committed Oct 14, 2024
1 parent f3a6472 commit a2ec1c9
Showing 1 changed file with 69 additions and 55 deletions.
124 changes: 69 additions & 55 deletions src/parakeet/analyse/_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,12 @@ def rebin(data, shape):
assert voxel_size[0] == voxel_size[2]
size = np.array(tomogram.shape)[[2, 0, 1]] * voxel_size

# Create a file to store particles
handle = h5py.File(extract_file, "w")
handle["voxel_size"] = voxel_size * particle_sampling
handle.create_group("data")

# Loop through the
assert sample.number_of_molecules == 1
for name, (atoms, positions, orientations) in sample.iter_molecules():
# Compute the box size based on the size of the particle so that any
# orientation should fit within the box
Expand Down Expand Up @@ -178,10 +182,8 @@ def rebin(data, shape):

# Create a file to store particles
sampled_shape = (sampled_length, sampled_length, sampled_length)
handle = h5py.File(extract_file, "w")
handle["voxel_size"] = voxel_size * particle_sampling
data_handle = handle.create_dataset(
"data", (0,) + sampled_shape, maxshape=(None,) + shape
data_handle = handle["data"].create_dataset(
name, (0,) + sampled_shape, maxshape=(None,) + shape
)

# Loop through all the particles
Expand Down Expand Up @@ -219,54 +221,66 @@ def average_extracted_particles(
"""

def _average_particle(
data, voxel_size, name, half1_filename, half2_filename, num_particles
):
# Get the number of particles
if num_particles is None or num_particles <= 0:
num_particles = data.shape[0]
half_num_particles = num_particles // 2
assert half_num_particles > 0
assert num_particles <= data.shape[0]

# Setup the arrays
half = np.zeros((2,) + data.shape[1:], dtype="float32")
num = np.zeros(2)

# Get the random indices
indices = list(
np.random.choice(range(data.shape[0]), size=num_particles, replace=False)
)
indices = [indices[:half_num_particles], indices[half_num_particles:]]

# Average the particles
print("Summing particles")
for half_index, particle_indices in enumerate(indices):
for i, particle_index in enumerate(particle_indices):
print(
"Half %d: adding %d / %d"
% (half_index + 1, i + 1, len(particle_indices))
)
half[half_index, :, :, :] += data[particle_index, :, :, :]
num[half_index] += 1

# Average the sub tomograms
print("Averaging half 1 with %d particles" % num[0])
print("Averaging half 2 with %d particles" % num[1])
if num[0] > 0:
half[0, :, :, :] = half[0, :, :, :] / num[0]
if num[1] > 0:
half[1, :, :, :] = half[1, :, :, :] / num[1]

# Set prefix
half1_filename = "%s_%s" % (name, half1_filename)
half2_filename = "%s_%s" % (name, half2_filename)

# Save the averaged data
print("Saving half 1 to %s" % half1_filename)
handle = mrcfile.new(half1_filename, overwrite=True)
handle.set_data(half[0, :, :, :])
handle.voxel_size = voxel_size
print("Saving half 2 to %s" % half2_filename)
handle = mrcfile.new(half2_filename, overwrite=True)
handle.set_data(half[1, :, :, :])
handle.voxel_size = voxel_size

# Open the particles file
handle = h5py.File(particles_filename, "r")
data = handle["data"]
voxel_size = tuple(handle["voxel_size"][:])
print("Voxel size: %s" % str(voxel_size))

# Get the number of particles
if num_particles is None or num_particles <= 0:
num_particles = data.shape[0]
half_num_particles = num_particles // 2
assert half_num_particles > 0
assert num_particles <= data.shape[0]

# Setup the arrays
half = np.zeros((2,) + data.shape[1:], dtype="float32")
num = np.zeros(2)

# Get the random indices
indices = list(
np.random.choice(range(data.shape[0]), size=num_particles, replace=False)
)
indices = [indices[:half_num_particles], indices[half_num_particles:]]

# Average the particles
print("Summing particles")
for half_index, particle_indices in enumerate(indices):
for i, particle_index in enumerate(particle_indices):
print(
"Half %d: adding %d / %d"
% (half_index + 1, i + 1, len(particle_indices))
)
half[half_index, :, :, :] += data[particle_index, :, :, :]
num[half_index] += 1

# Average the sub tomograms
print("Averaging half 1 with %d particles" % num[0])
print("Averaging half 2 with %d particles" % num[1])
if num[0] > 0:
half[0, :, :, :] = half[0, :, :, :] / num[0]
if num[1] > 0:
half[1, :, :, :] = half[1, :, :, :] / num[1]

# Save the averaged data
print("Saving half 1 to %s" % half1_filename)
handle = mrcfile.new(half1_filename, overwrite=True)
handle.set_data(half[0, :, :, :])
handle.voxel_size = voxel_size
print("Saving half 2 to %s" % half2_filename)
handle = mrcfile.new(half2_filename, overwrite=True)
handle.set_data(half[1, :, :, :])
handle.voxel_size = voxel_size
particles_handle = h5py.File(particles_filename, "r")

for name in particles_handle["data"].keys():
data = particles_handle["data"][name]
voxel_size = tuple(particles_handle["voxel_size"][:])
print("Voxel size: %s" % str(voxel_size))
_average_particle(
data, voxel_size, name, half1_filename, half2_filename, num_particles
)

0 comments on commit a2ec1c9

Please sign in to comment.