Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
quantumjot committed Oct 3, 2024
1 parent ab7f969 commit 80bcc34
Show file tree
Hide file tree
Showing 6 changed files with 562 additions and 4 deletions.
139 changes: 139 additions & 0 deletions vne/special/pdb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import os
import gemmi
import numpy as np
from scipy.spatial.transform import Rotation as R
from .ctf import contrast_transfer_function
from typing import Tuple, Union, List
from pathlib import Path


AXES = ["Cartn_x", "Cartn_y", "Cartn_z"]


def pdb_to_coordinates(filename: os.PathLike) -> np.ndarray:
"""Read a PDB file and return the atomic coordinates.
Parameters
----------
filename : PathLike
A filename for the PDBx/mmCIF file describing the atomic coordinates.
Returns
-------
coords : np.ndarray (N, 3)
A numpy array of the cartesian coordinates of the atoms of the model.
Notes
-----
This is super basic, and does not check for mutliple chains, cofactors etc.
"""

doc = gemmi.cif.read_file(filename) # copy all the data from mmCIF file
block = doc.sole_block() # mmCIF has exactly one block

data = block.find("_atom_site.", AXES)

coords = np.stack(
[[float(r) for r in data.column(idx)] for idx in range(len(AXES))],
axis=-1,
)

# center the molecule in XYZ
centroids = np.mean(coords, axis=0)
coords = coords - centroids

return coords


class DensitySimulator:
"""Simulate a Cryo-EM image using atomic coordinates from PDB files.
Parameters
----------
filename : PathLike
The PDB filename to use to extract atomic coordinates.
pixel_size : float
The pixel size for the image in angstroms.
box_size : float
The size of the box in pixels for image generation.
add_poission_noise : bool, default = True
Add shot noise to the final image.
Returns
-------
density : np.ndarray (N, N)
The simulated projection of the electron density.
"""

def __init__(
self,
filenames: Union[List[os.PathLike], os.PathLike],
pixel_size: float = 1.0,
box_size: int = 128,
defocus: float = 5e3,
):
if isinstance(filenames, list):
filenames = [Path(f) for f in filenames]
else:
filenames = Path(filenames)

if filenames.is_dir():
filenames = [
f for f in filenames.iterdir() if f.suffix == ".cif"
]
else:
filenames = [Path(filenames)]

self.filenames = filenames
self.pixel_size = pixel_size
self.box_size = box_size

self.structures = {
filename.name: pdb_to_coordinates(str(filename))
for filename in self.filenames
}

self.ctf = contrast_transfer_function(
defocus=defocus,
box_size=box_size * 2,
pixel_size=pixel_size,
)

def keys(self):
return list(self.structures.keys())

def __call__(
self,
key: str,
transform_euler_angles: list = [0, 0, 0],
transform_translate: list = [0, 0, 0],
project: bool = True,
add_poisson_noise: bool = False,
) -> Tuple[np.ndarray]:
# get the atomic coordinates
coords = self.structures[key]

# do a transform
r = R.from_euler("xyz", transform_euler_angles, degrees=True)
coords = r.apply(coords)

# centre the molecule, ish
pad = self.box_size // 2
data = coords / self.pixel_size + [pad, pad, pad]

# discretize the atomic coords assuming 1 px == 1 Angstrom
density, _ = np.histogramdd(
data,
bins=self.box_size,
range=tuple([(0, self.box_size - 1)] * 3),
)

if not project:
assert density.ndim == 3
return density

density = np.sum(density, axis=-1)
density = density + 1.0
assert density.ndim == 2

return density
204 changes: 204 additions & 0 deletions vne/special/shrec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
import enum
import itertools
import mrcfile
import torch
import warnings

import numpy as np
import numpy.typing as npt
import pandas as pd

from pathlib import Path
from scipy.ndimage import rotate
from typing import List, Tuple
from tqdm import tqdm

from vne.utils.utils import InfinitePaddedImage


class SHRECModelType(str, enum.Enum):
RECONSTRUCTION = "reconstruction"
GRANDMODEL = "grandmodel"
CLASSMASK = "classmask"


class SHRECModel:
def __init__(
self,
model_path: Path,
*,
model_type: SHRECModelType = "reconstruction",
exclude: List[str] = ["vesicle", "fiducial", "4V94"],
boxsize: Tuple[int, int, int] = (32, 32, 32),
augment: bool = False,
) -> None:
self.model_path = model_path
self.model_type = SHRECModelType(model_type)
particles = pd.read_csv(
model_path / "particle_locations.txt",
names=[
"class",
"x",
"y",
"z",
"rotation_Z1",
"rotation_X",
"rotation_Z2",
],
sep=" ",
)
self.exclude = exclude
self.particles = particles[~particles["class"].isin(exclude)]
self.data = None
self.boxsize = np.array(boxsize)
self.augment = augment

def __len__(self) -> int:
return len(self.particles)

def keys(self) -> List[str]:
return list(set(self.particles["class"].tolist()))

def __getitem__(self, idx: int) -> Tuple[npt.NDArray, str]:
"""Get the particle and class."""

if self.data is None:
self._load_volume()

particle = self.particles.iloc[idx]

# if augmenting, crop a slightly larger volume
cropsize = (
np.array(self.boxsize) + 16 if self.augment else self.boxsize
)

slices = [
slice(
particle[dim] - cropsize[d] // 2,
particle[dim] - cropsize[d] // 2 + cropsize[d],
1,
)
for d, dim in enumerate(["z", "y", "x"])
]

sz, sy, sx = slices
subvolume = self.data[sz, sy, sx]
return subvolume, str(particle["class"])

def _load_volume(self) -> None:
if self.data is not None:
return

data_fn = self.model_path / f"{self.model_type}.mrc"

with warnings.catch_warnings():
warnings.simplefilter("ignore", RuntimeWarning)
with mrcfile.open(data_fn, permissive=True) as mrc:
self.data = InfinitePaddedImage(np.asarray(mrc.data))


class SHRECDataset(torch.utils.data.Dataset):
def __init__(self, root_path: Path, **kwargs):
super().__init__()

self.boxsize = kwargs.get("boxsize", (32, 32, 32))
self.augment = kwargs.get("augment", True)
self.model_type = kwargs.get("model_type", SHRECModelType.GRANDMODEL)

self._subtomo_fn = (
Path(root_path) / f"subtomograms_{self.model_type.upper()}.npz"
)

if self._subtomo_fn.exists():
print(f"Loading dataset: {self._subtomo_fn}...")
dataset = np.load(self._subtomo_fn)
self._subvolumes = dataset["volumes"]
self._molecule_ids = dataset["molecule_ids"]
self._keys = dataset["keys"].tolist()
self._n_molecules = self._subvolumes.shape[0]
return

self.models = [
SHRECModel(model_path, **kwargs)
for model_path in root_path.iterdir()
if model_path.stem.startswith("model_")
]
# self.models = self.models[0:1]
self._n_molecules = sum(len(model) for model in self.models)
self.extract_subvolumes()

def extract_subvolumes(self):
subvolume_shape = self.models[0][0][0].shape
self._subvolumes = np.zeros(
(self._n_molecules, *subvolume_shape), dtype=np.float32
)
self._molecule_ids = np.zeros((self._n_molecules,), dtype=np.uint8)
jdx = 0
for model in tqdm(self.models, desc="Extracting subvolumes"):
for idx in range(len(model)):
self._subvolumes[jdx, ...], molecule_id = model[idx]
self._molecule_ids[jdx] = self._keys.index(molecule_id)
jdx += 1

self._keys = list(
set(
itertools.chain.from_iterable(
[model.keys() for model in self.models]
)
)
)

np.savez(
self._subtomo_fn,
volumes=self._subvolumes,
molecule_ids=self._molecule_ids,
keys=self._keys,
)

def keys(self) -> List[str]:
return self._keys

def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
subvolume = self._subvolumes[idx]
molecule_idx = self._molecule_ids[idx]

if self.augment:
theta = np.random.uniform(low=-30.0, high=30, size=(3,))

for d in range(3):
axis = (d, (d + 1) % 3)
subvolume = rotate(subvolume, theta[d], axis, reshape=False)

subvolume = subvolume[8:-8, 8:-8, 8:-8]
assert subvolume.shape == tuple(self.boxsize), subvolume.shape

if self.model_type == SHRECModelType.RECONSTRUCTION:
subvolume = (subvolume - np.mean(subvolume)) / max(
1 / np.sqrt(subvolume.size), np.std(subvolume)
)

subvolume = torch.as_tensor(subvolume[None, ...], dtype=torch.float32)
return subvolume, molecule_idx

def __len__(self):
return self._n_molecules

def examples(self) -> Tuple[torch.Tensor, List[str]]:
"""Return a set of examples of subvolumes."""
x_idx = set()
x_complete = set(range(len(self._keys)))
examples = []
examples_class = []
idx = 0

while x_complete.difference(x_idx) != set():
vol, mol_idx = self[idx]
if mol_idx not in x_idx:
x_idx.add(mol_idx)
examples.append(vol)
examples_class.append(mol_idx)
idx += 1

return torch.stack(examples, axis=0), [
self._keys[idx] for idx in examples_class
]
1 change: 1 addition & 0 deletions vne/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from vne.utils.napari import GenerativeAffinityVAEWidget # NOQA: F401
32 changes: 32 additions & 0 deletions vne/utils/anneal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import dataclasses

import numpy as np


@dataclasses.dataclass
class CyclicAnnealing:
n_cycles: int = 4
n_iterations: int = 1_000
R: float = 0.5
beta: float = 1.0

def __post_init__(self):
self.reset()
self._func = lambda t: 1.0 / (1.0 + np.exp(-(t * 18 - 3)))

def step(self) -> float:
self._iteration += 1
t_over_m = self.n_iterations / self.n_cycles
tau = np.mod((self._iteration - 1), np.ceil(t_over_m)) / t_over_m

if self._iteration >= self.n_iterations:
return self.beta

if tau <= self.R:
return self._func(tau) * self.beta
else:
return self.beta

def reset(self):
self._iteration = 0
print(f"Resetting cyclic annealing: {self._iteration}")
Loading

0 comments on commit 80bcc34

Please sign in to comment.