Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify run.py πŸƒβ€β™€οΈ #308

Closed
wants to merge 10 commits into from
77 changes: 25 additions & 52 deletions grace/run.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
from typing import Union
from functools import partial

import os
import click
import torch
from datetime import datetime
from tqdm.auto import tqdm

from grace.styling import LOGGER
from grace.io.image_dataset import ImageGraphDataset
from grace.base import EdgeProps

from grace.models.datasets import dataset_from_graph
from grace.models.classifier import Classifier
from grace.models.optimiser import optimise_graph

from grace.training.train import train_model
from grace.training.build import check_and_chop_dataset
from grace.training.config import (
validate_required_config_hparams,
load_config_params,
Expand Down Expand Up @@ -57,57 +57,31 @@ def run_grace(config_file: Union[str, os.PathLike]) -> None:
subfolder_path = run_dir / subfolder
subfolder_path.mkdir(parents=True, exist_ok=True)

# Process the datasets as desired:
def prepare_dataset(
image_dir: Union[str, os.PathLike],
grace_dir: Union[str, os.PathLike],
num_hops: int | str,
connection: str = "spiderweb",
verbose: bool = True,
) -> tuple[list]:
# Read the data & terate through images & extract node features:
input_data = ImageGraphDataset(
image_dir=image_dir,
grace_dir=grace_dir,
image_filetype=config.filetype,
keep_node_unknown_labels=config.keep_node_unknown_labels,
keep_edge_unknown_labels=config.keep_edge_unknown_labels,
)

# Process the (sub)graph data into torch_geometric dataset:
target_list, subgraph_dataset = [], []
desc = "Extracting patch features from images... "
for _, target in tqdm(input_data, desc=desc, disable=not verbose):
file_name = target["metadata"]["image_filename"]
LOGGER.info(f"Processing file: {file_name}")

# Store the valid graph list:
target_list.append(target)

# Chop graph into subgraphs & store:
graph_data = dataset_from_graph(
target["graph"],
num_hops=num_hops,
connection=connection,
)
subgraph_dataset.extend(graph_data)

return target_list, subgraph_dataset
# Create a transform function with frozen arguments:
check_and_chop_partial = partial(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find the name of this function a bit confusing, it is not very clear to what is happening here. As is in the main run.py scrip maybe we should add more comments to explain what is happening?

check_and_chop_dataset,
filetype=config.filetype,
node_feature_ndim=config.feature_dim,
edge_property_len=len(EdgeProps),
keep_node_unknown_labels=config.keep_node_unknown_labels,
keep_edge_unknown_labels=config.keep_edge_unknown_labels,
connection=config.connection,
store_permanently=config.store_graph_attributes_permanently,
extractor_fn=config.extractor_fn,
)

# Read the respective datasets:
_, train_dataset = prepare_dataset(
_, train_dataset = check_and_chop_partial(
config.train_image_dir,
config.train_grace_dir,
num_hops=config.num_hops,
connection=config.connection,
)
valid_target_list, valid_dataset = prepare_dataset(
valid_target_list, valid_dataset = check_and_chop_partial(
config.valid_image_dir,
config.valid_grace_dir,
num_hops=config.num_hops,
connection=config.connection,
)
infer_target_list, _ = prepare_dataset(
infer_target_list, _ = check_and_chop_partial(
config.infer_image_dir,
config.infer_grace_dir,
num_hops="whole",
Expand Down Expand Up @@ -211,14 +185,13 @@ def prepare_dataset(
save_figure=run_dir / "infer",
show_figure=False,
)
if config.classifier_type != "GAT":
continue
GLP.visualise_attention_weights_on_graph(
G=infer_graph,
graph_filename=fn,
save_figure=run_dir / "infer",
show_figure=False,
)
if config.classifier_type == "GAT":
GLP.visualise_attention_weights_on_graph(
G=infer_graph,
graph_filename=fn,
save_figure=run_dir / "infer",
show_figure=False,
)

# Generate GT & optimised graphs:
true_graph = generate_ground_truth_graph(infer_graph)
Expand Down
170 changes: 170 additions & 0 deletions grace/training/build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
from pathlib import Path
from tqdm.auto import tqdm

from grace.styling import LOGGER
from grace.base import GraphAttrs, EdgeProps
from grace.models.datasets import dataset_from_graph

from grace.io.image_dataset import ImageGraphDataset
from grace.io.store_node_features import store_node_features_in_graph
from grace.io.store_edge_properties import store_edge_properties_in_graph


def check_and_chop_dataset(
image_dir: str | Path,
grace_dir: str | Path,
filetype: str,
node_feature_ndim: int,
edge_property_len: int,
keep_node_unknown_labels: bool,
keep_edge_unknown_labels: bool,
num_hops: int | str,
connection: str = "spiderweb",
store_permanently: bool = False,
extractor_fn: str | Path = None,
):
# Check if datasets are ready for training:
Copy link
Collaborator

@crangelsmith crangelsmith Oct 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar as described for run.py maybe add some docstring description of what these functions are about?

dataset_ready_for_training = check_dataset_requirements(
image_dir=image_dir,
grace_dir=grace_dir,
filetype=filetype,
node_feature_ndim=node_feature_ndim,
edge_property_len=edge_property_len,
)
if not dataset_ready_for_training:
if store_permanently is True:
assert extractor_fn is not None, "Provide feature extractor"

# Inform the user about the delay:
LOGGER.warning(
"\n\nComputing node features & edge properties for data in "
f"{grace_dir}. Expect to take ~30-40 seconds per file...\n\n"
)
store_node_features_in_graph(grace_dir, extractor_fn)
store_edge_properties_in_graph(grace_dir)
else:
raise GraceGraphError(grace_dir=grace_dir)

# Now that you have the files with node features & edge properties:
target_list, subgraph_dataset = prepare_dataset_subgraphs(
image_dir=image_dir,
grace_dir=grace_dir,
image_filetype=filetype,
keep_node_unknown_labels=keep_node_unknown_labels,
keep_edge_unknown_labels=keep_edge_unknown_labels,
num_hops=num_hops,
connection=connection,
)
return target_list, subgraph_dataset


def check_dataset_requirements(
image_dir: str | Path,
grace_dir: str | Path,
filetype: str,
node_feature_ndim: int,
edge_property_len: int,
) -> tuple[list]:
# Read the data & terate through images & extract node features:
dataset_ready_for_training = True

input_data = ImageGraphDataset(
image_dir=image_dir,
grace_dir=grace_dir,
image_filetype=filetype,
verbose=False,
)

# Process the (sub)graph data into torch_geometric dataset:
desc = f"Chopping subgraphs from {grace_dir}"
for _, target in tqdm(input_data, desc=desc):
# Graph sanity checks: NODE_FEATURES:

for _, node in target["graph"].nodes(data=True):
if GraphAttrs.NODE_FEATURES not in node:
dataset_ready_for_training = False
break
node_features = node[GraphAttrs.NODE_FEATURES]
if node_features is None:
dataset_ready_for_training = False
break
if node_features.shape[0] != node_feature_ndim:
dataset_ready_for_training = False
break

# Graph sanity checks: EDGE_PROPERTIES:
for _, _, edge in target["graph"].edges(data=True):
if GraphAttrs.EDGE_PROPERTIES not in edge:
dataset_ready_for_training = False
break
edge_properties = edge[GraphAttrs.EDGE_PROPERTIES]
if edge_properties is None:
dataset_ready_for_training = False
break
edge_properties = edge_properties.properties_dict
if edge_properties is None:
dataset_ready_for_training = False
break
if len(edge_properties) < edge_property_len:
dataset_ready_for_training = False
break
if not all([item in edge_properties for item in EdgeProps]):
dataset_ready_for_training = False
break

return dataset_ready_for_training


def prepare_dataset_subgraphs(
image_dir: str | Path,
grace_dir: str | Path,
*,
image_filetype: str,
keep_node_unknown_labels: bool,
keep_edge_unknown_labels: bool,
num_hops: int | str,
connection: str = "spiderweb",
) -> tuple[list]:
# Read the data & terate through images & extract node features:
input_data = ImageGraphDataset(
image_dir=image_dir,
grace_dir=grace_dir,
image_filetype=image_filetype,
keep_node_unknown_labels=keep_node_unknown_labels,
keep_edge_unknown_labels=keep_edge_unknown_labels,
)

# Process the (sub)graph data into torch_geometric dataset:
target_list, subgraph_dataset = [], []
for _, target in input_data:
# Store the valid graph list with the updated target:
target_list.append(target)

# Now, process the graph with all attributes & chop into subgraphs & store:
graph_data = dataset_from_graph(
target["graph"],
num_hops=num_hops,
connection=connection,
)
subgraph_dataset.extend(graph_data)

return target_list, subgraph_dataset


class GraceGraphError(Exception):
def __init__(self, grace_dir):
super().__init__(
"\n\nThe GRACE annotation files don't contain the proper node "
"features & edge attributes for training \nin the `grace_dir` "
f"= '{grace_dir}'\n\nPlease consider:\n\n(i) changing your config"
" 'store_graph_attributes_permanently' argument to 'True', which "
"will automatically compute & store the graph attributes & or \n"
"(ii) manually run the scripts below for all your data paths, "
"incl. 'train', 'valid' & 'infer' before launching the next run:"
"\n\n\t`python3 grace/io/store_edge_properties.py --data_path="
"/path/to/your/data` \nand"
"\n\t`python3 grace/io/store_node_features.py --data_path="
"/path/to/your/data --extractor_fn=/path/to/feature/extractor.pt`"
"\n\nThis will compute required graph attributes & store them "
"in the GRACE annotation file collection, avoiding this error.\n"
)
3 changes: 2 additions & 1 deletion grace/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@ class Config:
keep_edge_unknown_labels: bool = False

# Feature extraction:
store_graph_attributes_permanently: bool = False
extractor_fn: Optional[os.PathLike] = None
patch_size: tuple[int] = (224, 224)
feature_dim: int = 2048
normalize: tuple[bool] = (False, False)
feature_dim: int = 2048

# Augmentations:
img_graph_augs: list[str] = field(
Expand Down
5 changes: 4 additions & 1 deletion grace/training/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@ keep_edge_unknown_labels: False # relabels UNKNOWN edges to TRUE_NEGATIVE

# Feature extractor path & patch normalization:
extractor_fn: /path/to/your/extractor/resnet152.pt
store_graph_attributes_permanently: False # write out graph attributes

feature_dim: 2048 # output dimensionality: 1D feature vector

patch_size: # input dimensionality: 2D patch shape
- 224
- 224
feature_dim: 2048 # output dimensionality: 1D feature vector
normalise: # [0-1] image patch standardisation
- False # before augmentations
- False # after augmentations
Expand Down
Loading