Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
alifahrri committed Dec 23, 2024
1 parent c23cc2b commit 7cc9090
Showing 1 changed file with 136 additions and 0 deletions.
136 changes: 136 additions & 0 deletions tests/functional/src/misc/ct_digraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1024,11 +1024,147 @@ TEST_CASE("contracted_edge(case1)" * doctest::test_suite("ct_digraph"))
// NMTOOLS_ASSERT_GRAPH_EQUAL( result, expected );
CHECK( result.nodes(0_ct) == expected.nodes(0_ct) );
CHECK( result.nodes(1_ct) == expected.nodes(1_ct) );

NMTOOLS_ASSERT_EQUAL( result.nodes(769_ct), expected.nodes(769_ct) );
NMTOOLS_ASSERT_EQUAL( result.nodes(447_ct), expected.nodes(447_ct) );
NMTOOLS_ASSERT_EQUAL( result.nodes(722_ct), expected.nodes(722_ct) );
NMTOOLS_ASSERT_EQUAL( result.nodes(635_ct), expected.nodes(635_ct) );
NMTOOLS_ASSERT_EQUAL( result.nodes(765_ct), expected.nodes(765_ct) );

NMTOOLS_ASSERT_EQUAL( result.out_edges(0_ct), expected.out_edges(0_ct) );
NMTOOLS_ASSERT_EQUAL( result.out_edges(1_ct), expected.out_edges(1_ct) );
NMTOOLS_ASSERT_EQUAL( result.out_edges(769_ct), expected.out_edges(769_ct) );
NMTOOLS_ASSERT_EQUAL( result.out_edges(447_ct), expected.out_edges(447_ct) );
NMTOOLS_ASSERT_EQUAL( result.out_edges(722_ct), expected.out_edges(722_ct) );
NMTOOLS_ASSERT_EQUAL( result.out_edges(635_ct), expected.out_edges(635_ct) );
NMTOOLS_ASSERT_EQUAL( result.out_edges(765_ct), expected.out_edges(765_ct) );
}

// softmax
TEST_CASE("contracted_edge(case2)" * doctest::test_suite("ct_digraph"))
{
auto input_shape = array{3,4};
auto input = na::reshape(na::arange(ix::product(input_shape)),input_shape);

auto graph = utility::ct_digraph()
.add_node(0_ct,&input)
.add_node(263_ct,fn::reduce_maximum[/*axis=*/0])
.add_node(975_ct,fn::subtract)
.add_node(547_ct,fn::exp)
.add_node(407_ct,fn::sum[/*axis=*/0])
.add_node(850_ct,fn::divide)
.add_edges(0_ct,tuple{263_ct,975_ct})
.add_edge(263_ct,975_ct)
.add_edge(975_ct,547_ct)
.add_edges(547_ct,tuple{407_ct,850_ct})
.add_edge(407_ct,850_ct)
;

auto fused = fn::subtract * fn::exp;
auto result = utility::contracted_edge(graph,tuple{975_ct,547_ct},111_ct,fused);

auto expected = utility::ct_digraph()
.add_node(0_ct,&input)
.add_node(263_ct,fn::reduce_maximum[/*axis=*/0])
.add_node(407_ct,fn::sum[/*axis=*/0])
.add_node(850_ct,fn::divide)
.add_node(111_ct,fused)
.add_edges(0_ct,tuple{263_ct,111_ct})
.add_edge(263_ct,111_ct)
.add_edge(407_ct,850_ct)
.add_edges(111_ct,tuple{407_ct,850_ct})
;

NMTOOLS_ASSERT_EQUAL( result.size(), expected.size() );
NMTOOLS_ASSERT_EQUAL( result.nodes(), expected.nodes() );

CHECK( result.nodes(0_ct) == expected.nodes(0_ct) );

NMTOOLS_ASSERT_EQUAL( result.nodes(263_ct), expected.nodes(263_ct) );
NMTOOLS_ASSERT_EQUAL( result.nodes(407_ct), expected.nodes(407_ct) );
NMTOOLS_ASSERT_EQUAL( result.nodes(850_ct), expected.nodes(850_ct) );
NMTOOLS_ASSERT_EQUAL( result.nodes(111_ct), expected.nodes(111_ct) );

NMTOOLS_ASSERT_EQUAL( result.out_edges(0_ct), expected.out_edges(0_ct) );
NMTOOLS_ASSERT_EQUAL( result.out_edges(263_ct), expected.out_edges(263_ct) );
NMTOOLS_ASSERT_EQUAL( result.out_edges(407_ct), expected.out_edges(407_ct) );
NMTOOLS_ASSERT_EQUAL( result.out_edges(850_ct), expected.out_edges(850_ct) );
NMTOOLS_ASSERT_EQUAL( result.out_edges(111_ct), expected.out_edges(111_ct) );
}

// var
TEST_CASE("contracted_edge(case3)" * doctest::test_suite("ct_digraph"))
{
auto input_shape = array{3,4};
auto input = na::reshape(na::arange(ix::product(input_shape)),input_shape);

auto graph = utility::ct_digraph()
.add_node(0_ct,&input)
.add_node(472_ct,3)
.add_node(470_ct,fn::reduce_add[/*axis=*/0])
.add_node(51_ct,fn::divide)
.add_node(428_ct,fn::subtract)
.add_node(391_ct,fn::fabs)
.add_node(591_ct,fn::square)
.add_node(433_ct,fn::reduce_add[/*axis=*/0])
.add_node(435_ct,3)
.add_node(1022_ct,fn::divide)
.add_edges(0_ct,tuple{470_ct,428_ct})
.add_edge(472_ct,51_ct)
.add_edge(470_ct,51_ct)
.add_edge(51_ct,428_ct)
.add_edge(428_ct,391_ct)
.add_edge(391_ct,591_ct)
.add_edge(591_ct,433_ct)
.add_edge(433_ct,1022_ct)
.add_edge(435_ct,1022_ct)
;

auto fused = fn::square * fn::fabs;
auto result = utility::contracted_edge(graph,tuple{391_ct,591_ct},391_ct,fused);

auto expected = utility::ct_digraph()
.add_node(0_ct,&input)
.add_node(472_ct,3)
.add_node(470_ct,fn::reduce_add[/*axis=*/0])
.add_node(51_ct,fn::divide)
.add_node(428_ct,fn::subtract)
.add_node(433_ct,fn::reduce_add[/*axis=*/0])
.add_node(435_ct,3)
.add_node(1022_ct,fn::divide)
.add_node(391_ct,fused)
.add_edges(0_ct,tuple{470_ct,428_ct})
.add_edge(472_ct,51_ct)
.add_edge(470_ct,51_ct)
.add_edge(51_ct,428_ct)
.add_edge(428_ct,391_ct)
.add_edge(433_ct,1022_ct)
.add_edge(435_ct,1022_ct)
.add_edge(391_ct,433_ct)
;

NMTOOLS_ASSERT_EQUAL( result.size(), expected.size() );
NMTOOLS_ASSERT_EQUAL( result.nodes(), expected.nodes() );

CHECK( result.nodes(0_ct) == expected.nodes(0_ct) );
NMTOOLS_ASSERT_EQUAL( result.nodes(472_ct), expected.nodes(472_ct) );
NMTOOLS_ASSERT_EQUAL( result.nodes(470_ct), expected.nodes(470_ct) );
NMTOOLS_ASSERT_EQUAL( result.nodes(51_ct), expected.nodes(51_ct) );
NMTOOLS_ASSERT_EQUAL( result.nodes(428_ct), expected.nodes(428_ct) );
NMTOOLS_ASSERT_EQUAL( result.nodes(433_ct), expected.nodes(433_ct) );
NMTOOLS_ASSERT_EQUAL( result.nodes(435_ct), expected.nodes(435_ct) );
NMTOOLS_ASSERT_EQUAL( result.nodes(1022_ct), expected.nodes(1022_ct) );
NMTOOLS_ASSERT_EQUAL( result.nodes(391_ct), expected.nodes(391_ct) );

NMTOOLS_ASSERT_EQUAL( result.out_edges(0_ct), expected.out_edges(0_ct) );
NMTOOLS_ASSERT_EQUAL( result.out_edges(472_ct), expected.out_edges(472_ct) );
NMTOOLS_ASSERT_EQUAL( result.out_edges(470_ct), expected.out_edges(470_ct) );
NMTOOLS_ASSERT_EQUAL( result.out_edges(51_ct), expected.out_edges(51_ct) );
NMTOOLS_ASSERT_EQUAL( result.out_edges(428_ct), expected.out_edges(428_ct) );
NMTOOLS_ASSERT_EQUAL( result.out_edges(433_ct), expected.out_edges(433_ct) );
NMTOOLS_ASSERT_EQUAL( result.out_edges(435_ct), expected.out_edges(435_ct) );
NMTOOLS_ASSERT_EQUAL( result.out_edges(1022_ct), expected.out_edges(1022_ct) );
NMTOOLS_ASSERT_EQUAL( result.out_edges(391_ct), expected.out_edges(391_ct) );
}

// matmul
Expand Down

0 comments on commit 7cc9090

Please sign in to comment.