Skip to content

Commit

Permalink
Merge branch 'development' into graph-laplacian
Browse files Browse the repository at this point in the history
  • Loading branch information
KristinaUlicna committed Oct 24, 2023
2 parents b5e4409 + 070283c commit cb2dfc8
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 123 deletions.
90 changes: 44 additions & 46 deletions tests/test_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from grace.base import GraphAttrs, Annotation
from grace.utils.augment_image import (
RandomEdgeCrop,
) # , RandomImageGraphRotate
RandomImageGraphRotate,
)
from grace.utils.augment_graph import (
find_average_annotation,
RandomEdgeAdditionAndRemoval,
Expand Down Expand Up @@ -146,51 +147,48 @@ def test_augment_random_edge_crop(n, max_fraction, fraction, num_rot):
]


# @pytest.mark.parametrize(
# "n, rot_angle, rot_center",
# [
# (0, 0, None),
# (1, 90, None),
# (2, 45, None),
# (3, 30, None),
# (4, 28, None),
# (5, 28, [2, 2]),
# ],
# )
# def test_augment_rotate_image_and_graph(n, rot_angle, rot_center):
# print ("hello-1")
# with patch("numpy.random.default_rng") as mock:
# rng = mock.return_value
# rng.uniform.return_value = 0
# rng.integers.return_value = augment_rotate_coords[n]
# rng.choice.return_value = [1] * 4
# image, graph = random_image_and_graph(rng, image_size=(6, 6))
# print ("hello-2")

# image = torch.tensor(image.astype("int16"))
# target = {"graph": graph}

# with patch("numpy.random.default_rng") as mock:
# rng = mock.return_value
# rng.uniform.return_value = rot_angle
# transform = RandomImageGraphRotate(rot_center=rot_center, rng=rng)
# image, target = transform(image, target)
# print ("hello-3")

# augmented_img_coords = np.where(image.squeeze().numpy())
# augmented_float_coords = np.array(
# [
# [f[GraphAttrs.NODE_X], f[GraphAttrs.NODE_Y]]
# for f in target["graph"].nodes.values()
# ],
# dtype=np.float32,
# )
# print ("hello-4")

# assert np.array_equal(expected_end_coords_img[n], augmented_img_coords)
# assert np.allclose(
# expected_end_coords_float[n], augmented_float_coords, atol=0.01
# )
@pytest.mark.skip(reason="augmentations throw an error for some reason")
@pytest.mark.parametrize(
"n, rot_angle, rot_center",
[
(0, 0, None),
(1, 90, None),
(2, 45, None),
(3, 30, None),
(4, 28, None),
(5, 28, [2, 2]),
],
)
def test_augment_rotate_image_and_graph(n, rot_angle, rot_center):
with patch("numpy.random.default_rng") as mock:
rng = mock.return_value
rng.uniform.return_value = 0
rng.integers.return_value = augment_rotate_coords[n]
rng.choice.return_value = [1] * 4
image, graph = random_image_and_graph(rng, image_size=(6, 6))

image = torch.tensor(image.astype("int16"))
target = {"graph": graph}

with patch("numpy.random.default_rng") as mock:
rng = mock.return_value
rng.uniform.return_value = rot_angle
transform = RandomImageGraphRotate(rot_center=rot_center, rng=rng)
image, target = transform(image, target)

augmented_img_coords = np.where(image.squeeze().numpy())
augmented_float_coords = np.array(
[
[f[GraphAttrs.NODE_X], f[GraphAttrs.NODE_Y]]
for f in target["graph"].nodes.values()
],
dtype=np.float32,
)

assert np.array_equal(expected_end_coords_img[n], augmented_img_coords)
assert np.allclose(
expected_end_coords_float[n], augmented_float_coords, atol=0.01
)


@pytest.mark.parametrize(
Expand Down
93 changes: 47 additions & 46 deletions tests/test_extractor.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# import math
import math
import torch
import pytest

from grace.base import GraphAttrs # , Annotation
from grace.base import GraphAttrs, Annotation
from grace.models.feature_extractor import FeatureExtractor

from conftest import random_image_and_graph
Expand Down Expand Up @@ -55,47 +55,48 @@ def test_feature_extractor_forward(self, bbox_size, model, vars):
if bbox_image.shape == bbox_size:
assert features == model(bbox_image)

# @pytest.mark.parametrize("keep_patch_fraction", [0.3, 0.5, 0.7])
# def test_feature_extractor_rejects_edge_touching_boxes(
# self, bbox_size, model, vars, keep_patch_fraction
# ):
# """TODO: There's a bug here if parametrised with:
# "keep_patch_fraction", [0.3, 0.5, 0.7] """
# extractor = vars["extractor"]
# image = vars["image"]
# graph = vars["graph"]

# setattr(extractor, "keep_patch_fraction", keep_patch_fraction)
# graph.add_node(
# 4,
# **{
# GraphAttrs.NODE_X: bbox_size[0] * 0.5
# + math.ceil(
# image.size(-1) - bbox_size[0] * keep_patch_fraction * 0.99
# ),
# GraphAttrs.NODE_Y: 0,
# GraphAttrs.NODE_GROUND_TRUTH: Annotation.TRUE_POSITIVE,
# },
# )
# graph.add_node(
# 5,
# **{
# GraphAttrs.NODE_X: 0,
# GraphAttrs.NODE_Y: bbox_size[1] * 0.5
# + math.floor(-bbox_size[1] * keep_patch_fraction * 1.01),
# GraphAttrs.NODE_GROUND_TRUTH: Annotation.TRUE_POSITIVE,
# },
# )

# _, target_out = extractor(image, {"graph": graph})
# graph_out = target_out["graph"]
# print(bbox_size)
# print(graph_out.number_of_nodes(), graph_out.nodes(data=True))
# labels = [
# node_attr[GraphAttrs.NODE_GROUND_TRUTH]
# for _, node_attr in graph_out.nodes(data=True)
# ]

# num_unknown = len([lab for lab in labels if lab == Annotation.UNKNOWN])

# assert num_unknown == 2
@pytest.mark.skip(reason="keep patch fraction needs to be re-implemented")
@pytest.mark.parametrize("keep_patch_fraction", [0.3, 0.5, 0.7])
def test_feature_extractor_rejects_edge_touching_boxes(
self, bbox_size, model, vars, keep_patch_fraction
):
"""TODO: There's a bug here if parametrised with:
"keep_patch_fraction", [0.3, 0.5, 0.7]"""
extractor = vars["extractor"]
image = vars["image"]
graph = vars["graph"]

setattr(extractor, "keep_patch_fraction", keep_patch_fraction)
graph.add_node(
4,
**{
GraphAttrs.NODE_X: bbox_size[0] * 0.5
+ math.ceil(
image.size(-1) - bbox_size[0] * keep_patch_fraction * 0.99
),
GraphAttrs.NODE_Y: 0,
GraphAttrs.NODE_GROUND_TRUTH: Annotation.TRUE_POSITIVE,
},
)
graph.add_node(
5,
**{
GraphAttrs.NODE_X: 0,
GraphAttrs.NODE_Y: bbox_size[1] * 0.5
+ math.floor(-bbox_size[1] * keep_patch_fraction * 1.01),
GraphAttrs.NODE_GROUND_TRUTH: Annotation.TRUE_POSITIVE,
},
)

_, target_out = extractor(image, {"graph": graph})
graph_out = target_out["graph"]
print(bbox_size)
print(graph_out.number_of_nodes(), graph_out.nodes(data=True))
labels = [
node_attr[GraphAttrs.NODE_GROUND_TRUTH]
for _, node_attr in graph_out.nodes(data=True)
]

num_unknown = len([lab for lab in labels if lab == Annotation.UNKNOWN])

assert num_unknown == 2
95 changes: 64 additions & 31 deletions tests/test_run.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,66 @@
# import torch

# from grace.run import run_grace
# from grace.training.config import Config, write_config_file


# def test_run_grace(mrc_image_and_annotations_dir, simple_extractor):
# tmp_data_dir = mrc_image_and_annotations_dir

# # temp extractor
# extractor_fn = tmp_data_dir / "extractor.pt"
# torch.save(simple_extractor, extractor_fn)

# config = Config(
# train_image_dir=tmp_data_dir,
# train_grace_dir=tmp_data_dir,
# valid_image_dir=tmp_data_dir,
# valid_grace_dir=tmp_data_dir,
# infer_image_dir=tmp_data_dir,
# infer_grace_dir=tmp_data_dir,
# log_dir=tmp_data_dir,
# run_dir=tmp_data_dir,
# extractor_fn=extractor_fn,
# epochs=3,
# batch_size=1,
# patch_size=(1, 1),
# feature_dim=2,
import pytest
import torch

from grace.run import run_grace
from grace.training.config import Config, write_config_file


@pytest.mark.xfail(
reason="sample graph contains no node features & edge properties"
)
def test_run_grace(mrc_image_and_annotations_dir, simple_extractor):
tmp_data_dir = mrc_image_and_annotations_dir

# temp extractor
extractor_fn = tmp_data_dir / "extractor.pt"
torch.save(simple_extractor, extractor_fn)

config = Config(
train_image_dir=tmp_data_dir,
train_grace_dir=tmp_data_dir,
valid_image_dir=tmp_data_dir,
valid_grace_dir=tmp_data_dir,
infer_image_dir=tmp_data_dir,
infer_grace_dir=tmp_data_dir,
log_dir=tmp_data_dir,
run_dir=tmp_data_dir,
extractor_fn=extractor_fn,
epochs=1,
batch_size=1,
patch_size=(1, 1),
feature_dim=2,
)
write_config_file(config, filetype="json")

# run
config_fn = tmp_data_dir / "config_hyperparams.json"
run_grace(config_file=config_fn)


# def test_run_grace_without_required_graph_attributes(
# mrc_image_and_annotations_dir,
# simple_extractor,
# ):
# run_grace_training(
# mrc_image_and_annotations_dir,
# simple_extractor,
# store_graph_attributes_permanently,
# )
# write_config_file(config, filetype="json")

# # run
# config_fn = tmp_data_dir / "config_hyperparams.json"
# run_grace(config_file=config_fn)

# @pytest.mark.parametrize(
# "store_graph_attributes_permanently",
# [
# True,
# ],
# )
# def test_run_grace_if_graph_attribute_computation_allowed(
# mrc_image_and_annotations_dir,
# simple_extractor,
# store_graph_attributes_permanently,
# ):
# run_grace_training(
# mrc_image_and_annotations_dir,
# simple_extractor,
# store_graph_attributes_permanently,
# )

0 comments on commit cb2dfc8

Please sign in to comment.