From 4b10408052f40ce5089cb7e8f887ab9e7c699982 Mon Sep 17 00:00:00 2001 From: SeonghwanSeo Date: Mon, 26 Aug 2024 13:34:09 +0900 Subject: [PATCH 1/2] update --- environment.yml | 2 +- feature_extraction.py | 6 +- pyproject.toml | 11 +- src/pmnet/__init__.py | 2 +- src/pmnet/api/__init__.py | 20 ++++ src/pmnet/data/constant.py | 42 +++---- src/pmnet/data/parser.py | 98 ++++++++++++++++ src/pmnet/module.py | 233 +++++++++++++++---------------------- test/maintain.sh | 1 + 9 files changed, 246 insertions(+), 169 deletions(-) create mode 100644 src/pmnet/api/__init__.py create mode 100644 src/pmnet/data/parser.py diff --git a/environment.yml b/environment.yml index 11c50db..b2f0ab9 100644 --- a/environment.yml +++ b/environment.yml @@ -5,5 +5,5 @@ dependencies: - python=3.11 - pip=24.0 - openbabel=3.1.1 - - pymol-open-source=3.0.0 - numpy=1.26.4 + - pymol-open-source=3.0.0 diff --git a/feature_extraction.py b/feature_extraction.py index e152841..263ab05 100644 --- a/feature_extraction.py +++ b/feature_extraction.py @@ -1,6 +1,6 @@ import argparse import torch -from pmnet.module import PharmacoNet +from pmnet.api import get_pmnet_dev class ArgParser(argparse.ArgumentParser): @@ -57,9 +57,7 @@ def main(args): ] """ device = "cuda" if args.cuda else "cpu" - score_threshold = 0.5 # NOTE: RECOMMENDED_SCORE_THRESHOLD - - module = PharmacoNet(device, score_threshold) + module = get_pmnet_dev(device) multi_scale_features, hotspot_infos = module.feature_extraction(args.protein, args.ref_ligand, args.center) torch.save([multi_scale_features, hotspot_infos], args.out) diff --git a/pyproject.toml b/pyproject.toml index 6c051a0..8666ef6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "pharmaconet" -version = "2.0.3" +version = "2.1.0" description = "PharmacoNet: Open-Source Software for Protein-based Pharmacophore Modeling and Virtual Screening" license = { text = "MIT" } authors = [{ name = "Seonghwan Seo", email = "shwan0106@kaist.ac.kr" }] @@ -34,6 +34,15 @@ dependencies = [ "biopython>=1.83" ] +[project.optional-dependencies] +appl = [ + "torch==2.3.1", + "torch-geometric==2.4.0", + "torch-scatter==2.1.2", + "torch-sparse==0.6.18", + "torch-cluster==1.6.3", +] + [project.urls] Website = "https://github.com/SeonghwanSeo/PharmacoNet" "Source Code" = "https://github.com/SeonghwanSeo/PharmacoNet" diff --git a/src/pmnet/__init__.py b/src/pmnet/__init__.py index dc1e6dd..b6e84a2 100644 --- a/src/pmnet/__init__.py +++ b/src/pmnet/__init__.py @@ -1,6 +1,6 @@ from .pharmacophore_model import PharmacophoreModel -__version__ = "2.0.3" +__version__ = "2.1.0" __citation_information__ = ( "Seo, S., & Kim, W. Y. (2023, December). " "PharmacoNet: Accelerating Large-Scale Virtual Screening by Deep Pharmacophore Modeling. " diff --git a/src/pmnet/api/__init__.py b/src/pmnet/api/__init__.py new file mode 100644 index 0000000..bc5b622 --- /dev/null +++ b/src/pmnet/api/__init__.py @@ -0,0 +1,20 @@ +# NOTE: For DL Model Training +__all__ = ["PharmacoNet", "ProteinParser", "get_pmnet_dev"] + +import torch +from pmnet.module import PharmacoNet +from pmnet.data.parser import ProteinParser + + +def get_pmnet_dev( + device: str | torch.device = "cpu", score_threshold: float = 0.5, molvoxel_library: str = "numpy" +) -> PharmacoNet: + """ + device: 'cpu' | 'cuda' + score_threshold: float | dict[str, float] | None + custom threshold to identify hotspots. + For feature extraction, recommended value is '0.5' + molvoxel_library: str + If you want to use PharmacoNet in DL model training, recommend to use 'numpy' + """ + return PharmacoNet(device, score_threshold, False, molvoxel_library) diff --git a/src/pmnet/data/constant.py b/src/pmnet/data/constant.py index 529dbe8..947b1e9 100644 --- a/src/pmnet/data/constant.py +++ b/src/pmnet/data/constant.py @@ -1,16 +1,16 @@ from typing import Sequence, Set INTERACTION_LIST: Sequence[str] = ( - 'Hydrophobic', - 'PiStacking_P', - 'PiStacking_T', - 'PiCation_lring', - 'PiCation_pring', - 'HBond_ldon', - 'HBond_pdon', - 'SaltBridge_lneg', - 'SaltBridge_pneg', - 'XBond' + "Hydrophobic", + "PiStacking_P", + "PiStacking_T", + "PiCation_lring", + "PiCation_pring", + "HBond_ldon", + "HBond_pdon", + "SaltBridge_lneg", + "SaltBridge_pneg", + "XBond", ) NUM_INTERACTION_TYPES: int = 10 @@ -28,16 +28,16 @@ # PLIP Distance + 0.5 A INTERACTION_DIST = { - HYDROPHOBIC: 4.5, # 4.0 + 0.5 - PISTACKING_P: 6.0, # 5.5 + 0.5 - PISTACKING_T: 6.0, # 5.5 + 0.5 - PICATION_LRING: 6.5, # 6.0 + 0.5 - PICATION_PRING: 6.5, # 6.0 + 0.5 - HBOND_LDON: 4.5, # 4.1 + 0.5 - 0.1 (to be devided to 0.5) - HBOND_PDON: 4.5, # 4.1 + 0.5 - 0.1 - SALTBRIDGE_LNEG: 6.0, # 5.5 + 0.5 - SALTBRIDGE_PNEG: 6.0, # 5.5 + 0.5 - XBOND: 4.5, # 4.0 + 0.5 + HYDROPHOBIC: 4.5, # 4.0 + 0.5 + PISTACKING_P: 6.0, # 5.5 + 0.5 + PISTACKING_T: 6.0, # 5.5 + 0.5 + PICATION_LRING: 6.5, # 6.0 + 0.5 + PICATION_PRING: 6.5, # 6.0 + 0.5 + HBOND_LDON: 4.5, # 4.1 + 0.5 - 0.1 (to be devided to 0.5) + HBOND_PDON: 4.5, # 4.1 + 0.5 - 0.1 + SALTBRIDGE_LNEG: 6.0, # 5.5 + 0.5 + SALTBRIDGE_PNEG: 6.0, # 5.5 + 0.5 + XBOND: 4.5, # 4.0 + 0.5 } LONG_INTERACTION: Set[int] = { @@ -46,7 +46,7 @@ PICATION_PRING, PICATION_LRING, SALTBRIDGE_LNEG, - SALTBRIDGE_PNEG + SALTBRIDGE_PNEG, } SHORT_INTERACTION: Set[int] = { diff --git a/src/pmnet/data/parser.py b/src/pmnet/data/parser.py new file mode 100644 index 0000000..6beaa0f --- /dev/null +++ b/src/pmnet/data/parser.py @@ -0,0 +1,98 @@ +import os +import tempfile +from pathlib import Path + +import torch +import numpy as np +from openbabel import pybel + +from pmnet.data import token_inference, pointcloud +from pmnet.data.objects import Protein +from pmnet.data.extract_pocket import extract_pocket + +from molvoxel import create_voxelizer, BaseVoxelizer +from torch import Tensor +from numpy.typing import NDArray + + +class ProteinParser: + def __init__(self, center_noise: float = 0.0, pocket_extract: bool = True, molvoxel_library: str = "numpy"): + """ + center_noise: for data augmentation + pocket_extract: if True, we read pocket instead of entire protein. (faster) + """ + self.voxelizer = create_voxelizer(0.5, 64, sigma=1 / 3, library=molvoxel_library) + self.noise: float = center_noise + self.extract: bool = pocket_extract + + def __call__( + self, + protein_pdb_path: str | Path, + ref_ligand_path: str | Path | None = None, + center: NDArray[np.float32] | tuple[float, float, float] | None = None, + ) -> tuple[Tensor, Tensor, Tensor, Tensor]: + return self.parse(protein_pdb_path, ref_ligand_path, center) + + def parse( + self, + protein_pdb_path: str | Path, + ref_ligand_path: str | Path | None = None, + center: NDArray[np.float32] | tuple[float, float, float] | None = None, + ) -> tuple[Tensor, Tensor, Tensor, Tensor]: + assert (ref_ligand_path is not None) or (center is not None) + _center = self.get_center(ref_ligand_path, center) + return parse_protein(self.voxelizer, protein_pdb_path, _center, self.noise, self.extract) + + @staticmethod + def get_center( + ref_ligand_path: str | Path | None = None, + center: tuple[float, float, float] | NDArray | None = None, + ) -> tuple[float, float, float]: + if center is not None: + assert len(center) == 3 + x, y, z = center + else: + assert ref_ligand_path is not None + extension = os.path.splitext(ref_ligand_path)[1] + assert extension in [".sdf", ".pdb", ".mol2"] + ref_ligand = next(pybel.readfile(extension[1:], str(ref_ligand_path))) + x, y, z = np.mean([atom.coords for atom in ref_ligand.atoms], axis=0, dtype=np.float32).tolist() + return float(x), float(y), float(z) + + +def parse_protein( + voxelizer: BaseVoxelizer, + protein_pdb_path: str | Path, + center: NDArray[np.float32] | tuple[float, float, float], + center_noise: float = 0.0, + pocket_extract: bool = True, +) -> tuple[Tensor, Tensor, Tensor, Tensor]: + if isinstance(center, tuple): + center = np.array(center, dtype=np.float32) + if center_noise > 0: + center = center + (np.random.rand(3) * 2 - 1) * center_noise + + if pocket_extract: + with tempfile.TemporaryDirectory() as dirname: + pocket_path = os.path.join(dirname, "pocket.pdb") + extract_pocket(protein_pdb_path, pocket_path, center) + protein_obj: Protein = Protein.from_pdbfile(pocket_path) + else: + protein_obj: Protein = Protein.from_pdbfile(protein_pdb_path) + + token_positions, token_classes = token_inference.get_token_informations(protein_obj) + tokens, filter = token_inference.get_token_and_filter(token_positions, token_classes, center) + token_positions = token_positions[filter] + + protein_positions, protein_features = pointcloud.get_protein_pointcloud(protein_obj) + protein_image = np.asarray( + voxelizer.forward_features(protein_positions, center, protein_features, radii=1.5), np.float32 + ) + mask = np.logical_not(np.asarray(voxelizer.forward_single(protein_positions, center, radii=1.0), np.bool_)) + del protein_obj + return ( + torch.from_numpy(protein_image).to(torch.float), + torch.from_numpy(mask).to(torch.bool), + torch.from_numpy(token_positions).to(torch.float), + torch.from_numpy(tokens).to(torch.long), + ) diff --git a/src/pmnet/module.py b/src/pmnet/module.py index cedbac4..3df8a27 100644 --- a/src/pmnet/module.py +++ b/src/pmnet/module.py @@ -1,11 +1,10 @@ from __future__ import annotations import os -import tempfile +import tqdm import logging from pathlib import Path from importlib.util import find_spec -import tqdm from openbabel import pybel import torch import numpy as np @@ -15,15 +14,11 @@ from torch import Tensor from numpy.typing import NDArray -from molvoxel import create_voxelizer, BaseVoxelizer - from pmnet.network import build_model from pmnet.network.detector import PharmacoFormer -from pmnet.data import token_inference, pointcloud from pmnet.data import constant as C -from pmnet.data import INTERACTION_LIST -from pmnet.data.objects import Protein -from pmnet.data.extract_pocket import extract_pocket +from pmnet.data import token_inference +from pmnet.data.parser import ProteinParser from pmnet.utils.smoothing import GaussianSmoothing from pmnet.utils.download_weight import download_pretrained_model from pmnet.pharmacophore_model import PharmacophoreModel, INTERACTION_TO_PHARMACOPHORE, INTERACTION_TO_HOTSPOT @@ -47,7 +42,7 @@ class PharmacoNet: def __init__( self, - device: str = "cpu", + device: str | torch.device = "cpu", score_threshold: float | dict[str, float] | None = DEFAULT_SCORE_THRESHOLD, verbose: bool = True, molvoxel_library: str = "numba", @@ -63,6 +58,7 @@ def __init__( assert molvoxel_library in ["numpy", "numba"] if molvoxel_library == "numba" and (not find_spec("numba")): molvoxel_library = "numpy" + self.parser: ProteinParser = ProteinParser(molvoxel_library=molvoxel_library) running_path = Path(__file__) weight_path = running_path.parent / "weights" / "model.tar" @@ -74,7 +70,7 @@ def __init__( model.load_state_dict(checkpoint["model"]) model.eval() self.model: PharmacoFormer = model.to(device) - self.device = device + self.smoothing = GaussianSmoothing(kernel_size=5, sigma=0.5).to(device) self.score_distributions = { typ: np.array(distribution["focus"]) for typ, distribution in checkpoint["score_distributions"].items() } @@ -86,16 +82,10 @@ def __init__( if isinstance(score_threshold, dict): self.score_threshold = score_threshold elif isinstance(score_threshold, float): - self.score_threshold = {typ: score_threshold for typ in INTERACTION_LIST} + self.score_threshold = {typ: score_threshold for typ in C.INTERACTION_LIST} else: self.score_threshold = DEFAULT_SCORE_THRESHOLD - self.resolution = 0.5 - self.size = 64 - self.voxelizer: BaseVoxelizer = create_voxelizer( - self.resolution, self.size, sigma=(1 / 3), library=molvoxel_library - ) - self.smoothing = GaussianSmoothing(kernel_size=5, sigma=0.5).to(device) if verbose: self.logger = logging.getLogger("PharmacoNet") else: @@ -110,7 +100,7 @@ def run( ) -> PharmacophoreModel: assert (ref_ligand_path is not None) or (center is not None) center = self.get_center(ref_ligand_path, center) - protein_data = parse_protein(self.voxelizer, protein_pdb_path, center, 0.0, True) + protein_data = self.parser.parse(protein_pdb_path, center=center) hotspot_infos = self.create_density_maps(protein_data) with open(protein_pdb_path) as f: pdbblock: str = "\n".join(f.readlines()) @@ -123,11 +113,70 @@ def feature_extraction( ref_ligand_path: str | Path | None = None, center: tuple[float, float, float] | NDArray | None = None, ) -> tuple[list[Tensor], list[dict[str, Any]]]: - assert (ref_ligand_path is not None) or (center is not None) - center = self.get_center(ref_ligand_path, center) - protein_data = parse_protein(self.voxelizer, protein_pdb_path, center, 0.0, True) + protein_data = self.parser.parse(protein_pdb_path, ref_ligand_path, center) return self.run_extraction(protein_data) + @torch.no_grad() + def run_extraction( + self, protein_data: tuple[Tensor, Tensor, Tensor, Tensor] + ) -> tuple[list[Tensor], list[dict[str, Any]]]: + protein_image, mask, token_pos, tokens = protein_data + protein_image = protein_image.to(device=self.device) + token_pos = token_pos.to(device=self.device) + tokens = tokens.to(device=self.device) + mask = mask.to(device=self.device) + + multi_scale_features = self.model.forward_feature(protein_image.unsqueeze(0)) # List[[1, D, H, W, F]] + token_scores, token_features = self.model.forward_token_prediction(multi_scale_features[-1], [tokens]) + token_scores = token_scores[0].sigmoid() # [Ntoken,] + token_features = token_features[0] # [Ntoken, F] + cavity_narrow, cavity_wide = self.model.forward_cavity_extraction(multi_scale_features[-1]) + cavity_narrow = cavity_narrow[0].sigmoid() > self.focus_threshold # [1, D, H, W] + cavity_wide = cavity_wide[0].sigmoid() > self.focus_threshold # [1, D, H, W] + + indices = [] + rel_scores = [] + for i in range(tokens.shape[0]): + x, y, z, typ = tokens[i].tolist() + # NOTE: Check the token score + absolute_score = token_scores[i].item() + relative_score = float((self.score_distributions[C.INTERACTION_LIST[int(typ)]] < absolute_score).mean()) + if relative_score < self.score_threshold[C.INTERACTION_LIST[int(typ)]]: + continue + # NOTE: Check the token exists in cavity + _cavity = cavity_wide if typ in C.LONG_INTERACTION else cavity_narrow + if not _cavity[0, x, y, z]: + continue + indices.append(i) + rel_scores.append(relative_score) + hotspots = tokens[indices] # [Ntoken',] + hotpsot_pos = token_pos[indices] # [Ntoken', 3] + hotspot_features = token_features[indices] # [Ntoken', F] + del protein_image, mask, token_pos, tokens + + hotspot_infos = [] + for hotspot, score, position, feature in zip(hotspots, rel_scores, hotpsot_pos, hotspot_features, strict=True): + interaction_type = C.INTERACTION_LIST[int(hotspot[3])] + hotspot_infos.append( + { + "nci_type": interaction_type, + "hotspot_type": INTERACTION_TO_HOTSPOT[interaction_type], + "hotspot_feature": feature, + "hotspot_position": tuple(position.tolist()), + "hotspot_score": float(score), + "point_type": INTERACTION_TO_PHARMACOPHORE[interaction_type], + } + ) + return multi_scale_features, hotspot_infos + + def print_log(self, level, log): + if self.logger is None: + return None + if level == "debug": + self.logger.debug(log) + elif level == "info": + self.logger.info(log) + @staticmethod def get_center( ref_ligand_path: str | Path | None = None, @@ -146,9 +195,9 @@ def get_center( @torch.no_grad() def create_density_maps(self, protein_data: tuple[Tensor, Tensor, Tensor, Tensor]): - protein_image, mask, token_positions, tokens = protein_data + protein_image, mask, token_pos, tokens = protein_data protein_image = protein_image.to(device=self.device, dtype=torch.float) - token_positions = token_positions.to(device=self.device, dtype=torch.float) + token_pos = token_pos.to(device=self.device, dtype=torch.float) tokens = tokens.to(device=self.device, dtype=torch.long) mask = mask.to(device=self.device, dtype=torch.bool) @@ -156,27 +205,23 @@ def create_density_maps(self, protein_data: tuple[Tensor, Tensor, Tensor, Tensor "debug", f"Protein-based Pharmacophore Modeling... (device: {self.device})", ) - protein_image = protein_image.unsqueeze(0) - multi_scale_features = self.model.forward_feature(protein_image) # List[[1, D, H, W, F]] - bottom_features = multi_scale_features[-1] - - token_scores, token_features = self.model.forward_token_prediction(bottom_features, [tokens]) + multi_scale_features = self.model.forward_feature(protein_image.unsqueeze(0)) # List[[1, D, H, W, F]] + token_scores, token_features = self.model.forward_token_prediction(multi_scale_features[-1], [tokens]) token_scores = token_scores[0].sigmoid() # [Ntoken,] token_features = token_features[0] # [Ntoken, F] - - cavity_narrow, cavity_wide = self.model.forward_cavity_extraction(bottom_features) + cavity_narrow, cavity_wide = self.model.forward_cavity_extraction(multi_scale_features[-1]) cavity_narrow = cavity_narrow[0].sigmoid() > self.focus_threshold # [1, D, H, W] cavity_wide = cavity_wide[0].sigmoid() > self.focus_threshold # [1, D, H, W] num_tokens = tokens.shape[0] indices = [] - relative_scores = [] + rel_scores = [] for i in range(num_tokens): x, y, z, typ = tokens[i].tolist() # NOTE: Check the token score absolute_score = token_scores[i].item() - relative_score = float((self.score_distributions[INTERACTION_LIST[int(typ)]] < absolute_score).mean()) - if relative_score < self.score_threshold[INTERACTION_LIST[int(typ)]]: + relative_score = float((self.score_distributions[C.INTERACTION_LIST[int(typ)]] < absolute_score).mean()) + if relative_score < self.score_threshold[C.INTERACTION_LIST[int(typ)]]: continue # NOTE: Check the token exists in cavity if typ in C.LONG_INTERACTION: @@ -186,12 +231,12 @@ def create_density_maps(self, protein_data: tuple[Tensor, Tensor, Tensor, Tensor if not cavity_narrow[0, x, y, z]: continue indices.append(i) - relative_scores.append(relative_score) + rel_scores.append(relative_score) selected_indices = torch.tensor(indices, device=self.device, dtype=torch.long) # [Ntoken',] hotspots = tokens[selected_indices] # [Ntoken',] - hotspot_positions = token_positions[selected_indices] # [Ntoken', 3] + hotpsot_pos = token_pos[selected_indices] # [Ntoken', 3] hotspot_features = token_features[selected_indices] # [Ntoken', F] - del protein_image, tokens, token_positions, token_features + del protein_image, tokens, token_pos, token_features density_maps_list = [] if self.device == "cpu": @@ -224,11 +269,10 @@ def create_density_maps(self, protein_data: tuple[Tensor, Tensor, Tensor, Tensor density_maps[density_maps < self.box_threshold] = 0.0 hotspot_infos = [] - assert len(hotspots) == len(relative_scores) - for hotspot, score, position, map in zip(hotspots, relative_scores, hotspot_positions, density_maps): + for hotspot, score, position, map in zip(hotspots, rel_scores, hotpsot_pos, density_maps, strict=True): if torch.all(map < 1e-6): continue - interaction_type = INTERACTION_LIST[int(hotspot[3])] + interaction_type = C.INTERACTION_LIST[int(hotspot[3])] hotspot_infos.append( { "nci_type": interaction_type, @@ -245,108 +289,15 @@ def create_density_maps(self, protein_data: tuple[Tensor, Tensor, Tensor, Tensor ) return hotspot_infos - @torch.no_grad() - def run_extraction( - self, protein_data: tuple[Tensor, Tensor, Tensor, Tensor] - ) -> tuple[list[Tensor], list[dict[str, Any]]]: - protein_image, mask, token_positions, tokens = protein_data - protein_image = protein_image.to(device=self.device, dtype=torch.float) - token_positions = token_positions.to(device=self.device, dtype=torch.float) - tokens = tokens.to(device=self.device, dtype=torch.long) - mask = mask.to(device=self.device, dtype=torch.bool) - - protein_image = protein_image.unsqueeze(0) - multi_scale_features = self.model.forward_feature(protein_image) # List[[1, D, H, W, F]] - bottom_features = multi_scale_features[-1] - - token_scores, token_features = self.model.forward_token_prediction(bottom_features, [tokens]) - token_scores = token_scores[0].sigmoid() # [Ntoken,] - token_features = token_features[0] # [Ntoken, F] - - cavity_narrow, cavity_wide = self.model.forward_cavity_extraction(bottom_features) - cavity_narrow = cavity_narrow[0].sigmoid() > self.focus_threshold # [1, D, H, W] - cavity_wide = cavity_wide[0].sigmoid() > self.focus_threshold # [1, D, H, W] - - num_tokens = tokens.shape[0] - indices = [] - relative_scores = [] - for i in range(num_tokens): - x, y, z, typ = tokens[i].tolist() - # NOTE: Check the token score - absolute_score = token_scores[i].item() - relative_score = float((self.score_distributions[INTERACTION_LIST[int(typ)]] < absolute_score).mean()) - if relative_score < self.score_threshold[INTERACTION_LIST[int(typ)]]: - continue - # NOTE: Check the token exists in cavity - _cavity = cavity_wide if typ in C.LONG_INTERACTION else cavity_narrow - if not _cavity[0, x, y, z]: - continue - indices.append(i) - relative_scores.append(relative_score) - hotspots = tokens[indices] # [Ntoken',] - hotspot_positions = token_positions[indices] # [Ntoken', 3] - hotspot_features = token_features[indices] # [Ntoken', F] - - hotspot_infos = [] - assert len(hotspots) == len(relative_scores) - for hotspot, score, position, feature in zip(hotspots, relative_scores, hotspot_positions, hotspot_features): - interaction_type = INTERACTION_LIST[int(hotspot[3])] - hotspot_infos.append( - { - "nci_type": interaction_type, - "hotspot_type": INTERACTION_TO_HOTSPOT[interaction_type], - "hotspot_feature": feature, - "hotspot_position": tuple(position.tolist()), - "hotspot_score": float(score), - "point_type": INTERACTION_TO_PHARMACOPHORE[interaction_type], - } - ) - del protein_image, mask, token_positions, tokens - return multi_scale_features, hotspot_infos - - def print_log(self, level, log): - if self.logger is None: - return None - if level == "debug": - self.logger.debug(log) - elif level == "info": - self.logger.info(log) - - -# NOTE: For DL Model Training -def parse_protein( - voxelizer: BaseVoxelizer, - protein_pdb_path: str | Path, - center: NDArray[np.float32] | tuple[float, float, float], - center_noise: float = 0.0, - pocket_extract: bool = True, -) -> tuple[Tensor, Tensor, Tensor, Tensor]: - if isinstance(center, tuple): - center = np.array(center, dtype=np.float32) - if center_noise > 0: - center = center + (np.random.rand(3) * 2 - 1) * center_noise + @property + def device(self): + return next(self.model.parameters()).device - if pocket_extract: - with tempfile.TemporaryDirectory() as dirname: - pocket_path = os.path.join(dirname, "pocket.pdb") - extract_pocket(protein_pdb_path, pocket_path, center) - protein_obj: Protein = Protein.from_pdbfile(pocket_path) - else: - protein_obj: Protein = Protein.from_pdbfile(protein_pdb_path) + def to(self, device): + self.model = self.model.to(device) - token_positions, token_classes = token_inference.get_token_informations(protein_obj) - tokens, filter = token_inference.get_token_and_filter(token_positions, token_classes, center) - token_positions = token_positions[filter] + def cuda(self): + self.model = self.model.cuda() - protein_positions, protein_features = pointcloud.get_protein_pointcloud(protein_obj) - protein_image = np.asarray( - voxelizer.forward_features(protein_positions, center, protein_features, radii=1.5), np.float32 - ) - mask = np.logical_not(np.asarray(voxelizer.forward_single(protein_positions, center, radii=1.0), np.bool_)) - del protein_obj - return ( - torch.from_numpy(protein_image), - torch.from_numpy(mask), - torch.from_numpy(token_positions), - torch.from_numpy(tokens), - ) + def cpu(self): + self.model = self.model.cpu() diff --git a/test/maintain.sh b/test/maintain.sh index ad4a4a7..90999f9 100644 --- a/test/maintain.sh +++ b/test/maintain.sh @@ -1,5 +1,6 @@ /bin/rm -rf result/6oim python modeling.py --cuda --pdb 6oim +python modeling.py --cuda --pdb 6oim -c D python modeling.py --cuda --pdb 6oim --ref_ligand ./result/6oim/6oim_B_MG.pdb python modeling.py --cuda --protein ./result/6oim/6oim.pdb --prefix 6oim python modeling.py --cuda --protein ./result/6oim/6oim.pdb --ref_ligand ./result/6oim/6oim_B_MG.pdb --prefix 6oim From b4ad69f0180e0bc2b02f5766bd751576b70f02ca Mon Sep 17 00:00:00 2001 From: SeonghwanSeo Date: Wed, 28 Aug 2024 23:10:24 +0900 Subject: [PATCH 2/2] release application, python api --- .gitignore | 3 + README.md | 69 +++--- pyproject.toml | 9 + src/pmnet/api/__init__.py | 3 +- src/pmnet/api/typing.py | 6 + src/pmnet/data/parser.py | 3 + src/pmnet/module.py | 10 +- src/pmnet_appl/README.md | 23 ++ src/pmnet_appl/docking_reward/__init__.py | 29 +++ src/pmnet_appl/docking_reward/config.py | 71 ++++++ src/pmnet_appl/docking_reward/dataset.py | 55 +++++ src/pmnet_appl/docking_reward/model.py | 101 +++++++++ .../docking_reward/network/__init__.py | 3 + src/pmnet_appl/docking_reward/network/head.py | 39 ++++ .../docking_reward/network/ligand_encoder.py | 66 ++++++ .../network/pharmacophore_encoder.py | 60 +++++ src/pmnet_appl/docking_reward/trainer.py | 207 ++++++++++++++++++ src/pmnet_appl/docking_reward/utils.py | 88 ++++++++ src/pmnet_appl/scripts/train_debug.py | 18 ++ test/maintain.sh | 9 - 20 files changed, 818 insertions(+), 54 deletions(-) create mode 100644 src/pmnet/api/typing.py create mode 100644 src/pmnet_appl/README.md create mode 100644 src/pmnet_appl/docking_reward/__init__.py create mode 100644 src/pmnet_appl/docking_reward/config.py create mode 100644 src/pmnet_appl/docking_reward/dataset.py create mode 100644 src/pmnet_appl/docking_reward/model.py create mode 100644 src/pmnet_appl/docking_reward/network/__init__.py create mode 100644 src/pmnet_appl/docking_reward/network/head.py create mode 100644 src/pmnet_appl/docking_reward/network/ligand_encoder.py create mode 100644 src/pmnet_appl/docking_reward/network/pharmacophore_encoder.py create mode 100644 src/pmnet_appl/docking_reward/trainer.py create mode 100644 src/pmnet_appl/docking_reward/utils.py create mode 100644 src/pmnet_appl/scripts/train_debug.py delete mode 100644 test/maintain.sh diff --git a/.gitignore b/.gitignore index 25b5a33..eb72a4b 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,9 @@ run.sh result/ examples/library/ nogit/ +maintain_test/ +tacogfn_reward +largfn_reward # Byte-compiled / optimized / DLL files diff --git a/README.md b/README.md index 8b2c5ad..a267662 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,7 @@ Official Github for **_PharmacoNet: Accelerating Large-Scale Virtual Screening b 1. Fully automated protein-based pharmacophore modeling based on image instance segmentation modeling 2. Coarse-grained graph matching at the pharmacophore level for high throughput 3. Pharmacophore-aware scoring function with parameterized analytical function for robust generalization ability +4. Better pocket representation for deep learning developer. ([Section](#pharmacophore-feature-extraction)) PharmacoNet is an extremely rapid yet reasonably accurate ligand evaluation tool with high generation ability. @@ -164,27 +165,34 @@ score = model.scoring_smiles(, ) ## Pharmacophore Feature Extraction -For deep learning researcher who want to use PharmacoNet as pre-trained model for feature extraction, we provide the script `feature_extraction.py`. -```bash -python feature_extraction.py --protein --ref_ligand --out -python feature_extraction.py --protein --center --out -``` +***See: [`./src/pmnet_appl/`](/src/pmnet_appl/).*** -```bash -OUTPUT=(multi_scale_features, hotspot_info) - multi_scale_features: list[torch.Tensor]: - - torch.Tensor [96, 4, 4, 4] - - torch.Tensor [96, 8, 8, 8] - - torch.Tensor [96, 16, 16, 16] - - torch.Tensor [96, 32, 32, 32] - - torch.Tensor [96, 64, 64, 64] - hotspot_infos: list[hotspot_info] - info: dict[str, Any] - - hotspot_feature: torch.Tensor (192,) +For deep learning researcher who want to use PharmacoNet as pre-trained model for feature extraction, we provide the python API. + +```python +from pmnet.api import PharmacoNet, get_pmnet_dev, ProteinParser +module: PharmacoNet = get_pmnet_dev('cuda') # default: score_threshold=0.5 (less threshold: more features) + +# End-to-End calculation +pmnet_attr = module.feature_extraction(, ref_ligand_path=) +pmnet_attr = module.feature_extraction(, center=(, , )) + +# Step-wise calculation +## In Dataset +parser = ProteinParser(center_noise=) # center_noise: for data augmentation +## In Model (freezed, method is decorated by torch.no_grad()) +pmnet_attr = module.run_extraction(protein_data) + +""" +pmnet_attr = (multi_scale_features, hotspot_infos) +- multi_scale_features: tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + - [96, 4, 4, 4], [96, 8, 8, 8], [96, 16, 16, 16], [96, 32, 32, 32], [96, 64, 64, 64] +- hotspot_infos: list[hotspot_info] + hotspot_info: dict[str, Any] + - hotspot_feature: Tensor [192,] - hotspot_position: tuple[float, float, float] - (x, y, z) - hotspot_score: float in [0, 1] - - nci_type: str (10 types) 'Hydrophobic': Hydrophobic interaction 'PiStacking_P': PiStacking (Parallel) @@ -197,37 +205,14 @@ OUTPUT=(multi_scale_features, hotspot_info) 'HBond_pdon': Hydrogen Bond btw Protein Donor & Ligand Acceptor 'HBond_ldon': Hydrogen Bond btw Protein Acceptor & Ligand Donor + # Features obtained from `nci_type`, i.e. `nci_type` is all you need. - hotspot_type: str (7 types) {'Hydrophobic', 'Aromatic', 'Cation', 'Anion', 'Halogen', 'HBond_donor', 'HBond_acceptor'} - *** `type` is obtained from `nci_type`. - point_type: str (7 types) {'Hydrophobic', 'Aromatic', 'Cation', 'Anion', 'Halogen', 'HBond_donor', 'HBond_acceptor'} - *** `type` is obtained from `nci_type`. -``` - -### Python Script - -For feature extraction, it is recommended to use `score_threshold=0.5` instead of default setting used for pharmacophore modeling. If you want to extract more features, decrease the `score_threshold`. - -```python -from pmnet.module import PharmacoNet, parse_protein -module = PharmacoNet( - "cuda", - score_threshold = 0.5, # , - molvoxel_library = 'numpy' # -) -# End-to-End calculation -multi_scale_features, hotspot_infos = module.feature_extraction(, ) -multi_scale_features, hotspot_infos = module.feature_extraction(, center=(, , )) - -# Step-wise calculation -voxelizer = module.voxelizer -# In Dataset (Type: Tuple[Tensor, Tensor, Tensor, Tensor]) -protein_data = module.parse_protein(voxelizer, , , ) -# In Model -multi_scale_features, hotspot_infos = module.run_extraction(protein_data) +""" ``` ### Paper List diff --git a/pyproject.toml b/pyproject.toml index 8666ef6..b77ec2e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,15 @@ appl = [ "torch-sparse==0.6.18", "torch-cluster==1.6.3", ] +dev = [ + "torch==2.3.1", + "torch-geometric==2.4.0", + "torch-scatter==2.1.2", + "torch-sparse==0.6.18", + "torch-cluster==1.6.3", + "wandb", + "tensorboard", +] [project.urls] Website = "https://github.com/SeonghwanSeo/PharmacoNet" diff --git a/src/pmnet/api/__init__.py b/src/pmnet/api/__init__.py index bc5b622..d5cf7b8 100644 --- a/src/pmnet/api/__init__.py +++ b/src/pmnet/api/__init__.py @@ -1,9 +1,10 @@ # NOTE: For DL Model Training -__all__ = ["PharmacoNet", "ProteinParser", "get_pmnet_dev"] +__all__ = ["PharmacoNet", "ProteinParser", "get_pmnet_dev", "MultiScaleFeature", "HotspotInfo"] import torch from pmnet.module import PharmacoNet from pmnet.data.parser import ProteinParser +from . import typing def get_pmnet_dev( diff --git a/src/pmnet/api/typing.py b/src/pmnet/api/typing.py new file mode 100644 index 0000000..4bc5cdb --- /dev/null +++ b/src/pmnet/api/typing.py @@ -0,0 +1,6 @@ +from torch import Tensor +from typing import Any + + +MultiScaleFeature = tuple[Tensor, Tensor, Tensor, Tensor, Tensor] +HotspotInfo = dict[str, Any] diff --git a/src/pmnet/data/parser.py b/src/pmnet/data/parser.py index 6beaa0f..d4eecf0 100644 --- a/src/pmnet/data/parser.py +++ b/src/pmnet/data/parser.py @@ -25,6 +25,9 @@ def __init__(self, center_noise: float = 0.0, pocket_extract: bool = True, molvo self.noise: float = center_noise self.extract: bool = pocket_extract + ob_log_handler = pybel.ob.OBMessageHandler() + ob_log_handler.SetOutputLevel(0) # 0: None + def __call__( self, protein_pdb_path: str | Path, diff --git a/src/pmnet/module.py b/src/pmnet/module.py index 3df8a27..33f94e4 100644 --- a/src/pmnet/module.py +++ b/src/pmnet/module.py @@ -38,6 +38,9 @@ "Hydrophobic": 0.85, } +MultiScaleFeature = tuple[Tensor, Tensor, Tensor, Tensor, Tensor] +HotspotInfo = dict[str, Any] + class PharmacoNet: def __init__( @@ -70,6 +73,9 @@ def __init__( model.load_state_dict(checkpoint["model"]) model.eval() self.model: PharmacoFormer = model.to(device) + for param in self.model.parameters(): + param.requires_grad = False + self.smoothing = GaussianSmoothing(kernel_size=5, sigma=0.5).to(device) self.score_distributions = { typ: np.array(distribution["focus"]) for typ, distribution in checkpoint["score_distributions"].items() @@ -112,14 +118,14 @@ def feature_extraction( protein_pdb_path: str | Path, ref_ligand_path: str | Path | None = None, center: tuple[float, float, float] | NDArray | None = None, - ) -> tuple[list[Tensor], list[dict[str, Any]]]: + ) -> tuple[MultiScaleFeature, list[HotspotInfo]]: protein_data = self.parser.parse(protein_pdb_path, ref_ligand_path, center) return self.run_extraction(protein_data) @torch.no_grad() def run_extraction( self, protein_data: tuple[Tensor, Tensor, Tensor, Tensor] - ) -> tuple[list[Tensor], list[dict[str, Any]]]: + ) -> tuple[MultiScaleFeature, list[HotspotInfo]]: protein_image, mask, token_pos, tokens = protein_data protein_image = protein_image.to(device=self.device) token_pos = token_pos.to(device=self.device) diff --git a/src/pmnet_appl/README.md b/src/pmnet_appl/README.md new file mode 100644 index 0000000..2fb1f32 --- /dev/null +++ b/src/pmnet_appl/README.md @@ -0,0 +1,23 @@ +# Application of PharmacoNet + +Example scripts to use PharmacoNet's protein pharmacophore representation, which depend on `torch-geometric`. + +```bash +# construct conda environment; pymol-open-source is not required. +conda create -n pmnet-dev python=3.10 openbabel=3.1.1 +conda activate pmnet-dev + +# install PharmacoNet&torch_geometric +pip install -e '.[dev]' --find-links https://data.pyg.org/whl/torch-2.3.1+cu121.html + +# if you want to train model with example scripts +pip install wandb, tensorboard +``` + +## Comming Soon (For Archiving): +- TacoGFN: Target-conditioned GFlowNet for Structure-based Drug Design [[paper](https://arxiv.org/abs/2310.03223)] + + + + + diff --git a/src/pmnet_appl/docking_reward/__init__.py b/src/pmnet_appl/docking_reward/__init__.py new file mode 100644 index 0000000..378d72d --- /dev/null +++ b/src/pmnet_appl/docking_reward/__init__.py @@ -0,0 +1,29 @@ +import os +import wandb + +from pmnet_appl.docking_reward.config import Config +from pmnet_appl.docking_reward.trainer import Trainer + + +def run_config(config: Config, project: str, name: str): + wandb.init(project=project, config=config.to_dict(), name=name) + trainer = Trainer(config, device="cuda") + trainer.fit() + + +if __name__ == "__main__": + PROJECT = "pmnet-appl" + NAME = "debug" + + config = Config() + config.data.protein_dir = "/home/share/DATA/SBDDReward/protein/train/" + config.data.ligand_dir = "/home/share/DATA/SBDDReward/lmdb/train" + config.data.ligand_dir = "/home/shwan/GFLOWNET_PROJECT/DATA/" + + config.train.max_iterations = 1000 + config.train.batch_size = 8 + + config.log_dir = f"./result/{NAME}" + assert not os.path.exists(config.log_dir) + os.mkdir(config.log_dir) + run_config(config, PROJECT, NAME) diff --git a/src/pmnet_appl/docking_reward/config.py b/src/pmnet_appl/docking_reward/config.py new file mode 100644 index 0000000..06a1388 --- /dev/null +++ b/src/pmnet_appl/docking_reward/config.py @@ -0,0 +1,71 @@ +from dataclasses import dataclass +from omegaconf import MISSING + + +@dataclass +class ModelConfig: + hidden_dim: int = 128 + ligand_num_convs: int = 4 + + +@dataclass +class DataConfig: + protein_info_path: str = MISSING + train_protein_code_path: str = MISSING + protein_dir: str = MISSING + ligand_path: str = MISSING + + +@dataclass +class LrSchedulerConfig: + scheduler: str = "lambdalr" + lr_decay: int = 50_000 + + +@dataclass +class OptimizerConfig: + opt: str = "adam" + lr: float = 1e-3 + eps: float = 1e-8 + betas: tuple[float, float] = (0.9, 0.999) + weight_decay: float = 0.05 + clip_grad: float = 1.0 + + +@dataclass +class TrainConfig: + val_every: int = 2_000 + log_every: int = 10 + print_every: int = 100 + save_every: int = 1_000 + max_iterations: int = 300_000 + batch_size: int = 4 + num_workers: int = 4 + + opt: OptimizerConfig = OptimizerConfig() + lr_scheduler: LrSchedulerConfig = LrSchedulerConfig() + + # NOTE: HYPER PARAMETER + split_ratio: float = 0.9 + center_noise: float = 3.0 + + +@dataclass +class Config: + log_dir: str = MISSING + model: ModelConfig = ModelConfig() + train: TrainConfig = TrainConfig() + data: DataConfig = DataConfig() + + def to_dict(self): + return config_to_dict(self) + + +def config_to_dict(obj) -> dict: + if not hasattr(obj, "__dataclass_fields__"): + return obj + result = {} + for field in obj.__dataclass_fields__.values(): + value = getattr(obj, field.name) + result[field.name] = config_to_dict(value) + return {"config": result} diff --git a/src/pmnet_appl/docking_reward/dataset.py b/src/pmnet_appl/docking_reward/dataset.py new file mode 100644 index 0000000..f1d1703 --- /dev/null +++ b/src/pmnet_appl/docking_reward/dataset.py @@ -0,0 +1,55 @@ +import pickle +from pathlib import Path +import torch + +from torch import Tensor +from torch.utils.data import Dataset +from torch_geometric.data import Data, Batch + +from pmnet.api import ProteinParser + +from .utils import smi2graphdata + + +class BaseDataset(Dataset): + def __init__( + self, + code_list: list[str], + protein_info: dict[str, tuple[float, float, float]], + protein_dir: Path | str, + ligand_path: Path | str, + center_noise: float = 0.0, + ): + self.parser: ProteinParser = ProteinParser(center_noise) + + self.code_list: list[str] = code_list + self.protein_info = protein_info + self.protein_dir = Path(protein_dir) + self.center_noise = center_noise + with open(ligand_path, "rb") as f: + self.ligand_data: dict[str, list[tuple[str, str, float]]] = pickle.load(f) + + def __len__(self): + return len(self.code_list) + + def __getitem__(self, index: int) -> tuple[tuple[Tensor, Tensor, Tensor, Tensor], Batch]: + code = self.code_list[index] + protein_path: str = str(self.protein_dir / f"{code}.pdb") + center: tuple[float, float, float] = self.protein_info[code] + pharmacophore_info = self.parser(protein_path, center=center) + ligands = self.ligand_data[code] + ligand_graphs: Batch = Batch.from_data_list(list(map(self.get_ligand_data, ligands))) + return pharmacophore_info, ligand_graphs + + @staticmethod + def get_ligand_data(args: tuple[str, str, float]) -> Data: + ligand_id, smiles, affinity = args + data = smi2graphdata(smiles) + x, edge_index, edge_attr = data["x"], data["edge_index"], data["edge_attr"] + affinity = min(float(affinity), 0.0) + return Data( + x, + edge_index, + edge_attr, + affinity=torch.FloatTensor([affinity]), + ) diff --git a/src/pmnet_appl/docking_reward/model.py b/src/pmnet_appl/docking_reward/model.py new file mode 100644 index 0000000..b1c4468 --- /dev/null +++ b/src/pmnet_appl/docking_reward/model.py @@ -0,0 +1,101 @@ +import torch +import torch.nn as nn +from torch import Tensor +from torch_geometric.data import Batch +from pathlib import Path +from numpy.typing import NDArray +from omegaconf import DictConfig +from typing import NewType + +from pmnet.api import PharmacoNet, get_pmnet_dev + +from .network.pharmacophore_encoder import PharmacophoreEncoder +from .network.ligand_encoder import GraphEncoder +from .network.head import AffinityHead +from .utils import NUM_ATOM_FEATURES, NUM_BOND_FEATURES, smi2graph +from .config import Config + +Cache = NewType("Cache", tuple[Tensor, Tensor, Tensor]) + + +class AffinityModel(nn.Module): + def __init__(self, config: Config | DictConfig, device: str = "cuda"): + super().__init__() + self.pmnet: PharmacoNet = get_pmnet_dev(device) + self.global_cfg = config + self.cfg = config.model + self.pharmacophore_encoder: PharmacophoreEncoder = PharmacophoreEncoder(self.cfg.hidden_dim) + self.ligand_encoder: GraphEncoder = GraphEncoder( + NUM_ATOM_FEATURES, NUM_BOND_FEATURES, self.cfg.hidden_dim, self.cfg.hidden_dim, self.cfg.ligand_num_convs + ) + self.head: AffinityHead = AffinityHead(self.cfg.hidden_dim) + self.l2_loss: nn.MSELoss = nn.MSELoss() + self.to(device) + self.initialize_weights() + + def initialize_weights(self): + self.pharmacophore_encoder.initialize_weights() + self.ligand_encoder.initialize_weights() + self.head.initialize_weights() + + # NOTE: Model training + def forward_train(self, batch) -> Tensor: + if self.pmnet.device != self.device: + self.pmnet.to(self.device) + + loss_list = [] + for pharmacophore_info, ligand_graphs in batch: + # NOTE: Run PharmacoNet Feature Extraction + # (Model is freezed; method `run_extraction` is decorated by torch.no_grad()) + pmnet_attr = self.pmnet.run_extraction(pharmacophore_info) + del pharmacophore_info + + # NOTE: Binding Affinity Prediction + x_protein, pos_protein, Z_protein = self.pharmacophore_encoder.forward(pmnet_attr) + x_ligand = self.ligand_encoder.forward(ligand_graphs.to(self.device)) + affinity = self.head.forward(x_protein, x_ligand, ligand_graphs.batch, ligand_graphs.num_graphs) + + loss_list.append(self.l2_loss.forward(affinity, ligand_graphs.affinity)) + loss = torch.stack(loss_list).mean() + return loss + + # NOTE: Python API + @torch.no_grad() + def feature_extraction( + self, + protein_pdb_path: str | Path, + ref_ligand_path: str | Path | None = None, + center: tuple[float, float, float] | NDArray | None = None, + ) -> Cache: + multi_scale_features, hotspot_infos = self.pmnet.feature_extraction(protein_pdb_path, ref_ligand_path, center) + return self.pharmacophore_encoder(multi_scale_features, hotspot_infos) + + def scoring(self, target: str, smiles: str) -> Tensor: + return self._scoring(self.cache[target], smiles) + + def scoring_list(self, target: str, smiles_list: list[str]) -> Tensor: + return self._scoring_list(self.cache[target], smiles_list) + + @torch.no_grad() + def _scoring(self, cache: Cache, smiles: str) -> Tensor: + return self._scoring_list(cache, [smiles]) + + @torch.no_grad() + def _scoring_list(self, cache: Cache, smiles_list: list[str]) -> Tensor: + Z_protein, X_protein, pos_protein = cache + Z_protein = Z_protein.to(self.device) + X_protein = X_protein.to(self.device) + pos_protein = pos_protein.to(self.device) + ligand_batch = Batch.from_data_list([smi2graph(smiles) for smiles in smiles_list]).to(self.device) + X_ligand, Z_ligand = self.ligand_encoder(ligand_batch) + return self.head.scoring(X_protein, pos_protein, Z_protein, X_ligand, Z_ligand, ligand_batch.batch) + + def to(self, device: str | torch.device): + super().to(device) + if self.pmnet is not None: + if self.pmnet.device != self.device: + self.pmnet.to(device) + + @property + def device(self) -> torch.device: + return next(self.parameters()).device diff --git a/src/pmnet_appl/docking_reward/network/__init__.py b/src/pmnet_appl/docking_reward/network/__init__.py new file mode 100644 index 0000000..5892048 --- /dev/null +++ b/src/pmnet_appl/docking_reward/network/__init__.py @@ -0,0 +1,3 @@ +from .ligand_encoder import GraphEncoder +from .pharmacophore_encoder import PharmacophoreEncoder +from .head import AffinityHead diff --git a/src/pmnet_appl/docking_reward/network/head.py b/src/pmnet_appl/docking_reward/network/head.py new file mode 100644 index 0000000..60a1ae3 --- /dev/null +++ b/src/pmnet_appl/docking_reward/network/head.py @@ -0,0 +1,39 @@ +import torch +from torch import nn + +from torch import Tensor +from torch_geometric.utils import to_dense_batch + + +class AffinityHead(nn.Module): + def __init__(self, hidden_dim: int, p_dropout: float = 0.1): + super().__init__() + self.interaction_mlp: nn.Module = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.LeakyReLU(), + ) + self.mlp_affinity: nn.Module = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU(), nn.Linear(hidden_dim, 1) + ) + self.dropout = nn.Dropout(p_dropout) + + def initialize_weights(self): + def _init_weight(m): + if isinstance(m, (nn.Linear)): + nn.init.xavier_normal_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + self.apply(_init_weight) + + def forward(self, x_protein: Tensor, x_ligand: Tensor, ligand_batch: Tensor, num_ligands: int) -> Tensor: + """ + affinity predict header for (single protein - multi ligands) + output: (N_ligand,) + """ + Z_complex = torch.einsum("ik,jk->ijk", x_ligand, x_protein) # [Vlig, Vprot, Fh] + Z_complex, mask_complex = to_dense_batch(Z_complex, ligand_batch, batch_size=num_ligands) + mask_complex = mask_complex.unsqueeze(-1) # [N, Vlig, 1] + Z_complex = self.interaction_mlp(self.dropout(Z_complex)) + pair_affinity = self.mlp_affinity(Z_complex).squeeze(-1) * mask_complex + return pair_affinity.sum((1, 2)) diff --git a/src/pmnet_appl/docking_reward/network/ligand_encoder.py b/src/pmnet_appl/docking_reward/network/ligand_encoder.py new file mode 100644 index 0000000..6c18fb8 --- /dev/null +++ b/src/pmnet_appl/docking_reward/network/ligand_encoder.py @@ -0,0 +1,66 @@ +from torch import nn +from torch import Tensor + +import torch_geometric.nn as gnn +from torch_geometric.data import Batch, Data + + +class GraphEncoder(nn.Module): + def __init__( + self, + input_node_dim: int, + input_edge_dim: int, + hidden_dim: int, + out_dim: int, + num_convs: int, + ): + super().__init__() + self.graph_channels: int = out_dim + self.atom_channels: int = out_dim + + # Ligand Encoding + self.node_layer = nn.Linear(input_node_dim, hidden_dim) + self.edge_layer = nn.Linear(input_edge_dim, hidden_dim) + self.conv_list = nn.ModuleList( + [ + gnn.GINEConv( + nn=nn.Sequential(gnn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU()), + edge_dim=hidden_dim, + ) + for _ in range(num_convs) + ] + ) + + self.head = nn.Sequential(nn.Linear(hidden_dim, out_dim), nn.LayerNorm(out_dim)) + + def initialize_weights(self): + def _init_weight(m): + if isinstance(m, (nn.Linear)): + nn.init.xavier_normal_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Embedding): + m.weight.data.uniform_(-1, 1) + + self.apply(_init_weight) + + def forward(self, data: Data | Batch) -> Tensor: + """Affinity Prediction + + Args: + x: Node Feature + edge_attr: Edge Feature + edge_index: Edge Index + + Returns: + updated_data: Union[Data, Batch] + """ + x: Tensor = self.node_layer(data.x) + edge_attr: Tensor = self.edge_layer(data.edge_attr) + + skip_x = x + edge_index = data.edge_index + for layer in self.conv_list: + x = layer(x, edge_index, edge_attr) + x = skip_x + x + return self.head(x) diff --git a/src/pmnet_appl/docking_reward/network/pharmacophore_encoder.py b/src/pmnet_appl/docking_reward/network/pharmacophore_encoder.py new file mode 100644 index 0000000..f732a53 --- /dev/null +++ b/src/pmnet_appl/docking_reward/network/pharmacophore_encoder.py @@ -0,0 +1,60 @@ +import torch +from torch import nn +from torch import Tensor +from pmnet.api.typing import MultiScaleFeature, HotspotInfo + + +class PharmacophoreEncoder(nn.Module): + def __init__(self, hidden_dim: int): + super().__init__() + self.multi_scale_dims = [96, 96, 96, 96, 96] + self.hotspot_dim = 192 + self.hidden_dim = hidden_dim + self.hotspot_mlp: nn.Module = nn.Sequential(nn.SiLU(), nn.Linear(self.hotspot_dim, hidden_dim)) + self.pocket_mlp_list: nn.ModuleList = nn.ModuleList( + [nn.Sequential(nn.SiLU(), nn.Conv3d(channels, hidden_dim, 3)) for channels in self.multi_scale_dims] + ) + self.pocket_layer: nn.Module = nn.Sequential( + nn.SiLU(), nn.Linear(5 * hidden_dim, hidden_dim), nn.SiLU(), nn.Linear(hidden_dim, hidden_dim) + ) + + def initialize_weights(self): + def _init_weight(m): + if isinstance(m, nn.Linear | nn.Conv3d): + nn.init.normal_(m.weight, std=0.01) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + self.apply(_init_weight) + + def forward(self, pmnet_attr: tuple[MultiScaleFeature, list[HotspotInfo]]) -> tuple[Tensor, Tensor, Tensor]: + """ + Out: + - hotspot_features: FloatTensor (V, Fh) + - hotspot_positions: FloatTensor (V, 3) (* Real value.) + - pocket_features: FloatTensor (Fh,) + """ + + multi_scale_features, hotspot_infos = pmnet_attr + dev = multi_scale_features[0].device + + # NOTE: Node features + if len(hotspot_infos) > 0: + hotspot_positions = torch.tensor([info["hotspot_position"] for info in hotspot_infos], device=dev) + hotspot_features = torch.stack([info["hotspot_feature"] for info in hotspot_infos]) + hotspot_features = self.hotspot_mlp(hotspot_features) + else: + hotspot_positions = torch.zeros((0, 3), device=dev) + hotspot_features = torch.zeros((0, self.hidden_dim), device=dev) + + # NOTE: Global features + pocket_features: Tensor = torch.cat( + [ + mlp(feat.squeeze(0)).mean((-1, -2, -3)) + for mlp, feat in zip(self.pocket_mlp_list, multi_scale_features, strict=True) + ], + dim=-1, + ) + pocket_features = self.pocket_layer(pocket_features) + + return hotspot_features, hotspot_positions, pocket_features diff --git a/src/pmnet_appl/docking_reward/trainer.py b/src/pmnet_appl/docking_reward/trainer.py new file mode 100644 index 0000000..b48505b --- /dev/null +++ b/src/pmnet_appl/docking_reward/trainer.py @@ -0,0 +1,207 @@ +import sys +import random +import gc +import logging +from pathlib import Path +import time +from omegaconf import OmegaConf +import wandb + +import numpy as np +import torch +import torch.multiprocessing +import torch.utils.tensorboard +from torch.utils.data import DataLoader + +from pmnet.module import PharmacoNet + +from .model import AffinityModel +from .dataset import BaseDataset +from .config import Config + +torch.multiprocessing.set_sharing_strategy("file_system") + + +class Trainer: + def __init__(self, config: Config, device: str = "cuda"): + self.config = config + self.device = device + self.log_dir = Path(config.log_dir) + self.log_dir.mkdir(parents=True) + self.save_dir = self.log_dir / "save" + self.save_dir.mkdir(parents=True) + + self.dictconfig = OmegaConf.create(config.to_dict()) + OmegaConf.save(self.dictconfig, self.log_dir / "config.yaml") + self.logger = create_logger(logfile=self.log_dir / "train.log") + if wandb.run is None: + self._summary_writer = torch.utils.tensorboard.SummaryWriter(self.log_dir) + + self.model = AffinityModel(config, device) + self.pmnet: PharmacoNet = self.model.pmnet + self.setup_data() + self.setup_train() + + def fit(self): + it = 1 + epoch = 0 + best_loss = float("inf") + self.model.train() + while it <= self.config.train.max_iterations: + for batch in self.train_dataloader: + if it > self.config.train.max_iterations: + break + if it % 1024 == 0: + gc.collect() + torch.cuda.empty_cache() + + tick = time.time() + info = self.train_batch(batch) + info["time"] = time.time() - tick + + if it % self.config.train.print_every == 0: + self.logger.info( + f"epoch {epoch} iteration {it} train : " + " ".join(f"{k}:{v:.2f}" for k, v in info.items()) + ) + if it % self.config.train.log_every == 0: + self.log(info, it, epoch, "train") + if it % self.config.train.save_every == 0: + self.save_checkpoint(f"epoch-{epoch}-it-{it}.pth") + if it % self.config.train.val_every == 0: + tick = time.time() + info = self.evaluate() + info["time"] = time.time() - tick + self.logger.info( + f"epoch {epoch} iteration {it} valid : " + " ".join(f"{k}:{v:.2f}" for k, v in info.items()) + ) + self.log(info, it, epoch, "valid") + if info["loss"] < best_loss: + torch.save(self.model.state_dict(), self.save_dir / "best.pth") + best_loss = info["loss"] + it += 1 + epoch += 1 + torch.save(self.model.state_dict(), self.save_dir / "last.pth") + + def log(self, info, index, epoch, key): + info.update({"step": index, "epoch": epoch}) + if wandb.run is not None: + wandb.log({f"{key}/{k}": v for k, v in info.items()}, step=index) + else: + for k, v in info.items(): + self._summary_writer.add_scalar(f"{key}/{k}", v, index) + + def train_batch(self, batch) -> dict[str, float]: + loss = self.model.forward_train(batch) + loss.backward() + torch.nn.utils.clip_grad.clip_grad_norm_(self.model.parameters(), self.config.train.opt.clip_grad) + self.optimizer.step() + self.optimizer.zero_grad() + self.lr_scheduler.step() + return {"loss": loss.item()} + + @torch.no_grad() + def evaluate(self) -> dict[str, float]: + self.model.eval() + logs = {"loss": []} + for batch in self.val_dataloader: + loss = self.model.forward_train(batch) + logs["loss"].append(loss.item()) + self.model.train() + return {k: float(np.mean(v)) for k, v in logs.items()} + + def setup_data(self): + config = self.config + protein_info = {} + with open(config.data.protein_info_path) as f: + lines = f.readlines() + for line in lines: + code, x, y, z = line.strip().split(",") + protein_info[code] = (float(x), float(y), float(z)) + + with open(config.data.train_protein_code_path) as f: + codes = [ln.strip() for ln in f.readlines()] + random.seed(0) + random.shuffle(codes) + split_offset = int(len(codes) * config.train.split_ratio) + train_codes = codes[:split_offset] + val_codes = codes[split_offset:] + + self.train_dataset = BaseDataset( + train_codes, + protein_info, + config.data.protein_dir, + config.data.ligand_path, + config.train.center_noise, + ) + + self.val_dataset = BaseDataset( + val_codes, + protein_info, + config.data.protein_dir, + config.data.ligand_path, + ) + + self.train_dataloader: DataLoader = DataLoader( + self.train_dataset, + batch_size=config.train.batch_size, + shuffle=True, + num_workers=config.train.num_workers, + drop_last=True, + collate_fn=collate_fn, + ) + + self.val_dataloader: DataLoader = DataLoader( + self.val_dataset, + batch_size=config.train.batch_size, + shuffle=False, + num_workers=config.train.num_workers, + collate_fn=collate_fn, + ) + + self.logger.info(f"train set: {len(self.train_dataset)}") + self.logger.info(f"valid set: {len(self.val_dataset)}") + + def setup_train(self): + self.optimizer = torch.optim.Adam( + self.model.parameters(), + self.config.train.opt.lr, + eps=self.config.train.opt.eps, + ) + + self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR( + self.optimizer, lr_lambda=lambda steps: 2 ** (-steps / self.config.train.lr_scheduler.lr_decay) + ) + + def save_checkpoint(self, filename: str): + ckpt = { + "model_state_dict": self.model.state_dict(), + "config": self.dictconfig, + } + torch.save(ckpt, self.save_dir / filename) + + +def collate_fn(batch): + return batch + + +def create_logger(name="logger", loglevel=logging.INFO, logfile=None, streamHandle=True): + logger = logging.getLogger(name) + logger.setLevel(loglevel) + formatter = logging.Formatter( + fmt="%(asctime)s - %(message)s", + datefmt="%d/%m/%Y %H:%M:%S", + ) + handlers = [] + if logfile is not None: + handlers.append(logging.FileHandler(logfile, mode="a")) + if streamHandle: + handlers.append(logging.StreamHandler(stream=sys.stdout)) + + for handler in logger.handlers[:]: + logging.root.removeHandler(handler) + + for handler in handlers: + handler.setFormatter(formatter) + logger.addHandler(handler) + + return logger diff --git a/src/pmnet_appl/docking_reward/utils.py b/src/pmnet_appl/docking_reward/utils.py new file mode 100644 index 0000000..b8742d7 --- /dev/null +++ b/src/pmnet_appl/docking_reward/utils.py @@ -0,0 +1,88 @@ +from openbabel import pybel +from openbabel.pybel import ob +import torch +from torch_geometric.data import Data as Data + +pybel.ob.OBMessageHandler().SetOutputLevel(0) # 0: None + + +ATOM_DICT = { + 6: 0, # C + 7: 1, # N + 8: 2, # O + 9: 3, # F + 15: 4, # P + 16: 5, # S + 17: 6, # Cl + 35: 7, # Br + 53: 8, # I + -1: 9, # UNKNOWN +} +NUM_ATOM_FEATURES = 10 + 2 + 2 + +BOND_DICT = { + 1: 0, + 2: 1, + 3: 2, + 1.5: 3, # AROMATIC + -1: 4, # UNKNOWN +} +NUM_BOND_FEATURES = 5 + + +def smi2graph(smiles: str) -> Data: + return Data(**smi2graphdata(smiles)) + + +def smi2graphdata(smiles: str) -> dict[str, torch.Tensor]: + pbmol = pybel.readstring("smi", smiles) + atom_features = get_atom_features(pbmol) + edge_attr, edge_index = get_bond_features(pbmol) + return dict( + x=torch.FloatTensor(atom_features), + edge_index=torch.LongTensor(edge_index), + edge_attr=torch.FloatTensor(edge_attr), + ) + + +def get_atom_features(pbmol: pybel.Molecule) -> list[list[float]]: + facade = pybel.ob.OBStereoFacade(pbmol.OBMol) + features = [] + for atom in pbmol.atoms: + feat = [0] * NUM_ATOM_FEATURES + feat[ATOM_DICT.get(atom.atomicnum, 9)] = 1 + + mid = atom.OBAtom.GetId() + if facade.HasTetrahedralStereo(mid): + stereo = facade.GetTetrahedralStereo(mid).GetConfig().winding + if stereo == pybel.ob.OBStereo.Clockwise: + feat[10] = 1 + else: + feat[11] = 1 + charge = atom.formalcharge + if charge > 0: + feat[12] = 1 + elif charge < 0: + feat[13] = 1 + features.append(feat) + return features + + +def get_bond_features(pbmol: pybel.Molecule) -> tuple[list[list[float]], tuple[list[int], list[int]]]: + edge_index_row = [] + edge_index_col = [] + edge_attr = [] + obmol: ob.OBMol = pbmol.OBMol + for obbond in ob.OBMolBondIter(obmol): + obbond: ob.OBBond + edge_index_row.append(obbond.GetBeginAtomIdx() - 1) + edge_index_col.append(obbond.GetEndAtomIdx() - 1) + + feat = [0] * NUM_BOND_FEATURES + if obbond.IsAromatic(): + feat[3] = 1 + else: + feat[BOND_DICT.get(obbond.GetBondOrder(), 4)] = 1 + edge_attr.append(feat) + edge_index = (edge_index_row, edge_index_col) + return edge_attr, edge_index diff --git a/src/pmnet_appl/scripts/train_debug.py b/src/pmnet_appl/scripts/train_debug.py new file mode 100644 index 0000000..d78825e --- /dev/null +++ b/src/pmnet_appl/scripts/train_debug.py @@ -0,0 +1,18 @@ +from pmnet_appl.docking_reward.config import Config +from pmnet_appl.docking_reward.trainer import Trainer + + +if __name__ == "__main__": + config = Config() + config.data.protein_dir = "./dataset/protein/" + config.data.train_protein_code_path = "./dataset/train_key.txt" + config.data.ligand_path = "./dataset/ligand.pkl" + config.train.max_iterations = 100 + config.train.batch_size = 16 + config.train.num_workers = 4 + config.train.log_every = 1 + config.train.print_every = 1 + config.train.val_every = 10 + config.log_dir = "./result/debug" + trainer = Trainer(config, device="cuda") + trainer.fit() diff --git a/test/maintain.sh b/test/maintain.sh deleted file mode 100644 index 90999f9..0000000 --- a/test/maintain.sh +++ /dev/null @@ -1,9 +0,0 @@ -/bin/rm -rf result/6oim -python modeling.py --cuda --pdb 6oim -python modeling.py --cuda --pdb 6oim -c D -python modeling.py --cuda --pdb 6oim --ref_ligand ./result/6oim/6oim_B_MG.pdb -python modeling.py --cuda --protein ./result/6oim/6oim.pdb --prefix 6oim -python modeling.py --cuda --protein ./result/6oim/6oim.pdb --ref_ligand ./result/6oim/6oim_B_MG.pdb --prefix 6oim -python modeling.py --cuda --protein ./result/6oim/6oim.pdb --center 1.872 -8.260 -1.361 --prefix 6oim -python screening.py -p ./result/6oim/6oim_D_MOV_model.pm -d ./examples/library/ --cpus 4 -o tmp.csv -python feature_extraction.py --cuda -p ./result/6oim/6oim.pdb --center 1.872 -8.260 -1.361 -o tmp.pt