Skip to content

Commit

Permalink
Pass all tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lockshaw committed Oct 8, 2024
1 parent dcd2e13 commit 39c8f1c
Show file tree
Hide file tree
Showing 25 changed files with 471 additions and 57 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@ features = [

includes = [
"pcg/file_format/v1/v1_computation_graph.dtg.h",
"utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h",
"pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.dtg.h",
]

src_includes = [
"utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/json.h",
"pcg/file_format/v1/v1_binary_sp_decomposition/json.h",
]

[[fields]]
name = "sp_decomposition"
type = "::FlexFlow::LeafOnlyBinarySPDecompositionTree<int>"
type = "::FlexFlow::V1BinarySPDecomposition"

[[fields]]
name = "computation_graph"
Expand Down
4 changes: 1 addition & 3 deletions bin/export-model-arch/src/export_model_arch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,7 @@ tl::expected<JsonSPModelExport, std::string>
to_v1_including_node_numbering(computation_graph);
V1ComputationGraph v1_cg = v1_result.first;
bidict<int, layer_guid_t> layer_numbering = v1_result.second;
GenericBinarySPDecompositionTree<int> v1_sp_decomposition =
transform(sp_decomposition.raw_tree,
[&](layer_guid_t const &l) { return layer_numbering.at_r(l); });
V1BinarySPDecomposition v1_sp_decomposition = to_v1(sp_decomposition, layer_numbering);

return JsonSPModelExport{
v1_sp_decomposition,
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@

#include "compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.dtg.h"
#include "pcg/computation_graph.dtg.h"
#include "pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.dtg.h"
#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h"
#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h"
#include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h"
#include "utils/overload.h"

namespace FlexFlow {

Expand All @@ -32,6 +34,9 @@ bool is_right_associative(ComputationGraphBinarySPDecomposition const &);
std::unordered_multiset<layer_guid_t>
get_layers(ComputationGraphBinarySPDecomposition const &);

V1BinarySPDecomposition to_v1(ComputationGraphBinarySPDecomposition const &,
bidict<int, layer_guid_t> const &layer_numbering);

} // namespace FlexFlow

#endif
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_BINARY_PARALLEL_SPLIT_H
#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_BINARY_PARALLEL_SPLIT_H
#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_PCG_BINARY_PARALLEL_SPLIT_H
#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_PCG_BINARY_PARALLEL_SPLIT_H

#include "compiler/series_parallel/pcg/pcg_binary_parallel_split.dtg.h"
#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.dtg.h"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_BINARY_SERIES_SPLIT_H
#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_BINARY_SERIES_SPLIT_H
#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_PCG_BINARY_SERIES_SPLIT_H
#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_PCG_BINARY_SERIES_SPLIT_H

#include "compiler/series_parallel/pcg/pcg_binary_series_split.dtg.h"
#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,31 @@ std::unordered_multiset<layer_guid_t>
return get_leaves(tree, generic_impl_for_computation_graph_sp_tree());
}

V1BinarySPDecomposition to_v1(ComputationGraphBinarySPDecomposition const &tree,
bidict<int, layer_guid_t> const &layer_numbering) {
return tree.visit<V1BinarySPDecomposition>(overload {
[&](ComputationGraphBinarySeriesSplit const &series) {
return V1BinarySPDecomposition{
V1BinarySeriesSplit{
to_v1(series.get_left_child(), layer_numbering),
to_v1(series.get_right_child(), layer_numbering),
},
};
},
[&](ComputationGraphBinaryParallelSplit const &parallel) {
return V1BinarySPDecomposition{
V1BinaryParallelSplit{
to_v1(parallel.get_left_child(), layer_numbering),
to_v1(parallel.get_right_child(), layer_numbering),
},
};
},
[&](layer_guid_t const &layer) {
return V1BinarySPDecomposition{
layer_numbering.at_r(layer),
};
}
});
}

} // namespace FlexFlow
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#include "compiler/series_parallel/pcg/pcg_binary_parallel_split.h"
#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h"

namespace FlexFlow {

BinaryParallelSplit binary_parallel_split_from_pcg_parallel_split(PCGBinaryParallelSplit const &pcg_split) {
return BinaryParallelSplit{
binary_sp_tree_from_pcg_sp_tree(pcg_split.get_left_child()),
binary_sp_tree_from_pcg_sp_tree(pcg_split.get_right_child()),
};
}

} // namespace FlexFlow
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#include "compiler/series_parallel/pcg/pcg_binary_series_split.h"
#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h"

namespace FlexFlow {

BinarySeriesSplit binary_series_split_from_pcg_series_split(PCGBinarySeriesSplit const &pcg_split) {
return BinarySeriesSplit{
binary_sp_tree_from_pcg_sp_tree(pcg_split.get_left_child()),
binary_sp_tree_from_pcg_sp_tree(pcg_split.get_right_child()),
};
}

} // namespace FlexFlow
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h"
#include "compiler/series_parallel/pcg/pcg_binary_parallel_split.h"
#include "compiler/series_parallel/pcg/pcg_binary_series_split.h"
#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h"
#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h"
Expand Down Expand Up @@ -55,7 +54,10 @@ BinarySPDecompositionTree binary_sp_tree_from_pcg_sp_tree(PCGBinarySPDecompositi
},
[](PCGBinaryParallelSplit const &parallel) -> BinarySPDecompositionTree {
return BinarySPDecompositionTree{
binary_parallel_split_from_pcg_parallel_split(parallel),
BinaryParallelSplit{
binary_sp_tree_from_pcg_sp_tree(parallel.get_left_child()),
binary_sp_tree_from_pcg_sp_tree(parallel.get_right_child()),
},
};
},
[](parallel_layer_guid_t const &layer) -> BinarySPDecompositionTree {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_V1_BINARY_SP_DECOMPOSITION_JSON_H
#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_V1_BINARY_SP_DECOMPOSITION_JSON_H

#include <nlohmann/json.hpp>
#include "pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.dtg.h"

namespace nlohmann {

template <>
struct adl_serializer<::FlexFlow::V1BinarySPDecomposition> {
static ::FlexFlow::V1BinarySPDecomposition from_json(json const &);
static void to_json(json &, ::FlexFlow::V1BinarySPDecomposition const &);
};

template <>
struct adl_serializer<::FlexFlow::V1BinarySeriesSplit> {
static ::FlexFlow::V1BinarySeriesSplit from_json(json const &);
static void to_json(json &, ::FlexFlow::V1BinarySeriesSplit const &);
};

template <>
struct adl_serializer<::FlexFlow::V1BinaryParallelSplit> {
static ::FlexFlow::V1BinaryParallelSplit from_json(json const &);
static void to_json(json &, ::FlexFlow::V1BinaryParallelSplit const &);
};

} // namespace nlohmann

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
namespace = "FlexFlow"
name = "V1BinaryParallelSplit"
features = [
"eq",
"hash",
"fmt",
]

fwd_decls = [
"struct V1BinarySPDecomposition"
]

post_includes = [
"pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.dtg.h",
]

[[fields]]
name = "left_child"
type = "::FlexFlow::V1BinarySPDecomposition"
indirect = true

[[fields]]
name = "right_child"
type = "::FlexFlow::V1BinarySPDecomposition"
indirect = true
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
namespace = "FlexFlow"
name = "V1BinarySeriesSplit"
features = [
"eq",
"hash",
"fmt",
]

fwd_decls = [
"struct V1BinarySPDecomposition"
]

post_includes = [
"pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.dtg.h",
]

[[fields]]
name = "left_child"
type = "::FlexFlow::V1BinarySPDecomposition"
indirect = true

[[fields]]
name = "right_child"
type = "::FlexFlow::V1BinarySPDecomposition"
indirect = true
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
namespace = "FlexFlow"
name = "V1BinarySPDecomposition"
features = [
"eq",
"hash",
"fmt",
]

includes = [
"pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_series_split.dtg.h",
"pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_parallel_split.dtg.h",
]

[[values]]
type = "::FlexFlow::V1BinarySeriesSplit"
key = "series"

[[values]]
type = "::FlexFlow::V1BinaryParallelSplit"
key = "parallel"

[[values]]
type = "int"
key = "leaf"
75 changes: 75 additions & 0 deletions lib/pcg/src/pcg/file_format/v1/v1_binary_sp_decomposition/json.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#include "pcg/file_format/v1/v1_binary_sp_decomposition/json.h"
#include "utils/exception.h"
#include "utils/overload.h"
#include "utils/fmt/json.h"

using namespace ::FlexFlow;

namespace nlohmann {

V1BinarySPDecomposition adl_serializer<V1BinarySPDecomposition>::from_json(json const &j) {
std::string type = j.at("type").get<std::string>();

if (type == "series") {
return V1BinarySPDecomposition{
j.get<V1BinarySeriesSplit>(),
};
} else if (type == "parallel") {
return V1BinarySPDecomposition{
j.get<V1BinaryParallelSplit>(),
};
} else if (type == "leaf") {
return V1BinarySPDecomposition{
j.at("value").get<int>(),
};
} else {
throw mk_runtime_error(fmt::format("Unknown json type value for LeafOnlyBinarySPDecompositionTree \"{}\" in json object: {}", type, j));
}
}

void adl_serializer<V1BinarySPDecomposition>::to_json(json &j, V1BinarySPDecomposition const &tree) {
tree.visit<std::monostate>(overload {
[&](V1BinarySeriesSplit const &split) {
j = split;
j["type"] = "series";
return std::monostate{};
},
[&](V1BinaryParallelSplit const &split) {
j = split;
j["type"] = "parallel";
return std::monostate{};
},
[&](int leaf) {
j["value"] = leaf;
j["type"] = "leaf";
return std::monostate{};
},
});
}

V1BinarySeriesSplit adl_serializer<V1BinarySeriesSplit>::from_json(json const &j) {
return V1BinarySeriesSplit{
/*lhs=*/j.at("left_child").get<V1BinarySPDecomposition>(),
/*rhs=*/j.at("right_child").get<V1BinarySPDecomposition>(),
};
}

void adl_serializer<V1BinarySeriesSplit>::to_json(json &j, V1BinarySeriesSplit const &series) {
j["left_child"] = series.get_left_child();
j["right_child"] = series.get_right_child();
}

V1BinaryParallelSplit adl_serializer<V1BinaryParallelSplit>::from_json(json const &j) {
return V1BinaryParallelSplit{
/*lhs=*/j.at("left_child").get<V1BinarySPDecomposition>(),
/*rhs=*/j.at("right_child").get<V1BinarySPDecomposition>(),
};
}

void adl_serializer<V1BinaryParallelSplit>::to_json(json &j, V1BinaryParallelSplit const &series) {
j["left_child"] = series.get_left_child();
j["right_child"] = series.get_right_child();
}


} // namespace FlexFlow
Loading

0 comments on commit 39c8f1c

Please sign in to comment.