Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Config hyperparameters for enhanced training #273

Merged
merged 8 commits into from
Sep 29, 2023
6 changes: 5 additions & 1 deletion grace/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,14 @@ def prepare_dataset(
classifier,
train_dataset,
valid_dataset,
valid_target_list,
valid_target_list=valid_target_list,
epochs=config.epochs,
batch_size=config.batch_size,
learning_rate=config.learning_rate,
weight_decay=config.weight_decay,
scheduler_type=config.scheduler_type,
scheduler_step=config.scheduler_step,
scheduler_gamma=config.scheduler_gamma,
log_dir=run_dir,
metrics=config.metrics_classifier,
tensorboard_update_frequency=config.tensorboard_update_frequency,
Expand Down
61 changes: 46 additions & 15 deletions grace/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,28 @@

@dataclass
class Config:
# Paths to inputs & outputs:
train_image_dir: Optional[os.PathLike] = None
train_grace_dir: Optional[os.PathLike] = None
valid_image_dir: Optional[os.PathLike] = None
valid_grace_dir: Optional[os.PathLike] = None
infer_image_dir: Optional[os.PathLike] = None
infer_grace_dir: Optional[os.PathLike] = None
extractor_fn: Optional[os.PathLike] = None
log_dir: Optional[os.PathLike] = None
run_dir: Optional[os.PathLike] = log_dir

# Feature extraction:
filetype: str = "mrc"
keep_node_unknown_labels: bool = False
keep_edge_unknown_labels: bool = False

# Feature extraction:
extractor_fn: Optional[os.PathLike] = None
patch_size: tuple[int] = (224, 224)
feature_dim: int = 2048
normalize: tuple[bool] = (False, False)

# Augmentations:
img_graph_augs: list[str] = field(
default_factory=lambda: [
"random_edge_addition_and_removal",
Expand All @@ -40,26 +51,35 @@ class Config:
patch_aug_params: list[dict[str, Any]] = field(
default_factory=lambda: [{}]
)
patch_size: tuple[int] = (224, 224)
keep_patch_fraction: float = 1.0
keep_node_unknown_labels: bool = False
keep_edge_unknown_labels: bool = False
feature_dim: int = 2048

# Classifier architecture setup
classifier_type: str = "GCN"
num_node_classes: int = 2
num_edge_classes: int = 2
epochs: int = 100
hidden_channels: list[int] = field(default_factory=lambda: [1024, 256, 64])

# Training run hyperparameters:
batch_size: int = 64
epochs: int = 100
dropout: float = 0.2
learning_rate: float = 0.001
weight_decay: float = 0.0

# Learning rate scheduler:
scheduler_type: str = "none"
scheduler_step: int = 1
scheduler_gamma: float = 1.0

# Performance evaluation:
metrics_classifier: list[str] = field(
default_factory=lambda: ["accuracy", "f1_score", "confusion_matrix"]
)
metrics_objects: list[str] = field(
default_factory=lambda: ["exact", "approx"]
)
dropout: float = 0.2
batch_size: int = 64
learning_rate: float = 0.001

# Validation & visualisation:
tensorboard_update_frequency: int = 1
valid_graph_ploter_frequency: int = 1
animate_valid_progress: bool = False
Expand Down Expand Up @@ -151,14 +171,25 @@ def validate_required_config_hparams(config: Config) -> None:
if not config.extractor_fn.is_file():
raise PathNotDefinedError(path_name=dr)

# Define which metrics to calculate:
# Validate the learning rate schedule is implemented:
assert config.scheduler_type in {"none", "step", "expo"}

# Define which object metrics to calculate:
for i in range(len(config.metrics_classifier)):
m = config.metrics_classifier[i].lower()
config.metrics_classifier[i] = m
assert all(
m in {"accuracy", "f1_score", "confusion_matrix"}
for m in config.metrics_classifier
)

# Define which object metrics to calculate:
for i in range(len(config.metrics_objects)):
m = config.metrics_objects[i].upper()
if m == "APPROXIMATE":
m = "APPROX"
m = config.metrics_objects[i].lower()
if m == "approximate":
m = "approx"
config.metrics_objects[i] = m

# Make sure saving file suffix is expected:
assert all(m in {"exact", "approx"} for m in config.metrics_objects)

# HACK: not automated yet:
if config.animate_valid_progress is True:
Expand Down
18 changes: 11 additions & 7 deletions grace/training/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,24 @@ epochs: 500
dropout: 0.0
batch_size: 512
learning_rate: 0.05
weight_decay: 0.0

# Training Tensorboard logging:
metrics_classifier:
- accuracy
- f1_score
- confusion_matrix
tensorboard_update_frequency: 1 # stores loss / metrics output every x epochs
# Learning rate scheduler:
scheduler_type: step
scheduler_step: 5
scheduler_gamma: 0.9

# Validation monitoring:
tensorboard_update_frequency: 1 # stores loss / metrics output every X epochs
valid_graph_ploter_frequency: 20
animate_valid_progress: False
visualise_tsne_manifold: False

# Optimisation evaluation:
# Performance metrics evaluation:
metrics_classifier:
- accuracy
- f1_score
- confusion_matrix
metrics_objects:
- exact
- approx
104 changes: 74 additions & 30 deletions grace/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,15 @@ def train_model(
model: torch.nn.Module,
train_dataset: list[torch_geometric.data.Data],
valid_dataset: list[torch_geometric.data.Data],
valid_target_list: list[nx.Graph],
*,
valid_target_list: list[nx.Graph] = None,
epochs: int = 100,
batch_size: int = 64,
learning_rate: float = 0.001,
weight_decay: float = 0.0,
scheduler_type: str = "none",
scheduler_step: int = 1,
scheduler_gamma: float = 1.0,
node_masked_class: Annotation = Annotation.UNKNOWN,
edge_masked_class: Annotation = Annotation.UNKNOWN,
log_dir: Optional[str] = None,
Expand All @@ -48,7 +52,15 @@ def train_model(
batch_size : int
Batch size
learning_rate : float
Learning rate to use during training
(Base) learning rate to use during training
weight_decay : float
Weight decay (L2 penalty) (default: 0.0)
scheduler_type : str
Learning rate scheduler (default: "none")
scheduler_step: int
Period of learning rate decay in epochs.
scheduler_gamma : float
Multiplicative factor of learning rate decay (default: 1.0)
node_masked_class : Annotation
Target node class for which to set the loss to 0
edge_masked_class : Annotation
Expand Down Expand Up @@ -80,9 +92,27 @@ def train_model(
optimizer = torch.optim.Adam(
model.parameters(),
lr=learning_rate,
# weight_decay=5e-4,
weight_decay=weight_decay,
)

# Define the scheduler:
if scheduler_type != "none":
if scheduler_type == "step":
scheduler = torch.optim.lr_scheduler.StepLR(
optimizer,
step_size=scheduler_step,
gamma=scheduler_gamma,
)
elif scheduler_type == "expo":
scheduler = torch.optim.lr_scheduler.ExponentialLR(
optimizer,
gamma=scheduler_gamma,
)
else:
raise NotImplementedError(
f"Scheduler type '{scheduler_type}' is not implemented."
)

# Specify node & edge criterion:
# TODO: Implement class weighting
node_criterion = torch.nn.CrossEntropyLoss(
Expand Down Expand Up @@ -170,14 +200,28 @@ def valid(

# Iterate over all epochs:
for epoch in range(1, epochs + 1):
train(train_loader) # computes loss, backprops grads, updates params
# Computes loss, backprop grads, update params:
train(train_loader)

# Get the current learning rate from the optimizer
current_lr = optimizer.param_groups[0]["lr"]

# Call the scheduler step after each epoch
if scheduler_type != "none":
scheduler.step()

# Loss & metrics on both Dataloaders:
train_metrics = valid(train_loader)
valid_metrics = valid(valid_loader)

# Log the loss & metrics data:
logger_string = f"Epoch: {epoch:03d} | "
logger_string += f"Learning rate: {current_lr:.8f} | "
logger_string += f"Scheduler type: {scheduler_type} | "

for metric in train_metrics:
logger_string += "\n\t"

for regime, metric_dict in [
("train", train_metrics),
("valid", valid_metrics),
Expand All @@ -190,9 +234,11 @@ def valid(
"edge": edge_value,
}

# Combine node & edge losses:
if len(metric_dict[metric]) == 3:
metric_out["total"] = metric_dict[metric][2]

# Record floating point values (loss & numerical metrics):
if isinstance(node_value, float):
logger_string += (
f"{metric_name} (node): " f"{node_value:.4f} | "
Expand All @@ -209,7 +255,7 @@ def valid(
f"{metric_name} (edge)", metric_out["edge"], epoch
)

# elif isinstance(node_value, plt.Figure):
# Upload a figure to Tensorboard (confusion matrix):
else:
if epoch % tensorboard_update_frequency == 0:
writer.add_figure(
Expand All @@ -222,31 +268,29 @@ def valid(
# Print out the logging string:
LOGGER.info(logger_string)

# At chosen epochs, visualise the prediction probabs for whole graph:
if epoch % valid_graph_ploter_frequency == 0:
# Instantiate the model with frozen weights from current epoch:
GLP = GraphLabelPredictor(model)

# Iterate through all validation graphs & predict nodes / edges:
for valid_target in valid_target_list:
valid_graph = valid_target["graph"]

# Filename:
valid_name = valid_target["metadata"]["image_filename"]
valid_name = f"{valid_name}-Epoch_{epoch}"

# Update probabs & visualise the graph:
GLP.set_node_and_edge_probabilities(G=valid_graph)
GLP.visualise_prediction_probs_on_graph(
G=valid_graph,
graph_filename=valid_name,
save_figure=log_dir / "valid",
show_figure=False,
)

# Save the graph out & clear figure:
# plt.savefig(log_dir / "valid" / valid_name)
# plt.close()
# Choose whether to visualise the valudation progress or not:
if valid_target_list is not None:
# At chosen epochs, visualise the prediction probabs for whole graph:
if epoch % valid_graph_ploter_frequency == 0:
# Instantiate the model with frozen weights from current epoch:
GLP = GraphLabelPredictor(model)

# Iterate through all validation graphs & predict nodes / edges:
for valid_target in valid_target_list:
valid_graph = valid_target["graph"]

# Filename:
valid_name = valid_target["metadata"]["image_filename"]
valid_name = f"{valid_name}-Epoch_{epoch}"

# Update probabs & visualise the graph:
GLP.set_node_and_edge_probabilities(G=valid_graph)
GLP.visualise_prediction_probs_on_graph(
G=valid_graph,
graph_filename=valid_name,
save_figure=log_dir / "valid",
show_figure=False,
)

# Clear & close the tensorboard writer:
writer.flush()
Expand Down