Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add training script for atomic tensor #13

Merged
merged 2 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions scripts/configs/atomic_tensor.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
## Config files for atomic tensor (i.e. a tensor value for each atom)

seed_everything: 35
log_level: info

data:
tensor_target_name: nmr_tensor
atom_selector: atom_selector
tensor_target_formula: ij=ji
root: .
trainset_filename: /Users/mjwen.admin/Documents/Dataset/NMR_tensor/si_nmr_data_small.json
valset_filename: /Users/mjwen.admin/Documents/Dataset/NMR_tensor/si_nmr_data_small.json
testset_filename: /Users/mjwen.admin/Documents/Dataset/NMR_tensor/si_nmr_data_small.json
r_cut: 5.0
reuse: false
loader_kwargs:
batch_size: 2
shuffle: true

model:
##########
# embedding
##########

# atom species embedding
species_embedding_dim: 16

# spherical harmonics embedding of edge direction
irreps_edge_sh: 0e + 1o + 2e

# radial edge distance embedding
radial_basis_type: bessel
num_radial_basis: 8
radial_basis_start: 0.
radial_basis_end: 5.

##########
# message passing conv layers
##########
num_layers: 3

# radial network
invariant_layers: 2 # number of radial layers
invariant_neurons: 32 # number of hidden neurons in radial function

# Average number of neighbors used for normalization. Options:
# 1. `auto` to determine it automatically, by setting it to average number
# of neighbors of the training set
# 2. float or int provided here.
# 3. `null` to not use it
average_num_neighbors: auto

# point convolution
conv_layer_irreps: 32x0o+32x0e + 16x1o+16x1e + 4x2o+4x2e
nonlinearity_type: gate
normalization: batch
resnet: true

##########
# output
##########

# output_format and output_formula should be used together.
# - output_format (can be `irreps` or `cartesian`) determines what the loss
# function will be on (either on the irreps space or the cartesian space).
# - output_formula gives what the cartesian formula of the tensor is.
# For example, ijkl=jikl=klij specifies a forth-rank elasticity tensor.
output_format: irreps
output_formula: ij=ji

# pooling node feats to graph feats
reduce: mean

trainer:
max_epochs: 10 # number of maximum training epochs
num_nodes: 1
accelerator: cpu
devices: 1

callbacks:
- class_path: pytorch_lightning.callbacks.ModelCheckpoint
init_args:
monitor: val/score
mode: min
save_top_k: 3
save_last: true
verbose: false
- class_path: pytorch_lightning.callbacks.EarlyStopping
init_args:
monitor: val/score
mode: min
patience: 150
min_delta: 0
verbose: true
- class_path: pytorch_lightning.callbacks.ModelSummary
init_args:
max_depth: -1

#logger:
# class_path: pytorch_lightning.loggers.wandb.WandbLogger
# init_args:
# save_dir: matten_logs
# project: matten_proj

optimizer:
class_path: torch.optim.Adam
init_args:
lr: 0.01
weight_decay: 0.00001

lr_scheduler:
class_path: torch.optim.lr_scheduler.ReduceLROnPlateau
init_args:
mode: min
factor: 0.5
patience: 50
verbose: true
1 change: 1 addition & 0 deletions scripts/configs/materials_tensor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ log_level: info

data:
root: ../datasets/
tensor_target_name: elastic_tensor_full
trainset_filename: example_crystal_elasticity_tensor_n100.json
valset_filename: example_crystal_elasticity_tensor_n100.json
testset_filename: example_crystal_elasticity_tensor_n100.json
Expand Down
81 changes: 81 additions & 0 deletions scripts/train_atomic_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""Script to train the materials tensor model."""

from pathlib import Path
from typing import Dict, List, Union

import yaml
from loguru import logger
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.cli import instantiate_class as lit_instantiate_class

from matten.dataset.structure_scalar_tensor import TensorDataModule
from matten.log import set_logger
from matten.model_factory.task import TensorRegressionTask
from matten.model_factory.tfn_atomic_tensor import AtomicTensorModel


def instantiate_class(d: Union[Dict, List]):
args = tuple() # no positional args
if isinstance(d, dict):
return lit_instantiate_class(args, d)
elif isinstance(d, list):
return [lit_instantiate_class(args, x) for x in d]
else:
raise ValueError(f"Cannot instantiate class from {d}")


def get_args(path: Path):
"""Get the arguments from the config file."""
with open(path, "r") as f:
config = yaml.safe_load(f)
return config


def main(config: Dict):
dm = TensorDataModule(**config["data"])
dm.prepare_data()
dm.setup()

model = AtomicTensorModel(
tasks=TensorRegressionTask(name=config["data"]["tensor_target_name"]),
backbone_hparams=config["model"],
dataset_hparams=dm.get_to_model_info(),
optimizer_hparams=config["optimizer"],
lr_scheduler_hparams=config["lr_scheduler"],
)

try:
callbacks = instantiate_class(config["trainer"].pop("callbacks"))
lit_logger = instantiate_class(config["trainer"].pop("logger"))
except KeyError:
callbacks = None
lit_logger = None

trainer = Trainer(
callbacks=callbacks,
logger=lit_logger,
**config["trainer"],
)

logger.info("Start training!")
trainer.fit(model, datamodule=dm)

# test
logger.info("Start testing!")
trainer.test(ckpt_path="best", datamodule=dm)

# print path of best checkpoint
logger.info(f"Best checkpoint path: {trainer.checkpoint_callback.best_model_path}")


if __name__ == "__main__":
config_file = Path(__file__).parent / "configs" / "atomic_tensor.yaml"
config = get_args(config_file)

seed = config.get("seed_everything", 1)
seed_everything(seed)

log_level = config.get("log_level", "INFO")
set_logger(log_level)

main(config)
2 changes: 1 addition & 1 deletion scripts/train_materials_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def main(config: Dict):
dm.setup()

model = ScalarTensorModel(
tasks=TensorRegressionTask(name="elastic_tensor_full"),
tasks=TensorRegressionTask(name=config["data"]["tensor_target_name"]),
backbone_hparams=config["model"],
dataset_hparams=dm.get_to_model_info(),
optimizer_hparams=config["optimizer"],
Expand Down
15 changes: 14 additions & 1 deletion src/matten/dataset/structure_scalar_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ class TensorDataset(InMemoryDataset):
10.0, 0:1.0}}, then for data points with minLC less than 0, it will have
a weight of 10, and for those with minLC larger than 0, it will have a
weight of 1. By default `None` means all data points has a weight of 1.
atom_selector: a list of bools to indicate which atoms to use in the structure
to compute the target. If `None`, all atoms are used.
global_featurizer: featurizer to compute global features.
normalize_global_features: whether to normalize the global feature.
atom_featuruzer: featurizer to compute atom features.
Expand Down Expand Up @@ -79,6 +81,7 @@ def __init__(
log_scalar_targets: List[bool] = None,
normalize_scalar_targets: List[bool] = None,
tensor_target_weight: Dict[str, Dict[str, float]] = None,
atom_selector: List[bool] = None,
global_featurizer: None = None,
normalize_global_features: bool = False,
atom_featurizer: None = None,
Expand All @@ -97,6 +100,7 @@ def __init__(
self.normalize_tensor_target = normalize_tensor_target

self.tensor_target_weight = tensor_target_weight
self.atom_selector = atom_selector

self.scalar_target_names = (
[] if scalar_target_names is None else scalar_target_names
Expand Down Expand Up @@ -259,7 +263,7 @@ def _get_crystals(self, df):
# TODO, convert to irreps tensor, assuming all input tensor is Cartesian
converter = CartesianTensorWrapper(formula=self.tensor_target_formula)
df[self.tensor_target_name] = df[self.tensor_target_name].apply(
lambda x: converter.from_cartesian(x).reshape(1, -1)
lambda x: torch.atleast_2d(converter.from_cartesian(x))
)
elif self.tensor_target_format == "cartesian":
df[self.tensor_target_name] = df[self.tensor_target_name].apply(
Expand Down Expand Up @@ -303,6 +307,10 @@ def _get_crystals(self, df):
y[self.tensor_target_name] * self.tensor_target_scale
)

# atom selector
if self.atom_selector is not None:
y["atom_selector"] = torch.as_tensor(row[self.atom_selector])

x = None
if self.global_featurizer:
# feats
Expand Down Expand Up @@ -414,6 +422,7 @@ def __init__(
tensor_target_scale: float = 1.0,
normalize_tensor_target: bool = False,
tensor_target_weight: Dict[str, Dict[str, float]] = None,
atom_selector: List[bool] = None,
scalar_target_names: List[str] = None,
log_scalar_targets: List[bool] = None,
normalize_scalar_targets: List[bool] = None,
Expand Down Expand Up @@ -457,6 +466,7 @@ def __init__(
self.tensor_target_scale = tensor_target_scale
self.normalize_tensor_target = normalize_tensor_target
self.tensor_target_weight = tensor_target_weight
self.atom_selector = atom_selector

self.scalar_target_names = scalar_target_names
self.log_scalar_targets = log_scalar_targets
Expand Down Expand Up @@ -563,6 +573,7 @@ def setup(self, stage: Optional[str] = None):
log_scalar_targets=self.log_scalar_targets,
normalize_scalar_targets=self.normalize_scalar_targets,
tensor_target_weight=self.tensor_target_weight,
atom_selector=self.atom_selector,
global_featurizer=gf,
normalize_global_features=self.normalize_global_features,
atom_featurizer=af,
Expand All @@ -584,6 +595,7 @@ def setup(self, stage: Optional[str] = None):
log_scalar_targets=self.log_scalar_targets,
normalize_scalar_targets=self.normalize_scalar_targets,
tensor_target_weight=self.tensor_target_weight,
atom_selector=self.atom_selector,
global_featurizer=gf,
normalize_global_features=self.normalize_global_features,
atom_featurizer=af,
Expand All @@ -605,6 +617,7 @@ def setup(self, stage: Optional[str] = None):
log_scalar_targets=self.log_scalar_targets,
normalize_scalar_targets=self.normalize_scalar_targets,
tensor_target_weight=self.tensor_target_weight,
atom_selector=self.atom_selector,
global_featurizer=gf,
normalize_global_features=self.normalize_global_features,
atom_featurizer=af,
Expand Down
7 changes: 7 additions & 0 deletions src/matten/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,11 @@ def shared_step(self, batch, mode: str):
# ========== compute predictions ==========
preds = self.decode(graphs)

# select atoms
if "atom_selector" in labels:
selector = labels["atom_selector"]
preds = {k: v[selector] for k, v in preds.items()}

# ========== compute losses ==========
target_weight = graphs.get("target_weight", None)
individual_loss, total_loss = self.compute_loss(
Expand Down Expand Up @@ -504,6 +509,8 @@ def preprocess_batch(self, batch: DataPoint) -> Tuple[DataPoint, Dict[str, Tenso

# task labels
labels = {name: graphs.y[name] for name in self.tasks}
if "atom_selector" in graphs.y:
labels["atom_selector"] = graphs.y["atom_selector"]

# convert graphs to a dict to use NequIP stuff
graphs = graphs.tensor_property_to_dict()
Expand Down
Loading
Loading