From fe8f8785b21f76aeae0d4cd7e0c8924dac002c94 Mon Sep 17 00:00:00 2001 From: KristinaUlicna Date: Fri, 29 Sep 2023 12:05:12 +0100 Subject: [PATCH 1/8] Implement lr scheduler + weight decay --- grace/training/train.py | 93 ++++++++++++++++++++++++++++------------- 1 file changed, 65 insertions(+), 28 deletions(-) diff --git a/grace/training/train.py b/grace/training/train.py index 7e3ea98..ae69c54 100644 --- a/grace/training/train.py +++ b/grace/training/train.py @@ -19,11 +19,15 @@ def train_model( model: torch.nn.Module, train_dataset: list[torch_geometric.data.Data], valid_dataset: list[torch_geometric.data.Data], - valid_target_list: list[nx.Graph], *, + valid_target_list: list[nx.Graph] = None, epochs: int = 100, batch_size: int = 64, learning_rate: float = 0.001, + scheduler_type: str = None, + scheduler_step: int = 1, + scheduler_gamma: float = 1.0, + weight_decay: float = 0.0, node_masked_class: Annotation = Annotation.UNKNOWN, edge_masked_class: Annotation = Annotation.UNKNOWN, log_dir: Optional[str] = None, @@ -81,8 +85,27 @@ def train_model( model.parameters(), lr=learning_rate, # weight_decay=5e-4, + weight_decay=weight_decay, ) + # Define the scheduler: + if scheduler_type is not None: + if scheduler_type == "step": + scheduler = torch.optim.lr_scheduler.StepLR( + optimizer, + step_size=scheduler_step, + gamma=scheduler_gamma, + ) + elif scheduler_type == "expo": + scheduler = torch.optim.lr_scheduler.ExponentialLR( + optimizer, + gamma=scheduler_gamma, + ) + else: + raise NotImplementedError( + f"Scheduler type '{scheduler_type}' is not implemented." + ) + # Specify node & edge criterion: # TODO: Implement class weighting node_criterion = torch.nn.CrossEntropyLoss( @@ -170,14 +193,28 @@ def valid( # Iterate over all epochs: for epoch in range(1, epochs + 1): - train(train_loader) # computes loss, backprops grads, updates params + # Computes loss, backprop grads, update params: + train(train_loader) + + # Get the current learning rate from the optimizer + current_lr = optimizer.param_groups[0]["lr"] + + # Call the scheduler step after each epoch + if scheduler_type is not None: + scheduler.step() + + # Loss & metrics on both Dataloaders: train_metrics = valid(train_loader) valid_metrics = valid(valid_loader) # Log the loss & metrics data: logger_string = f"Epoch: {epoch:03d} | " + logger_string += f"Learning rate: {current_lr} | " + logger_string += f"Scheduler type: {scheduler_type} | " for metric in train_metrics: + logger_string += "\n\t" + for regime, metric_dict in [ ("train", train_metrics), ("valid", valid_metrics), @@ -190,9 +227,11 @@ def valid( "edge": edge_value, } + # Combine node & edge losses: if len(metric_dict[metric]) == 3: metric_out["total"] = metric_dict[metric][2] + # Record floating point values (loss & numerical metrics): if isinstance(node_value, float): logger_string += ( f"{metric_name} (node): " f"{node_value:.4f} | " @@ -209,7 +248,7 @@ def valid( f"{metric_name} (edge)", metric_out["edge"], epoch ) - # elif isinstance(node_value, plt.Figure): + # Upload a figure to Tensorboard (confusion matrix): else: if epoch % tensorboard_update_frequency == 0: writer.add_figure( @@ -222,31 +261,29 @@ def valid( # Print out the logging string: LOGGER.info(logger_string) - # At chosen epochs, visualise the prediction probabs for whole graph: - if epoch % valid_graph_ploter_frequency == 0: - # Instantiate the model with frozen weights from current epoch: - GLP = GraphLabelPredictor(model) - - # Iterate through all validation graphs & predict nodes / edges: - for valid_target in valid_target_list: - valid_graph = valid_target["graph"] - - # Filename: - valid_name = valid_target["metadata"]["image_filename"] - valid_name = f"{valid_name}-Epoch_{epoch}" - - # Update probabs & visualise the graph: - GLP.set_node_and_edge_probabilities(G=valid_graph) - GLP.visualise_prediction_probs_on_graph( - G=valid_graph, - graph_filename=valid_name, - save_figure=log_dir / "valid", - show_figure=False, - ) - - # Save the graph out & clear figure: - # plt.savefig(log_dir / "valid" / valid_name) - # plt.close() + # Choose whether to visualise the valudation progress or not: + if valid_target_list is not None: + # At chosen epochs, visualise the prediction probabs for whole graph: + if epoch % valid_graph_ploter_frequency == 0: + # Instantiate the model with frozen weights from current epoch: + GLP = GraphLabelPredictor(model) + + # Iterate through all validation graphs & predict nodes / edges: + for valid_target in valid_target_list: + valid_graph = valid_target["graph"] + + # Filename: + valid_name = valid_target["metadata"]["image_filename"] + valid_name = f"{valid_name}-Epoch_{epoch}" + + # Update probabs & visualise the graph: + GLP.set_node_and_edge_probabilities(G=valid_graph) + GLP.visualise_prediction_probs_on_graph( + G=valid_graph, + graph_filename=valid_name, + save_figure=log_dir / "valid", + show_figure=False, + ) # Clear & close the tensorboard writer: writer.flush() From 982f42198c0ed523ea99394cc7c9d80c15ed7bac Mon Sep 17 00:00:00 2001 From: KristinaUlicna Date: Fri, 29 Sep 2023 12:06:03 +0100 Subject: [PATCH 2/8] Include configurable lr hparams + checks --- grace/training/config.py | 61 ++++++++++++++++++++++++++++++---------- 1 file changed, 46 insertions(+), 15 deletions(-) diff --git a/grace/training/config.py b/grace/training/config.py index c7c4c0f..922e65f 100644 --- a/grace/training/config.py +++ b/grace/training/config.py @@ -11,17 +11,28 @@ @dataclass class Config: + # Paths to inputs & outputs: train_image_dir: Optional[os.PathLike] = None train_grace_dir: Optional[os.PathLike] = None valid_image_dir: Optional[os.PathLike] = None valid_grace_dir: Optional[os.PathLike] = None infer_image_dir: Optional[os.PathLike] = None infer_grace_dir: Optional[os.PathLike] = None - extractor_fn: Optional[os.PathLike] = None log_dir: Optional[os.PathLike] = None run_dir: Optional[os.PathLike] = log_dir + + # Feature extraction: filetype: str = "mrc" + keep_node_unknown_labels: bool = False + keep_edge_unknown_labels: bool = False + + # Feature extraction: + extractor_fn: Optional[os.PathLike] = None + patch_size: tuple[int] = (224, 224) + feature_dim: int = 2048 normalize: tuple[bool] = (False, False) + + # Augmentations: img_graph_augs: list[str] = field( default_factory=lambda: [ "random_edge_addition_and_removal", @@ -40,26 +51,35 @@ class Config: patch_aug_params: list[dict[str, Any]] = field( default_factory=lambda: [{}] ) - patch_size: tuple[int] = (224, 224) keep_patch_fraction: float = 1.0 - keep_node_unknown_labels: bool = False - keep_edge_unknown_labels: bool = False - feature_dim: int = 2048 + # Classifier architecture setup classifier_type: str = "GCN" num_node_classes: int = 2 num_edge_classes: int = 2 - epochs: int = 100 hidden_channels: list[int] = field(default_factory=lambda: [1024, 256, 64]) + + # Training run hyperparameters: + batch_size: int = 64 + epochs: int = 100 + dropout: float = 0.2 + learning_rate: float = 0.001 + weight_decay: float = 0.0 + + # Learning rate scheduler: + scheduler_type: str = None + scheduler_step: int = 1 + scheduler_gamma: float = 1.0 + + # Performance evaluation: metrics_classifier: list[str] = field( default_factory=lambda: ["accuracy", "f1_score", "confusion_matrix"] ) metrics_objects: list[str] = field( default_factory=lambda: ["exact", "approx"] ) - dropout: float = 0.2 - batch_size: int = 64 - learning_rate: float = 0.001 + + # Validation & visualisation: tensorboard_update_frequency: int = 1 valid_graph_ploter_frequency: int = 1 animate_valid_progress: bool = False @@ -151,14 +171,25 @@ def validate_required_config_hparams(config: Config) -> None: if not config.extractor_fn.is_file(): raise PathNotDefinedError(path_name=dr) - # Define which metrics to calculate: + # Validate the learning rate schedule is implemented: + assert config.scheduler_type in {"step", "expo"} + + # Define which object metrics to calculate: + for i in range(len(config.metrics_classifier)): + m = config.metrics_classifier[i].lower() + config.metrics_classifier[i] = m + assert all( + m in {"accuracy", "f1_score", "confusion_matrix"} + for m in config.metrics_classifier + ) + + # Define which object metrics to calculate: for i in range(len(config.metrics_objects)): - m = config.metrics_objects[i].upper() - if m == "APPROXIMATE": - m = "APPROX" + m = config.metrics_objects[i].lower() + if m == "approximate": + m = "approx" config.metrics_objects[i] = m - - # Make sure saving file suffix is expected: + assert all(m in {"exact", "approx"} for m in config.metrics_objects) # HACK: not automated yet: if config.animate_valid_progress is True: From ab7aa393af7165759f2daffc4ff149a27932bdbd Mon Sep 17 00:00:00 2001 From: KristinaUlicna Date: Fri, 29 Sep 2023 12:08:27 +0100 Subject: [PATCH 3/8] Update config.yaml documentation --- grace/training/config.yaml | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/grace/training/config.yaml b/grace/training/config.yaml index c099699..2d55039 100644 --- a/grace/training/config.yaml +++ b/grace/training/config.yaml @@ -44,20 +44,24 @@ epochs: 500 dropout: 0.0 batch_size: 512 learning_rate: 0.05 +weight_decay: 0.0 -# Training Tensorboard logging: -metrics_classifier: - - accuracy - - f1_score - - confusion_matrix -tensorboard_update_frequency: 1 # stores loss / metrics output every x epochs +# Learning rate scheduler: +scheduler_type: step +scheduler_step: 5 +scheduler_gamma: 0.9 # Validation monitoring: +tensorboard_update_frequency: 1 # stores loss / metrics output every X epochs valid_graph_ploter_frequency: 20 animate_valid_progress: False visualise_tsne_manifold: False -# Optimisation evaluation: +# Performance metrics evaluation: +metrics_classifier: + - accuracy + - f1_score + - confusion_matrix metrics_objects: - exact - approx From 6533ed7fc97bfe5359841e40a26736517a684d76 Mon Sep 17 00:00:00 2001 From: KristinaUlicna Date: Fri, 29 Sep 2023 12:09:13 +0100 Subject: [PATCH 4/8] Include new hparam specification in run.py --- grace/run.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/grace/run.py b/grace/run.py index 027d38d..a1cd619 100644 --- a/grace/run.py +++ b/grace/run.py @@ -150,10 +150,14 @@ def prepare_dataset( classifier, train_dataset, valid_dataset, - valid_target_list, + valid_target_list=valid_target_list, epochs=config.epochs, batch_size=config.batch_size, learning_rate=config.learning_rate, + weight_decay=config.weight_decay, + scheduler_type=config.scheduler_type, + scheduler_step=config.scheduler_step, + scheduler_gamma=config.scheduler_gamma, log_dir=run_dir, metrics=config.metrics_classifier, tensorboard_update_frequency=config.tensorboard_update_frequency, From 123f5851c3a4bfe7eea4ba0e4e9a2e47c23b982f Mon Sep 17 00:00:00 2001 From: KristinaUlicna Date: Fri, 29 Sep 2023 15:08:42 +0100 Subject: [PATCH 5/8] Modify None to none for parsing --- grace/training/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/grace/training/config.py b/grace/training/config.py index 922e65f..f8bef8e 100644 --- a/grace/training/config.py +++ b/grace/training/config.py @@ -67,7 +67,7 @@ class Config: weight_decay: float = 0.0 # Learning rate scheduler: - scheduler_type: str = None + scheduler_type: str = "none" scheduler_step: int = 1 scheduler_gamma: float = 1.0 @@ -172,7 +172,7 @@ def validate_required_config_hparams(config: Config) -> None: raise PathNotDefinedError(path_name=dr) # Validate the learning rate schedule is implemented: - assert config.scheduler_type in {"step", "expo"} + assert config.scheduler_type in {"none", "step", "expo"} # Define which object metrics to calculate: for i in range(len(config.metrics_classifier)): From c85b169da6622c1656a6efa0a3ff965f23bb6e00 Mon Sep 17 00:00:00 2001 From: KristinaUlicna Date: Fri, 29 Sep 2023 15:09:24 +0100 Subject: [PATCH 6/8] Change None to str none if scheduler is not supplied --- grace/training/train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/grace/training/train.py b/grace/training/train.py index ae69c54..285325b 100644 --- a/grace/training/train.py +++ b/grace/training/train.py @@ -24,7 +24,7 @@ def train_model( epochs: int = 100, batch_size: int = 64, learning_rate: float = 0.001, - scheduler_type: str = None, + scheduler_type: str = "none", scheduler_step: int = 1, scheduler_gamma: float = 1.0, weight_decay: float = 0.0, @@ -89,7 +89,7 @@ def train_model( ) # Define the scheduler: - if scheduler_type is not None: + if scheduler_type != "none": if scheduler_type == "step": scheduler = torch.optim.lr_scheduler.StepLR( optimizer, @@ -200,7 +200,7 @@ def valid( current_lr = optimizer.param_groups[0]["lr"] # Call the scheduler step after each epoch - if scheduler_type is not None: + if scheduler_type != "none": scheduler.step() # Loss & metrics on both Dataloaders: From a4f3c1536090344b9b117052ae873253972b770a Mon Sep 17 00:00:00 2001 From: KristinaUlicna Date: Fri, 29 Sep 2023 15:30:08 +0100 Subject: [PATCH 7/8] Add docstring documentation for scheduler --- grace/training/train.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/grace/training/train.py b/grace/training/train.py index 285325b..999bc39 100644 --- a/grace/training/train.py +++ b/grace/training/train.py @@ -24,10 +24,10 @@ def train_model( epochs: int = 100, batch_size: int = 64, learning_rate: float = 0.001, + weight_decay: float = 0.0, scheduler_type: str = "none", scheduler_step: int = 1, scheduler_gamma: float = 1.0, - weight_decay: float = 0.0, node_masked_class: Annotation = Annotation.UNKNOWN, edge_masked_class: Annotation = Annotation.UNKNOWN, log_dir: Optional[str] = None, @@ -52,7 +52,15 @@ def train_model( batch_size : int Batch size learning_rate : float - Learning rate to use during training + (Base) learning rate to use during training + weight_decay : float + Weight decay (L2 penalty) (default: 0.0) + scheduler_type : str + Learning rate scheduler (default: "none") + scheduler_step: int + Period of learning rate decay in epochs. + scheduler_gamma : float + Multiplicative factor of learning rate decay (default: 1.0) node_masked_class : Annotation Target node class for which to set the loss to 0 edge_masked_class : Annotation @@ -84,7 +92,6 @@ def train_model( optimizer = torch.optim.Adam( model.parameters(), lr=learning_rate, - # weight_decay=5e-4, weight_decay=weight_decay, ) From fbd29f8c6b5d823ce2f04950de0aba8b4492728b Mon Sep 17 00:00:00 2001 From: KristinaUlicna Date: Fri, 29 Sep 2023 15:35:57 +0100 Subject: [PATCH 8/8] Round lr float to 8 significant figures --- grace/training/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grace/training/train.py b/grace/training/train.py index 999bc39..0ce3d59 100644 --- a/grace/training/train.py +++ b/grace/training/train.py @@ -216,7 +216,7 @@ def valid( # Log the loss & metrics data: logger_string = f"Epoch: {epoch:03d} | " - logger_string += f"Learning rate: {current_lr} | " + logger_string += f"Learning rate: {current_lr:.8f} | " logger_string += f"Scheduler type: {scheduler_type} | " for metric in train_metrics: