From 7cc909067cbb8ee787435dc1b40cfc55c4760cb1 Mon Sep 17 00:00:00 2001 From: Fahri Ali Rahman Date: Mon, 23 Dec 2024 09:49:26 +0700 Subject: [PATCH] update tests --- tests/functional/src/misc/ct_digraph.cpp | 136 +++++++++++++++++++++++ 1 file changed, 136 insertions(+) diff --git a/tests/functional/src/misc/ct_digraph.cpp b/tests/functional/src/misc/ct_digraph.cpp index 29627d287..2d016a0b5 100644 --- a/tests/functional/src/misc/ct_digraph.cpp +++ b/tests/functional/src/misc/ct_digraph.cpp @@ -1024,11 +1024,147 @@ TEST_CASE("contracted_edge(case1)" * doctest::test_suite("ct_digraph")) // NMTOOLS_ASSERT_GRAPH_EQUAL( result, expected ); CHECK( result.nodes(0_ct) == expected.nodes(0_ct) ); CHECK( result.nodes(1_ct) == expected.nodes(1_ct) ); + NMTOOLS_ASSERT_EQUAL( result.nodes(769_ct), expected.nodes(769_ct) ); NMTOOLS_ASSERT_EQUAL( result.nodes(447_ct), expected.nodes(447_ct) ); NMTOOLS_ASSERT_EQUAL( result.nodes(722_ct), expected.nodes(722_ct) ); NMTOOLS_ASSERT_EQUAL( result.nodes(635_ct), expected.nodes(635_ct) ); NMTOOLS_ASSERT_EQUAL( result.nodes(765_ct), expected.nodes(765_ct) ); + + NMTOOLS_ASSERT_EQUAL( result.out_edges(0_ct), expected.out_edges(0_ct) ); + NMTOOLS_ASSERT_EQUAL( result.out_edges(1_ct), expected.out_edges(1_ct) ); + NMTOOLS_ASSERT_EQUAL( result.out_edges(769_ct), expected.out_edges(769_ct) ); + NMTOOLS_ASSERT_EQUAL( result.out_edges(447_ct), expected.out_edges(447_ct) ); + NMTOOLS_ASSERT_EQUAL( result.out_edges(722_ct), expected.out_edges(722_ct) ); + NMTOOLS_ASSERT_EQUAL( result.out_edges(635_ct), expected.out_edges(635_ct) ); + NMTOOLS_ASSERT_EQUAL( result.out_edges(765_ct), expected.out_edges(765_ct) ); +} + +// softmax +TEST_CASE("contracted_edge(case2)" * doctest::test_suite("ct_digraph")) +{ + auto input_shape = array{3,4}; + auto input = na::reshape(na::arange(ix::product(input_shape)),input_shape); + + auto graph = utility::ct_digraph() + .add_node(0_ct,&input) + .add_node(263_ct,fn::reduce_maximum[/*axis=*/0]) + .add_node(975_ct,fn::subtract) + .add_node(547_ct,fn::exp) + .add_node(407_ct,fn::sum[/*axis=*/0]) + .add_node(850_ct,fn::divide) + .add_edges(0_ct,tuple{263_ct,975_ct}) + .add_edge(263_ct,975_ct) + .add_edge(975_ct,547_ct) + .add_edges(547_ct,tuple{407_ct,850_ct}) + .add_edge(407_ct,850_ct) + ; + + auto fused = fn::subtract * fn::exp; + auto result = utility::contracted_edge(graph,tuple{975_ct,547_ct},111_ct,fused); + + auto expected = utility::ct_digraph() + .add_node(0_ct,&input) + .add_node(263_ct,fn::reduce_maximum[/*axis=*/0]) + .add_node(407_ct,fn::sum[/*axis=*/0]) + .add_node(850_ct,fn::divide) + .add_node(111_ct,fused) + .add_edges(0_ct,tuple{263_ct,111_ct}) + .add_edge(263_ct,111_ct) + .add_edge(407_ct,850_ct) + .add_edges(111_ct,tuple{407_ct,850_ct}) + ; + + NMTOOLS_ASSERT_EQUAL( result.size(), expected.size() ); + NMTOOLS_ASSERT_EQUAL( result.nodes(), expected.nodes() ); + + CHECK( result.nodes(0_ct) == expected.nodes(0_ct) ); + + NMTOOLS_ASSERT_EQUAL( result.nodes(263_ct), expected.nodes(263_ct) ); + NMTOOLS_ASSERT_EQUAL( result.nodes(407_ct), expected.nodes(407_ct) ); + NMTOOLS_ASSERT_EQUAL( result.nodes(850_ct), expected.nodes(850_ct) ); + NMTOOLS_ASSERT_EQUAL( result.nodes(111_ct), expected.nodes(111_ct) ); + + NMTOOLS_ASSERT_EQUAL( result.out_edges(0_ct), expected.out_edges(0_ct) ); + NMTOOLS_ASSERT_EQUAL( result.out_edges(263_ct), expected.out_edges(263_ct) ); + NMTOOLS_ASSERT_EQUAL( result.out_edges(407_ct), expected.out_edges(407_ct) ); + NMTOOLS_ASSERT_EQUAL( result.out_edges(850_ct), expected.out_edges(850_ct) ); + NMTOOLS_ASSERT_EQUAL( result.out_edges(111_ct), expected.out_edges(111_ct) ); +} + +// var +TEST_CASE("contracted_edge(case3)" * doctest::test_suite("ct_digraph")) +{ + auto input_shape = array{3,4}; + auto input = na::reshape(na::arange(ix::product(input_shape)),input_shape); + + auto graph = utility::ct_digraph() + .add_node(0_ct,&input) + .add_node(472_ct,3) + .add_node(470_ct,fn::reduce_add[/*axis=*/0]) + .add_node(51_ct,fn::divide) + .add_node(428_ct,fn::subtract) + .add_node(391_ct,fn::fabs) + .add_node(591_ct,fn::square) + .add_node(433_ct,fn::reduce_add[/*axis=*/0]) + .add_node(435_ct,3) + .add_node(1022_ct,fn::divide) + .add_edges(0_ct,tuple{470_ct,428_ct}) + .add_edge(472_ct,51_ct) + .add_edge(470_ct,51_ct) + .add_edge(51_ct,428_ct) + .add_edge(428_ct,391_ct) + .add_edge(391_ct,591_ct) + .add_edge(591_ct,433_ct) + .add_edge(433_ct,1022_ct) + .add_edge(435_ct,1022_ct) + ; + + auto fused = fn::square * fn::fabs; + auto result = utility::contracted_edge(graph,tuple{391_ct,591_ct},391_ct,fused); + + auto expected = utility::ct_digraph() + .add_node(0_ct,&input) + .add_node(472_ct,3) + .add_node(470_ct,fn::reduce_add[/*axis=*/0]) + .add_node(51_ct,fn::divide) + .add_node(428_ct,fn::subtract) + .add_node(433_ct,fn::reduce_add[/*axis=*/0]) + .add_node(435_ct,3) + .add_node(1022_ct,fn::divide) + .add_node(391_ct,fused) + .add_edges(0_ct,tuple{470_ct,428_ct}) + .add_edge(472_ct,51_ct) + .add_edge(470_ct,51_ct) + .add_edge(51_ct,428_ct) + .add_edge(428_ct,391_ct) + .add_edge(433_ct,1022_ct) + .add_edge(435_ct,1022_ct) + .add_edge(391_ct,433_ct) + ; + + NMTOOLS_ASSERT_EQUAL( result.size(), expected.size() ); + NMTOOLS_ASSERT_EQUAL( result.nodes(), expected.nodes() ); + + CHECK( result.nodes(0_ct) == expected.nodes(0_ct) ); + NMTOOLS_ASSERT_EQUAL( result.nodes(472_ct), expected.nodes(472_ct) ); + NMTOOLS_ASSERT_EQUAL( result.nodes(470_ct), expected.nodes(470_ct) ); + NMTOOLS_ASSERT_EQUAL( result.nodes(51_ct), expected.nodes(51_ct) ); + NMTOOLS_ASSERT_EQUAL( result.nodes(428_ct), expected.nodes(428_ct) ); + NMTOOLS_ASSERT_EQUAL( result.nodes(433_ct), expected.nodes(433_ct) ); + NMTOOLS_ASSERT_EQUAL( result.nodes(435_ct), expected.nodes(435_ct) ); + NMTOOLS_ASSERT_EQUAL( result.nodes(1022_ct), expected.nodes(1022_ct) ); + NMTOOLS_ASSERT_EQUAL( result.nodes(391_ct), expected.nodes(391_ct) ); + + NMTOOLS_ASSERT_EQUAL( result.out_edges(0_ct), expected.out_edges(0_ct) ); + NMTOOLS_ASSERT_EQUAL( result.out_edges(472_ct), expected.out_edges(472_ct) ); + NMTOOLS_ASSERT_EQUAL( result.out_edges(470_ct), expected.out_edges(470_ct) ); + NMTOOLS_ASSERT_EQUAL( result.out_edges(51_ct), expected.out_edges(51_ct) ); + NMTOOLS_ASSERT_EQUAL( result.out_edges(428_ct), expected.out_edges(428_ct) ); + NMTOOLS_ASSERT_EQUAL( result.out_edges(433_ct), expected.out_edges(433_ct) ); + NMTOOLS_ASSERT_EQUAL( result.out_edges(435_ct), expected.out_edges(435_ct) ); + NMTOOLS_ASSERT_EQUAL( result.out_edges(1022_ct), expected.out_edges(1022_ct) ); + NMTOOLS_ASSERT_EQUAL( result.out_edges(391_ct), expected.out_edges(391_ct) ); } // matmul