-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ab7f969
commit 80bcc34
Showing
6 changed files
with
562 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
import os | ||
import gemmi | ||
import numpy as np | ||
from scipy.spatial.transform import Rotation as R | ||
from .ctf import contrast_transfer_function | ||
from typing import Tuple, Union, List | ||
from pathlib import Path | ||
|
||
|
||
AXES = ["Cartn_x", "Cartn_y", "Cartn_z"] | ||
|
||
|
||
def pdb_to_coordinates(filename: os.PathLike) -> np.ndarray: | ||
"""Read a PDB file and return the atomic coordinates. | ||
Parameters | ||
---------- | ||
filename : PathLike | ||
A filename for the PDBx/mmCIF file describing the atomic coordinates. | ||
Returns | ||
------- | ||
coords : np.ndarray (N, 3) | ||
A numpy array of the cartesian coordinates of the atoms of the model. | ||
Notes | ||
----- | ||
This is super basic, and does not check for mutliple chains, cofactors etc. | ||
""" | ||
|
||
doc = gemmi.cif.read_file(filename) # copy all the data from mmCIF file | ||
block = doc.sole_block() # mmCIF has exactly one block | ||
|
||
data = block.find("_atom_site.", AXES) | ||
|
||
coords = np.stack( | ||
[[float(r) for r in data.column(idx)] for idx in range(len(AXES))], | ||
axis=-1, | ||
) | ||
|
||
# center the molecule in XYZ | ||
centroids = np.mean(coords, axis=0) | ||
coords = coords - centroids | ||
|
||
return coords | ||
|
||
|
||
class DensitySimulator: | ||
"""Simulate a Cryo-EM image using atomic coordinates from PDB files. | ||
Parameters | ||
---------- | ||
filename : PathLike | ||
The PDB filename to use to extract atomic coordinates. | ||
pixel_size : float | ||
The pixel size for the image in angstroms. | ||
box_size : float | ||
The size of the box in pixels for image generation. | ||
add_poission_noise : bool, default = True | ||
Add shot noise to the final image. | ||
Returns | ||
------- | ||
density : np.ndarray (N, N) | ||
The simulated projection of the electron density. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
filenames: Union[List[os.PathLike], os.PathLike], | ||
pixel_size: float = 1.0, | ||
box_size: int = 128, | ||
defocus: float = 5e3, | ||
): | ||
if isinstance(filenames, list): | ||
filenames = [Path(f) for f in filenames] | ||
else: | ||
filenames = Path(filenames) | ||
|
||
if filenames.is_dir(): | ||
filenames = [ | ||
f for f in filenames.iterdir() if f.suffix == ".cif" | ||
] | ||
else: | ||
filenames = [Path(filenames)] | ||
|
||
self.filenames = filenames | ||
self.pixel_size = pixel_size | ||
self.box_size = box_size | ||
|
||
self.structures = { | ||
filename.name: pdb_to_coordinates(str(filename)) | ||
for filename in self.filenames | ||
} | ||
|
||
self.ctf = contrast_transfer_function( | ||
defocus=defocus, | ||
box_size=box_size * 2, | ||
pixel_size=pixel_size, | ||
) | ||
|
||
def keys(self): | ||
return list(self.structures.keys()) | ||
|
||
def __call__( | ||
self, | ||
key: str, | ||
transform_euler_angles: list = [0, 0, 0], | ||
transform_translate: list = [0, 0, 0], | ||
project: bool = True, | ||
add_poisson_noise: bool = False, | ||
) -> Tuple[np.ndarray]: | ||
# get the atomic coordinates | ||
coords = self.structures[key] | ||
|
||
# do a transform | ||
r = R.from_euler("xyz", transform_euler_angles, degrees=True) | ||
coords = r.apply(coords) | ||
|
||
# centre the molecule, ish | ||
pad = self.box_size // 2 | ||
data = coords / self.pixel_size + [pad, pad, pad] | ||
|
||
# discretize the atomic coords assuming 1 px == 1 Angstrom | ||
density, _ = np.histogramdd( | ||
data, | ||
bins=self.box_size, | ||
range=tuple([(0, self.box_size - 1)] * 3), | ||
) | ||
|
||
if not project: | ||
assert density.ndim == 3 | ||
return density | ||
|
||
density = np.sum(density, axis=-1) | ||
density = density + 1.0 | ||
assert density.ndim == 2 | ||
|
||
return density |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,204 @@ | ||
import enum | ||
import itertools | ||
import mrcfile | ||
import torch | ||
import warnings | ||
|
||
import numpy as np | ||
import numpy.typing as npt | ||
import pandas as pd | ||
|
||
from pathlib import Path | ||
from scipy.ndimage import rotate | ||
from typing import List, Tuple | ||
from tqdm import tqdm | ||
|
||
from vne.utils.utils import InfinitePaddedImage | ||
|
||
|
||
class SHRECModelType(str, enum.Enum): | ||
RECONSTRUCTION = "reconstruction" | ||
GRANDMODEL = "grandmodel" | ||
CLASSMASK = "classmask" | ||
|
||
|
||
class SHRECModel: | ||
def __init__( | ||
self, | ||
model_path: Path, | ||
*, | ||
model_type: SHRECModelType = "reconstruction", | ||
exclude: List[str] = ["vesicle", "fiducial", "4V94"], | ||
boxsize: Tuple[int, int, int] = (32, 32, 32), | ||
augment: bool = False, | ||
) -> None: | ||
self.model_path = model_path | ||
self.model_type = SHRECModelType(model_type) | ||
particles = pd.read_csv( | ||
model_path / "particle_locations.txt", | ||
names=[ | ||
"class", | ||
"x", | ||
"y", | ||
"z", | ||
"rotation_Z1", | ||
"rotation_X", | ||
"rotation_Z2", | ||
], | ||
sep=" ", | ||
) | ||
self.exclude = exclude | ||
self.particles = particles[~particles["class"].isin(exclude)] | ||
self.data = None | ||
self.boxsize = np.array(boxsize) | ||
self.augment = augment | ||
|
||
def __len__(self) -> int: | ||
return len(self.particles) | ||
|
||
def keys(self) -> List[str]: | ||
return list(set(self.particles["class"].tolist())) | ||
|
||
def __getitem__(self, idx: int) -> Tuple[npt.NDArray, str]: | ||
"""Get the particle and class.""" | ||
|
||
if self.data is None: | ||
self._load_volume() | ||
|
||
particle = self.particles.iloc[idx] | ||
|
||
# if augmenting, crop a slightly larger volume | ||
cropsize = ( | ||
np.array(self.boxsize) + 16 if self.augment else self.boxsize | ||
) | ||
|
||
slices = [ | ||
slice( | ||
particle[dim] - cropsize[d] // 2, | ||
particle[dim] - cropsize[d] // 2 + cropsize[d], | ||
1, | ||
) | ||
for d, dim in enumerate(["z", "y", "x"]) | ||
] | ||
|
||
sz, sy, sx = slices | ||
subvolume = self.data[sz, sy, sx] | ||
return subvolume, str(particle["class"]) | ||
|
||
def _load_volume(self) -> None: | ||
if self.data is not None: | ||
return | ||
|
||
data_fn = self.model_path / f"{self.model_type}.mrc" | ||
|
||
with warnings.catch_warnings(): | ||
warnings.simplefilter("ignore", RuntimeWarning) | ||
with mrcfile.open(data_fn, permissive=True) as mrc: | ||
self.data = InfinitePaddedImage(np.asarray(mrc.data)) | ||
|
||
|
||
class SHRECDataset(torch.utils.data.Dataset): | ||
def __init__(self, root_path: Path, **kwargs): | ||
super().__init__() | ||
|
||
self.boxsize = kwargs.get("boxsize", (32, 32, 32)) | ||
self.augment = kwargs.get("augment", True) | ||
self.model_type = kwargs.get("model_type", SHRECModelType.GRANDMODEL) | ||
|
||
self._subtomo_fn = ( | ||
Path(root_path) / f"subtomograms_{self.model_type.upper()}.npz" | ||
) | ||
|
||
if self._subtomo_fn.exists(): | ||
print(f"Loading dataset: {self._subtomo_fn}...") | ||
dataset = np.load(self._subtomo_fn) | ||
self._subvolumes = dataset["volumes"] | ||
self._molecule_ids = dataset["molecule_ids"] | ||
self._keys = dataset["keys"].tolist() | ||
self._n_molecules = self._subvolumes.shape[0] | ||
return | ||
|
||
self.models = [ | ||
SHRECModel(model_path, **kwargs) | ||
for model_path in root_path.iterdir() | ||
if model_path.stem.startswith("model_") | ||
] | ||
# self.models = self.models[0:1] | ||
self._n_molecules = sum(len(model) for model in self.models) | ||
self.extract_subvolumes() | ||
|
||
def extract_subvolumes(self): | ||
subvolume_shape = self.models[0][0][0].shape | ||
self._subvolumes = np.zeros( | ||
(self._n_molecules, *subvolume_shape), dtype=np.float32 | ||
) | ||
self._molecule_ids = np.zeros((self._n_molecules,), dtype=np.uint8) | ||
jdx = 0 | ||
for model in tqdm(self.models, desc="Extracting subvolumes"): | ||
for idx in range(len(model)): | ||
self._subvolumes[jdx, ...], molecule_id = model[idx] | ||
self._molecule_ids[jdx] = self._keys.index(molecule_id) | ||
jdx += 1 | ||
|
||
self._keys = list( | ||
set( | ||
itertools.chain.from_iterable( | ||
[model.keys() for model in self.models] | ||
) | ||
) | ||
) | ||
|
||
np.savez( | ||
self._subtomo_fn, | ||
volumes=self._subvolumes, | ||
molecule_ids=self._molecule_ids, | ||
keys=self._keys, | ||
) | ||
|
||
def keys(self) -> List[str]: | ||
return self._keys | ||
|
||
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]: | ||
subvolume = self._subvolumes[idx] | ||
molecule_idx = self._molecule_ids[idx] | ||
|
||
if self.augment: | ||
theta = np.random.uniform(low=-30.0, high=30, size=(3,)) | ||
|
||
for d in range(3): | ||
axis = (d, (d + 1) % 3) | ||
subvolume = rotate(subvolume, theta[d], axis, reshape=False) | ||
|
||
subvolume = subvolume[8:-8, 8:-8, 8:-8] | ||
assert subvolume.shape == tuple(self.boxsize), subvolume.shape | ||
|
||
if self.model_type == SHRECModelType.RECONSTRUCTION: | ||
subvolume = (subvolume - np.mean(subvolume)) / max( | ||
1 / np.sqrt(subvolume.size), np.std(subvolume) | ||
) | ||
|
||
subvolume = torch.as_tensor(subvolume[None, ...], dtype=torch.float32) | ||
return subvolume, molecule_idx | ||
|
||
def __len__(self): | ||
return self._n_molecules | ||
|
||
def examples(self) -> Tuple[torch.Tensor, List[str]]: | ||
"""Return a set of examples of subvolumes.""" | ||
x_idx = set() | ||
x_complete = set(range(len(self._keys))) | ||
examples = [] | ||
examples_class = [] | ||
idx = 0 | ||
|
||
while x_complete.difference(x_idx) != set(): | ||
vol, mol_idx = self[idx] | ||
if mol_idx not in x_idx: | ||
x_idx.add(mol_idx) | ||
examples.append(vol) | ||
examples_class.append(mol_idx) | ||
idx += 1 | ||
|
||
return torch.stack(examples, axis=0), [ | ||
self._keys[idx] for idx in examples_class | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from vne.utils.napari import GenerativeAffinityVAEWidget # NOQA: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import dataclasses | ||
|
||
import numpy as np | ||
|
||
|
||
@dataclasses.dataclass | ||
class CyclicAnnealing: | ||
n_cycles: int = 4 | ||
n_iterations: int = 1_000 | ||
R: float = 0.5 | ||
beta: float = 1.0 | ||
|
||
def __post_init__(self): | ||
self.reset() | ||
self._func = lambda t: 1.0 / (1.0 + np.exp(-(t * 18 - 3))) | ||
|
||
def step(self) -> float: | ||
self._iteration += 1 | ||
t_over_m = self.n_iterations / self.n_cycles | ||
tau = np.mod((self._iteration - 1), np.ceil(t_over_m)) / t_over_m | ||
|
||
if self._iteration >= self.n_iterations: | ||
return self.beta | ||
|
||
if tau <= self.R: | ||
return self._func(tau) * self.beta | ||
else: | ||
return self.beta | ||
|
||
def reset(self): | ||
self._iteration = 0 | ||
print(f"Resetting cyclic annealing: {self._iteration}") |
Oops, something went wrong.