-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simplify run.py
πββοΈ
#308
Closed
Closed
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
85e1a65
Builder & checker for train, valid & infer datasets
KristinaUlicna 225308f
Include tqdm progress bar
KristinaUlicna a48cfae
Simplify run.py script
KristinaUlicna 3058983
Uncommented failing tests with @pytest.mark.skip
KristinaUlicna cf3dce9
Revitalise test_run with xfail
KristinaUlicna d3fae06
Build dataset for training run session
KristinaUlicna 876bb8e
Optional hparam to overwrite files
KristinaUlicna 893753c
Simplify run with partial fn specs
KristinaUlicna 31c1a4e
Merge branch 'development' into extractor
KristinaUlicna a60029d
Add storing hparam to example config.yaml
KristinaUlicna File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
from pathlib import Path | ||
from tqdm.auto import tqdm | ||
|
||
from grace.styling import LOGGER | ||
from grace.base import GraphAttrs, EdgeProps | ||
from grace.models.datasets import dataset_from_graph | ||
|
||
from grace.io.image_dataset import ImageGraphDataset | ||
from grace.io.store_node_features import store_node_features_in_graph | ||
from grace.io.store_edge_properties import store_edge_properties_in_graph | ||
|
||
|
||
def check_and_chop_dataset( | ||
image_dir: str | Path, | ||
grace_dir: str | Path, | ||
filetype: str, | ||
node_feature_ndim: int, | ||
edge_property_len: int, | ||
keep_node_unknown_labels: bool, | ||
keep_edge_unknown_labels: bool, | ||
num_hops: int | str, | ||
connection: str = "spiderweb", | ||
store_permanently: bool = False, | ||
extractor_fn: str | Path = None, | ||
): | ||
# Check if datasets are ready for training: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar as described for run.py maybe add some docstring description of what these functions are about? |
||
dataset_ready_for_training = check_dataset_requirements( | ||
image_dir=image_dir, | ||
grace_dir=grace_dir, | ||
filetype=filetype, | ||
node_feature_ndim=node_feature_ndim, | ||
edge_property_len=edge_property_len, | ||
) | ||
if not dataset_ready_for_training: | ||
if store_permanently is True: | ||
assert extractor_fn is not None, "Provide feature extractor" | ||
|
||
# Inform the user about the delay: | ||
LOGGER.warning( | ||
"\n\nComputing node features & edge properties for data in " | ||
f"{grace_dir}. Expect to take ~30-40 seconds per file...\n\n" | ||
) | ||
store_node_features_in_graph(grace_dir, extractor_fn) | ||
store_edge_properties_in_graph(grace_dir) | ||
else: | ||
raise GraceGraphError(grace_dir=grace_dir) | ||
|
||
# Now that you have the files with node features & edge properties: | ||
target_list, subgraph_dataset = prepare_dataset_subgraphs( | ||
image_dir=image_dir, | ||
grace_dir=grace_dir, | ||
image_filetype=filetype, | ||
keep_node_unknown_labels=keep_node_unknown_labels, | ||
keep_edge_unknown_labels=keep_edge_unknown_labels, | ||
num_hops=num_hops, | ||
connection=connection, | ||
) | ||
return target_list, subgraph_dataset | ||
|
||
|
||
def check_dataset_requirements( | ||
image_dir: str | Path, | ||
grace_dir: str | Path, | ||
filetype: str, | ||
node_feature_ndim: int, | ||
edge_property_len: int, | ||
) -> tuple[list]: | ||
# Read the data & terate through images & extract node features: | ||
dataset_ready_for_training = True | ||
|
||
input_data = ImageGraphDataset( | ||
image_dir=image_dir, | ||
grace_dir=grace_dir, | ||
image_filetype=filetype, | ||
verbose=False, | ||
) | ||
|
||
# Process the (sub)graph data into torch_geometric dataset: | ||
desc = f"Chopping subgraphs from {grace_dir}" | ||
for _, target in tqdm(input_data, desc=desc): | ||
# Graph sanity checks: NODE_FEATURES: | ||
|
||
for _, node in target["graph"].nodes(data=True): | ||
if GraphAttrs.NODE_FEATURES not in node: | ||
dataset_ready_for_training = False | ||
break | ||
node_features = node[GraphAttrs.NODE_FEATURES] | ||
if node_features is None: | ||
dataset_ready_for_training = False | ||
break | ||
if node_features.shape[0] != node_feature_ndim: | ||
dataset_ready_for_training = False | ||
break | ||
|
||
# Graph sanity checks: EDGE_PROPERTIES: | ||
for _, _, edge in target["graph"].edges(data=True): | ||
if GraphAttrs.EDGE_PROPERTIES not in edge: | ||
dataset_ready_for_training = False | ||
break | ||
edge_properties = edge[GraphAttrs.EDGE_PROPERTIES] | ||
if edge_properties is None: | ||
dataset_ready_for_training = False | ||
break | ||
edge_properties = edge_properties.properties_dict | ||
if edge_properties is None: | ||
dataset_ready_for_training = False | ||
break | ||
if len(edge_properties) < edge_property_len: | ||
dataset_ready_for_training = False | ||
break | ||
if not all([item in edge_properties for item in EdgeProps]): | ||
dataset_ready_for_training = False | ||
break | ||
|
||
return dataset_ready_for_training | ||
|
||
|
||
def prepare_dataset_subgraphs( | ||
image_dir: str | Path, | ||
grace_dir: str | Path, | ||
*, | ||
image_filetype: str, | ||
keep_node_unknown_labels: bool, | ||
keep_edge_unknown_labels: bool, | ||
num_hops: int | str, | ||
connection: str = "spiderweb", | ||
) -> tuple[list]: | ||
# Read the data & terate through images & extract node features: | ||
input_data = ImageGraphDataset( | ||
image_dir=image_dir, | ||
grace_dir=grace_dir, | ||
image_filetype=image_filetype, | ||
keep_node_unknown_labels=keep_node_unknown_labels, | ||
keep_edge_unknown_labels=keep_edge_unknown_labels, | ||
) | ||
|
||
# Process the (sub)graph data into torch_geometric dataset: | ||
target_list, subgraph_dataset = [], [] | ||
for _, target in input_data: | ||
# Store the valid graph list with the updated target: | ||
target_list.append(target) | ||
|
||
# Now, process the graph with all attributes & chop into subgraphs & store: | ||
graph_data = dataset_from_graph( | ||
target["graph"], | ||
num_hops=num_hops, | ||
connection=connection, | ||
) | ||
subgraph_dataset.extend(graph_data) | ||
|
||
return target_list, subgraph_dataset | ||
|
||
|
||
class GraceGraphError(Exception): | ||
def __init__(self, grace_dir): | ||
super().__init__( | ||
"\n\nThe GRACE annotation files don't contain the proper node " | ||
"features & edge attributes for training \nin the `grace_dir` " | ||
f"= '{grace_dir}'\n\nPlease consider:\n\n(i) changing your config" | ||
" 'store_graph_attributes_permanently' argument to 'True', which " | ||
"will automatically compute & store the graph attributes & or \n" | ||
"(ii) manually run the scripts below for all your data paths, " | ||
"incl. 'train', 'valid' & 'infer' before launching the next run:" | ||
"\n\n\t`python3 grace/io/store_edge_properties.py --data_path=" | ||
"/path/to/your/data` \nand" | ||
"\n\t`python3 grace/io/store_node_features.py --data_path=" | ||
"/path/to/your/data --extractor_fn=/path/to/feature/extractor.pt`" | ||
"\n\nThis will compute required graph attributes & store them " | ||
"in the GRACE annotation file collection, avoiding this error.\n" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find the name of this function a bit confusing, it is not very clear to what is happening here. As is in the main run.py scrip maybe we should add more comments to explain what is happening?