Skip to content

Commit

Permalink
Modified tests based on updated classifier architecture
Browse files Browse the repository at this point in the history
  • Loading branch information
KristinaUlicna committed Oct 3, 2023
1 parent 58f7761 commit 84f019e
Show file tree
Hide file tree
Showing 12 changed files with 317 additions and 263 deletions.
58 changes: 0 additions & 58 deletions tests/_utils.py

This file was deleted.

51 changes: 50 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,62 @@
import pandas as pd
import networkx as nx
import numpy as np
import numpy.typing as npt

import torch

from grace.base import GraphAttrs, graph_from_dataframe
from grace.io.core import Annotation
from pathlib import Path

from _utils import random_image_and_graph

def random_image_and_graph(
rng,
*,
num_nodes: int = 4,
image_size: tuple[int] = (128, 128),
feature_ndim: int = 32,
) -> tuple[npt.NDArray, list[nx.Graph]]:
"""Create a random image and graph."""
image = np.zeros(image_size, dtype=np.uint16)

features = [rng.uniform(size=(feature_ndim,)) for _ in range(num_nodes)]

node_coords = rng.integers(0, image.shape[1], size=(num_nodes, 2))
node_ground_truth = rng.choice(
[Annotation.TRUE_NEGATIVE, Annotation.TRUE_POSITIVE], size=(num_nodes,)
)
df = pd.DataFrame(
{
GraphAttrs.NODE_X: node_coords[:, 0],
GraphAttrs.NODE_Y: node_coords[:, 1],
GraphAttrs.NODE_FEATURES: features,
GraphAttrs.NODE_GROUND_TRUTH: node_ground_truth,
GraphAttrs.NODE_CONFIDENCE: rng.uniform(
size=(num_nodes),
),
}
)

image[tuple(node_coords[:, 1]), tuple(node_coords[:, 0])] = 1
graph = graph_from_dataframe(df, triangulate=True)

graph.update(
edges=[
(
src,
dst,
{
GraphAttrs.EDGE_GROUND_TRUTH: rng.choice(
[Annotation.TRUE_NEGATIVE, Annotation.TRUE_POSITIVE],
)
},
)
for src, dst in graph.edges
]
)

return image, graph


@pytest.fixture(scope="session")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
RandomXYTranslation,
)

from _utils import random_image_and_graph
from conftest import random_image_and_graph


expected_outputs_random_edge_crop = [
Expand Down
140 changes: 140 additions & 0 deletions tests/test_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import pytest

import networkx as nx

from grace.base import GraphAttrs, Annotation
from grace.models.datasets import dataset_from_graph
from grace.models.classifier import GCNModel

from conftest import random_image_and_graph


@pytest.mark.parametrize("input_channels", [1, 2])
@pytest.mark.parametrize("node_output_classes", [2, 4])
@pytest.mark.parametrize("edge_output_classes", [2, 4])
@pytest.mark.parametrize(
"hidden_graph_channels",
[
[16, 4],
[
128,
],
[],
],
)
@pytest.mark.parametrize(
"hidden_dense_channels",
[
[16, 4],
[
128,
],
[],
],
)
class TestGCN:
@pytest.fixture
def gcn(
self,
input_channels,
hidden_graph_channels,
hidden_dense_channels,
node_output_classes,
edge_output_classes,
):
return GCNModel(
input_channels=input_channels,
hidden_graph_channels=hidden_graph_channels,
hidden_dense_channels=hidden_dense_channels,
node_output_classes=node_output_classes,
edge_output_classes=edge_output_classes,
)

def test_model_building(
self,
input_channels,
hidden_graph_channels,
hidden_dense_channels,
node_output_classes,
edge_output_classes,
gcn,
):
"""Test building the model with different dimensions."""
# torch.nn.ModuleList objects are created with no hidden layers:
if not hidden_graph_channels:
assert gcn.conv_layer_list is None
else:
assert gcn.conv_layer_list is not None

if not hidden_dense_channels:
assert gcn.node_dense_list is None
else:
assert gcn.node_dense_list is not None

# match shape of first list items based on hidden features:
if gcn.conv_layer_list is not None:
assert gcn.conv_layer_list[0].in_channels == input_channels

if gcn.conv_layer_list is None and gcn.node_dense_list is not None:
assert gcn.node_dense_list[0].in_features == input_channels

# control final classifier layers based on hidden features:
if hidden_dense_channels:
assert gcn.node_classifier.in_features == hidden_dense_channels[-1]
assert gcn.node_classifier.out_features == node_output_classes

assert (
gcn.edge_classifier.in_features
== hidden_dense_channels[-1] * 2
)
assert gcn.edge_classifier.out_features == edge_output_classes

elif hidden_graph_channels:
assert gcn.node_classifier.in_features == hidden_graph_channels[-1]
assert gcn.node_classifier.out_features == node_output_classes

assert (
gcn.edge_classifier.in_features
== hidden_graph_channels[-1] * 2
)
assert gcn.edge_classifier.out_features == edge_output_classes

else:
assert gcn.node_classifier.in_features == input_channels
assert gcn.node_classifier.out_features == node_output_classes

assert gcn.edge_classifier.in_features == input_channels * 2
assert gcn.edge_classifier.out_features == edge_output_classes

@pytest.mark.parametrize("num_nodes", [4, 5, 8, 10])
def test_output_sizes(
self,
input_channels,
node_output_classes,
edge_output_classes,
gcn,
num_nodes,
default_rng,
):
_, graph = random_image_and_graph(
default_rng, num_nodes=num_nodes, feature_ndim=input_channels
)
graph.update(
edges=[
(
src,
dst,
{GraphAttrs.EDGE_GROUND_TRUTH: Annotation.TRUE_POSITIVE},
)
for src, dst in graph.edges
]
)
data = dataset_from_graph(graph, mode="sub")[0]

subgraph = nx.ego_graph(graph, 0)
num_nodes = subgraph.number_of_nodes()
num_edges = subgraph.number_of_edges()
node_x, edge_x = gcn(x=data.x, edge_index=data.edge_index)

assert node_x.size() == (num_nodes, node_output_classes)
assert edge_x.size() == (num_edges, edge_output_classes)
70 changes: 13 additions & 57 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import pytest
import torch
import networkx as nx

from grace.base import GraphAttrs, Annotation
from grace.io.image_dataset import ImageGraphDataset
from grace.models.datasets import dataset_from_graph

from _utils import random_image_and_graph


def test_image_graph_dataset(mrc_image_and_annotations_dir):
Expand All @@ -16,63 +13,22 @@ def test_image_graph_dataset(mrc_image_and_annotations_dir):
mrc_image_and_annotations_dir,
image_filetype="mrc",
)

# # all currently fail
# image, graph = dataset[0]

# assert isinstance(image, torch.Tensor)
# assert isinstance(graph, nx.Graph)

# assert image.shape == (1, 128, 128)
# assert graph.number_of_nodes() == 4

assert len(dataset) == num_images

# unwrap sample image & target
image, target = dataset[0]
graph = target["graph"]
metadata = target["metadata"]
filename = metadata["image_filename"]

def test_dataset_ignores_subgraph_if_all_edges_unknown(default_rng):
_, graph = random_image_and_graph(default_rng)

edge_update = [
(src, dst, {GraphAttrs.EDGE_GROUND_TRUTH: Annotation.UNKNOWN})
for src, dst in graph.edges
]
graph.update(edges=edge_update)
# this action is not currently required since edges are by default UNKNOWN;
# however it enables testing of this condition should the default label be changed

assert dataset_from_graph(graph, mode="sub") == []


@pytest.mark.parametrize("num_unknown", [7, 17])
def test_dataset_ignores_subgraph_if_central_node_unknown(
num_unknown, default_rng
):
num_nodes_total = 20
_, graph = random_image_and_graph(
default_rng,
num_nodes=num_nodes_total,
)

edge_update = [
(src, dst, {GraphAttrs.EDGE_GROUND_TRUTH: Annotation.TRUE_POSITIVE})
for src, dst in graph.edges
]
graph.update(edges=edge_update)
assert isinstance(image, torch.Tensor)
assert image.shape == (128, 128) # 2D image

node_update = [
(node, {GraphAttrs.NODE_GROUND_TRUTH: Annotation.UNKNOWN})
for node in list(graph.nodes)[:num_unknown]
]
node_update += [
(node, {GraphAttrs.NODE_GROUND_TRUTH: Annotation.TRUE_POSITIVE})
for node in list(graph.nodes)[num_unknown:]
]
graph.update(nodes=node_update)
assert isinstance(metadata, dict)
assert isinstance(filename, str)

assert (
len(dataset_from_graph(graph, mode="sub"))
== num_nodes_total - num_unknown
)
assert isinstance(graph, nx.Graph)
assert graph.number_of_nodes() == 4


def test_dataset_only_takes_common_filenames(tmp_path):
Expand Down
7 changes: 2 additions & 5 deletions tests/test_dummy.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import numpy as np

from grace.base import GraphAttrs
from grace.evaluation.process import update_graph_with_dummy_predictions

from grace.evaluation.process import (
update_graph_with_dummy_predictions,
)

from _utils import random_image_and_graph
from conftest import random_image_and_graph


def test_update_dummy_graph_predictions(default_rng):
Expand Down
Loading

0 comments on commit 84f019e

Please sign in to comment.