Skip to content

Commit

Permalink
Merge pull request #273 from alan-turing-institute/scheduler
Browse files Browse the repository at this point in the history
Config hyperparameters for enhanced training
  • Loading branch information
KristinaUlicna authored Sep 29, 2023
2 parents 710af32 + fbd29f8 commit ee85a0c
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 53 deletions.
6 changes: 5 additions & 1 deletion grace/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,14 @@ def prepare_dataset(
classifier,
train_dataset,
valid_dataset,
valid_target_list,
valid_target_list=valid_target_list,
epochs=config.epochs,
batch_size=config.batch_size,
learning_rate=config.learning_rate,
weight_decay=config.weight_decay,
scheduler_type=config.scheduler_type,
scheduler_step=config.scheduler_step,
scheduler_gamma=config.scheduler_gamma,
log_dir=run_dir,
metrics=config.metrics_classifier,
tensorboard_update_frequency=config.tensorboard_update_frequency,
Expand Down
61 changes: 46 additions & 15 deletions grace/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,28 @@

@dataclass
class Config:
# Paths to inputs & outputs:
train_image_dir: Optional[os.PathLike] = None
train_grace_dir: Optional[os.PathLike] = None
valid_image_dir: Optional[os.PathLike] = None
valid_grace_dir: Optional[os.PathLike] = None
infer_image_dir: Optional[os.PathLike] = None
infer_grace_dir: Optional[os.PathLike] = None
extractor_fn: Optional[os.PathLike] = None
log_dir: Optional[os.PathLike] = None
run_dir: Optional[os.PathLike] = log_dir

# Feature extraction:
filetype: str = "mrc"
keep_node_unknown_labels: bool = False
keep_edge_unknown_labels: bool = False

# Feature extraction:
extractor_fn: Optional[os.PathLike] = None
patch_size: tuple[int] = (224, 224)
feature_dim: int = 2048
normalize: tuple[bool] = (False, False)

# Augmentations:
img_graph_augs: list[str] = field(
default_factory=lambda: [
"random_edge_addition_and_removal",
Expand All @@ -40,26 +51,35 @@ class Config:
patch_aug_params: list[dict[str, Any]] = field(
default_factory=lambda: [{}]
)
patch_size: tuple[int] = (224, 224)
keep_patch_fraction: float = 1.0
keep_node_unknown_labels: bool = False
keep_edge_unknown_labels: bool = False
feature_dim: int = 2048

# Classifier architecture setup
classifier_type: str = "GCN"
num_node_classes: int = 2
num_edge_classes: int = 2
epochs: int = 100
hidden_channels: list[int] = field(default_factory=lambda: [1024, 256, 64])

# Training run hyperparameters:
batch_size: int = 64
epochs: int = 100
dropout: float = 0.2
learning_rate: float = 0.001
weight_decay: float = 0.0

# Learning rate scheduler:
scheduler_type: str = "none"
scheduler_step: int = 1
scheduler_gamma: float = 1.0

# Performance evaluation:
metrics_classifier: list[str] = field(
default_factory=lambda: ["accuracy", "f1_score", "confusion_matrix"]
)
metrics_objects: list[str] = field(
default_factory=lambda: ["exact", "approx"]
)
dropout: float = 0.2
batch_size: int = 64
learning_rate: float = 0.001

# Validation & visualisation:
tensorboard_update_frequency: int = 1
valid_graph_ploter_frequency: int = 1
animate_valid_progress: bool = False
Expand Down Expand Up @@ -151,14 +171,25 @@ def validate_required_config_hparams(config: Config) -> None:
if not config.extractor_fn.is_file():
raise PathNotDefinedError(path_name=dr)

# Define which metrics to calculate:
# Validate the learning rate schedule is implemented:
assert config.scheduler_type in {"none", "step", "expo"}

# Define which object metrics to calculate:
for i in range(len(config.metrics_classifier)):
m = config.metrics_classifier[i].lower()
config.metrics_classifier[i] = m
assert all(
m in {"accuracy", "f1_score", "confusion_matrix"}
for m in config.metrics_classifier
)

# Define which object metrics to calculate:
for i in range(len(config.metrics_objects)):
m = config.metrics_objects[i].upper()
if m == "APPROXIMATE":
m = "APPROX"
m = config.metrics_objects[i].lower()
if m == "approximate":
m = "approx"
config.metrics_objects[i] = m

# Make sure saving file suffix is expected:
assert all(m in {"exact", "approx"} for m in config.metrics_objects)

# HACK: not automated yet:
if config.animate_valid_progress is True:
Expand Down
18 changes: 11 additions & 7 deletions grace/training/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,24 @@ epochs: 500
dropout: 0.0
batch_size: 512
learning_rate: 0.05
weight_decay: 0.0

# Training Tensorboard logging:
metrics_classifier:
- accuracy
- f1_score
- confusion_matrix
tensorboard_update_frequency: 1 # stores loss / metrics output every x epochs
# Learning rate scheduler:
scheduler_type: step
scheduler_step: 5
scheduler_gamma: 0.9

# Validation monitoring:
tensorboard_update_frequency: 1 # stores loss / metrics output every X epochs
valid_graph_ploter_frequency: 20
animate_valid_progress: False
visualise_tsne_manifold: False

# Optimisation evaluation:
# Performance metrics evaluation:
metrics_classifier:
- accuracy
- f1_score
- confusion_matrix
metrics_objects:
- exact
- approx
104 changes: 74 additions & 30 deletions grace/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,15 @@ def train_model(
model: torch.nn.Module,
train_dataset: list[torch_geometric.data.Data],
valid_dataset: list[torch_geometric.data.Data],
valid_target_list: list[nx.Graph],
*,
valid_target_list: list[nx.Graph] = None,
epochs: int = 100,
batch_size: int = 64,
learning_rate: float = 0.001,
weight_decay: float = 0.0,
scheduler_type: str = "none",
scheduler_step: int = 1,
scheduler_gamma: float = 1.0,
node_masked_class: Annotation = Annotation.UNKNOWN,
edge_masked_class: Annotation = Annotation.UNKNOWN,
log_dir: Optional[str] = None,
Expand All @@ -48,7 +52,15 @@ def train_model(
batch_size : int
Batch size
learning_rate : float
Learning rate to use during training
(Base) learning rate to use during training
weight_decay : float
Weight decay (L2 penalty) (default: 0.0)
scheduler_type : str
Learning rate scheduler (default: "none")
scheduler_step: int
Period of learning rate decay in epochs.
scheduler_gamma : float
Multiplicative factor of learning rate decay (default: 1.0)
node_masked_class : Annotation
Target node class for which to set the loss to 0
edge_masked_class : Annotation
Expand Down Expand Up @@ -80,9 +92,27 @@ def train_model(
optimizer = torch.optim.Adam(
model.parameters(),
lr=learning_rate,
# weight_decay=5e-4,
weight_decay=weight_decay,
)

# Define the scheduler:
if scheduler_type != "none":
if scheduler_type == "step":
scheduler = torch.optim.lr_scheduler.StepLR(
optimizer,
step_size=scheduler_step,
gamma=scheduler_gamma,
)
elif scheduler_type == "expo":
scheduler = torch.optim.lr_scheduler.ExponentialLR(
optimizer,
gamma=scheduler_gamma,
)
else:
raise NotImplementedError(
f"Scheduler type '{scheduler_type}' is not implemented."
)

# Specify node & edge criterion:
# TODO: Implement class weighting
node_criterion = torch.nn.CrossEntropyLoss(
Expand Down Expand Up @@ -170,14 +200,28 @@ def valid(

# Iterate over all epochs:
for epoch in range(1, epochs + 1):
train(train_loader) # computes loss, backprops grads, updates params
# Computes loss, backprop grads, update params:
train(train_loader)

# Get the current learning rate from the optimizer
current_lr = optimizer.param_groups[0]["lr"]

# Call the scheduler step after each epoch
if scheduler_type != "none":
scheduler.step()

# Loss & metrics on both Dataloaders:
train_metrics = valid(train_loader)
valid_metrics = valid(valid_loader)

# Log the loss & metrics data:
logger_string = f"Epoch: {epoch:03d} | "
logger_string += f"Learning rate: {current_lr:.8f} | "
logger_string += f"Scheduler type: {scheduler_type} | "

for metric in train_metrics:
logger_string += "\n\t"

for regime, metric_dict in [
("train", train_metrics),
("valid", valid_metrics),
Expand All @@ -190,9 +234,11 @@ def valid(
"edge": edge_value,
}

# Combine node & edge losses:
if len(metric_dict[metric]) == 3:
metric_out["total"] = metric_dict[metric][2]

# Record floating point values (loss & numerical metrics):
if isinstance(node_value, float):
logger_string += (
f"{metric_name} (node): " f"{node_value:.4f} | "
Expand All @@ -209,7 +255,7 @@ def valid(
f"{metric_name} (edge)", metric_out["edge"], epoch
)

# elif isinstance(node_value, plt.Figure):
# Upload a figure to Tensorboard (confusion matrix):
else:
if epoch % tensorboard_update_frequency == 0:
writer.add_figure(
Expand All @@ -222,31 +268,29 @@ def valid(
# Print out the logging string:
LOGGER.info(logger_string)

# At chosen epochs, visualise the prediction probabs for whole graph:
if epoch % valid_graph_ploter_frequency == 0:
# Instantiate the model with frozen weights from current epoch:
GLP = GraphLabelPredictor(model)

# Iterate through all validation graphs & predict nodes / edges:
for valid_target in valid_target_list:
valid_graph = valid_target["graph"]

# Filename:
valid_name = valid_target["metadata"]["image_filename"]
valid_name = f"{valid_name}-Epoch_{epoch}"

# Update probabs & visualise the graph:
GLP.set_node_and_edge_probabilities(G=valid_graph)
GLP.visualise_prediction_probs_on_graph(
G=valid_graph,
graph_filename=valid_name,
save_figure=log_dir / "valid",
show_figure=False,
)

# Save the graph out & clear figure:
# plt.savefig(log_dir / "valid" / valid_name)
# plt.close()
# Choose whether to visualise the valudation progress or not:
if valid_target_list is not None:
# At chosen epochs, visualise the prediction probabs for whole graph:
if epoch % valid_graph_ploter_frequency == 0:
# Instantiate the model with frozen weights from current epoch:
GLP = GraphLabelPredictor(model)

# Iterate through all validation graphs & predict nodes / edges:
for valid_target in valid_target_list:
valid_graph = valid_target["graph"]

# Filename:
valid_name = valid_target["metadata"]["image_filename"]
valid_name = f"{valid_name}-Epoch_{epoch}"

# Update probabs & visualise the graph:
GLP.set_node_and_edge_probabilities(G=valid_graph)
GLP.visualise_prediction_probs_on_graph(
G=valid_graph,
graph_filename=valid_name,
save_figure=log_dir / "valid",
show_figure=False,
)

# Clear & close the tensorboard writer:
writer.flush()
Expand Down

0 comments on commit ee85a0c

Please sign in to comment.