diff --git a/grace/run.py b/grace/run.py index 1dea46e..350e27d 100644 --- a/grace/run.py +++ b/grace/run.py @@ -1,19 +1,19 @@ from typing import Union +from functools import partial import os import click import torch from datetime import datetime -from tqdm.auto import tqdm from grace.styling import LOGGER -from grace.io.image_dataset import ImageGraphDataset +from grace.base import EdgeProps -from grace.models.datasets import dataset_from_graph from grace.models.classifier import Classifier from grace.models.optimiser import optimise_graph from grace.training.train import train_model +from grace.training.build import check_and_chop_dataset from grace.training.config import ( validate_required_config_hparams, load_config_params, @@ -57,57 +57,31 @@ def run_grace(config_file: Union[str, os.PathLike]) -> None: subfolder_path = run_dir / subfolder subfolder_path.mkdir(parents=True, exist_ok=True) - # Process the datasets as desired: - def prepare_dataset( - image_dir: Union[str, os.PathLike], - grace_dir: Union[str, os.PathLike], - num_hops: int | str, - connection: str = "spiderweb", - verbose: bool = True, - ) -> tuple[list]: - # Read the data & terate through images & extract node features: - input_data = ImageGraphDataset( - image_dir=image_dir, - grace_dir=grace_dir, - image_filetype=config.filetype, - keep_node_unknown_labels=config.keep_node_unknown_labels, - keep_edge_unknown_labels=config.keep_edge_unknown_labels, - ) - - # Process the (sub)graph data into torch_geometric dataset: - target_list, subgraph_dataset = [], [] - desc = "Extracting patch features from images... " - for _, target in tqdm(input_data, desc=desc, disable=not verbose): - file_name = target["metadata"]["image_filename"] - LOGGER.info(f"Processing file: {file_name}") - - # Store the valid graph list: - target_list.append(target) - - # Chop graph into subgraphs & store: - graph_data = dataset_from_graph( - target["graph"], - num_hops=num_hops, - connection=connection, - ) - subgraph_dataset.extend(graph_data) - - return target_list, subgraph_dataset + # Create a transform function with frozen arguments: + check_and_chop_partial = partial( + check_and_chop_dataset, + filetype=config.filetype, + node_feature_ndim=config.feature_dim, + edge_property_len=len(EdgeProps), + keep_node_unknown_labels=config.keep_node_unknown_labels, + keep_edge_unknown_labels=config.keep_edge_unknown_labels, + connection=config.connection, + store_permanently=config.store_graph_attributes_permanently, + extractor_fn=config.extractor_fn, + ) # Read the respective datasets: - _, train_dataset = prepare_dataset( + _, train_dataset = check_and_chop_partial( config.train_image_dir, config.train_grace_dir, num_hops=config.num_hops, - connection=config.connection, ) - valid_target_list, valid_dataset = prepare_dataset( + valid_target_list, valid_dataset = check_and_chop_partial( config.valid_image_dir, config.valid_grace_dir, num_hops=config.num_hops, - connection=config.connection, ) - infer_target_list, _ = prepare_dataset( + infer_target_list, _ = check_and_chop_partial( config.infer_image_dir, config.infer_grace_dir, num_hops="whole", @@ -211,14 +185,13 @@ def prepare_dataset( save_figure=run_dir / "infer", show_figure=False, ) - if config.classifier_type != "GAT": - continue - GLP.visualise_attention_weights_on_graph( - G=infer_graph, - graph_filename=fn, - save_figure=run_dir / "infer", - show_figure=False, - ) + if config.classifier_type == "GAT": + GLP.visualise_attention_weights_on_graph( + G=infer_graph, + graph_filename=fn, + save_figure=run_dir / "infer", + show_figure=False, + ) # Generate GT & optimised graphs: true_graph = generate_ground_truth_graph(infer_graph) diff --git a/grace/training/build.py b/grace/training/build.py new file mode 100644 index 0000000..63513ac --- /dev/null +++ b/grace/training/build.py @@ -0,0 +1,170 @@ +from pathlib import Path +from tqdm.auto import tqdm + +from grace.styling import LOGGER +from grace.base import GraphAttrs, EdgeProps +from grace.models.datasets import dataset_from_graph + +from grace.io.image_dataset import ImageGraphDataset +from grace.io.store_node_features import store_node_features_in_graph +from grace.io.store_edge_properties import store_edge_properties_in_graph + + +def check_and_chop_dataset( + image_dir: str | Path, + grace_dir: str | Path, + filetype: str, + node_feature_ndim: int, + edge_property_len: int, + keep_node_unknown_labels: bool, + keep_edge_unknown_labels: bool, + num_hops: int | str, + connection: str = "spiderweb", + store_permanently: bool = False, + extractor_fn: str | Path = None, +): + # Check if datasets are ready for training: + dataset_ready_for_training = check_dataset_requirements( + image_dir=image_dir, + grace_dir=grace_dir, + filetype=filetype, + node_feature_ndim=node_feature_ndim, + edge_property_len=edge_property_len, + ) + if not dataset_ready_for_training: + if store_permanently is True: + assert extractor_fn is not None, "Provide feature extractor" + + # Inform the user about the delay: + LOGGER.warning( + "\n\nComputing node features & edge properties for data in " + f"{grace_dir}. Expect to take ~30-40 seconds per file...\n\n" + ) + store_node_features_in_graph(grace_dir, extractor_fn) + store_edge_properties_in_graph(grace_dir) + else: + raise GraceGraphError(grace_dir=grace_dir) + + # Now that you have the files with node features & edge properties: + target_list, subgraph_dataset = prepare_dataset_subgraphs( + image_dir=image_dir, + grace_dir=grace_dir, + image_filetype=filetype, + keep_node_unknown_labels=keep_node_unknown_labels, + keep_edge_unknown_labels=keep_edge_unknown_labels, + num_hops=num_hops, + connection=connection, + ) + return target_list, subgraph_dataset + + +def check_dataset_requirements( + image_dir: str | Path, + grace_dir: str | Path, + filetype: str, + node_feature_ndim: int, + edge_property_len: int, +) -> tuple[list]: + # Read the data & terate through images & extract node features: + dataset_ready_for_training = True + + input_data = ImageGraphDataset( + image_dir=image_dir, + grace_dir=grace_dir, + image_filetype=filetype, + verbose=False, + ) + + # Process the (sub)graph data into torch_geometric dataset: + desc = f"Chopping subgraphs from {grace_dir}" + for _, target in tqdm(input_data, desc=desc): + # Graph sanity checks: NODE_FEATURES: + + for _, node in target["graph"].nodes(data=True): + if GraphAttrs.NODE_FEATURES not in node: + dataset_ready_for_training = False + break + node_features = node[GraphAttrs.NODE_FEATURES] + if node_features is None: + dataset_ready_for_training = False + break + if node_features.shape[0] != node_feature_ndim: + dataset_ready_for_training = False + break + + # Graph sanity checks: EDGE_PROPERTIES: + for _, _, edge in target["graph"].edges(data=True): + if GraphAttrs.EDGE_PROPERTIES not in edge: + dataset_ready_for_training = False + break + edge_properties = edge[GraphAttrs.EDGE_PROPERTIES] + if edge_properties is None: + dataset_ready_for_training = False + break + edge_properties = edge_properties.properties_dict + if edge_properties is None: + dataset_ready_for_training = False + break + if len(edge_properties) < edge_property_len: + dataset_ready_for_training = False + break + if not all([item in edge_properties for item in EdgeProps]): + dataset_ready_for_training = False + break + + return dataset_ready_for_training + + +def prepare_dataset_subgraphs( + image_dir: str | Path, + grace_dir: str | Path, + *, + image_filetype: str, + keep_node_unknown_labels: bool, + keep_edge_unknown_labels: bool, + num_hops: int | str, + connection: str = "spiderweb", +) -> tuple[list]: + # Read the data & terate through images & extract node features: + input_data = ImageGraphDataset( + image_dir=image_dir, + grace_dir=grace_dir, + image_filetype=image_filetype, + keep_node_unknown_labels=keep_node_unknown_labels, + keep_edge_unknown_labels=keep_edge_unknown_labels, + ) + + # Process the (sub)graph data into torch_geometric dataset: + target_list, subgraph_dataset = [], [] + for _, target in input_data: + # Store the valid graph list with the updated target: + target_list.append(target) + + # Now, process the graph with all attributes & chop into subgraphs & store: + graph_data = dataset_from_graph( + target["graph"], + num_hops=num_hops, + connection=connection, + ) + subgraph_dataset.extend(graph_data) + + return target_list, subgraph_dataset + + +class GraceGraphError(Exception): + def __init__(self, grace_dir): + super().__init__( + "\n\nThe GRACE annotation files don't contain the proper node " + "features & edge attributes for training \nin the `grace_dir` " + f"= '{grace_dir}'\n\nPlease consider:\n\n(i) changing your config" + " 'store_graph_attributes_permanently' argument to 'True', which " + "will automatically compute & store the graph attributes & or \n" + "(ii) manually run the scripts below for all your data paths, " + "incl. 'train', 'valid' & 'infer' before launching the next run:" + "\n\n\t`python3 grace/io/store_edge_properties.py --data_path=" + "/path/to/your/data` \nand" + "\n\t`python3 grace/io/store_node_features.py --data_path=" + "/path/to/your/data --extractor_fn=/path/to/feature/extractor.pt`" + "\n\nThis will compute required graph attributes & store them " + "in the GRACE annotation file collection, avoiding this error.\n" + ) diff --git a/grace/training/config.py b/grace/training/config.py index 63484f7..f76ee06 100644 --- a/grace/training/config.py +++ b/grace/training/config.py @@ -27,10 +27,11 @@ class Config: keep_edge_unknown_labels: bool = False # Feature extraction: + store_graph_attributes_permanently: bool = False extractor_fn: Optional[os.PathLike] = None patch_size: tuple[int] = (224, 224) - feature_dim: int = 2048 normalize: tuple[bool] = (False, False) + feature_dim: int = 2048 # Augmentations: img_graph_augs: list[str] = field( diff --git a/grace/training/config.yaml b/grace/training/config.yaml index f9b439a..58983d2 100644 --- a/grace/training/config.yaml +++ b/grace/training/config.yaml @@ -15,10 +15,13 @@ keep_edge_unknown_labels: False # relabels UNKNOWN edges to TRUE_NEGATIVE # Feature extractor path & patch normalization: extractor_fn: /path/to/your/extractor/resnet152.pt +store_graph_attributes_permanently: False # write out graph attributes + +feature_dim: 2048 # output dimensionality: 1D feature vector + patch_size: # input dimensionality: 2D patch shape - 224 - 224 -feature_dim: 2048 # output dimensionality: 1D feature vector normalise: # [0-1] image patch standardisation - False # before augmentations - False # after augmentations diff --git a/tests/test_augmentation.py b/tests/test_augmentation.py index 89d75e4..85c1c86 100644 --- a/tests/test_augmentation.py +++ b/tests/test_augmentation.py @@ -9,7 +9,8 @@ from grace.base import GraphAttrs, Annotation from grace.utils.augment_image import ( RandomEdgeCrop, -) # , RandomImageGraphRotate + RandomImageGraphRotate, +) from grace.utils.augment_graph import ( find_average_annotation, RandomEdgeAdditionAndRemoval, @@ -146,51 +147,52 @@ def test_augment_random_edge_crop(n, max_fraction, fraction, num_rot): ] -# @pytest.mark.parametrize( -# "n, rot_angle, rot_center", -# [ -# (0, 0, None), -# (1, 90, None), -# (2, 45, None), -# (3, 30, None), -# (4, 28, None), -# (5, 28, [2, 2]), -# ], -# ) -# def test_augment_rotate_image_and_graph(n, rot_angle, rot_center): -# print ("hello-1") -# with patch("numpy.random.default_rng") as mock: -# rng = mock.return_value -# rng.uniform.return_value = 0 -# rng.integers.return_value = augment_rotate_coords[n] -# rng.choice.return_value = [1] * 4 -# image, graph = random_image_and_graph(rng, image_size=(6, 6)) -# print ("hello-2") - -# image = torch.tensor(image.astype("int16")) -# target = {"graph": graph} - -# with patch("numpy.random.default_rng") as mock: -# rng = mock.return_value -# rng.uniform.return_value = rot_angle -# transform = RandomImageGraphRotate(rot_center=rot_center, rng=rng) -# image, target = transform(image, target) -# print ("hello-3") - -# augmented_img_coords = np.where(image.squeeze().numpy()) -# augmented_float_coords = np.array( -# [ -# [f[GraphAttrs.NODE_X], f[GraphAttrs.NODE_Y]] -# for f in target["graph"].nodes.values() -# ], -# dtype=np.float32, -# ) -# print ("hello-4") - -# assert np.array_equal(expected_end_coords_img[n], augmented_img_coords) -# assert np.allclose( -# expected_end_coords_float[n], augmented_float_coords, atol=0.01 -# ) +@pytest.mark.skip(reason="augmentations throw an error for some reason") +@pytest.mark.parametrize( + "n, rot_angle, rot_center", + [ + (0, 0, None), + (1, 90, None), + (2, 45, None), + (3, 30, None), + (4, 28, None), + (5, 28, [2, 2]), + ], +) +def test_augment_rotate_image_and_graph(n, rot_angle, rot_center): + print("hello-1") + with patch("numpy.random.default_rng") as mock: + rng = mock.return_value + rng.uniform.return_value = 0 + rng.integers.return_value = augment_rotate_coords[n] + rng.choice.return_value = [1] * 4 + image, graph = random_image_and_graph(rng, image_size=(6, 6)) + print("hello-2") + + image = torch.tensor(image.astype("int16")) + target = {"graph": graph} + + with patch("numpy.random.default_rng") as mock: + rng = mock.return_value + rng.uniform.return_value = rot_angle + transform = RandomImageGraphRotate(rot_center=rot_center, rng=rng) + image, target = transform(image, target) + print("hello-3") + + augmented_img_coords = np.where(image.squeeze().numpy()) + augmented_float_coords = np.array( + [ + [f[GraphAttrs.NODE_X], f[GraphAttrs.NODE_Y]] + for f in target["graph"].nodes.values() + ], + dtype=np.float32, + ) + print("hello-4") + + assert np.array_equal(expected_end_coords_img[n], augmented_img_coords) + assert np.allclose( + expected_end_coords_float[n], augmented_float_coords, atol=0.01 + ) @pytest.mark.parametrize( diff --git a/tests/test_extractor.py b/tests/test_extractor.py index 1bf5fb5..0b4b71b 100644 --- a/tests/test_extractor.py +++ b/tests/test_extractor.py @@ -1,8 +1,8 @@ -# import math +import math import torch import pytest -from grace.base import GraphAttrs # , Annotation +from grace.base import GraphAttrs, Annotation from grace.models.feature_extractor import FeatureExtractor from conftest import random_image_and_graph @@ -55,47 +55,48 @@ def test_feature_extractor_forward(self, bbox_size, model, vars): if bbox_image.shape == bbox_size: assert features == model(bbox_image) - # @pytest.mark.parametrize("keep_patch_fraction", [0.3, 0.5, 0.7]) - # def test_feature_extractor_rejects_edge_touching_boxes( - # self, bbox_size, model, vars, keep_patch_fraction - # ): - # """TODO: There's a bug here if parametrised with: - # "keep_patch_fraction", [0.3, 0.5, 0.7] """ - # extractor = vars["extractor"] - # image = vars["image"] - # graph = vars["graph"] - - # setattr(extractor, "keep_patch_fraction", keep_patch_fraction) - # graph.add_node( - # 4, - # **{ - # GraphAttrs.NODE_X: bbox_size[0] * 0.5 - # + math.ceil( - # image.size(-1) - bbox_size[0] * keep_patch_fraction * 0.99 - # ), - # GraphAttrs.NODE_Y: 0, - # GraphAttrs.NODE_GROUND_TRUTH: Annotation.TRUE_POSITIVE, - # }, - # ) - # graph.add_node( - # 5, - # **{ - # GraphAttrs.NODE_X: 0, - # GraphAttrs.NODE_Y: bbox_size[1] * 0.5 - # + math.floor(-bbox_size[1] * keep_patch_fraction * 1.01), - # GraphAttrs.NODE_GROUND_TRUTH: Annotation.TRUE_POSITIVE, - # }, - # ) - - # _, target_out = extractor(image, {"graph": graph}) - # graph_out = target_out["graph"] - # print(bbox_size) - # print(graph_out.number_of_nodes(), graph_out.nodes(data=True)) - # labels = [ - # node_attr[GraphAttrs.NODE_GROUND_TRUTH] - # for _, node_attr in graph_out.nodes(data=True) - # ] - - # num_unknown = len([lab for lab in labels if lab == Annotation.UNKNOWN]) - - # assert num_unknown == 2 + @pytest.mark.skip(reason="keep patch fraction needs to be re-implemented") + @pytest.mark.parametrize("keep_patch_fraction", [0.3, 0.5, 0.7]) + def test_feature_extractor_rejects_edge_touching_boxes( + self, bbox_size, model, vars, keep_patch_fraction + ): + """TODO: There's a bug here if parametrised with: + "keep_patch_fraction", [0.3, 0.5, 0.7]""" + extractor = vars["extractor"] + image = vars["image"] + graph = vars["graph"] + + setattr(extractor, "keep_patch_fraction", keep_patch_fraction) + graph.add_node( + 4, + **{ + GraphAttrs.NODE_X: bbox_size[0] * 0.5 + + math.ceil( + image.size(-1) - bbox_size[0] * keep_patch_fraction * 0.99 + ), + GraphAttrs.NODE_Y: 0, + GraphAttrs.NODE_GROUND_TRUTH: Annotation.TRUE_POSITIVE, + }, + ) + graph.add_node( + 5, + **{ + GraphAttrs.NODE_X: 0, + GraphAttrs.NODE_Y: bbox_size[1] * 0.5 + + math.floor(-bbox_size[1] * keep_patch_fraction * 1.01), + GraphAttrs.NODE_GROUND_TRUTH: Annotation.TRUE_POSITIVE, + }, + ) + + _, target_out = extractor(image, {"graph": graph}) + graph_out = target_out["graph"] + print(bbox_size) + print(graph_out.number_of_nodes(), graph_out.nodes(data=True)) + labels = [ + node_attr[GraphAttrs.NODE_GROUND_TRUTH] + for _, node_attr in graph_out.nodes(data=True) + ] + + num_unknown = len([lab for lab in labels if lab == Annotation.UNKNOWN]) + + assert num_unknown == 2 diff --git a/tests/test_run.py b/tests/test_run.py index 7ba62a3..2a753b8 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -1,33 +1,55 @@ -# import torch - -# from grace.run import run_grace -# from grace.training.config import Config, write_config_file - - -# def test_run_grace(mrc_image_and_annotations_dir, simple_extractor): -# tmp_data_dir = mrc_image_and_annotations_dir - -# # temp extractor -# extractor_fn = tmp_data_dir / "extractor.pt" -# torch.save(simple_extractor, extractor_fn) - -# config = Config( -# train_image_dir=tmp_data_dir, -# train_grace_dir=tmp_data_dir, -# valid_image_dir=tmp_data_dir, -# valid_grace_dir=tmp_data_dir, -# infer_image_dir=tmp_data_dir, -# infer_grace_dir=tmp_data_dir, -# log_dir=tmp_data_dir, -# run_dir=tmp_data_dir, -# extractor_fn=extractor_fn, -# epochs=3, -# batch_size=1, -# patch_size=(1, 1), -# feature_dim=2, -# ) -# write_config_file(config, filetype="json") - -# # run -# config_fn = tmp_data_dir / "config_hyperparams.json" -# run_grace(config_file=config_fn) +import pytest +import torch + +from grace.run import run_grace +from grace.training.config import Config, write_config_file + + +def run_grace_training( + mrc_image_and_annotations_dir, + simple_extractor, + store_graph_attributes_permanently, +): + tmp_data_dir = mrc_image_and_annotations_dir + + # temp extractor + extractor_fn = tmp_data_dir / "extractor.pt" + torch.save(simple_extractor, extractor_fn) + + config = Config( + train_image_dir=tmp_data_dir, + train_grace_dir=tmp_data_dir, + valid_image_dir=tmp_data_dir, + valid_grace_dir=tmp_data_dir, + infer_image_dir=tmp_data_dir, + infer_grace_dir=tmp_data_dir, + log_dir=tmp_data_dir, + run_dir=tmp_data_dir, + extractor_fn=extractor_fn, + epochs=1, + batch_size=1, + patch_size=(1, 1), + feature_dim=2, + store_graph_attributes_permanently=store_graph_attributes_permanently, + ) + write_config_file(config, filetype="json") + + # run + config_fn = tmp_data_dir / "config_hyperparams.json" + run_grace(config_file=config_fn) + + +@pytest.mark.parametrize("store_graph_attributes_permanently", [False, True]) +@pytest.mark.xfail( + reason="sample graph contains no node features & edge properties" +) +def test_run_grace( + mrc_image_and_annotations_dir, + simple_extractor, + store_graph_attributes_permanently, +): + run_grace_training( + mrc_image_and_annotations_dir, + simple_extractor, + store_graph_attributes_permanently, + )