Skip to content

Commit

Permalink
Merge pull request #286 from alan-turing-institute/graph-features
Browse files Browse the repository at this point in the history
Store graph `NODE_FEATURES` in the graph once for development 🌐
  • Loading branch information
KristinaUlicna authored Oct 3, 2023
2 parents b0caf2c + b0258db commit c88895a
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 17 deletions.
38 changes: 23 additions & 15 deletions grace/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,28 +56,35 @@ def run_grace(config_file: Union[str, os.PathLike]) -> None:
subfolder_path = run_dir / subfolder
subfolder_path.mkdir(parents=True, exist_ok=True)

# Prepare the feature extractor:
extractor_model = torch.load(config.extractor_fn)
patch_augs = get_transforms(config, "patch")
# Augmentations, if any:
img_patch_augs = get_transforms(config, "patch")
img_graph_augs = get_transforms(config, "graph")
feature_extractor = FeatureExtractor(
model=extractor_model,
augmentations=patch_augs,
normalize=config.normalize,
bbox_size=config.patch_size,
keep_patch_fraction=config.keep_patch_fraction,
)

def return_unchanged(image, graph):
return image, graph

# Prepare the feature extractor:
if config.extractor_fn is not None:
# Feature extractor:
extractor_model = torch.load(config.extractor_fn)
feature_extractor = FeatureExtractor(
model=extractor_model,
augmentations=img_patch_augs,
normalize=config.normalize,
bbox_size=config.patch_size,
keep_patch_fraction=config.keep_patch_fraction,
)
else:
feature_extractor = return_unchanged

# Condition the augmentations to train mode only:
def transform(
image: torch.Tensor, graph: dict, *, in_train_mode: bool = True
) -> Callable:
# Ensure augmentations are only run on train data:
if in_train_mode:
image_aug, graph_aug = img_graph_augs(image, graph)
return feature_extractor(image_aug, graph_aug)
else:
return feature_extractor(image, graph)
if in_train_mode is True:
image, graph = img_graph_augs(image, graph)
return feature_extractor(image, graph)

# Process the datasets as desired:
def prepare_dataset(
Expand All @@ -89,6 +96,7 @@ def prepare_dataset(
verbose: bool = True,
) -> tuple[list]:
# Read the data & terate through images & extract node features:
print(transform_method)
input_data = ImageGraphDataset(
image_dir=image_dir,
grace_dir=grace_dir,
Expand Down
8 changes: 8 additions & 0 deletions grace/simulator/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,12 @@ DRAWING = "squares" # type of modification of the patch (extends to 'circles'
PADDING = (112, 112) # padding of the image in case boundary nodes & their patches need to be modified, too - otherwise nodes lying too close to the boundary will be left untouched
```

## Storing the node features directly in the graph:

If you want to perform a hyperparameter grid search for GNN training and you know that the (node) features of your graph dataset won't change, you can run this script to make sure you append the resnet-extracted features to your dataset graphs once and for all. It takes ~30-40 seconds per single image to get processed, so this significantly saves time if launching multiple runs on your (otherwise constant) dataset.

```sh
python3 grace/simulator/store_features.py --data_path=/Users/kulicna/Desktop/dataset/playground/infer/ --extractor_fn=/Users/kulicna/Desktop/classifier/extractor/resnet152.pt
```

Happy simulating :-)
74 changes: 74 additions & 0 deletions grace/simulator/store_features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import click
import torch

from pathlib import Path
from grace.base import GraphAttrs
from grace.io import write_graph
from grace.io.image_dataset import ImageGraphDataset
from grace.models.feature_extractor import FeatureExtractor


# Define a click command to input the file name directly:
@click.command(name="Graph Storage")
@click.option(
"--data_path",
type=click.Path(exists=True),
help="Path to images and grace annotations",
)
@click.option(
"--extractor_fn",
type=click.Path(exists=True),
help="Path to feature extractor model",
)
@click.option(
"--extractor_fn",
type=tuple[int, int],
help="Image patch shape for feature extraction",
default=(224, 224),
)
def store_node_features_in_graph(
data_path: str | Path,
extractor_fn=str | Path,
bbox_size: tuple[int, int] = (224, 224),
) -> None:
# Process the check the paths:
if isinstance(data_path, str):
data_path = Path(data_path)
assert data_path.is_dir()

if isinstance(extractor_fn, str):
extractor_fn = Path(extractor_fn)
assert extractor_fn.is_file()

# Load the feature extractor:
pre_trained_resnet = torch.load(extractor_fn)
feature_extractor = FeatureExtractor(
model=pre_trained_resnet,
bbox_size=bbox_size,
)

# Organise the image + grace annotation pairs:
dataset = ImageGraphDataset(
image_dir=data_path, grace_dir=data_path, transform=feature_extractor
)

# Unwrap each item & store the node features:
for _, target in dataset:
fn = target["metadata"]["image_filename"]
graph = target["graph"]

for _, node in graph.nodes(data=True):
node[GraphAttrs.NODE_FEATURES] = node[
GraphAttrs.NODE_FEATURES
].numpy()

write_graph(
filename=data_path / f"{fn}.grace",
graph=graph,
metadata=target["metadata"],
annotation=target["annotation"],
)


if __name__ == "__main__":
store_node_features_in_graph()
5 changes: 3 additions & 2 deletions grace/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,9 @@ def validate_required_config_hparams(config: Config) -> None:
raise PathNotDefinedError(path_name=dr)

# Check extractor is there:
if not config.extractor_fn.is_file():
raise PathNotDefinedError(path_name=dr)
if config.extractor_fn is not None:
if not config.extractor_fn.is_file():
raise PathNotDefinedError(path_name=dr)

# Validate the learning rate schedule is implemented:
assert config.scheduler_type in {"none", "step", "expo"}
Expand Down

0 comments on commit c88895a

Please sign in to comment.