From 3adba49a8f9d259bd4ecb7eb90fdd72a4b751bae Mon Sep 17 00:00:00 2001 From: SeonghwanSeo Date: Thu, 1 Aug 2024 22:04:04 +0900 Subject: [PATCH 1/2] update feature extraction interface --- feature_extraction.py | 74 +++++++++-------- src/pmnet/module.py | 123 +++++++++-------------------- src/pmnet/pharmacophore_model.py | 10 +-- src/pmnet/utils/download_weight.py | 11 ++- 4 files changed, 79 insertions(+), 139 deletions(-) diff --git a/feature_extraction.py b/feature_extraction.py index 6542e3d..02173e8 100644 --- a/feature_extraction.py +++ b/feature_extraction.py @@ -24,49 +24,47 @@ def __init__(self): self.add_argument("--cuda", action="store_true", help="use gpu acceleration with CUDA") -""" -return tuple[multi_scale_features, hotspot_info] - multi_scale_features: list[torch.Tensor]: - - [96, 4, 4, 4], [96, 8, 8, 8], [96, 16, 16, 16], [96, 32, 32, 32], [96, 64, 64, 64] - hotspot_info - - hotspot_feature: torch.Tensor (192,) - - hotspot_position: tuple[float, float, float] - (x, y, z) - - hotspot_score: float in [0, 1] +def main(args): + """ + return tuple[multi_scale_features, hotspot_info] + multi_scale_features: list[torch.Tensor]: + - [96, 4, 4, 4], [96, 8, 8, 8], [96, 16, 16, 16], [96, 32, 32, 32], [96, 64, 64, 64] + hotspot_info + - hotspot_feature: torch.Tensor (192,) + - hotspot_position: tuple[float, float, float] - (x, y, z) + - hotspot_score: float in [0, 1] - - nci_type: str (10 types) - 'Hydrophobic': Hydrophobic interaction - 'PiStacking_P': PiStacking (Parallel) - 'PiStacking_T': PiStacking (T-shaped) - 'PiCation_lring': Interaction btw Protein Cation & Ligand Aromatic Ring - 'PiCation_pring': Interaction btw Protein Aromatic Ring & Ligand Cation - 'SaltBridge_pneg': SaltBridge btw Protein Anion & Ligand Cation - 'SaltBridge_lneg': SaltBridge btw Protein Cation & Ligand Anion - 'XBond': Halogen Bond - 'HBond_pdon': Hydrogen Bond btw Protein Donor & Ligand Acceptor - 'HBond_ldon': Hydrogen Bond btw Protein Acceptor & Ligand Donor + - nci_type: str (10 types) + 'Hydrophobic': Hydrophobic interaction + 'PiStacking_P': PiStacking (Parallel) + 'PiStacking_T': PiStacking (T-shaped) + 'PiCation_lring': Interaction btw Protein Cation & Ligand Aromatic Ring + 'PiCation_pring': Interaction btw Protein Aromatic Ring & Ligand Cation + 'SaltBridge_pneg': SaltBridge btw Protein Anion & Ligand Cation + 'SaltBridge_lneg': SaltBridge btw Protein Cation & Ligand Anion + 'XBond': Halogen Bond + 'HBond_pdon': Hydrogen Bond btw Protein Donor & Ligand Acceptor + 'HBond_ldon': Hydrogen Bond btw Protein Acceptor & Ligand Donor - - hotspot_type: str (7 types) - {'Hydrophobic', 'Aromatic', 'Cation', 'Anion', - 'Halogen', 'HBond_donor', 'HBond_acceptor'} - *** `type` is obtained from `nci_type`. - - point_type: str (7 types) - {'Hydrophobic', 'Aromatic', 'Cation', 'Anion', - 'Halogen', 'HBond_donor', 'HBond_acceptor'} - *** `type` is obtained from `nci_type`. -] -""" + - hotspot_type: str (7 types) + {'Hydrophobic', 'Aromatic', 'Cation', 'Anion', + 'Halogen', 'HBond_donor', 'HBond_acceptor'} + *** `type` is obtained from `nci_type`. + - point_type: str (7 types) + {'Hydrophobic', 'Aromatic', 'Cation', 'Anion', + 'Halogen', 'HBond_donor', 'HBond_acceptor'} + *** `type` is obtained from `nci_type`. + ] + """ + device = "cuda" if args.cuda else "cpu" + score_threshold = 0.5 # NOTE: RECOMMENDED_SCORE_THRESHOLD - -# NOTE: RECOMMENDED -RECOMMENDED_SCORE_THRESHOLD = 0.5 + module = PharmacoNet(device, score_threshold) + multi_scale_features, hotspot_infos = module.feature_extraction(args.protein, args.ref_ligand, args.center) + torch.save([multi_scale_features, hotspot_infos], args.out) if __name__ == "__main__": parser = ArgParser() args = parser.parse_args() - module = PharmacoNet( - device="cuda" if args.cuda else "cpu", - score_threshold=RECOMMENDED_SCORE_THRESHOLD, - ) - multi_scale_features, hotspot_infos = module.feature_extraction(args.protein, args.ref_ligand, args.center) - torch.save([multi_scale_features, hotspot_infos], args.out) + main(args) diff --git a/src/pmnet/module.py b/src/pmnet/module.py index f15853f..f007039 100644 --- a/src/pmnet/module.py +++ b/src/pmnet/module.py @@ -57,7 +57,8 @@ def __init__( ): running_path = Path(__file__) weight_path = running_path.parent / "weights" / "model.tar" - download_pretrained_model(weight_path) + if not weight_path.exists(): + download_pretrained_model(weight_path, verbose) checkpoint = torch.load(weight_path, map_location="cpu") self.config = config = OmegaConf.create(checkpoint["config"]) self.device = device @@ -118,7 +119,7 @@ def feature_extraction( protein_pdb_path: str, ref_ligand_path: str | None = None, center: ArrayLike | None = None, - return_density: bool = False, + return_numpy: bool = True, ) -> tuple[list[Tensor | NDArray[np.float32]], list[dict[str, Any]]]: if center is not None: center_array = np.array(center, dtype=np.float32) @@ -130,7 +131,8 @@ def feature_extraction( center_array = np.mean([atom.coords for atom in ref_ligand.atoms], axis=0, dtype=np.float32) assert center_array is not None assert center_array.shape == (3,) - return self._feature_extraction(protein_pdb_path, center_array, return_density) + multi_scale_features, hotspot_infos = self._feature_extraction(protein_pdb_path, center_array, return_numpy) + return multi_scale_features, hotspot_infos @torch.no_grad() def _run( @@ -322,28 +324,24 @@ def _create_density_maps( ) return hotspot_infos - def print_log(self, level, log): - if self.logger is None: - return None - if level == "debug": - self.logger.debug(log) - elif level == "info": - self.logger.info(log) - def _feature_extraction( self, protein_pdb_path: str, center: NDArray[np.float32], - return_density: bool, + return_numpy: bool = True, ) -> tuple[list[Tensor | NDArray[np.float32]], list[dict[str, Any]]]: protein_image, mask, token_positions, tokens = self._parse_protein(protein_pdb_path, center) - return self._run_feature_extraction( + multi_scale_features, hotspot_infos = self._run_feature_extraction( torch.from_numpy(protein_image), (torch.from_numpy(mask) if mask is not None else None), torch.from_numpy(token_positions), torch.from_numpy(tokens), - return_density, ) + if return_numpy: + multi_scale_features = [feat.cpu().numpy() for feat in multi_scale_features] + for info in hotspot_infos: + info["hotspot_feature"] = info["hotspot_feature"].cpu().numpy() + return multi_scale_features, hotspot_features def _run_feature_extraction( self, @@ -351,7 +349,6 @@ def _run_feature_extraction( mask: Tensor | None, token_positions: Tensor, tokens: Tensor, - return_density: bool, ) -> tuple[list[Tensor], list[dict[str, Any]]]: protein_image = protein_image.to(device=self.device, dtype=torch.float) token_positions = token_positions.to(device=self.device, dtype=torch.float) @@ -359,10 +356,6 @@ def _run_feature_extraction( mask = mask.to(device=self.device, dtype=torch.bool) if mask is not None else None with torch.amp.autocast(self.device, enabled=self.config.AMP_ENABLE): - self.print_log( - "debug", - f"Protein-based Pharmacophore Modeling... (device: {self.device})", - ) protein_image = protein_image.unsqueeze(0) multi_scale_features = self.model.forward_feature(protein_image) # List[[1, D, H, W, F]] bottom_features = multi_scale_features[-1] @@ -386,77 +379,35 @@ def _run_feature_extraction( if relative_score < self.score_threshold[INTERACTION_LIST[int(typ)]]: continue # NOTE: Check the token exists in cavity - if typ in C.LONG_INTERACTION: - if not cavity_wide[0, x, y, z]: - continue - else: - if not cavity_narrow[0, x, y, z]: - continue + _cavity = cavity_wide if typ in C.LONG_INTERACTION else cavity_narrow + if not _cavity[0, x, y, z]: + continue indices.append(i) relative_scores.append(relative_score) - selected_indices = torch.tensor(indices, device=self.device, dtype=torch.long) # [Ntoken',] - - hotspots = tokens[selected_indices] # [Ntoken',] - hotspot_positions = token_positions[selected_indices] # [Ntoken', 3] - hotspot_features = token_features[selected_indices] # [Ntoken', F] - - if return_density: - density_maps_list = [] - if self.device == "cpu": - step = 1 - else: - step = 4 - for idx in range(0, hotspots.size(0), step): - _hotspots, _hotspot_features = ( - hotspots[idx : idx + step], - hotspot_features[idx : idx + step], - ) - density_maps = self.model.forward_segmentation( - multi_scale_features, [_hotspots], [_hotspot_features] - )[0] # [[4, D, H, W]] - density_maps = density_maps[0].sigmoid() # [4, D, H, W] - density_maps_list.append(density_maps) - - if len(density_maps_list) > 0: - density_maps = torch.cat(density_maps_list, dim=0) # [Ntoken', D, H, W] - box_area = token_inference.get_box_area( - hotspots, - self.config.VOXEL.RADII.PHARMACOPHORE, - self.out_resolution, - self.out_size, - ) - box_area = torch.from_numpy(box_area).to(device=self.device, dtype=torch.bool) # [Ntoken', D, H, W] - unavailable_area = ~(box_area & mask & cavity_narrow) # [Ntoken', D, H, W] - - # NOTE: masking should be performed before smoothing - masked area is not trained. - density_maps.masked_fill_(unavailable_area, 0.0) - density_maps = self.smoothing(density_maps) - density_maps.masked_fill_(unavailable_area, 0.0) - density_maps[density_maps < self.box_threshold] = 0.0 - else: - density_maps = [] - else: - density_maps = [None] * len(hotspots) + hotspots = tokens[indices] # [Ntoken',] + hotspot_positions = token_positions[indices] # [Ntoken', 3] + hotspot_features = token_features[indices] # [Ntoken', F] hotspot_infos = [] assert len(hotspots) == len(relative_scores) - for hotspot, score, position, feature, map in zip( - hotspots, relative_scores, hotspot_positions, hotspot_features, density_maps, strict=True - ): - if map is not None: - if torch.all(map < 1e-6): - continue + for hotspot, score, position, feature in zip(hotspots, relative_scores, hotspot_positions, hotspot_features): interaction_type = INTERACTION_LIST[int(hotspot[3])] - info = { - "nci_type": interaction_type, - "hotspot_type": INTERACTION_TO_HOTSPOT[interaction_type], - "hotspot_feature": feature, - "hotspot_position": tuple(position.tolist()), - "hotspot_score": float(score), - "point_type": INTERACTION_TO_PHARMACOPHORE[interaction_type], - } - if map is not None: - info["point_map"] = map - hotspot_infos.append(info) - multi_scale_features = [feature.squeeze(0) for feature in multi_scale_features] + hotspot_infos.append( + { + "nci_type": interaction_type, + "hotspot_type": INTERACTION_TO_HOTSPOT[interaction_type], + "hotspot_feature": feature, + "hotspot_position": tuple(position.tolist()), + "hotspot_score": float(score), + "point_type": INTERACTION_TO_PHARMACOPHORE[interaction_type], + } + ) return multi_scale_features, hotspot_infos + + def print_log(self, level, log): + if self.logger is None: + return None + if level == "debug": + self.logger.debug(log) + elif level == "info": + self.logger.info(log) diff --git a/src/pmnet/pharmacophore_model.py b/src/pmnet/pharmacophore_model.py index f8d63ad..48a04ab 100644 --- a/src/pmnet/pharmacophore_model.py +++ b/src/pmnet/pharmacophore_model.py @@ -116,18 +116,10 @@ def create( hotspot_infos: list[dict], ): graph = DensityMapGraph(center, resolution, size) - for node in hotspot_infos: - hotspot_type = node["nci_type"] - hotspot_pos = tuple(node["hotspot_position"].tolist()) - hotspot_score = float(node["hotspot_score"]) - map = node["point_map"] - graph.add_node(hotspot_type, hotspot_pos, hotspot_score, map) + graph.add_node(node["type"], node["position"], node["score"], node["map"]) graph.setup() - return cls.create_from_graph(pdbblock, graph) - @classmethod - def create_from_graph(cls, pdbblock: str, graph: DensityMapGraph): model = cls() model.pdbblock = pdbblock model.nodes = [ModelNode.create(model, node) for node in graph.nodes] diff --git a/src/pmnet/utils/download_weight.py b/src/pmnet/utils/download_weight.py index 82e4363..9c3c98c 100644 --- a/src/pmnet/utils/download_weight.py +++ b/src/pmnet/utils/download_weight.py @@ -1,9 +1,8 @@ -import logging import os from pathlib import Path -def download_pretrained_model(weight_path): +def download_pretrained_model(weight_path, verbose): if not os.path.exists(weight_path): weight_path = Path(weight_path) weight_path.parent.mkdir(exist_ok=True) @@ -15,12 +14,12 @@ def download_pretrained_model(weight_path): subprocess.check_call([sys.executable, "-m", "pip", "install", "gdown"]) import gdown - logging.debug(f"Download pre-trained model... (path: {weight_path})") + if verbose: + print(f"Download pre-trained model... (path: {weight_path})") gdown.download( "https://drive.google.com/uc?id=1gzjdM7bD3jPm23LBcDXtkSk18nETL04p", str(weight_path), quiet=False, ) - logging.debug(f"Download pre-trained model finish") - else: - logging.debug(f"Load pre-trained model (path: {weight_path})") + if verbose: + print("Download pre-trained model finish") From 59fcba4835c1972964200e23a869debbe94bd676 Mon Sep 17 00:00:00 2001 From: SeonghwanSeo Date: Wed, 7 Aug 2024 22:10:01 +0900 Subject: [PATCH 2/2] update --- .gitignore | 2 + README.md | 13 +- environment.yml | 7 - pyproject.toml | 13 +- src/pmnet/__init__.py | 2 +- src/pmnet/data/extract_pocket.py | 71 ++++- src/pmnet/data/token_inference.py | 34 +-- src/pmnet/module.py | 475 +++++++++++++----------------- src/pmnet/network/builder.py | 16 +- src/pmnet/pharmacophore_model.py | 10 +- src/pmnet/utils/density_map.py | 2 +- 11 files changed, 314 insertions(+), 331 deletions(-) diff --git a/.gitignore b/.gitignore index a3d38c1..bc76713 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,8 @@ weights run.sh result/ examples/library/ +nogit/ +test.sh # Byte-compiled / optimized / DLL files diff --git a/README.md b/README.md index 63b9ab9..8b2c5ad 100644 --- a/README.md +++ b/README.md @@ -212,13 +212,22 @@ OUTPUT=(multi_scale_features, hotspot_info) For feature extraction, it is recommended to use `score_threshold=0.5` instead of default setting used for pharmacophore modeling. If you want to extract more features, decrease the `score_threshold`. ```python -from pmnet.module import PharmacoNet +from pmnet.module import PharmacoNet, parse_protein module = PharmacoNet( "cuda", - score_threshold = 0.5 # , + score_threshold = 0.5, # , + molvoxel_library = 'numpy' # ) +# End-to-End calculation multi_scale_features, hotspot_infos = module.feature_extraction(, ) multi_scale_features, hotspot_infos = module.feature_extraction(, center=(, , )) + +# Step-wise calculation +voxelizer = module.voxelizer +# In Dataset (Type: Tuple[Tensor, Tensor, Tensor, Tensor]) +protein_data = module.parse_protein(voxelizer, , , ) +# In Model +multi_scale_features, hotspot_infos = module.run_extraction(protein_data) ``` ### Paper List diff --git a/environment.yml b/environment.yml index a3afa51..11c50db 100644 --- a/environment.yml +++ b/environment.yml @@ -7,10 +7,3 @@ dependencies: - openbabel=3.1.1 - pymol-open-source=3.0.0 - numpy=1.26.4 - - pip: - - tqdm - - molvoxel==0.1.3 - - numba==0.59.1 - - omegaconf==2.3.0 - - gdown==5.1.0 - - biopython==1.83 diff --git a/pyproject.toml b/pyproject.toml index 56b112b..258d994 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "pharmaconet" -version = "2.0.1" +version = "2.0.2" description = "PharmacoNet: Open-Source Software for Protein-based Pharmacophore Modeling and Virtual Screening" license = { text = "MIT" } authors = [{ name = "Seonghwan Seo", email = "shwan0106@kaist.ac.kr" }] @@ -24,13 +24,14 @@ classifiers = [ ] dependencies = [ + "tqdm", "torch>=1.13.0", - "numpy==1.26.4", - "numba==0.59.1", + "numpy>=1.26,<1.27", + "numba>=0.59", "omegaconf>=2.3.0", - "molvoxel==0.1.3", - "gdown==5.1.0", - "biopython==1.83" + "molvoxel>=0.1.3", + "gdown>=5.1.0", + "biopython>=1.83" ] [project.urls] diff --git a/src/pmnet/__init__.py b/src/pmnet/__init__.py index 9ee30ba..fef68fb 100644 --- a/src/pmnet/__init__.py +++ b/src/pmnet/__init__.py @@ -1,6 +1,6 @@ from .pharmacophore_model import PharmacophoreModel -__version__ = "2.0.1" +__version__ = "2.0.2" __citation_information__ = ( "Seo, S., & Kim, W. Y. (2023, December). " "PharmacoNet: Accelerating Large-Scale Virtual Screening by Deep Pharmacophore Modeling. " diff --git a/src/pmnet/data/extract_pocket.py b/src/pmnet/data/extract_pocket.py index 1712fa8..fb847ae 100644 --- a/src/pmnet/data/extract_pocket.py +++ b/src/pmnet/data/extract_pocket.py @@ -1,20 +1,63 @@ import os import numpy as np +import math from Bio.PDB import PDBParser, PDBIO from Bio.PDB.PDBIO import Select +from typing import Union from numpy.typing import ArrayLike +from pathlib import Path import warnings + warnings.filterwarnings("ignore") AMINO_ACID = [ - 'GLY', 'ALA', 'VAL', 'LEU', 'ILE', 'PRO', 'PHE', 'TYR', 'TRP', 'SER', - 'THR', 'CYS', 'MET', 'ASN', 'GLN', 'ASP', 'GLU', 'LYS', 'ARG', 'HIS', - 'HIP', 'HIE', 'TPO', 'HID', 'LEV', 'MEU', 'PTR', 'GLV', 'CYT', 'SEP', - 'HIZ', 'CYM', 'GLM', 'ASQ', 'TYS', 'CYX', 'GLZ', 'MSE', 'CSO', 'KCX', - 'CSD', 'MLY', 'PCA', 'LLP' + "GLY", + "ALA", + "VAL", + "LEU", + "ILE", + "PRO", + "PHE", + "TYR", + "TRP", + "SER", + "THR", + "CYS", + "MET", + "ASN", + "GLN", + "ASP", + "GLU", + "LYS", + "ARG", + "HIS", + "HIP", + "HIE", + "TPO", + "HID", + "LEV", + "MEU", + "PTR", + "GLV", + "CYT", + "SEP", + "HIZ", + "CYM", + "GLM", + "ASQ", + "TYS", + "CYX", + "GLZ", + "MSE", + "CSO", + "KCX", + "CSD", + "MLY", + "PCA", + "LLP", ] @@ -28,11 +71,9 @@ def accept_residue(self, residue): return 0 if residue.get_resname() not in AMINO_ACID: return 0 - residue_positions = np.array([ - list(atom.get_vector()) - for atom in residue.get_atoms() - if "H" not in atom.get_id() - ]) + residue_positions = np.array( + [list(atom.get_vector()) for atom in residue.get_atoms() if "H" not in atom.get_id()] + ) if residue_positions.shape[0] == 0: return 0 min_dis = np.min(np.linalg.norm(residue_positions - self.center, axis=-1)) @@ -42,14 +83,14 @@ def accept_residue(self, residue): return 0 +DEFAULT_CUTOFF = 16 * math.sqrt(3) + 5.0 + + def extract_pocket( - protein_pdb_path: str, - out_pocket_pdb_path: str, - center: ArrayLike, - cutoff: float + protein_pdb_path: Union[str, Path], out_pocket_pdb_path: str, center: ArrayLike, cutoff: float = DEFAULT_CUTOFF ): parser = PDBParser() - structure = parser.get_structure("protein", protein_pdb_path) + structure = parser.get_structure("protein", str(protein_pdb_path)) io = PDBIO() io.set_structure(structure) io.save(out_pocket_pdb_path, DistSelect(center, cutoff)) diff --git a/src/pmnet/data/token_inference.py b/src/pmnet/data/token_inference.py index 552fa70..70b0a61 100644 --- a/src/pmnet/data/token_inference.py +++ b/src/pmnet/data/token_inference.py @@ -8,9 +8,7 @@ from . import constant as C -def get_token_informations( - protein_obj: Protein, -) -> Tuple[NDArray[np.float32], NDArray[np.int16]]: +def get_token_informations(protein_obj: Protein) -> Tuple[NDArray[np.float32], NDArray[np.int16]]: """get token information Args: @@ -20,14 +18,15 @@ def get_token_informations( token_positions: [float, (N, 3)] token center positions token_classes: [int, (N,)] token interaction type """ - num_tokens = \ - len(protein_obj.hydrophobic_atoms_all) + \ - len(protein_obj.rings_all) * 3 + \ - len(protein_obj.hbond_donors_all) + \ - len(protein_obj.hbond_acceptors_all) + \ - len(protein_obj.pos_charged_atoms_all) * 2 + \ - len(protein_obj.neg_charged_atoms_all) + \ - len(protein_obj.xbond_acceptors_all) + num_tokens = ( + len(protein_obj.hydrophobic_atoms_all) + + len(protein_obj.rings_all) * 3 + + len(protein_obj.hbond_donors_all) + + len(protein_obj.hbond_acceptors_all) + + len(protein_obj.pos_charged_atoms_all) * 2 + + len(protein_obj.neg_charged_atoms_all) + + len(protein_obj.xbond_acceptors_all) + ) positions: List[Tuple[float, float, float]] = [] classes: List[int] = [] @@ -83,8 +82,6 @@ def get_token_and_filter( positions: NDArray[np.float32], classes: NDArray[np.int16], center: NDArray[np.float32], - resolution: float, - dimension: int, ) -> Tuple[NDArray[np.int16], NDArray[np.int16]]: """Create token and Filtering valid instances @@ -99,6 +96,7 @@ def get_token_and_filter( token: [int, (N_token, 4)] filter: [int, (N_token,)] """ + resolution, dimension = 0.5, 64 filter = [] tokens = [] x_center, y_center, z_center = center @@ -116,12 +114,7 @@ def get_token_and_filter( return np.array(tokens, dtype=np.int16), np.array(filter, dtype=np.int16) -def get_box_area( - tokens: ArrayLike, - pharmacophore_size: float, - resolution: float, - dimension: int, -) -> NDArray[np.bool_]: +def get_box_area(tokens: ArrayLike) -> NDArray[np.bool_]: """Create Box Area Args: @@ -132,9 +125,10 @@ def get_box_area( Returns: box_areas: BoolArray [Ntoken, D, H, W] D=H=W=dimension """ + resolution, dimension, pharmacophore_size = 0.5, 64, 1.0 num_tokens = len(tokens) box_areas = np.zeros((num_tokens, dimension, dimension, dimension), dtype=np.bool_) - grids = np.stack(np.meshgrid(np.arange(dimension), np.arange(dimension), np.arange(dimension), indexing='ij'), 3) + grids = np.stack(np.meshgrid(np.arange(dimension), np.arange(dimension), np.arange(dimension), indexing="ij"), 3) for i, (x, y, z, t) in enumerate(tokens): x, y, z, t = int(x), int(y), int(z), int(t) distances = np.linalg.norm(grids - np.array([[x, y, z]]), axis=-1) diff --git a/src/pmnet/module.py b/src/pmnet/module.py index f007039..8ed4eb9 100644 --- a/src/pmnet/module.py +++ b/src/pmnet/module.py @@ -1,18 +1,19 @@ +from __future__ import annotations import os import tempfile -import math import logging from pathlib import Path +from importlib.util import find_spec import tqdm -from omegaconf import OmegaConf from openbabel import pybel import torch import numpy as np -from typing import Any +from omegaconf import OmegaConf +from typing import Any from torch import Tensor -from numpy.typing import NDArray, ArrayLike +from numpy.typing import NDArray from molvoxel import create_voxelizer, BaseVoxelizer @@ -25,13 +26,8 @@ from pmnet.data.extract_pocket import extract_pocket from pmnet.utils.smoothing import GaussianSmoothing from pmnet.utils.download_weight import download_pretrained_model - from pmnet.pharmacophore_model import PharmacophoreModel, INTERACTION_TO_PHARMACOPHORE, INTERACTION_TO_HOTSPOT -from importlib.util import find_spec - -MOLVOXEL_LIBRARY = "numba" if find_spec("numba") is not None else "numpy" - DEFAULT_FOCUS_THRESHOLD = 0.5 DEFAULT_BOX_THRESHOLD = 0.5 DEFAULT_SCORE_THRESHOLD = { @@ -54,253 +50,178 @@ def __init__( device: str = "cpu", score_threshold: float | dict[str, float] | None = DEFAULT_SCORE_THRESHOLD, verbose: bool = True, + molvoxel_library: str = "numba", ): + """ + device: 'cpu' | 'cuda' + score_threshold: float | dict[str, float] | None + custom threshold to identify hotspots. + For feature extraction, recommended value is '0.5' + molvoxel_library: str + If you want to use PharmacoNet in DL model training, recommend to use 'numpy' + """ + assert molvoxel_library in ["numpy", "numba"] + if molvoxel_library == "numba" and (not find_spec("numba")): + molvoxel_library = "numpy" + running_path = Path(__file__) weight_path = running_path.parent / "weights" / "model.tar" if not weight_path.exists(): download_pretrained_model(weight_path, verbose) checkpoint = torch.load(weight_path, map_location="cpu") - self.config = config = OmegaConf.create(checkpoint["config"]) - self.device = device + config = OmegaConf.create(checkpoint["config"]) model = build_model(config.MODEL) model.load_state_dict(checkpoint["model"]) model.eval() self.model: PharmacoFormer = model.to(device) - self.smoothing = GaussianSmoothing(kernel_size=5, sigma=0.5).to(device) - self.focus_threshold: float = DEFAULT_FOCUS_THRESHOLD - self.box_threshold: float = DEFAULT_BOX_THRESHOLD + self.device = device self.score_distributions = { typ: np.array(distribution["focus"]) for typ, distribution in checkpoint["score_distributions"].items() } + del checkpoint + self.focus_threshold: float = DEFAULT_FOCUS_THRESHOLD + self.box_threshold: float = DEFAULT_BOX_THRESHOLD self.score_threshold: dict[str, float] if isinstance(score_threshold, dict): self.score_threshold = score_threshold - else: + elif isinstance(score_threshold, float): self.score_threshold = {typ: score_threshold for typ in INTERACTION_LIST} - del checkpoint + else: + self.score_threshold = DEFAULT_SCORE_THRESHOLD - in_resolution = config.VOXEL.IN.RESOLUTION - in_size = config.VOXEL.IN.SIZE - self.in_voxelizer: BaseVoxelizer = create_voxelizer( - in_resolution, in_size, sigma=(1 / 3), library=MOLVOXEL_LIBRARY + self.resolution = 0.5 + self.size = 64 + self.voxelizer: BaseVoxelizer = create_voxelizer( + self.resolution, self.size, sigma=(1 / 3), library=molvoxel_library ) - self.pocket_cutoff = (in_resolution * in_size * math.sqrt(3) / 2) + 5.0 - self.out_resolution = config.VOXEL.OUT.RESOLUTION - self.out_size = config.VOXEL.OUT.SIZE - + self.smoothing = GaussianSmoothing(kernel_size=5, sigma=0.5).to(device) if verbose: self.logger = logging.getLogger("PharmacoNet") else: self.logger = None + @torch.no_grad() def run( self, - protein_pdb_path: str, - ref_ligand_path: str | None = None, - center: ArrayLike | None = None, + protein_pdb_path: str | Path, + ref_ligand_path: str | Path | None = None, + center: tuple[float, float, float] | NDArray | None = None, ) -> PharmacophoreModel: assert (ref_ligand_path is not None) or (center is not None) - if center is not None: - center_array = np.array(center, dtype=np.float32) - else: - assert ref_ligand_path is not None - extension = os.path.splitext(ref_ligand_path)[1] - assert extension in [".sdf", ".pdb", ".mol2"] - ref_ligand = next(pybel.readfile(extension[1:], str(ref_ligand_path))) - center_array = np.mean([atom.coords for atom in ref_ligand.atoms], axis=0, dtype=np.float32) - assert center_array is not None - assert center_array.shape == (3,) - return self._run(protein_pdb_path, center_array) + center = self.get_center(ref_ligand_path, center) + protein_data = parse_protein(self.voxelizer, protein_pdb_path, center, 0.0, True) + hotspot_infos = self.create_density_maps(protein_data) + with open(protein_pdb_path) as f: + pdbblock: str = "\n".join(f.readlines()) + return PharmacophoreModel.create(pdbblock, center, hotspot_infos) @torch.no_grad() def feature_extraction( self, - protein_pdb_path: str, - ref_ligand_path: str | None = None, - center: ArrayLike | None = None, - return_numpy: bool = True, - ) -> tuple[list[Tensor | NDArray[np.float32]], list[dict[str, Any]]]: + protein_pdb_path: str | Path, + ref_ligand_path: str | Path | None = None, + center: tuple[float, float, float] | NDArray | None = None, + ) -> tuple[list[Tensor], list[dict[str, Any]]]: + assert (ref_ligand_path is not None) or (center is not None) + center = self.get_center(ref_ligand_path, center) + protein_data = parse_protein(self.voxelizer, protein_pdb_path, center, 0.0, True) + return self.run_extraction(protein_data) + + @staticmethod + def get_center( + ref_ligand_path: str | Path | None = None, + center: tuple[float, float, float] | NDArray | None = None, + ) -> tuple[float, float, float]: if center is not None: - center_array = np.array(center, dtype=np.float32) + assert len(center) == 3 + x, y, z = center else: assert ref_ligand_path is not None extension = os.path.splitext(ref_ligand_path)[1] assert extension in [".sdf", ".pdb", ".mol2"] ref_ligand = next(pybel.readfile(extension[1:], str(ref_ligand_path))) - center_array = np.mean([atom.coords for atom in ref_ligand.atoms], axis=0, dtype=np.float32) - assert center_array is not None - assert center_array.shape == (3,) - multi_scale_features, hotspot_infos = self._feature_extraction(protein_pdb_path, center_array, return_numpy) - return multi_scale_features, hotspot_infos + x, y, z = np.mean([atom.coords for atom in ref_ligand.atoms], axis=0, dtype=np.float32).to_list() + return float(x), float(y), float(z) @torch.no_grad() - def _run( - self, - protein_pdb_path: str, - center: NDArray[np.float32], - ): - protein_image, mask, token_positions, tokens = self._parse_protein(protein_pdb_path, center) - with open(protein_pdb_path) as f: - pdbblock: str = "\n".join(f.readlines()) - - hotspot_infos = self._create_density_maps( - torch.from_numpy(protein_image), - (torch.from_numpy(mask) if mask is not None else None), - torch.from_numpy(token_positions), - torch.from_numpy(tokens), - ) - x, y, z = center.tolist() - return PharmacophoreModel.create(pdbblock, (x, y, z), self.out_resolution, self.out_size, hotspot_infos) - - def _parse_protein( - self, - protein_pdb_path: str, - center: NDArray[np.float32], - pocket_extract: bool = True, - ) -> tuple[NDArray, NDArray | None, NDArray, NDArray]: - self.print_log("debug", "Extract Pocket...") - if pocket_extract: - with tempfile.TemporaryDirectory() as dirname: - pocket_path = os.path.join(dirname, "pocket.pdb") - extract_pocket(protein_pdb_path, pocket_path, center, self.pocket_cutoff) # root(3) - protein_obj: Protein = Protein.from_pdbfile(pocket_path) - self.print_log("debug", "Extract Pocket Finish") - else: - protein_obj: Protein = Protein.from_pdbfile(protein_pdb_path) - - token_positions, token_classes = token_inference.get_token_informations(protein_obj) - tokens, filter = token_inference.get_token_and_filter( - token_positions, token_classes, center, self.out_resolution, self.out_size - ) - token_positions = token_positions[filter] - - protein_positions, protein_features = pointcloud.get_protein_pointcloud(protein_obj) - - self.print_log("debug", "MolVoxel:Voxelize Pocket...") - protein_image = np.asarray( - self.in_voxelizer.forward_features( - protein_positions, - center, - protein_features, - radii=self.config.VOXEL.RADII.PROTEIN, - ), - np.float32, - ) - if self.config.VOXEL.RADII.PROTEIN_MASKING > 0: - mask = np.logical_not( - np.asarray( - self.in_voxelizer.forward_single( - protein_positions, - center, - radii=self.config.VOXEL.RADII.PROTEIN_MASKING, - ), - np.bool_, - ) - ) - else: - mask = None - self.print_log("debug", "MolVoxel:Voxelize Pocket Finish") - del protein_obj - - return protein_image, mask, token_positions, tokens - - def _create_density_maps( - self, - protein_image: Tensor, - mask: Tensor | None, - token_positions: Tensor, - tokens: Tensor, - ): + def create_density_maps(self, protein_data: tuple[Tensor, Tensor, Tensor, Tensor]): + protein_image, mask, token_positions, tokens = protein_data protein_image = protein_image.to(device=self.device, dtype=torch.float) token_positions = token_positions.to(device=self.device, dtype=torch.float) tokens = tokens.to(device=self.device, dtype=torch.long) - mask = mask.to(device=self.device, dtype=torch.bool) if mask is not None else None + mask = mask.to(device=self.device, dtype=torch.bool) - with torch.amp.autocast(self.device, enabled=self.config.AMP_ENABLE): - self.print_log( - "debug", - f"Protein-based Pharmacophore Modeling... (device: {self.device})", - ) - protein_image = protein_image.unsqueeze(0) - multi_scale_features = self.model.forward_feature(protein_image) # List[[1, D, H, W, F]] - bottom_features = multi_scale_features[-1] - - token_scores, token_features = self.model.forward_token_prediction(bottom_features, [tokens]) - token_scores = token_scores[0].sigmoid() # [Ntoken,] - token_features = token_features[0] # [Ntoken, F] - - cavity_narrow, cavity_wide = self.model.forward_cavity_extraction(bottom_features) - cavity_narrow = cavity_narrow[0].sigmoid() > self.focus_threshold # [1, D, H, W] - cavity_wide = cavity_wide[0].sigmoid() > self.focus_threshold # [1, D, H, W] - - num_tokens = tokens.shape[0] - indices = [] - relative_scores = [] - for i in range(num_tokens): - x, y, z, typ = tokens[i].tolist() - # NOTE: Check the token score - absolute_score = token_scores[i].item() - relative_score = float((self.score_distributions[INTERACTION_LIST[int(typ)]] < absolute_score).mean()) - if relative_score < self.score_threshold[INTERACTION_LIST[int(typ)]]: + self.print_log( + "debug", + f"Protein-based Pharmacophore Modeling... (device: {self.device})", + ) + protein_image = protein_image.unsqueeze(0) + multi_scale_features = self.model.forward_feature(protein_image) # List[[1, D, H, W, F]] + bottom_features = multi_scale_features[-1] + + token_scores, token_features = self.model.forward_token_prediction(bottom_features, [tokens]) + token_scores = token_scores[0].sigmoid() # [Ntoken,] + token_features = token_features[0] # [Ntoken, F] + + cavity_narrow, cavity_wide = self.model.forward_cavity_extraction(bottom_features) + cavity_narrow = cavity_narrow[0].sigmoid() > self.focus_threshold # [1, D, H, W] + cavity_wide = cavity_wide[0].sigmoid() > self.focus_threshold # [1, D, H, W] + + num_tokens = tokens.shape[0] + indices = [] + relative_scores = [] + for i in range(num_tokens): + x, y, z, typ = tokens[i].tolist() + # NOTE: Check the token score + absolute_score = token_scores[i].item() + relative_score = float((self.score_distributions[INTERACTION_LIST[int(typ)]] < absolute_score).mean()) + if relative_score < self.score_threshold[INTERACTION_LIST[int(typ)]]: + continue + # NOTE: Check the token exists in cavity + if typ in C.LONG_INTERACTION: + if not cavity_wide[0, x, y, z]: continue - # NOTE: Check the token exists in cavity - if typ in C.LONG_INTERACTION: - if not cavity_wide[0, x, y, z]: - continue - else: - if not cavity_narrow[0, x, y, z]: - continue - indices.append(i) - relative_scores.append(relative_score) - selected_indices = torch.tensor(indices, device=self.device, dtype=torch.long) # [Ntoken',] - hotspots = tokens[selected_indices] # [Ntoken',] - hotspot_positions = token_positions[selected_indices] # [Ntoken', 3] - hotspot_features = token_features[selected_indices] # [Ntoken', F] - del tokens - del token_positions - del token_features - - density_maps_list = [] - if self.device == "cpu": - step = 1 else: - step = 4 - with tqdm.tqdm( - desc="hotspots", - total=hotspots.size(0), - leave=False, - disable=(self.logger is None), - ) as pbar: - for idx in range(0, hotspots.size(0), step): - _hotspots, _hotspot_features = ( - hotspots[idx : idx + step], - hotspot_features[idx : idx + step], - ) - density_maps = self.model.forward_segmentation( - multi_scale_features, [_hotspots], [_hotspot_features] - )[0] # [[4, D, H, W]] - - density_maps = density_maps[0].sigmoid() # [4, D, H, W] - density_maps_list.append(density_maps) - pbar.update(len(_hotspots)) - - density_maps = torch.cat(density_maps_list, dim=0) # [Ntoken', D, H, W] - - box_area = token_inference.get_box_area( - hotspots, - self.config.VOXEL.RADII.PHARMACOPHORE, - self.out_resolution, - self.out_size, - ) - box_area = torch.from_numpy(box_area).to(device=self.device, dtype=torch.bool) # [Ntoken', D, H, W] - unavailable_area = ~(box_area & mask & cavity_narrow) # [Ntoken', D, H, W] - - # NOTE: masking should be performed before smoothing - masked area is not trained. - density_maps.masked_fill_(unavailable_area, 0.0) - density_maps = self.smoothing(density_maps) - density_maps.masked_fill_(unavailable_area, 0.0) - density_maps[density_maps < self.box_threshold] = 0.0 + if not cavity_narrow[0, x, y, z]: + continue + indices.append(i) + relative_scores.append(relative_score) + selected_indices = torch.tensor(indices, device=self.device, dtype=torch.long) # [Ntoken',] + hotspots = tokens[selected_indices] # [Ntoken',] + hotspot_positions = token_positions[selected_indices] # [Ntoken', 3] + hotspot_features = token_features[selected_indices] # [Ntoken', F] + del protein_image, tokens, token_positions, token_features + + density_maps_list = [] + if self.device == "cpu": + step = 1 + else: + step = 4 + with tqdm.tqdm( + desc="hotspots", + total=hotspots.size(0), + leave=False, + disable=(self.logger is None), + ) as pbar: + for idx in range(0, hotspots.size(0), step): + _hotspots, _hotspot_features = [hotspots[idx : idx + step]], [hotspot_features[idx : idx + step]] + density_maps = self.model.forward_segmentation(multi_scale_features, _hotspots, _hotspot_features)[0] + density_maps = density_maps[0].sigmoid() # [4, D, H, W] + density_maps_list.append(density_maps) + pbar.update(len(_hotspots)) + + density_maps = torch.cat(density_maps_list, dim=0) # [Ntoken', D, H, W] + + box_area = token_inference.get_box_area(hotspots) + box_area = torch.from_numpy(box_area).to(device=self.device, dtype=torch.bool) # [Ntoken', D, H, W] + unavailable_area = ~(box_area & mask & cavity_narrow) # [Ntoken', D, H, W] + + # NOTE: masking should be performed before smoothing - masked area is not trained. + density_maps.masked_fill_(unavailable_area, 0.0) + density_maps = self.smoothing(density_maps) + density_maps.masked_fill_(unavailable_area, 0.0) + density_maps[density_maps < self.box_threshold] = 0.0 hotspot_infos = [] assert len(hotspots) == len(relative_scores) @@ -324,69 +245,47 @@ def _create_density_maps( ) return hotspot_infos - def _feature_extraction( - self, - protein_pdb_path: str, - center: NDArray[np.float32], - return_numpy: bool = True, - ) -> tuple[list[Tensor | NDArray[np.float32]], list[dict[str, Any]]]: - protein_image, mask, token_positions, tokens = self._parse_protein(protein_pdb_path, center) - multi_scale_features, hotspot_infos = self._run_feature_extraction( - torch.from_numpy(protein_image), - (torch.from_numpy(mask) if mask is not None else None), - torch.from_numpy(token_positions), - torch.from_numpy(tokens), - ) - if return_numpy: - multi_scale_features = [feat.cpu().numpy() for feat in multi_scale_features] - for info in hotspot_infos: - info["hotspot_feature"] = info["hotspot_feature"].cpu().numpy() - return multi_scale_features, hotspot_features - - def _run_feature_extraction( - self, - protein_image: Tensor, - mask: Tensor | None, - token_positions: Tensor, - tokens: Tensor, + @torch.no_grad() + def run_extraction( + self, protein_data: tuple[Tensor, Tensor, Tensor, Tensor] ) -> tuple[list[Tensor], list[dict[str, Any]]]: + protein_image, mask, token_positions, tokens = protein_data protein_image = protein_image.to(device=self.device, dtype=torch.float) token_positions = token_positions.to(device=self.device, dtype=torch.float) tokens = tokens.to(device=self.device, dtype=torch.long) - mask = mask.to(device=self.device, dtype=torch.bool) if mask is not None else None - - with torch.amp.autocast(self.device, enabled=self.config.AMP_ENABLE): - protein_image = protein_image.unsqueeze(0) - multi_scale_features = self.model.forward_feature(protein_image) # List[[1, D, H, W, F]] - bottom_features = multi_scale_features[-1] - - token_scores, token_features = self.model.forward_token_prediction(bottom_features, [tokens]) - token_scores = token_scores[0].sigmoid() # [Ntoken,] - token_features = token_features[0] # [Ntoken, F] - - cavity_narrow, cavity_wide = self.model.forward_cavity_extraction(bottom_features) - cavity_narrow = cavity_narrow[0].sigmoid() > self.focus_threshold # [1, D, H, W] - cavity_wide = cavity_wide[0].sigmoid() > self.focus_threshold # [1, D, H, W] - - num_tokens = tokens.shape[0] - indices = [] - relative_scores = [] - for i in range(num_tokens): - x, y, z, typ = tokens[i].tolist() - # NOTE: Check the token score - absolute_score = token_scores[i].item() - relative_score = float((self.score_distributions[INTERACTION_LIST[int(typ)]] < absolute_score).mean()) - if relative_score < self.score_threshold[INTERACTION_LIST[int(typ)]]: - continue - # NOTE: Check the token exists in cavity - _cavity = cavity_wide if typ in C.LONG_INTERACTION else cavity_narrow - if not _cavity[0, x, y, z]: - continue - indices.append(i) - relative_scores.append(relative_score) - hotspots = tokens[indices] # [Ntoken',] - hotspot_positions = token_positions[indices] # [Ntoken', 3] - hotspot_features = token_features[indices] # [Ntoken', F] + mask = mask.to(device=self.device, dtype=torch.bool) + + protein_image = protein_image.unsqueeze(0) + multi_scale_features = self.model.forward_feature(protein_image) # List[[1, D, H, W, F]] + bottom_features = multi_scale_features[-1] + + token_scores, token_features = self.model.forward_token_prediction(bottom_features, [tokens]) + token_scores = token_scores[0].sigmoid() # [Ntoken,] + token_features = token_features[0] # [Ntoken, F] + + cavity_narrow, cavity_wide = self.model.forward_cavity_extraction(bottom_features) + cavity_narrow = cavity_narrow[0].sigmoid() > self.focus_threshold # [1, D, H, W] + cavity_wide = cavity_wide[0].sigmoid() > self.focus_threshold # [1, D, H, W] + + num_tokens = tokens.shape[0] + indices = [] + relative_scores = [] + for i in range(num_tokens): + x, y, z, typ = tokens[i].tolist() + # NOTE: Check the token score + absolute_score = token_scores[i].item() + relative_score = float((self.score_distributions[INTERACTION_LIST[int(typ)]] < absolute_score).mean()) + if relative_score < self.score_threshold[INTERACTION_LIST[int(typ)]]: + continue + # NOTE: Check the token exists in cavity + _cavity = cavity_wide if typ in C.LONG_INTERACTION else cavity_narrow + if not _cavity[0, x, y, z]: + continue + indices.append(i) + relative_scores.append(relative_score) + hotspots = tokens[indices] # [Ntoken',] + hotspot_positions = token_positions[indices] # [Ntoken', 3] + hotspot_features = token_features[indices] # [Ntoken', F] hotspot_infos = [] assert len(hotspots) == len(relative_scores) @@ -402,6 +301,7 @@ def _run_feature_extraction( "point_type": INTERACTION_TO_PHARMACOPHORE[interaction_type], } ) + del protein_image, mask, token_positions, tokens return multi_scale_features, hotspot_infos def print_log(self, level, log): @@ -411,3 +311,42 @@ def print_log(self, level, log): self.logger.debug(log) elif level == "info": self.logger.info(log) + + +# NOTE: For DL Model Training +def parse_protein( + voxelizer: BaseVoxelizer, + protein_pdb_path: str | Path, + center: NDArray[np.float32] | tuple[float, float, float], + center_noise: float = 0.0, + pocket_extract: bool = True, +) -> tuple[Tensor, Tensor, Tensor, Tensor]: + if isinstance(center, tuple): + center = np.array(center, dtype=np.float32) + if center_noise > 0: + center = center + (np.random.rand(3) * 2 - 1) * center_noise + + if pocket_extract: + with tempfile.TemporaryDirectory() as dirname: + pocket_path = os.path.join(dirname, "pocket.pdb") + extract_pocket(protein_pdb_path, pocket_path, center) + protein_obj: Protein = Protein.from_pdbfile(pocket_path) + else: + protein_obj: Protein = Protein.from_pdbfile(protein_pdb_path) + + token_positions, token_classes = token_inference.get_token_informations(protein_obj) + tokens, filter = token_inference.get_token_and_filter(token_positions, token_classes, center) + token_positions = token_positions[filter] + + protein_positions, protein_features = pointcloud.get_protein_pointcloud(protein_obj) + protein_image = np.asarray( + voxelizer.forward_features(protein_positions, center, protein_features, radii=1.5), np.float32 + ) + mask = np.logical_not(np.asarray(voxelizer.forward_single(protein_positions, center, radii=1.0), np.bool_)) + del protein_obj + return ( + torch.from_numpy(protein_image), + torch.from_numpy(mask), + torch.from_numpy(token_positions), + torch.from_numpy(tokens), + ) diff --git a/src/pmnet/network/builder.py b/src/pmnet/network/builder.py index 3af204e..d7c4c00 100644 --- a/src/pmnet/network/builder.py +++ b/src/pmnet/network/builder.py @@ -3,17 +3,17 @@ from .utils.registry import Registry -BACKBONE = Registry('Backbone') -NECK = Registry('Neck') -DECODER = Registry('Decoder') +BACKBONE = Registry("Backbone") +NECK = Registry("Neck") +DECODER = Registry("Decoder") -EMBEDDING = Registry('Embedding') -HEAD = Registry('Head') +EMBEDDING = Registry("Embedding") +HEAD = Registry("Head") -MODEL = Registry('Model') +MODEL = Registry("Model") def build_model(config: Dict) -> nn.Module: - registry_key = 'registry' - module_key = 'name' + registry_key = "registry" + module_key = "name" return Registry.build_from_config(config, registry_key, module_key, convert_key_to_lower_case=True, safe_build=True) diff --git a/src/pmnet/pharmacophore_model.py b/src/pmnet/pharmacophore_model.py index 48a04ab..2bc9cae 100644 --- a/src/pmnet/pharmacophore_model.py +++ b/src/pmnet/pharmacophore_model.py @@ -110,11 +110,15 @@ def _scoring( def create( cls, pdbblock: str, - center: tuple[float, float, float], - resolution: float, - size: int, + center: tuple[float, float, float] | NDArray, hotspot_infos: list[dict], + resolution: float = 0.5, + size: int = 64, ): + assert len(center) == 3 + if not isinstance(center, tuple): + x, y, z = center.tolist() + center = (x, y, z) graph = DensityMapGraph(center, resolution, size) for node in hotspot_infos: graph.add_node(node["type"], node["position"], node["score"], node["map"]) diff --git a/src/pmnet/utils/density_map.py b/src/pmnet/utils/density_map.py index 74af07d..2624c82 100644 --- a/src/pmnet/utils/density_map.py +++ b/src/pmnet/utils/density_map.py @@ -27,7 +27,7 @@ def coords_to_position(coords, center, resolution, size) -> tuple[float, float, class DensityMapGraph: - def __init__(self, center: tuple[float, float, float], resolution: float, size: int): + def __init__(self, center: tuple[float, float, float], resolution: float = 0.5, size: int = 64): self.center: tuple[float, float, float] = center self.resolution: float = resolution self.size: int = size