diff --git a/notebooks/predict_atomic.py b/notebooks/predict_atomic.py new file mode 100644 index 0000000..26c113c --- /dev/null +++ b/notebooks/predict_atomic.py @@ -0,0 +1,29 @@ +""" +An example script make predictions of any tensor. +""" + +from pymatgen.core import Structure + +from matten.predict import predict + + +def get_structure(): + a = 5.46 + lattice = [[0, a / 2, a / 2], [a / 2, 0, a / 2], [a / 2, a / 2, 0]] + basis = [[0.0, 0.0, 0.0], [0.25, 0.25, 0.25]] + Si = Structure(lattice, ["Si", "Si"], basis) + + return Si + + +if __name__ == "__main__": + structure = get_structure() + + # predict for one structure + tensors = predict( + structure, + model_identifier="/Users/mjwen.admin/Downloads/trained", + checkpoint="epoch=9-step=100.ckpt", + is_atomic_tensor=True, + ) + print("value:", tensors) diff --git a/scripts/configs/pizeoelectric.yaml b/scripts/configs/pizeoelectric.yaml deleted file mode 100644 index ebae2c0..0000000 --- a/scripts/configs/pizeoelectric.yaml +++ /dev/null @@ -1,116 +0,0 @@ -seed_everything: 35 -log_level: info - -data: - root: . - tensor_target_name: piezoelectric_tensor_total - tensor_target_formula: ijk=ikj - trainset_filename: /Users/mjwen.admin/Documents/Dataset/di_pizeoelectric_tensor/piezoelectric_tensors_n20.json - valset_filename: /Users/mjwen.admin/Documents/Dataset/di_pizeoelectric_tensor/piezoelectric_tensors_n20.json - testset_filename: /Users/mjwen.admin/Documents/Dataset/di_pizeoelectric_tensor/piezoelectric_tensors_n20.json - r_cut: 5.0 - reuse: false - loader_kwargs: - batch_size: 32 - shuffle: true - -model: - ########## - # embedding - ########## - - # atom species embedding - species_embedding_dim: 16 - - # spherical harmonics embedding of edge direction - irreps_edge_sh: 0e + 1o + 2e + 3o - - # radial edge distance embedding - radial_basis_type: bessel - num_radial_basis: 8 - radial_basis_start: 0. - radial_basis_end: 5. - - ########## - # message passing conv layers - ########## - num_layers: 3 - - # radial network - invariant_layers: 2 # number of radial layers - invariant_neurons: 32 # number of hidden neurons in radial function - - # Average number of neighbors used for normalization. Options: - # 1. `auto` to determine it automatically, by setting it to average number - # of neighbors of the training set - # 2. float or int provided here. - # 3. `null` to not use it - average_num_neighbors: auto - - # point convolution - conv_layer_irreps: 32x0o+32x0e + 16x1o+16x1e + 4x2o+4x2e + 2x3o+2x3e - nonlinearity_type: gate - normalization: batch - resnet: true - - ########## - # output - ########## - - conv_to_output_hidden_irreps_out: 16x1o + 4x2o + 2x3o - - # output_format and output_formula should be used together. - # - output_format (can be `irreps` or `cartesian`) determines what the loss - # function will be on (either on the irreps space or the cartesian space). - # - output_formula gives what the cartesian formula of the tensor is. - # For example, ijkl=jikl=klij specifies a forth-rank elasticity tensor. - output_format: irreps - output_formula: ijk=ikj - - # pooling node feats to graph feats - reduce: mean - -trainer: - max_epochs: 10 # number of maximum training epochs - num_nodes: 1 - accelerator: cpu - devices: 1 - - callbacks: - - class_path: pytorch_lightning.callbacks.ModelCheckpoint - init_args: - monitor: val/score - mode: min - save_top_k: 3 - save_last: true - verbose: false - - class_path: pytorch_lightning.callbacks.EarlyStopping - init_args: - monitor: val/score - mode: min - patience: 150 - min_delta: 0 - verbose: true - - class_path: pytorch_lightning.callbacks.ModelSummary - init_args: - max_depth: -1 - - #logger: - # class_path: pytorch_lightning.loggers.wandb.WandbLogger - # init_args: - # save_dir: matten_logs - # project: matten_proj - -optimizer: - class_path: torch.optim.Adam - init_args: - lr: 0.01 - weight_decay: 0.00001 - -lr_scheduler: - class_path: torch.optim.lr_scheduler.ReduceLROnPlateau - init_args: - mode: min - factor: 0.5 - patience: 50 - verbose: true diff --git a/src/matten/predict.py b/src/matten/predict.py index c4300b0..140a3a5 100644 --- a/src/matten/predict.py +++ b/src/matten/predict.py @@ -13,6 +13,7 @@ from matten.dataset.structure_scalar_tensor import TensorDatasetPrediction from matten.log import set_logger +from matten.model_factory.tfn_atomic_tensor import AtomicTensorModel from matten.model_factory.tfn_scalar_tensor import ScalarTensorModel from matten.utils import CartesianTensorWrapper, yaml_load @@ -31,9 +32,11 @@ def get_pretrained_model_dir(identifier: str) -> Path: return Path(__file__).parent.parent.parent / "pretrained" / identifier -def get_pretrained_model(identifier: str, checkpoint: str = "model_final.ckpt"): +def get_pretrained_model( + identifier: str, checkpoint: str = "model_final.ckpt", model_class=ScalarTensorModel +): directory = get_pretrained_model_dir(identifier) - model = ScalarTensorModel.load_from_checkpoint( + model = model_class.load_from_checkpoint( checkpoint_path=directory.joinpath(checkpoint).as_posix(), map_location="cpu", ) @@ -62,6 +65,7 @@ def get_data_loader( "valset_filename", "testset_filename", "compute_dataset_statistics", + "atom_selector", ]: try: config.pop(k) @@ -151,6 +155,7 @@ def predict( batch_size: int = 200, logger_level: str = "ERROR", is_elasticity_tensor: bool = True, + is_atomic_tensor: bool = False, ) -> Union[ElasticTensor, List[ElasticTensor]]: f""" Predict the property of a structure or a list of structures. @@ -174,6 +179,8 @@ def predict( is_elasticity_tensor: whether the target property is an elasticity tensor. If `True`, the returned value will be a pymargen `ElasticTensor` object. Otherwise, it will be numpy array. + is_atomic_tensor: whether the target property is an atomic tensor. If `True`, + we predict a tensor value for each atom in the structure. Returns: Predicted tensor(s). `None` if the model cannot make prediction for a structure. @@ -186,7 +193,16 @@ def predict( else: single_struct = False - model = get_pretrained_model(identifier=model_identifier, checkpoint=checkpoint) + if is_atomic_tensor: + model_class = AtomicTensorModel + is_elasticity_tensor = False + else: + model_class = ScalarTensorModel + model = get_pretrained_model( + identifier=model_identifier, + checkpoint=checkpoint, + model_class=model_class, + ) check_species(model, structure) loader = get_data_loader(structure, model_identifier, batch_size=batch_size) @@ -223,10 +239,10 @@ def predict( else: pred_tensors = predictions - if single_struct: + if single_struct and not is_atomic_tensor: return pred_tensors[0] - else: - return pred_tensors + + return pred_tensors if __name__ == "__main__":