diff --git a/.github/workflows/helpers/build_libs.sh b/.github/workflows/helpers/build_target.sh similarity index 100% rename from .github/workflows/helpers/build_libs.sh rename to .github/workflows/helpers/build_target.sh diff --git a/.github/workflows/helpers/test_libs.sh b/.github/workflows/helpers/test_target.sh similarity index 100% rename from .github/workflows/helpers/test_libs.sh rename to .github/workflows/helpers/test_target.sh diff --git a/.github/workflows/per-lib-check.yml b/.github/workflows/per-lib-check.yml index 639f4d82b5..b54ef25819 100644 --- a/.github/workflows/per-lib-check.yml +++ b/.github/workflows/per-lib-check.yml @@ -23,6 +23,9 @@ jobs: - name: Add helpers directory to path run: echo "${PWD}/.github/workflows/helpers" >> $GITHUB_PATH + - name: Free additional space on runner + run: free_space_on_runner.sh + - name: Install nix uses: cachix/install-nix-action@v25 with: @@ -62,71 +65,79 @@ jobs: - name: Build utils run: | - build_libs.sh utils + build_target.sh utils - name: Build op-attrs run: | - build_libs.sh op-attrs + build_target.sh op-attrs - name: Build pcg run: | - build_libs.sh pcg + build_target.sh pcg - name: Build kernels run: | - build_libs.sh kernels + build_target.sh kernels - name: Build substitutions run: | - build_libs.sh substitutions + build_target.sh substitutions - name: Build compiler run: | - build_libs.sh compiler + build_target.sh compiler - name: Build substitution-generator run: | - build_libs.sh substitution-generator + build_target.sh substitution-generator - name: Build local-execution run: | - build_libs.sh local-execution + build_target.sh local-execution - name: Build models run: | - build_libs.sh models + build_target.sh models + + - name: Build substitution-to-dot + run: | + build_target.sh substitution-to-dot + + - name: Build export-model-arch + run: | + build_target.sh export-model-arch - name: Test utils run: | - test_libs.sh utils + test_target.sh utils - name: Test op-attrs run: | - test_libs.sh op-attrs + test_target.sh op-attrs - name: Test pcg run: | - test_libs.sh pcg + test_target.sh pcg - name: Test substitutions run: | - test_libs.sh substitutions + test_target.sh substitutions - # - name: Test compiler - # run: | - # test_libs.sh compiler + - name: Test compiler + run: | + test_target.sh compiler - name: Test substitution-generator run: | - test_libs.sh substitution-generator + test_target.sh substitution-generator - name: Test local-execution run: | - test_libs.sh local-execution + test_target.sh local-execution - name: Test models run: | - test_libs.sh models + test_target.sh models - name: Generate code coverage run: | diff --git a/.proj.toml b/.proj.toml index 721d212e31..5592f184ad 100644 --- a/.proj.toml +++ b/.proj.toml @@ -13,6 +13,8 @@ build_targets = [ "substitution-generator", "local-execution", "models", + "export-model-arch", + "substitution-to-dot", ] test_targets = [ diff --git a/CMakeLists.txt b/CMakeLists.txt index a518931ac5..792126449b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -74,6 +74,7 @@ option(FF_BUILD_UNIT_TESTS "build non-operator unit tests" OFF) option(FF_BUILD_SUBSTITUTION_TOOL "build substitution conversion tool" OFF) option(FF_BUILD_VISUALIZATION_TOOL "build substitution visualization tool" ON) option(FF_BUILD_ARG_PARSER "build command line argument parser" OFF) +option(FF_BUILD_BIN_EXPORT_MODEL_ARCH "build export-model-arch utility" ON) set(FF_CUDA_ARCH "autodetect" CACHE STRING "Target CUDA Arch") if (FF_CUDA_ARCH STREQUAL "") diff --git a/bin/CMakeLists.txt b/bin/CMakeLists.txt index fcc19b33b9..1cd7068cfd 100644 --- a/bin/CMakeLists.txt +++ b/bin/CMakeLists.txt @@ -7,9 +7,13 @@ if(FF_BUILD_SUBSTITUTION_TOOL) endif() if(FF_BUILD_VISUALIZATION_TOOL) - add_subdirectory(substitutions-to-dot) + add_subdirectory(substitution-to-dot) endif() if(FF_BUILD_ARG_PARSER) add_subdirectory(arg_parser) endif() + +if(FF_BUILD_BIN_EXPORT_MODEL_ARCH) + add_subdirectory(export-model-arch) +endif() diff --git a/bin/export-model-arch/CMakeLists.txt b/bin/export-model-arch/CMakeLists.txt new file mode 100644 index 0000000000..b931668594 --- /dev/null +++ b/bin/export-model-arch/CMakeLists.txt @@ -0,0 +1,12 @@ +ff_add_executable( + NAME + export-model-arch + SRC_PATTERNS + src/*.cc + PRIVATE_INCLUDE + include/ + DEPS + utils + models + compiler +) diff --git a/bin/export-model-arch/include/export_model_arch/json_sp_model_export.struct.toml b/bin/export-model-arch/include/export_model_arch/json_sp_model_export.struct.toml new file mode 100644 index 0000000000..efaf10c255 --- /dev/null +++ b/bin/export-model-arch/include/export_model_arch/json_sp_model_export.struct.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "JsonSPModelExport" +features = [ + "eq", + "hash", + "fmt", + "json", +] + +includes = [ + "pcg/file_format/v1/v1_computation_graph.dtg.h", + "pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.dtg.h", +] + +src_includes = [ + "pcg/file_format/v1/v1_binary_sp_decomposition/json.h", +] + +[[fields]] +name = "sp_decomposition" +type = "::FlexFlow::V1BinarySPDecomposition" + +[[fields]] +name = "computation_graph" +type = "::FlexFlow::V1ComputationGraph" diff --git a/bin/export-model-arch/src/export_model_arch.cc b/bin/export-model-arch/src/export_model_arch.cc new file mode 100644 index 0000000000..64419acce4 --- /dev/null +++ b/bin/export-model-arch/src/export_model_arch.cc @@ -0,0 +1,220 @@ +#include "compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.h" +#include "compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.h" +#include "export_model_arch/json_sp_model_export.dtg.h" +#include "models/bert/bert.h" +#include "models/candle_uno/candle_uno.h" +#include "models/inception_v3/inception_v3.h" +#include "models/split_test/split_test.h" +#include "models/transformer/transformer.h" +#include "op-attrs/computation_graph_op_attrs.h" +#include "pcg/computation_graph.h" +#include "pcg/file_format/v1/v1_computation_graph.h" +#include "utils/cli/cli_get_help_message.h" +#include "utils/cli/cli_parse.h" +#include "utils/cli/cli_parse_result.h" +#include "utils/cli/cli_spec.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h" +#include "utils/graph/series_parallel/get_series_parallel_decomposition.h" + +using namespace ::FlexFlow; + +ComputationGraph get_single_operator_computation_graph() { + ComputationGraphBuilder b; + + size_t batch_size = 8; + size_t in_channels = 16; + size_t out_channels = 12; + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + batch_size, + in_channels, + out_channels, + }}, + DataType::FLOAT, + }; + + tensor_guid_t input = b.create_input(input_shape, CreateGrad::YES); + + InitializerAttrs kernel_initializer = + InitializerAttrs{GlorotUniformAttrs{/*seed=*/12}}; + InitializerAttrs bias_initializer = InitializerAttrs{ZeroInitializerAttrs{}}; + tensor_guid_t output = b.dense(input, + in_channels, + Activation::RELU, + /*use_bias=*/true, + DataType::FLOAT, + kernel_initializer, + bias_initializer, + "my_example_operator"); + + return b.computation_graph; +} + +ComputationGraph get_default_transformer_computation_graph() { + TransformerConfig config = get_default_transformer_config(); + ComputationGraph cg = get_transformer_computation_graph(config); + + return cg; +} + +tl::expected + get_model_computation_graph(std::string const &model_name) { + if (model_name == "transformer") { + return get_default_transformer_computation_graph(); + } else if (model_name == "inception_v3") { + return get_inception_v3_computation_graph( + get_default_inception_v3_training_config()); + } else if (model_name == "candle_uno") { + return get_candle_uno_computation_graph(get_default_candle_uno_config()); + } else if (model_name == "bert") { + return get_bert_computation_graph(get_default_bert_config()); + } else if (model_name == "split_test") { + int batch_size = 8; + return get_split_test_computation_graph(batch_size); + } else if (model_name == "single_operator") { + return get_single_operator_computation_graph(); + } else { + return tl::unexpected(fmt::format("Unknown model name: {}", model_name)); + } +} + +tl::expected + get_sp_model_export(std::string const &model_name) { + ComputationGraph computation_graph = ({ + tl::expected result = + get_model_computation_graph(model_name); + if (!result.has_value()) { + return tl::unexpected(result.error()); + } + result.value(); + }); + + ComputationGraphBinarySPDecomposition sp_decomposition = ({ + std::optional result = + get_computation_graph_right_assoc_binary_sp_decomposition( + computation_graph); + if (!result.has_value()) { + return tl::unexpected("Failed to generate series-parallel decomposition " + "of computation graph."); + } + result.value(); + }); + + std::pair> v1_result = + to_v1_including_node_numbering(computation_graph); + V1ComputationGraph v1_cg = v1_result.first; + bidict layer_numbering = v1_result.second; + V1BinarySPDecomposition v1_sp_decomposition = + to_v1(sp_decomposition, layer_numbering); + + return JsonSPModelExport{ + v1_sp_decomposition, + v1_cg, + }; +} + +int main(int argc, char **argv) { + CLISpec cli = empty_cli_spec(); + + CLIArgumentKey arg_key_help = cli_add_help_flag(cli); + + CLIArgumentKey key_sp_decomposition = + cli_add_flag(cli, + CLIFlagSpec{"sp-decomposition", + std::nullopt, + "also output a series parallel decomposition of " + "the model's computation graph"}); + + CLIArgumentKey key_dot = cli_add_flag( + cli, + CLIFlagSpec{ + "dot", + std::nullopt, + "output a dot representation of the model's computation graph"}); + + CLIArgumentKey key_preprocessed_dot = cli_add_flag( + cli, + CLIFlagSpec{"preprocessed-dot", + std::nullopt, + "output a dot representation of model's computation graph " + "for preprocessed to help check series-parallel structure"}); + + std::vector model_options = {"transformer", + "inception_v3", + "candle_uno", + "bert", + "split_test", + "single_operator"}; + CLIArgumentKey key_model_name = cli_add_positional_argument( + cli, + CLIPositionalArgumentSpec{ + "model", model_options, "name of the model to export"}); + + assert(argc >= 1); + std::string prog_name = argv[0]; + + CLIParseResult parsed = ({ + tl::expected result = + cli_parse(cli, argc, argv); + if (!result.has_value()) { + std::string error_msg = result.error(); + std::cerr << cli_get_help_message(prog_name, cli); + std::cerr << std::endl; + std::cerr << "error: " << error_msg << std::endl; + return 1; + } + + result.value(); + }); + + bool help = cli_get_flag(parsed, arg_key_help); + if (help) { + std::cerr << cli_get_help_message(prog_name, cli); + return 1; + } + + std::string model_name = cli_get_argument(parsed, key_model_name); + bool sp_decompositition = cli_get_flag(parsed, key_sp_decomposition); + bool dot = cli_get_flag(parsed, key_dot); + bool preprocessed_dot = cli_get_flag(parsed, key_preprocessed_dot); + + auto handle_error = [](auto const &result) { + if (!result.has_value()) { + std::cerr << "error: " << result.error() << std::endl; + exit(1); + } + + return result.value(); + }; + + if (dot) { + ComputationGraph cg = handle_error(get_model_computation_graph(model_name)); + + std::cout << as_dot(cg) << std::endl; + return 0; + } + + if (preprocessed_dot) { + ComputationGraph cg = handle_error(get_model_computation_graph(model_name)); + std::string rendered = + render_preprocessed_computation_graph_for_sp_decomposition(cg); + + std::cout << rendered << std::endl; + return 0; + } + + nlohmann::json json_output; + if (sp_decompositition) { + JsonSPModelExport model_export = + handle_error(get_sp_model_export(model_name)); + + json_output = model_export; + } else { + ComputationGraph cg = handle_error(get_model_computation_graph(model_name)); + + json_output = to_v1(cg); + } + std::cout << json_output.dump(2) << std::endl; + + return 0; +} diff --git a/bin/substitutions-to-dot/CMakeLists.txt b/bin/substitution-to-dot/CMakeLists.txt similarity index 100% rename from bin/substitutions-to-dot/CMakeLists.txt rename to bin/substitution-to-dot/CMakeLists.txt diff --git a/bin/substitutions-to-dot/substitution_to_dot.cc b/bin/substitution-to-dot/substitution_to_dot.cc similarity index 89% rename from bin/substitutions-to-dot/substitution_to_dot.cc rename to bin/substitution-to-dot/substitution_to_dot.cc index 49a199ddd3..1b5f715bcd 100644 --- a/bin/substitutions-to-dot/substitution_to_dot.cc +++ b/bin/substitution-to-dot/substitution_to_dot.cc @@ -1,4 +1,4 @@ -#include "substitution-generator/json.h" +#include "substitution-generator/legacy_rules.h" #include "utils/dot_file.h" #include #include @@ -24,10 +24,11 @@ int main(int argc, char **argv) { std::string json_path(argv[1]); std::string rule_name(argv[2]); - RuleCollection rule_collection = load_rule_collection_from_path(json_path); + LegacyRuleCollection rule_collection = + load_rule_collection_from_path(json_path); - std::optional found = std::nullopt; - for (Rule const &r : rule_collection.rules) { + std::optional found = std::nullopt; + for (LegacyRule const &r : rule_collection.rules) { if (r.name == rule_name) { found = r; break; @@ -39,7 +40,7 @@ int main(int argc, char **argv) { return 1; } - Rule r = found.value(); + LegacyRule r = found.value(); using Node = std::tuple; @@ -82,14 +83,14 @@ int main(int argc, char **argv) { }; for (int i = 0; i < r.srcOp.size(); i++) { - Operator const &o = r.srcOp[i]; + LegacyOperator const &o = r.srcOp[i]; Node srcOpNode = {NodeType::SRC, i, 0}; { dot.add_node(srcOpNode, label_map(fmt::to_string(o.op_type), srcOpNode)); dot.add_node_to_subgraph(srcOpNode, src_body_subgraph); } - for (Tensor const &t : o.input) { + for (LegacyTensor const &t : o.input) { if (t.opId < 0) { assert(t.tsId == 0); Node inputOpNode = {NodeType::SRC_INPUT_TENSOR, t.opId, 0}; @@ -106,14 +107,14 @@ int main(int argc, char **argv) { } } for (int j = 0; j < r.dstOp.size(); j++) { - Operator const &o = r.dstOp[j]; + LegacyOperator const &o = r.dstOp[j]; Node dstOpNode = {NodeType::DST, j, 0}; { dot.add_node(dstOpNode, label_map(fmt::to_string(o.op_type), dstOpNode)); dot.add_node_to_subgraph(dstOpNode, dst_body_subgraph); } - for (Tensor const &t : o.input) { + for (LegacyTensor const &t : o.input) { if (t.opId < 0) { assert(t.tsId == 0); Node inputOpNode = {NodeType::DST_INPUT_TENSOR, t.opId, 0}; @@ -128,7 +129,7 @@ int main(int argc, char **argv) { } } } - for (MapOutput const &mo : r.mappedOutput) { + for (LegacyMapOutput const &mo : r.mappedOutput) { Node srcOutputNode = {NodeType::SRC_OUTPUT_TENSOR, mo.srcOpId, mo.srcTsId}; Node dstOutputNode = {NodeType::DST_OUTPUT_TENSOR, mo.dstOpId, mo.dstTsId}; { diff --git a/cmake/flexflow-utils.cmake b/cmake/flexflow-utils.cmake index 1dbd16bdb1..90e100bb1b 100644 --- a/cmake/flexflow-utils.cmake +++ b/cmake/flexflow-utils.cmake @@ -149,6 +149,11 @@ function(ff_add_executable) ${FF_EXEC_NAME} ${SRC}) + target_include_directories( + ${FF_EXEC_NAME} + PRIVATE + ${FF_EXEC_PRIVATE_INCLUDE}) + target_link_libraries( ${FF_EXEC_NAME} ${FF_EXEC_DEPS}) diff --git a/flake.lock b/flake.lock index 1aad68ae29..87fae7f446 100644 --- a/flake.lock +++ b/flake.lock @@ -43,11 +43,11 @@ ] }, "locked": { - "lastModified": 1722923482, - "narHash": "sha256-myUec+oBcnKNCqLQqSiPCyXFsIsvlrsGoj/mQFlHVrY=", + "lastModified": 1728341842, + "narHash": "sha256-XMS52KBSS6z3k2VaiVcHyZQD6b2QUm1wIvTClel4xwg=", "owner": "lockshaw", "repo": "proj", - "rev": "c650b0e52337652ea7190131988c0370e0ee7f25", + "rev": "830fb5b1a0c7087752693990e90bbbf021168dfe", "type": "github" }, "original": { diff --git a/lib/compiler/include/compiler/allowed_machine_views.h b/lib/compiler/include/compiler/allowed_machine_views.h new file mode 100644 index 0000000000..9bb73fd1a9 --- /dev/null +++ b/lib/compiler/include/compiler/allowed_machine_views.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_COMPILER_ALLOWED_MACHINE_VIEWS_H +#define _FLEXFLOW_COMPILER_ALLOWED_MACHINE_VIEWS_H + +#include "pcg/machine_specification.dtg.h" +#include "pcg/machine_view.dtg.h" +#include "pcg/operator_task_space.dtg.h" + +namespace FlexFlow { + +bool is_valid_machine_view(MachineView const &mv, + OperatorTaskSpace const &task, + MachineSpecification const &ms); + +std::unordered_set + get_allowed_machine_views(MachineSpecification const &machine_spec, + OperatorTaskSpace const &task, + DeviceType device_type); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/cost_estimate.h b/lib/compiler/include/compiler/cost_estimate.h deleted file mode 100644 index 2e4ff8448b..0000000000 --- a/lib/compiler/include/compiler/cost_estimate.h +++ /dev/null @@ -1,61 +0,0 @@ - -#ifndef _FLEXFLOW_COMPILER_COST_ESTIMATE_H -#define _FLEXFLOW_COMPILER_COST_ESTIMATE_H - -#include "op-attrs/operator_attrs.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "pcg/machine_view.h" -#include "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h" - -namespace FlexFlow { - -struct ICostEstimator { - virtual float estimate_cost(PCGOperatorAttrs const &op, - std::vector const &inputs, - std::vector const &weights, - std::vector const &outputs, - MachineView const &mv) const = 0; - virtual float estimate_cost(ParallelTensorShape const &tensor_shape, - MachineView const &src, - MachineView const &dst) const = 0; - - ICostEstimator() = default; - ICostEstimator(ICostEstimator const &) = delete; - ICostEstimator &operator=(ICostEstimator const &) = delete; - - virtual ~ICostEstimator() = default; -}; -CHECK_RC_COPY_VIRTUAL_COMPLIANT(ICostEstimator); - -struct CostEstimator { - float estimate_cost(PCGOperatorAttrs const &op, - std::vector const &inputs, - std::vector const &weights, - std::vector const &outputs, - MachineView const &mv) const { - return this->implementation_ptr->estimate_cost( - op, inputs, weights, outputs, mv); - } - - float estimate_cost(ParallelTensorShape const &tensor_shape, - MachineView const &src, - MachineView const &dst) const { - return this->implementation_ptr->estimate_cost(tensor_shape, src, dst); - } - - template - static typename std::enable_if::value, - CostEstimator>::type - create(Args &&...args) { - return CostEstimator(std::make_shared(std::forward(args)...)); - } - -private: - CostEstimator(std::shared_ptr implementation_ptr) - : implementation_ptr(implementation_ptr) {} - std::shared_ptr implementation_ptr; -}; - -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/include/compiler/cost_estimator/cost_estimator.h b/lib/compiler/include/compiler/cost_estimator/cost_estimator.h new file mode 100644 index 0000000000..65bae0c76a --- /dev/null +++ b/lib/compiler/include/compiler/cost_estimator/cost_estimator.h @@ -0,0 +1,45 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_COST_ESTIMATOR_COST_ESTIMATOR_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_COST_ESTIMATOR_COST_ESTIMATOR_H + +#include "compiler/cost_estimator/op_cost_estimate_key.dtg.h" +#include "compiler/cost_estimator/tensor_set_movement.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/pcg_operator_attrs.dtg.h" +#include "pcg/machine_view.dtg.h" +#include + +namespace FlexFlow { + +struct ICostEstimator { + virtual float estimate_cost(OpCostEstimateKey const &) const = 0; + virtual float estimate_cost(TensorSetMovement const &) const = 0; + + ICostEstimator() = default; + ICostEstimator(ICostEstimator const &) = delete; + ICostEstimator &operator=(ICostEstimator const &) = delete; + + virtual ~ICostEstimator() = default; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT(ICostEstimator); + +struct CostEstimator { + float estimate_cost(OpCostEstimateKey const &k) const; + float estimate_cost(TensorSetMovement const &m) const; + + template + static typename std::enable_if::value, + CostEstimator>::type + create(Args &&...args) { + return CostEstimator(std::make_shared(std::forward(args)...)); + } + +private: + CostEstimator(std::shared_ptr implementation_ptr); + +private: + std::shared_ptr implementation_ptr; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/cost_estimator/op_cost_estimate_key.struct.toml b/lib/compiler/include/compiler/cost_estimator/op_cost_estimate_key.struct.toml new file mode 100644 index 0000000000..8fd860d00d --- /dev/null +++ b/lib/compiler/include/compiler/cost_estimator/op_cost_estimate_key.struct.toml @@ -0,0 +1,40 @@ +namespace = "FlexFlow" +name = "OpCostEstimateKey" +features = [ + "eq", + "ord", + "fmt", + "hash", +] + +includes = [ + "op-attrs/pcg_operator_attrs.dtg.h", + "op-attrs/parallel_tensor_shape.dtg.h", + "", + "pcg/machine_view.dtg.h", +] + +src_includes = [ + "utils/hash/vector.h", + "utils/fmt/vector.h", +] + +[[fields]] +name = "op_attrs" +type = "::FlexFlow::PCGOperatorAttrs" + +[[fields]] +name = "input_shapes" +type = "std::vector<::FlexFlow::ParallelTensorShape>" + +[[fields]] +name = "weight_shapes" +type = "std::vector<::FlexFlow::ParallelTensorShape>" + +[[fields]] +name = "output_shapes" +type = "std::vector<::FlexFlow::ParallelTensorShape>" + +[[fields]] +name = "machine_view" +type = "::FlexFlow::MachineView" diff --git a/lib/compiler/include/compiler/cost_estimator/single_tensor_movement.struct.toml b/lib/compiler/include/compiler/cost_estimator/single_tensor_movement.struct.toml new file mode 100644 index 0000000000..70f73ebe51 --- /dev/null +++ b/lib/compiler/include/compiler/cost_estimator/single_tensor_movement.struct.toml @@ -0,0 +1,30 @@ +namespace = "FlexFlow" +name = "SingleTensorMovement" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "op-attrs/parallel_tensor_shape.dtg.h", + "pcg/machine_view.dtg.h", + "", +] + +src_includes = [ + "utils/hash/unordered_set.h", + "utils/fmt/unordered_set.h", +] + +[[fields]] +name = "parallel_tensor_shape" +type = "::FlexFlow::ParallelTensorShape" + +[[fields]] +name = "src_machine_views" +type = "std::unordered_set<::FlexFlow::MachineView>" + +[[fields]] +name = "dst_machine_views" +type = "std::unordered_set<::FlexFlow::MachineView>" diff --git a/lib/compiler/include/compiler/cost_estimator/tensor_set_movement.struct.toml b/lib/compiler/include/compiler/cost_estimator/tensor_set_movement.struct.toml new file mode 100644 index 0000000000..3625605239 --- /dev/null +++ b/lib/compiler/include/compiler/cost_estimator/tensor_set_movement.struct.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "TensorSetMovement" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "compiler/cost_estimator/single_tensor_movement.dtg.h", + "", +] + +src_includes = [ + "utils/fmt/unordered_multiset.h", + "utils/hash/unordered_multiset.h", +] + +[[fields]] +name = "single_tensor_movements" +type = "std::unordered_multiset<::FlexFlow::SingleTensorMovement>" diff --git a/lib/compiler/include/compiler/graph_optimize_result.struct.toml b/lib/compiler/include/compiler/graph_optimize_result.struct.toml new file mode 100644 index 0000000000..22f29cbd59 --- /dev/null +++ b/lib/compiler/include/compiler/graph_optimize_result.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "GraphOptimizeResult" +features = [ ] + +includes = [ + "compiler/machine_mapping/machine_mapping.dtg.h", + "pcg/parallel_computation_graph/parallel_computation_graph.h" +] + +[[fields]] +name = "pcg" +type = "::FlexFlow::ParallelComputationGraph" + +[[fields]] +name = "machine_mapping" +type = "::FlexFlow::MachineMapping" diff --git a/lib/compiler/include/compiler/graph_optimize_state.h b/lib/compiler/include/compiler/graph_optimize_state.h new file mode 100644 index 0000000000..2de2321ba6 --- /dev/null +++ b/lib/compiler/include/compiler/graph_optimize_state.h @@ -0,0 +1,31 @@ +#ifndef _FLEXFLOW_COMPILER_MCMC_STATE_H +#define _FLEXFLOW_COMPILER_MCMC_STATE_H + +#include "compiler/graph_optimize_result.dtg.h" + +namespace FlexFlow { + +struct GraphOptimizeState { + GraphOptimizeState(GraphOptimizeResult const &graph_optimize_result, + float runtime); + + GraphOptimizeResult graph_optimize_result; + float runtime; + + bool operator==(GraphOptimizeState const &other) const; + bool operator!=(GraphOptimizeState const &other) const; + bool operator<(GraphOptimizeState const &other) const; +}; + +} // namespace FlexFlow + +namespace std { + +template <> +struct hash<::FlexFlow::GraphOptimizeState> { + size_t operator()(::FlexFlow::GraphOptimizeState const &) const; +}; + +} // namespace std + +#endif diff --git a/lib/compiler/include/compiler/graph_utils.h b/lib/compiler/include/compiler/graph_utils.h deleted file mode 100644 index 1370357837..0000000000 --- a/lib/compiler/include/compiler/graph_utils.h +++ /dev/null @@ -1,37 +0,0 @@ -#ifndef _FLEXFLOW_COMPILER_GRAPH_UTILS_H -#define _FLEXFLOW_COMPILER_GRAPH_UTILS_H - -#include "compiler/unity_algorithm.h" -#include "pcg/computation_graph.dtg.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" -#include "substitutions/sub_parallel_computation_graph.dtg.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" - -namespace FlexFlow { - -SerialParallelDecomposition - get_serial_parallel_decomposition(ParallelComputationGraph const &pcg); - -ParallelComputationGraph cg_to_pcg(ComputationGraph const &g); -SubParallelComputationGraph pcg_to_subpcg(ParallelComputationGraph const &g); - -// NOTE(@wmdi): I think we should have the following interfaces in the graph -// library eventually. - -template -void minimize(T &t, T const &v) { - if (v < t) { - t = v; - } -} - -template -void minimize(T &t, T const &v, Compare comp) { - if (comp(v, t)) { - t = v; - } -} - -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/include/compiler/machine_mapping.h b/lib/compiler/include/compiler/machine_mapping.h deleted file mode 100644 index 5d17cbb373..0000000000 --- a/lib/compiler/include/compiler/machine_mapping.h +++ /dev/null @@ -1,69 +0,0 @@ -#ifndef _FLEXFLOW_COMPILER_MACHINE_MAPPING_H -#define _FLEXFLOW_COMPILER_MACHINE_MAPPING_H - -#include "compiler/machine_mapping.dtg.h" -#include "compiler/optimal_cost_state.dtg.h" -#include "cost_estimate.h" -#include "pcg/machine_specification.dtg.h" -#include "pcg/machine_specification.h" -#include "pcg/machine_view.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph.h" -#include "substitutions/sub_parallel_computation_graph.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" - -namespace FlexFlow { - -MachineMapping combine(MachineMapping const &, MachineMapping const &); - -bool nodes_are_disjoint(MachineMapping const &m1, MachineMapping const &m2); - -struct OptimalCostResult { - static OptimalCostResult sequential_combine(OptimalCostResult const &s1, - OptimalCostResult const &s2); - static OptimalCostResult parallel_combine(OptimalCostResult const &s1, - OptimalCostResult const &s2); - static OptimalCostResult infinity(); - - float runtime; - req machine_mapping; -}; -FF_VISITABLE_STRUCT(OptimalCostResult, runtime, machine_mapping); - -struct OptimalCostRuntimeCmp { - bool operator()(OptimalCostResult const &, OptimalCostResult const &); -}; - -class OptimalCostCache { -public: - OptimalCostCache() = default; - - std::optional load(OptimalCostState const &) const; - void save(OptimalCostState const &, OptimalCostResult const &); - -private: - std::unordered_map cache; -}; - -OptimalCostResult optimal_cost( - ParallelComputationGraph const &g, - std::function( - ParallelLayerAttrs const &, MachineSpecification const &)> const - &allowed_machine_views, - CostEstimator const &cost_estimator, - MachineSpecification const &resources, - OptimalCostCache &cached_subgraph_costs); - -} // namespace FlexFlow - -// namespace std { -// -// template <> -// struct hash> { -// size_t operator()( -// std::unordered_map const &g) -// const; -// }; - -// }; // namespace std - -#endif diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.struct.toml b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.struct.toml new file mode 100644 index 0000000000..449a448706 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.struct.toml @@ -0,0 +1,30 @@ +namespace = "FlexFlow" +name = "AbstractedSingleTensorMovement" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "op-attrs/parallel_tensor_shape.dtg.h", + "utils/full_binary_tree/binary_tree_path.dtg.h", + "", +] + +src_includes = [ + "utils/hash/unordered_set.h", + "utils/fmt/unordered_set.h", +] + +[[fields]] +name = "parallel_tensor_shape" +type = "::FlexFlow::ParallelTensorShape" + +[[fields]] +name = "src_machine_views" +type = "std::unordered_set<::FlexFlow::BinaryTreePath>" + +[[fields]] +name = "dst_machine_views" +type = "std::unordered_set<::FlexFlow::BinaryTreePath>" diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h new file mode 100644 index 0000000000..5b7e2f3613 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h @@ -0,0 +1,25 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_ABSTRACTED_TENSOR_SET_MOVEMENT_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_ABSTRACTED_TENSOR_SET_MOVEMENT_H + +#include "compiler/cost_estimator/tensor_set_movement.dtg.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.dtg.h" +#include "compiler/machine_mapping/machine_mapping.dtg.h" +#include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.dtg.h" + +namespace FlexFlow { + +AbstractedTensorSetMovement empty_abstracted_tensor_set_movement(); + +std::unordered_set + get_src_layers(AbstractedTensorSetMovement const &); +std::unordered_set + get_dst_layers(AbstractedTensorSetMovement const &); + +TensorSetMovement concretize_abstracted_tensor_set_movement( + AbstractedTensorSetMovement const &, + ParallelLayerGuidObliviousMachineMapping const &pre, + ParallelLayerGuidObliviousMachineMapping const &post); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.struct.toml b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.struct.toml new file mode 100644 index 0000000000..4cf184706b --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.struct.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "AbstractedTensorSetMovement" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.dtg.h", + "", +] + +src_includes = [ + "utils/fmt/unordered_multiset.h", + "utils/hash/unordered_multiset.h", +] + +[[fields]] +name = "single_tensor_movements" +type = "std::unordered_multiset<::FlexFlow::AbstractedSingleTensorMovement>" diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h new file mode 100644 index 0000000000..8567a7a3e6 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_GET_ABSTRACTED_TENSOR_SET_MOVEMENT_ACROSS_SPLIT_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_GET_ABSTRACTED_TENSOR_SET_MOVEMENT_ACROSS_SPLIT_H + +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.dtg.h" +#include "compiler/machine_mapping/transitive_reduced_pcg.dtg.h" +#include "compiler/series_parallel/pcg/pcg_binary_series_split.dtg.h" + +namespace FlexFlow { + +AbstractedTensorSetMovement get_abstracted_tensor_set_movement_across_split( + TransitiveReducedPCG const &transitive_reduced_pcg, + PCGBinarySeriesSplit const &split); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/feasible_machine_mapping_result.struct.toml b/lib/compiler/include/compiler/machine_mapping/feasible_machine_mapping_result.struct.toml new file mode 100644 index 0000000000..e71cfc540f --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/feasible_machine_mapping_result.struct.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "FeasibleMachineMappingResult" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.dtg.h", +] + +[[fields]] +name = "runtime" +type = "float" + +[[fields]] +name = "machine_mapping" +type = "::FlexFlow::ParallelLayerGuidObliviousMachineMapping" diff --git a/lib/compiler/include/compiler/machine_mapping/get_machine_resource_splits.h b/lib/compiler/include/compiler/machine_mapping/get_machine_resource_splits.h new file mode 100644 index 0000000000..990c1c8205 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/get_machine_resource_splits.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_GET_MACHINE_RESOURCE_SPLITS_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_GET_MACHINE_RESOURCE_SPLITS_H + +#include "pcg/machine_specification.dtg.h" +#include +#include + +namespace FlexFlow { + +std::unordered_set> + get_machine_resource_splits(MachineSpecification const &resource); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h b/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h new file mode 100644 index 0000000000..62da90bfcb --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h @@ -0,0 +1,47 @@ +#ifndef _FLEXFLOW_COMPILER_MACHINE_MAPPING_GET_OPTIMAL_MACHINE_MAPPING_H +#define _FLEXFLOW_COMPILER_MACHINE_MAPPING_GET_OPTIMAL_MACHINE_MAPPING_H + +#include "compiler/machine_mapping/machine_mapping_cache.dtg.h" +#include "compiler/machine_mapping/machine_mapping_constraints.dtg.h" +#include "compiler/machine_mapping/machine_mapping_context.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.dtg.h" +#include "compiler/machine_mapping/parallel_split_transformation.dtg.h" +#include "pcg/machine_specification.dtg.h" + +namespace FlexFlow { + +MachineMappingResult + get_optimal_machine_mapping(MachineMappingCache &result_cache, + MachineMappingContext const &context, + MachineMappingProblemTree const &problem_tree, + MachineSpecification const &resources, + MachineMappingConstraints const &constraints); + +MachineMappingResult + get_optimal_machine_mapping(MachineMappingCache &result_cache, + MachineMappingContext const &context, + MMProblemTreeSeriesSplit const &series_split, + MachineSpecification const &resources, + MachineMappingConstraints const &constraints, + std::optional const + ¶llel_split_transformation); + +MachineMappingResult get_optimal_machine_mapping( + MachineMappingCache &result_cache, + MachineMappingContext const &context, + MMProblemTreeParallelSplit const ¶llel_split, + MachineSpecification const &resources, + MachineMappingConstraints const &constraints); + +MachineMappingResult + get_optimal_machine_mapping(MachineMappingCache &result_cache, + MachineMappingContext const &, + UnmappedOpCostEstimateKey const &leaf, + MachineSpecification const &resources, + MachineMappingConstraints const &constraints); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/get_tensor_set_movement_across_split.h b/lib/compiler/include/compiler/machine_mapping/get_tensor_set_movement_across_split.h new file mode 100644 index 0000000000..2aed9a20e4 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/get_tensor_set_movement_across_split.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_ESTIMATE_COST_ACROSS_SPLIT_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_ESTIMATE_COST_ACROSS_SPLIT_H + +#include "compiler/cost_estimator/tensor_set_movement.dtg.h" +#include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.dtg.h" +#include "compiler/machine_mapping/transitive_reduced_pcg.dtg.h" +#include "compiler/series_parallel/pcg/pcg_binary_series_split.dtg.h" + +namespace FlexFlow { + +TensorSetMovement get_tensor_set_movement_across_split( + TransitiveReducedPCG const &transitive_reduced_pcg, + PCGBinarySeriesSplit const &split, + ParallelLayerGuidObliviousMachineMapping const &pre_mapping, + ParallelLayerGuidObliviousMachineMapping const &post_mapping); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/include_unconstrained.struct.toml b/lib/compiler/include/compiler/machine_mapping/include_unconstrained.struct.toml new file mode 100644 index 0000000000..b9a7f9ac59 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/include_unconstrained.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "IncludeUnconstrained" +features = [ + "eq", + "ord", + "hash", + "fmt", + "rapidcheck", + "json", +] + +includes = [] + +[[fields]] +name = "raw_bool" +type = "bool" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping.h new file mode 100644 index 0000000000..06cbbf942d --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_COMPILER_MACHINE_MAPPING_H +#define _FLEXFLOW_COMPILER_MACHINE_MAPPING_H + +#include "compiler/machine_mapping/machine_mapping.dtg.h" + +namespace FlexFlow { + +MachineMapping combine_disjoint_mappings(MachineMapping const &, + MachineMapping const &); + +bool nodes_are_disjoint(MachineMapping const &m1, MachineMapping const &m2); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping.struct.toml similarity index 50% rename from lib/compiler/include/compiler/machine_mapping.struct.toml rename to lib/compiler/include/compiler/machine_mapping/machine_mapping.struct.toml index 4c4912a3fd..92517c1110 100644 --- a/lib/compiler/include/compiler/machine_mapping.struct.toml +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping.struct.toml @@ -9,13 +9,16 @@ features = [ "fmt", ] -includes = [ - "utils/graph/node/node.dtg.h", +includes = [ + "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", "pcg/machine_view.dtg.h", +] + +src_includes = [ "utils/hash/unordered_map.h", - "utils/fmt/unordered_map.h", + "utils/fmt/unordered_map.h", ] [[fields]] name = "machine_views" -type = "std::unordered_map<::FlexFlow::Node, ::FlexFlow::MachineView>" \ No newline at end of file +type = "std::unordered_map<::FlexFlow::parallel_layer_guid_t, ::FlexFlow::MachineView>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.h new file mode 100644 index 0000000000..3a0fcf0e15 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_CACHE_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_CACHE_H + +#include "compiler/machine_mapping/machine_mapping_cache.dtg.h" + +namespace FlexFlow { + +MachineMappingCache empty_machine_mapping_cache(); +std::optional + machine_mapping_cache_load(MachineMappingCache const &, + MachineMappingState const &); +void machine_mapping_cache_save(MachineMappingCache &, + MachineMappingState const &, + MachineMappingResult const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.struct.toml new file mode 100644 index 0000000000..a76ff26eb9 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "MachineMappingCache" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "", + "compiler/machine_mapping/machine_mapping_state.dtg.h", + "compiler/machine_mapping/machine_mapping_result.dtg.h", +] + +src_includes = [ + "utils/fmt/unordered_map.h", + "utils/hash/unordered_map.h", +] + +[[fields]] +name = "raw_map" +type = "std::unordered_map<::FlexFlow::MachineMappingState, ::FlexFlow::MachineMappingResult>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.h new file mode 100644 index 0000000000..d314ab493b --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.h @@ -0,0 +1,39 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_CONSTRAINTS_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_CONSTRAINTS_H + +#include "compiler/machine_mapping/include_unconstrained.dtg.h" +#include "compiler/machine_mapping/machine_mapping.dtg.h" +#include "compiler/machine_mapping/machine_mapping_constraints.dtg.h" +#include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.dtg.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" + +namespace FlexFlow { + +MachineMappingConstraints get_unconstrained_solution_for_layers( + std::unordered_set const &); + +std::unordered_set + get_all_layers(MachineMappingConstraints const &, + IncludeUnconstrained const &); + +std::optional + get_machine_view_for_layer(MachineMappingConstraints const &, + BinaryTreePath const &); + +MachineMappingConstraints restrict_to_child(MachineMappingConstraints const &, + BinaryTreePathEntry const &); +MachineMappingConstraints + restrict_to_left_child(MachineMappingConstraints const &); +MachineMappingConstraints + restrict_to_right_child(MachineMappingConstraints const &); + +MachineMappingConstraints with_additional_constraints( + MachineMappingConstraints const &, + ParallelLayerGuidObliviousMachineMapping const &); + +std::optional require_only_root(MachineMappingConstraints const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.struct.toml new file mode 100644 index 0000000000..8e13abedb9 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.struct.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "MachineMappingConstraints" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "pcg/machine_view.dtg.h", + "utils/full_binary_tree/binary_tree_path.dtg.h", + "", +] + +src_includes = [ + "utils/hash/unordered_map.h", + "utils/fmt/unordered_map.h", + "utils/fmt/optional.h", +] + +[[fields]] +name = "machine_views" +type = "std::unordered_map<::FlexFlow::BinaryTreePath, std::optional<::FlexFlow::MachineView>>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.struct.toml new file mode 100644 index 0000000000..81e26f491d --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "MachineMappingContext" +features = [] + +includes = [ + "compiler/cost_estimator/cost_estimator.h", + "pcg/machine_view.dtg.h", + "pcg/machine_specification.dtg.h", + "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.dtg.h", +] + +[[fields]] +name = "cost_estimator" +type = "::FlexFlow::CostEstimator" + +[[fields]] +name = "allowed_machine_views" +type = "std::function(::FlexFlow::UnmappedOpCostEstimateKey const &, ::FlexFlow::MachineSpecification const &)>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h new file mode 100644 index 0000000000..68d02aaa54 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_GET_MACHINE_MAPPING_PROBLEM_TREE_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_GET_MACHINE_MAPPING_PROBLEM_TREE_H + +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.dtg.h" +#include "pcg/machine_specification.dtg.h" +#include "pcg/machine_view.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" + +namespace FlexFlow { + +MachineMappingProblemTree + get_machine_mapping_problem_tree(ParallelComputationGraph const &pcg, + PCGBinarySPDecomposition const &sp); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h new file mode 100644 index 0000000000..29e9e7c90b --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h @@ -0,0 +1,32 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_PROBLEM_TREE_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_PROBLEM_TREE_H + +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.dtg.h" +#include "utils/full_binary_tree/binary_tree_path.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h" +#include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" + +namespace FlexFlow { + +GenericBinarySPDecompositionTreeImplementation + generic_binary_sp_impl_for_mm_problem_tree(); + +SPDecompositionTreeNodeType get_node_type(MachineMappingProblemTree const &); + +std::unordered_multiset + get_leaves(MachineMappingProblemTree const &); +std::unordered_set + get_all_leaf_paths(MachineMappingProblemTree const &); + +std::optional + mm_problem_tree_get_subtree_at_path(MachineMappingProblemTree const &, + BinaryTreePath const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.variant.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.variant.toml new file mode 100644 index 0000000000..1949f143cb --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.variant.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "MachineMappingProblemTree" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.dtg.h", + "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.dtg.h", + "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.dtg.h", +] + +[[values]] +type = "::FlexFlow::MMProblemTreeSeriesSplit" +key = "series" + +[[values]] +type = "::FlexFlow::MMProblemTreeParallelSplit" +key = "parallel" + +[[values]] +type = "::FlexFlow::UnmappedOpCostEstimateKey" +key = "leaf" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.struct.toml new file mode 100644 index 0000000000..5247b2006a --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.struct.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "MMProblemTreeParallelSplit" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct MachineMappingProblemTree", +] + +post_includes = [ + "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h", +] + +includes = [] + +[[fields]] +name = "left_child" +type = "::FlexFlow::MachineMappingProblemTree" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::MachineMappingProblemTree" +indirect = true diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.struct.toml new file mode 100644 index 0000000000..d4f61bb3f5 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.struct.toml @@ -0,0 +1,33 @@ +namespace = "FlexFlow" +name = "MMProblemTreeSeriesSplit" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct MachineMappingProblemTree", +] + +post_includes = [ + "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h", +] + +includes = [ + "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.dtg.h", +] + +[[fields]] +name = "tensor_set_movement" +type = "::FlexFlow::AbstractedTensorSetMovement" + +[[fields]] +name = "left_child" +type = "::FlexFlow::MachineMappingProblemTree" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::MachineMappingProblemTree" +indirect = true diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h new file mode 100644 index 0000000000..9fbad4a1d0 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_PROBLEM_TREE_UNMAPPED_OP_COST_ESTIMATE_KEY_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_PROBLEM_TREE_UNMAPPED_OP_COST_ESTIMATE_KEY_H + +#include "compiler/cost_estimator/op_cost_estimate_key.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" +#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" + +namespace FlexFlow { + +UnmappedOpCostEstimateKey get_unmapped_op_cost_estimate_key_for_layer( + ParallelComputationGraph const &, parallel_layer_guid_t const &); + +OpCostEstimateKey + map_unmapped_op_cost_estimate_key(UnmappedOpCostEstimateKey const &unmapped, + MachineView const &machine_view); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.struct.toml new file mode 100644 index 0000000000..fe76683eb7 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.struct.toml @@ -0,0 +1,36 @@ +namespace = "FlexFlow" +name = "UnmappedOpCostEstimateKey" +features = [ + "eq", + "fmt", + "hash", +] + +includes = [ + "op-attrs/pcg_operator_attrs.dtg.h", + "op-attrs/parallel_tensor_shape.dtg.h", + "", + "pcg/machine_view.dtg.h", +] + +src_includes = [ + "utils/hash/vector.h", + "utils/fmt/vector.h", +] + +[[fields]] +name = "op_attrs" +type = "::FlexFlow::PCGOperatorAttrs" + +[[fields]] +name = "input_shapes" +type = "std::vector<::FlexFlow::ParallelTensorShape>" + +[[fields]] +name = "weight_shapes" +type = "std::vector<::FlexFlow::ParallelTensorShape>" + +[[fields]] +name = "output_shapes" +type = "std::vector<::FlexFlow::ParallelTensorShape>" + diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h new file mode 100644 index 0000000000..b21fea5f24 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h @@ -0,0 +1,36 @@ +#ifndef _FLEXFLOW_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_RESULT_H +#define _FLEXFLOW_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_RESULT_H + +#include "compiler/machine_mapping/machine_mapping_result.dtg.h" +#include "compiler/machine_mapping/parallel_split_transformation.dtg.h" + +namespace FlexFlow { + +[[nodiscard]] MachineMappingResult infeasible_machine_mapping_result(); +[[nodiscard]] bool is_infeasible(MachineMappingResult const &); +FeasibleMachineMappingResult require_feasible(MachineMappingResult const &); + +[[nodiscard]] MachineMappingResult get_mapping_with_minimal_runtime( + std::unordered_set const &); + +[[nodiscard]] MachineMappingResult + series_combine(float comm_cost, + MachineMappingResult const &pre_result, + MachineMappingResult const &post_result, + std::optional const + ¶llel_split_transformation); +[[nodiscard]] MachineMappingResult + parallel_combine(MachineMappingResult const &lhs_result, + MachineMappingResult const &rhs_result); + +[[nodiscard]] MachineMappingResult + minimize_runtime(MachineMappingResult const &m1, + MachineMappingResult const &m2); + +[[nodiscard]] MachineMappingResult + make_singleton_machine_mapping_result(float runtime, + MachineView const &machine_view); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.struct.toml new file mode 100644 index 0000000000..92a2873af5 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.struct.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "MachineMappingResult" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "compiler/machine_mapping/feasible_machine_mapping_result.dtg.h", + "", +] + +src_includes = [ + "utils/fmt/optional.h", +] + +[[fields]] +name = "raw_result" +type = "std::optional<::FlexFlow::FeasibleMachineMappingResult>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_state.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_state.struct.toml new file mode 100644 index 0000000000..1346f6ebe7 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_state.struct.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "MachineMappingState" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "pcg/machine_specification.dtg.h", + "compiler/machine_mapping/machine_mapping_constraints.dtg.h", + "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h", +] + +[[fields]] +name = "problem_tree" +type = "::FlexFlow::MachineMappingProblemTree" + +[[fields]] +name = "resources" +type = "::FlexFlow::MachineSpecification" + +[[fields]] +name = "constraints" +type = "::FlexFlow::MachineMappingConstraints" diff --git a/lib/compiler/include/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h b/lib/compiler/include/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h new file mode 100644 index 0000000000..cb3af9c689 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_PARALLEL_LAYER_GUID_OBLIVIOUS_MACHINE_MAPPING_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_PARALLEL_LAYER_GUID_OBLIVIOUS_MACHINE_MAPPING_H + +#include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.dtg.h" +#include + +namespace FlexFlow { + +ParallelLayerGuidObliviousMachineMapping binary_combine_mappings( + ParallelLayerGuidObliviousMachineMapping const &pre, + ParallelLayerGuidObliviousMachineMapping const &post); + +ParallelLayerGuidObliviousMachineMapping + restrict_to_left_child(ParallelLayerGuidObliviousMachineMapping const &); +ParallelLayerGuidObliviousMachineMapping + restrict_to_right_child(ParallelLayerGuidObliviousMachineMapping const &); + +std::optional + get_machine_view_for_path(ParallelLayerGuidObliviousMachineMapping const &, + BinaryTreePath const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.struct.toml b/lib/compiler/include/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.struct.toml new file mode 100644 index 0000000000..f00fcc8490 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.struct.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "ParallelLayerGuidObliviousMachineMapping" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "pcg/machine_view.dtg.h", + "utils/full_binary_tree/binary_tree_path.dtg.h", +] + +src_includes = [ + "utils/fmt/unordered_map.h", + "utils/hash/unordered_map.h", +] + +[[fields]] +name = "raw_mapping" +type = "std::unordered_map<::FlexFlow::BinaryTreePath, ::FlexFlow::MachineView>" diff --git a/lib/compiler/include/compiler/machine_mapping/parallel_split_transformation.enum.toml b/lib/compiler/include/compiler/machine_mapping/parallel_split_transformation.enum.toml new file mode 100644 index 0000000000..8247c0cbdc --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/parallel_split_transformation.enum.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "ParallelSplitTransformation" +features = [ + "hash", + "fmt", + "rapidcheck", + "json", +] + +[[values]] +name = "LthenR" + +[[values]] +name = "RthenL" diff --git a/lib/compiler/include/compiler/machine_mapping/pcg_split_boundary_layers.struct.toml b/lib/compiler/include/compiler/machine_mapping/pcg_split_boundary_layers.struct.toml new file mode 100644 index 0000000000..155e526672 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/pcg_split_boundary_layers.struct.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "PCGSplitBoundaryLayers" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", + "", +] + +src_includes = [ + "utils/hash/unordered_set.h", "utils/fmt/unordered_set.h", +] + +[[fields]] +name = "pre_split_boundary" +type = "std::unordered_set<::FlexFlow::parallel_layer_guid_t>" + +[[fields]] +name = "post_split_boundary" +type = "std::unordered_set<::FlexFlow::parallel_layer_guid_t>" diff --git a/lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.h b/lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.h new file mode 100644 index 0000000000..2b2bc9bf84 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.h @@ -0,0 +1,34 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_TRANSITIVE_REDUCED_PCG_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_TRANSITIVE_REDUCED_PCG_H + +#include "compiler/machine_mapping/pcg_split_boundary_layers.dtg.h" +#include "compiler/machine_mapping/transitive_reduced_pcg.dtg.h" +#include "compiler/series_parallel/pcg/pcg_binary_series_split.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" +#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.dtg.h" + +namespace FlexFlow { + +TransitiveReducedDataflowGraphView + get_underlying_transitive_reduced_dataflow_graph( + TransitiveReducedPCG const &); + +TransitiveReducedPCG + pcg_get_transitive_reduction(ParallelComputationGraph const &); + +std::unordered_set + pcg_get_transitive_reduced_edges_across_split(TransitiveReducedPCG const &, + PCGBinarySeriesSplit const &); + +std::unordered_set + pcg_get_transitive_reduced_tensors_across_split( + TransitiveReducedPCG const &, PCGBinarySeriesSplit const &); + +PCGSplitBoundaryLayers pcg_get_transitive_reduced_boundary_layers_for_split( + TransitiveReducedPCG const &, PCGBinarySeriesSplit const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.struct.toml b/lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.struct.toml new file mode 100644 index 0000000000..bb76ec2ff7 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "TransitiveReducedPCG" +features = [] + +includes = [ + "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h", +] + +[[fields]] +name = "full_pcg" +type = "::FlexFlow::ParallelComputationGraph" + +[[fields]] +name = "transitive_reduction" +type = "::FlexFlow::DiGraphView" + diff --git a/lib/compiler/include/compiler/optimal_cost_state.struct.toml b/lib/compiler/include/compiler/optimal_cost_state.struct.toml deleted file mode 100644 index 50496f661b..0000000000 --- a/lib/compiler/include/compiler/optimal_cost_state.struct.toml +++ /dev/null @@ -1,36 +0,0 @@ -namespace = "FlexFlow" -name = "OptimalCostState" -features = [ - "eq", - # "ord", - "hash", - # "json", - # "rapidcheck", - "fmt", -] - -includes = [ - "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h", - "pcg/machine_specification.dtg.h", - "pcg/machine_view.dtg.h", - "utils/graph/node/node.dtg.h", - "utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.h", - "utils/fmt/unordered_map.h", - "utils/hash/unordered_map.h", -] - -[[fields]] -name = "subgraph" -type = "::FlexFlow::SerialParallelDecomposition" - -[[fields]] -name = "resource" -type = "::FlexFlow::MachineSpecification" - -[[fields]] -name = "given_machine_views" -type = "std::unordered_map<::FlexFlow::Node, ::FlexFlow::MachineView>" - -[[fields]] -name = "frontier_machine_views" -type = "std::unordered_map<::FlexFlow::OpenDataflowEdge, ::FlexFlow::MachineView>" \ No newline at end of file diff --git a/lib/compiler/include/compiler/optimizer_config.struct.toml b/lib/compiler/include/compiler/optimizer_config.struct.toml new file mode 100644 index 0000000000..b7f4f71e9c --- /dev/null +++ b/lib/compiler/include/compiler/optimizer_config.struct.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "OptimizerConfig" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ +] + +[[fields]] +name = "alpha" +type = "float" + +[[fields]] +name = "budget" +type = "int" + +[[fields]] +name = "threshold" +type = "float" + +[[fields]] +name = "max_num_ops" +type = "int" diff --git a/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_parallel_split.struct.toml b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_parallel_split.struct.toml new file mode 100644 index 0000000000..9654a2546e --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_parallel_split.struct.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "ComputationGraphBinaryParallelSplit" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct ComputationGraphBinarySPDecomposition", +] + +post_includes = [ + "compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.dtg.h", +] + +includes = [] + +[[fields]] +name = "left_child" +type = "::FlexFlow::ComputationGraphBinarySPDecomposition" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::ComputationGraphBinarySPDecomposition" +indirect = true diff --git a/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_series_split.struct.toml b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_series_split.struct.toml new file mode 100644 index 0000000000..aa66c80b43 --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_series_split.struct.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "ComputationGraphBinarySeriesSplit" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct ComputationGraphBinarySPDecomposition", +] + +post_includes = [ + "compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.dtg.h", +] + +includes = [] + +[[fields]] +name = "left_child" +type = "::FlexFlow::ComputationGraphBinarySPDecomposition" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::ComputationGraphBinarySPDecomposition" +indirect = true diff --git a/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.h b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.h new file mode 100644 index 0000000000..fdc80a1e37 --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.h @@ -0,0 +1,44 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_COMPUTATION_GRAPH_BINARY_SP_DECOMPOSITION_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_COMPUTATION_GRAPH_BINARY_SP_DECOMPOSITION_H + +#include "compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.dtg.h" +#include "pcg/computation_graph.dtg.h" +#include "pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h" +#include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" +#include "utils/overload.h" + +namespace FlexFlow { + +GenericBinarySPDecompositionTreeImplementation< + ComputationGraphBinarySPDecomposition, + ComputationGraphBinarySeriesSplit, + ComputationGraphBinaryParallelSplit, + layer_guid_t> + generic_impl_for_computation_graph_sp_tree(); + +SPDecompositionTreeNodeType + get_node_type(ComputationGraphBinarySPDecomposition const &); + +ComputationGraphBinarySPDecomposition + computation_graph_sp_decomp_from_binary_sp_decomp( + BinarySPDecompositionTree const &); + +std::optional + get_computation_graph_left_assoc_binary_sp_decomposition( + ComputationGraph const &); +std::optional + get_computation_graph_right_assoc_binary_sp_decomposition( + ComputationGraph const &); +bool is_left_associative(ComputationGraphBinarySPDecomposition const &); +bool is_right_associative(ComputationGraphBinarySPDecomposition const &); +std::unordered_multiset + get_layers(ComputationGraphBinarySPDecomposition const &); + +V1BinarySPDecomposition to_v1(ComputationGraphBinarySPDecomposition const &, + bidict const &layer_numbering); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.variant.toml b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.variant.toml new file mode 100644 index 0000000000..452470620b --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.variant.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "ComputationGraphBinarySPDecomposition" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "pcg/layer_guid_t.dtg.h", + "compiler/series_parallel/computation_graph/computation_graph_binary_series_split.dtg.h", + "compiler/series_parallel/computation_graph/computation_graph_binary_parallel_split.dtg.h", +] + +[[values]] +type = "::FlexFlow::ComputationGraphBinarySeriesSplit" +key = "series" + +[[values]] +type = "::FlexFlow::ComputationGraphBinaryParallelSplit" +key = "parallel" + +[[values]] +type = "::FlexFlow::layer_guid_t" +key = "leaf" diff --git a/lib/compiler/include/compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.h b/lib/compiler/include/compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.h new file mode 100644 index 0000000000..e85843ed26 --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_GET_COMPUTATION_GRAPH_SERIES_PARALLEL_DECOMPOSITION_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_GET_COMPUTATION_GRAPH_SERIES_PARALLEL_DECOMPOSITION_H + +#include "pcg/computation_graph.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" + +namespace FlexFlow { + +std::string render_preprocessed_computation_graph_for_sp_decomposition( + ComputationGraph const &); +std::optional + get_computation_graph_series_parallel_decomposition( + ComputationGraph const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/series_parallel/pcg/get_pcg_balanced_binary_sp_decomposition.h b/lib/compiler/include/compiler/series_parallel/pcg/get_pcg_balanced_binary_sp_decomposition.h new file mode 100644 index 0000000000..d43edaa79d --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/pcg/get_pcg_balanced_binary_sp_decomposition.h @@ -0,0 +1,11 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_GET_PCG_BALANCED_BINARY_SP_DECOMPOSITION_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_GET_PCG_BALANCED_BINARY_SP_DECOMPOSITION_H + +namespace FlexFlow { + +std::optional + get_pcg_balanced_binary_sp_decomposition(ParallelComputationGraph const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/series_parallel/pcg/get_pcg_series_parallel_decomposition.h b/lib/compiler/include/compiler/series_parallel/pcg/get_pcg_series_parallel_decomposition.h new file mode 100644 index 0000000000..d4ae77541a --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/pcg/get_pcg_series_parallel_decomposition.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_GET_PCG_SERIES_PARALLEL_DECOMPOSITION_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_GET_PCG_SERIES_PARALLEL_DECOMPOSITION_H + +#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" + +namespace FlexFlow { + +std::optional + get_pcg_series_parallel_decomposition(ParallelComputationGraph const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_parallel_split.h b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_parallel_split.h new file mode 100644 index 0000000000..f348b1a851 --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_parallel_split.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_PCG_BINARY_PARALLEL_SPLIT_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_PCG_BINARY_PARALLEL_SPLIT_H + +#include "compiler/series_parallel/pcg/pcg_binary_parallel_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.dtg.h" + +namespace FlexFlow { + +BinaryParallelSplit binary_parallel_split_from_pcg_parallel_split( + PCGBinaryParallelSplit const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_parallel_split.struct.toml b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_parallel_split.struct.toml new file mode 100644 index 0000000000..f7f7026716 --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_parallel_split.struct.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "PCGBinaryParallelSplit" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct PCGBinarySPDecomposition", +] + +post_includes = [ + "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.dtg.h", +] + +includes = [] + +[[fields]] +name = "left_child" +type = "::FlexFlow::PCGBinarySPDecomposition" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::PCGBinarySPDecomposition" +indirect = true diff --git a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_series_split.h b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_series_split.h new file mode 100644 index 0000000000..0842ffb48f --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_series_split.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_PCG_BINARY_SERIES_SPLIT_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_PCG_BINARY_SERIES_SPLIT_H + +#include "compiler/series_parallel/pcg/pcg_binary_series_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.h" + +namespace FlexFlow { + +BinarySeriesSplit + binary_series_split_from_pcg_series_split(PCGBinarySeriesSplit const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_series_split.struct.toml b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_series_split.struct.toml new file mode 100644 index 0000000000..af2c8c4dae --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_series_split.struct.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "PCGBinarySeriesSplit" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct PCGBinarySPDecomposition", +] + +post_includes = [ + "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.dtg.h", +] + +includes = [] + +[[fields]] +name = "left_child" +type = "::FlexFlow::PCGBinarySPDecomposition" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::PCGBinarySPDecomposition" +indirect = true diff --git a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h new file mode 100644 index 0000000000..86fa1a59aa --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h @@ -0,0 +1,38 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_BINARY_SP_DECOMPOSITION_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_BINARY_SP_DECOMPOSITION_H + +#include "compiler/series_parallel/pcg/pcg_binary_parallel_split.dtg.h" +#include "compiler/series_parallel/pcg/pcg_binary_series_split.dtg.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" +#include "utils/full_binary_tree/binary_tree_path.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h" +#include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" +#include + +namespace FlexFlow { + +GenericBinarySPDecompositionTreeImplementation + generic_impl_for_pcg_sp_tree(); + +BinarySPDecompositionTree + binary_sp_tree_from_pcg_sp_tree(PCGBinarySPDecomposition const &); + +std::optional + get_pcg_balanced_binary_sp_decomposition(ParallelComputationGraph const &); +std::unordered_multiset + get_parallel_layers(PCGBinarySPDecomposition const &); + +SPDecompositionTreeNodeType get_node_type(PCGBinarySPDecomposition const &); + +std::unordered_set + find_paths_to_leaf(PCGBinarySPDecomposition const &, + parallel_layer_guid_t const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.variant.toml b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.variant.toml new file mode 100644 index 0000000000..52372fb270 --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.variant.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "PCGBinarySPDecomposition" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", + "compiler/series_parallel/pcg/pcg_binary_series_split.dtg.h", + "compiler/series_parallel/pcg/pcg_binary_parallel_split.dtg.h", +] + +[[values]] +type = "::FlexFlow::PCGBinarySeriesSplit" +key = "series" + +[[values]] +type = "::FlexFlow::PCGBinaryParallelSplit" +key = "parallel" + +[[values]] +type = "::FlexFlow::parallel_layer_guid_t" +key = "leaf" diff --git a/lib/compiler/include/compiler/unity_algorithm.h b/lib/compiler/include/compiler/unity_algorithm.h index abddef37ed..232f2b9563 100644 --- a/lib/compiler/include/compiler/unity_algorithm.h +++ b/lib/compiler/include/compiler/unity_algorithm.h @@ -1,39 +1,17 @@ #ifndef _FLEXFLOW_COMPILER_UNITY_ALGORITHM_H #define _FLEXFLOW_COMPILER_UNITY_ALGORITHM_H -#include "compiler/machine_mapping.h" -#include "cost_estimate.h" -#include "machine_mapping.h" +#include "compiler/cost_estimator/cost_estimator.h" +#include "compiler/graph_optimize_result.dtg.h" +#include "optimizer_config.dtg.h" #include "pcg/computation_graph.h" #include "pcg/machine_specification.dtg.h" #include "substitutions/sub_parallel_computation_graph.h" -namespace FlexFlow { - -struct Strategy { - ParallelComputationGraph pcg; - MachineMapping machine_mapping; - req runtime; - friend bool operator!=(Strategy const &lhs, Strategy const &rhs) { - return (lhs.machine_mapping != rhs.machine_mapping) || - (lhs.runtime != rhs.runtime); - } -}; - -FF_VISITABLE_STRUCT(Strategy, pcg, machine_mapping, runtime); -struct StrategyRuntimeCmp { - bool operator()(Strategy const &, Strategy const &); -}; - -struct OptimizerConfig { - float alpha; - int budget; - float threshold; - int max_num_ops; -}; +namespace FlexFlow { -Strategy graph_optimize( - ComputationGraph &cg, +GraphOptimizeResult graph_optimize( + ParallelComputationGraph &pcg, CostEstimator const &cost_estimator, MachineSpecification const &resources, std::function( diff --git a/lib/compiler/src/compiler/allowed_machine_views.cc b/lib/compiler/src/compiler/allowed_machine_views.cc new file mode 100644 index 0000000000..1c226f79b0 --- /dev/null +++ b/lib/compiler/src/compiler/allowed_machine_views.cc @@ -0,0 +1,122 @@ +#include "compiler/allowed_machine_views.h" +#include "pcg/machine_specification.h" +#include "pcg/machine_view.h" +#include "pcg/multi_dimensional_stride.dtg.h" +#include "pcg/operator_task_space.h" +#include "utils/containers/all_of.h" +#include "utils/containers/cartesian_product.h" +#include "utils/containers/extend.h" +#include "utils/containers/filter.h" +#include "utils/containers/get_all_permutations_with_repetition.h" +#include "utils/containers/map_from_keys_and_values.h" +#include "utils/containers/product.h" +#include "utils/containers/range.h" +#include "utils/containers/replicate.h" +#include "utils/containers/sorted.h" +#include "utils/containers/transform.h" +#include "utils/containers/unordered_multiset_of.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/containers/zip.h" +#include "utils/overload.h" + +namespace FlexFlow { + +bool is_valid_machine_view(MachineView const &mv, + OperatorTaskSpace const &task, + MachineSpecification const &ms) { + std::optional maximum_device_coord = + get_machine_space_coordinate( + task, mv, get_task_space_maximum_coordinate(task), ms); + return maximum_device_coord.has_value(); +} + +/* + * Generates a set of candidate `MachineView`s. + * The returned set includes all valid machine views, and might contain invalid + * ones. This function should not be used externally (see + * `get_allowed_machine_views` instead). There is no guarantee that a non-empty + * returned set contains a valid machine view (i.e. it's possible for all + * the returned `MachineView`s to be invalid) + */ +static std::unordered_set + get_candidate_machine_views(MachineSpecification const &machine_spec, + OperatorTaskSpace const &task, + DeviceType const &device_type) { + + auto get_max_stride_upper_bound = [](std::vector const &tensor_dims, + int total_devices) -> int { + int min_num_devices_with_full_stride_volume = product(transform( + tensor_dims, [](int const &num_devices) { return num_devices - 1; })); + return std::ceil(total_devices / min_num_devices_with_full_stride_volume); + }; + + auto candidate_strides = [&](std::vector const &tensor_dims, + int total_devices) + -> std::unordered_multiset { + int max_stride_upper_bound = + get_max_stride_upper_bound(tensor_dims, total_devices); + + std::vector single_stride_range = + transform(range(1, max_stride_upper_bound + 1), + [](int stride) { return stride_t{stride}; }); + std::unordered_multiset> raw_stride_vectors = + cartesian_product(replicate(tensor_dims.size(), single_stride_range)); + std::unordered_multiset strides = + transform(raw_stride_vectors, [](auto const &stride_vec) { + return MultiDimensionalStride{stride_vec}; + }); + return strides; + }; + + auto candidate_starts = [](MachineSpecification const &ms, + DeviceType const &device_type) { + std::unordered_set result; + for (int node_idx : range(ms.num_nodes)) { + for (int device_idx : range(get_num_devices_per_node(ms, device_type))) { + result.insert( + MachineSpaceCoordinate{node_idx, device_idx, device_type}); + } + } + return result; + }; + + auto candidate_dimensions = [](OperatorTaskSpace const &task) { + std::unordered_set options = { + MachineSpecificationDimension::INTER_NODE, + MachineSpecificationDimension::INTRA_NODE}; + return get_all_permutations_with_repetition(options, num_dims(task)); + }; + + std::vector tensor_dims = task.degrees; + int total_devices = get_num_devices(machine_spec, device_type); + + std::unordered_set machine_views; + + for (MultiDimensionalStride const &strides : + candidate_strides(tensor_dims, total_devices)) { + for (MachineSpaceCoordinate start : + candidate_starts(machine_spec, device_type)) { + for (std::vector const &dims : + candidate_dimensions(task)) { + machine_views.insert( + machine_view_from_strides_and_machine_spec_dimensions( + start, strides.raw_strides, dims)); + } + } + } + return machine_views; +} + +std::unordered_set + get_allowed_machine_views(MachineSpecification const &machine_spec, + OperatorTaskSpace const &task, + DeviceType device_type) { + + std::unordered_set views = + get_candidate_machine_views(machine_spec, task, device_type); + return filter(views, [&](MachineView const &mv) { + return is_valid_machine_view(mv, task, machine_spec); + }); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/cost_estimator/cost_estimator.cc b/lib/compiler/src/compiler/cost_estimator/cost_estimator.cc new file mode 100644 index 0000000000..051ffcd190 --- /dev/null +++ b/lib/compiler/src/compiler/cost_estimator/cost_estimator.cc @@ -0,0 +1,16 @@ +#include "compiler/cost_estimator/cost_estimator.h" + +namespace FlexFlow { + +CostEstimator::CostEstimator(std::shared_ptr implementation_ptr) + : implementation_ptr(implementation_ptr) {} + +float CostEstimator::estimate_cost(OpCostEstimateKey const &k) const { + return this->implementation_ptr->estimate_cost(k); +} + +float CostEstimator::estimate_cost(TensorSetMovement const &m) const { + return this->implementation_ptr->estimate_cost(m); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/graph_optimize_state.cc b/lib/compiler/src/compiler/graph_optimize_state.cc new file mode 100644 index 0000000000..4b4f323ea4 --- /dev/null +++ b/lib/compiler/src/compiler/graph_optimize_state.cc @@ -0,0 +1,85 @@ +#include "compiler/graph_optimize_state.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.h" + +namespace FlexFlow { + +GraphOptimizeState::GraphOptimizeState( + GraphOptimizeResult const &graph_optimize_result, float runtime) + : graph_optimize_result(graph_optimize_result), runtime(runtime) {} + +bool GraphOptimizeState::operator==(GraphOptimizeState const &other) const { + // Note(@wmdi): This is a hack to implement a partially correct homomorphism + // check. Switch to the homomorphism check used in substitutions right after + // https://github.com/flexflow/FlexFlow/pull/1471 is merged. + auto layers1 = topological_ordering(graph_optimize_result.pcg); + auto layers2 = topological_ordering(other.graph_optimize_result.pcg); + if (layers1.size() != layers2.size()) { + return false; + } + std::unordered_map mapping; + for (size_t i = 0; i < layers1.size(); ++i) { + if (get_parallel_layer_attrs(graph_optimize_result.pcg, layers1[i]) != + get_parallel_layer_attrs(other.graph_optimize_result.pcg, layers2[i])) { + return false; + } + auto inputs1 = get_incoming_tensors(graph_optimize_result.pcg, layers1[i]); + auto inputs2 = + get_incoming_tensors(other.graph_optimize_result.pcg, layers2[i]); + if (inputs1.size() != inputs2.size()) { + return false; + } + for (size_t j = 0; j < inputs1.size(); ++j) { + if (inputs1[j] != mapping.at(inputs2[j])) { + return false; + } + } + auto outputs1 = get_layer_outputs(graph_optimize_result.pcg, layers1[i]); + auto outputs2 = + get_layer_outputs(other.graph_optimize_result.pcg, layers2[i]); + if (outputs1.size() != outputs2.size()) { + return false; + } + for (size_t j = 0; j < outputs1.size(); ++j) { + mapping.emplace(outputs2[j], outputs1[j]); + } + } + return true; +} + +bool GraphOptimizeState::operator!=(GraphOptimizeState const &other) const { + return !(*this == other); +} + +bool GraphOptimizeState::operator<(GraphOptimizeState const &other) const { + return runtime < other.runtime; +} + +} // namespace FlexFlow + +namespace std { + +size_t hash<::FlexFlow::GraphOptimizeState>::operator()( + ::FlexFlow::GraphOptimizeState const &state) const { + // TODO(@wmdi): Eventually it might be good to use a proper graph hash like + // https://networkx.org/documentation/stable/reference/algorithms/generated/networkx.algorithms.graph_hashing.weisfeiler_lehman_graph_hash.html#networkx.algorithms.graph_hashing.weisfeiler_lehman_graph_hash + size_t seed = 0; + auto layers = topological_ordering(state.graph_optimize_result.pcg); + ::FlexFlow::hash_combine(seed, layers.size()); + for (auto layer : layers) { + ::FlexFlow::hash_combine( + seed, get_parallel_layer_attrs(state.graph_optimize_result.pcg, layer)); + auto inputs = get_incoming_tensors(state.graph_optimize_result.pcg, layer); + ::FlexFlow::hash_combine(seed, inputs.size()); + for (auto input : inputs) { + for (size_t i = 0; i < layers.size(); ++i) { + if (get_source_layer(input) == layers[i]) { + ::FlexFlow::hash_combine(seed, i); + break; + } + } + } + } + return seed; +} + +} // namespace std diff --git a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.cc b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.cc new file mode 100644 index 0000000000..6f3deca138 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.cc @@ -0,0 +1,62 @@ +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h" +#include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h" +#include "utils/containers/flatmap.h" +#include "utils/containers/transform.h" +#include "utils/containers/unordered_set_of.h" + +namespace FlexFlow { + +AbstractedTensorSetMovement empty_abstracted_tensor_set_movement() { + return AbstractedTensorSetMovement{{}}; +} + +std::unordered_set + get_src_layers(AbstractedTensorSetMovement const &m) { + return flatmap(unordered_set_of(m.single_tensor_movements), + [](AbstractedSingleTensorMovement const &s) { + return s.src_machine_views; + }); +} + +std::unordered_set + get_dst_layers(AbstractedTensorSetMovement const &m) { + return flatmap(unordered_set_of(m.single_tensor_movements), + [](AbstractedSingleTensorMovement const &s) { + return s.dst_machine_views; + }); +} + +TensorSetMovement concretize_abstracted_tensor_set_movement( + AbstractedTensorSetMovement const &abstracted, + ParallelLayerGuidObliviousMachineMapping const &pre_mapping, + ParallelLayerGuidObliviousMachineMapping const &post_mapping) { + ParallelLayerGuidObliviousMachineMapping mapping = + binary_combine_mappings(/*lhs=*/pre_mapping, + /*rhs=*/post_mapping); + + auto concretize_tensor_movement = + [&](AbstractedSingleTensorMovement const &a) { + return SingleTensorMovement{ + /*parallel_tensor_shape=*/a.parallel_tensor_shape, + /*src_machine_views=*/ + transform( + a.src_machine_views, + [&](BinaryTreePath const &path) { + return get_machine_view_for_path(pre_mapping, path).value(); + }), + /*dst_machine_views=*/ + transform( + a.dst_machine_views, + [&](BinaryTreePath const &path) { + return get_machine_view_for_path(post_mapping, path).value(); + }), + }; + }; + + return TensorSetMovement{ + /*single_tensor_movements=*/transform(abstracted.single_tensor_movements, + concretize_tensor_movement), + }; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc new file mode 100644 index 0000000000..0e0f60c891 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc @@ -0,0 +1,63 @@ +#include "compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h" +#include "compiler/machine_mapping/transitive_reduced_pcg.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" +#include "utils/containers/generate_map.h" +#include "utils/containers/get_only.h" +#include "utils/containers/values.h" + +namespace FlexFlow { + +AbstractedTensorSetMovement get_abstracted_tensor_set_movement_across_split( + TransitiveReducedPCG const &tr_pcg, PCGBinarySeriesSplit const &split) { + + std::unordered_set edges_across_split = + pcg_get_transitive_reduced_edges_across_split(tr_pcg, split); + + auto get_movement_for_tensor = + [&](parallel_tensor_guid_t const &t) -> AbstractedSingleTensorMovement { + std::unordered_set tensor_edges = + filter(edges_across_split, [&](ParallelComputationGraphEdge const &e) { + return get_parallel_tensor(e) == t; + }); + + std::unordered_set src_layers = + transform(tensor_edges, [&](ParallelComputationGraphEdge const &e) { + return get_src_layer(e); + }); + + std::unordered_set dst_layers = + transform(tensor_edges, [&](ParallelComputationGraphEdge const &e) { + return get_dst_layer(e); + }); + + return AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/get_parallel_tensor_shape(tr_pcg.full_pcg, t), + /*src_machine_views=*/ + transform(src_layers, + [&](parallel_layer_guid_t const &l) { + return get_only( + find_paths_to_leaf(split.get_left_child(), l)); + }), + /*dst_machine_views=*/ + transform(dst_layers, + [&](parallel_layer_guid_t const &l) { + return get_only( + find_paths_to_leaf(split.get_right_child(), l)); + }), + }; + }; + + std::unordered_map + single_tensor_movements = generate_map( + pcg_get_transitive_reduced_tensors_across_split(tr_pcg, split), + get_movement_for_tensor); + + return AbstractedTensorSetMovement{ + values(single_tensor_movements), + }; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/get_machine_resource_splits.cc b/lib/compiler/src/compiler/machine_mapping/get_machine_resource_splits.cc new file mode 100644 index 0000000000..5126d9687e --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/get_machine_resource_splits.cc @@ -0,0 +1,32 @@ +#include "compiler/machine_mapping/get_machine_resource_splits.h" +#include "utils/hash/pair.h" + +namespace FlexFlow { + +std::unordered_set> + get_machine_resource_splits(MachineSpecification const &resource) { + std::unordered_set> + result; + + for (int i = 1; i < resource.num_nodes; i *= 2) { + MachineSpecification sub_resource1 = resource; + MachineSpecification sub_resource2 = resource; + sub_resource1.num_nodes = i; + sub_resource2.num_nodes = resource.num_nodes - i; + result.insert(std::make_pair(sub_resource1, sub_resource2)); + result.insert(std::make_pair(sub_resource2, sub_resource1)); + } + + for (int i = 1; i < resource.num_gpus_per_node; i *= 2) { + MachineSpecification sub_resource1 = resource; + MachineSpecification sub_resource2 = resource; + sub_resource1.num_gpus_per_node = i; + sub_resource2.num_gpus_per_node = resource.num_gpus_per_node - i; + result.insert(std::make_pair(sub_resource1, sub_resource2)); + result.insert(std::make_pair(sub_resource2, sub_resource1)); + } + + return result; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc new file mode 100644 index 0000000000..10abd7ff90 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -0,0 +1,254 @@ +#include "compiler/machine_mapping/get_optimal_machine_mapping.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h" +#include "compiler/machine_mapping/get_machine_resource_splits.h" +#include "compiler/machine_mapping/machine_mapping_cache.h" +#include "compiler/machine_mapping/machine_mapping_constraints.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h" +#include "compiler/machine_mapping/machine_mapping_result.h" +#include "compiler/machine_mapping/transitive_reduced_pcg.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.dtg.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" +#include "pcg/machine_specification.dtg.h" +#include "pcg/machine_specification.h" +#include "pcg/machine_view.dtg.h" +#include "pcg/machine_view.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "utils/containers/contains.h" +#include "utils/containers/flatmap.h" +#include "utils/containers/generate_map.h" +#include "utils/containers/get_all_assignments.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/exception.h" +#include "utils/overload.h" + +namespace FlexFlow { + +MachineMappingResult + get_optimal_machine_mapping(MachineMappingCache &result_cache, + MachineMappingContext const &context, + MachineMappingProblemTree const &problem_tree, + MachineSpecification const &resources, + MachineMappingConstraints const &constraints) { + + MachineMappingState state = MachineMappingState{ + problem_tree, + resources, + constraints, + }; + + { + std::optional cached_result = + machine_mapping_cache_load(result_cache, state); + if (cached_result) { + return cached_result.value(); + } + } + + MachineMappingResult result = + problem_tree.visit(overload{ + [&](MMProblemTreeSeriesSplit const &series_split) { + return get_optimal_machine_mapping( + result_cache, + context, + series_split, + resources, + constraints, + /*parallel_split_transformation=*/std::nullopt); + }, + [&](auto const &decomp_tree_node) { + return get_optimal_machine_mapping(result_cache, + context, + decomp_tree_node, + resources, + constraints); + }, + }); + + machine_mapping_cache_save(result_cache, state, result); + return result; +} + +MachineMappingResult + get_optimal_machine_mapping(MachineMappingCache &result_cache, + MachineMappingContext const &context, + MMProblemTreeSeriesSplit const &series_split, + MachineSpecification const &resources, + MachineMappingConstraints const &constraints, + std::optional const + ¶llel_split_transformation) { + + auto get_boundary_machine_view_assignments = + [&](std::unordered_set const &boundary_layers) + -> std::unordered_set { + std::unordered_map> + allowed = generate_map( + boundary_layers, + [&](BinaryTreePath const &l) -> std::unordered_set { + UnmappedOpCostEstimateKey leaf = + mm_problem_tree_get_subtree_at_path( + MachineMappingProblemTree{series_split}, l) + .value() + .get(); + return context.allowed_machine_views(leaf, resources); + }); + return transform( + get_all_assignments(allowed), + [](std::unordered_map const &m) { + return ParallelLayerGuidObliviousMachineMapping{m}; + }); + }; + + auto eval_pre_boundary_mapping = + [&](ParallelLayerGuidObliviousMachineMapping const + &assigned_pre_machine_views) { + MachineMappingConstraints pre_candidate = with_additional_constraints( + restrict_to_left_child(constraints), assigned_pre_machine_views); + + MachineMappingResult pre_result = + get_optimal_machine_mapping(result_cache, + context, + series_split.get_left_child(), + resources, + pre_candidate); + + return pre_result; + }; + + auto eval_post_boundary_mapping = + [&](ParallelLayerGuidObliviousMachineMapping const + &assigned_post_machine_views) { + MachineMappingConstraints post_candidate = with_additional_constraints( + restrict_to_right_child(constraints), assigned_post_machine_views); + + MachineMappingResult post_result = + get_optimal_machine_mapping(result_cache, + context, + series_split.get_right_child(), + resources, + post_candidate); + + return post_result; + }; + + MachineMappingResult result = infeasible_machine_mapping_result(); + AbstractedTensorSetMovement tensor_movement = + series_split.tensor_set_movement; + + for (ParallelLayerGuidObliviousMachineMapping const + &assigned_pre_machine_views : + get_boundary_machine_view_assignments(get_src_layers(tensor_movement))) { + + MachineMappingResult pre_result = + eval_pre_boundary_mapping(assigned_pre_machine_views); + + for (ParallelLayerGuidObliviousMachineMapping const + &assigned_post_machine_views : + get_boundary_machine_view_assignments( + get_dst_layers(tensor_movement))) { + + MachineMappingResult post_result = + eval_post_boundary_mapping(assigned_post_machine_views); + + TensorSetMovement comm_across_split = + concretize_abstracted_tensor_set_movement( + tensor_movement, + /*pre_mapping=*/assigned_pre_machine_views, + /*post_mapping=*/assigned_post_machine_views); + float cost_across_split = + context.cost_estimator.estimate_cost(comm_across_split); + + result = minimize_runtime(result, + series_combine(cost_across_split, + pre_result, + post_result, + parallel_split_transformation)); + } + } + + return result; +} + +MachineMappingResult get_optimal_machine_mapping( + MachineMappingCache &result_cache, + MachineMappingContext const &context, + MMProblemTreeParallelSplit const ¶llel_split, + MachineSpecification const &resources, + MachineMappingConstraints const &constraints) { + + MachineMappingProblemTree lhs = parallel_split.get_left_child(); + MachineMappingProblemTree rhs = parallel_split.get_right_child(); + + MachineMappingResult series_result = [&] { + MMProblemTreeSeriesSplit series_split = MMProblemTreeSeriesSplit{ + /*tensor_set_movement=*/empty_abstracted_tensor_set_movement(), + /*left_child=*/lhs, + /*right_child=*/rhs, + }; + + return get_optimal_machine_mapping(result_cache, + context, + series_split, + resources, + constraints, + ParallelSplitTransformation::LthenR); + }(); + + MachineMappingConstraints left_constraints = + restrict_to_left_child(constraints); + MachineMappingConstraints right_constraints = + restrict_to_right_child(constraints); + + auto evaluate_resource_split = + [&](std::pair const + &resource_split) { + MachineMappingResult left_result = get_optimal_machine_mapping( + result_cache, context, lhs, resource_split.first, left_constraints); + MachineMappingResult right_result = + get_optimal_machine_mapping(result_cache, + context, + rhs, + resource_split.second, + right_constraints); + + return parallel_combine(left_result, right_result); + }; + + std::unordered_set parallel_results = transform( + get_machine_resource_splits(resources), evaluate_resource_split); + + return minimize_runtime(series_result, + get_mapping_with_minimal_runtime(parallel_results)); +} + +MachineMappingResult + get_optimal_machine_mapping(MachineMappingCache &result_cache, + MachineMappingContext const &context, + UnmappedOpCostEstimateKey const &leaf, + MachineSpecification const &resource, + MachineMappingConstraints const &constraints) { + + std::unordered_set candidates = [&] { + std::optional machine_view = require_only_root(constraints); + if (machine_view.has_value()) { + return std::unordered_set{machine_view.value()}; + } else { + return context.allowed_machine_views(leaf, resource); + } + }(); + + auto get_mapping_result = [&](MachineView const &machine_view) { + OpCostEstimateKey mapped = + map_unmapped_op_cost_estimate_key(leaf, machine_view); + float cost = context.cost_estimator.estimate_cost(mapped); + + return make_singleton_machine_mapping_result(cost, machine_view); + }; + + std::unordered_set candidate_results = + transform(candidates, get_mapping_result); + + return get_mapping_with_minimal_runtime(candidate_results); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc b/lib/compiler/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc new file mode 100644 index 0000000000..6cc3f4329c --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc @@ -0,0 +1,26 @@ +#include "compiler/machine_mapping/get_tensor_set_movement_across_split.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h" +#include "compiler/machine_mapping/transitive_reduced_pcg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" +#include "utils/containers/generate_map.h" +#include "utils/containers/keys.h" +#include "utils/containers/sum.h" +#include "utils/containers/values.h" + +namespace FlexFlow { + +TensorSetMovement get_tensor_set_movement_across_split( + TransitiveReducedPCG const &tr_pcg, + PCGBinarySeriesSplit const &split, + ParallelLayerGuidObliviousMachineMapping const &pre_mapping, + ParallelLayerGuidObliviousMachineMapping const &post_mapping) { + AbstractedTensorSetMovement abstracted = + get_abstracted_tensor_set_movement_across_split(tr_pcg, split); + return concretize_abstracted_tensor_set_movement( + abstracted, pre_mapping, post_mapping); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc new file mode 100644 index 0000000000..6f350d8773 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc @@ -0,0 +1,18 @@ +#include "compiler/machine_mapping/machine_mapping.h" +#include "utils/containers.h" +#include "utils/containers/are_disjoint.h" +#include "utils/containers/keys.h" +#include "utils/containers/merge_maps.h" + +namespace FlexFlow { + +MachineMapping combine_disjoint_mappings(MachineMapping const &s1, + MachineMapping const &s2) { + return MachineMapping{merge_maps(s1.machine_views, s2.machine_views)}; +} + +bool nodes_are_disjoint(MachineMapping const &m1, MachineMapping const &m2) { + return are_disjoint(keys(m1.machine_views), keys(m2.machine_views)); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_cache.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_cache.cc new file mode 100644 index 0000000000..fbfccf737f --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_cache.cc @@ -0,0 +1,30 @@ +#include "compiler/machine_mapping/machine_mapping_cache.h" +#include "utils/containers/contains_key.h" +#include "utils/containers/try_at.h" + +namespace FlexFlow { + +MachineMappingCache empty_machine_mapping_cache() { + return MachineMappingCache{{}}; +} + +std::optional + machine_mapping_cache_load(MachineMappingCache const &cache, + MachineMappingState const &k) { + return try_at(cache.raw_map, k); +} + +void machine_mapping_cache_save(MachineMappingCache &cache, + MachineMappingState const &k, + MachineMappingResult const &v) { + if (contains_key(cache.raw_map, k)) { + throw mk_runtime_error( + fmt::format("machine_mapping_cache_save expected key to not already " + "exist, but received existing key {}", + k)); + } + + cache.raw_map.emplace(k, v); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc new file mode 100644 index 0000000000..2cee866a01 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc @@ -0,0 +1,112 @@ +#include "compiler/machine_mapping/machine_mapping_constraints.h" +#include "utils/containers/filter.h" +#include "utils/containers/filtermap_keys.h" +#include "utils/containers/flatmap.h" +#include "utils/containers/generate_map.h" +#include "utils/containers/keys.h" +#include "utils/containers/map_values.h" +#include "utils/containers/restrict_keys.h" +#include "utils/full_binary_tree/binary_tree_path.h" + +namespace FlexFlow { + +MachineMappingConstraints get_unconstrained_solution_for_layers( + std::unordered_set const &layers) { + return MachineMappingConstraints{ + generate_map(layers, + [](BinaryTreePath const &) -> std::optional { + return std::nullopt; + }), + }; +} + +std::unordered_set + get_all_layers(MachineMappingConstraints const &partial_solution, + IncludeUnconstrained const &include_unconstrained) { + std::unordered_set with_unconstrained = + keys(partial_solution.machine_views); + + if (include_unconstrained.raw_bool) { + return with_unconstrained; + } else { + return filter(with_unconstrained, [&](BinaryTreePath const &l) { + return partial_solution.machine_views.at(l).has_value(); + }); + } +} + +std::optional get_machine_view_for_layer( + MachineMappingConstraints const &partial_solution, + BinaryTreePath const &layer) { + return partial_solution.machine_views.at(layer); +} + +MachineMappingConstraints + restrict_to_child(MachineMappingConstraints const &constraints, + BinaryTreePathEntry const &prefix) { + return MachineMappingConstraints{filtermap_keys( + constraints.machine_views, + [&](BinaryTreePath const &path) -> std::optional { + BinaryTreePathEntry head = binary_tree_path_get_top_level(path); + + if (head == prefix) { + BinaryTreePath rest = binary_tree_path_get_non_top_level(path); + return rest; + } else { + return std::nullopt; + } + })}; +} + +MachineMappingConstraints + restrict_to_left_child(MachineMappingConstraints const &c) { + return restrict_to_child(c, BinaryTreePathEntry::LEFT_CHILD); +} + +MachineMappingConstraints + restrict_to_right_child(MachineMappingConstraints const &c) { + return restrict_to_child(c, BinaryTreePathEntry::RIGHT_CHILD); +} + +MachineMappingConstraints with_additional_constraints( + MachineMappingConstraints const &constraints, + ParallelLayerGuidObliviousMachineMapping const &additional) { + MachineMappingConstraints result = constraints; + + for (auto const &[layer, machine_view] : additional.raw_mapping) { + std::optional current_machine_view = + result.machine_views.at(layer); + + if (!current_machine_view.has_value()) { + result.machine_views.at(layer) = machine_view; + } else { + if (current_machine_view.value() != machine_view) { + throw mk_runtime_error( + fmt::format("with_additional_layer_machine_views received machine " + "view assignment for layer {} " + "to machine view {}, but that layer is already " + "assigned to machine view {}.", + layer, + machine_view, + current_machine_view.value())); + } + } + } + + return result; +} + +std::optional + require_only_root(MachineMappingConstraints const &constraints) { + if (keys(constraints.machine_views) != + std::unordered_set{binary_tree_root_path()}) { + throw mk_runtime_error( + fmt::format("require_only_root expected constraints to have only a " + "single key (the root path), but received {}", + constraints)); + } + + return constraints.machine_views.at(binary_tree_root_path()); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc new file mode 100644 index 0000000000..367af3701e --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc @@ -0,0 +1,53 @@ +#include "compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h" +#include "compiler/machine_mapping/transitive_reduced_pcg.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "utils/overload.h" + +namespace FlexFlow { + +MachineMappingProblemTree get_machine_mapping_problem_tree( + ParallelComputationGraph const &pcg, + PCGBinarySPDecomposition const &sp_decomposition_tree) { + TransitiveReducedPCG tr_pcg = pcg_get_transitive_reduction(pcg); + + std::function + to_problem_tree; + + to_problem_tree = + [&](PCGBinarySPDecomposition const &sp) -> MachineMappingProblemTree { + return sp.visit(overload{ + [&](PCGBinarySeriesSplit const &series) { + AbstractedTensorSetMovement tensor_movement = + get_abstracted_tensor_set_movement_across_split(tr_pcg, series); + return MachineMappingProblemTree{ + MMProblemTreeSeriesSplit{ + /*tensor_set_movement=*/tensor_movement, + /*lhs=*/to_problem_tree(series.get_left_child()), + /*rhs=*/to_problem_tree(series.get_right_child()), + }, + }; + }, + [&](PCGBinaryParallelSplit const ¶llel) { + return MachineMappingProblemTree{ + MMProblemTreeParallelSplit{ + to_problem_tree(parallel.get_left_child()), + to_problem_tree(parallel.get_right_child()), + }, + }; + }, + [&](parallel_layer_guid_t const &leaf) { + return MachineMappingProblemTree{ + get_unmapped_op_cost_estimate_key_for_layer(pcg, leaf), + }; + }, + }); + }; + + return to_problem_tree(sp_decomposition_tree); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc new file mode 100644 index 0000000000..1e39a7be19 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc @@ -0,0 +1,91 @@ +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h" + +namespace FlexFlow { + +GenericBinarySPDecompositionTreeImplementation + generic_binary_sp_impl_for_mm_problem_tree() { + return GenericBinarySPDecompositionTreeImplementation< + MachineMappingProblemTree, + MMProblemTreeSeriesSplit, + MMProblemTreeParallelSplit, + UnmappedOpCostEstimateKey>{ + /*series_get_left_child=*/[](MMProblemTreeSeriesSplit const &split) + -> MachineMappingProblemTree const & { + return split.get_left_child(); + }, + /*parallel_get_left_child=*/ + [](MMProblemTreeParallelSplit const &split) + -> MachineMappingProblemTree const & { + return split.get_left_child(); + }, + /*series_get_right_child=*/ + [](MMProblemTreeSeriesSplit const &split) + -> MachineMappingProblemTree const & { + return split.get_right_child(); + }, + /*parallel_get_right_child=*/ + [](MMProblemTreeParallelSplit const &split) + -> MachineMappingProblemTree const & { + return split.get_right_child(); + }, + /*get_node_type=*/ + [](MachineMappingProblemTree const &tree) -> SPDecompositionTreeNodeType { + return get_node_type(tree); + }, + /*require_series=*/ + [](MachineMappingProblemTree const &tree) + -> MMProblemTreeSeriesSplit const & { + return tree.get(); + }, + /*require_parallel=*/ + [](MachineMappingProblemTree const &tree) + -> MMProblemTreeParallelSplit const & { + return tree.get(); + }, + /*require_leaf=*/ + [](MachineMappingProblemTree const &tree) + -> UnmappedOpCostEstimateKey const & { + return tree.get(); + }, + }; +} + +SPDecompositionTreeNodeType + get_node_type(MachineMappingProblemTree const &tree) { + return tree.visit(overload{ + [](MMProblemTreeSeriesSplit const &) { + return SPDecompositionTreeNodeType::SERIES; + }, + [](MMProblemTreeParallelSplit const &) { + return SPDecompositionTreeNodeType::PARALLEL; + }, + [](UnmappedOpCostEstimateKey const &) { + return SPDecompositionTreeNodeType::NODE; + }, + }); +} + +std::unordered_multiset + get_leaves(MachineMappingProblemTree const &tree) { + return get_leaves(tree, generic_binary_sp_impl_for_mm_problem_tree()); +} + +std::unordered_set + get_all_leaf_paths(MachineMappingProblemTree const &tree) { + return get_all_leaf_paths(tree, generic_binary_sp_impl_for_mm_problem_tree()); +} + +std::optional + mm_problem_tree_get_subtree_at_path(MachineMappingProblemTree const &tree, + BinaryTreePath const &path) { + return get_subtree_at_path( + tree, generic_binary_sp_impl_for_mm_problem_tree(), path); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.cc new file mode 100644 index 0000000000..990b287f8b --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.cc @@ -0,0 +1,36 @@ +#include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" + +namespace FlexFlow { + +UnmappedOpCostEstimateKey get_unmapped_op_cost_estimate_key_for_layer( + ParallelComputationGraph const &pcg, parallel_layer_guid_t const &layer) { + auto get_tensor_shape = [&](parallel_tensor_guid_t const &t) { + return get_parallel_tensor_shape(pcg, t); + }; + + return UnmappedOpCostEstimateKey{ + /*op_attrs=*/pcg_get_op_attrs(pcg, layer), + /*input_shapes=*/ + transform(get_incoming_inputs(pcg, layer), get_tensor_shape), + /*weight_shapes=*/ + transform(get_incoming_weights(pcg, layer), get_tensor_shape), + /*output_shapes=*/ + transform(get_layer_outputs(pcg, layer), get_tensor_shape), + }; +} + +OpCostEstimateKey + map_unmapped_op_cost_estimate_key(UnmappedOpCostEstimateKey const &unmapped, + MachineView const &machine_view) { + return OpCostEstimateKey{ + /*op_attrs=*/unmapped.op_attrs, + /*input_shapes=*/unmapped.input_shapes, + /*weight_shapes=*/unmapped.weight_shapes, + /*output_shapes=*/unmapped.output_shapes, + /*machine_view=*/machine_view, + }; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc new file mode 100644 index 0000000000..3409f7f871 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc @@ -0,0 +1,138 @@ +#include "compiler/machine_mapping/machine_mapping_result.h" +#include "compiler/machine_mapping/machine_mapping.h" +#include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h" +#include "utils/containers/map_keys.h" +#include "utils/containers/merge_maps.h" +#include "utils/full_binary_tree/binary_tree_path.h" + +namespace FlexFlow { + +MachineMappingResult infeasible_machine_mapping_result() { + return MachineMappingResult{std::nullopt}; +} + +bool is_infeasible(MachineMappingResult const &result) { + return !result.raw_result.has_value(); +} + +FeasibleMachineMappingResult + require_feasible(MachineMappingResult const &result) { + return result.raw_result.value(); +} + +[[nodiscard]] MachineMappingResult get_mapping_with_minimal_runtime( + std::unordered_set const &candidates) { + MachineMappingResult result = infeasible_machine_mapping_result(); + + for (MachineMappingResult const &candidate : candidates) { + result = minimize_runtime(result, candidate); + } + + return result; +} + +MachineMappingResult + series_combine(float comm_cost, + MachineMappingResult const &maybe_pre_result, + MachineMappingResult const &maybe_post_result, + std::optional const + ¶llel_split_transformation) { + FeasibleMachineMappingResult pre_result = ({ + if (is_infeasible(maybe_pre_result)) { + return infeasible_machine_mapping_result(); + } + require_feasible(maybe_pre_result); + }); + + FeasibleMachineMappingResult post_result = ({ + if (is_infeasible(maybe_post_result)) { + return infeasible_machine_mapping_result(); + } + require_feasible(maybe_post_result); + }); + + ParallelLayerGuidObliviousMachineMapping mapping = [&] { + if (parallel_split_transformation.has_value() && + parallel_split_transformation.value() == + ParallelSplitTransformation::RthenL) { + return binary_combine_mappings(/*lhs=*/post_result.machine_mapping, + /*rhs=*/pre_result.machine_mapping); + } else { + return binary_combine_mappings(/*lhs=*/pre_result.machine_mapping, + /*rhs=*/post_result.machine_mapping); + } + }(); + + return MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/pre_result.runtime + comm_cost + post_result.runtime, + /*machine_mapping=*/mapping, + }, + }; +} + +MachineMappingResult + parallel_combine(MachineMappingResult const &maybe_lhs_result, + MachineMappingResult const &maybe_rhs_result) { + FeasibleMachineMappingResult lhs_result = ({ + if (is_infeasible(maybe_lhs_result)) { + return infeasible_machine_mapping_result(); + } + require_feasible(maybe_lhs_result); + }); + + FeasibleMachineMappingResult rhs_result = ({ + if (is_infeasible(maybe_rhs_result)) { + return infeasible_machine_mapping_result(); + } + require_feasible(maybe_rhs_result); + }); + + return MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/std::max(lhs_result.runtime, rhs_result.runtime), + /*machine_mapping=*/ + binary_combine_mappings(/*lhs=*/lhs_result.machine_mapping, + /*rhs=*/rhs_result.machine_mapping), + }, + }; +} + +MachineMappingResult minimize_runtime(MachineMappingResult const &maybe_m1, + MachineMappingResult const &maybe_m2) { + FeasibleMachineMappingResult m1 = ({ + if (is_infeasible(maybe_m1)) { + return maybe_m2; + } + require_feasible(maybe_m1); + }); + + FeasibleMachineMappingResult m2 = ({ + if (is_infeasible(maybe_m2)) { + return maybe_m1; + } + require_feasible(maybe_m2); + }); + + if (m2.runtime < m1.runtime) { + return maybe_m2; + } else { + return maybe_m1; + } +} + +MachineMappingResult + make_singleton_machine_mapping_result(float runtime, + MachineView const &machine_view) { + return MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/runtime, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + {binary_tree_root_path(), machine_view}, + }}, + }, + }; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.cc new file mode 100644 index 0000000000..715a4c2e3d --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.cc @@ -0,0 +1,24 @@ +#include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h" +#include "utils/containers/map_keys.h" +#include "utils/containers/merge_maps.h" +#include "utils/containers/try_at.h" +#include "utils/full_binary_tree/binary_tree_path.h" + +namespace FlexFlow { + +ParallelLayerGuidObliviousMachineMapping binary_combine_mappings( + ParallelLayerGuidObliviousMachineMapping const &lhs, + ParallelLayerGuidObliviousMachineMapping const &rhs) { + return ParallelLayerGuidObliviousMachineMapping{ + merge_maps(map_keys(lhs.raw_mapping, nest_inside_left_child), + map_keys(rhs.raw_mapping, nest_inside_right_child)), + }; +} + +std::optional get_machine_view_for_path( + ParallelLayerGuidObliviousMachineMapping const &mapping, + BinaryTreePath const &path) { + return try_at(mapping.raw_mapping, path); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/transitive_reduced_pcg.cc b/lib/compiler/src/compiler/machine_mapping/transitive_reduced_pcg.cc new file mode 100644 index 0000000000..96c8106cad --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/transitive_reduced_pcg.cc @@ -0,0 +1,93 @@ +#include "compiler/machine_mapping/transitive_reduced_pcg.h" +#include "compiler/series_parallel/pcg/pcg_binary_series_split.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" +#include "utils/containers/flatmap.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.h" +#include "utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.h" +#include "utils/graph/digraph/algorithms/get_predecessors.h" +#include "utils/graph/digraph/algorithms/get_successors.h" +#include "utils/graph/digraph/algorithms/transitive_reduction.h" + +namespace FlexFlow { + +TransitiveReducedDataflowGraphView + get_underlying_transitive_reduced_dataflow_graph( + TransitiveReducedPCG const &tr_pcg) { + return TransitiveReducedDataflowGraphView{ + /*full_dataflow_graph=*/tr_pcg.full_pcg.raw_graph, + /*transitive_reduction=*/tr_pcg.transitive_reduction, + }; +} + +TransitiveReducedPCG + pcg_get_transitive_reduction(ParallelComputationGraph const &pcg) { + DiGraphView raw_digraph = pcg.raw_graph; + DiGraphView transitive_reduced = transitive_reduction(raw_digraph); + + return TransitiveReducedPCG{ + /*pcg=*/pcg, + /*transitive_reduction=*/transitive_reduced, + }; +} + +std::unordered_set + pcg_get_transitive_reduced_edges_across_split( + TransitiveReducedPCG const &tr_pcg, PCGBinarySeriesSplit const &split) { + + TransitiveReducedDataflowGraphView raw_tr_g = + get_underlying_transitive_reduced_dataflow_graph(tr_pcg); + + BinarySeriesSplit raw_split = + binary_series_split_from_pcg_series_split(split); + + std::unordered_set raw_edges = + get_transitive_reduced_edges_across_split(raw_tr_g, raw_split); + + return transform(raw_edges, [](DataflowEdge const &e) { + return ParallelComputationGraphEdge{e}; + }); +} + +std::unordered_set + pcg_get_transitive_reduced_tensors_across_split( + TransitiveReducedPCG const &tr_pcg, PCGBinarySeriesSplit const &split) { + TransitiveReducedDataflowGraphView raw_tr_g = + get_underlying_transitive_reduced_dataflow_graph(tr_pcg); + + BinarySeriesSplit raw_split = + binary_series_split_from_pcg_series_split(split); + + std::unordered_set raw_outputs = + get_transitive_reduced_outputs_across_split(raw_tr_g, raw_split); + + return transform(raw_outputs, [](DataflowOutput const &o) { + return parallel_tensor_guid_t{o}; + }); +} + +PCGSplitBoundaryLayers pcg_get_transitive_reduced_boundary_layers_for_split( + TransitiveReducedPCG const &tr_pcg, PCGBinarySeriesSplit const &split) { + TransitiveReducedDataflowGraphView raw_tr_g = + get_underlying_transitive_reduced_dataflow_graph(tr_pcg); + + BinarySeriesSplit raw_split = + binary_series_split_from_pcg_series_split(split); + + SplitBoundaryNodes raw_boundary = + get_transitive_reduced_boundary_nodes_for_split(raw_tr_g, raw_split); + + return PCGSplitBoundaryLayers{ + /*pre_split_boundary=*/transform( + raw_boundary.pre_split_boundary, + [](Node const &n) { return parallel_layer_guid_t{n}; }), + /*post_split_boundary=*/ + transform(raw_boundary.post_split_boundary, + [](Node const &n) { return parallel_layer_guid_t{n}; }), + }; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.cc b/lib/compiler/src/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.cc new file mode 100644 index 0000000000..32fb53b58a --- /dev/null +++ b/lib/compiler/src/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.cc @@ -0,0 +1,192 @@ +#include "compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.h" +#include "compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h" +#include "utils/overload.h" + +namespace FlexFlow { + +GenericBinarySPDecompositionTreeImplementation< + ComputationGraphBinarySPDecomposition, + ComputationGraphBinarySeriesSplit, + ComputationGraphBinaryParallelSplit, + layer_guid_t> + generic_impl_for_computation_graph_sp_tree() { + + return GenericBinarySPDecompositionTreeImplementation< + ComputationGraphBinarySPDecomposition, + ComputationGraphBinarySeriesSplit, + ComputationGraphBinaryParallelSplit, + layer_guid_t>{ + /*series_get_left_child=*/ + [](ComputationGraphBinarySeriesSplit const &split) + -> ComputationGraphBinarySPDecomposition const & { + return split.get_left_child(); + }, + /*parallel_get_left_child=*/ + [](ComputationGraphBinaryParallelSplit const &split) + -> ComputationGraphBinarySPDecomposition const & { + return split.get_left_child(); + }, + /*series_get_right_child=*/ + [](ComputationGraphBinarySeriesSplit const &split) + -> ComputationGraphBinarySPDecomposition const & { + return split.get_right_child(); + }, + /*parallel_get_right_child=*/ + [](ComputationGraphBinaryParallelSplit const &split) + -> ComputationGraphBinarySPDecomposition const & { + return split.get_right_child(); + }, + /*get_node_type=*/ + [](ComputationGraphBinarySPDecomposition const &tree) + -> SPDecompositionTreeNodeType { return get_node_type(tree); }, + /*require_series=*/ + [](ComputationGraphBinarySPDecomposition const &tree) + -> ComputationGraphBinarySeriesSplit const & { + return tree.get(); + }, + /*require_parallel=*/ + [](ComputationGraphBinarySPDecomposition const &tree) + -> ComputationGraphBinaryParallelSplit const & { + return tree.get(); + }, + /*require_leaf=*/ + [](ComputationGraphBinarySPDecomposition const &tree) + -> layer_guid_t const & { return tree.get(); }, + }; +} + +SPDecompositionTreeNodeType + get_node_type(ComputationGraphBinarySPDecomposition const &tree) { + return tree.visit(overload{ + [](ComputationGraphBinarySeriesSplit const &) { + return SPDecompositionTreeNodeType::SERIES; + }, + [](ComputationGraphBinaryParallelSplit const ¶llel) { + return SPDecompositionTreeNodeType::PARALLEL; + }, + [](layer_guid_t const &leaf) { + return SPDecompositionTreeNodeType::NODE; + }, + }); +} + +layer_guid_t require_node(ComputationGraphBinarySPDecomposition const &tree) { + return tree.get(); +} + +ComputationGraphBinarySPDecomposition + computation_graph_sp_decomp_from_binary_sp_decomp( + BinarySPDecompositionTree const &bin) { + return bin.visit(overload{ + [](BinarySeriesSplit const &series) { + return ComputationGraphBinarySPDecomposition{ + ComputationGraphBinarySeriesSplit{ + computation_graph_sp_decomp_from_binary_sp_decomp( + series.get_left_child()), + computation_graph_sp_decomp_from_binary_sp_decomp( + series.get_right_child()), + }, + }; + }, + [](BinaryParallelSplit const ¶llel) { + return ComputationGraphBinarySPDecomposition{ + ComputationGraphBinaryParallelSplit{ + computation_graph_sp_decomp_from_binary_sp_decomp( + parallel.get_left_child()), + computation_graph_sp_decomp_from_binary_sp_decomp( + parallel.get_right_child()), + }, + }; + }, + [](Node const &node) { + return ComputationGraphBinarySPDecomposition{ + layer_guid_t{node}, + }; + }, + }); +} + +std::optional + get_computation_graph_left_assoc_binary_sp_decomposition( + ComputationGraph const &cg) { + SeriesParallelDecomposition sp_decomposition = ({ + std::optional result = + get_computation_graph_series_parallel_decomposition(cg); + if (!result.has_value()) { + return std::nullopt; + } + result.value(); + }); + + BinarySPDecompositionTree raw_binary_tree = + left_associative_binary_sp_tree_from_nary(sp_decomposition); + + return computation_graph_sp_decomp_from_binary_sp_decomp(raw_binary_tree); +} + +std::optional + get_computation_graph_right_assoc_binary_sp_decomposition( + ComputationGraph const &cg) { + SeriesParallelDecomposition sp_decomposition = ({ + std::optional result = + get_computation_graph_series_parallel_decomposition(cg); + if (!result.has_value()) { + return std::nullopt; + } + result.value(); + }); + + BinarySPDecompositionTree raw_binary_tree = + right_associative_binary_sp_tree_from_nary(sp_decomposition); + + return computation_graph_sp_decomp_from_binary_sp_decomp(raw_binary_tree); +} + +bool is_left_associative(ComputationGraphBinarySPDecomposition const &tree) { + return is_binary_sp_tree_left_associative( + tree, generic_impl_for_computation_graph_sp_tree()); +} + +bool is_right_associative(ComputationGraphBinarySPDecomposition const &tree) { + return is_binary_sp_tree_right_associative( + tree, generic_impl_for_computation_graph_sp_tree()); +} + +std::unordered_multiset + get_layers(ComputationGraphBinarySPDecomposition const &tree) { + return get_leaves(tree, generic_impl_for_computation_graph_sp_tree()); +} + +V1BinarySPDecomposition + to_v1(ComputationGraphBinarySPDecomposition const &tree, + bidict const &layer_numbering) { + return tree.visit( + overload{[&](ComputationGraphBinarySeriesSplit const &series) { + return V1BinarySPDecomposition{ + V1BinarySeriesSplit{ + to_v1(series.get_left_child(), layer_numbering), + to_v1(series.get_right_child(), layer_numbering), + }, + }; + }, + [&](ComputationGraphBinaryParallelSplit const ¶llel) { + return V1BinarySPDecomposition{ + V1BinaryParallelSplit{ + to_v1(parallel.get_left_child(), layer_numbering), + to_v1(parallel.get_right_child(), layer_numbering), + }, + }; + }, + [&](layer_guid_t const &layer) { + return V1BinarySPDecomposition{ + layer_numbering.at_r(layer), + }; + }}); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.cc b/lib/compiler/src/compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.cc new file mode 100644 index 0000000000..8f78d423b3 --- /dev/null +++ b/lib/compiler/src/compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.cc @@ -0,0 +1,98 @@ +#include "compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.h" +#include "op-attrs/computation_graph_op_attrs.h" +#include "pcg/computation_graph.h" +#include "pcg/computation_graph/computation_graph_edge.h" +#include "utils/graph/digraph/algorithms/digraph_as_dot.h" +#include "utils/graph/digraph/algorithms/materialize_digraph_view.h" +#include "utils/graph/digraph/algorithms/transitive_reduction.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/series_parallel/get_series_parallel_decomposition.h" +#include "utils/record_formatter.h" + +namespace FlexFlow { + +std::string render_preprocessed_computation_graph_for_sp_decomposition( + ComputationGraph const &cg) { + std::unordered_set weight_and_input_layers = + filter(get_layers(cg), [&](layer_guid_t const &l) { + ComputationGraphOpAttrs op_attrs = get_layer_attrs(cg, l).attrs; + return op_attrs.has() || op_attrs.has(); + }); + + std::unordered_set weight_and_input_layer_successors = + get_subgraph_successors(cg, weight_and_input_layers); + + // dot has is incapable of rendering the number of edges in the all-to-all + // connection, so for visualization purposes we instead insert a "fake" node + // to reduce the n^2 edges to 2*n edges + DiGraph preprocessed_digraph = + materialize_digraph_view(cg.raw_graph); + Node fake_node = preprocessed_digraph.add_node(); + for (layer_guid_t const &src : weight_and_input_layers) { + preprocessed_digraph.add_edge(DirectedEdge{src.raw_node, fake_node}); + } + for (layer_guid_t const &dst : weight_and_input_layer_successors) { + preprocessed_digraph.add_edge(DirectedEdge{fake_node, dst.raw_node}); + } + + std::function get_node_label = + [&](Node const &n) -> std::string { + if (n == fake_node) { + return "FAKE"; + } + LayerAttrs a = cg.raw_graph.at(n); + RecordFormatter r = as_dot(a.attrs); + + if (a.name.has_value()) { + RecordFormatter rr; + rr << "Name" << a.name.value(); + r << rr; + } + + std::ostringstream oss; + oss << r; + return oss.str(); + }; + std::string preprocessed_dot = digraph_as_dot( + transitive_reduction(preprocessed_digraph), get_node_label); + + return preprocessed_dot; +} + +std::optional + get_computation_graph_series_parallel_decomposition( + ComputationGraph const &cg) { + + { + DiGraphView unpreprocessed_digraph = cg.raw_graph; + std::optional unpreprocessed_sp_decomposition = + get_series_parallel_decomposition(unpreprocessed_digraph); + if (unpreprocessed_sp_decomposition.has_value()) { + return unpreprocessed_sp_decomposition.value(); + } + } + + DiGraphView preprocessed_digraph = [&] { + std::unordered_set weight_and_input_layers = + filter(get_layers(cg), [&](layer_guid_t const &l) { + ComputationGraphOpAttrs op_attrs = get_layer_attrs(cg, l).attrs; + return op_attrs.has() || op_attrs.has(); + }); + + std::unordered_set weight_and_input_layer_successors = + get_subgraph_successors(cg, weight_and_input_layers); + + DiGraph digraph = materialize_digraph_view(cg.raw_graph); + for (layer_guid_t const &src : weight_and_input_layers) { + for (layer_guid_t const &dst : weight_and_input_layer_successors) { + digraph.add_edge(DirectedEdge{src.raw_node, dst.raw_node}); + } + } + + return digraph; + }(); + + return get_series_parallel_decomposition(preprocessed_digraph); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/series_parallel/pcg/get_pcg_series_parallel_decomposition.cc b/lib/compiler/src/compiler/series_parallel/pcg/get_pcg_series_parallel_decomposition.cc new file mode 100644 index 0000000000..220614bb8b --- /dev/null +++ b/lib/compiler/src/compiler/series_parallel/pcg/get_pcg_series_parallel_decomposition.cc @@ -0,0 +1,10 @@ +#include "compiler/series_parallel/pcg/get_pcg_series_parallel_decomposition.h" + +namespace FlexFlow { + +std::optional + get_pcg_series_parallel_decomposition(ParallelComputationGraph const &) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_parallel_split.cc b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_parallel_split.cc new file mode 100644 index 0000000000..657a3c3166 --- /dev/null +++ b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_parallel_split.cc @@ -0,0 +1,14 @@ +#include "compiler/series_parallel/pcg/pcg_binary_parallel_split.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" + +namespace FlexFlow { + +BinaryParallelSplit binary_parallel_split_from_pcg_parallel_split( + PCGBinaryParallelSplit const &pcg_split) { + return BinaryParallelSplit{ + binary_sp_tree_from_pcg_sp_tree(pcg_split.get_left_child()), + binary_sp_tree_from_pcg_sp_tree(pcg_split.get_right_child()), + }; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_series_split.cc b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_series_split.cc new file mode 100644 index 0000000000..304ad224b1 --- /dev/null +++ b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_series_split.cc @@ -0,0 +1,14 @@ +#include "compiler/series_parallel/pcg/pcg_binary_series_split.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" + +namespace FlexFlow { + +BinarySeriesSplit binary_series_split_from_pcg_series_split( + PCGBinarySeriesSplit const &pcg_split) { + return BinarySeriesSplit{ + binary_sp_tree_from_pcg_sp_tree(pcg_split.get_left_child()), + binary_sp_tree_from_pcg_sp_tree(pcg_split.get_right_child()), + }; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc new file mode 100644 index 0000000000..5eb993c6ef --- /dev/null +++ b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc @@ -0,0 +1,115 @@ +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" +#include "compiler/series_parallel/pcg/pcg_binary_series_split.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" +#include "utils/overload.h" + +namespace FlexFlow { + +GenericBinarySPDecompositionTreeImplementation + generic_impl_for_pcg_sp_tree() { + + return GenericBinarySPDecompositionTreeImplementation< + PCGBinarySPDecomposition, + PCGBinarySeriesSplit, + PCGBinaryParallelSplit, + parallel_layer_guid_t>{ + /*series_get_left_child=*/[](PCGBinarySeriesSplit const &split) + -> PCGBinarySPDecomposition const & { + return split.get_left_child(); + }, + /*parallel_get_left_child=*/ + [](PCGBinaryParallelSplit const &split) + -> PCGBinarySPDecomposition const & { + return split.get_left_child(); + }, + /*series_get_right_child=*/ + [](PCGBinarySeriesSplit const &split) + -> PCGBinarySPDecomposition const & { + return split.get_right_child(); + }, + /*parallel_get_right_child=*/ + [](PCGBinaryParallelSplit const &split) + -> PCGBinarySPDecomposition const & { + return split.get_right_child(); + }, + /*get_node_type=*/ + [](PCGBinarySPDecomposition const &tree) -> SPDecompositionTreeNodeType { + return get_node_type(tree); + }, + /*require_series=*/ + [](PCGBinarySPDecomposition const &tree) -> PCGBinarySeriesSplit const & { + return tree.get(); + }, + /*require_parallel=*/ + [](PCGBinarySPDecomposition const &tree) + -> PCGBinaryParallelSplit const & { + return tree.get(); + }, + /*require_leaf=*/ + [](PCGBinarySPDecomposition const &tree) + -> parallel_layer_guid_t const & { + return tree.get(); + }, + }; +} + +BinarySPDecompositionTree + binary_sp_tree_from_pcg_sp_tree(PCGBinarySPDecomposition const &pcg_tree) { + return pcg_tree.visit(overload{ + [](PCGBinarySeriesSplit const &series) -> BinarySPDecompositionTree { + return BinarySPDecompositionTree{ + binary_series_split_from_pcg_series_split(series), + }; + }, + [](PCGBinaryParallelSplit const ¶llel) -> BinarySPDecompositionTree { + return BinarySPDecompositionTree{ + BinaryParallelSplit{ + binary_sp_tree_from_pcg_sp_tree(parallel.get_left_child()), + binary_sp_tree_from_pcg_sp_tree(parallel.get_right_child()), + }, + }; + }, + [](parallel_layer_guid_t const &layer) -> BinarySPDecompositionTree { + return BinarySPDecompositionTree{ + layer.raw_graph_node, + }; + }, + }); +} + +std::optional + get_pcg_balanced_binary_sp_decomposition(ParallelComputationGraph const &) { + NOT_IMPLEMENTED(); +} + +std::unordered_multiset + get_parallel_layers(PCGBinarySPDecomposition const &tree) { + return get_leaves(tree, generic_impl_for_pcg_sp_tree()); +} + +SPDecompositionTreeNodeType + get_node_type(PCGBinarySPDecomposition const &tree) { + return tree.visit(overload{ + [](PCGBinarySeriesSplit const &) { + return SPDecompositionTreeNodeType::SERIES; + }, + [](PCGBinaryParallelSplit const &) { + return SPDecompositionTreeNodeType::PARALLEL; + }, + [](parallel_layer_guid_t const &) { + return SPDecompositionTreeNodeType::NODE; + }, + }); +} + +std::unordered_set + find_paths_to_leaf(PCGBinarySPDecomposition const &tree, + parallel_layer_guid_t const &leaf) { + return find_paths_to_leaf(tree, generic_impl_for_pcg_sp_tree(), leaf); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/graph_utils.cc b/lib/compiler/src/graph_utils.cc deleted file mode 100644 index 08db219a21..0000000000 --- a/lib/compiler/src/graph_utils.cc +++ /dev/null @@ -1,153 +0,0 @@ -#include "compiler/graph_utils.h" -#include "pcg/computation_graph.dtg.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph.h" -#include "substitutions/sub_parallel_computation_graph.dtg.h" -#include "utils/containers/without_order.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" -namespace FlexFlow { - -SerialParallelDecomposition - get_serial_parallel_decomposition(ParallelComputationGraph const &pcg) { - NOT_IMPLEMENTED(); - // return get_serial_parallel_decomposition(pcg.raw_graph); -} - -ParallelComputationGraph cg_to_pcg(ComputationGraph const &g) { - NOT_IMPLEMENTED(); -} - -SubParallelComputationGraph pcg_to_subpcg(ParallelComputationGraph const &pcg) { - NOT_IMPLEMENTED(); - // return view_output_labelled_as_output_labelled_open(pcg.raw_graph); -} - -// std::vector -// get_sorted_node_input_edges(ParallelComputationGraph const &pcg, -// Node const &n) { -// std::unordered_map> -// incoming_edges = -// get_incoming_edges_by_idx(pcg, n); - -// std::vector result; -// for (auto const &p_id_edge_set : incoming_edges) { -// result.push_back(get_only(p_id_edge_set.second)); -// } - -// return result; -// } - -// std::unordered_map -// infer_tensor_shapes(ParallelComputationGraph const &pcg) { -// std::unordered_map result; -// for (Node const &n : get_topological_ordering(pcg)) { -// PCGOperatorAttrs op = pcg.raw_graph.at(n); - -// std::vector input_tensor_shapes = -// vector_transform([&](MultiDiEdge const &e) { return result.at(e); }, -// get_sorted_node_input_edges(pcg, n)); - -// std::vector output_tensor_shapes = -// get_output_shapes(op, input_tensor_shapes); - -// auto outgoing_edges = get_outgoing_edges_by_idx(pcg, n); - -// int i = 0; - -// for (auto const &[node_port, edges] : outgoing_edges) { -// for (MultiDiEdge const &e : edges) { -// result.insert({e, output_tensor_shapes[i++]}); -// } -// } -// } - -// assert(result.size() == get_edges(pcg.raw_graph).size()); - -// return result; -// } - -/* template */ -/* LabelledOpenMultiDiGraph */ -/* get_subgraph(LabelledOpenMultiDiGraph const &g, */ -/* std::unordered_set const &nodes, */ -/* InputSettings input_settings, */ -/* OutputSettings output_settings) { */ - -/* auto iview = LabelledOpenMultiDiGraphView(g) */ -/* .unsafe(); */ - -/* if (input_settings == InputSettings::INCLUDE_INPUTS && */ -/* output_settings == OutputSettings::INCLUDE_OUTPUTS) { */ -/* LabelledOpenMultiDiSubgraphView */ -/* subgraph_view(*iview, nodes); */ -/* return materialize_labelled_openmultidigraph_view(subgraph_view); */ -/* } else if (input_settings == InputSettings::INCLUDE_INPUTS && */ -/* output_settings == OutputSettings::EXCLUDE_OUTPUTS) { */ -/* LabelledUpwardMultiDiSubgraphView */ -/* subgraph_view(*iview, nodes); */ -/* return materialize_labelled_openmultidigraph_view( */ -/* view_as_labelled_open_multidisubgraph(subgraph_view)); */ -/* } else if (input_settings == InputSettings::EXCLUDE_INPUTS && */ -/* output_settings == OutputSettings::INCLUDE_OUTPUTS) { */ -/* LabelledDownwardMultiDiSubgraphView */ -/* subgraph_view(*iview, nodes); */ -/* return materialize_labelled_openmultidigraph_view( */ -/* view_as_labelled_open_multidisubgraph(subgraph_view)); */ -/* } else { */ -/* LabelledMultiDiSubgraphView subgraph_view(*iview, - */ -/* nodes); - */ -/* return materialize_labelled_openmultidigraph_view( */ -/* view_as_labelled_open_multidisubgraph(subgraph_view)); - */ -/* } */ -/* } */ - -// struct GetNodes { -// template -// std::unordered_set operator()(T const &t) { -// return get_nodes(t); -// } -// }; - -// std::unordered_set get_nodes(SerialParallelDecomposition const &sp) { -// return std::visit(GetNodes{}, sp.raw_variant); -// } - -// std::unordered_set get_nodes(SerialSplit const &serial) { -// return set_union( -// transform(serial.children, [](std::variant const -// child) { -// return std::visit(GetNodes{}, child); -// })); -// } - -// std::unordered_set get_nodes(ParallelSplit const ¶llel) { -// return set_union( -// transform(parallel.children, [](std::variant const -// child) { -// return std::visit(GetNodes{}, child); -// })); -// } - -// std::unordered_set get_nodes(Node const &node) { -// return {node}; -// } - -} // namespace FlexFlow diff --git a/lib/compiler/src/machine_mapping.cc b/lib/compiler/src/machine_mapping.cc deleted file mode 100644 index 741be06e68..0000000000 --- a/lib/compiler/src/machine_mapping.cc +++ /dev/null @@ -1,360 +0,0 @@ -#include "compiler/machine_mapping.h" -#include "compiler/cost_estimate.h" -#include "compiler/graph_utils.h" -#include "pcg/machine_specification.dtg.h" -#include "pcg/machine_specification.h" -#include "pcg/machine_view.dtg.h" -#include "pcg/machine_view.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph.h" -#include "utils/containers/are_disjoint.h" -#include "utils/containers/as_vector.h" -#include "utils/containers/contains_key.h" -#include "utils/containers/get_only.h" -#include "utils/containers/keys.h" -#include "utils/containers/merge_maps.h" -#include "utils/exception.h" -#include "utils/graph/graph_split.dtg.h" -#include "utils/graph/node/algorithms.h" -#include "utils/graph/open_dataflow_graph/algorithms/get_subgraph.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.h" -#include "utils/graph/serial_parallel/serial_parallel_splits.h" - -namespace FlexFlow { - -MachineMapping combine(MachineMapping const &s1, MachineMapping const &s2) { - return MachineMapping{merge_maps(s1.machine_views, s2.machine_views)}; -} - -bool nodes_are_disjoint(MachineMapping const &m1, MachineMapping const &m2) { - return are_disjoint(keys(m1.machine_views), keys(m2.machine_views)); -} - -OptimalCostResult - OptimalCostResult::sequential_combine(OptimalCostResult const &s1, - OptimalCostResult const &s2) { - return OptimalCostResult{s1.runtime + s2.runtime, - combine(s1.machine_mapping, s2.machine_mapping)}; -} - -OptimalCostResult - OptimalCostResult::parallel_combine(OptimalCostResult const &s1, - OptimalCostResult const &s2) { - return OptimalCostResult{std::max(s1.runtime, s2.runtime), - combine(s1.machine_mapping, s2.machine_mapping)}; -} - -OptimalCostResult OptimalCostResult::infinity() { - return {std::numeric_limits::infinity(), - MachineMapping{std::unordered_map{}}}; -} - -bool OptimalCostRuntimeCmp::operator()(OptimalCostResult const &lhs, - OptimalCostResult const &rhs) { - return lhs.runtime < rhs.runtime; -} - -std::optional - OptimalCostCache::load(OptimalCostState const &state) const { - if (contains_key(cache, state)) { - OptimalCostResult result = cache.at(state); - return std::make_optional(result); - } - return std::nullopt; -} - -void OptimalCostCache::save(OptimalCostState const &state, - OptimalCostResult const &result) { - assert(!contains_key(cache, state)); - cache.emplace(state, result); -} - -std::vector> - get_resource_split(MachineSpecification const &resource) { - std::vector> result; - for (int i = 1; i < resource.num_nodes; ++i) { - MachineSpecification sub_resource1 = resource, sub_resource2 = resource; - sub_resource1.num_nodes = i; - sub_resource2.num_nodes = resource.num_nodes - i; - result.push_back(std::make_pair(sub_resource1, sub_resource2)); - } - return result; -} - -// We may replace this by having unflattened AST -std::pair - decompose(SerialSplit const &serial) { - if (serial.children.size() == 2) { - return {widen(serial.children[0]), - widen(serial.children[1])}; - } - SerialSplit decompn1 = serial; - decompn1.children.pop_back(); - return {SerialParallelDecomposition(decompn1), - widen(serial.children.back())}; -} - -std::pair - decompose(ParallelSplit const ¶llel) { - if (parallel.children.size() == 2) { - std::vector children = - transform(as_vector(parallel.children), [&](auto const &child) { - return widen(child); - }); - return {children[0], children[1]}; - } - ParallelSplit decompn1 = parallel; - std::variant child = *parallel.children.begin(); - decompn1.children.erase(child); - return {SerialParallelDecomposition(decompn1), - widen(child)}; -} - -GraphSplit - get_graph_split(SerialParallelDecomposition const &pre_decomposition, - SerialParallelDecomposition const &post_decomposition) { - return GraphSplit{get_nodes(pre_decomposition), - get_nodes(post_decomposition)}; -} - -float estimate_cost(SubParallelComputationGraph const &g, - CostEstimator const &estimator, - MachineMapping const &device_mapping, - std::unordered_map const - &frontier_machine_views) { - // TODO: Consider parallelism - float cost = 0; - // for (Node const &node : get_nodes(g.raw_graph)) { - // std::vector incoming_edges = - // get_incoming_edges(g.raw_graph, node); - // std::vector inputs = - // transform(incoming_edges, - // [&](OpenDataflowEdge const &input_edge) { - // return g.raw_graph.at(input_edge).get_shape(); - // }); - // cost += estimator.estimate_cost( - // g.raw_graph.at(node).op_attrs, inputs, - // device_mapping.machine_views.at(node)); - // } - return cost; -} - -void minimize_runtime(OptimalCostResult &m1, OptimalCostResult const &m2) { - minimize(m1, m2, OptimalCostRuntimeCmp{}); -} - -struct MachineMappingSearcher { - MachineMappingSearcher( - CostEstimator cost_estimator, - std::function( - ParallelLayerAttrs const &, MachineSpecification const &)> const - &allowed_machine_views, - OptimalCostCache &cached_subgraph_costs) - : cost_estimator(cost_estimator), - allowed_machine_views(allowed_machine_views), - cached_subgraph_costs(cached_subgraph_costs) {} - - CostEstimator cost_estimator; - std::function(ParallelLayerAttrs const &, - MachineSpecification const &)> - allowed_machine_views; - OptimalCostCache &cached_subgraph_costs; - - struct OptimalCostFunctor { - OptimalCostFunctor( - MachineMappingSearcher *searcher, - SubParallelComputationGraph const &g, - MachineSpecification resource, - std::unordered_map given_machine_views, - std::unordered_map - frontier_machine_views) - : searcher(searcher), g(g), resource(resource), - given_machine_views(given_machine_views), - frontier_machine_views(frontier_machine_views) {} - - MachineMappingSearcher *searcher; - SubParallelComputationGraph const &g; - MachineSpecification resource; - std::unordered_map given_machine_views; - std::unordered_map frontier_machine_views; - - template - OptimalCostResult operator()(T const &t) { - OptimalCostState state{SerialParallelDecomposition{t}, - resource, - given_machine_views, - frontier_machine_views}; - std::optional cached_result = - searcher->cached_subgraph_costs.load(state); - - if (cached_result) { - return cached_result.value(); - } - OptimalCostResult result = searcher->optimal_cost( - t, g, resource, given_machine_views, frontier_machine_views); - - searcher->cached_subgraph_costs.save(state, result); - return result; - } - }; - - OptimalCostResult - optimal_cost(SubParallelComputationGraph const &g, - MachineSpecification resource, - SerialParallelDecomposition const &sp_decomposition) { - return std::visit(OptimalCostFunctor(this, g, resource, {}, {}), - sp_decomposition.raw_variant); - } - - OptimalCostResult optimal_cost( - SerialSplit const &serial, - SubParallelComputationGraph const &g, - MachineSpecification const &resource, - std::unordered_map const &given_machine_views, - std::unordered_map const - &frontier_machine_views) { - NOT_IMPLEMENTED(); - // OptimalCostResult optimal_result = OptimalCostResult::infinity(); - - // auto decomposed = decompose(serial); - // SerialParallelDecomposition pre_decompn = decomposed.first; - // SerialParallelDecomposition post_decompn = decomposed.second; - - // GraphSplit graph_split = get_graph_split(pre_decompn, post_decompn); - // SubParallelComputationGraph pre_graph = - // get_subgraph(g, graph_split.first); - // SubParallelComputationGraph post_graph = - // get_subgraph(g, graph_split.second); - - // std::unordered_set post_graph_sources = - // get_closed_sources(post_graph); - - // assert(post_graph_sources.size() == 1); // assume perfect SP - - // Node split_point = get_only(post_graph_sources); - // OutputMultiDiEdge split_edge = get_only(get_open_outputs(pre_graph)); - - // for (MachineView const &mv : - // allowed_machine_views(g.raw_graph.at(split_point), resource)) { - // std::unordered_map new_given_machine_views = - // given_machine_views; - // new_given_machine_views.emplace(split_point, mv); - // std::unordered_map - // new_frontier_machine_views = frontier_machine_views; - // new_frontier_machine_views.emplace(split_edge, mv); - // minimize_runtime( - // optimal_result, - // OptimalCostResult::sequential_combine( - // std::visit(OptimalCostFunctor(this, - // pre_graph, - // resource, - // given_machine_views, - // new_frontier_machine_views), - // pre_decompn.raw_variant), - // std::visit(OptimalCostFunctor(this, - // post_graph, - // resource, - // new_given_machine_views, - // frontier_machine_views), - // post_decompn.raw_variant))); - // } - - // return optimal_result; - } - - OptimalCostResult optimal_cost( - ParallelSplit const ¶llel, - SubParallelComputationGraph const &g, - MachineSpecification const &resource, - std::unordered_map const &given_machine_views, - std::unordered_map const - &frontier_machine_views) { - - NOT_IMPLEMENTED(); - // auto decomposed = decompose(parallel); - // SerialParallelDecomposition decompn1 = decomposed.first; - // SerialParallelDecomposition decompn2 = decomposed.second; - - // GraphSplit graph_split = get_graph_split(decompn1, decompn2); - // SubParallelComputationGraph g1 = get_subgraph(g, graph_split.first), - // g2 = get_subgraph(g, graph_split.second); - - // OptimalCostResult optimal_result = OptimalCostResult::sequential_combine( - // std::visit(OptimalCostFunctor(this, - // g1, - // resource, - // given_machine_views, - // frontier_machine_views), - // decompn1.raw_variant), - // std::visit(OptimalCostFunctor(this, - // g2, - // resource, - // given_machine_views, - // frontier_machine_views), - // decompn2.raw_variant)); - - // for (auto const &resource_split : get_resource_split(resource)) { - // minimize_runtime( - // optimal_result, - // OptimalCostResult::parallel_combine( - // std::visit(OptimalCostFunctor(this, - // g1, - // resource_split.first, - // given_machine_views, - // frontier_machine_views), - // decompn1.raw_variant), - // std::visit(OptimalCostFunctor(this, - // g2, - // resource_split.second, - // given_machine_views, - // frontier_machine_views), - // decompn2.raw_variant))); - // } - - // return optimal_result; - } - - OptimalCostResult optimal_cost( - Node const &node, - SubParallelComputationGraph const &g, - MachineSpecification const &resource, - std::unordered_map const &given_machine_views, - std::unordered_map const - &frontier_machine_views) { - if (contains_key(given_machine_views, node)) { - assert(contains(allowed_machine_views(g.raw_graph.at(node), resource), - given_machine_views.at(node))); - MachineMapping mv_map{given_machine_views}; - return {estimate_cost(g, cost_estimator, mv_map, frontier_machine_views), - mv_map}; - } else { - OptimalCostResult optimal_result = OptimalCostResult::infinity(); - for (auto mv : allowed_machine_views(g.raw_graph.at(node), resource)) { - MachineMapping mv_map{{{node, mv}}}; - minimize_runtime( - optimal_result, - {estimate_cost(g, cost_estimator, mv_map, frontier_machine_views), - mv_map}); - } - return optimal_result; - } - } -}; - -OptimalCostResult optimal_cost( - ParallelComputationGraph const &g, - std::function( - ParallelLayerAttrs const &, MachineSpecification const &)> const - &allowed_machine_views, - CostEstimator const &cost_estimator, - MachineSpecification const &resources, - OptimalCostCache &cached_subgraph_costs) { - SerialParallelDecomposition sp_decomposition = - get_serial_parallel_decomposition(g); - SubParallelComputationGraph subpcg = pcg_to_subpcg(g); - MachineMappingSearcher searcher( - cost_estimator, allowed_machine_views, cached_subgraph_costs); - return searcher.optimal_cost(subpcg, resources, sp_decomposition); -} - -} // namespace FlexFlow diff --git a/lib/compiler/src/unity_algorithm.cc b/lib/compiler/src/unity_algorithm.cc index ba6ef28daa..86a211c535 100644 --- a/lib/compiler/src/unity_algorithm.cc +++ b/lib/compiler/src/unity_algorithm.cc @@ -1,20 +1,16 @@ #include "compiler/unity_algorithm.h" -#include "compiler/graph_utils.h" -#include "compiler/machine_mapping.h" +#include "compiler/graph_optimize_state.h" +#include "compiler/machine_mapping/get_optimal_machine_mapping.h" #include "pcg/machine_specification.dtg.h" #include "substitutions/substitution.h" #include "utils/deduplicated_priority_queue.h" #include "utils/graph/node/algorithms.h" namespace FlexFlow { -bool StrategyRuntimeCmp::operator()(Strategy const &lhs, Strategy const &rhs) { - return lhs.runtime < rhs.runtime; -} - /* * Gets all substitutions applicable to a PCG */ -std::unordered_set +std::vector get_all_applicable_substitutions(ParallelComputationGraph const &pcg) { NOT_IMPLEMENTED(); } @@ -22,14 +18,14 @@ std::unordered_set /* * Applies a substitution to all possible positions in PCG */ -std::unordered_set +std::vector apply_substitution(ParallelComputationGraph const &pcg, Substitution const &) { NOT_IMPLEMENTED(); } -Strategy graph_optimize( - ComputationGraph &cg, +GraphOptimizeResult graph_optimize( + ParallelComputationGraph &pcg, CostEstimator const &cost_estimator, MachineSpecification const &resources, std::function( @@ -37,58 +33,61 @@ Strategy graph_optimize( &allowed_machine_views, OptimizerConfig const &opt_config) { NOT_IMPLEMENTED(); - // ParallelComputationGraph pcg = cg_to_pcg(cg); - - // std::unordered_set subs = - // get_all_applicable_substitutions(pcg); - - // OptimalCostCache cached_subgraph_costs; - // DeduplicatedPriorityQueue, - // StrategyRuntimeCmp> - // candidates; - - // OptimalCostResult initial_pcg_result = optimal_cost(pcg, - // allowed_machine_views, - // cost_estimator, - // resources, - // cached_subgraph_costs); - // Strategy initial_result{ - // pcg, initial_pcg_result.machine_mapping, initial_pcg_result.runtime}; - - // Strategy best_result = initial_result; - // candidates.push(initial_result); + // std::vector substitutions = + // get_all_applicable_substitutions(pcg); + // + // MachineMappingCache cached_subgraph_costs; + // DeduplicatedPriorityQueue candidates; + // + // MachineMappingResult original_pcg_cost = + // get_optimal_machine_mapping(pcg, + // allowed_machine_views, + // cost_estimator, + // resources, + // cached_subgraph_costs); + // + // GraphOptimizeState initial_state = { + // GraphOptimizeResult(pcg, original_pcg_cost.machine_mapping), + // original_pcg_cost.runtime}; + // + // GraphOptimizeState best_state = initial_state; + // candidates.push(initial_state); + // // for (int iteration = 0; !candidates.empty() && iteration < // opt_config.budget; // ++iteration) { - // Strategy const ¤t_result = candidates.top(); + // GraphOptimizeState current_state = candidates.top(); // candidates.pop(); - - // if (current_result.runtime < best_result.runtime) { - // best_result = current_result; - // } else if (current_result.runtime > - // best_result.runtime * opt_config.alpha) { + // + // if (current_state.runtime < best_state.runtime) { + // best_state = current_state; + // } else if (current_state.runtime > best_state.runtime * opt_config.alpha) + // { // continue; // } - - // for (auto const &sub : subs) { - // for (auto const &new_pcg : apply_substitution(current_result.pcg, sub)) - // { - // OptimalCostResult c = optimal_cost(new_pcg, - // allowed_machine_views, - // cost_estimator, - // resources, - // cached_subgraph_costs); - // Strategy new_result{new_pcg, c.machine_mapping, c.runtime}; - // if (new_result.runtime <= opt_config.threshold && + // + // for (Substitution const &substitution : substitutions) { + // for (ParallelComputationGraph const &new_pcg : apply_substitution( + // current_state.graph_optimize_result.pcg, substitution)) { + // MachineMappingResult new_pcg_cost = + // get_optimal_machine_mapping(new_pcg, + // allowed_machine_views, + // cost_estimator, + // resources, + // cached_subgraph_costs); + // GraphOptimizeState new_state{ + // GraphOptimizeResult(new_pcg, new_pcg_cost.machine_mapping), + // new_pcg_cost.runtime}; + // if (new_pcg_cost.runtime <= opt_config.threshold && // get_nodes(new_pcg.raw_graph).size() <= opt_config.max_num_ops) { - // candidates.push(new_result); + // candidates.push(new_state); // } // } // } // } - // return best_result; + // return best_state.graph_optimize_result; } } // namespace FlexFlow diff --git a/lib/compiler/test/CMakeLists.txt b/lib/compiler/test/CMakeLists.txt index 13b1fd3b83..3399a45f0f 100644 --- a/lib/compiler/test/CMakeLists.txt +++ b/lib/compiler/test/CMakeLists.txt @@ -10,4 +10,5 @@ ff_add_test_executable( compiler doctest utils-test-common + models ) diff --git a/lib/compiler/test/src/allowed_machine_views.cc b/lib/compiler/test/src/allowed_machine_views.cc new file mode 100644 index 0000000000..936894ad2d --- /dev/null +++ b/lib/compiler/test/src/allowed_machine_views.cc @@ -0,0 +1,104 @@ +#include "compiler/allowed_machine_views.h" +#include "doctest/doctest.h" +#include "utils/containers/extend.h" +#include "utils/containers/range.h" +#include "utils/containers/transform.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/containers/zip.h" +#include "utils/fmt/unordered_set.h" + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("get_allowed_machine_views") { + + SUBCASE("1 degree of parallelism") { + MachineSpecification ms = MachineSpecification{ + /*num_nodes=*/1, + /*num_cpus_per_node=*/5, + /*num_gpus_per_node=*/5, + /*inter_node_bandwidth=*/0, + /*intra_node_bandwidth=*/0, + }; + + OperatorTaskSpace task = OperatorTaskSpace{{3}}; + + std::unordered_set correct = { + MachineView{ + MachineSpaceCoordinate{ + /*node_idx=*/0, /*device_idx=*/0, DeviceType::GPU}, + {MachineViewDimension{stride_t{1}, + MachineSpecificationDimension::INTRA_NODE}}, + }, + + MachineView{ + MachineSpaceCoordinate{ + /*node_idx=*/0, /*device_idx=*/1, DeviceType::GPU}, + {MachineViewDimension{stride_t{1}, + MachineSpecificationDimension::INTRA_NODE}}, + }, + MachineView{ + MachineSpaceCoordinate{ + /*node_idx=*/0, /*device_idx=*/2, DeviceType::GPU}, + {MachineViewDimension{stride_t{1}, + MachineSpecificationDimension::INTRA_NODE}}, + }, + MachineView{ + MachineSpaceCoordinate{ + /*node_idx=*/0, /*device_idx=*/0, DeviceType::GPU}, + {MachineViewDimension{stride_t{2}, + MachineSpecificationDimension::INTRA_NODE}}, + }, + }; + + std::unordered_set result = + get_allowed_machine_views(ms, task, DeviceType::GPU); + + CHECK(correct == result); + } + + SUBCASE("2 degrees of parallelism") { + + MachineSpecification ms = MachineSpecification{ + /*num_nodes=*/3, + /*num_cpus_per_node=*/3, + /*num_gpus_per_node=*/3, + /*inter_node_bandwidth=*/0, + /*intra_node_bandwidth=*/0, + }; + OperatorTaskSpace task = OperatorTaskSpace{{2, 3}}; + + auto make_2d_view = [&](int start_node_idx, + int start_device_idx, + int stride1, + int stride2, + MachineSpecificationDimension m1, + MachineSpecificationDimension m2) { + return MachineView{ + MachineSpaceCoordinate{ + start_node_idx, start_device_idx, DeviceType::GPU}, + {MachineViewDimension{stride_t{stride1}, m1}, + MachineViewDimension{stride_t{stride2}, m2}}, + }; + }; + + auto intra = MachineSpecificationDimension::INTRA_NODE; + auto inter = MachineSpecificationDimension::INTER_NODE; + std::unordered_set correct = { + make_2d_view(0, 0, /*stride1=*/1, /*stride2=*/1, inter, intra), + make_2d_view(1, 0, /*stride1=*/1, /*stride2=*/1, inter, intra), + make_2d_view(0, 0, /*stride1=*/2, /*stride2=*/1, inter, intra), + + make_2d_view(0, 0, /*stride1=*/1, /*stride2=*/1, intra, inter), + make_2d_view(0, 1, /*stride1=*/1, /*stride2=*/1, intra, inter), + make_2d_view(0, 0, /*stride1=*/2, /*stride2=*/1, intra, inter), + }; + + std::unordered_set result = + get_allowed_machine_views(ms, task, DeviceType::GPU); + + CHECK(correct == result); + } + } +} diff --git a/lib/compiler/test/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc b/lib/compiler/test/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc new file mode 100644 index 0000000000..5c8ea1c0f1 --- /dev/null +++ b/lib/compiler/test/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc @@ -0,0 +1,300 @@ +#include "compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h" +#include "compiler/machine_mapping/transitive_reduced_pcg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "utils/containers/get_only.h" +#include "utils/full_binary_tree/binary_tree_path.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_abstracted_tensor_set_movement_across_split") { + auto make_series_split = [](PCGBinarySPDecomposition const &lhs, + PCGBinarySPDecomposition const &rhs) { + return PCGBinarySPDecomposition{PCGBinarySeriesSplit{lhs, rhs}}; + }; + + auto make_parallel_split = [](PCGBinarySPDecomposition const &lhs, + PCGBinarySPDecomposition const &rhs) { + return PCGBinarySPDecomposition{PCGBinaryParallelSplit{lhs, rhs}}; + }; + + auto make_leaf = [](parallel_layer_guid_t const &l) { + return PCGBinarySPDecomposition{l}; + }; + + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelTensorShape input_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{10, 2}, + ShardParallelDim{12, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + ParallelLayerAttrs relu_attrs = ParallelLayerAttrs{ + /*op_attrs=*/PCGOperatorAttrs{ + ElementUnaryAttrs{ + /*op_type=*/OperatorType::RELU, + /*scalar=*/std::nullopt, + }, + }, + /*name=*/std::nullopt, + }; + + ParallelLayerAttrs ew_add_attrs = ParallelLayerAttrs{ + /*op_attrs=*/PCGOperatorAttrs{ + ElementBinaryAttrs{ + /*type=*/OperatorType::EW_ADD, + /*compute_type=*/DataType::FLOAT, + /*should_broadcast_lhs=*/false, + /*should_broadcast_rhs=*/false, + }, + }, + /*name=*/std::nullopt, + }; + + ParallelTensorAttrs relu_output_attrs = ParallelTensorAttrs{ + /*shape=*/input_shape, + /*sync_type=*/std::nullopt, + /*initializer=*/std::nullopt, + /*create_gradients=*/CreateGrad::YES, + }; + + SUBCASE("no edges across split") { + ParallelLayerAddedResult input1 = pcg_add_input_layer(pcg, input_shape); + ParallelLayerAddedResult input2 = pcg_add_input_layer(pcg, input_shape); + + PCGBinarySeriesSplit split = PCGBinarySeriesSplit{ + make_leaf(input1.parallel_layer), + make_leaf(input2.parallel_layer), + }; + + AbstractedTensorSetMovement result = + get_abstracted_tensor_set_movement_across_split( + pcg_get_transitive_reduction(pcg), split); + + AbstractedTensorSetMovement correct = AbstractedTensorSetMovement{ + /*single_tensor_movements=*/{}, + }; + + CHECK(result == correct); + } + + SUBCASE("single edge across split") { + ParallelLayerAddedResult input = pcg_add_input_layer(pcg, input_shape); + + ParallelLayerAddedResult layer_1 = add_parallel_layer( + pcg, relu_attrs, {get_only(input.outputs)}, {relu_output_attrs}); + ParallelLayerAddedResult layer_2 = add_parallel_layer( + pcg, relu_attrs, {get_only(layer_1.outputs)}, {relu_output_attrs}); + + PCGBinarySeriesSplit split = PCGBinarySeriesSplit{ + make_series_split(make_leaf(input.parallel_layer), + make_leaf(layer_1.parallel_layer)), + make_leaf(layer_2.parallel_layer), + }; + + AbstractedTensorSetMovement result = + get_abstracted_tensor_set_movement_across_split( + pcg_get_transitive_reduction(pcg), split); + + AbstractedTensorSetMovement correct = AbstractedTensorSetMovement{ + /*single_tensor_movements=*/{ + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/ + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + }, + /*dst_machine_views=*/ + { + BinaryTreePath{{}}, + }, + }, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("does not include edges removed by transitive reduction") { + ParallelLayerAddedResult input = pcg_add_input_layer(pcg, input_shape); + + ParallelLayerAddedResult layer_1 = add_parallel_layer( + pcg, relu_attrs, {get_only(input.outputs)}, {relu_output_attrs}); + + ParallelLayerAddedResult layer_2 = add_parallel_layer( + pcg, relu_attrs, {get_only(layer_1.outputs)}, {relu_output_attrs}); + + ParallelLayerAddedResult layer_3 = add_parallel_layer( + pcg, + ew_add_attrs, + {get_only(layer_1.outputs), get_only(layer_2.outputs)}, + {relu_output_attrs}); + + PCGBinarySeriesSplit split = PCGBinarySeriesSplit{ + make_series_split( + make_leaf(input.parallel_layer), + make_series_split(make_leaf(layer_1.parallel_layer), + make_leaf(layer_2.parallel_layer))), + make_leaf(layer_3.parallel_layer), + }; + + AbstractedTensorSetMovement result = + get_abstracted_tensor_set_movement_across_split( + pcg_get_transitive_reduction(pcg), split); + + AbstractedTensorSetMovement correct = AbstractedTensorSetMovement{ + /*single_tensor_movements=*/{ + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/ + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + BinaryTreePathEntry::RIGHT_CHILD, + }}, + }, + /*dst_machine_views=*/ + { + BinaryTreePath{{}}, + }, + }, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("single tensor, multiple consumers across split") { + ParallelLayerAddedResult input = pcg_add_input_layer(pcg, input_shape); + + ParallelLayerAddedResult layer_1 = add_parallel_layer( + pcg, relu_attrs, {get_only(input.outputs)}, {relu_output_attrs}); + + ParallelLayerAddedResult layer_2 = add_parallel_layer( + pcg, relu_attrs, {get_only(layer_1.outputs)}, {relu_output_attrs}); + + ParallelLayerAddedResult layer_3 = add_parallel_layer( + pcg, relu_attrs, {get_only(layer_1.outputs)}, {relu_output_attrs}); + + PCGBinarySeriesSplit split = PCGBinarySeriesSplit{ + make_series_split(make_leaf(input.parallel_layer), + make_leaf(layer_1.parallel_layer)), + make_parallel_split(make_leaf(layer_2.parallel_layer), + make_leaf(layer_3.parallel_layer)), + }; + + AbstractedTensorSetMovement result = + get_abstracted_tensor_set_movement_across_split( + pcg_get_transitive_reduction(pcg), split); + + AbstractedTensorSetMovement correct = AbstractedTensorSetMovement{ + /*single_tensor_movements=*/{ + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/ + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + }, + /*dst_machine_views=*/ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + }, + }, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("multiple tensors, multiple consumers across split") { + ParallelLayerAddedResult input = pcg_add_input_layer(pcg, input_shape); + + ParallelLayerAddedResult layer_1 = add_parallel_layer( + pcg, relu_attrs, {get_only(input.outputs)}, {relu_output_attrs}); + + ParallelLayerAddedResult layer_2 = add_parallel_layer( + pcg, relu_attrs, {get_only(input.outputs)}, {relu_output_attrs}); + + ParallelLayerAddedResult layer_3 = add_parallel_layer( + pcg, relu_attrs, {get_only(layer_1.outputs)}, {relu_output_attrs}); + + ParallelLayerAddedResult layer_4 = add_parallel_layer( + pcg, + ew_add_attrs, + {get_only(layer_1.outputs), get_only(layer_2.outputs)}, + {relu_output_attrs}); + + PCGBinarySeriesSplit split = PCGBinarySeriesSplit{ + make_series_split( + make_leaf(input.parallel_layer), + make_parallel_split(make_leaf(layer_1.parallel_layer), + make_leaf(layer_2.parallel_layer))), + make_parallel_split(make_leaf(layer_3.parallel_layer), + make_leaf(layer_4.parallel_layer))}; + + AbstractedTensorSetMovement result = + get_abstracted_tensor_set_movement_across_split( + pcg_get_transitive_reduction(pcg), split); + + AbstractedTensorSetMovement correct = AbstractedTensorSetMovement{ + /*single_tensor_movements=*/{ + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/ + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + BinaryTreePathEntry::LEFT_CHILD, + }}, + }, + /*dst_machine_views=*/ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + }, + }, + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/ + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + BinaryTreePathEntry::RIGHT_CHILD, + }}, + }, + /*dst_machine_views=*/ + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + }, + }, + }, + }; + + CHECK(result == correct); + } + } +} diff --git a/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.cc b/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.cc new file mode 100644 index 0000000000..9ee596af3e --- /dev/null +++ b/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.cc @@ -0,0 +1,41 @@ +#include "./cost_estimator_for_test.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h" + +namespace FlexFlow { + +TestCostEstimator::TestCostEstimator( + std::function const &get_operator_cost, + std::function const + &get_communication_cost) + : get_operator_cost(get_operator_cost), + get_communication_cost(get_communication_cost) {} + +float TestCostEstimator::estimate_cost(OpCostEstimateKey const &k) const { + return this->get_operator_cost(k); +} + +float TestCostEstimator::estimate_cost(TensorSetMovement const &m) const { + return this->get_communication_cost(m); +} + +CostEstimator make_fake_cost_estimator( + std::function const &get_operator_cost, + std::function const + &get_communication_cost) { + + return CostEstimator::create(get_operator_cost, + get_communication_cost); +} + +CostEstimator make_fake_cost_estimator( + std::unordered_map const &op_cost_map, + std::unordered_map const &comm_cost_map) { + return make_fake_cost_estimator( + [op_cost_map](OpCostEstimateKey const &k) { return op_cost_map.at(k); }, + [comm_cost_map](TensorSetMovement const &m) { + return comm_cost_map.at(m); + }); +} + +} // namespace FlexFlow diff --git a/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h b/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h new file mode 100644 index 0000000000..7c1d06207a --- /dev/null +++ b/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h @@ -0,0 +1,38 @@ +#ifndef _FLEXFLOW_TEST_COST_ESTIMATOR_H +#define _FLEXFLOW_TEST_COST_ESTIMATOR_H + +#include "compiler/cost_estimator/cost_estimator.h" +#include "compiler/cost_estimator/op_cost_estimate_key.dtg.h" +#include "compiler/cost_estimator/tensor_set_movement.dtg.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.dtg.h" +#include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.dtg.h" + +namespace FlexFlow { + +struct TestCostEstimator : public ICostEstimator { + std::function get_operator_cost; + std::function get_communication_cost; + + TestCostEstimator() = delete; + TestCostEstimator(decltype(get_operator_cost) const &get_operator_cost, + decltype(get_communication_cost) + const &get_communication_cost); + + float estimate_cost(OpCostEstimateKey const &) const override; + + float estimate_cost(TensorSetMovement const &) const override; +}; + +CostEstimator make_fake_cost_estimator( + std::function const &get_operator_cost, + std::function const + &get_communication_cost); + +CostEstimator make_fake_cost_estimator( + std::unordered_map const &op_cost_map, + std::unordered_map const &comm_cost_map); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/test/src/compiler/machine_mapping/get_machine_resource_splits.cc b/lib/compiler/test/src/compiler/machine_mapping/get_machine_resource_splits.cc new file mode 100644 index 0000000000..499b111f8f --- /dev/null +++ b/lib/compiler/test/src/compiler/machine_mapping/get_machine_resource_splits.cc @@ -0,0 +1,235 @@ +#include "compiler/machine_mapping/get_machine_resource_splits.h" +#include "test/utils/doctest/fmt/pair.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "utils/hash/pair.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_machine_resource_splits") { + auto make_machine_spec = [](int num_nodes, int num_gpus_per_node) { + return MachineSpecification{ + /*num_nodes=*/num_nodes, + /*num_cpus_per_node=*/1, + /*num_gpus_per_node=*/num_gpus_per_node, + /*inter_node_bandwidth=*/1.0, + /*intra_node_bandwidth=*/1.0, + }; + }; + + SUBCASE("returns no splits if no splits are possible") { + MachineSpecification input = make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/1); + + std::unordered_set> + result = get_machine_resource_splits(input); + std::unordered_set> + correct = {}; + + CHECK(result == correct); + } + + SUBCASE( + "returns splits in gpu and node dimensions, but not at the same time") { + MachineSpecification input = make_machine_spec(/*num_nodes=*/2, + /*num_gpus_per_node=*/2); + + std::unordered_set> + result = get_machine_resource_splits(input); + + std::unordered_set> + correct = { + { + make_machine_spec(/*num_nodes=*/2, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/2, + /*num_gpus_per_node=*/1), + }, + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/2), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/2), + }, + + }; + + CHECK(result == correct); + } + + SUBCASE("returns splits in node dimension in powers of two") { + SUBCASE("num_nodes is a power of 2") { + MachineSpecification input = make_machine_spec(/*num_nodes=*/8, + /*num_gpus_per_node=*/1); + + std::unordered_set< + std::pair> + result = get_machine_resource_splits(input); + + std::unordered_set< + std::pair> + correct = { + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/7, + /*num_gpus_per_node=*/1), + }, + { + make_machine_spec(/*num_nodes=*/2, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/6, + /*num_gpus_per_node=*/1), + }, + { + make_machine_spec(/*num_nodes=*/4, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/4, + /*num_gpus_per_node=*/1), + }, + { + make_machine_spec(/*num_nodes=*/6, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/2, + /*num_gpus_per_node=*/1), + }, + { + make_machine_spec(/*num_nodes=*/7, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/1), + }, + }; + + CHECK(result == correct); + } + + SUBCASE("num_nodes is not a power of 2") { + MachineSpecification input = make_machine_spec(/*num_nodes=*/6, + /*num_gpus_per_node=*/1); + + std::unordered_set< + std::pair> + result = get_machine_resource_splits(input); + + std::unordered_set< + std::pair> + correct = { + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/5, + /*num_gpus_per_node=*/1), + }, + { + make_machine_spec(/*num_nodes=*/2, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/4, + /*num_gpus_per_node=*/1), + }, + { + make_machine_spec(/*num_nodes=*/4, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/2, + /*num_gpus_per_node=*/1), + }, + { + make_machine_spec(/*num_nodes=*/5, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/1), + }, + }; + + CHECK(result == correct); + } + } + + SUBCASE("returns splits in gpu dimension in powers of two") { + SUBCASE("num_gpus_per_node is a power of 2") { + MachineSpecification input = make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/8); + + std::unordered_set< + std::pair> + result = get_machine_resource_splits(input); + + std::unordered_set< + std::pair> + correct = { + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/7), + }, + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/2), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/6), + }, + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/4), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/4), + }, + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/6), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/2), + }, + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/7), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/1), + }, + }; + + CHECK(result == correct); + } + + SUBCASE("num_gpus_per_node is not a power of 2") { + MachineSpecification input = make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/6); + + std::unordered_set< + std::pair> + result = get_machine_resource_splits(input); + + std::unordered_set< + std::pair> + correct = { + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/5), + }, + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/2), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/4), + }, + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/4), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/2), + }, + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/5), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/1), + }, + }; + } + } + } +} diff --git a/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc new file mode 100644 index 0000000000..a0d06fe930 --- /dev/null +++ b/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -0,0 +1,263 @@ +#include "compiler/machine_mapping/get_optimal_machine_mapping.h" +#include "./cost_estimator_for_test.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h" +#include "compiler/machine_mapping/machine_mapping_cache.h" +#include "compiler/machine_mapping/machine_mapping_constraints.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h" +#include "pcg/machine_view.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include "utils/containers/get_only.h" +#include "utils/full_binary_tree/binary_tree_path.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_optimal_machine_mapping") { + auto make_leaf = [](UnmappedOpCostEstimateKey const &k) { + return MachineMappingProblemTree{k}; + }; + + auto make_series_split = + [](AbstractedTensorSetMovement const &tensor_set_movement, + MachineMappingProblemTree const &lhs, + MachineMappingProblemTree const &rhs) { + return MachineMappingProblemTree{ + MMProblemTreeSeriesSplit{ + /*tensor_set_movement=*/tensor_set_movement, + /*left_child=*/lhs, + /*right_child=*/rhs, + }, + }; + }; + + auto make_parallel_split = [](MachineMappingProblemTree const &lhs, + MachineMappingProblemTree const &rhs) { + return MachineMappingProblemTree{ + MMProblemTreeParallelSplit{ + /*left_child=*/lhs, + /*right_child=*/rhs, + }, + }; + }; + + MachineView mv1 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0, + /*device_idx=*/0, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{1}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + MachineView mv2 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0, + /*device_idx=*/0, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{2}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + MachineSpecification full_machine_spec = MachineSpecification{ + /*num_nodes=*/2, + /*num_cpus_per_node=*/1, + /*num_gpus_per_node=*/1, + /*inter_node_bandwidth=*/1, + /*intra_node_bandwidth=*/1, + }; + + MachineSpecification split_machine_spec = MachineSpecification{ + /*num_nodes=*/1, + /*num_cpus_per_node=*/1, + /*num_gpus_per_node=*/1, + /*inter_node_bandwidth=*/1, + /*intra_node_bandwidth=*/1, + }; + + auto allowed_machine_views1 = [&](UnmappedOpCostEstimateKey const &, + MachineSpecification const &resources) { + if (resources == full_machine_spec) { + return std::unordered_set{mv1, mv2}; + } else { + return std::unordered_set{mv2}; + } + }; + + UnmappedOpCostEstimateKey k1 = UnmappedOpCostEstimateKey{ + /*op_attrs=*/PCGOperatorAttrs{InputAttrs{}}, + /*input_shapes=*/{}, + /*weight_shapes=*/{}, + /*output_shapes=*/{}, + }; + + UnmappedOpCostEstimateKey k2 = UnmappedOpCostEstimateKey{ + /*op_attrs=*/PCGOperatorAttrs{ElementBinaryAttrs{ + /*type=*/OperatorType::EW_ADD, + /*compute_type=*/DataType::FLOAT, + /*should_broadcast_lhs=*/false, + /*should_broadcast_rhs=*/false, + }}, + /*input_shapes=*/{}, + /*weight_shapes=*/{}, + /*output_shapes=*/{}, + }; + + ParallelTensorShape tensor_shape1 = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{}, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + AbstractedTensorSetMovement movement1 = AbstractedTensorSetMovement{{ + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/tensor_shape1, + /*src_machine_views=*/{}, + /*dst_machine_views=*/{}, + }, + }}; + + ParallelLayerGuidObliviousMachineMapping mm1 = + ParallelLayerGuidObliviousMachineMapping{{ + {binary_tree_root_path(), mv1}, + }}; + ParallelLayerGuidObliviousMachineMapping mm2 = + ParallelLayerGuidObliviousMachineMapping{{ + {binary_tree_root_path(), mv2}, + }}; + + CostEstimator cost_estimator = make_fake_cost_estimator( + std::unordered_map{{ + {map_unmapped_op_cost_estimate_key(k1, mv1), 1.0}, + {map_unmapped_op_cost_estimate_key(k2, mv1), 2.0}, + {map_unmapped_op_cost_estimate_key(k1, mv2), 1.5}, + {map_unmapped_op_cost_estimate_key(k2, mv2), 2.5}, + }}, + std::unordered_map{{ + {TensorSetMovement{{}}, 0.0}, + {concretize_abstracted_tensor_set_movement(movement1, mm1, mm1), + 0.1}, + {concretize_abstracted_tensor_set_movement(movement1, mm2, mm2), + 0.2}, + {concretize_abstracted_tensor_set_movement(movement1, mm1, mm2), + 0.3}, + {concretize_abstracted_tensor_set_movement(movement1, mm2, mm1), + 0.4}, + }}); + + MachineMappingContext context = MachineMappingContext{ + cost_estimator, + allowed_machine_views1, + }; + + MachineMappingCache cache = empty_machine_mapping_cache(); + + SUBCASE("single layer") { + MachineMappingProblemTree problem_tree = make_leaf(k1); + + MachineMappingConstraints constraints = + get_unconstrained_solution_for_layers( + get_all_leaf_paths(problem_tree)); + + MachineMappingResult result = get_optimal_machine_mapping( + cache, context, problem_tree, full_machine_spec, constraints); + MachineMappingResult correct = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/1.0, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + {binary_tree_root_path(), mv1}, + }}, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("pair of layers in sequence") { + MachineMappingProblemTree problem_tree = + make_series_split(movement1, make_leaf(k1), make_leaf(k2)); + + MachineMappingConstraints constraints = + get_unconstrained_solution_for_layers( + get_all_leaf_paths(problem_tree)); + + MachineMappingResult result = get_optimal_machine_mapping( + cache, context, problem_tree, full_machine_spec, constraints); + MachineMappingResult correct = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/1.0 + 2.0 + 0.1, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + mv1, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + mv1, + }, + }}, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("pair of layers in parallel") { + MachineMappingProblemTree problem_tree = + make_parallel_split(make_leaf(k1), make_leaf(k2)); + + MachineMappingConstraints constraints = + get_unconstrained_solution_for_layers( + get_all_leaf_paths(problem_tree)); + + MachineMappingResult result = get_optimal_machine_mapping( + cache, context, problem_tree, full_machine_spec, constraints); + MachineMappingResult correct = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/2.5, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + mv2, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + mv2, + }, + }}, + }, + }; + + CHECK(result == correct); + } + } +} diff --git a/lib/compiler/test/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc b/lib/compiler/test/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc new file mode 100644 index 0000000000..e22f715d82 --- /dev/null +++ b/lib/compiler/test/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc @@ -0,0 +1,339 @@ +#include "compiler/machine_mapping/get_tensor_set_movement_across_split.h" +#include "./cost_estimator_for_test.h" +#include "compiler/machine_mapping/transitive_reduced_pcg.h" +#include "pcg/machine_view.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include "utils/containers/get_only.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_tensor_set_movement_across_split") { + auto make_pcg_series_split = [](PCGBinarySPDecomposition const &lhs, + PCGBinarySPDecomposition const &rhs) { + return PCGBinarySPDecomposition{PCGBinarySeriesSplit{lhs, rhs}}; + }; + + auto make_pcg_parallel_split = [](PCGBinarySPDecomposition const &lhs, + PCGBinarySPDecomposition const &rhs) { + return PCGBinarySPDecomposition{PCGBinaryParallelSplit{lhs, rhs}}; + }; + + auto make_pcg_leaf_node = [](parallel_layer_guid_t const &l) { + return PCGBinarySPDecomposition{l}; + }; + + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelTensorShape input_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{10, 2}, + ShardParallelDim{12, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + ParallelLayerAddedResult input = pcg_add_input_layer(pcg, input_shape); + + ParallelLayerAttrs relu_attrs = ParallelLayerAttrs{ + /*op_attrs=*/PCGOperatorAttrs{ + ElementUnaryAttrs{ + /*op_type=*/OperatorType::RELU, + /*scalar=*/std::nullopt, + }, + }, + /*name=*/std::nullopt, + }; + + ParallelTensorAttrs relu_output_attrs = ParallelTensorAttrs{ + /*shape=*/input_shape, + /*sync_type=*/std::nullopt, + /*initializer=*/std::nullopt, + /*create_gradients=*/CreateGrad::YES, + }; + + ParallelLayerAddedResult relu_1 = add_parallel_layer( + pcg, relu_attrs, {get_only(input.outputs)}, {relu_output_attrs}); + ParallelLayerAddedResult relu_2 = add_parallel_layer( + pcg, relu_attrs, {get_only(relu_1.outputs)}, {relu_output_attrs}); + + MachineView pre_mv1 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0, + /*device_idx=*/0, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{1}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + MachineView pre_mv2 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0, + /*device_idx=*/0, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{2}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + MachineView post_mv1 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0, + /*device_idx=*/0, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{3}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + MachineView post_mv2 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0, + /*device_idx=*/0, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{4}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + SUBCASE("single edge across split") { + PCGBinarySeriesSplit split = PCGBinarySeriesSplit{ + make_pcg_series_split(make_pcg_leaf_node(input.parallel_layer), + make_pcg_leaf_node(relu_1.parallel_layer)), + make_pcg_leaf_node(relu_2.parallel_layer), + }; + + auto pre_mapping = ParallelLayerGuidObliviousMachineMapping{{ + {BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + pre_mv1}, + }}; + + auto post_mapping = ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{}}, + post_mv1, + }, + }}; + + TensorSetMovement result = get_tensor_set_movement_across_split( + pcg_get_transitive_reduction(pcg), split, pre_mapping, post_mapping); + TensorSetMovement correct = TensorSetMovement{ + /*single_tensor_movements=*/{ + SingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/{pre_mv1}, + /*dst_machine_views=*/{post_mv1}, + }, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("does not include edges removed by transitive reduction") {} + + SUBCASE("single tensor, multiple consumers across split") { + ParallelLayerAddedResult relu_3 = add_parallel_layer( + pcg, relu_attrs, {get_only(relu_1.outputs)}, {relu_output_attrs}); + + PCGBinarySeriesSplit split = PCGBinarySeriesSplit{ + make_pcg_series_split(make_pcg_leaf_node(input.parallel_layer), + make_pcg_leaf_node(relu_1.parallel_layer)), + make_pcg_parallel_split(make_pcg_leaf_node(relu_2.parallel_layer), + make_pcg_leaf_node(relu_3.parallel_layer)), + }; + + SUBCASE("consumers have same view") { + auto pre_mapping = ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + pre_mv1, + }, + }}; + + auto post_mapping = ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + post_mv1, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + post_mv1, + }, + }}; + + TensorSetMovement result = get_tensor_set_movement_across_split( + pcg_get_transitive_reduction(pcg), + split, + pre_mapping, + post_mapping); + + TensorSetMovement correct = TensorSetMovement{ + /*single_tensor_movements=*/{ + SingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/{pre_mv1}, + /*dst_machine_views=*/{post_mv1}, + }, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("consumers have different views") { + auto pre_mapping = ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + pre_mv1, + }, + }}; + + auto post_mapping = ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + post_mv1, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + post_mv2, + }, + }}; + + TensorSetMovement result = get_tensor_set_movement_across_split( + pcg_get_transitive_reduction(pcg), + split, + pre_mapping, + post_mapping); + + TensorSetMovement correct = TensorSetMovement{ + /*single_tensor_movements=*/{ + SingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/{pre_mv1}, + /*dst_machine_views=*/{post_mv1, post_mv2}, + }, + }, + }; + + CHECK(result == correct); + } + } + + SUBCASE("multiple tensors, multiple consumers across split") { + ParallelLayerAddedResult relu_3 = add_parallel_layer( + pcg, relu_attrs, {get_only(input.outputs)}, {relu_output_attrs}); + + ParallelLayerAddedResult relu_4 = add_parallel_layer( + pcg, + relu_attrs, + // relu's don't have two inputs, but for the + // purposes of this test it's fine. + {get_only(relu_1.outputs), get_only(relu_3.outputs)}, + {relu_output_attrs}); + + PCGBinarySeriesSplit split = PCGBinarySeriesSplit{ + make_pcg_series_split(make_pcg_leaf_node(input.parallel_layer), + make_pcg_parallel_split( + make_pcg_leaf_node(relu_1.parallel_layer), + make_pcg_leaf_node(relu_3.parallel_layer))), + make_pcg_parallel_split(make_pcg_leaf_node(relu_2.parallel_layer), + make_pcg_leaf_node(relu_4.parallel_layer)), + }; + + auto pre_mapping = ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + BinaryTreePathEntry::LEFT_CHILD, + }}, + pre_mv1, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + BinaryTreePathEntry::RIGHT_CHILD, + }}, + pre_mv2, + }, + }}; + + auto post_mapping = ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + post_mv1, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + post_mv2, + }, + }}; + + TensorSetMovement result = get_tensor_set_movement_across_split( + pcg_get_transitive_reduction(pcg), split, pre_mapping, post_mapping); + + TensorSetMovement correct = TensorSetMovement{ + /*single_tensor_movements=*/{ + SingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/{pre_mv1}, + /*dst_machine_views=*/{post_mv1, post_mv2}, + }, + SingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/{pre_mv2}, + /*dst_machine_views=*/{post_mv2}, + }, + }, + }; + + CHECK(result == correct); + } + } +} diff --git a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping.cc b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping.cc new file mode 100644 index 0000000000..221cca3ae1 --- /dev/null +++ b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping.cc @@ -0,0 +1,111 @@ +#include "compiler/machine_mapping/machine_mapping.h" +#include "cost_estimator_for_test.h" +#include "doctest/doctest.h" +#include "pcg/machine_view.h" + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("combine_disjoint_mappings(MachineMapping, MachineMappping)") { + MachineView machine_view_0 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0, + /*device_idx=*/0, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{1}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + MachineView machine_view_1 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0, + /*device_idx=*/0, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{2}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + MachineMapping machine_mapping_0 = MachineMapping({ + {parallel_layer_guid_t{Node{0}}, machine_view_0}, + }); + MachineMapping machine_mapping_1 = MachineMapping({ + {parallel_layer_guid_t{Node{1}}, machine_view_1}, + }); + MachineMapping correct = MachineMapping{{ + {parallel_layer_guid_t{Node{0}}, machine_view_0}, + {parallel_layer_guid_t{Node{1}}, machine_view_1}, + }}; + MachineMapping result = + combine_disjoint_mappings(machine_mapping_0, machine_mapping_1); + CHECK(result == correct); + } + + TEST_CASE("nodes_are_disjoint(MachineMapping, MachineMappping)") { + MachineView machine_view_0 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0, + /*device_idx=*/0, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{1}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + MachineView machine_view_1 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0, + /*device_idx=*/0, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{2}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + MachineMapping machine_mapping_0 = MachineMapping({ + {parallel_layer_guid_t{Node{0}}, machine_view_0}, + }); + + SUBCASE("nodes are disjoint") { + MachineMapping machine_mapping_1 = MachineMapping({ + {parallel_layer_guid_t{Node{1}}, machine_view_1}, + }); + + bool correct = true; + bool result = nodes_are_disjoint(machine_mapping_0, machine_mapping_1); + CHECK(result == correct); + } + + SUBCASE("nodes are not disjoint") { + MachineMapping machine_mapping_1 = MachineMapping({ + {parallel_layer_guid_t{Node{0}}, machine_view_0}, + {parallel_layer_guid_t{Node{1}}, machine_view_1}, + }); + bool correct = false; + bool result = nodes_are_disjoint(machine_mapping_0, machine_mapping_1); + CHECK(result == correct); + } + } +} diff --git a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc new file mode 100644 index 0000000000..06ab1e5b8c --- /dev/null +++ b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc @@ -0,0 +1,289 @@ +#include "compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "utils/containers/get_only.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_machine_mapping_problem_tree") { + auto pcg_make_leaf = [](parallel_layer_guid_t const &l) { + return PCGBinarySPDecomposition{l}; + }; + + auto pcg_make_series = [](PCGBinarySPDecomposition const &lhs, + PCGBinarySPDecomposition const &rhs) { + return PCGBinarySPDecomposition{ + PCGBinarySeriesSplit{ + lhs, + rhs, + }, + }; + }; + + auto pcg_make_parallel = [](PCGBinarySPDecomposition const &lhs, + PCGBinarySPDecomposition const &rhs) { + return PCGBinarySPDecomposition{ + PCGBinaryParallelSplit{ + lhs, + rhs, + }, + }; + }; + + auto mm_problem_tree_make_leaf = [](UnmappedOpCostEstimateKey const &k) { + return MachineMappingProblemTree{k}; + }; + + auto mm_problem_tree_make_series = + [](AbstractedTensorSetMovement const &tensor_set_movement, + MachineMappingProblemTree const &lhs, + MachineMappingProblemTree const &rhs) { + return MachineMappingProblemTree{ + MMProblemTreeSeriesSplit{ + tensor_set_movement, + lhs, + rhs, + }, + }; + }; + + auto mm_problem_tree_make_parallel = + [](MachineMappingProblemTree const &lhs, + MachineMappingProblemTree const &rhs) { + return MachineMappingProblemTree{ + MMProblemTreeParallelSplit{ + lhs, + rhs, + }, + }; + }; + + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelTensorShape input_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{10, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + auto make_output_attrs = [](ParallelTensorShape const &shape) { + return ParallelTensorAttrs{ + /*shape=*/shape, + /*sync_type=*/std::nullopt, + /*initializer=*/std::nullopt, + /*create_gradients=*/CreateGrad::YES, + }; + }; + + auto make_layer_attrs = [](PCGOperatorAttrs const &op_attrs) { + return ParallelLayerAttrs{ + /*op_attrs=*/op_attrs, + /*name=*/std::nullopt, + }; + }; + + PCGOperatorAttrs input_attrs = PCGOperatorAttrs{InputAttrs{}}; + + auto make_input_key = + [&](ParallelTensorShape const ¶llel_tensor_shape) { + return UnmappedOpCostEstimateKey{ + /*op_attrs=*/input_attrs, + /*input_shapes=*/{}, + /*weight_shapes=*/{}, + /*output_shapes=*/{parallel_tensor_shape}, + }; + }; + + SUBCASE("single layer") { + ParallelLayerAddedResult input_added = add_parallel_layer( + pcg, + /*layer_attrs=*/make_layer_attrs(input_attrs), + /*inputs=*/{}, + /*output_labels=*/{make_output_attrs(input_shape)}); + parallel_layer_guid_t input_layer = input_added.parallel_layer; + + UnmappedOpCostEstimateKey input_key = make_input_key(input_shape); + + PCGBinarySPDecomposition sp_decomposition = + PCGBinarySPDecomposition{input_layer}; + + MachineMappingProblemTree result = + get_machine_mapping_problem_tree(pcg, sp_decomposition); + MachineMappingProblemTree correct = MachineMappingProblemTree{input_key}; + + CHECK(result == correct); + } + + SUBCASE("two layers in series") { + ParallelLayerAddedResult input_added = add_parallel_layer( + pcg, + /*layer_attrs=*/make_layer_attrs(input_attrs), + /*inputs=*/{}, + /*output_labels=*/{make_output_attrs(input_shape)}); + parallel_layer_guid_t input_layer = input_added.parallel_layer; + parallel_tensor_guid_t input = get_only(input_added.outputs); + + UnmappedOpCostEstimateKey input_key = make_input_key(input_shape); + + PCGOperatorAttrs relu_attrs = PCGOperatorAttrs{ + ElementUnaryAttrs{ + /*op_type=*/OperatorType::RELU, + /*scalar=*/std::nullopt, + }, + }; + ParallelTensorShape relu_output_shape = input_shape; + ParallelLayerAddedResult relu_added = + add_parallel_layer(pcg, + make_layer_attrs(relu_attrs), + {input}, + {make_output_attrs(relu_output_shape)}); + parallel_layer_guid_t relu_layer = relu_added.parallel_layer; + parallel_tensor_guid_t relu_output = get_only(relu_added.outputs); + + UnmappedOpCostEstimateKey relu_key = UnmappedOpCostEstimateKey{ + /*op_attrs=*/relu_attrs, + /*input_shapes=*/{input_shape}, + /*weight_shapes=*/{}, + /*output_shapes=*/{relu_output_shape}, + }; + + PCGBinarySPDecomposition sp_decomposition = pcg_make_series( + pcg_make_leaf(input_layer), pcg_make_leaf(relu_layer)); + + MachineMappingProblemTree result = + get_machine_mapping_problem_tree(pcg, sp_decomposition); + + MachineMappingProblemTree correct = mm_problem_tree_make_series( + AbstractedTensorSetMovement{{ + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/ + { + BinaryTreePath{{}}, + }, + /*dst_machine_views=*/ + { + BinaryTreePath{{}}, + }, + }, + }}, + mm_problem_tree_make_leaf(input_key), + mm_problem_tree_make_leaf(relu_key)); + + CHECK(result == correct); + } + + SUBCASE("two layers in parallel") { + ParallelLayerAddedResult input1_added = + pcg_add_input_layer(pcg, input_shape); + parallel_layer_guid_t input1_layer = input1_added.parallel_layer; + UnmappedOpCostEstimateKey input1_key = make_input_key(input_shape); + + ParallelLayerAddedResult input2_added = + pcg_add_input_layer(pcg, input_shape); + parallel_layer_guid_t input2_layer = input2_added.parallel_layer; + UnmappedOpCostEstimateKey input2_key = make_input_key(input_shape); + + PCGBinarySPDecomposition sp_decomposition = pcg_make_parallel( + pcg_make_leaf(input1_layer), pcg_make_leaf(input2_layer)); + + MachineMappingProblemTree result = + get_machine_mapping_problem_tree(pcg, sp_decomposition); + + MachineMappingProblemTree correct = + mm_problem_tree_make_parallel(mm_problem_tree_make_leaf(input1_key), + mm_problem_tree_make_leaf(input2_key)); + + CHECK(result == correct); + } + + SUBCASE("multiple tensors across split") { + ParallelLayerAddedResult input1_added = + pcg_add_input_layer(pcg, input_shape); + parallel_layer_guid_t input1_layer = input1_added.parallel_layer; + parallel_tensor_guid_t input1_tensor = get_only(input1_added.outputs); + UnmappedOpCostEstimateKey input1_key = make_input_key(input_shape); + + ParallelLayerAddedResult input2_added = + pcg_add_input_layer(pcg, input_shape); + parallel_layer_guid_t input2_layer = input2_added.parallel_layer; + parallel_tensor_guid_t input2_tensor = get_only(input2_added.outputs); + UnmappedOpCostEstimateKey input2_key = make_input_key(input_shape); + + PCGOperatorAttrs ew_op_attrs = PCGOperatorAttrs{ + ElementBinaryAttrs{ + /*type=*/OperatorType::EW_ADD, + /*compute_type=*/DataType::FLOAT, + /*should_broadcast_lhs=*/false, + /*should_broadcast_rhs=*/false, + }, + }; + ParallelTensorShape ew_op_output_shape = input_shape; + ParallelLayerAddedResult ew_op_added = + add_parallel_layer(pcg, + make_layer_attrs(ew_op_attrs), + {input1_tensor, input2_tensor}, + {make_output_attrs(ew_op_output_shape)}); + parallel_layer_guid_t ew_op_layer = ew_op_added.parallel_layer; + UnmappedOpCostEstimateKey ew_op_key = UnmappedOpCostEstimateKey{ + /*op_attrs=*/ew_op_attrs, + /*input_shapes=*/{input_shape, input_shape}, + /*weight_shapes=*/{}, + /*output_shapes=*/{ew_op_output_shape}, + }; + + PCGBinarySPDecomposition sp_decomposition = + pcg_make_series(pcg_make_parallel(pcg_make_leaf(input1_layer), + pcg_make_leaf(input2_layer)), + pcg_make_leaf(ew_op_layer)); + + MachineMappingProblemTree result = + get_machine_mapping_problem_tree(pcg, sp_decomposition); + + MachineMappingProblemTree correct = mm_problem_tree_make_series( + AbstractedTensorSetMovement{{ + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + }, + /*dst_machine_views=*/ + { + BinaryTreePath{{}}, + }, + }, + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/ + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + }, + /*dst_machine_views=*/ + { + BinaryTreePath{{}}, + }, + }, + }}, + /*pre=*/ + mm_problem_tree_make_parallel(mm_problem_tree_make_leaf(input1_key), + mm_problem_tree_make_leaf(input2_key)), + /*post=*/mm_problem_tree_make_leaf(ew_op_key)); + + CHECK(result == correct); + } + } +} diff --git a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_result.cc b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_result.cc new file mode 100644 index 0000000000..73b921fc98 --- /dev/null +++ b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_result.cc @@ -0,0 +1,423 @@ +#include "compiler/machine_mapping/machine_mapping_result.h" +#include "pcg/machine_view.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("series_combine") { + MachineView machine_view_0 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0, + /*device_idx=*/0, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{1}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + MachineView machine_view_1 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0, + /*device_idx=*/0, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{2}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + float pre_cost = 2.0; + MachineMappingResult pre = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/pre_cost, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + machine_view_0, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + machine_view_1, + }, + }}, + }, + }; + + float post_cost = 4.0; + MachineMappingResult post = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/post_cost, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{}}, + machine_view_1, + }, + }}, + }, + }; + + MachineMappingResult infeasible = infeasible_machine_mapping_result(); + + float comm_cost = 3.0; + + SUBCASE("pre is infeasible") { + MachineMappingResult result = series_combine( + comm_cost, infeasible, post, ParallelSplitTransformation::LthenR); + MachineMappingResult correct = infeasible; + + CHECK(result == correct); + } + + SUBCASE("post is infeasible") { + MachineMappingResult result = series_combine( + comm_cost, pre, infeasible, ParallelSplitTransformation::LthenR); + MachineMappingResult correct = infeasible; + + CHECK(result == correct); + } + + SUBCASE("both are infeasible") { + MachineMappingResult result = + series_combine(comm_cost, + infeasible, + infeasible, + ParallelSplitTransformation::LthenR); + MachineMappingResult correct = infeasible; + + CHECK(result == correct); + } + + SUBCASE("both are feasible") { + MachineMappingResult no_parallel_split_transform = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/pre_cost + comm_cost + post_cost, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + BinaryTreePathEntry::LEFT_CHILD, + }}, + machine_view_0, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + BinaryTreePathEntry::RIGHT_CHILD, + }}, + machine_view_1, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + machine_view_1, + }, + }}, + }, + }; + + SUBCASE("parallel_split_transformation = std::nullopt") { + MachineMappingResult result = + series_combine(comm_cost, pre, post, std::nullopt); + MachineMappingResult correct = no_parallel_split_transform; + + CHECK(result == correct); + } + + SUBCASE("parallel_split_transformation = LthenR") { + MachineMappingResult result = series_combine( + comm_cost, pre, post, ParallelSplitTransformation::LthenR); + MachineMappingResult correct = no_parallel_split_transform; + + CHECK(result == correct); + } + + SUBCASE("parallel_split_transformation = RthenL") { + MachineMappingResult result = series_combine( + comm_cost, pre, post, ParallelSplitTransformation::RthenL); + MachineMappingResult correct = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/pre_cost + comm_cost + post_cost, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + BinaryTreePathEntry::LEFT_CHILD, + }}, + machine_view_0, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + BinaryTreePathEntry::RIGHT_CHILD, + }}, + machine_view_1, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + machine_view_1, + }, + }}, + }, + }; + + CHECK(result == correct); + } + } + } + + TEST_CASE("parallel_combine") { + MachineView machine_view_0 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0, + /*device_idx=*/0, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{1}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + MachineView machine_view_1 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0, + /*device_idx=*/0, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{2}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + MachineMappingResult lhs = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/2.0, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + machine_view_0, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + machine_view_1, + }, + }}, + }, + }; + + MachineMappingResult rhs = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/4.0, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{}}, + machine_view_1, + }, + }}, + }, + }; + + MachineMappingResult infeasible = infeasible_machine_mapping_result(); + + SUBCASE("lhs is infeasible") { + MachineMappingResult result = parallel_combine(infeasible, rhs); + MachineMappingResult correct = infeasible; + + CHECK(result == correct); + } + + SUBCASE("rhs is infeasible") { + MachineMappingResult result = parallel_combine(lhs, infeasible); + MachineMappingResult correct = infeasible; + + CHECK(result == correct); + } + + SUBCASE("both are infeasible") { + MachineMappingResult result = parallel_combine(infeasible, infeasible); + MachineMappingResult correct = infeasible; + + CHECK(result == correct); + } + + SUBCASE("both are feasible") { + MachineMappingResult result = parallel_combine(lhs, rhs); + MachineMappingResult correct = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/4.0, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + BinaryTreePathEntry::LEFT_CHILD, + }}, + machine_view_0, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + BinaryTreePathEntry::RIGHT_CHILD, + }}, + machine_view_1, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + machine_view_1, + }, + }}, + }, + }; + + CHECK(result == correct); + } + } + + TEST_CASE("minimize_runtime") { + MachineView machine_view_0 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0, + /*device_idx=*/0, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{1}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + MachineView machine_view_1 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0, + /*device_idx=*/0, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{2}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + MachineMappingResult faster = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/2.0, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + machine_view_0, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + machine_view_1, + }, + }}, + }, + }; + + MachineMappingResult slower = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/4.0, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{}}, + machine_view_1, + }, + }}, + }, + }; + + MachineMappingResult infeasible = infeasible_machine_mapping_result(); + + SUBCASE("lhs is infeasible") { + MachineMappingResult result = minimize_runtime(infeasible, slower); + MachineMappingResult correct = slower; + + CHECK(result == correct); + } + + SUBCASE("rhs is infeasible") { + MachineMappingResult result = minimize_runtime(slower, infeasible); + MachineMappingResult correct = slower; + + CHECK(result == correct); + } + + SUBCASE("both are infeasible") { + MachineMappingResult result = minimize_runtime(infeasible, infeasible); + MachineMappingResult correct = infeasible; + + CHECK(result == correct); + } + + SUBCASE("both are feasible") { + SUBCASE("lhs is faster") { + MachineMappingResult result = minimize_runtime(faster, slower); + MachineMappingResult correct = faster; + + CHECK(result == correct); + } + + SUBCASE("rhs is faster") { + MachineMappingResult result = minimize_runtime(slower, faster); + MachineMappingResult correct = faster; + + CHECK(result == correct); + } + + SUBCASE("lhs and rhs have the same speed") { + MachineMappingResult result = minimize_runtime(slower, slower); + MachineMappingResult correct = slower; + + CHECK(result == correct); + } + } + } +} diff --git a/lib/compiler/test/src/compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.cc b/lib/compiler/test/src/compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.cc new file mode 100644 index 0000000000..2b59669aad --- /dev/null +++ b/lib/compiler/test/src/compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.cc @@ -0,0 +1,397 @@ +#include "compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.h" +#include "models/bert/bert.h" +#include "models/candle_uno/candle_uno.h" +#include "models/inception_v3/inception_v3.h" +#include "models/split_test/split_test.h" +#include "models/transformer/transformer.h" +#include "pcg/computation_graph.h" +#include "pcg/computation_graph_builder.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE( + "get_computation_graph_series_parallel_decomposition(ComputationGraph)") { + SUBCASE("empty computation graph") { + ComputationGraph cg = make_empty_computation_graph(); + + std::optional result = + get_computation_graph_series_parallel_decomposition(cg); + // technically an empty graph is non-SP + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("just a single input") { + std::string input_layer_name = "my input"; + ComputationGraph cg = [&] { + ComputationGraphBuilder b; + + TensorShape input_shape = TensorShape{TensorDims{FFOrdered{ + 10, + 12, + }}, + DataType::FLOAT}; + b.create_input(input_shape, CreateGrad::YES, input_layer_name); + + return b.computation_graph; + }(); + + layer_guid_t input_layer = get_layer_by_name(cg, input_layer_name); + + std::optional result = + get_computation_graph_series_parallel_decomposition(cg); + std::optional correct = + SeriesParallelDecomposition{input_layer.raw_node}; + + CHECK(result == correct); + } + + SUBCASE("single operator plus inputs and weights") { + std::string input_layer_name = "my input"; + std::string projection_weights_layer_name = "my projection weights"; + std::string bias_weights_layer_name = "my bias weights"; + std::string operator_name = "my operator"; + ComputationGraph cg = [&] { + ComputationGraphBuilder b; + + TensorShape input_shape = TensorShape{TensorDims{FFOrdered{ + 10, + 12, + }}, + DataType::FLOAT}; + tensor_guid_t input = + b.create_input(input_shape, CreateGrad::YES, input_layer_name); + + b.dense(input, + /*outDim=*/14, + /*activation=*/std::nullopt, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*projection_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt, + /*name=*/operator_name, + /*projection_name=*/projection_weights_layer_name, + /*bias_name=*/bias_weights_layer_name); + + return b.computation_graph; + }(); + + layer_guid_t input_layer = get_layer_by_name(cg, input_layer_name); + layer_guid_t projection_weights_layer = + get_layer_by_name(cg, projection_weights_layer_name); + layer_guid_t bias_weights_layer = + get_layer_by_name(cg, bias_weights_layer_name); + layer_guid_t operator_layer = get_layer_by_name(cg, operator_name); + + std::optional result = + get_computation_graph_series_parallel_decomposition(cg); + std::optional correct = + SeriesParallelDecomposition{SeriesSplit{{ + ParallelSplit{{ + input_layer.raw_node, + projection_weights_layer.raw_node, + bias_weights_layer.raw_node, + }}, + operator_layer.raw_node, + }}}; + + CHECK(result == correct); + } + + SUBCASE("SP without weight nodes but non-SP with weight nodes") { + // A minimal computation graph where without weights (w1 and w2) the + // computation graph is series-parallel, but with weight nodes it is not + // + // w1 input w2 + // \ / \ / + // op1 op2 + + std::string w1_name = "w1"; + std::string input_name = "input"; + std::string w2_name = "w2"; + std::string op1_name = "op1"; + std::string op2_name = "op2"; + + ComputationGraph cg = [&] { + ComputationGraphBuilder b; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + 10, + 12, + }}, + DataType::FLOAT, + }; + tensor_guid_t input = + b.create_input(input_shape, CreateGrad::YES, input_name); + + b.dense(input, + /*outDim=*/14, + /*activation=*/std::nullopt, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*projection_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt, + /*name=*/op1_name, + /*projection_name=*/w1_name); + b.dense(input, + /*outDim=*/14, + /*activation=*/std::nullopt, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*projection_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt, + /*name=*/op2_name, + /*projection_name=*/w2_name); + + return b.computation_graph; + }(); + + layer_guid_t w1 = get_layer_by_name(cg, w1_name); + layer_guid_t input = get_layer_by_name(cg, input_name); + layer_guid_t w2 = get_layer_by_name(cg, w2_name); + layer_guid_t op1 = get_layer_by_name(cg, op1_name); + layer_guid_t op2 = get_layer_by_name(cg, op2_name); + + std::optional result = + get_computation_graph_series_parallel_decomposition(cg); + std::optional correct = + SeriesParallelDecomposition{SeriesSplit{{ + ParallelSplit{{ + w1.raw_node, + input.raw_node, + w2.raw_node, + }}, + ParallelSplit{{ + op1.raw_node, + op2.raw_node, + }}, + }}}; + } + + SUBCASE("SP with or without preprocessing, but preprocessing would SP " + "decomposition") { + // computation graph: + // + // input1 input2 + // | | + // op1 op2 + + std::string input1_name = "input1"; + std::string input2_name = "input2"; + std::string op1_name = "op1"; + std::string op2_name = "op2"; + + ComputationGraph cg = [&] { + ComputationGraphBuilder b; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + 10, + 12, + }}, + DataType::FLOAT, + }; + tensor_guid_t input1 = + b.create_input(input_shape, CreateGrad::YES, input1_name); + tensor_guid_t input2 = + b.create_input(input_shape, CreateGrad::YES, input2_name); + + b.relu(input1, op1_name); + b.relu(input2, op2_name); + + return b.computation_graph; + }(); + + layer_guid_t input1 = get_layer_by_name(cg, input1_name); + layer_guid_t input2 = get_layer_by_name(cg, input2_name); + layer_guid_t op1 = get_layer_by_name(cg, op1_name); + layer_guid_t op2 = get_layer_by_name(cg, op2_name); + + std::optional result = + get_computation_graph_series_parallel_decomposition(cg); + std::optional correct = + SeriesParallelDecomposition{ParallelSplit{{ + SeriesSplit{{ + input1.raw_node, + op1.raw_node, + }}, + SeriesSplit{{ + input2.raw_node, + op2.raw_node, + }}, + }}}; + } + + SUBCASE("not SP with or without weight nodes") { + // computation graph: + // + // input1 + // / \ + // op1 op2 + // | \ | + // | \ | + // op3 op4 + + std::string input1_name = "input1"; + std::string op1_name = "op1"; + std::string op2_name = "op2"; + std::string op3_name = "op3"; + std::string op4_name = "op4"; + + ComputationGraph cg = [&] { + ComputationGraphBuilder b; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + 10, + 12, + }}, + DataType::FLOAT, + }; + tensor_guid_t input1 = + b.create_input(input_shape, CreateGrad::YES, input1_name); + + tensor_guid_t op1_output = b.relu(input1, op1_name); + tensor_guid_t op2_output = b.relu(input1, op2_name); + b.relu(op1_output, op3_name); + b.add(op1_output, op2_output, op4_name); + + return b.computation_graph; + }(); + + layer_guid_t input1 = get_layer_by_name(cg, input1_name); + layer_guid_t op1 = get_layer_by_name(cg, op1_name); + layer_guid_t op2 = get_layer_by_name(cg, op2_name); + layer_guid_t op3 = get_layer_by_name(cg, op3_name); + layer_guid_t op4 = get_layer_by_name(cg, op4_name); + + std::optional result = + get_computation_graph_series_parallel_decomposition(cg); + std::optional correct = std::nullopt; + } + + SUBCASE("real models") { + SUBCASE("split_test") { + ComputationGraph cg = + get_split_test_computation_graph(/*batch_size=*/8); + + std::optional sp_decomposition = + get_computation_graph_series_parallel_decomposition(cg); + + CHECK(sp_decomposition.has_value()); + } + + SUBCASE("transformer") { + ComputationGraph cg = + get_transformer_computation_graph(get_default_transformer_config()); + + std::optional sp_decomposition = + get_computation_graph_series_parallel_decomposition(cg); + + CHECK(sp_decomposition.has_value()); + } + + SUBCASE("inception_v3") { + ComputationGraph cg = get_inception_v3_computation_graph( + get_default_inception_v3_training_config()); + + std::optional sp_decomposition = + get_computation_graph_series_parallel_decomposition(cg); + + CHECK(sp_decomposition.has_value()); + } + + SUBCASE("candle_uno") { + ComputationGraph cg = + get_candle_uno_computation_graph(get_default_candle_uno_config()); + + std::optional sp_decomposition = + get_computation_graph_series_parallel_decomposition(cg); + + CHECK(sp_decomposition.has_value()); + } + + SUBCASE("bert") { + ComputationGraph cg = + get_bert_computation_graph(get_default_bert_config()); + + std::optional sp_decomposition = + get_computation_graph_series_parallel_decomposition(cg); + + CHECK(sp_decomposition.has_value()); + } + } + } + + TEST_CASE("render_preprocessed_computation_graph_for_sp_decomposition(" + "ComputationGraph)") { + // currently there's not really a good way to test this, and its arguable + // how much its output really should be validated as its primarily for + // visualization and so there's not really a strict definition of + // correctness, so for now we just run it on some models and make sure it + // doesn't crash. Don't use this as an example. + + SUBCASE("basic single-operator model") { + ComputationGraph cg = [&] { + ComputationGraphBuilder b; + + TensorShape input_shape = TensorShape{TensorDims{FFOrdered{ + 10, + 12, + }}, + DataType::FLOAT}; + tensor_guid_t input = b.create_input(input_shape, CreateGrad::YES); + + b.dense(input, /*outDim=*/14); + + return b.computation_graph; + }(); + + std::string result = + render_preprocessed_computation_graph_for_sp_decomposition(cg); + } + + SUBCASE("split_test") { + ComputationGraph cg = get_split_test_computation_graph(/*batch_size=*/8); + + std::string result = + render_preprocessed_computation_graph_for_sp_decomposition(cg); + } + + SUBCASE("transformer") { + ComputationGraph cg = + get_transformer_computation_graph(get_default_transformer_config()); + + std::string result = + render_preprocessed_computation_graph_for_sp_decomposition(cg); + } + + SUBCASE("inception_v3") { + ComputationGraph cg = get_inception_v3_computation_graph( + get_default_inception_v3_training_config()); + + std::string result = + render_preprocessed_computation_graph_for_sp_decomposition(cg); + } + + SUBCASE("candle_uno") { + ComputationGraph cg = + get_candle_uno_computation_graph(get_default_candle_uno_config()); + + std::string result = + render_preprocessed_computation_graph_for_sp_decomposition(cg); + } + + SUBCASE("bert") { + ComputationGraph cg = + get_bert_computation_graph(get_default_bert_config()); + + std::string result = + render_preprocessed_computation_graph_for_sp_decomposition(cg); + } + } +} diff --git a/lib/compiler/test/src/graph_optimize_state.cc b/lib/compiler/test/src/graph_optimize_state.cc new file mode 100644 index 0000000000..46177ad420 --- /dev/null +++ b/lib/compiler/test/src/graph_optimize_state.cc @@ -0,0 +1,80 @@ +#include "compiler/graph_optimize_state.h" +#include "doctest/doctest.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("GraphOptimizeState::operator==") { + ParallelComputationGraphBuilder builder; + + ParallelTensorShape input_shape = + ParallelTensorShape{ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{32, 2}, + ShardParallelDim{16, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT}; + + parallel_tensor_guid_t input0 = + builder.create_input_tensor(input_shape, CreateGrad::YES, "input0"); + parallel_tensor_guid_t dense0 = builder.dense(input0, + 8, + Activation::RELU, + true, + DataType::FLOAT, + std::nullopt, + std::nullopt, + "dense0"); + + parallel_tensor_guid_t dense1 = builder.dense(dense0, + 4, + Activation::RELU, + true, + DataType::FLOAT, + std::nullopt, + std::nullopt, + "dense1"); + + ParallelComputationGraph pcg = builder.pcg; + + // `machine_mapping` is determined by the PCG and the device mapping + // algorithm, and `runtime` is determined by the PCG and the device mapping, + // so their values here do not matter. + std::unordered_map empty_machine_views; + MachineMapping empty_machine_mapping(empty_machine_views); + bool result1 = + GraphOptimizeState(GraphOptimizeResult(pcg, empty_machine_mapping), + 0) == + GraphOptimizeState(GraphOptimizeResult(pcg, empty_machine_mapping), 0); + bool correct1 = true; + CHECK(result1 == correct1); + + ParallelComputationGraphBuilder builder_; + + parallel_tensor_guid_t input0_ = + builder.create_input_tensor(input_shape, CreateGrad::YES, "input0"); + parallel_tensor_guid_t dense0_ = builder.dense(input0, + 8, + Activation::RELU, + true, + DataType::FLOAT, + std::nullopt, + std::nullopt, + "dense0"); + + ParallelComputationGraph pcg_ = builder.pcg; + + bool result2 = + GraphOptimizeState(GraphOptimizeResult(pcg, empty_machine_mapping), + 0) == + GraphOptimizeState(GraphOptimizeResult(pcg_, empty_machine_mapping), 0); + bool correct2 = false; + CHECK(result2 == correct2); + } +} diff --git a/lib/compiler/test/src/test_cost_estimator.h b/lib/compiler/test/src/test_cost_estimator.h deleted file mode 100644 index 9417b863e4..0000000000 --- a/lib/compiler/test/src/test_cost_estimator.h +++ /dev/null @@ -1,25 +0,0 @@ -#ifndef _FLEXFLOW_TEST_COST_ESTIMATOR_H -#define _FLEXFLOW_TEST_COST_ESTIMATOR_H - -#include "compiler/cost_estimate.h" - -namespace FlexFlow { - -struct TestCostEstimator : public ICostEstimator { - float estimate_cost(PCGOperatorAttrs const &op, - std::vector const &inputs, - std::vector const &weights, - std::vector const &outputs, - MachineView const &mv) const override { - return 0.1; - } - float estimate_cost(ParallelTensorShape const &tensor_shape, - MachineView const &src, - MachineView const &dst) const override { - return 0.1; - } -}; - -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/test/src/test_generator.h b/lib/compiler/test/src/test_generator.h deleted file mode 100644 index d6b8222968..0000000000 --- a/lib/compiler/test/src/test_generator.h +++ /dev/null @@ -1,174 +0,0 @@ -#ifndef _FLEXFLOW_TEST_GENERATOR_H -#define _FLEXFLOW_TEST_GENERATOR_H - -#include "compiler/machine_mapping.h" -#include "pcg/computation_graph.h" -#include "rapidcheck.h" -#include "substitutions/sub_parallel_computation_graph.h" - -using namespace FlexFlow; - -// Rapidcheck does not work for now -// /* -// Generates computation graphs with trivial layers and tensors, which are -// used for tests focusing on graph structures. -// */ -// ComputationGraph test_computataion_graph(MultiDiGraphView const &g) { -// return materialize_output_labelled_multidigraph_view( -// ViewMultiDiGraphAsOutputLabelled( -// g, -// [](Layer(Node const &)) { return Layer(NoopAttrs{}); }, -// [](Tensor(MultiDiOutput const &)) { -// return Tensor{0, DataType::FLOAT, nullopt, false, nullopt}; -// })); -// } - -// /* -// Generates parallel computation graphs with trivial layers and tensors, -// which are used for tests focusing on graph structures. -// */ -// ParallelComputationGraph -// test_parallel_computation_graph(MultiDiGraphView const &g) { -// return materialize_output_labelled_multidigraph_view( -// ViewMultiDiGraphAsOutputLabelled( -// g, -// [](Operator(Node const &)) { return ParallelTensor(NoopAttrs{}); }, -// [](Operator(MultiDiOutput const &)) { -// return ParallelTensor(ParallelTensorDims(TensorDims({})), -// DataType::FLOAT); -// })); -// } - -// rc::Gen small_integer_generator() { -// return rc::gen::inRange(1, 4); -// } - -// namespace rc { - -// Gen serialParallelMultiDiGraph() { -// return gen::map(gen::arbitrary(), -// multidigraph_from_sp_decomposition); -// } - -// template <> -// struct Arbitrary { -// static Gen arbitrary() { -// return -// gen::map(gen::cast(serialParallelMultiDiGraph()), -// test_computataion_graph); -// } -// }; - -// template <> -// struct Arbitrary { -// static Gen arbitrary() { -// return -// gen::map(gen::cast(serialParallelMultiDiGraph()), -// test_parallel_computation_graph); -// } -// }; - -// template <> -// struct Arbitrary> { -// static Gen> arbitrary() { -// return gen::mapcat(gen::arbitrary(), [](bool is_node) { -// return is_node -// ? gen::cast>(gen::arbitrary()) -// : gen::cast>(gen::arbitrary()); -// }); -// } -// }; - -// template <> -// struct Arbitrary> { -// static Gen> arbitrary() { -// return gen::mapcat(gen::arbitrary(), [](bool is_node) { -// return is_node -// ? gen::cast>(gen::arbitrary()) -// : gen::cast>( -// gen::arbitrary()); -// }); -// } -// }; - -// template <> -// struct Arbitrary { -// static Gen arbitrary() { -// return gen::build( -// gen::set(&Serial::children, -// gen::container>>( -// gen::arbitrary>()))); -// } -// }; - -// template <> -// struct Arbitrary { -// static Gen arbitrary() { -// return gen::build( -// gen::set(&Parallel::children, -// gen::container>>( -// gen::arbitrary>()))); -// } -// }; - -// template <> -// struct Arbitrary { -// static Gen arbitrary() { -// return gen::mapcat(gen::arbitrary(), [](bool is_serial) { -// return is_serial ? gen::construct( -// gen::arbitrary()) -// : gen::construct( -// gen::arbitrary()); -// }); -// } -// }; - -// template -// struct Arbitrary { -// static Gen< -// std::enable_if, -// Tag>::value>::type> arbitrary() { -// return gen::construct(gen::arbitrary()); -// } -// }; - -// template <> -// struct Arbitrary { -// static Gen arbitrary() { -// return gen::apply(make_1d_machine_view, -// gen::arbitrary, -// gen::arbitrary, -// small_integer_generator()); -// } -// } - -// template <> -// struct Arbitrary { -// static Gen arbitrary() { -// return gen::build( -// gen::set(&MachineMapping::machine_views, -// gen::container>( -// gen::arbitrary(), -// gen::arbitrary()))); -// } -// } - -// template <> -// struct Arbitrary { -// static Gen arbitrary() { -// return gen::build( -// gen::set(&MachineSpecification::num_nodes, gen::inRange(1, 64)), -// gen::set(&MachineSpecification::num_cpus_per_node, gen::inRange(1, -// 64)), gen::set(&MachineSpecification::num_gpus_per_node, -// gen::inRange(1, 16)), -// gen::set(&MachineSpecification::inter_node_bandwidth, -// gen::nonZero()), -// gen::set(&MachineSpecification::intra_node_bandwidth, -// gen::nonZero())); -// } -// } - -// } // namespace rc - -#endif diff --git a/lib/compiler/test/src/test_labelled_open_graph.cc b/lib/compiler/test/src/test_labelled_open_graph.cc deleted file mode 100644 index 59fa0f1e5e..0000000000 --- a/lib/compiler/test/src/test_labelled_open_graph.cc +++ /dev/null @@ -1,132 +0,0 @@ -// #include "compiler/unity_algorithm.h" -// #include "doctest/doctest.h" -// // #include "rapidcheck.h" - -// using namespace FlexFlow; - -// TEST_SUITE(FF_TEST_SUITE) { -// TEST_CASE("get_subgraph(OpenMultiDiGraphView)") { -// auto g = OpenMultiDiGraph::create(); - -// Node n0 = g.add_node(); -// Node n1 = g.add_node(); -// Node n2 = g.add_node(); -// Node n3 = g.add_node(); -// Node n4 = g.add_node(); - -// NodePort p0 = g.add_node_port(); -// NodePort p1 = g.add_node_port(); -// NodePort p2 = g.add_node_port(); -// NodePort p3 = g.add_node_port(); -// NodePort p4 = g.add_node_port(); -// NodePort p5 = g.add_node_port(); -// NodePort p6 = g.add_node_port(); -// NodePort p7 = g.add_node_port(); -// NodePort p8 = g.add_node_port(); -// NodePort p9 = g.add_node_port(); - -// MultiDiEdge e0{n1, p1, n0, p0}; -// MultiDiEdge e1{n2, p2, n0, p0}; -// MultiDiEdge e2{n3, p5, n1, p3}; -// MultiDiEdge e3{n3, p6, n2, p4}; -// MultiDiEdge e4{n4, p8, n3, p7}; -// OutputMultiDiEdge e5{n4, p9, std::make_pair(p9.value(), p9.value())}; - -// g.add_edge(e0); -// g.add_edge(e1); -// g.add_edge(e2); -// g.add_edge(e3); -// g.add_edge(e4); -// g.add_edge(e5); - -// std::unordered_set node_set0{n3, n4}; - -// auto subgraph0 = get_subgraph(g, node_set0); -// auto subgraph1 = get_subgraph(g, -// node_set0); auto subgraph2 = -// get_subgraph(g, node_set0); -// auto subgraph3 = get_subgraph(g, node_set0); - -// CHECK(bool(get_nodes(subgraph0) == node_set0)); -// CHECK(bool(get_nodes(subgraph1) == node_set0)); -// CHECK(bool(get_nodes(subgraph2) == node_set0)); -// CHECK(bool(get_nodes(subgraph3) == node_set0)); - -// std::unordered_set input_set{split_edge(e2).second, -// split_edge(e3).second}; -// std::unordered_set output_set{e5}; - -// CHECK(bool(get_open_inputs(subgraph0) == input_set)); -// CHECK(bool(get_open_inputs(subgraph1) == input_set)); -// CHECK(bool(get_open_inputs(subgraph2).empty())); -// CHECK(bool(get_open_inputs(subgraph3).empty())); - -// CHECK(bool(get_open_outputs(subgraph0) == output_set)); -// CHECK(bool(get_open_outputs(subgraph1).empty())); -// CHECK(bool(get_open_outputs(subgraph2) == output_set)); -// CHECK(bool(get_open_outputs(subgraph3).empty())); - -// CHECK(bool(get_edges(subgraph0) == -// std::unordered_set{ -// split_edge(e2).second, split_edge(e3).second, e4, e5})); -// CHECK(bool(get_edges(subgraph1) == -// std::unordered_set{ -// split_edge(e2).second, split_edge(e3).second, e4})); -// CHECK(bool(get_edges(subgraph2) == -// std::unordered_set{e4, e5})); -// CHECK( -// bool(get_edges(subgraph3) == -// std::unordered_set{e4})); - -// CHECK(bool(get_closed_sources(subgraph2) == -// std::unordered_set{n3})); -// } - -// TEST_CASE("view OutputLabelledMultiDiGraph as open") { -// OutputLabelledMultiDiGraph g = -// OutputLabelledMultiDiGraph::create< -// UnorderedOutputLabelledMultiDiGraph>(); - -// Node n0 = g.add_node(0); -// Node n1 = g.add_node(1); - -// NodePort p0 = g.add_node_port(); -// NodePort p1 = g.add_node_port(); - -// MultiDiEdge e0{n1, p1, n0, p0}; - -// g.add_edge(e0); -// g.add_output(e0, 2); - -// CHECK(bool(get_edges(g).size() == 1)); - -// OutputLabelledOpenMultiDiGraphView open_graph = -// view_output_labelled_as_output_labelled_open(g); - -// CHECK(bool(open_graph.at(n0) == 0)); -// CHECK(bool(open_graph.at(n1) == 1)); -// CHECK(bool(open_graph.at(e0) == 2)); - -// CHECK(get_edges(open_graph).size() == 1); -// } - -// TEST_CASE("OutputLabelledOpenMultiDiGraph") { -// OutputLabelledOpenMultiDiGraph g = -// OutputLabelledOpenMultiDiGraph::create< -// UnorderedOutputLabelledOpenMultiDiGraph>(); - -// Node n0 = g.add_node(0); -// Node n1 = g.add_node(1); - -// NodePort p0 = g.add_node_port(); -// NodePort p1 = g.add_node_port(); - -// MultiDiEdge e0{n1, p1, n0, p0}; - -// g.add_edge(e0); -// g.add_label(e0, 2); - -// CHECK(bool(g.query_edges(OpenMultiDiEdgeQuery::all()).size() == 1)); -// CHECK(bool(get_edges(g).size() == 1)); -// } -// } diff --git a/lib/compiler/test/src/test_machine_mapping.cc b/lib/compiler/test/src/test_machine_mapping.cc deleted file mode 100644 index a4caceef66..0000000000 --- a/lib/compiler/test/src/test_machine_mapping.cc +++ /dev/null @@ -1,23 +0,0 @@ -#include "doctest/doctest.h" -#include "test_generator.h" - -TEST_SUITE(FF_TEST_SUITE) { - // TEST_CASE("MachineMapping::combine") { - // RC_SUBCASE([](MachineMapping const &m0, MachineMapping const &m1) { - // RC_PRE(MachineMapping::nodes_are_disjoint(m0, m1)); - - // MachineMapping comb = MachineMapping::combine(m0, m1); - - // RC_ASSERT(comb.machine_views.size() == - // m0.machine_views.size() + m1.machine_views.size()); - // RC_ASSERT(is_submapeq_of(comb.machine_views, m0.machine_views)); - // RC_ASSERT(is_submapeq_of(comb.machine_views, m1.machine_views)); - // }); - // } - - // TEST_CASE("OptimalCostResult::infinity") { - // RC_SUBCASE([](OptimalCostResult const &c) { - // RC_ASSERT(c.runtime <= OptimalCostResult::infinity().runtime); - // }); - // } -} diff --git a/lib/compiler/test/src/test_open_graph.cc b/lib/compiler/test/src/test_open_graph.cc deleted file mode 100644 index e3426aa293..0000000000 --- a/lib/compiler/test/src/test_open_graph.cc +++ /dev/null @@ -1,81 +0,0 @@ -// #include "compiler/unity_algorithm.h" -// #include "doctest/doctest.h" -// #include "utils/graph/algorithms.h" - -// using namespace FlexFlow; - -// TEST_SUITE(FF_TEST_SUITE) { -// TEST_CASE("get_source_sink_open_graph") { -// OpenMultiDiGraph g = -// OpenMultiDiGraph::create(); - -// Node n0 = g.add_node(); -// NodePort p0 = g.add_node_port(); -// InputMultiDiEdge e0{ -// n0, g.add_node_port(), std::make_pair(n0.value(), n0.value())}; -// g.add_edge(e0); - -// CHECK(bool(get_closed_sources(g) == std::unordered_set{})); -// CHECK(bool(get_closed_sinks(g) == std::unordered_set{n0})); - -// CHECK(bool(get_open_sources(g) == std::unordered_set{n0})); -// CHECK(bool(get_open_sinks(g) == std::unordered_set{})); -// } - -// TEST_CASE("get_source_sink_open_graph:unconnected") { -// OpenMultiDiGraph g = -// OpenMultiDiGraph::create(); - -// Node n0 = g.add_node(); -// Node n1 = g.add_node(); - -// NodePort p0 = g.add_node_port(); -// NodePort p1 = g.add_node_port(); - -// InputMultiDiEdge e0{n0, p0, std::make_pair(p0.value(), p0.value())}; -// OutputMultiDiEdge e1{n1, p1, std::make_pair(p1.value(), p1.value())}; -// g.add_edge(e0); -// g.add_edge(e1); - -// /* -// g: ->n0 -// n1-> -// */ - -// CHECK(bool(get_closed_sources(g) == std::unordered_set{n1})); -// CHECK(bool(get_closed_sinks(g) == std::unordered_set{n0})); - -// CHECK(bool(get_open_sources(g) == std::unordered_set{n0})); -// CHECK(bool(get_open_sinks(g) == std::unordered_set{n1})); -// } - -// TEST_CASE("get_cut") { -// auto g = OpenMultiDiGraph::create(); - -// std::vector ns = add_nodes(g, 5); - -// MultiDiEdge e0{ns[1], g.add_node_port(), ns[0], g.add_node_port()}; -// MultiDiEdge e1{ns[2], g.add_node_port(), ns[1], g.add_node_port()}; -// MultiDiEdge e2{ns[3], g.add_node_port(), ns[1], g.add_node_port()}; -// MultiDiEdge e3{ns[4], g.add_node_port(), ns[2], g.add_node_port()}; -// MultiDiEdge e4{ns[4], g.add_node_port(), ns[3], g.add_node_port()}; -// OutputMultiDiEdge e5{ -// ns[4], g.add_node_port(), std::make_pair(ns[4].value(), -// ns[4].value())}; - -// g.add_edge(e0); -// g.add_edge(e1); -// g.add_edge(e2); -// g.add_edge(e3); -// g.add_edge(e4); -// g.add_edge(e5); - -// GraphSplit gs0{{ns[0], ns[1]}, {ns[2], ns[3], ns[4]}}; -// CHECK(bool(get_cut_set(g, gs0) == std::unordered_set{e1, -// e2})); - -// GraphSplit gs1{{ns[0], ns[1], ns[2], ns[3]}, {ns[4]}}; -// CHECK(bool(get_cut_set(g, gs1) == std::unordered_set{e3, -// e4})); -// } -// } diff --git a/lib/compiler/test/src/test_optimal_cost.cc b/lib/compiler/test/src/test_optimal_cost.cc deleted file mode 100644 index 133558f83a..0000000000 --- a/lib/compiler/test/src/test_optimal_cost.cc +++ /dev/null @@ -1,72 +0,0 @@ -// #include "compiler/unity_algorithm.h" -// #include "doctest/doctest.h" -// #include "test_cost_estimator.h" - -// using namespace FlexFlow; - -// TEST_SUITE(FF_TEST_SUITE) { -// // Rapidcheck infrastructures for graphs does not work for now -// /* -// Tests whether optimal_cost can give a valid result given random PCG, -// trivial allowed machine views, trivial cost estimator and random machine -// specification. -// */ -// // TEST_CASE("optimal_cost") { -// // auto test_allowed_machine_views = [](Operator const &, -// // MachineSpecification const &) { -// // return std::unordered_set{make_1d_machine_view(0, 1, -// 1)}; -// // }; -// // RC_SUBCASE([](ParallelComputationGraph const &g, -// // MachineSpecification const &machine_spec) { -// // OptimalCostCache cached_subgraph_costs; -// // OptimalCostResult result = optimal_cost(g, -// // test_allowed_machine_views, -// // TestCostEstimator{}, -// // machine_spec, -// // cached_subgraph_costs); -// // RC_ASSERT(result.runtime > 0); -// // RC_ASSERT(keys(result.machine_mapping.machine_views) == -// get_nodes(g)); -// // }); -// // } - -// TEST_CASE("optimal_cost_0") { -// auto pcg = -// OutputLabelledMultiDiGraph::template -// create< -// UnorderedOutputLabelledMultiDiGraph>(); - -// Node n0 = pcg.add_node(Operator{InputAttrs{}, "input"}); -// Node n1 = pcg.add_node(Operator{ -// LinearAttrs{1, false, DataType::FLOAT, Activation::RELU, -// std::nullopt}, "linear"}); - -// MultiDiEdge e{n1, pcg.add_node_port(), n0, pcg.add_node_port()}; -// pcg.add_edge(e); -// ParallelDim dim = {2, 1, false}; -// ParallelTensorDims dims = {FFOrdered{dim}}; -// pcg.add_output(e, ParallelTensor(dims, DataType::FLOAT, -// CreateGrad::YES)); - -// auto test_allowed_machine_views = [](Operator const &, -// MachineSpecification const &) { -// return std::unordered_set{ -// make_1d_machine_view(gpu_id_t(1), gpu_id_t(2))}; -// }; - -// CostEstimator estimator = CostEstimator::create(); - -// MachineSpecification machine_spec{1, 1, 1, 1, 1}; - -// OptimalCostCache cached_results; - -// OptimalCostResult result = optimal_cost(ParallelComputationGraph(pcg), -// test_allowed_machine_views, -// estimator, -// machine_spec, -// cached_results); - -// CHECK(bool(result.runtime > 0)); -// } -// } diff --git a/lib/compiler/test/src/test_unity_algorithm.cc b/lib/compiler/test/src/unity_algorithm.cc similarity index 93% rename from lib/compiler/test/src/test_unity_algorithm.cc rename to lib/compiler/test/src/unity_algorithm.cc index ed5e895a75..8ff0978ea5 100644 --- a/lib/compiler/test/src/test_unity_algorithm.cc +++ b/lib/compiler/test/src/unity_algorithm.cc @@ -1,7 +1,5 @@ #include "compiler/unity_algorithm.h" #include "doctest/doctest.h" -#include "test_cost_estimator.h" -#include "test_generator.h" TEST_SUITE(FF_TEST_SUITE) { // Rapidcheck does not work for now diff --git a/lib/kernels/include/kernels/accessor.h b/lib/kernels/include/kernels/accessor.h index d3221474c0..5fbcd91a06 100644 --- a/lib/kernels/include/kernels/accessor.h +++ b/lib/kernels/include/kernels/accessor.h @@ -16,10 +16,10 @@ class GenericTensorAccessorW { template typename data_type_enum_to_class
::type *get() const { if (this->data_type == DT) { - return static_cast *>(this->ptr); + return static_cast *>(this->ptr); } else { - throw mk_runtime_error( - "Invalid access data type ({} != {})", this->data_type, DT); + throw mk_runtime_error(fmt::format( + "Invalid access data type ({} != {})", this->data_type, DT)); } } @@ -47,10 +47,10 @@ class GenericTensorAccessorR { template typename data_type_enum_to_class
::type const *get() const { if (this->data_type == DT) { - return static_cast const *>(this->ptr); + return static_cast const *>(this->ptr); } else { - throw mk_runtime_error( - "Invalid access data type ({} != {})", this->data_type, DT); + throw mk_runtime_error(fmt::format( + "Invalid access data type ({} != {})", this->data_type, DT)); } } @@ -94,17 +94,17 @@ template typename data_type_enum_to_class
::type * get(GenericTensorAccessorW const &a) { if (a.data_type == DT) { - return static_cast *>(a.ptr); + return static_cast *>(a.ptr); } else { throw mk_runtime_error( - "Invalid access data type ({} != {})", a.data_type, DT); + fmt::format("Invalid access data type ({} != {})", a.data_type, DT)); } } template -std::vector *> +std::vector *> get(std::vector const &accs) { - std::vector *> out; + std::vector *> out; for (auto acc : accs) { out.push_back(get
(acc)); } @@ -115,10 +115,10 @@ template typename data_type_enum_to_class
::type const * get(GenericTensorAccessorR const &a) { if (a.data_type == DT) { - return static_cast const *>(a.ptr); + return static_cast const *>(a.ptr); } else { throw mk_runtime_error( - "Invalid access data type ({} != {})", a.data_type, DT); + fmt::format("Invalid access data type ({} != {})", a.data_type, DT)); } } @@ -139,9 +139,9 @@ std::vector get_half_ptrs(std::vector const &); template -std::vector const *> +std::vector const *> get(std::vector const &accs) { - std::vector const *> out; + std::vector const *> out; for (auto acc : accs) { out.push_back(get
(acc)); } diff --git a/lib/kernels/include/kernels/array_shape.h b/lib/kernels/include/kernels/array_shape.h index 5de9fae7ad..96a3b3b281 100644 --- a/lib/kernels/include/kernels/array_shape.h +++ b/lib/kernels/include/kernels/array_shape.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_KERNELS_ARRAY_SHAPE_H #include "legion_dim.h" -#include "op-attrs/tensor_shape.h" +#include "op-attrs/tensor_shape.dtg.h" #include "utils/stack_vector.h" #include "utils/visitable.h" #include diff --git a/lib/kernels/include/kernels/attention_kernels.h b/lib/kernels/include/kernels/attention_kernels.h index 575de57f09..eb5a1b8198 100644 --- a/lib/kernels/include/kernels/attention_kernels.h +++ b/lib/kernels/include/kernels/attention_kernels.h @@ -5,7 +5,6 @@ #include "kernels/allocation.h" #include "kernels/device.h" #include "kernels/ff_handle.h" -#include "op-attrs/ops/attention.h" #include namespace FlexFlow { diff --git a/lib/kernels/include/kernels/batch_matmul_kernels.h b/lib/kernels/include/kernels/batch_matmul_kernels.h index 38be2118fa..bfd72647b0 100644 --- a/lib/kernels/include/kernels/batch_matmul_kernels.h +++ b/lib/kernels/include/kernels/batch_matmul_kernels.h @@ -4,7 +4,6 @@ #include "device.h" #include "kernels/allocation.h" #include "kernels/ff_handle.h" -#include "utils/visitable.h" namespace FlexFlow { namespace Kernels { diff --git a/lib/kernels/include/kernels/datatype_dispatch.h b/lib/kernels/include/kernels/datatype_dispatch.h index e6ab9fa8cc..e83fc3325d 100644 --- a/lib/kernels/include/kernels/datatype_dispatch.h +++ b/lib/kernels/include/kernels/datatype_dispatch.h @@ -22,7 +22,7 @@ Out dispatch(DataType dt, Args &&...args) { case DataType::BOOL: return F{}(std::forward(args)...); default: - throw mk_runtime_error("Unknown datatype {}", dt); + throw mk_runtime_error(fmt::format("Unknown datatype {}", dt)); } } diff --git a/lib/kernels/include/kernels/initializer_kernels.h b/lib/kernels/include/kernels/initializer_kernels.h index 14bb9d2cd2..52609a303f 100644 --- a/lib/kernels/include/kernels/initializer_kernels.h +++ b/lib/kernels/include/kernels/initializer_kernels.h @@ -3,6 +3,7 @@ #include "accessor.h" #include "kernels/cpu.h" +#include "op-attrs/datatype_value.dtg.h" #include "utils/variant.h" namespace FlexFlow { diff --git a/lib/kernels/include/kernels/legion_dim.h b/lib/kernels/include/kernels/legion_dim.h index aafbd2cdcb..e4dd9723b8 100644 --- a/lib/kernels/include/kernels/legion_dim.h +++ b/lib/kernels/include/kernels/legion_dim.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_KERNELS_INCLUDE_KERNELS_LEGION_DIM_H #include "kernels/legion_dim_t.dtg.h" -#include "op-attrs/dim_ordered.h" +#include "op-attrs/dim_ordered/dim_ordered.h" namespace FlexFlow { diff --git a/lib/kernels/include/kernels/linear_kernels.h b/lib/kernels/include/kernels/linear_kernels.h index c761eaf1d9..3128e39fd0 100644 --- a/lib/kernels/include/kernels/linear_kernels.h +++ b/lib/kernels/include/kernels/linear_kernels.h @@ -4,7 +4,7 @@ #include "device.h" #include "ff_handle.h" #include "op-attrs/datatype.h" -#include "op-attrs/ops/linear.h" +#include "op-attrs/ops/linear_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/kernels/src/allocation.cc b/lib/kernels/src/allocation.cc index a892e14a54..ccd88580db 100644 --- a/lib/kernels/src/allocation.cc +++ b/lib/kernels/src/allocation.cc @@ -1,4 +1,5 @@ #include "kernels/allocation.h" +#include "op-attrs/tensor_shape.h" namespace FlexFlow { diff --git a/lib/kernels/src/cpu/initializer_kernels.cc b/lib/kernels/src/cpu/initializer_kernels.cc index f3b4c9b8fd..91f4f46ef8 100644 --- a/lib/kernels/src/cpu/initializer_kernels.cc +++ b/lib/kernels/src/cpu/initializer_kernels.cc @@ -24,7 +24,7 @@ struct ConstantInitKernel { void operator()(GenericTensorAccessorW const &tensor, DataTypeValue value) const { auto arr = get
(tensor); - auto unwrapped_value = get>(value); + auto unwrapped_value = value.get>(); for (size_t i = 0; i < get_volume(tensor.shape); i++) { arr[i] = unwrapped_value; } diff --git a/lib/kernels/src/cuda/embedding_kernels.cu b/lib/kernels/src/cuda/embedding_kernels.cu index 371b45f760..e6a614ba70 100644 --- a/lib/kernels/src/cuda/embedding_kernels.cu +++ b/lib/kernels/src/cuda/embedding_kernels.cu @@ -358,7 +358,7 @@ struct ForwardKernel { weight.data_type == DataType::DOUBLE); if (!aggr.has_value()) { - embed_forward_no_aggr, real_type> + embed_forward_no_aggr, real_type_t> <<, real_type> + embed_forward_with_aggr, real_type_t> <<, real_type> + embed_backward_no_aggr, real_type_t> <<, real_type> + embed_backward_with_aggr, real_type_t> <<> + add_kernel> <<>>( input_grad.get
(), output_grad.get
(), num_elements); } diff --git a/lib/kernels/src/cuda/ops/element_unary_kernels.cu b/lib/kernels/src/cuda/ops/element_unary_kernels.cu index 3eb9c486f2..a35d28fa8c 100644 --- a/lib/kernels/src/cuda/ops/element_unary_kernels.cu +++ b/lib/kernels/src/cuda/ops/element_unary_kernels.cu @@ -267,16 +267,16 @@ struct ForwardKernel { } else if (use_scalar(op_type)) { assert(scalar.has_value()); size_t num_elements = input.shape.num_elements(); - elewise_scalar_unary_forward_kernel> + elewise_scalar_unary_forward_kernel> <<>>( num_elements, - static_cast>(scalar.value()), + static_cast>(scalar.value()), op_type, input.get(), output.get()); } else { size_t num_elements = input.shape.num_elements(); - elewise_unary_forward_kernel> + elewise_unary_forward_kernel> <<>>( num_elements, op_type, input.get(), output.get()); } @@ -313,10 +313,10 @@ struct BackwardKernel { } else if (use_scalar(op_type)) { assert(scalar.has_value()); size_t num_elements = input.shape.num_elements(); - elewise_scalar_unary_backward_kernel> + elewise_scalar_unary_backward_kernel> <<>>( num_elements, - static_cast>(scalar.value()), + static_cast>(scalar.value()), op_type, output.get(), output_grad.get(), @@ -324,7 +324,7 @@ struct BackwardKernel { input_grad.get()); } else { size_t num_elements = input.shape.num_elements(); - elewise_unary_backward_kernel> + elewise_unary_backward_kernel> <<>>( num_elements, op_type, diff --git a/lib/kernels/src/cuda/ops/partition_kernels.cu b/lib/kernels/src/cuda/ops/partition_kernels.cu index e356f83d2a..1d07efb5fa 100644 --- a/lib/kernels/src/cuda/ops/partition_kernels.cu +++ b/lib/kernels/src/cuda/ops/partition_kernels.cu @@ -41,12 +41,12 @@ struct BackwardKernel { RepartitionPerDeviceState const &m, GenericTensorAccessorW const &input_grad, GenericTensorAccessorR const &output_grad) { - add_kernel><<>>(input_grad.get(), - output_grad.get(), - input_grad.shape.num_elements()); + add_kernel><<>>(input_grad.get(), + output_grad.get(), + input_grad.shape.num_elements()); } }; diff --git a/lib/kernels/src/cuda/ops/reduction_kernels.cu b/lib/kernels/src/cuda/ops/reduction_kernels.cu index 992d27fe60..0c6ba7d8e3 100644 --- a/lib/kernels/src/cuda/ops/reduction_kernels.cu +++ b/lib/kernels/src/cuda/ops/reduction_kernels.cu @@ -42,7 +42,7 @@ struct ForwardKernel { size_t num_replicas) { size_t total_elements = input.shape.num_elements() * num_replicas; - reduction_forward_kernel> + reduction_forward_kernel> <<>>( input.get(), output.get(), diff --git a/lib/kernels/src/cuda/ops/replicate_kernels.cu b/lib/kernels/src/cuda/ops/replicate_kernels.cu index 0c87418f58..76bfbe2658 100644 --- a/lib/kernels/src/cuda/ops/replicate_kernels.cu +++ b/lib/kernels/src/cuda/ops/replicate_kernels.cu @@ -54,7 +54,7 @@ struct BackwardKernel { GenericTensorAccessorR const &output, size_t num_replicas) { size_t total_elements = input.shape.num_elements() * num_replicas; - replicate_backward_kernel> + replicate_backward_kernel> <<>>( input.get(), output.get(), diff --git a/lib/kernels/src/cuda/ops/reshape_kernels.cu b/lib/kernels/src/cuda/ops/reshape_kernels.cu index c4da408952..5b7843a3a5 100644 --- a/lib/kernels/src/cuda/ops/reshape_kernels.cu +++ b/lib/kernels/src/cuda/ops/reshape_kernels.cu @@ -45,14 +45,14 @@ struct BackwardKernel { GenericTensorAccessorW const &input, GenericTensorAccessorR const &output) { float alpha = 1.0f; - apply_add_with_scale> + apply_add_with_scale> <<>>(input.get(), output.get(), input.shape.num_elements(), - static_cast>(alpha)); + static_cast>(alpha)); } }; diff --git a/lib/kernels/src/hip/ops/replicate_kernels.cpp b/lib/kernels/src/hip/ops/replicate_kernels.cpp index 9a5fc813c3..8d27bb1908 100644 --- a/lib/kernels/src/hip/ops/replicate_kernels.cpp +++ b/lib/kernels/src/hip/ops/replicate_kernels.cpp @@ -55,15 +55,16 @@ struct BackwardKernel { GenericTensorAccessorR const &output, size_t num_replicas) { size_t total_elements = input.shape.num_elements() * num_replicas; - hipLaunchKernelGGL(HIP_KERNEL_NAME(replicate_backward_kernel>), - GET_BLOCKS(total_elements), - CUDA_NUM_THREADS, - 0, - stream, - input.get(), - output.get(), - input.shape.num_elements(), - num_replicas); + hipLaunchKernelGGL( + HIP_KERNEL_NAME(replicate_backward_kernel>), + GET_BLOCKS(total_elements), + CUDA_NUM_THREADS, + 0, + stream, + input.get(), + output.get(), + input.shape.num_elements(), + num_replicas); } } diff --git a/lib/kernels/src/hip/ops/reshape_kernels.cpp b/lib/kernels/src/hip/ops/reshape_kernels.cpp index 941495c0fd..47978a5f4a 100644 --- a/lib/kernels/src/hip/ops/reshape_kernels.cpp +++ b/lib/kernels/src/hip/ops/reshape_kernels.cpp @@ -47,7 +47,7 @@ struct BackwardKernel { GenericTensorAccessorW const &input, GenericTensorAccessorR const &output) { float alpha = 1.0f; - hipLaunchKernelGGL(HIP_KERNEL_NAME(apply_add_with_scale>), + hipLaunchKernelGGL(HIP_KERNEL_NAME(apply_add_with_scale>), GET_BLOCKS(input.shape.num_elements()), CUDA_NUM_THREADS, 0, @@ -55,7 +55,7 @@ struct BackwardKernel { input.get(), output.get(), input.shape.num_elements(), - static_cast> alpha); + static_cast> alpha); } } diff --git a/lib/local-execution/include/local-execution/cost_estimate.h b/lib/local-execution/include/local-execution/cost_estimate.h index 33954827bd..7020089ccf 100644 --- a/lib/local-execution/include/local-execution/cost_estimate.h +++ b/lib/local-execution/include/local-execution/cost_estimate.h @@ -4,11 +4,10 @@ #include "local-execution/cost_details.dtg.h" #include "local-execution/local_training_backing.h" -#include "op-attrs/operator_attrs.h" -#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/pcg_operator_attrs.dtg.h" #include "pcg/machine_view.h" #include "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h" - namespace FlexFlow { struct ICostEstimator { diff --git a/lib/local-execution/include/local-execution/device_specific.h b/lib/local-execution/include/local-execution/device_specific.h index 3a36e02327..4035aaf7cf 100644 --- a/lib/local-execution/include/local-execution/device_specific.h +++ b/lib/local-execution/include/local-execution/device_specific.h @@ -28,10 +28,11 @@ struct DeviceSpecific { T const *get(size_t curr_device_idx) const { if (curr_device_idx != this->device_idx) { - throw mk_runtime_error("Invalid access to DeviceSpecific: attempted " - "device_idx {} != correct device_idx {})", - curr_device_idx, - this->device_idx); + throw mk_runtime_error( + fmt::format("Invalid access to DeviceSpecific: attempted " + "device_idx {} != correct device_idx {})", + curr_device_idx, + this->device_idx)); } return (T const *)this->ptr.get(); } diff --git a/lib/local-execution/include/local-execution/legion_tensor_shape.h b/lib/local-execution/include/local-execution/legion_tensor_shape.h index f1d2ad252a..2f2ed50d41 100644 --- a/lib/local-execution/include/local-execution/legion_tensor_shape.h +++ b/lib/local-execution/include/local-execution/legion_tensor_shape.h @@ -4,8 +4,9 @@ #include "kernels/legion_dim.h" #include "op-attrs/datatype.h" #include "op-attrs/ff_dim.h" -#include "op-attrs/tensor_shape.h" +#include "op-attrs/tensor_shape.dtg.h" #include "utils/stack_vector.h" +#include "utils/visitable.h" #include namespace FlexFlow { diff --git a/lib/local-execution/include/local-execution/local_slots_backing.h b/lib/local-execution/include/local-execution/local_slots_backing.h index 6a0c28e988..5b826c7022 100644 --- a/lib/local-execution/include/local-execution/local_slots_backing.h +++ b/lib/local-execution/include/local-execution/local_slots_backing.h @@ -7,6 +7,9 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/per_device_op_state.h" #include "local-execution/runtime_arg_config.h" +#include "pcg/computation_graph.dtg.h" +#include "pcg/layer_guid_t.dtg.h" +#include "pcg/tensor_guid_t.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/include/local-execution/local_training_backing.h b/lib/local-execution/include/local-execution/local_training_backing.h index b398bb8cc3..6789624076 100644 --- a/lib/local-execution/include/local-execution/local_training_backing.h +++ b/lib/local-execution/include/local-execution/local_training_backing.h @@ -3,6 +3,7 @@ #include "local-execution/local_slots_backing.h" #include "local-execution/task_registry.h" +#include "pcg/computation_graph.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/include/local-execution/op_arg_ref.h b/lib/local-execution/include/local-execution/op_arg_ref.h index 20d6ccb1c5..102a8d4362 100644 --- a/lib/local-execution/include/local-execution/op_arg_ref.h +++ b/lib/local-execution/include/local-execution/op_arg_ref.h @@ -5,7 +5,7 @@ #include "local-execution/device_specific.h" #include "local-execution/op_arg_ref_type.dtg.h" #include "local-execution/per_device_op_state.h" -#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/include/local-execution/op_task_invocation.h b/lib/local-execution/include/local-execution/op_task_invocation.h index 73a0460554..0f351c3a0e 100644 --- a/lib/local-execution/include/local-execution/op_task_invocation.h +++ b/lib/local-execution/include/local-execution/op_task_invocation.h @@ -13,10 +13,6 @@ #include "local-execution/slot_grad_id.dtg.h" #include "local-execution/task_id_t.dtg.h" #include "local-execution/variadic_tensor_ref.h" -#include "op-attrs/computation_graph_op_attrs.h" -#include "pcg/computation_graph.h" -#include "utils/bidict/bidict.h" -#include "utils/stack_map.h" #include #include #include diff --git a/lib/local-execution/include/local-execution/permissions.h b/lib/local-execution/include/local-execution/permissions.h index ce19e38e7e..f34969f233 100644 --- a/lib/local-execution/include/local-execution/permissions.h +++ b/lib/local-execution/include/local-execution/permissions.h @@ -42,8 +42,8 @@ struct formatter<::FlexFlow::Permissions> : formatter { name = "READ_WRITE"; break; default: - throw ::FlexFlow::mk_runtime_error("Unknown permission {}", - static_cast(p)); + throw ::FlexFlow::mk_runtime_error( + fmt::format("Unknown permission {}", static_cast(p))); } return formatter::format(name, ctx); } diff --git a/lib/local-execution/include/local-execution/serialization.h b/lib/local-execution/include/local-execution/serialization.h index a260519a55..2fc4b4b706 100644 --- a/lib/local-execution/include/local-execution/serialization.h +++ b/lib/local-execution/include/local-execution/serialization.h @@ -3,7 +3,7 @@ #include "kernels/device.h" #include "kernels/nccl.h" -#include "op-attrs/dim_ordered.h" +#include "op-attrs/dim_ordered/dim_ordered.h" #include "utils/required.h" #include "utils/strong_typedef.h" #include "utils/type_traits.h" diff --git a/lib/local-execution/include/local-execution/sim_environment.h b/lib/local-execution/include/local-execution/sim_environment.h index 3ba17ea3ff..7c81cba408 100644 --- a/lib/local-execution/include/local-execution/sim_environment.h +++ b/lib/local-execution/include/local-execution/sim_environment.h @@ -7,7 +7,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/task_argument_accessor.h" #include "local-execution/task_signature_impl.h" -#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" #include "pcg/machine_view.h" #include diff --git a/lib/local-execution/include/local-execution/task_registry.struct.toml b/lib/local-execution/include/local-execution/task_registry.struct.toml index 308527efac..ada467a67d 100644 --- a/lib/local-execution/include/local-execution/task_registry.struct.toml +++ b/lib/local-execution/include/local-execution/task_registry.struct.toml @@ -15,6 +15,7 @@ includes = [ src_includes = [ "utils/hash/unordered_map.h", "utils/fmt/unordered_map.h", + "utils/fmt/optional.h", ] [[fields]] diff --git a/lib/local-execution/src/legion_tensor_shape.cc b/lib/local-execution/src/legion_tensor_shape.cc index b3a045bab4..bce29fafeb 100644 --- a/lib/local-execution/src/legion_tensor_shape.cc +++ b/lib/local-execution/src/legion_tensor_shape.cc @@ -1,4 +1,5 @@ #include "local-execution/legion_tensor_shape.h" +#include "op-attrs/tensor_shape.h" namespace FlexFlow { diff --git a/lib/local-execution/src/local_cost_estimator.cc b/lib/local-execution/src/local_cost_estimator.cc index d4e0467cbf..6d82e26511 100644 --- a/lib/local-execution/src/local_cost_estimator.cc +++ b/lib/local-execution/src/local_cost_estimator.cc @@ -5,6 +5,7 @@ #include "op-attrs/computation_graph_op_attrs.h" #include "op-attrs/pcg_operator_attrs.h" #include "pcg/computation_graph_builder.h" +#include "pcg/machine_view.dtg.h" #include "pcg/parallel_tensor_attrs.h" #include "utils/containers/transform.h" @@ -51,7 +52,7 @@ CostDetails LocalCostEstimator::estimate_cost( for (ParallelTensorShape const &input : inputs) { TensorShape tensor_shape = get_piece_shape(input); tensor_guid_t tensor_id = - cg_builder.create_tensor(tensor_shape, CreateGrad::YES); + cg_builder.create_input(tensor_shape, CreateGrad::YES); GenericTensorAccessorW tensor_backing = allocator.allocate_tensor(tensor_shape); tensor_backing_map.insert({tensor_id, tensor_backing}); @@ -69,7 +70,10 @@ CostDetails LocalCostEstimator::estimate_cost( std::vector output_tensor_ids = cg_builder.add_layer(layer_attrs, input_tensor_ids, - get_vector_piece_attrs(weights), + transform(get_vector_piece_attrs(weights), + [&](TensorAttrs const &a) { + return cg_builder.create_weight(a); + }), get_vector_piece_attrs(outputs)); LocalTrainingBacking local_backing(allocator, diff --git a/lib/local-execution/src/local_slots_backing.cc b/lib/local-execution/src/local_slots_backing.cc index 0ec9068c6a..ac35d63c0b 100644 --- a/lib/local-execution/src/local_slots_backing.cc +++ b/lib/local-execution/src/local_slots_backing.cc @@ -1,4 +1,6 @@ #include "local-execution/local_slots_backing.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "pcg/computation_graph.h" #include "utils/containers/contains_key.h" #include "utils/overload.h" diff --git a/lib/local-execution/src/local_task_argument_accessor.cc b/lib/local-execution/src/local_task_argument_accessor.cc index 5d0156201e..54eca7e514 100644 --- a/lib/local-execution/src/local_task_argument_accessor.cc +++ b/lib/local-execution/src/local_task_argument_accessor.cc @@ -30,7 +30,7 @@ GenericTensorAccessor LocalTaskArgumentAccessor::get_tensor( } else if (priv == Permissions::RW || priv == Permissions::WO) { return tensor_backing; } else { - throw mk_runtime_error("Unhandled privilege mode {}", priv); + throw mk_runtime_error(fmt::format("Unhandled privilege mode {}", priv)); } } VariadicGenericTensorAccessor LocalTaskArgumentAccessor::get_variadic_tensor( @@ -49,7 +49,7 @@ VariadicGenericTensorAccessor LocalTaskArgumentAccessor::get_variadic_tensor( } else if (priv == Permissions::RW || priv == Permissions::WO) { return variadic_tensor_backing; } else { - throw mk_runtime_error("Unhandled privilege mode {}", priv); + throw mk_runtime_error(fmt::format("Unhandled privilege mode {}", priv)); } } diff --git a/lib/local-execution/src/local_training_backing.cc b/lib/local-execution/src/local_training_backing.cc index a2ee06a95a..0fdf1761e3 100644 --- a/lib/local-execution/src/local_training_backing.cc +++ b/lib/local-execution/src/local_training_backing.cc @@ -1,5 +1,6 @@ #include "local-execution/local_training_backing.h" #include "local-execution/task_signature_impl.h" +#include "pcg/computation_graph.h" #include "utils/containers/reversed.h" #include "utils/exception.h" diff --git a/lib/local-execution/src/op_task_signature.cc b/lib/local-execution/src/op_task_signature.cc index 36a1dd708d..932b330453 100644 --- a/lib/local-execution/src/op_task_signature.cc +++ b/lib/local-execution/src/op_task_signature.cc @@ -1,4 +1,5 @@ #include "local-execution/op_task_signature.h" +#include "utils/fmt/optional.h" #include "utils/fmt/unordered_map.h" #include "utils/fmt/unordered_set.h" diff --git a/lib/local-execution/src/ops/batch_matmul.h b/lib/local-execution/src/ops/batch_matmul.h index c082dec020..a7e29b1931 100644 --- a/lib/local-execution/src/ops/batch_matmul.h +++ b/lib/local-execution/src/ops/batch_matmul.h @@ -4,7 +4,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/op_task_signature.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/batch_matmul.h" +#include "op-attrs/ops/batch_matmul.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/batch_norm.h b/lib/local-execution/src/ops/batch_norm.h index 1f6cceec19..36aa8ffa4e 100644 --- a/lib/local-execution/src/ops/batch_norm.h +++ b/lib/local-execution/src/ops/batch_norm.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/batch_norm.h" +#include "op-attrs/ops/batch_norm_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/cast.h b/lib/local-execution/src/ops/cast.h index b4a1e91c91..e7af6aca6b 100644 --- a/lib/local-execution/src/ops/cast.h +++ b/lib/local-execution/src/ops/cast.h @@ -17,7 +17,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/cast.h" +#include "op-attrs/ops/cast_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/combine.h b/lib/local-execution/src/ops/combine.h index c6157a2955..e85e8fba39 100644 --- a/lib/local-execution/src/ops/combine.h +++ b/lib/local-execution/src/ops/combine.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/combine.h" +#include "op-attrs/ops/combine_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/concat.cc b/lib/local-execution/src/ops/concat.cc index 35f663b1cd..4c3462e694 100644 --- a/lib/local-execution/src/ops/concat.cc +++ b/lib/local-execution/src/ops/concat.cc @@ -50,7 +50,7 @@ static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { auto output = acc.get_tensor(OUTPUT); auto inputs = acc.get_variadic_tensor(INPUTS); - assert(attrs.num_inputs <= MAX_NUM_INPUTS); + assert(inputs.size() <= MAX_NUM_INPUTS); return profile(forward_kernel, profiling, @@ -68,7 +68,7 @@ static std::optional auto input_grads = acc.get_variadic_tensor_grad(INPUTS); auto output_grad = acc.get_tensor_grad(OUTPUT); - assert(attrs.num_inputs <= MAX_NUM_INPUTS); + assert(input_grads.size() <= MAX_NUM_INPUTS); return profile(backward_kernel, profiling, diff --git a/lib/local-execution/src/ops/concat.h b/lib/local-execution/src/ops/concat.h index 1f1443f25d..eab70d621c 100644 --- a/lib/local-execution/src/ops/concat.h +++ b/lib/local-execution/src/ops/concat.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/concat.h" +#include "op-attrs/ops/concat_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/conv_2d.h b/lib/local-execution/src/ops/conv_2d.h index f70d36d514..0358d71eea 100644 --- a/lib/local-execution/src/ops/conv_2d.h +++ b/lib/local-execution/src/ops/conv_2d.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/conv_2d.h" +#include "op-attrs/ops/conv_2d_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/dropout.h b/lib/local-execution/src/ops/dropout.h index 84b67a29c2..a3dc5ff8af 100644 --- a/lib/local-execution/src/ops/dropout.h +++ b/lib/local-execution/src/ops/dropout.h @@ -4,7 +4,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" #include "local-execution/task_id_t.dtg.h" -#include "op-attrs/ops/dropout.h" +#include "op-attrs/ops/dropout_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/element_binary.h b/lib/local-execution/src/ops/element_binary.h index 05273e34b4..72c0976df8 100644 --- a/lib/local-execution/src/ops/element_binary.h +++ b/lib/local-execution/src/ops/element_binary.h @@ -3,7 +3,7 @@ #include "local-execution/sim_environment.h" #include "local-execution/task_signature_impl.h" -#include "op-attrs/ops/element_binary.h" +#include "op-attrs/ops/element_binary_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/element_unary.cc b/lib/local-execution/src/ops/element_unary.cc index a52ebb8089..4ee609bd6c 100644 --- a/lib/local-execution/src/ops/element_unary.cc +++ b/lib/local-execution/src/ops/element_unary.cc @@ -1,6 +1,7 @@ #include "element_unary.h" #include "kernels/element_unary_kernels.h" #include "op-attrs/get_output_shapes.h" +#include "op-attrs/parallel_tensor_shape.h" #include "utils/hash-utils.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/element_unary.h b/lib/local-execution/src/ops/element_unary.h index 4d1783f1f6..04a72e2e12 100644 --- a/lib/local-execution/src/ops/element_unary.h +++ b/lib/local-execution/src/ops/element_unary.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/element_unary.h" +#include "op-attrs/ops/element_unary_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/embedding.h b/lib/local-execution/src/ops/embedding.h index 0463984122..995d2296e1 100644 --- a/lib/local-execution/src/ops/embedding.h +++ b/lib/local-execution/src/ops/embedding.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/embedding.h" +#include "op-attrs/ops/embedding_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/flat.h b/lib/local-execution/src/ops/flat.h index d1501f85ca..e019bfc654 100644 --- a/lib/local-execution/src/ops/flat.h +++ b/lib/local-execution/src/ops/flat.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_FLAT_H #include "local-execution/sim_environment.h" -#include "op-attrs/ops/flat.h" +#include "op-attrs/ops/flat_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/gather.h b/lib/local-execution/src/ops/gather.h index 74db276e35..e339683381 100644 --- a/lib/local-execution/src/ops/gather.h +++ b/lib/local-execution/src/ops/gather.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/gather.h" +#include "op-attrs/ops/gather_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/input.h b/lib/local-execution/src/ops/input.h index 97985585e1..baad25b798 100644 --- a/lib/local-execution/src/ops/input.h +++ b/lib/local-execution/src/ops/input.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_INPUT_H #include "local-execution/op_task_invocation.h" -#include "op-attrs/ops/input.h" +#include "op-attrs/ops/input_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/layer_norm.h b/lib/local-execution/src/ops/layer_norm.h index 4f8d87153b..8e034ac519 100644 --- a/lib/local-execution/src/ops/layer_norm.h +++ b/lib/local-execution/src/ops/layer_norm.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/layer_norm.h" +#include "op-attrs/ops/layer_norm_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/linear.h b/lib/local-execution/src/ops/linear.h index 2c76483df4..2aaf13a95a 100644 --- a/lib/local-execution/src/ops/linear.h +++ b/lib/local-execution/src/ops/linear.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/linear.h" +#include "op-attrs/ops/linear_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/noop.h b/lib/local-execution/src/ops/noop.h index 959f7dc054..1097adeb5e 100644 --- a/lib/local-execution/src/ops/noop.h +++ b/lib/local-execution/src/ops/noop.h @@ -2,9 +2,7 @@ #define _FLEXFLOW_NOOP_H #include "local-execution/op_task_invocation.h" -#include "op-attrs/ops/input.h" -#include "op-attrs/ops/noop.h" -#include "op-attrs/ops/weight_attrs.dtg.h" +#include "op-attrs/ops/noop_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/pool_2d.h b/lib/local-execution/src/ops/pool_2d.h index e8624185ac..908fd5462f 100644 --- a/lib/local-execution/src/ops/pool_2d.h +++ b/lib/local-execution/src/ops/pool_2d.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/pool_2d.h" +#include "op-attrs/ops/pool_2d_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/reduce.h b/lib/local-execution/src/ops/reduce.h index 92f0578757..7900c28159 100644 --- a/lib/local-execution/src/ops/reduce.h +++ b/lib/local-execution/src/ops/reduce.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/reduce.h" +#include "op-attrs/ops/reduce_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/reduction.h b/lib/local-execution/src/ops/reduction.h index a0af4f3aea..56833602e6 100644 --- a/lib/local-execution/src/ops/reduction.h +++ b/lib/local-execution/src/ops/reduction.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/reduction.h" +#include "op-attrs/ops/reduction_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/repartition.h b/lib/local-execution/src/ops/repartition.h index b38a93f8b1..5187d04ca0 100644 --- a/lib/local-execution/src/ops/repartition.h +++ b/lib/local-execution/src/ops/repartition.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/repartition.h" +#include "op-attrs/ops/repartition_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/replicate.h b/lib/local-execution/src/ops/replicate.h index 77bda411c1..85d1dff41a 100644 --- a/lib/local-execution/src/ops/replicate.h +++ b/lib/local-execution/src/ops/replicate.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/replicate.h" +#include "op-attrs/ops/replicate_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/reshape.h b/lib/local-execution/src/ops/reshape.h index 06a6b32597..37f07534ee 100644 --- a/lib/local-execution/src/ops/reshape.h +++ b/lib/local-execution/src/ops/reshape.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/reshape.h" +#include "op-attrs/ops/reshape_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/reverse.h b/lib/local-execution/src/ops/reverse.h index 10072860b0..7c16073be7 100644 --- a/lib/local-execution/src/ops/reverse.h +++ b/lib/local-execution/src/ops/reverse.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/reverse.h" +#include "op-attrs/ops/reverse_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/softmax.h b/lib/local-execution/src/ops/softmax.h index b5756d92ff..d440fe7239 100644 --- a/lib/local-execution/src/ops/softmax.h +++ b/lib/local-execution/src/ops/softmax.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/softmax.h" +#include "op-attrs/ops/softmax_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/split.h b/lib/local-execution/src/ops/split.h index c82152b06a..dde46c20bf 100644 --- a/lib/local-execution/src/ops/split.h +++ b/lib/local-execution/src/ops/split.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/split.h" +#include "op-attrs/ops/split_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/topk.h b/lib/local-execution/src/ops/topk.h index b04d807400..c8f3175ebd 100644 --- a/lib/local-execution/src/ops/topk.h +++ b/lib/local-execution/src/ops/topk.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/topk.h" +#include "op-attrs/ops/topk_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/transpose.h b/lib/local-execution/src/ops/transpose.h index 3feffc7d86..0f3a2e80a0 100644 --- a/lib/local-execution/src/ops/transpose.h +++ b/lib/local-execution/src/ops/transpose.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/transpose.h" +#include "op-attrs/ops/transpose_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/permissions.cc b/lib/local-execution/src/permissions.cc index e5c46b42f8..2286215987 100644 --- a/lib/local-execution/src/permissions.cc +++ b/lib/local-execution/src/permissions.cc @@ -33,7 +33,8 @@ static int as_int(Permissions p) { case Permissions::RW: return 2; default: - throw mk_runtime_error("Unknown permission {}", static_cast(p)); + throw mk_runtime_error( + fmt::format("Unknown permission {}", static_cast(p))); } } diff --git a/lib/local-execution/test/src/test_local_cost_estimator.cc b/lib/local-execution/test/src/test_local_cost_estimator.cc index 2bd0acc222..da3af6e3ad 100644 --- a/lib/local-execution/test/src/test_local_cost_estimator.cc +++ b/lib/local-execution/test/src/test_local_cost_estimator.cc @@ -1,77 +1,79 @@ -#include "doctest/doctest.h" -#include "kernels/local_cuda_allocator.h" -#include "kernels/managed_per_device_ff_handle.h" -#include "local-execution/local_cost_estimator.h" -#include "pcg/computation_graph_builder.h" -#include "test_utils.h" +// #include "doctest/doctest.h" +// #include "kernels/local_cuda_allocator.h" +// #include "kernels/managed_per_device_ff_handle.h" +// #include "local-execution/local_cost_estimator.h" +// #include "op-attrs/ops/attention.h" +// #include "op-attrs/parallel_tensor_shape.h" +// #include "pcg/computation_graph_builder.h" +// #include "test_utils.h" -namespace FlexFlow { +// using namespace ::FlexFlow; -TEST_SUITE(FF_CUDA_TEST_SUITE) { - TEST_CASE("Local Cost Estimator") { - // local backing initialization - ManagedPerDeviceFFHandle managed_handle{}; +// TEST_SUITE(FF_CUDA_TEST_SUITE) { +// TEST_CASE("Local Cost Estimator") { +// // local backing initialization +// ManagedPerDeviceFFHandle managed_handle{}; - RuntimeArgConfig runtime_arg_config = RuntimeArgConfig{ - DeviceSpecific::create(managed_handle.raw_handle()), - EnableProfiling::YES, - ProfilingSettings{/*warmup_iters=*/0, - /*measure_iters=*/1}}; +// RuntimeArgConfig runtime_arg_config = RuntimeArgConfig{ +// DeviceSpecific::create(managed_handle.raw_handle()), +// EnableProfiling::YES, +// ProfilingSettings{/*warmup_iters=*/0, +// /*measure_iters=*/1}}; - LocalCostEstimator cost_estimator = LocalCostEstimator{runtime_arg_config}; +// LocalCostEstimator cost_estimator = +// LocalCostEstimator{runtime_arg_config}; - SUBCASE("Estimate cost -- Attention Op") { - int embed_dim = 32; - int num_heads = 10; - MultiHeadAttentionAttrs attrs = MultiHeadAttentionAttrs{ - /*embed_dim=*/embed_dim, - /*num_heads=*/num_heads, - /*kdim=*/embed_dim, - /*vdim=*/embed_dim, - /*dropout=*/0.0, - /*bias=*/true, - /*add_bias_kv=*/false, - /*add_zero_attn=*/false, - }; +// SUBCASE("Estimate cost -- Attention Op") { +// int embed_dim = 32; +// int num_heads = 10; +// MultiHeadAttentionAttrs attrs = MultiHeadAttentionAttrs{ +// /*embed_dim=*/embed_dim, +// /*num_heads=*/num_heads, +// /*kdim=*/embed_dim, +// /*vdim=*/embed_dim, +// /*dropout=*/0.0, +// /*bias=*/true, +// /*add_bias_kv=*/false, +// /*add_zero_attn=*/false, +// }; - size_t batch_size = 40; - size_t seq_len = 48; - size_t feature_size = 36; +// size_t batch_size = 40; +// size_t seq_len = 48; +// size_t feature_size = 36; - DataType dtype = DataType::FLOAT; - ParallelTensorShape inputs_shape = lift_to_parallel(TensorShape{ - TensorDims{FFOrdered{batch_size, seq_len, feature_size}}, - DataType::FLOAT, - }); +// DataType dtype = DataType::FLOAT; +// ParallelTensorShape inputs_shape = lift_to_parallel(TensorShape{ +// TensorDims{FFOrdered{batch_size, seq_len, feature_size}}, +// DataType::FLOAT, +// }); - ParallelTensorShape weights_shape = throw_if_unexpected( - get_weights_shape(attrs, inputs_shape, inputs_shape, inputs_shape)); - ParallelTensorAttrs weight_attrs = - ParallelTensorAttrs{weights_shape, - /*sync_type=*/std::nullopt, - /*initializer=*/std::nullopt, - CreateGrad::YES}; +// ParallelTensorShape weights_shape = throw_if_unexpected( +// get_weights_shape(attrs, inputs_shape, inputs_shape, +// inputs_shape)); +// ParallelTensorAttrs weight_attrs = +// ParallelTensorAttrs{weights_shape, +// /*sync_type=*/std::nullopt, +// /*initializer=*/std::nullopt, +// CreateGrad::YES}; - ParallelTensorShape output_shape = throw_if_unexpected( - get_output_shape(attrs, inputs_shape, inputs_shape, inputs_shape)); - ParallelTensorAttrs output_attrs = - ParallelTensorAttrs{output_shape, - /*sync_type=*/std::nullopt, - /*initializer=*/std::nullopt, - CreateGrad::YES}; +// ParallelTensorShape output_shape = throw_if_unexpected( +// get_output_shape(attrs, inputs_shape, inputs_shape, inputs_shape)); +// ParallelTensorAttrs output_attrs = +// ParallelTensorAttrs{output_shape, +// /*sync_type=*/std::nullopt, +// /*initializer=*/std::nullopt, +// CreateGrad::YES}; - CostDetails result = cost_estimator.estimate_cost( - PCGOperatorAttrs{attrs}, - std::vector{ - inputs_shape, inputs_shape, inputs_shape}, - std::vector{weight_attrs}, - std::vector{output_attrs}, - make_1d_machine_view(gpu_id_t{0}, gpu_id_t{1})); +// CostDetails result = cost_estimator.estimate_cost( +// PCGOperatorAttrs{attrs}, +// std::vector{ +// inputs_shape, inputs_shape, inputs_shape}, +// std::vector{weight_attrs}, +// std::vector{output_attrs}, +// make_1d_machine_view(gpu_id_t{0}, gpu_id_t{1})); - CHECK(result.total_elapsed_time > 0); - CHECK(result.total_mem_usage > 0); - } - } -} - -} // namespace FlexFlow +// CHECK(result.total_elapsed_time > 0); +// CHECK(result.total_mem_usage > 0); +// } +// } +// } diff --git a/lib/local-execution/test/src/test_local_slots_backing.cc b/lib/local-execution/test/src/test_local_slots_backing.cc index 542aa66087..1ec441fbca 100644 --- a/lib/local-execution/test/src/test_local_slots_backing.cc +++ b/lib/local-execution/test/src/test_local_slots_backing.cc @@ -1,15 +1,19 @@ -#include "doctest/doctest.h" #include "kernels/attention_kernels.h" #include "local-execution/local_cost_estimator.h" #include "local-execution/local_cpu_allocator.h" #include "local-execution/local_slots_backing.h" +#include "op-attrs/ops/attention.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "pcg/computation_graph.h" #include "pcg/computation_graph_builder.h" +#include "test/utils/doctest/fmt/pair.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/variant.h" +#include "test/utils/doctest/fmt/vector.h" #include "test_utils.h" -#include "utils/fmt/unordered_map.h" -#include "utils/fmt/variant.h" -#include "utils/fmt/vector.h" +#include -namespace FlexFlow { +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("LocalSlotsBacking -- Attention Op") { @@ -37,11 +41,11 @@ TEST_SUITE(FF_TEST_SUITE) { // build graph ComputationGraphBuilder cg_builder; tensor_guid_t query_guid = - cg_builder.create_tensor(query_shape, CreateGrad::YES); + cg_builder.create_input(query_shape, CreateGrad::YES); tensor_guid_t key_guid = - cg_builder.create_tensor(key_shape, CreateGrad::YES); + cg_builder.create_input(key_shape, CreateGrad::YES); tensor_guid_t value_guid = - cg_builder.create_tensor(value_shape, CreateGrad::YES); + cg_builder.create_input(value_shape, CreateGrad::YES); std::string layer_name = "attn1"; tensor_guid_t output_guid = @@ -269,5 +273,3 @@ TEST_SUITE(FF_TEST_SUITE) { } } } - -} // namespace FlexFlow diff --git a/lib/local-execution/test/src/test_local_task_arg_accessor.cc b/lib/local-execution/test/src/test_local_task_arg_accessor.cc index 0637faaf1c..f52fccb1ed 100644 --- a/lib/local-execution/test/src/test_local_task_arg_accessor.cc +++ b/lib/local-execution/test/src/test_local_task_arg_accessor.cc @@ -4,7 +4,7 @@ #include "local-execution/task_signature_impl.h" #include "utils/fmt/variant.h" -namespace FlexFlow { +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("LocalTaskArgumentAccessor") { @@ -140,5 +140,3 @@ TEST_SUITE(FF_TEST_SUITE) { } } } - -} // namespace FlexFlow diff --git a/lib/local-execution/test/src/test_task_registry.cc b/lib/local-execution/test/src/test_task_registry.cc index fa3b068425..e18b7ea2de 100644 --- a/lib/local-execution/test/src/test_task_registry.cc +++ b/lib/local-execution/test/src/test_task_registry.cc @@ -7,7 +7,7 @@ #include "utils/fmt/optional.h" #include "utils/fmt/unordered_map.h" -namespace FlexFlow { +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Task Registry") { @@ -127,5 +127,3 @@ TEST_SUITE(FF_TEST_SUITE) { } } } - -} // namespace FlexFlow diff --git a/lib/models/CMakeLists.txt b/lib/models/CMakeLists.txt index 7dd7f48700..4f4b22ed47 100644 --- a/lib/models/CMakeLists.txt +++ b/lib/models/CMakeLists.txt @@ -11,6 +11,7 @@ ff_add_library( op-attrs utils pcg + rapidcheck ) -add_subdirectory(test) \ No newline at end of file +add_subdirectory(test) diff --git a/lib/models/include/models/bert/bert.h b/lib/models/include/models/bert/bert.h new file mode 100644 index 0000000000..0047996b78 --- /dev/null +++ b/lib/models/include/models/bert/bert.h @@ -0,0 +1,41 @@ +#ifndef _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_BERT_H +#define _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_BERT_H + +#include "models/bert/bert_config.dtg.h" +#include "pcg/computation_graph_builder.h" + +namespace FlexFlow { + +// Helper functions to construct the BERT model +tensor_guid_t create_bert_feedforward_network(ComputationGraphBuilder &, + BertConfig const &, + tensor_guid_t const &); +tensor_guid_t create_bert_encoder_layer(ComputationGraphBuilder &, + BertConfig const &, + tensor_guid_t const &); +tensor_guid_t create_bert_encoder(ComputationGraphBuilder &, + BertConfig const &, + tensor_guid_t const &); + +/** + * @brief Get the base config of the BERT model. + * + * @details Refer to + * https://huggingface.co/docs/transformers/v4.18.0/en/model_doc/bert#transformers.BertConfig + * for default configs. + */ +BertConfig get_default_bert_config(); + +/** + * @brief Get the BERT computation graph. + * + * @note This is a plain encoder-only model for pre-training. + * + * @param BertConfig The config of BERT model. + * @return ComputationGraph The computation graph of a BERT model. + */ +ComputationGraph get_bert_computation_graph(BertConfig const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/models/include/models/bert/bert_config.struct.toml b/lib/models/include/models/bert/bert_config.struct.toml new file mode 100644 index 0000000000..398210cf48 --- /dev/null +++ b/lib/models/include/models/bert/bert_config.struct.toml @@ -0,0 +1,71 @@ +namespace = "FlexFlow" +name = "BertConfig" + +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/activation.dtg.h", +] + +[[fields]] +name = "vocab_size" +type = "size_t" + +[[fields]] +name = "hidden_size" +type = "size_t" + +[[fields]] +name = "num_encoder_layers" +type = "size_t" + +[[fields]] +name = "num_heads" +type = "size_t" + +[[fields]] +name = "dim_feedforward" +type = "size_t" + +[[fields]] +name = "hidden_act" +type = "::FlexFlow::Activation" + +[[fields]] +name = "hidden_dropout_prob" +type = "float" + +[[fields]] +name = "attention_probs_dropout_prob" +type = "float" + +[[fields]] +name = "initializer_range" +type = "float" + +[[fields]] +name = "layer_norm_eps" +type = "float" + +[[fields]] +name = "position_embedding_type" +type = "std::string" + +[[fields]] +name = "classifier_dropout" +type = "float" + +[[fields]] +name = "sequence_length" +type = "size_t" + +[[fields]] +name = "batch_size" +type = "size_t" diff --git a/lib/models/include/models/candle_uno/candle_uno.h b/lib/models/include/models/candle_uno/candle_uno.h new file mode 100644 index 0000000000..a2d21f2830 --- /dev/null +++ b/lib/models/include/models/candle_uno/candle_uno.h @@ -0,0 +1,41 @@ +#ifndef _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_CANDLE_UNO_H +#define _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_CANDLE_UNO_H + +#include "candle_uno_config.dtg.h" +#include "pcg/computation_graph_builder.h" +#include +#include +#include + +namespace FlexFlow { + +// Helper functions to construct the Candle Uno model +tensor_guid_t create_candle_uno_feature_model(ComputationGraphBuilder &, + CandleUnoConfig const &, + tensor_guid_t const &); + +/** + * @brief Get the default configs of Candle Uno model. + * + * @details The default configs come from the dataset used by the original + * model: + * https://github.com/ECP-CANDLE/Benchmarks/tree/f6a3da8818308c9edcd9fedbcf831dd5968efcdd/Pilot1/Uno + */ +CandleUnoConfig get_default_candle_uno_config(); + +/** + * @brief Get the Candle Uno computation graph. + * + * @details CandleUnoConfig.feature_shapes is a map from feature name to the + * number of channels for the feature, and CandleUnoConfig.input_features is a + * map from specific data identifier in the dataset to the feature name used in + * this model. + * + * @param CandleUnoConfig The config of the Candle Uno model. + * @return ComputationGraph The PCG of a Transformer model. + */ +ComputationGraph get_candle_uno_computation_graph(CandleUnoConfig const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/models/include/models/candle_uno/candle_uno_config.struct.toml b/lib/models/include/models/candle_uno/candle_uno_config.struct.toml new file mode 100644 index 0000000000..667a6531c3 --- /dev/null +++ b/lib/models/include/models/candle_uno/candle_uno_config.struct.toml @@ -0,0 +1,52 @@ +namespace = "FlexFlow" +name = "CandleUnoConfig" + +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "", + "", + "", +] + +src_includes = [ + "utils/fmt/vector.h", + "utils/fmt/map.h", + "utils/hash/vector.h", + "utils/hash/map.h", +] + +[[fields]] +name = "batch_size" +type = "size_t" + +[[fields]] +name = "dense_layers" +type = "std::vector" + +[[fields]] +name = "dense_feature_layers" +type = "std::vector" + +[[fields]] +name = "feature_shapes" +type = "std::map" + +[[fields]] +name = "input_features" +type = "std::map" + +[[fields]] +name = "dropout" +type = "float" + +[[fields]] +name = "residual" +type = "bool" diff --git a/lib/models/include/models/inception_v3/inception_v3.h b/lib/models/include/models/inception_v3/inception_v3.h new file mode 100644 index 0000000000..5c4754e441 --- /dev/null +++ b/lib/models/include/models/inception_v3/inception_v3.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_INCEPTION_V3 +#define _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_INCEPTION_V3 + +#include "models/inception_v3/inception_v3_config.dtg.h" +#include "pcg/computation_graph.dtg.h" + +namespace FlexFlow { + +/** + * @brief Get the default training config from https://arxiv.org/abs/1512.00567. + */ +InceptionV3Config get_default_inception_v3_training_config(); + +/** + * @brief Get a computation graph for Inception-v3 as described in + * https://arxiv.org/abs/1512.00567. + */ +ComputationGraph + get_inception_v3_computation_graph(InceptionV3Config const &config); + +} // namespace FlexFlow + +#endif diff --git a/lib/models/include/models/inception_v3/inception_v3_config.struct.toml b/lib/models/include/models/inception_v3/inception_v3_config.struct.toml new file mode 100644 index 0000000000..a2a75c83bb --- /dev/null +++ b/lib/models/include/models/inception_v3/inception_v3_config.struct.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "InceptionV3Config" + +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "num_classes" +type = "int" + +[[fields]] +name = "batch_size" +type = "int" + +[[fields]] +name = "aux_logits" +type = "bool" diff --git a/lib/models/include/models/inception_v3/inception_v3_output.struct.toml b/lib/models/include/models/inception_v3/inception_v3_output.struct.toml new file mode 100644 index 0000000000..066e6df02b --- /dev/null +++ b/lib/models/include/models/inception_v3/inception_v3_output.struct.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "InceptionV3Output" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "pcg/tensor_guid_t.dtg.h", + "", +] + +src_includes = [ + "utils/fmt/optional.h", +] + +[[fields]] +name = "standard_logits" +type = "::FlexFlow::tensor_guid_t" + +[[fields]] +name = "aux_logits" +type = "std::optional<::FlexFlow::tensor_guid_t>" diff --git a/lib/models/include/models/split_test/split_test.h b/lib/models/include/models/split_test/split_test.h new file mode 100644 index 0000000000..b03e45b2d2 --- /dev/null +++ b/lib/models/include/models/split_test/split_test.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_SPLIT_TEST_SPLIT_TEST_H +#define _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_SPLIT_TEST_SPLIT_TEST_H + +#include "pcg/computation_graph.dtg.h" + +namespace FlexFlow { + +/** + * @brief Get the computation graph of the old FlexFlow test model + * split_test + * + * @note This is a tiny model developed for testing the original Unity + * implementation. It is not a "real" model and has never been trained. + */ +ComputationGraph get_split_test_computation_graph(int batch_size); + +} // namespace FlexFlow + +#endif diff --git a/lib/models/include/models/transformer.h b/lib/models/include/models/transformer/transformer.h similarity index 90% rename from lib/models/include/models/transformer.h rename to lib/models/include/models/transformer/transformer.h index e50fa37709..385100a4c9 100644 --- a/lib/models/include/models/transformer.h +++ b/lib/models/include/models/transformer/transformer.h @@ -1,7 +1,7 @@ -#ifndef _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_TRANSFORMER_H -#define _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_TRANSFORMER_H +#ifndef _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_TRANSFORMER_TRANSFORMER_H +#define _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_TRANSFORMER_TRANSFORMER_H -#include "models/transformer_config.dtg.h" +#include "models/transformer/transformer_config.dtg.h" #include "pcg/computation_graph_builder.h" namespace FlexFlow { diff --git a/lib/models/include/models/transformer_config.struct.toml b/lib/models/include/models/transformer/transformer_config.struct.toml similarity index 100% rename from lib/models/include/models/transformer_config.struct.toml rename to lib/models/include/models/transformer/transformer_config.struct.toml diff --git a/lib/models/src/models/bert/bert.cc b/lib/models/src/models/bert/bert.cc new file mode 100644 index 0000000000..cf48f2399b --- /dev/null +++ b/lib/models/src/models/bert/bert.cc @@ -0,0 +1,160 @@ +#include "models/bert/bert.h" +#include "op-attrs/tensor_shape.h" +#include "pcg/computation_graph.h" +#include "pcg/initializers/truncated_normal_initializer_attrs.dtg.h" + +namespace FlexFlow { + +BertConfig get_default_bert_config() { + return BertConfig{/*vocab_size=*/30522, + /*hidden_size=*/768, + /*num_encoder_layers=*/12, + /*num_heads=*/12, + /*dim_feedforward=*/3072, + /*hidden_act=*/Activation::GELU, + /*hidden_dropout_prob=*/0.1, + /*attention_probs_dropout_prob=*/0.1, + /*initializer_range=*/0.02, + /*layer_norm_eps=*/1e-12, + /*position_embedding_type=*/"absolute", + /*classifier_dropout=*/0.1, + /*sequence_length=*/512, + /*batch_size=*/64}; +} + +tensor_guid_t + create_feedforward_network(ComputationGraphBuilder &cgb, + BertConfig const &config, + tensor_guid_t const &input, + InitializerAttrs const &bias_initializer, + InitializerAttrs const &projection_initializer) { + tensor_guid_t layer1_out = + cgb.dense(input, + config.dim_feedforward, + /*activation=*/config.hidden_act, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*projection_initializer=*/projection_initializer, + /*bias_initializer=*/bias_initializer); + tensor_guid_t dropout_out = + cgb.dropout(layer1_out, config.hidden_dropout_prob); + tensor_guid_t layer2_out = + cgb.dense(dropout_out, + config.hidden_size, + /*activation=*/std::nullopt, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*projection_initializer=*/projection_initializer, + /*bias_initializer=*/bias_initializer); + return cgb.dropout(layer2_out, config.hidden_dropout_prob); +}; + +tensor_guid_t + create_bert_encoder_layer(ComputationGraphBuilder &cgb, + BertConfig const &config, + tensor_guid_t const &input, + InitializerAttrs const &bias_initializer, + InitializerAttrs const &projection_initializer) { + assert(num_dims(cgb.get_shape(input)) == 3); + std::vector layer_norm_axis = {2}; // Apply layernorm across the last dim + int kdim = config.dim_feedforward / config.num_heads; + int vdim = config.dim_feedforward / config.num_heads; + tensor_guid_t self_attention = + cgb.multihead_attention(input, + input, + input, + config.hidden_size, + config.num_heads, + kdim, + vdim, + /*dropout=*/config.attention_probs_dropout_prob, + /*bias=*/true, + /*add_bias_kv=*/false, + /*add_zero_attn=*/false, + /*initializer=*/projection_initializer); + assert(are_tensor_guid_shapes_equivalent( + cgb.computation_graph, input, self_attention)); + + tensor_guid_t normalized = cgb.layer_norm(cgb.add(self_attention, input), + layer_norm_axis, + /*elementwise_affine=*/true, + config.layer_norm_eps); + assert(are_tensor_guid_shapes_equivalent( + cgb.computation_graph, input, normalized)); + + tensor_guid_t feedforward_output = create_feedforward_network( + cgb, config, normalized, bias_initializer, projection_initializer); + assert(are_tensor_guid_shapes_equivalent( + cgb.computation_graph, input, feedforward_output)); + return cgb.layer_norm(cgb.add(normalized, feedforward_output), + layer_norm_axis, + /*elementwise_affine=*/true, + config.layer_norm_eps); +} + +tensor_guid_t + create_bert_encoder(ComputationGraphBuilder &cgb, + BertConfig const &config, + tensor_guid_t const &input, + InitializerAttrs const &bias_initializer, + InitializerAttrs const &projection_initializer) { + tensor_guid_t t = input; + for (int i = 0; i < config.num_encoder_layers; i++) { + t = create_bert_encoder_layer( + cgb, config, t, bias_initializer, projection_initializer); + } + return t; +}; + +ComputationGraph get_bert_computation_graph(BertConfig const &config) { + if (config.position_embedding_type != "absolute") { + throw mk_runtime_error( + fmt::format("Currently only position_embedding_type=absolute is " + "supported, but found position_embedding_type={}. " + "If you need support for additional " + "position_embedding_type values, please create an issue.", + config.position_embedding_type)); + } + + ComputationGraphBuilder cgb; + InitializerAttrs projection_initializer = + InitializerAttrs{TruncatedNormalInitializerAttrs{ + /*seed=*/0, + /*mean=*/0, + /*stddev=*/config.initializer_range, + /*min_cutoff=*/-2 * config.initializer_range, + /*max_cutoff=*/2 * config.initializer_range}}; + InitializerAttrs bias_initializer = InitializerAttrs{ZeroInitializerAttrs{}}; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + config.batch_size, config.sequence_length, config.hidden_size}}, + DataType::FLOAT, + }; + tensor_guid_t input = cgb.create_input(input_shape, CreateGrad::YES); + + tensor_guid_t encoder_output = create_bert_encoder( + cgb, config, input, bias_initializer, projection_initializer); + assert(are_tensor_guid_shapes_equivalent( + cgb.computation_graph, input, encoder_output)); + + tensor_guid_t out_prob = + cgb.softmax(cgb.dense(encoder_output, + /*outDim=*/config.vocab_size, + /*activation=*/config.hidden_act, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*projection_initializer=*/projection_initializer, + /*bias_initializer=*/bias_initializer)); + assert( + (cgb.get_shape(out_prob) == + TensorShape{ + TensorDims{FFOrdered{ + config.batch_size, config.sequence_length, config.vocab_size}}, + DataType::FLOAT, + })); + + return cgb.computation_graph; +} + +} // namespace FlexFlow diff --git a/lib/models/src/models/candle_uno/candle_uno.cc b/lib/models/src/models/candle_uno/candle_uno.cc new file mode 100644 index 0000000000..4d52d515fb --- /dev/null +++ b/lib/models/src/models/candle_uno/candle_uno.cc @@ -0,0 +1,123 @@ +#include "models/candle_uno/candle_uno.h" +#include "pcg/initializers/glorot_normal_attrs.dtg.h" + +namespace FlexFlow { + +CandleUnoConfig get_default_candle_uno_config() { + CandleUnoConfig config{ + /*batch_size=*/64, + /*dense_layers=*/std::vector(4, 4192), + /*dense_feature_layers=*/std::vector(8, 4192), + /*feature_shapes=*/std::map{}, + /*input_features=*/std::map{}, + /*dropout=*/0.1, + /*residual=*/false}; + + config.feature_shapes["dose"] = 1; + config.feature_shapes["cell.rnaseq"] = 942; + config.feature_shapes["drug.descriptors"] = 5270; + config.feature_shapes["drug.fingerprints"] = 2048; + + config.input_features["dose1"] = "dose"; + config.input_features["dose2"] = "dose"; + config.input_features["cell.rnaseq"] = "cell.rnaseq"; + config.input_features["drug1.descriptors"] = "drug.descriptors"; + config.input_features["drug1.fingerprints"] = "drug.fingerprints"; + config.input_features["drug2.descriptors"] = "drug.descriptors"; + config.input_features["drug2.fingerprints"] = "drug.fingerprints"; + + return config; +} + +tensor_guid_t create_candle_uno_feature_model( + ComputationGraphBuilder &cgb, + CandleUnoConfig const &config, + tensor_guid_t const &input, + InitializerAttrs const &kernel_initializer) { + tensor_guid_t t = input; + for (int const dense_dim : config.dense_feature_layers) { + t = cgb.dense(t, + dense_dim, + Activation::RELU, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*kernel_initializer=*/kernel_initializer); + if (config.dropout > 0) { + t = cgb.dropout(t, config.dropout); + } + } + return t; +} + +ComputationGraph + get_candle_uno_computation_graph(CandleUnoConfig const &config) { + ComputationGraphBuilder cgb; + InitializerAttrs kernel_initializer = + InitializerAttrs{GlorotNormalAttrs{/*seed=*/0}}; + + auto create_input_tensor = + [&](FFOrdered const &dims) -> tensor_guid_t { + TensorShape input_shape = TensorShape{ + TensorDims{dims}, + DataType::FLOAT, + }; + return cgb.create_input(input_shape, CreateGrad::YES); + }; + + std::set input_models; + for (auto const &shape : config.feature_shapes) { + auto const &type = shape.first; + if (type.find(".") != std::string::npos) { + std::string base_type = type.substr(0, type.find(".")); + // The string parsing here is to match with original implementation at + // https://github.com/ECP-CANDLE/Benchmarks/blob/f6a3da8818308c9edcd9fedbcf831dd5968efcdd/Pilot1/Uno/uno_baseline_keras2.py#L178. + if (base_type == "cell" || base_type == "drug") { + input_models.insert(type); + } + } + } + + std::vector all_inputs; + std::vector encoded_inputs; + + for (auto const &input_feature : config.input_features) { + std::string const &feature_name = input_feature.second; + size_t shape = config.feature_shapes.at(feature_name); + tensor_guid_t input = create_input_tensor({config.batch_size, shape}); + all_inputs.push_back(input); + + if (contains(input_models, feature_name)) { + encoded_inputs.emplace_back(create_candle_uno_feature_model( + cgb, config, input, kernel_initializer)); + } else { + encoded_inputs.emplace_back(input); + } + } + + tensor_guid_t output = cgb.concat(encoded_inputs, /*axis=*/1); + for (int const &dense_layer_dim : config.dense_layers) { + tensor_guid_t residual_input = output; + output = cgb.dense(output, + dense_layer_dim, + Activation::RELU, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*kernel_initializer=*/kernel_initializer); + if (config.dropout > 0) { + output = cgb.dropout(output, config.dropout); + } + if (config.residual) { + output = cgb.add(output, residual_input); + } + } + output = cgb.dense(output, + /*outDim=*/1, + /*activation=*/std::nullopt, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*kernel_initializer=*/kernel_initializer); + + return cgb.computation_graph; +} + +} // namespace FlexFlow diff --git a/lib/models/src/models/inception_v3/inception_v3.cc b/lib/models/src/models/inception_v3/inception_v3.cc new file mode 100644 index 0000000000..f540eae629 --- /dev/null +++ b/lib/models/src/models/inception_v3/inception_v3.cc @@ -0,0 +1,750 @@ +#include "models/inception_v3/inception_v3.h" +#include "models/inception_v3/inception_v3_output.dtg.h" +#include "op-attrs/tensor_shape.h" +#include "pcg/computation_graph.h" +#include "pcg/computation_graph_builder.h" +#include "utils/integer_conversions.h" + +namespace FlexFlow { + +struct CheckShape { + CheckShape(ComputationGraphBuilder const &cgb, + InceptionV3Config const &config) + : cgb(cgb), config(config) {} + + ComputationGraphBuilder const &cgb; + InceptionV3Config const &config; + + void operator()(tensor_guid_t t, int c, int h, int w) const { + TensorShape current_shape = cgb.get_shape(t); + TensorShape expected_shape = TensorShape{ + TensorDims{FFOrdered{ + size_t_from_int(config.batch_size), + size_t_from_int(c), + size_t_from_int(h), + size_t_from_int(w), + }}, + DataType::FLOAT, + }; + + if (current_shape != expected_shape) { + throw mk_runtime_error(fmt::format( + "Expected activation shape {}, but found activation shape {}", + expected_shape, + current_shape)); + } + } + + void operator()(tensor_guid_t t, int c) const { + TensorShape current_shape = cgb.get_shape(t); + TensorShape expected_shape = TensorShape{ + TensorDims{FFOrdered{ + size_t_from_int(config.batch_size), + size_t_from_int(c), + }}, + DataType::FLOAT, + }; + + if (current_shape != expected_shape) { + throw mk_runtime_error(fmt::format( + "Expected activation shape {}, but found activation shape {}", + expected_shape, + current_shape)); + } + } +}; + +InceptionV3Config get_default_inception_v3_training_config() { + return InceptionV3Config{ + /*num_classes=*/1000, + + // see section 8 of https://arxiv.org/abs/1512.00567 for the source of the + // batch size + /*batch_size=*/32, + + // see section 4 of https://arxiv.org/abs/1512.00567 for a discussion of + // auxiliary logits. they are used by default in training + /*aux_logits=*/true, + }; +} + +static tensor_guid_t create_conv_block(ComputationGraphBuilder &cgb, + tensor_guid_t const &input, + int filters, + int kernel_size_h, + int kernel_size_w, + int stride_h = 1, + int stride_w = 1, + int padding_h = 0, + int padding_w = 0, + bool use_bias = false) { + tensor_guid_t conv = cgb.conv2d(input, + /*outChannels=*/filters, + /*kernelH=*/kernel_size_h, + /*kernelW=*/kernel_size_w, + /*strideH=*/stride_h, + /*strideW=*/stride_w, + /*paddingH=*/padding_h, + /*paddingW=*/padding_w, + /*activation=*/std::nullopt, + /*groups=*/1, + /*use_bias=*/use_bias); + return cgb.batch_norm(conv, + /*affine=*/true, + /*activation=*/Activation::RELU, + /*eps=*/1e-5, + /*momentum=*/0.1); +} + +static tensor_guid_t create_inception_module_a(ComputationGraphBuilder &cgb, + tensor_guid_t const &input, + int pool_features) { + tensor_guid_t branch1x1 = create_conv_block(cgb, + input, + /*filters=*/64, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + + tensor_guid_t branch5x5 = [&] { + tensor_guid_t t = input; + t = create_conv_block(cgb, + t, + /*filters=*/48, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + t = create_conv_block(cgb, + t, + /*filters=*/64, + /*kernel_size_h=*/5, + /*kernel_size_w=*/5, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/2, + /*padding_w=*/2); + return t; + }(); + + tensor_guid_t branch3x3dbl = [&] { + tensor_guid_t t = input; + t = create_conv_block(cgb, + t, + /*filters=*/64, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + t = create_conv_block(cgb, + t, + /*filters=*/96, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/1, + /*padding_w=*/1); + t = create_conv_block(cgb, + t, + /*filters=*/96, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/1, + /*padding_w=*/1); + return t; + }(); + + tensor_guid_t branch_pool = [&] { + tensor_guid_t t = input; + t = cgb.pool2d(t, + /*kernelH=*/3, + /*kernelW=*/3, + /*strideH=*/1, + /*strideW=*/1, + /*paddingH=*/1, + /*paddingW=*/1, + /*type=*/PoolOp::AVG); + t = create_conv_block(cgb, + t, + /*filters=*/pool_features, + /*kernel_stride_h=*/1, + /*kernel_stride_w=*/1); + return t; + }(); + + return cgb.concat({branch1x1, branch5x5, branch3x3dbl, branch_pool}, + /*axis=*/1); +} + +static tensor_guid_t create_inception_module_b(ComputationGraphBuilder &cgb, + tensor_guid_t const &input) { + tensor_guid_t branch3x3 = create_conv_block(cgb, + input, + /*filters=*/384, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3, + /*stride_h=*/2, + /*stride_w=*/2); + + tensor_guid_t branch3x3dbl = [&] { + tensor_guid_t t = input; + t = create_conv_block(cgb, + t, + /*filters=*/64, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + t = create_conv_block(cgb, + t, + /*filters=*/96, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/1, + /*padding_w=*/1); + t = create_conv_block(cgb, + t, + /*filters=*/96, + /*kernel_stride_h=*/3, + /*kernel_stride_w=*/3, + /*stride_h=*/2, + /*stride_w=*/2); + return t; + }(); + + tensor_guid_t branch_pool = cgb.pool2d(input, + /*kernelH=*/3, + /*kernelW=*/3, + /*strideH=*/2, + /*strideW=*/2, + /*paddingH=*/0, + /*paddingW=*/0, + /*type=*/PoolOp::MAX); + + return cgb.concat({branch3x3, branch3x3dbl, branch_pool}, /*axis=*/1); +} + +static tensor_guid_t create_inception_module_c(ComputationGraphBuilder &cgb, + CheckShape const &check_shape, + tensor_guid_t const &input, + int channels_7x7) { + tensor_guid_t branch1x1 = create_conv_block(cgb, + input, + /*filters=*/192, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + check_shape(branch1x1, 192, 17, 17); + + tensor_guid_t branch7x7 = [&] { + tensor_guid_t t = input; + t = create_conv_block(cgb, + t, + /*filters=*/channels_7x7, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + t = create_conv_block(cgb, + t, + /*filters=*/channels_7x7, + /*kernel_size_h=*/1, + /*kernel_size_w=*/7, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/0, + /*padding_w=*/3); + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/7, + /*kernel_size_w=*/1, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/3, + /*padding_w=*/0); + return t; + }(); + check_shape(branch7x7, 192, 17, 17); + + tensor_guid_t branch7x7dbl = [&] { + tensor_guid_t t = input; + t = create_conv_block(cgb, + t, + /*filters=*/channels_7x7, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + t = create_conv_block(cgb, + t, + /*filters=*/channels_7x7, + /*kernel_size_h=*/7, + /*kernel_size_w=*/1, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/3, + /*padding_w=*/0); + t = create_conv_block(cgb, + t, + /*filters=*/channels_7x7, + /*kernel_size_h=*/1, + /*kernel_size_w=*/7, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/0, + /*padding_w=*/3); + t = create_conv_block(cgb, + t, + /*filters=*/channels_7x7, + /*kernel_size_h=*/7, + /*kernel_size_w=*/1, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/3, + /*padding_w=*/0); + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/1, + /*kernel_size_w=*/7, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/0, + /*padding_w=*/3); + return t; + }(); + check_shape(branch7x7dbl, 192, 17, 17); + + tensor_guid_t branch_pool = [&] { + tensor_guid_t t = input; + t = cgb.pool2d(t, + /*kernelH=*/3, + /*kernelW=*/3, + /*strideH=*/1, + /*strideW=*/1, + /*paddingH=*/1, + /*paddingW=*/1, + /*type=*/PoolOp::AVG); + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + return t; + }(); + check_shape(branch_pool, 192, 17, 17); + + return cgb.concat({branch1x1, branch7x7, branch7x7dbl, branch_pool}, + /*axis=*/1); +} + +static tensor_guid_t create_inception_module_d(ComputationGraphBuilder &cgb, + tensor_guid_t const &input) { + tensor_guid_t branch3x3 = [&] { + tensor_guid_t t = input; + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + t = create_conv_block(cgb, t, 320, 3, 3, 2, 2); + return t; + }(); + + tensor_guid_t branch7x7x3 = [&] { + tensor_guid_t t = input; + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/1, + /*kernel_size_w=*/7, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/0, + /*padding_w=*/3); + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/7, + /*kernel_size_w=*/1, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/3, + /*padding_w=*/0); + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3, + /*stride_h=*/2, + /*stride_w=*/2); + return t; + }(); + + tensor_guid_t branch_pool = cgb.pool2d(input, + /*kernelH=*/3, + /*kernelW=*/3, + /*strideH=*/2, + /*strideW=*/2, + /*paddingH=*/0, + /*paddingW=*/0, + /*type=*/PoolOp::MAX); + + return cgb.concat({branch3x3, branch7x7x3, branch_pool}, /*axis=*/1); +} + +static tensor_guid_t create_inception_module_e(ComputationGraphBuilder &cgb, + tensor_guid_t const &input) { + tensor_guid_t branch1x1 = create_conv_block(cgb, + input, + /*filters=*/320, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + + tensor_guid_t branch3x3 = [&] { + tensor_guid_t t = input; + t = create_conv_block(cgb, + t, + /*filters=*/384, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + tensor_guid_t t_1 = create_conv_block(cgb, + t, + /*filters=*/384, + /*kernel_size_h=*/1, + /*kernel_size_w=*/3, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/0, + /*padding_w=*/1); + tensor_guid_t t_2 = create_conv_block(cgb, + t, + /*filters=*/384, + /*kernel_size_h=*/3, + /*kernel_size_w=*/1, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/1, + /*padding_w=*/0); + t = cgb.concat({t_1, t_2}, /*axis=*/1); + return t; + }(); + + tensor_guid_t branch3x3dbl = [&] { + tensor_guid_t t = input; + t = create_conv_block(cgb, + t, + /*filters=*/448, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + t = create_conv_block(cgb, + t, + /*filters=*/384, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/1, + /*padding_w=*/1); + tensor_guid_t t_1 = create_conv_block(cgb, + t, + /*filters=*/384, + /*kernel_size_h=*/1, + /*kernel_size_w=*/3, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/0, + /*padding_w=*/1); + tensor_guid_t t_2 = create_conv_block(cgb, + t, + /*filters=*/384, + /*kernel_size_h=*/3, + /*kernel_size_w=*/1, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/1, + /*padding_w=*/0); + t = cgb.concat({t_1, t_2}, /*axis=*/1); + return t; + }(); + + tensor_guid_t branch_pool = [&] { + tensor_guid_t t = input; + t = cgb.pool2d(t, + /*kernelH=*/3, + /*kernelW=*/3, + /*strideH=*/1, + /*strideW=*/1, + /*paddingH=*/1, + /*paddingW=*/1, + /*type=*/PoolOp::AVG); + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + return t; + }(); + + return cgb.concat({branch1x1, branch3x3, branch3x3dbl, branch_pool}, + /*axis=*/1); +} + +static tensor_guid_t create_initial_layers(ComputationGraphBuilder &cgb, + CheckShape const &check_shape, + tensor_guid_t const &input) { + tensor_guid_t t = input; + + check_shape(t, 3, 299, 299); + + // Conv2d_1a_3x3 + t = create_conv_block(cgb, + t, + /*filters=*/32, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3, + /*stride_h=*/2, + /*stride_w=*/2); + check_shape(t, 32, 149, 149); + + // Conv2d_2a_3x3 + t = create_conv_block(cgb, + t, + /*filters=*/32, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3); + check_shape(t, 32, 147, 147); + + // Conv2d_2b_3x3 + t = create_conv_block(cgb, + t, + /*filters=*/64, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/1, + /*padding_w=*/1); + check_shape(t, 64, 147, 147); + + // maxpool1 + t = cgb.pool2d(t, + /*kernelH=*/3, + /*kernelW=*/3, + /*strideH=*/2, + /*strideW=*/2, + /*paddingH=*/0, + /*paddingW=*/0, + /*type=*/PoolOp::MAX); + check_shape(t, 64, 73, 73); + + // Conv2d_3b_1x1 + t = create_conv_block(cgb, + t, + /*filters=*/80, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + check_shape(t, 80, 73, 73); + + // Conv2d_4a_3x3 + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3); + check_shape(t, 192, 71, 71); + + // maxpool2 + t = cgb.pool2d(t, + /*kernelH=*/3, + /*kernelW=*/3, + /*strideH=*/2, + /*strideW=*/2, + /*paddingH=*/0, + /*paddingW=*/0, + /*type=*/PoolOp::MAX); + check_shape(t, 192, 35, 35); + + return t; +} + +static tensor_guid_t create_final_layers(ComputationGraphBuilder &cgb, + CheckShape const &check_shape, + tensor_guid_t const &input, + size_t num_classes) { + // avgpool + tensor_guid_t x = cgb.pool2d(input, + /*kernelH=*/8, + /*kernelW=*/8, + /*strideH=*/1, + /*strideW=*/1, + /*paddingH=*/0, + /*paddingW=*/0, + /*type=*/PoolOp::AVG); + check_shape(x, 2048, 1, 1); + + // dropout + x = cgb.dropout(x, + /*rate=*/0.5); + check_shape(x, 2048, 1, 1); + + x = cgb.flat(x, + /*start_dim=*/1); + check_shape(x, 2048); + + // fc + x = cgb.dense(x, + /*outDim=*/num_classes); + check_shape(x, num_classes); + + // softmax (not in pytorch model, but shown in Table 1 on p6 of + // https://arxiv.org/abs/1512.00567) + x = cgb.softmax(x); + check_shape(x, num_classes); + + return x; +} + +static tensor_guid_t create_inception_aux(ComputationGraphBuilder &cgb, + CheckShape const &check_shape, + tensor_guid_t const &input, + size_t num_classes) { + tensor_guid_t x = input; + check_shape(x, 768, 17, 17); + + x = cgb.pool2d(x, + /*kernelH=*/5, + /*kernelW=*/5, + /*strideH=*/3, + /*strideW=*/3, + /*paddingH=*/0, + /*paddingW=*/0, + /*type=*/PoolOp::AVG); + check_shape(x, 768, 5, 5); + + // conv0 + x = create_conv_block(cgb, + x, + /*filters=*/128, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + check_shape(x, 128, 5, 5); + + // conv1 + x = create_conv_block(cgb, + x, + /*filters=*/768, + /*kernel_size_h=*/5, + /*kernel_size_w=*/5); + check_shape(x, 768, 1, 1); + + x = cgb.adaptive_pool2d(x, + /*output_h=*/1, + /*output_w=*/1); + check_shape(x, 768, 1, 1); + + x = cgb.flat(x, + /*start_dim=*/1); + check_shape(x, 768); + + // fc + x = cgb.dense(x, + /*outDim=*/num_classes); + check_shape(x, num_classes); + + return x; +} + +static InceptionV3Output create_inception_v3(ComputationGraphBuilder &cgb, + InceptionV3Config const &config, + tensor_guid_t const &input) { + // NOTE: the shapes for check_shape (as well as the layer names in comments) + // are pulled from + // https://github.com/pytorch/vision/blob/6d7851bd5e2bedc294e40e90532f0e375fcfee04/torchvision/models/inception.py#L103-L155 + CheckShape check_shape = CheckShape{ + /*cgb=*/cgb, + /*config=*/config, + }; + + tensor_guid_t x = create_initial_layers(cgb, check_shape, input); + check_shape(x, 192, 35, 35); + + // Mixed_5b + x = create_inception_module_a(cgb, x, 32); + check_shape(x, 256, 35, 35); + + // Mixed_5c + x = create_inception_module_a(cgb, x, 64); + check_shape(x, 288, 35, 35); + + // Mixed_5d + x = create_inception_module_a(cgb, x, 64); + check_shape(x, 288, 35, 35); + + // Mixed_6a + x = create_inception_module_b(cgb, x); + check_shape(x, 768, 17, 17); + + // Mixed_6b + x = create_inception_module_c(cgb, check_shape, x, 128); + check_shape(x, 768, 17, 17); + + // Mixed_6c + x = create_inception_module_c(cgb, check_shape, x, 160); + check_shape(x, 768, 17, 17); + + // Mixed_6d + x = create_inception_module_c(cgb, check_shape, x, 160); + check_shape(x, 768, 17, 17); + + // Mixed_6e + x = create_inception_module_c(cgb, check_shape, x, 192); + check_shape(x, 768, 17, 17); + + std::optional aux; + if (config.aux_logits) { + aux = create_inception_aux(cgb, check_shape, x, config.num_classes); + check_shape(aux.value(), config.num_classes); + } + + // Mixed_7a + x = create_inception_module_d(cgb, x); + check_shape(x, 1280, 8, 8); + + // Mixed_7b + x = create_inception_module_e(cgb, x); + check_shape(x, 2048, 8, 8); + + // Mixed_7c + x = create_inception_module_e(cgb, x); + check_shape(x, 2048, 8, 8); + + x = create_final_layers(cgb, check_shape, x, config.num_classes); + check_shape(x, config.num_classes); + + return InceptionV3Output{ + x, + aux, + }; +} + +ComputationGraph + get_inception_v3_computation_graph(InceptionV3Config const &config) { + ComputationGraphBuilder cgb; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + size_t_from_int(config.batch_size), + 3, + 299, + 299, + }}, + DataType::FLOAT, + }; + + tensor_guid_t input = cgb.create_input(input_shape, CreateGrad::YES); + InceptionV3Output output = create_inception_v3(cgb, config, input); + + return cgb.computation_graph; +} + +} // namespace FlexFlow diff --git a/lib/models/src/models/split_test/split_test.cc b/lib/models/src/models/split_test/split_test.cc new file mode 100644 index 0000000000..118f94ec06 --- /dev/null +++ b/lib/models/src/models/split_test/split_test.cc @@ -0,0 +1,39 @@ +#include "models/split_test/split_test.h" +#include "pcg/computation_graph_builder.h" +#include "utils/integer_conversions.h" + +namespace FlexFlow { + +ComputationGraph get_split_test_computation_graph(int batch_size) { + ComputationGraphBuilder cgb; + + int layer_dim1 = 256; + int layer_dim2 = 128; + int layer_dim3 = 64; + int layer_dim4 = 32; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + size_t_from_int(batch_size), + size_t_from_int(layer_dim1), + }}, + DataType::FLOAT, + }; + + tensor_guid_t t = cgb.create_input(input_shape, CreateGrad::YES); + t = cgb.dense(t, layer_dim2); + t = cgb.relu(t); + tensor_guid_t t1 = cgb.dense(t, layer_dim3); + tensor_guid_t t2 = cgb.dense(t, layer_dim3); + t = cgb.add(t1, t2); + t = cgb.relu(t); + t1 = cgb.dense(t, layer_dim4); + t2 = cgb.dense(t, layer_dim4); + t = cgb.add(t1, t2); + t = cgb.relu(t); + t = cgb.softmax(t); + + return cgb.computation_graph; +} + +} // namespace FlexFlow diff --git a/lib/models/src/models/transformer.cc b/lib/models/src/models/transformer/transformer.cc similarity index 91% rename from lib/models/src/models/transformer.cc rename to lib/models/src/models/transformer/transformer.cc index 874cd85787..173a1b291c 100644 --- a/lib/models/src/models/transformer.cc +++ b/lib/models/src/models/transformer/transformer.cc @@ -1,4 +1,4 @@ -#include "models/transformer.h" +#include "models/transformer/transformer.h" #include "pcg/computation_graph.h" namespace FlexFlow { @@ -42,7 +42,8 @@ tensor_guid_t create_transformer_encoder_layer(ComputationGraphBuilder &cgb, config.num_heads, kdim, vdim, - config.dropout); + config.dropout, + /*bias=*/false); assert(are_tensor_guid_shapes_equivalent( cgb.computation_graph, input, self_attention)); @@ -88,7 +89,8 @@ tensor_guid_t config.num_heads, kdim, vdim, - config.dropout); + config.dropout, + /*bias=*/false); assert(are_tensor_guid_shapes_equivalent( cgb.computation_graph, input, self_attention)); @@ -100,14 +102,15 @@ tensor_guid_t assert(are_tensor_guid_shapes_equivalent( cgb.computation_graph, input, self_attention_normalized)); - tensor_guid_t mha = cgb.multihead_attention(input, + tensor_guid_t mha = cgb.multihead_attention(self_attention_normalized, encoder_output, encoder_output, config.num_features, config.num_heads, kdim, vdim, - config.dropout); + config.dropout, + /*bias=*/false); assert(are_tensor_guid_shapes_equivalent(cgb.computation_graph, input, mha)); tensor_guid_t mha_normalized = @@ -149,11 +152,13 @@ ComputationGraph config.batch_size, config.sequence_length, config.num_features}}, DataType::FLOAT, }; - tensor_guid_t input = cgb.create_tensor(input_shape, CreateGrad::YES); + tensor_guid_t input = cgb.create_input(input_shape, CreateGrad::YES, "input"); + tensor_guid_t target = + cgb.create_input(input_shape, CreateGrad::YES, "target"); tensor_guid_t encoder_output = create_transformer_encoder(cgb, config, input); tensor_guid_t decoder_output = - create_transformer_decoder(cgb, config, input, encoder_output); + create_transformer_decoder(cgb, config, target, encoder_output); tensor_guid_t out_prob = cgb.softmax(cgb.dense(decoder_output, /*outDim=*/config.vocab_size, diff --git a/lib/models/test/src/models/bert/bert.cc b/lib/models/test/src/models/bert/bert.cc new file mode 100644 index 0000000000..1defc3a1a2 --- /dev/null +++ b/lib/models/test/src/models/bert/bert.cc @@ -0,0 +1,33 @@ +#include "models/bert/bert.h" +#include "pcg/computation_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_bert_computation_graph") { + + SUBCASE("default config") { + BertConfig config = get_default_bert_config(); + + ComputationGraph result = get_bert_computation_graph(config); + + SUBCASE("num layers") { + int result_num_layers = get_layers(result).size(); + int correct_num_layers = 245; + CHECK(result_num_layers == correct_num_layers); + } + } + + SUBCASE("throws on position_embedding_type != absolute as other values are " + "currently unsupported") { + BertConfig config = [] { + BertConfig c = get_default_bert_config(); + c.position_embedding_type = "relative_key"; + return c; + }(); + + CHECK_THROWS(get_bert_computation_graph(config)); + } + } +} diff --git a/lib/models/test/src/models/candle_uno/candle_uno.cc b/lib/models/test/src/models/candle_uno/candle_uno.cc new file mode 100644 index 0000000000..e32c5b5486 --- /dev/null +++ b/lib/models/test/src/models/candle_uno/candle_uno.cc @@ -0,0 +1,19 @@ +#include "models/candle_uno/candle_uno.h" +#include "pcg/computation_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_candle_uno_computation_graph") { + CandleUnoConfig config = get_default_candle_uno_config(); + + ComputationGraph result = get_candle_uno_computation_graph(config); + + SUBCASE("num layers") { + int result_num_layers = get_layers(result).size(); + int correct_num_layers = 142; + CHECK(result_num_layers == correct_num_layers); + } + } +} diff --git a/lib/models/test/src/models/inception_v3/inception_v3.cc b/lib/models/test/src/models/inception_v3/inception_v3.cc new file mode 100644 index 0000000000..2b0fe82fd6 --- /dev/null +++ b/lib/models/test/src/models/inception_v3/inception_v3.cc @@ -0,0 +1,19 @@ +#include "models/inception_v3/inception_v3.h" +#include "pcg/computation_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_inception_v3_computation_graph") { + InceptionV3Config config = get_default_inception_v3_training_config(); + + ComputationGraph result = get_inception_v3_computation_graph(config); + + SUBCASE("num layers") { + int result_num_layers = get_layers(result).size(); + int correct_num_layers = 522; + CHECK(result_num_layers == correct_num_layers); + } + } +} diff --git a/lib/models/test/src/models/transformer.cc b/lib/models/test/src/models/transformer/transformer.cc similarity index 85% rename from lib/models/test/src/models/transformer.cc rename to lib/models/test/src/models/transformer/transformer.cc index 2133e9965b..20274c4151 100644 --- a/lib/models/test/src/models/transformer.cc +++ b/lib/models/test/src/models/transformer/transformer.cc @@ -1,4 +1,4 @@ -#include "models/transformer.h" +#include "models/transformer/transformer.h" #include "pcg/computation_graph.h" #include @@ -12,7 +12,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("num layers") { int result_num_layers = get_layers(result).size(); - int correct_num_layers = 317; + int correct_num_layers = 258; CHECK(result_num_layers == correct_num_layers); } } diff --git a/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.h b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.h index 03f38bb8f9..52e6e12a8c 100644 --- a/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.h +++ b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.h @@ -2,12 +2,15 @@ #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_COMPUTATION_GRAPH_OP_ATTRS_H #include "op-attrs/computation_graph_op_attrs.dtg.h" +#include "op-attrs/pcg_operator_attrs.dtg.h" #include "utils/record_formatter.h" namespace FlexFlow { OperatorType get_op_type(ComputationGraphOpAttrs const &); RecordFormatter as_dot(ComputationGraphOpAttrs const &); +ComputationGraphOpAttrs + compgraph_op_attrs_from_pcg_op_attrs(PCGOperatorAttrs const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/datatype.h b/lib/op-attrs/include/op-attrs/datatype.h index 6204b9ca49..5af00fb510 100644 --- a/lib/op-attrs/include/op-attrs/datatype.h +++ b/lib/op-attrs/include/op-attrs/datatype.h @@ -47,14 +47,7 @@ typename data_type_enum_to_class
::type cast_to(T t) { } template -using real_type = typename data_type_enum_to_class
::type; - -using DataTypeValue = std::variant, - real_type, - real_type, - real_type, - /* real_type, */ - real_type>; +using real_type_t = typename data_type_enum_to_class
::type; size_t size_of_datatype(DataType); diff --git a/lib/op-attrs/include/op-attrs/datatype_value.variant.toml b/lib/op-attrs/include/op-attrs/datatype_value.variant.toml new file mode 100644 index 0000000000..3386e9d131 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/datatype_value.variant.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "DataTypeValue" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +type = "float" + +[[values]] +type = "double" + +[[values]] +type = "int32_t" + +[[values]] +type = "int64_t" + +[[values]] +type = "bool" diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/concat.h b/lib/op-attrs/include/op-attrs/dim_ordered/concat.h new file mode 100644 index 0000000000..9b9eaf9b93 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/dim_ordered/concat.h @@ -0,0 +1,34 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_CONCAT_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_CONCAT_H + +#include "op-attrs/dim_ordered/dim_ordered.h" +#include "utils/containers/concat_vectors.h" +#include "utils/containers/transform.h" + +namespace FlexFlow { + +template +FFOrdered concat(FFOrdered const &l, FFOrdered const &r) { + std::vector l_vec = std::vector(l.cbegin(), l.cend()); + std::vector r_vec = std::vector(r.cbegin(), r.cend()); + + std::vector raw_result = concat_vectors(l_vec, r_vec); + + return FFOrdered(raw_result.cbegin(), raw_result.cend()); +} + +template +FFOrdered concat(std::vector> const &inputs) { + std::vector> vec_inputs = + transform(inputs, [](FFOrdered const &input) { + return std::vector(input.cbegin(), input.cend()); + }); + + std::vector raw_result = concat_vectors(vec_inputs); + + return FFOrdered(raw_result.cbegin(), raw_result.cend()); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/dim_ordered.h b/lib/op-attrs/include/op-attrs/dim_ordered/dim_ordered.h similarity index 95% rename from lib/op-attrs/include/op-attrs/dim_ordered.h rename to lib/op-attrs/include/op-attrs/dim_ordered/dim_ordered.h index 6868ba083f..6aa23d40fc 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/dim_ordered.h @@ -3,14 +3,14 @@ #include "op-attrs/ff_dim.dtg.h" #include "utils/fmt/vector.h" -#include "utils/json.h" #include "utils/stack_vector.h" +#include namespace FlexFlow { template struct DimOrdered { - DimOrdered() = delete; + DimOrdered() {} DimOrdered(std::initializer_list const &l) : contents(l.begin(), l.end()) {} @@ -138,6 +138,10 @@ struct DimOrdered { return this->contents.size(); } + size_t empty() const { + return this->contents.empty(); + } + size_t num_dims() const { return this->size(); } @@ -202,11 +206,12 @@ FFOrdered const &outer_to_inner(FFOrdered const &ff_ordered) { namespace nlohmann { template struct adl_serializer<::FlexFlow::DimOrdered> { - static ::FlexFlow::DimOrdered from_json(json const &j) { + static ::FlexFlow::DimOrdered from_json(nlohmann::json const &j) { return {j.template get>()}; } - static void to_json(json &j, ::FlexFlow::DimOrdered const &x) { + static void to_json(nlohmann::json &j, + ::FlexFlow::DimOrdered const &x) { j = std::vector{x.cbegin(), x.cend()}; } }; diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/enumerate.h b/lib/op-attrs/include/op-attrs/dim_ordered/enumerate.h index f9f6d00532..38e7da4bb2 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/enumerate.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/enumerate.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_ENUMERATE_H #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_ENUMERATE_H -#include "op-attrs/dim_ordered.h" +#include "op-attrs/dim_ordered/dim_ordered.h" #include "utils/bidict/bidict.h" #include "utils/containers/count.h" diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/ff_ordered_from_map.h b/lib/op-attrs/include/op-attrs/dim_ordered/ff_ordered_from_map.h new file mode 100644 index 0000000000..79d4929797 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/dim_ordered/ff_ordered_from_map.h @@ -0,0 +1,29 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_FF_ORDERED_FROM_MAP_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_FF_ORDERED_FROM_MAP_H + +#include "op-attrs/dim_ordered/dim_ordered.h" +#include "op-attrs/dim_ordered/ff_ordered_of.h" + +namespace FlexFlow { + +template +FFOrdered ff_ordered_from_map(std::map const &m) { + std::vector raw; + for (int i = 0; i < m.size(); i++) { + raw.push_back(m.at(ff_dim_t{i})); + } + return ff_ordered_of(raw); +} + +template +FFOrdered ff_ordered_from_map(std::unordered_map const &m) { + std::vector raw; + for (int i = 0; i < m.size(); i++) { + raw.push_back(m.at(ff_dim_t{i})); + } + return ff_ordered_of(raw); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/ff_ordered_of.h b/lib/op-attrs/include/op-attrs/dim_ordered/ff_ordered_of.h index c843ed3842..8cc1bf3a51 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/ff_ordered_of.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/ff_ordered_of.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_FF_ORDERED_OF_H #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_FF_ORDERED_OF_H -#include "op-attrs/dim_ordered.h" +#include "op-attrs/dim_ordered/dim_ordered.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/get_idxs.h b/lib/op-attrs/include/op-attrs/dim_ordered/get_idxs.h index 560862677e..7343dc0e69 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/get_idxs.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/get_idxs.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_GET_IDXS_H #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_GET_IDXS_H -#include "op-attrs/dim_ordered.h" +#include "op-attrs/dim_ordered/dim_ordered.h" #include "utils/containers/count.h" #include "utils/containers/transform.h" diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/slice.h b/lib/op-attrs/include/op-attrs/dim_ordered/slice.h index f3dfe5d199..e4c0e8e275 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/slice.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/slice.h @@ -1,10 +1,10 @@ #ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_SLICE_H #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_SLICE_H -#include "op-attrs/dim_ordered.h" -#include "utils/containers/as_vector.h" +#include "op-attrs/dim_ordered/dim_ordered.h" #include "utils/containers/subvec.h" #include "utils/containers/transform.h" +#include "utils/containers/vector_of.h" #include "utils/optional.h" namespace FlexFlow { @@ -18,7 +18,14 @@ DimOrdered nonoverloaded_slice(DimOrdered const &d, }; return DimOrdered{ - subvec(as_vector(d), to_raw_idx(start), to_raw_idx(end))}; + subvec(vector_of(d), to_raw_idx(start), to_raw_idx(end))}; +} + +template +FFOrdered slice(FFOrdered const &d, + std::optional const &start, + std::optional const &end) { + return nonoverloaded_slice(d, start, end); } template diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/transform.h b/lib/op-attrs/include/op-attrs/dim_ordered/transform.h index 3a31ea511d..4fd3df0abb 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/transform.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/transform.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_TRANSFORM_H #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_TRANSFORM_H -#include "op-attrs/dim_ordered.h" -#include "utils/containers/as_vector.h" +#include "op-attrs/dim_ordered/dim_ordered.h" +#include "utils/containers/vector_of.h" #include "utils/containers/vector_transform.h" namespace FlexFlow { @@ -12,7 +12,7 @@ DimOrdered> transform(DimOrdered const &d, F f) { using Out = std::invoke_result_t; - return DimOrdered{vector_transform(as_vector(d), f)}; + return DimOrdered{vector_transform(vector_of(d), f)}; } } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/zip.h b/lib/op-attrs/include/op-attrs/dim_ordered/zip.h index 54554afb81..cc8b050f50 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/zip.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/zip.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_ZIP_H #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_ZIP_H -#include "op-attrs/dim_ordered.h" -#include "utils/containers/as_vector.h" +#include "op-attrs/dim_ordered/dim_ordered.h" +#include "utils/containers/vector_of.h" #include "utils/containers/zip.h" namespace FlexFlow { @@ -11,7 +11,7 @@ template DimOrdered> zip(DimOrdered const &lhs, DimOrdered const &rhs) { return DimOrdered>{ - zip(as_vector(lhs), as_vector(rhs))}; + zip(vector_of(lhs), vector_of(rhs))}; } } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/get_incoming_tensor_roles.h b/lib/op-attrs/include/op-attrs/get_incoming_tensor_roles.h new file mode 100644 index 0000000000..b395736773 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/get_incoming_tensor_roles.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_GET_INCOMING_TENSOR_ROLES_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_GET_INCOMING_TENSOR_ROLES_H + +#include "op-attrs/computation_graph_op_attrs.dtg.h" +#include "op-attrs/incoming_tensor_role.dtg.h" +#include "op-attrs/pcg_operator_attrs.dtg.h" + +namespace FlexFlow { + +std::vector + get_incoming_tensor_roles(ComputationGraphOpAttrs const &, int num_inputs); +std::vector + get_incoming_tensor_roles(PCGOperatorAttrs const &, int num_inputs); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/incoming_tensor_role.enum.toml b/lib/op-attrs/include/op-attrs/incoming_tensor_role.enum.toml new file mode 100644 index 0000000000..427701c801 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/incoming_tensor_role.enum.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "IncomingTensorRole" +features = [ + "hash", + "fmt", + "rapidcheck", + "json", +] + +[[values]] +name = "INPUT" + +[[values]] +name = "WEIGHT" diff --git a/lib/op-attrs/include/op-attrs/operator_attrs.h b/lib/op-attrs/include/op-attrs/operator_attrs.h deleted file mode 100644 index 268554b5be..0000000000 --- a/lib/op-attrs/include/op-attrs/operator_attrs.h +++ /dev/null @@ -1,50 +0,0 @@ -#ifndef _OPERATOR_PARAMS_H -#define _OPERATOR_PARAMS_H - -#include "op-attrs/ops/core.h" -#include "op-attrs/pcg_operator_attrs.dtg.h" -#include "ops/attention.h" -#include "ops/batch_matmul.h" -#include "ops/batch_norm.h" -#include "ops/broadcast.h" -#include "ops/cast.h" -#include "ops/combine.h" -#include "ops/concat.h" -#include "ops/conv_2d.h" -#include "ops/dropout.h" -#include "ops/element_binary.h" -#include "ops/element_unary.h" -#include "ops/embedding.h" -#include "ops/flat.h" -#include "ops/gather.h" -#include "ops/input.h" -#include "ops/layer_norm.h" -#include "ops/linear.h" -#include "ops/noop.h" -#include "ops/pool_2d.h" -#include "ops/reduce.h" -#include "ops/reduction.h" -#include "ops/repartition.h" -#include "ops/replicate.h" -#include "ops/reshape.h" -#include "ops/reverse.h" -#include "ops/softmax.h" -#include "ops/split.h" -#include "ops/topk.h" -#include "ops/transpose.h" -#include "utils/record_formatter.h" -#include "utils/variant.h" -#include - -namespace FlexFlow { - -std::vector get_output_shapes( - PCGOperatorAttrs const &op_params, - std::vector const &input_tensor_shapes); - -bool is_valid(PCGOperatorAttrs const &, - std::vector const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/op-attrs/include/op-attrs/ops/attention.h b/lib/op-attrs/include/op-attrs/ops/attention.h index 40f57d08af..e06d795c04 100644 --- a/lib/op-attrs/include/op-attrs/ops/attention.h +++ b/lib/op-attrs/include/op-attrs/ops/attention.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_ATTENTION_ATTRS_H #define _FLEXFLOW_ATTENTION_ATTRS_H +#include "op-attrs/incoming_tensor_role.dtg.h" #include "op-attrs/ops/attention/multihead_attention_inputs.dtg.h" #include "op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.h" #include "op-attrs/ops/attention_attrs.dtg.h" @@ -37,6 +38,9 @@ int get_kvSeqLength(MultiHeadAttentionInputs const &); int get_num_samples(MultiHeadAttentionParallelInputs const &); int get_num_samples(MultiHeadAttentionInputs const &); +std::vector + get_attention_incoming_tensor_roles(MultiHeadAttentionAttrs const &); + tl::expected get_weights_shape(MultiHeadAttentionAttrs const &, TensorShape const &input_q, @@ -58,6 +62,22 @@ tl::expected TensorShape const &input_k, TensorShape const &input_v); +tl::expected + get_weights_parallel_dims(MultiHeadAttentionAttrs const &, + ParallelTensorShape const &input_q, + ParallelTensorShape const &input_k, + ParallelTensorShape const &input_v); +tl::expected + get_input_bias_parallel_dims(MultiHeadAttentionAttrs const &, + ParallelTensorShape const &input_q, + ParallelTensorShape const &input_k, + ParallelTensorShape const &input_v); +tl::expected + get_output_bias_parallel_dims(MultiHeadAttentionAttrs const &, + ParallelTensorShape const &input_q, + ParallelTensorShape const &input_k, + ParallelTensorShape const &input_v); + tl::expected get_weights_shape(MultiHeadAttentionAttrs const &, ParallelTensorShape const &input_q, diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h index 57760d1110..574b4ef579 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h @@ -2,12 +2,15 @@ #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_MATMUL_H #include "op-attrs/ops/batch_matmul.dtg.h" +#include "op-attrs/ops/core.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" #include namespace FlexFlow { +CHECK_VALID_OP_ATTR(BatchMatmulAttrs); + bool is_valid(BatchMatmulAttrs const &, ParallelTensorShape const &, ParallelTensorShape const &); diff --git a/lib/op-attrs/include/op-attrs/ops/batch_norm.h b/lib/op-attrs/include/op-attrs/ops/batch_norm.h index 8afcbb06b1..f2e95690d1 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_norm.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_norm.h @@ -1,15 +1,42 @@ #ifndef _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_NORM_H #define _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_NORM_H +#include "op-attrs/incoming_tensor_role.dtg.h" #include "op-attrs/ops/batch_norm_attrs.dtg.h" #include "op-attrs/ops/core.h" -#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" namespace FlexFlow { -TensorShape get_output_shape(BatchNormAttrs const &, TensorShape const &); -ParallelTensorShape get_output_shape(BatchNormAttrs const &, - ParallelTensorShape const &); +std::vector + get_batch_norm_incoming_tensor_roles(BatchNormAttrs const &); + +tl::expected get_output_shape(BatchNormAttrs const &, + TensorShape const &); +tl::expected + get_gamma_weights_shape(BatchNormAttrs const &, TensorShape const &); +tl::expected + get_beta_weights_shape(BatchNormAttrs const &, TensorShape const &); + +tl::expected + get_output_parallel_dim_degrees(BatchNormAttrs const &, + ParallelTensorDimDegrees const &); +tl::expected + get_gamma_weights_parallel_dim_degrees(BatchNormAttrs const &, + ParallelTensorDimDegrees const &); +tl::expected + get_beta_weights_parallel_dim_degrees(BatchNormAttrs const &, + ParallelTensorDimDegrees const &); + +tl::expected + get_output_shape(BatchNormAttrs const &, ParallelTensorShape const &); +tl::expected + get_gamma_weights_shape(BatchNormAttrs const &, + ParallelTensorShape const &); +tl::expected + get_beta_weights_shape(BatchNormAttrs const &, ParallelTensorShape const &); CHECK_VALID_OP_ATTR(BatchNormAttrs); diff --git a/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.struct.toml index bc82f3c743..fdc3bce1fe 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.struct.toml @@ -10,6 +10,28 @@ features = [ "fmt", ] +includes = [ + "", +] + +src_includes = [ + "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", +] + [[fields]] name = "relu" type = "bool" + +[[fields]] +name = "affine" +type = "bool" + +[[fields]] +name = "eps" +type = "float" + +[[fields]] +name = "momentum" +type = "std::optional" diff --git a/lib/op-attrs/include/op-attrs/ops/broadcast.h b/lib/op-attrs/include/op-attrs/ops/broadcast.h index 0a5f057578..4fd7d49234 100644 --- a/lib/op-attrs/include/op-attrs/ops/broadcast.h +++ b/lib/op-attrs/include/op-attrs/ops/broadcast.h @@ -5,11 +5,14 @@ #include "op-attrs/ops/core.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" +#include "utils/record_formatter.h" namespace FlexFlow { CHECK_VALID_OP_ATTR(BroadcastAttrs); +RecordFormatter as_dot(BroadcastAttrs const &); + tl::expected get_output_shape(BroadcastAttrs const &, TensorShape const &); ParallelTensorShape get_output_shape(BroadcastAttrs const &, diff --git a/lib/op-attrs/include/op-attrs/ops/concat.h b/lib/op-attrs/include/op-attrs/ops/concat.h index f3ac8494c0..f07f06df85 100644 --- a/lib/op-attrs/include/op-attrs/ops/concat.h +++ b/lib/op-attrs/include/op-attrs/ops/concat.h @@ -10,10 +10,11 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(ConcatAttrs); -TensorShape get_output_shape(ConcatAttrs const &, - std::vector const &); -ParallelTensorShape get_output_shape(ConcatAttrs const &, - std::vector const &); +tl::expected + get_output_shape(ConcatAttrs const &, std::vector const &); +tl::expected + get_output_shape(ConcatAttrs const &, + std::vector const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml index 4faa870bc4..fab8132993 100644 --- a/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml @@ -17,7 +17,3 @@ includes = [ [[fields]] name = "axis" type = "::FlexFlow::ff_dim_t" - -[[fields]] -name = "num_inputs" -type = "int" diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d.h b/lib/op-attrs/include/op-attrs/ops/conv_2d.h index 72d1123c39..ae9f9249c6 100644 --- a/lib/op-attrs/include/op-attrs/ops/conv_2d.h +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_CONV_2D_ATTRS_H #define _FLEXFLOW_CONV_2D_ATTRS_H +#include "op-attrs/incoming_tensor_role.dtg.h" #include "op-attrs/ops/conv_2d_attrs.dtg.h" #include "op-attrs/ops/core.h" #include "op-attrs/parallel_tensor_shape.h" @@ -10,6 +11,9 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(Conv2DAttrs); +std::vector + get_conv2d_incoming_tensor_roles(Conv2DAttrs const &); + TensorShape get_kernel_shape(Conv2DAttrs const &attrs, TensorShape const &input); TensorShape get_bias_shape(Conv2DAttrs const &attrs, TensorShape const &input); diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml index 2fb385b64d..5bef144cd9 100644 --- a/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml @@ -12,11 +12,12 @@ features = [ includes = [ "", "op-attrs/activation.dtg.h", - "utils/json.h", ] src_includes = [ "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", ] fields = [ diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml index 4b9c8a9f45..403bb87592 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml @@ -11,12 +11,14 @@ features = [ ] includes = [ - "utils/json.h", - "op-attrs/operator_type.h", + "op-attrs/operator_type.dtg.h", + "", ] src_includes = [ "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml index 38d5a4371e..66d6f99253 100644 --- a/lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml @@ -17,6 +17,8 @@ includes = [ src_includes = [ "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/ops/flat.h b/lib/op-attrs/include/op-attrs/ops/flat.h index 676d21c59b..710cbdb44b 100644 --- a/lib/op-attrs/include/op-attrs/ops/flat.h +++ b/lib/op-attrs/include/op-attrs/ops/flat.h @@ -3,6 +3,7 @@ #include "op-attrs/ops/core.h" #include "op-attrs/ops/flat_attrs.dtg.h" +#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" @@ -11,8 +12,11 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(FlatAttrs); TensorShape get_output_shape(FlatAttrs const &, TensorShape const &); -ParallelTensorShape get_output_shape(FlatAttrs const &, - ParallelTensorShape const &); +tl::expected + get_output_parallel_dim_degrees(FlatAttrs const &, + ParallelTensorDimDegrees const &); +tl::expected + get_output_shape(FlatAttrs const &, ParallelTensorShape const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml index e445535e29..7349e2a8c4 100644 --- a/lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml @@ -8,4 +8,23 @@ features = [ "rapidcheck", "fmt", ] -fields = [] + +includes = [ + "", + "op-attrs/ff_dim.dtg.h", +] + +src_includes = [ + "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", + "op-attrs/ff_dim.h", +] + +[[fields]] +name = "start_dim" +type = "::FlexFlow::ff_dim_t" + +[[fields]] +name = "end_dim" +type = "::FlexFlow::ff_dim_t" diff --git a/lib/op-attrs/include/op-attrs/ops/layer_norm.h b/lib/op-attrs/include/op-attrs/ops/layer_norm.h index 29b0b2f514..0fbadae2a1 100644 --- a/lib/op-attrs/include/op-attrs/ops/layer_norm.h +++ b/lib/op-attrs/include/op-attrs/ops/layer_norm.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_OP_META_OPS_LAYER_NORM_ATTRS_H #define _FLEXFLOW_OP_META_OPS_LAYER_NORM_ATTRS_H +#include "op-attrs/incoming_tensor_role.dtg.h" #include "op-attrs/ops/core.h" #include "op-attrs/ops/layer_norm_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" @@ -8,6 +9,9 @@ namespace FlexFlow { +std::vector + get_layer_norm_incoming_tensor_roles(LayerNormAttrs const &); + tl::expected get_output_shape(LayerNormAttrs const &, TensorShape const &); tl::expected diff --git a/lib/op-attrs/include/op-attrs/ops/linear.h b/lib/op-attrs/include/op-attrs/ops/linear.h index 795ba19ae8..065cc7e38e 100644 --- a/lib/op-attrs/include/op-attrs/ops/linear.h +++ b/lib/op-attrs/include/op-attrs/ops/linear.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_LINEAR_ATTRS_H #define _FLEXFLOW_LINEAR_ATTRS_H +#include "op-attrs/incoming_tensor_role.dtg.h" #include "op-attrs/ops/core.h" #include "op-attrs/ops/linear_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" @@ -10,20 +11,23 @@ namespace FlexFlow { +std::vector + get_linear_incoming_tensor_roles(LinearAttrs const &); + CHECK_VALID_OP_ATTR(LinearAttrs); RecordFormatter as_dot(LinearAttrs const &); tl::expected - get_kernel_shape(LinearAttrs const &attrs, TensorShape const &input); + get_projection_shape(LinearAttrs const &attrs, TensorShape const &input); tl::expected get_bias_shape(LinearAttrs const &attrs, TensorShape const &input); tl::expected get_output_shape(LinearAttrs const &attrs, TensorShape const &input); tl::expected - get_kernel_shape(LinearAttrs const &attrs, - ParallelTensorShape const &input); + get_projection_shape(LinearAttrs const &attrs, + ParallelTensorShape const &input); tl::expected get_bias_shape(LinearAttrs const &attrs, ParallelTensorShape const &input); tl::expected diff --git a/lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml index eaa34cc496..0a35a6c5ec 100644 --- a/lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml @@ -13,11 +13,13 @@ includes = [ "op-attrs/datatype.dtg.h", "op-attrs/activation.dtg.h", "op-attrs/regularizer_attrs.dtg.h", - "utils/json.h", + "", ] src_includes = [ "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/ops/pool_2d.h b/lib/op-attrs/include/op-attrs/ops/pool_2d.h index 505fdd9f8c..1af22ad022 100644 --- a/lib/op-attrs/include/op-attrs/ops/pool_2d.h +++ b/lib/op-attrs/include/op-attrs/ops/pool_2d.h @@ -3,6 +3,7 @@ #include "op-attrs/ops/core.h" #include "op-attrs/ops/pool_2d_attrs.dtg.h" +#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" @@ -10,9 +11,22 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(Pool2DAttrs); -TensorShape get_output_shape(Pool2DAttrs const &, TensorShape const &); -ParallelTensorShape get_output_shape(Pool2DAttrs const &, - ParallelTensorShape const &); +tl::expected + make_adaptive_pool2d_attrs(TensorDims const &input_dims, + int output_h, + int output_w, + PoolOp pool_type, + std::optional const &activation); + +tl::expected get_output_shape(Pool2DAttrs const &, + TensorShape const &); + +tl::expected + get_output_shape(Pool2DAttrs const &, ParallelTensorShape const &); + +tl::expected + get_output_parallel_dim_degrees(Pool2DAttrs const &, + ParallelTensorDimDegrees const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml index 56bf682f50..20ca7deabc 100644 --- a/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml @@ -12,6 +12,13 @@ features = [ includes = [ "op-attrs/pool_op.dtg.h", "op-attrs/activation.dtg.h", + "", +] + +src_includes = [ + "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", ] [[fields]] @@ -44,4 +51,4 @@ type = "::FlexFlow::PoolOp" [[fields]] name = "activation" -type = "::FlexFlow::Activation" +type = "std::optional<::FlexFlow::Activation>" diff --git a/lib/op-attrs/include/op-attrs/ops/topk.h b/lib/op-attrs/include/op-attrs/ops/topk.h index bd11f0ae91..d6de90903a 100644 --- a/lib/op-attrs/include/op-attrs/ops/topk.h +++ b/lib/op-attrs/include/op-attrs/ops/topk.h @@ -4,11 +4,13 @@ #include "op-attrs/ops/core.h" #include "op-attrs/ops/topk_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" namespace FlexFlow { CHECK_VALID_OP_ATTR(TopKAttrs); +TensorShape get_output_shape(TopKAttrs const &, TensorShape const &); ParallelTensorShape get_output_shape(TopKAttrs const &attrs, ParallelTensorShape const &input_shape); diff --git a/lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml index 756091f653..0dc30d9a79 100644 --- a/lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml @@ -12,7 +12,7 @@ features = [ includes = [ "op-attrs/ff_dim.h", "op-attrs/ff_dim.dtg.h", - "op-attrs/dim_ordered.h", + "op-attrs/dim_ordered/dim_ordered.h", ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/parallel_dim.h b/lib/op-attrs/include/op-attrs/parallel_dim.h index 5397ad7c68..a12951dec9 100644 --- a/lib/op-attrs/include/op-attrs/parallel_dim.h +++ b/lib/op-attrs/include/op-attrs/parallel_dim.h @@ -11,6 +11,7 @@ bool is_replica_dim(ParallelDim const &); ParallelDim with_size_set_to(ParallelDim const &, size_t); ParallelDim with_degree_set_to(ParallelDim const &, int); ParallelDim with_is_replica_set_to(ParallelDim const &, bool); +int get_degree(ParallelDim const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dim_degrees.struct.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_dim_degrees.struct.toml new file mode 100644 index 0000000000..974b27d2a7 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dim_degrees.struct.toml @@ -0,0 +1,28 @@ +namespace = "FlexFlow" +name = "ParallelTensorDimDegrees" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", + "rapidcheck", +] + +includes = [ + "op-attrs/parallel_tensor_shape/sum_degree.dtg.h", + "op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h", + "op-attrs/dim_ordered/dim_ordered.h", +] + +[[fields]] +name = "sum_degree" +type = "::FlexFlow::SumDegree" + +[[fields]] +name = "discard_copy_degree" +type = "::FlexFlow::DiscardCopyDegree" + +[[fields]] +name = "shard_degrees" +type = "::FlexFlow::FFOrdered" diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dim_idx_t.variant.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_dim_idx_t.variant.toml new file mode 100644 index 0000000000..9396cbcbe8 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dim_idx_t.variant.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "parallel_tensor_dim_idx_t" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "op-attrs/ff_dim.dtg.h", + "op-attrs/replica_type.dtg.h", +] + +[[values]] +type = "::FlexFlow::ff_dim_t" + +[[values]] +type = "::FlexFlow::ReplicaType" diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h index 8e02e3607b..1b8361abf6 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_DIMS_H #include "op-attrs/parallel_dim.h" +#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" #include "op-attrs/parallel_tensor_dims.dtg.h" #include "op-attrs/tensor_dims.dtg.h" @@ -14,6 +15,18 @@ std::unordered_set replica_dims(ParallelTensorDims const &); /* size_t get_volume(ParallelTensorDims const &); */ size_t num_shard_dims(ParallelTensorDims const &); +ParallelTensorDimDegrees get_parallel_degrees(ParallelTensorDims const &); + +ParallelTensorDims lift_to_parallel(TensorDims const &); +ParallelTensorDims + lift_to_parallel_with_degrees(TensorDims const &, + SumDegree const &, + DiscardCopyDegree const &, + FFOrdered const &shard_degrees); +ParallelTensorDims + lift_to_parallel_with_degrees(TensorDims const &, + ParallelTensorDimDegrees const &); + int total_replica_degree(ParallelTensorDims const &); int total_shard_degree(ParallelTensorDims const &); int total_parallel_degree(ParallelTensorDims const &); diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml index ae6eab1e58..f24fa12309 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml @@ -10,7 +10,7 @@ features = [ ] includes = [ - "op-attrs/dim_ordered.h", + "op-attrs/dim_ordered/dim_ordered.h", "op-attrs/shard_parallel_dim.dtg.h", "op-attrs/replica_parallel_dim_set.dtg.h", "", diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h index 76356b39d4..0759dc746e 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h @@ -1,6 +1,9 @@ #ifndef _OP_META_PARALLEL_TENSOR_SHAPE_H #define _OP_META_PARALLEL_TENSOR_SHAPE_H +#include "op-attrs/parallel_dim.h" +#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" +#include "op-attrs/parallel_tensor_dim_idx_t.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/replica_parallel_dim.dtg.h" #include "op-attrs/tensor_shape.h" @@ -17,12 +20,17 @@ FFOrdered ff_ordered_shard_degrees(ParallelTensorShape const &); std::optional try_get_shard_dim_at_idx(ParallelTensorShape const &, ff_dim_t); +ParallelTensorDimDegrees get_parallel_degrees(ParallelTensorShape const &); + ParallelTensorShape lift_to_parallel(TensorShape const &); ParallelTensorShape lift_to_parallel_with_degrees(TensorShape const &, - SumDegree sum_degree, - DiscardCopyDegree discard_copy_degree, + SumDegree const &, + DiscardCopyDegree const &, FFOrdered const &shard_degrees); +ParallelTensorShape + lift_to_parallel_with_degrees(TensorShape const &, + ParallelTensorDimDegrees const &); std::unordered_set replica_dims(ParallelTensorShape const &); @@ -44,6 +52,12 @@ std::vector TensorShape get_reduced_shape(ParallelTensorShape const &); +ParallelDim get_parallel_dim_at_idx(ParallelTensorShape const &shape, + parallel_tensor_dim_idx_t idx); + +std::unordered_set + get_parallel_tensor_dim_indices(ParallelTensorShape const &shape); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.h b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.h index 08167fe3d9..723c05298d 100644 --- a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.h +++ b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.h @@ -8,8 +8,8 @@ namespace FlexFlow { bool is_parallel_op(PCGOperatorAttrs const &); OperatorType get_op_type(PCGOperatorAttrs const &); -ComputationGraphOpAttrs - compgraph_op_attrs_from_pcg_op_attrs(PCGOperatorAttrs const &); +PCGOperatorAttrs + pcg_op_attrs_from_compgraph_op_attrs(ComputationGraphOpAttrs const &); RecordFormatter as_dot(PCGOperatorAttrs const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml index 8617c5fd64..a44d712dbf 100644 --- a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml +++ b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml @@ -13,6 +13,7 @@ includes = [ "op-attrs/ops/attention_attrs.dtg.h", "op-attrs/ops/batch_matmul.dtg.h", "op-attrs/ops/batch_norm_attrs.dtg.h", + "op-attrs/ops/broadcast_attrs.dtg.h", "op-attrs/ops/cast_attrs.dtg.h", "op-attrs/ops/combine_attrs.dtg.h", "op-attrs/ops/concat_attrs.dtg.h", @@ -49,6 +50,10 @@ key = "batch_matmul" type = "::FlexFlow::BatchNormAttrs" key = "batch_norm" +[[values]] +type = "::FlexFlow::BroadcastAttrs" +key = "broadcast" + [[values]] type = "::FlexFlow::CastAttrs" key = "cast" diff --git a/lib/op-attrs/include/op-attrs/tensor_dims.h b/lib/op-attrs/include/op-attrs/tensor_dims.h index c8af3b02e7..ee44a39170 100644 --- a/lib/op-attrs/include/op-attrs/tensor_dims.h +++ b/lib/op-attrs/include/op-attrs/tensor_dims.h @@ -17,13 +17,6 @@ bool tensor_dims_is_broadcastable_to(TensorDims const &curr, std::optional get_broadcast_target_dims(std::unordered_set const &); -ParallelTensorDims lift_to_parallel(TensorDims const &); -ParallelTensorDims - lift_to_parallel_with_degrees(TensorDims const &, - SumDegree sum_degree, - DiscardCopyDegree discard_copy_degree, - FFOrdered const &shard_degrees); - } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/tensor_dims.struct.toml b/lib/op-attrs/include/op-attrs/tensor_dims.struct.toml index cff8e08b0f..b262dd32b6 100644 --- a/lib/op-attrs/include/op-attrs/tensor_dims.struct.toml +++ b/lib/op-attrs/include/op-attrs/tensor_dims.struct.toml @@ -9,7 +9,7 @@ features = [ "fmt", ] includes = [ - "op-attrs/dim_ordered.h", + "op-attrs/dim_ordered/dim_ordered.h", ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/tensor_shape.h b/lib/op-attrs/include/op-attrs/tensor_shape.h index 108df58dce..14ee637f92 100644 --- a/lib/op-attrs/include/op-attrs/tensor_shape.h +++ b/lib/op-attrs/include/op-attrs/tensor_shape.h @@ -11,11 +11,6 @@ size_t &dim_at_idx(TensorShape &, ff_dim_t); size_t get_num_elements(TensorShape const &); size_t get_size_in_bytes(TensorShape const &); -bool tensor_shape_is_broadcastable_to(TensorShape const &curr, - TensorShape const &goal); -std::optional - get_broadcast_target_shape(std::unordered_set const &); - } // namespace FlexFlow #endif diff --git a/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc b/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc index 166416cbad..c4ae7b31e5 100644 --- a/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc +++ b/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc @@ -1,5 +1,8 @@ #include "op-attrs/computation_graph_op_attrs.h" #include "op-attrs/get_op_type.h" +#include "op-attrs/ops/broadcast.h" +#include "op-attrs/ops/linear.h" +#include "utils/overload.h" namespace FlexFlow { @@ -8,4 +11,34 @@ OperatorType get_op_type(ComputationGraphOpAttrs const &attrs) { [](auto const &x) { return get_op_type(x); }); } +RecordFormatter as_dot(ComputationGraphOpAttrs const &attrs) { + return attrs.visit(overload{ + [](LinearAttrs const &l) { return as_dot(l); }, + [](BroadcastAttrs const &a) { return as_dot(a); }, + [&](auto const &) { + RecordFormatter r; + r << fmt::to_string(get_op_type(attrs)); + return r; + }, + }); +} + +ComputationGraphOpAttrs + compgraph_op_attrs_from_pcg_op_attrs(PCGOperatorAttrs const &op) { + auto fail_on_parallel_op = [](auto const &attrs) -> ComputationGraphOpAttrs { + throw mk_runtime_error( + fmt::format("Encountered parallel operator in " + "compgraph_op_attrs_from_pcg_op_attrs: {}", + attrs)); + }; + + return op.visit(overload{ + [&](CombineAttrs const &attrs) { return fail_on_parallel_op(attrs); }, + [&](ReductionAttrs const &attrs) { return fail_on_parallel_op(attrs); }, + [&](RepartitionAttrs const &attrs) { return fail_on_parallel_op(attrs); }, + [&](ReplicateAttrs const &attrs) { return fail_on_parallel_op(attrs); }, + [](auto const &attrs) { return ComputationGraphOpAttrs{attrs}; }, + }); +} + } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/dim_ordered/concat.cc b/lib/op-attrs/src/op-attrs/dim_ordered/concat.cc new file mode 100644 index 0000000000..cb29f708a3 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/dim_ordered/concat.cc @@ -0,0 +1 @@ +#include "op-attrs/dim_ordered/concat.h" diff --git a/lib/op-attrs/src/op-attrs/dim_ordered/ff_ordered_from_map.cc b/lib/op-attrs/src/op-attrs/dim_ordered/ff_ordered_from_map.cc new file mode 100644 index 0000000000..2de88f38c8 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/dim_ordered/ff_ordered_from_map.cc @@ -0,0 +1 @@ +#include "op-attrs/dim_ordered/ff_ordered_from_map.h" diff --git a/lib/op-attrs/src/op-attrs/get_incoming_tensor_roles.cc b/lib/op-attrs/src/op-attrs/get_incoming_tensor_roles.cc new file mode 100644 index 0000000000..21efc26466 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/get_incoming_tensor_roles.cc @@ -0,0 +1,104 @@ +#include "op-attrs/get_incoming_tensor_roles.h" +#include "op-attrs/ops/attention.h" +#include "op-attrs/ops/batch_norm.h" +#include "op-attrs/ops/conv_2d.h" +#include "op-attrs/ops/layer_norm.h" +#include "op-attrs/ops/linear.h" +#include "op-attrs/pcg_operator_attrs.h" +#include "utils/overload.h" + +namespace FlexFlow { + +std::vector get_incoming_tensor_roles( + ComputationGraphOpAttrs const &comp_graph_op_attrs, int num_incoming) { + return get_incoming_tensor_roles( + pcg_op_attrs_from_compgraph_op_attrs(comp_graph_op_attrs), num_incoming); +} + +std::vector + get_incoming_tensor_roles(PCGOperatorAttrs const &pcg_op_attrs, + int num_incoming) { + return pcg_op_attrs.visit>(overload{ + [](BatchMatmulAttrs const &) { + return std::vector{IncomingTensorRole::INPUT, + IncomingTensorRole::INPUT}; + }, + [](BatchNormAttrs const &attrs) { + return get_batch_norm_incoming_tensor_roles(attrs); + }, + [](BroadcastAttrs const &) { + return std::vector{IncomingTensorRole::INPUT}; + }, + [](CastAttrs const &) { return std::vector{IncomingTensorRole::INPUT}; }, + [](CombineAttrs const &) { + return std::vector{IncomingTensorRole::INPUT}; + }, + [&](ConcatAttrs const &) { + return std::vector(num_incoming, IncomingTensorRole::INPUT); + }, + [](Conv2DAttrs const &attrs) { + return get_conv2d_incoming_tensor_roles(attrs); + }, + [](DropoutAttrs const &) { + return std::vector{IncomingTensorRole::INPUT}; + }, + [](ElementBinaryAttrs const &) { + return std::vector{IncomingTensorRole::INPUT, + IncomingTensorRole::INPUT}; + }, + [](ElementUnaryAttrs const &) { + return std::vector{IncomingTensorRole::INPUT}; + }, + [](EmbeddingAttrs const &) { + return std::vector{IncomingTensorRole::INPUT, + IncomingTensorRole::WEIGHT}; + }, + [](FlatAttrs const &) { return std::vector{IncomingTensorRole::INPUT}; }, + [](GatherAttrs const &) { + return std::vector{IncomingTensorRole::INPUT}; + }, + [](InputAttrs const &) { return std::vector{}; }, + [](LayerNormAttrs const &attrs) { + return get_layer_norm_incoming_tensor_roles(attrs); + }, + [](LinearAttrs const &attrs) { + return get_linear_incoming_tensor_roles(attrs); + }, + [](MultiHeadAttentionAttrs const &attrs) { + return get_attention_incoming_tensor_roles(attrs); + }, + [](NoopAttrs const &) { return std::vector{IncomingTensorRole::INPUT}; }, + [](Pool2DAttrs const &) { + return std::vector{IncomingTensorRole::INPUT}; + }, + [](ReduceAttrs const &) { + return std::vector{IncomingTensorRole::INPUT}; + }, + [](ReductionAttrs const &) { + return std::vector{IncomingTensorRole::INPUT}; + }, + [](RepartitionAttrs const &) { + return std::vector{IncomingTensorRole::INPUT}; + }, + [](ReplicateAttrs const &) { + return std::vector{IncomingTensorRole::INPUT}; + }, + [](ReverseAttrs const &) { + return std::vector{IncomingTensorRole::INPUT}; + }, + [](ReshapeAttrs const &) { + return std::vector{IncomingTensorRole::INPUT}; + }, + [](SplitAttrs const &) { return std::vector{IncomingTensorRole::INPUT}; }, + [](SoftmaxAttrs const &) { + return std::vector{IncomingTensorRole::INPUT}; + }, + [](TopKAttrs const &) { return std::vector{IncomingTensorRole::INPUT}; }, + [](TransposeAttrs const &) { + return std::vector{IncomingTensorRole::INPUT}; + }, + [](WeightAttrs const &) { return std::vector{}; }, + }); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/get_output_shapes.cc b/lib/op-attrs/src/op-attrs/get_output_shapes.cc index d91d1a1eca..0058ee35a2 100644 --- a/lib/op-attrs/src/op-attrs/get_output_shapes.cc +++ b/lib/op-attrs/src/op-attrs/get_output_shapes.cc @@ -14,6 +14,7 @@ #include "op-attrs/ops/input.h" #include "op-attrs/ops/layer_norm.h" #include "op-attrs/ops/linear.h" +#include "op-attrs/ops/pool_2d.h" #include "op-attrs/ops/replicate.h" #include "op-attrs/ops/weight.h" #include "utils/overload.h" @@ -29,7 +30,7 @@ std::vector get_output_shape(attrs, inputs.at(0), inputs.at(1)))}; }, [&](BatchNormAttrs const &attrs) -> std::vector { - return {get_output_shape(attrs, inputs.at(0))}; + return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; }, [&](CastAttrs const &attrs) -> std::vector { return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; @@ -38,7 +39,7 @@ std::vector return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; }, [&](ConcatAttrs const &attrs) -> std::vector { - return {get_output_shape(attrs, inputs)}; + return {throw_if_unexpected(get_output_shape(attrs, inputs))}; }, [&](Conv2DAttrs const &attrs) -> std::vector { return {get_output_shape(attrs, inputs.at(0))}; @@ -57,7 +58,7 @@ std::vector return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; }, [&](FlatAttrs const &attrs) -> std::vector { - return {get_output_shape(attrs, inputs.at(0))}; + return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; }, [&](GatherAttrs const &attrs) -> std::vector { return {get_output_shape(attrs, inputs.at(0), inputs.at(1))}; @@ -71,6 +72,9 @@ std::vector [&](LinearAttrs const &attrs) -> std::vector { return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; }, + [&](Pool2DAttrs const &attrs) -> std::vector { + return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; + }, [&](ReplicateAttrs const &attrs) -> std::vector { return {get_output_shape(attrs, inputs.at(0))}; }, diff --git a/lib/op-attrs/src/op-attrs/ops/attention.cc b/lib/op-attrs/src/op-attrs/ops/attention.cc index 036daa6e67..483d832fee 100644 --- a/lib/op-attrs/src/op-attrs/ops/attention.cc +++ b/lib/op-attrs/src/op-attrs/ops/attention.cc @@ -3,6 +3,7 @@ #include "op-attrs/ops/attention/multihead_attention_parallel_inputs.h" #include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/tensor_shape.h" +#include "utils/containers/extend.h" #include "utils/integer_conversions.h" namespace FlexFlow { @@ -91,6 +92,24 @@ int get_num_samples(MultiHeadAttentionInputs const &inputs) { return inputs.batch_size; } +std::vector + get_attention_incoming_tensor_roles(MultiHeadAttentionAttrs const &attrs) { + + std::vector roles = std::vector{ + IncomingTensorRole::INPUT, + IncomingTensorRole::INPUT, + IncomingTensorRole::INPUT, + IncomingTensorRole::WEIGHT, + }; + + if (attrs.bias) { + extend(roles, + std::vector{IncomingTensorRole::WEIGHT, IncomingTensorRole::WEIGHT}); + } + + return roles; +} + tl::expected get_output_shape(MultiHeadAttentionAttrs const &attrs, TensorShape const &input_q, diff --git a/lib/op-attrs/src/op-attrs/ops/batch_norm.cc b/lib/op-attrs/src/op-attrs/ops/batch_norm.cc index b75c3521c6..f394bb8473 100644 --- a/lib/op-attrs/src/op-attrs/ops/batch_norm.cc +++ b/lib/op-attrs/src/op-attrs/ops/batch_norm.cc @@ -1,15 +1,260 @@ #include "op-attrs/ops/batch_norm.h" +#include "op-attrs/dim_ordered/concat.h" +#include "op-attrs/dim_ordered/slice.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/tensor_shape.h" +#include "utils/containers/any_of.h" +#include "utils/containers/extend.h" namespace FlexFlow { -TensorShape get_output_shape(BatchNormAttrs const &, - TensorShape const &input_shape) { +std::vector + get_batch_norm_incoming_tensor_roles(BatchNormAttrs const &attrs) { + std::vector result = {IncomingTensorRole::INPUT}; + + if (attrs.affine) { + extend(result, + std::vector{IncomingTensorRole::WEIGHT, IncomingTensorRole::WEIGHT}); + } + + return result; +} + +static std::optional + check_input_shape(BatchNormAttrs const &, TensorShape const &input_shape) { + if (num_dims(input_shape) < 2) { + return fmt::format( + "BatchNormAttrs expected input dims >= 2, but received input shape {}", + input_shape); + } + + if (input_shape.data_type != DataType::FLOAT) { + return fmt::format("BatchNormAttrs currently only supports data_type = " + "FLOAT, but received input data_type {}. " + "If you need this feature, please create an issue.", + input_shape.data_type); + } + + return std::nullopt; +} + +tl::expected + get_output_shape(BatchNormAttrs const &attrs, + TensorShape const &input_shape) { + { + std::optional maybe_err_msg = + check_input_shape(attrs, input_shape); + if (maybe_err_msg.has_value()) { + return tl::unexpected(maybe_err_msg.value()); + } + } + return input_shape; } -ParallelTensorShape get_output_shape(BatchNormAttrs const &, - ParallelTensorShape const &) { - NOT_IMPLEMENTED(); +tl::expected + get_gamma_weights_shape(BatchNormAttrs const &attrs, + TensorShape const &input_shape) { + { + std::optional maybe_err_msg = + check_input_shape(attrs, input_shape); + if (maybe_err_msg.has_value()) { + return tl::unexpected(maybe_err_msg.value()); + } + } + + if (!attrs.affine) { + return tl::unexpected("No gamma weights exist for attrs.affine = false"); + } + + size_t num_channels = dim_at_idx(input_shape, ff_dim_t{1}); + + return TensorShape{ + TensorDims{FFOrdered{ + num_channels, + }}, + DataType::FLOAT, + }; +} + +tl::expected + get_beta_weights_shape(BatchNormAttrs const &attrs, + TensorShape const &input_shape) { + + if (!attrs.affine) { + return tl::unexpected("No beta weights exist for attrs.affine = false"); + } + + return get_gamma_weights_shape(attrs, input_shape); +} + +static std::optional + check_input_degrees(BatchNormAttrs const &, + ParallelTensorDimDegrees const &input_degrees) { + if (input_degrees.shard_degrees.size() < 2) { + return fmt::format("BatchNormAttrs expected input dims >= 2, but received " + "input degrees {}", + input_degrees); + } + + if (input_degrees.sum_degree != SumDegree{1}) { + return fmt::format("Expected sum degree 1, but receieved sum degree {}", + input_degrees.sum_degree); + } + + if (input_degrees.discard_copy_degree != DiscardCopyDegree{1}) { + return fmt::format( + "Expected discard copy degree 1, but receieved discard copy degree {}", + input_degrees.discard_copy_degree); + } + + FFOrdered non_channel_degrees = + concat(slice(input_degrees.shard_degrees, ff_dim_t{0}, ff_dim_t{1}), + slice(input_degrees.shard_degrees, ff_dim_t{2}, std::nullopt)); + + if (any_of(non_channel_degrees, [](int degree) { return degree != 1; })) { + return fmt::format("Expected parallel degree of all non-channel dimensions " + "to be 1, but received input with degrees {}", + input_degrees); + } + + return std::nullopt; +} + +tl::expected + get_output_parallel_dim_degrees( + BatchNormAttrs const &attrs, + ParallelTensorDimDegrees const &input_degrees) { + { + std::optional maybe_err_msg = + check_input_degrees(attrs, input_degrees); + if (maybe_err_msg.has_value()) { + return tl::unexpected(maybe_err_msg.value()); + } + } + + return input_degrees; +} + +tl::expected + get_gamma_weights_parallel_dim_degrees( + BatchNormAttrs const &attrs, + ParallelTensorDimDegrees const &input_degrees) { + { + std::optional maybe_err_msg = + check_input_degrees(attrs, input_degrees); + if (maybe_err_msg.has_value()) { + return tl::unexpected(maybe_err_msg.value()); + } + } + + if (!attrs.affine) { + return tl::unexpected("No gamma weights exist for attrs.affine = false"); + } + + ff_dim_t channel_dim = ff_dim_t{1}; + + return ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{input_degrees.shard_degrees.at(channel_dim)}, + }; +} + +tl::expected + get_beta_weights_parallel_dim_degrees( + BatchNormAttrs const &attrs, + ParallelTensorDimDegrees const &input_degrees) { + { + std::optional maybe_err_msg = + check_input_degrees(attrs, input_degrees); + if (maybe_err_msg.has_value()) { + return tl::unexpected(maybe_err_msg.value()); + } + } + + if (!attrs.affine) { + return tl::unexpected("No beta weights exist for attrs.affine = false"); + } + + return get_gamma_weights_parallel_dim_degrees(attrs, input_degrees); +} + +tl::expected + get_output_shape(BatchNormAttrs const &attrs, + ParallelTensorShape const &input_shape) { + TensorShape unpar = ({ + tl::expected returned = + get_output_shape(attrs, get_reduced_shape(input_shape)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + ParallelTensorDimDegrees degrees = ({ + tl::expected returned = + get_output_parallel_dim_degrees(attrs, + get_parallel_degrees(input_shape)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + return lift_to_parallel_with_degrees(unpar, degrees); +} + +tl::expected + get_gamma_weights_shape(BatchNormAttrs const &attrs, + ParallelTensorShape const &input_shape) { + + TensorShape unpar = ({ + tl::expected returned = + get_gamma_weights_shape(attrs, get_reduced_shape(input_shape)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + ParallelTensorDimDegrees degrees = ({ + tl::expected returned = + get_gamma_weights_parallel_dim_degrees( + attrs, get_parallel_degrees(input_shape)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + return lift_to_parallel_with_degrees(unpar, degrees); +} + +tl::expected + get_beta_weights_shape(BatchNormAttrs const &attrs, + ParallelTensorShape const &input_shape) { + + TensorShape unpar = ({ + tl::expected returned = + get_beta_weights_shape(attrs, get_reduced_shape(input_shape)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + ParallelTensorDimDegrees degrees = ({ + tl::expected returned = + get_beta_weights_parallel_dim_degrees( + attrs, get_parallel_degrees(input_shape)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + return lift_to_parallel_with_degrees(unpar, degrees); } } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/broadcast.cc b/lib/op-attrs/src/op-attrs/ops/broadcast.cc index bd69864aff..aa3c95f551 100644 --- a/lib/op-attrs/src/op-attrs/ops/broadcast.cc +++ b/lib/op-attrs/src/op-attrs/ops/broadcast.cc @@ -1,8 +1,26 @@ #include "op-attrs/ops/broadcast.h" #include "op-attrs/tensor_dims.h" +#include "utils/record_formatter.h" namespace FlexFlow { +RecordFormatter as_dot(BroadcastAttrs const &attrs) { + RecordFormatter r; + + auto kv = [](std::string const &label, auto const &val) { + RecordFormatter rr; + rr << label << fmt::to_string(val); + return rr; + }; + + for (int i = 0; i < num_dims(attrs.target_dims); i++) { + r << kv(fmt::format("target_dims[{}]", i), + dim_at_idx(attrs.target_dims, ff_dim_t{i})); + } + + return r; +} + tl::expected get_output_shape(BroadcastAttrs const &attrs, TensorShape const &input_shape) { diff --git a/lib/op-attrs/src/op-attrs/ops/concat.cc b/lib/op-attrs/src/op-attrs/ops/concat.cc index 02fee70bea..74295f279e 100644 --- a/lib/op-attrs/src/op-attrs/ops/concat.cc +++ b/lib/op-attrs/src/op-attrs/ops/concat.cc @@ -1,24 +1,129 @@ #include "op-attrs/ops/concat.h" +#include "op-attrs/dim_ordered/enumerate.h" +#include "op-attrs/dim_ordered/ff_ordered_from_map.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/tensor_dims.h" +#include "op-attrs/tensor_shape.h" +#include "utils/containers/all_of.h" +#include "utils/containers/are_all_same.h" +#include "utils/containers/require_all_same1.h" +#include "utils/containers/sum.h" +#include "utils/containers/transform.h" +#include "utils/fmt/map.h" namespace FlexFlow { -/* bool ConcatAttrs::is_valid( */ -/* std::vector const &input) const { */ -/* bool valid = true; */ -/* for (auto p : input) { */ -/* valid &= p.is_valid(); */ -/* } */ -/* return valid; */ -/* } */ - -TensorShape get_output_shape(ConcatAttrs const &, - std::vector const &) { - NOT_IMPLEMENTED(); +tl::expected + get_output_shape(ConcatAttrs const &attrs, + std::vector const &inputs) { + auto get_non_axis_dims = [&](TensorShape const &s) { + std::map dim_sizes = enumerate(ff_ordered(s.dims)); + dim_sizes.erase(attrs.axis); + return dim_sizes; + }; + + if (inputs.size() <= 1) { + return tl::unexpected(fmt::format("get_output_shape for Concat expected 2 " + "or more input, but receieved {}", + inputs)); + } + + if (attrs.axis.value < 0) { + return tl::unexpected(fmt::format("ConcatAttrs requires axis >= 0")); + } + + if (!are_all_same(transform( + inputs, [](TensorShape const &s) { return num_dims(s); }))) { + return tl::unexpected( + fmt::format("get_output_shape for Concat expected all inputs to have " + "the same number of dimensions, but receieved {}", + inputs)); + } + + std::map non_axis_dims = ({ + tl::expected, std::string> returned = + require_all_same1(transform(inputs, get_non_axis_dims)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + std::vector axis_dim_sizes = transform( + inputs, [&](TensorShape const &s) { return dim_at_idx(s, attrs.axis); }); + + size_t output_axis_dim_size = sum(axis_dim_sizes); + + non_axis_dims.insert({attrs.axis, output_axis_dim_size}); + + DataType datatype = ({ + tl::expected returned = require_all_same1( + transform(inputs, [](TensorShape const &s) { return s.data_type; })); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + return TensorShape{ + TensorDims{ + ff_ordered_from_map(non_axis_dims), + }, + datatype, + }; } -ParallelTensorShape get_output_shape(ConcatAttrs const &, - std::vector const &) { - NOT_IMPLEMENTED(); +tl::expected + get_output_shape(ConcatAttrs const &attrs, + std::vector const &inputs) { + TensorShape unpar = ({ + tl::expected returned = + get_output_shape(attrs, transform(inputs, get_reduced_shape)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + SumDegree sum_degree = ({ + tl::expected returned = + require_all_same1(transform(inputs, get_sum_degree)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + SumDegree{returned.value()}; + }); + + DiscardCopyDegree discard_copy_degree = ({ + tl::expected returned = + require_all_same1(transform(inputs, get_discard_copy_degree)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + DiscardCopyDegree{returned.value()}; + }); + + if (!all_of(inputs, [&](ParallelTensorShape const &s) { + return shard_dim_at_idx(s, attrs.axis).degree == 1; + })) { + return tl::unexpected(fmt::format( + "get_output_shape for Concat expected input tensors to have parallel " + "degree 1 in the concat axis dimension, but received {}", + inputs)); + } + + ParallelTensorDimDegrees degrees = ({ + tl::expected returned = + require_all_same1(transform(inputs, [](ParallelTensorShape const &s) { + return get_parallel_degrees(s); + })); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + return lift_to_parallel_with_degrees(unpar, degrees); } } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d.cc index 03ae18a1d9..eac756cc15 100644 --- a/lib/op-attrs/src/op-attrs/ops/conv_2d.cc +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d.cc @@ -5,6 +5,20 @@ namespace FlexFlow { +std::vector + get_conv2d_incoming_tensor_roles(Conv2DAttrs const &attrs) { + std::vector result = { + IncomingTensorRole::INPUT, + IncomingTensorRole::WEIGHT, + }; + + if (attrs.use_bias) { + result.push_back(IncomingTensorRole::WEIGHT); + } + + return result; +} + TensorShape get_kernel_shape(Conv2DAttrs const &attrs, TensorShape const &raw_input_shape) { assert(attrs.groups == 1); // TODO(@lockshaw): currently not supported @@ -40,11 +54,11 @@ TensorShape get_output_shape(Conv2DAttrs const &attrs, Conv2DInputShape input = parse_input_shape(raw_input_shape); size_t out_height = - (input.height - (2 * attrs.padding_h) - (attrs.kernel_h - 1)) / - attrs.stride_h; + (input.height + (2 * attrs.padding_h) - attrs.kernel_h) / attrs.stride_h + + 1; size_t out_width = - (input.width - (2 * attrs.padding_w) - (attrs.kernel_w - 1)) / - attrs.stride_w; + (input.width + (2 * attrs.padding_w) - attrs.kernel_w) / attrs.stride_w + + 1; assert(attrs.out_channels > 0); diff --git a/lib/op-attrs/src/op-attrs/ops/flat.cc b/lib/op-attrs/src/op-attrs/ops/flat.cc index 5d318207ee..e9833d5e3f 100644 --- a/lib/op-attrs/src/op-attrs/ops/flat.cc +++ b/lib/op-attrs/src/op-attrs/ops/flat.cc @@ -1,57 +1,85 @@ #include "op-attrs/ops/flat.h" +#include "op-attrs/dim_ordered/concat.h" +#include "op-attrs/dim_ordered/slice.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/tensor_dims.h" +#include "utils/containers/any_of.h" +#include "utils/containers/product.h" #include namespace FlexFlow { -TensorShape get_output_shape(FlatAttrs const &, TensorShape const &) { - NOT_IMPLEMENTED(); +TensorShape get_output_shape(FlatAttrs const &attrs, + TensorShape const &input_shape) { + FFOrdered leading_dims = + slice(ff_ordered(input_shape.dims), ff_dim_t{0}, attrs.start_dim); + FFOrdered flattened_dims = + slice(ff_ordered(input_shape.dims), attrs.start_dim, attrs.end_dim); + FFOrdered trailing_dims = + slice(ff_ordered(input_shape.dims), attrs.end_dim, std::nullopt); + + if (flattened_dims.empty()) { + return input_shape; + } + + return TensorShape{ + TensorDims{ + concat(std::vector{ + leading_dims, + {product(flattened_dims)}, + trailing_dims, + }), + }, + input_shape.data_type, + }; } -ParallelTensorShape get_output_shape(FlatAttrs const &, - ParallelTensorShape const &) { - NOT_IMPLEMENTED(); +tl::expected + get_output_parallel_dim_degrees( + FlatAttrs const &attrs, ParallelTensorDimDegrees const &input_degrees) { + FFOrdered flattened_dim_degrees = + slice(input_degrees.shard_degrees, attrs.start_dim, attrs.end_dim); + + if (flattened_dim_degrees.empty()) { + return input_degrees; + } + + if (any_of(flattened_dim_degrees, [](int degree) { return degree != 1; })) { + return tl::unexpected( + fmt::format("get_output_parallel_dim_degrees for {} expected all shard " + "degrees of flattened dimensions to be 1, but received {}", + attrs, + input_degrees)); + } + + return ParallelTensorDimDegrees{ + /*sum_degree=*/input_degrees.sum_degree, + /*discard_copy_degree=*/input_degrees.discard_copy_degree, + /*shard_degrees=*/ + concat(std::vector{ + slice(input_degrees.shard_degrees, ff_dim_t{0}, attrs.start_dim), + {product(flattened_dim_degrees)}, + slice(input_degrees.shard_degrees, attrs.end_dim, std::nullopt), + }), + }; } -// namespace Input { -// constexpr int NUMDIM = 5, WIDTH = 0, HEIGHT = 1, CHANNEL = 2, SAMPLE = 3, -// REPLICA = 4; -// } -// -// namespace Output { -// constexpr int NUMDIM = 3, CHANNEL = 0, SAMPLE = 1, REPLICA = 2; -// } -// -/* bool FlatAttrs::is_valid(ParallelTensorShape const &input) const { */ -/* ParallelTensorShape output_shape = this->calculate_output_shape(input); */ - -/* bool is_valid = true; */ -/* is_valid &= input.is_valid(); */ -/* is_valid &= output_shape.is_valid(); */ -/* is_valid &= (input.at(Input::WIDTH).degree == 1); */ - -/* return is_valid; */ -/* } */ - -/* ParallelTensorShape FlatAttrs::calculate_output_shape(ParallelTensorShape - * const &input) const { */ -/* assert (input.num_dims() == Input::NUMDIM); */ -/* ParallelTensorShape output_dims; */ -/* output_dims.data_type = input.data_type; */ - -/* output_dims.at(Output::REPLICA) = input.at(Input::REPLICA); */ -/* output_dims.at(Output::SAMPLE) = input.at(Input::SAMPLE); */ - -/* output_dims.at(Output::CHANNEL).degree = input.at(Input::CHANNEL).degree; - */ -/* assert (input.at(Input::HEIGHT).degree == 1); */ -/* assert (input.at(Input::WIDTH).degree == 1); */ - -/* output_dims.at(Output::CHANNEL).size = input.at(Input::CHANNEL).size * - * input.at(Input::HEIGHT).size * input.at(Input::WIDTH).size; */ -/* output_dims.at(Output::CHANNEL).parallel_idx = - * input.at(Input::CHANNEL).parallel_idx; */ - -/* return output_dims; */ -/* } */ +tl::expected + get_output_shape(FlatAttrs const &attrs, + ParallelTensorShape const &input_shape) { + TensorShape unpar = get_output_shape(attrs, get_reduced_shape(input_shape)); + + ParallelTensorDimDegrees degrees = ({ + tl::expected returned = + get_output_parallel_dim_degrees(attrs, + get_parallel_degrees(input_shape)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + return lift_to_parallel_with_degrees(unpar, degrees); +} } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/layer_norm.cc b/lib/op-attrs/src/op-attrs/ops/layer_norm.cc index b9603d7850..0dd9ac7a17 100644 --- a/lib/op-attrs/src/op-attrs/ops/layer_norm.cc +++ b/lib/op-attrs/src/op-attrs/ops/layer_norm.cc @@ -6,10 +6,23 @@ #include "utils/containers/all_of.h" #include "utils/containers/any_of.h" #include "utils/containers/contains.h" +#include "utils/containers/extend.h" #include "utils/containers/filter.h" namespace FlexFlow { +std::vector + get_layer_norm_incoming_tensor_roles(LayerNormAttrs const &attrs) { + std::vector result = {IncomingTensorRole::INPUT}; + + if (attrs.elementwise_affine) { + extend(result, + std::vector{IncomingTensorRole::WEIGHT, IncomingTensorRole::WEIGHT}); + } + + return result; +} + static std::optional check_input_shape(LayerNormAttrs const &attrs, TensorShape const &input_shape) { @@ -99,7 +112,7 @@ static std::optional if (get_discard_copy_degree(input_shape) != 1) { return fmt::format( - "Expected discard copy degree 1, but received discartd copy degree {}", + "Expected discard copy degree 1, but received discard copy degree {}", get_discard_copy_degree(input_shape)); } diff --git a/lib/op-attrs/src/op-attrs/ops/linear.cc b/lib/op-attrs/src/op-attrs/ops/linear.cc index 24a8250690..feac647216 100644 --- a/lib/op-attrs/src/op-attrs/ops/linear.cc +++ b/lib/op-attrs/src/op-attrs/ops/linear.cc @@ -8,6 +8,20 @@ namespace FlexFlow { +std::vector + get_linear_incoming_tensor_roles(LinearAttrs const &attrs) { + std::vector result = { + IncomingTensorRole::INPUT, + IncomingTensorRole::WEIGHT, + }; + + if (attrs.use_bias) { + result.push_back(IncomingTensorRole::WEIGHT); + } + + return result; +} + RecordFormatter as_dot(LinearAttrs const &attrs) { RecordFormatter r; @@ -25,7 +39,8 @@ RecordFormatter as_dot(LinearAttrs const &attrs) { } tl::expected - get_kernel_shape(LinearAttrs const &attrs, TensorShape const &input_shape) { + get_projection_shape(LinearAttrs const &attrs, + TensorShape const &input_shape) { size_t in_channels = dim_at_idx(input_shape, ff_dim_t{-1}); return TensorShape{ @@ -56,11 +71,11 @@ tl::expected } tl::expected - get_kernel_shape(LinearAttrs const &attrs, - ParallelTensorShape const &input) { + get_projection_shape(LinearAttrs const &attrs, + ParallelTensorShape const &input) { TensorShape unpar = ({ tl::expected result_unpar = - get_kernel_shape(attrs, get_reduced_shape(input)); + get_projection_shape(attrs, get_reduced_shape(input)); if (!result_unpar.has_value()) { return tl::unexpected(result_unpar.error()); } diff --git a/lib/op-attrs/src/op-attrs/ops/pool_2d.cc b/lib/op-attrs/src/op-attrs/ops/pool_2d.cc index e1917efd89..95bcd8b336 100644 --- a/lib/op-attrs/src/op-attrs/ops/pool_2d.cc +++ b/lib/op-attrs/src/op-attrs/ops/pool_2d.cc @@ -1,62 +1,184 @@ #include "op-attrs/ops/pool_2d.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/tensor_dims.h" +#include "op-attrs/tensor_shape.h" +#include "utils/integer_conversions.h" namespace FlexFlow { -TensorShape get_output_shape(Pool2DAttrs const &, TensorShape const &) { - NOT_IMPLEMENTED(); -} +tl::expected + make_adaptive_pool2d_attrs(TensorDims const &input_dims, + int output_h, + int output_w, + PoolOp pool_type, + std::optional const &activation) { + // AdaptivePool2D semantics pulled from + // https://stackoverflow.com/questions/53841509/how-does-adaptive-pooling-in-pytorch-work/63603993 -ParallelTensorShape get_output_shape(Pool2DAttrs const &, - ParallelTensorShape const &) { - NOT_IMPLEMENTED(); -} + if (num_dims(input_dims) != 4) { + return tl::unexpected( + fmt::format("make_adaptive_pool2d_attrs expected input tensor to " + "have 4 dims, but received dims {}", + input_dims)); + } -} // namespace FlexFlow + size_t num_samples = dim_at_idx(input_dims, ff_dim_t{0}); + size_t num_channels = dim_at_idx(input_dims, ff_dim_t{1}); + size_t input_h = dim_at_idx(input_dims, ff_dim_t{2}); + size_t input_w = dim_at_idx(input_dims, ff_dim_t{3}); -/* -#include "op-attrs/ops/pool_2d.h" -#include "parallel_dim_mapping_record.h" -#include "parallel_dim_mapping_record_solver.h" + if (input_h % output_h != 0) { + return tl::unexpected(fmt::format( + "Currently make_adaptive_pool2d_attrs only supports input_h % output_h " + "== 0, but received input_h={} and output_h={} (input_dims={}). If you " + "need input_h % output_h != 0 supported, please create an issue.", + input_h, + output_h, + input_dims)); + } -namespace FlexFlow { + if (input_w % output_w != 0) { + return tl::unexpected(fmt::format( + "Currently make_adaptive_pool2d_attrs only supports input_w % output_w " + "== 0, but received input_w={} and output_w={} (input_dims={}). If you " + "need input_w % output_w != 0 supported, please create an issue.", + input_w, + output_w, + input_dims)); + } -namespace Input { -constexpr int NUMDIM = 5, WIDTH = 0, HEIGHT = 1, CHANNEL = 2, SAMPLE = 3, - REPLICA = 4; -}; + // Note that for some reason the stack overflow post linked above states that + // `kernel_size = ind - (outd-1)*stride`, but some simplification yields + // `kernel_size` = `ind - (outd - 1)*stride` + // = `ind - (outd - 1) * (ind / outd)` + // = `ind - ind + (ind /outd)` + // = `ind / outd` + // = `stride` -namespace Output { -constexpr int NUMDIM = 5, WIDTH = 0, HEIGHT = 1, CHANNEL = 2, SAMPLE = 3, - REPLICA = 4; -}; + int kernel_h = input_h / output_h; + int kernel_w = input_w / output_w; -bool Pool2DAttrs::is_valid(ParallelTensorShape const &input) const { - ParallelTensorShape output_shape = this->calculate_output_shape(input); + int stride_h = kernel_h; + int stride_w = kernel_w; - return output_shape.is_valid() && (input.at(Input::REPLICA).degree == 1); -} + Pool2DAttrs attrs = Pool2DAttrs{ + /*kernel_h=*/kernel_h, + /*kernel_w=*/kernel_w, + /*stride_h=*/stride_h, + /*stride_w=*/stride_w, + /*padding_h=*/0, + /*padding_w=*/0, + /*pool_type=*/pool_type, + /*activation=*/activation, + }; + + TensorShape expected_ouput_shape = TensorShape{ + TensorDims{FFOrdered{ + num_samples, + num_channels, + size_t_from_int(output_h), + size_t_from_int(output_w), + }}, + DataType::FLOAT, + }; -static std::vector - construct_mappings(ParallelTensorShape const &input_shape) { - auto const outputMappings = construct_output_parallel_dims({ - {Input::REPLICA, MappingOperation::PARTITION, Output::REPLICA}, - {Input::SAMPLE, MappingOperation::PARTITION, Output::SAMPLE}, - {Input::CHANNEL, MappingOperation::PARTITION, Output::CHANNEL}, - {Input::HEIGHT, MappingOperation::PARTITION, Output::HEIGHT}, - {Input::WIDTH, MappingOperation::PARTITION, Output::WIDTH}, + TensorShape output_shape = ({ + tl::expected result = + get_output_shape(attrs, TensorShape{input_dims, DataType::FLOAT}); + if (!result.has_value()) { + return tl::unexpected(result.error()); + } + result.value(); }); - return outputMappings; + if (output_shape != expected_ouput_shape) { + return tl::unexpected( + fmt::format("Result of make_adaptive_pool_2d (i.e., {}) should produce " + "expected output shape {}, but produced {}. This is a bug " + "in FlexFlow, Please create an issue.", + attrs, + expected_ouput_shape, + output_shape)); + } + + return attrs; } -static ParallelDimMappingSolution - solve_mappings(ParallelTensorShape const &input) { - return solve_parallel_dim_mappings(construct_mappings(input), {input}, 0, 1); +tl::expected + get_output_shape(Pool2DAttrs const &attrs, TensorShape const &input_shape) { + if (num_dims(input_shape) != 4) { + return tl::unexpected( + fmt::format("get_output_shape for Pool2DAttrs expected input tensor to " + "have 4 dims, but received shape {}", + input_shape)); + } + + size_t num_samples = dim_at_idx(input_shape, ff_dim_t{0}); + size_t num_channels = dim_at_idx(input_shape, ff_dim_t{1}); + size_t input_height = dim_at_idx(input_shape, ff_dim_t{2}); + size_t input_width = dim_at_idx(input_shape, ff_dim_t{3}); + + size_t output_height = + (input_height + 2 * attrs.padding_h - attrs.kernel_h) / attrs.stride_h + + 1; + + size_t output_width = + (input_width + 2 * attrs.padding_w - attrs.kernel_w) / attrs.stride_w + 1; + + return TensorShape{TensorDims{FFOrdered{ + num_samples, + num_channels, + output_height, + output_width, + }}, + input_shape.data_type}; } -ParallelTensorShape Pool2DAttrs::calculate_output_shape(ParallelTensorShape -const &input) const { return solve_mappings(input).output_shapes.at(0); +tl::expected + get_output_shape(Pool2DAttrs const &attrs, + ParallelTensorShape const &input_shape) { + TensorShape unpar = ({ + tl::expected result_unpar = + get_output_shape(attrs, get_reduced_shape(input_shape)); + if (!result_unpar.has_value()) { + return tl::unexpected(result_unpar.error()); + } + result_unpar.value(); + }); + + ParallelTensorDimDegrees degrees = ({ + tl::expected result_degrees = + get_output_parallel_dim_degrees(attrs, + get_parallel_degrees(input_shape)); + if (!result_degrees.has_value()) { + return tl::unexpected(result_degrees.error()); + } + result_degrees.value(); + }); + + return lift_to_parallel_with_degrees(unpar, degrees); +} + +tl::expected + get_output_parallel_dim_degrees( + Pool2DAttrs const &attrs, + ParallelTensorDimDegrees const &input_degrees) { + if (input_degrees.sum_degree.value > 1) { + if (attrs.pool_type == PoolOp::MAX) { + return tl::unexpected(fmt::format( + "get_output_parallel_dim_degrees for Pool2DAttrs with PoolOp::MAX " + "expected input sum degree == 1, but received {}", + input_degrees)); + } else if (attrs.activation.has_value()) { + return tl::unexpected(fmt::format( + "get_output_parallel_dim_degrees for Pool2DAttrs with activation={} " + "expected input sum degree == 1, but received {}", + attrs.activation.value(), + input_degrees)); + } + } + + return input_degrees; } } // namespace FlexFlow -*/ diff --git a/lib/op-attrs/src/op-attrs/ops/topk.cc b/lib/op-attrs/src/op-attrs/ops/topk.cc index 9d2fd35a94..7a6868340b 100644 --- a/lib/op-attrs/src/op-attrs/ops/topk.cc +++ b/lib/op-attrs/src/op-attrs/ops/topk.cc @@ -2,6 +2,10 @@ namespace FlexFlow { +TensorShape get_output_shape(TopKAttrs const &, TensorShape const &) { + NOT_IMPLEMENTED(); +} + ParallelTensorShape get_output_shape(TopKAttrs const &attrs, ParallelTensorShape const &input_shape) { NOT_IMPLEMENTED(); diff --git a/lib/op-attrs/src/op-attrs/parallel_dim.cc b/lib/op-attrs/src/op-attrs/parallel_dim.cc new file mode 100644 index 0000000000..26ba2b3fa1 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/parallel_dim.cc @@ -0,0 +1,14 @@ +#include "op-attrs/parallel_dim.h" +#include "utils/overload.h" + +namespace FlexFlow { + +int get_degree(ParallelDim const &dim) { + return dim.visit(overload{ + [](ShardParallelDim const &shard_dim) { return shard_dim.degree; }, + [](ReplicaParallelDim const &replica_dim) { + return replica_dim.degree; + }}); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc index 73c0068826..2955545561 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc @@ -1,12 +1,14 @@ #include "op-attrs/parallel_tensor_dims.h" #include "op-attrs/dim_ordered/transform.h" +#include "op-attrs/dim_ordered/zip.h" #include "op-attrs/replica_parallel_dim.h" #include "op-attrs/replica_parallel_dim_set.h" #include "op-attrs/shard_parallel_dim.h" +#include "op-attrs/tensor_dims.h" #include "utils/containers/all_of.h" -#include "utils/containers/as_vector.h" #include "utils/containers/product.h" #include "utils/containers/transform.h" +#include "utils/containers/vector_of.h" #include "utils/integer_conversions.h" namespace FlexFlow { @@ -29,13 +31,57 @@ size_t num_shard_dims(ParallelTensorDims const &dims) { return dims.shard_dims.size(); } +ParallelTensorDimDegrees get_parallel_degrees(ParallelTensorDims const &d) { + return ParallelTensorDimDegrees{ + d.replica_dims.sum_degree, + d.replica_dims.discard_copy_degree, + ff_ordered_shard_degrees(d), + }; +} + +ParallelTensorDims lift_to_parallel(TensorDims const &dims) { + std::vector shard_degrees(num_dims(dims), + 1); // 1 repeated num_dims(dims) times + return lift_to_parallel_with_degrees( + dims, SumDegree{1}, DiscardCopyDegree{1}, shard_degrees); +} + +ParallelTensorDims + lift_to_parallel_with_degrees(TensorDims const &unpar, + SumDegree const &sum_degree, + DiscardCopyDegree const &discard_copy_degree, + FFOrdered const &shard_degrees) { + std::vector lifted = + transform(zip(vector_of(unpar.ff_ordered), vector_of(shard_degrees)), + [](std::pair const &p) { + size_t size = p.first; + int degree = p.second; + return ShardParallelDim{size, degree}; + }); + + return ParallelTensorDims{FFOrdered{lifted}, + ReplicaParallelDimSet{ + sum_degree, + discard_copy_degree, + }}; +} + +ParallelTensorDims + lift_to_parallel_with_degrees(TensorDims const &unpar, + ParallelTensorDimDegrees const °rees) { + return lift_to_parallel_with_degrees(unpar, + degrees.sum_degree, + degrees.discard_copy_degree, + degrees.shard_degrees); +} + int total_replica_degree(ParallelTensorDims const &dims) { return dims.replica_dims.discard_copy_degree.value * dims.replica_dims.sum_degree.value; } int total_shard_degree(ParallelTensorDims const &dims) { - return product(transform(as_vector(dims.shard_dims), + return product(transform(vector_of(dims.shard_dims), [](ShardParallelDim const &d) { return d.degree; })); } diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc index 10bf5027a4..dcc567e0ca 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc @@ -1,9 +1,12 @@ #include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/parallel_tensor_dims.h" #include "op-attrs/tensor_dims.h" +#include "utils/containers/extend.h" #include "utils/containers/product.h" +#include "utils/containers/range.h" #include "utils/containers/transform.h" #include "utils/hash-utils.h" +#include "utils/overload.h" namespace FlexFlow { @@ -59,19 +62,32 @@ std::optional } } +ParallelTensorDimDegrees get_parallel_degrees(ParallelTensorShape const &s) { + return get_parallel_degrees(s.dims); +} + ParallelTensorShape lift_to_parallel(TensorShape const &s) { return ParallelTensorShape{lift_to_parallel(s.dims), s.data_type}; } ParallelTensorShape - lift_to_parallel_with_degrees(TensorShape const &s, - SumDegree sum_degree, - DiscardCopyDegree discard_copy_degree, + lift_to_parallel_with_degrees(TensorShape const &unpar, + SumDegree const &sum_degree, + DiscardCopyDegree const &discard_copy_degree, FFOrdered const &shard_degrees) { return ParallelTensorShape{ lift_to_parallel_with_degrees( - s.dims, sum_degree, discard_copy_degree, shard_degrees), - s.data_type, + unpar.dims, sum_degree, discard_copy_degree, shard_degrees), + unpar.data_type, + }; +} + +ParallelTensorShape + lift_to_parallel_with_degrees(TensorShape const &unpar, + ParallelTensorDimDegrees const °rees) { + return ParallelTensorShape{ + lift_to_parallel_with_degrees(unpar.dims, degrees), + unpar.data_type, }; } @@ -103,4 +119,30 @@ TensorShape get_reduced_shape(ParallelTensorShape const &s) { }; } +ParallelDim get_parallel_dim_at_idx(ParallelTensorShape const &shape, + parallel_tensor_dim_idx_t idx) { + return idx.visit( + overload{[&](ff_dim_t shard_dim) { + return ParallelDim{shape.dims.shard_dims.at(shard_dim)}; + }, + [&](ReplicaType replica_type) { + ReplicaParallelDimSet replicas = shape.dims.replica_dims; + int degree = (ReplicaType::SUM == replica_type + ? replicas.sum_degree.value + : replicas.discard_copy_degree.value); + return ParallelDim{ReplicaParallelDim{degree, replica_type}}; + }}); +} + +std::unordered_set + get_parallel_tensor_dim_indices(ParallelTensorShape const &shape) { + std::unordered_set indices; + extend(indices, transform(range(num_shard_dims(shape.dims)), [](int idx) { + return parallel_tensor_dim_idx_t(ff_dim_t(idx)); + })); + indices.insert(parallel_tensor_dim_idx_t(ReplicaType::SUM)); + indices.insert(parallel_tensor_dim_idx_t(ReplicaType::DISCARD_COPY)); + return indices; +} + } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc b/lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc index 0bb134da6b..4fe01c2c1a 100644 --- a/lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc +++ b/lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc @@ -15,56 +15,6 @@ OperatorType get_op_type(PCGOperatorAttrs const &attrs) { [](auto const &x) { return get_op_type(x); }); } -ComputationGraphOpAttrs - compgraph_op_attrs_from_pcg_op_attrs(PCGOperatorAttrs const &op) { - return op.visit(overload{ - [](BatchMatmulAttrs const &attrs) { - return ComputationGraphOpAttrs{attrs}; - }, - [](BatchNormAttrs const &attrs) { - return ComputationGraphOpAttrs{attrs}; - }, - [](CastAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](ConcatAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](Conv2DAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](DropoutAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](ElementBinaryAttrs const &attrs) { - return ComputationGraphOpAttrs{attrs}; - }, - [](ElementUnaryAttrs const &attrs) { - return ComputationGraphOpAttrs{attrs}; - }, - [](EmbeddingAttrs const &attrs) { - return ComputationGraphOpAttrs{attrs}; - }, - [](FlatAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](GatherAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](InputAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](LayerNormAttrs const &attrs) { - return ComputationGraphOpAttrs{attrs}; - }, - [](LinearAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](MultiHeadAttentionAttrs const &attrs) { - return ComputationGraphOpAttrs{attrs}; - }, - [](NoopAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](Pool2DAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](ReduceAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](ReverseAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](ReshapeAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](SplitAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](SoftmaxAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](TopKAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](TransposeAttrs const &attrs) { - return ComputationGraphOpAttrs{attrs}; - }, - [](auto const &attrs) -> ComputationGraphOpAttrs { - throw mk_runtime_error(fmt::format( - "Cannot convert parallel op to non-parallel, received {}", attrs)); - }, - }); -} - RecordFormatter as_dot(PCGOperatorAttrs const &attrs) { return attrs.visit(overload{ [](LinearAttrs const &l) { return as_dot(l); }, @@ -76,4 +26,11 @@ RecordFormatter as_dot(PCGOperatorAttrs const &attrs) { }); } +PCGOperatorAttrs pcg_op_attrs_from_compgraph_op_attrs( + ComputationGraphOpAttrs const &cg_attrs) { + return cg_attrs.visit(overload{ + [](auto const &attrs) { return PCGOperatorAttrs{attrs}; }, + }); +} + } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/tensor_dims.cc b/lib/op-attrs/src/op-attrs/tensor_dims.cc index e716793a8f..1bb050db52 100644 --- a/lib/op-attrs/src/op-attrs/tensor_dims.cc +++ b/lib/op-attrs/src/op-attrs/tensor_dims.cc @@ -3,9 +3,9 @@ #include "op-attrs/replica_parallel_dim_set.h" #include "op-attrs/shard_parallel_dim.dtg.h" #include "utils/containers/all_of.h" -#include "utils/containers/as_vector.h" #include "utils/containers/reversed.h" #include "utils/containers/transform.h" +#include "utils/containers/vector_of.h" #include "utils/containers/zip.h" #include "utils/integer_conversions.h" @@ -33,8 +33,8 @@ bool tensor_dims_is_broadcastable_to(TensorDims const &curr, return false; } - std::vector curr_dims = as_vector(curr.ff_ordered); - std::vector goal_dims = as_vector(goal.ff_ordered); + std::vector curr_dims = vector_of(curr.ff_ordered); + std::vector goal_dims = vector_of(goal.ff_ordered); for (auto const &[curr_dim, goal_dim] : zip(reversed(curr_dims), reversed(goal_dims))) { @@ -59,31 +59,4 @@ std::optional return std::nullopt; } -ParallelTensorDims lift_to_parallel(TensorDims const &dims) { - std::vector shard_degrees(num_dims(dims), - 1); // 1 repeated num_dims(dims) times - return lift_to_parallel_with_degrees( - dims, SumDegree{1}, DiscardCopyDegree{1}, shard_degrees); -} - -ParallelTensorDims - lift_to_parallel_with_degrees(TensorDims const &dims, - SumDegree sum_degree, - DiscardCopyDegree discard_copy_degree, - FFOrdered const &shard_degrees) { - std::vector lifted = - transform(zip(as_vector(dims.ff_ordered), as_vector(shard_degrees)), - [](std::pair const &p) { - size_t size = p.first; - int degree = p.second; - return ShardParallelDim(size, degree); - }); - - return ParallelTensorDims{FFOrdered{lifted}, - ReplicaParallelDimSet{ - sum_degree, - discard_copy_degree, - }}; -} - } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/tensor_shape.cc b/lib/op-attrs/src/op-attrs/tensor_shape.cc index b604d442cb..07508e3065 100644 --- a/lib/op-attrs/src/op-attrs/tensor_shape.cc +++ b/lib/op-attrs/src/op-attrs/tensor_shape.cc @@ -27,35 +27,4 @@ size_t get_size_in_bytes(TensorShape const &s) { return get_num_elements(s) * size_of_datatype(s.data_type); } -bool tensor_shape_is_broadcastable_to(TensorShape const &curr, - TensorShape const &goal) { - return tensor_dims_is_broadcastable_to(curr.dims, goal.dims) && - curr.data_type == goal.data_type; -} - -std::optional - get_broadcast_target_shape(std::unordered_set const &shapes) { - std::unordered_set datatypes = - transform(shapes, [](TensorShape const &s) { return s.data_type; }); - - if (datatypes.size() != 1) { - return std::nullopt; - } - - std::unordered_set shapes_dims = - transform(shapes, [](TensorShape const &s) { return s.dims; }); - - std::optional maybe_result_dims = - get_broadcast_target_dims(shapes_dims); - std::optional result = - transform(maybe_result_dims, [&](TensorDims const &result_dims) { - return TensorShape{ - result_dims, - get_only(datatypes), - }; - }); - - return result; -} - } // namespace FlexFlow diff --git a/lib/op-attrs/src/operator_attrs.cc b/lib/op-attrs/src/operator_attrs.cc deleted file mode 100644 index e6459c6819..0000000000 --- a/lib/op-attrs/src/operator_attrs.cc +++ /dev/null @@ -1,287 +0,0 @@ -#include "op-attrs/operator_attrs.h" -#include "utils/fmt.h" -#include "utils/record_formatter.h" -#include "utils/type_traits.h" - -namespace FlexFlow { - -/* OperatorType GetOpType::operator()(BatchMatmulAttrs const &p) const { return - * OP_BATCHMATMUL; } */ -/* OperatorType GetOpType::operator()(Conv2DAttrs const &p) const { return - * OP_CONV2D; } */ -/* OperatorType GetOpType::operator()(ConcatAttrs const &p) const { return - * OP_CONCAT; } */ -/* OperatorType GetOpType::operator()(CastAttrs const &p) const { return - * OP_CAST; } */ -/* OperatorType GetOpType::operator()(ElementBinaryAttrs const &p) const { - * return p.type; } */ -/* OperatorType GetOpType::operator()(ElementUnaryAttrs const &p) const { return - * p.op_type; } */ -/* OperatorType GetOpType::operator()(DropoutAttrs const &p) const { return - * OP_DROPOUT; } */ -/* OperatorType GetOpType::operator()(EmbeddingAttrs const &p) const { return - * OP_EMBEDDING; } */ -/* OperatorType GetOpType::operator()(FlatAttrs const &p) const { return - * OP_FLAT; } */ -/* OperatorType GetOpType::operator()(LayerNormAttrs const &p) const { return - * OP_LAYERNORM; } */ -/* OperatorType GetOpType::operator()(LinearAttrs const &p) const { return - * OP_LINEAR; } */ -/* OperatorType GetOpType::operator()(MultiHeadAttentionAttrs const &p) const { - * return OP_DROPOUT; } */ -/* OperatorType GetOpType::operator()(Pool2DAttrs const &p) const { return - * OP_POOL2D; } */ -/* OperatorType GetOpType::operator()(ReshapeAttrs const &p) const { return - * OP_RESHAPE; } */ -/* OperatorType GetOpType::operator()(SplitAttrs const &p) const { return - * OP_SPLIT; } */ -/* OperatorType GetOpType::operator()(SoftmaxAttrs const &p) const { return - * OP_SOFTMAX; } */ -/* OperatorType GetOpType::operator()(TransposeAttrs const &p) const { return - * OP_TRANSPOSE; } */ -/* OperatorType GetOpType::operator()(RepartitionAttrs const &p) const { return - * OP_REPARTITION; } */ -/* OperatorType GetOpType::operator()(ReplicateAttrs const &p) const { return - * OP_REPLICATE; } */ -/* OperatorType GetOpType::operator()(ReductionAttrs const &p) const { return - * OP_REDUCTION; } */ -/* OperatorType GetOpType::operator()(CombineAttrs const &p) const { return - * OP_COMBINE; } */ -/* OperatorType GetOpType::operator()(FusedParallelOpAttrs const &p) const { - * return OP_FUSED_PARALLEL; } */ - -/* struct AsOpAttrs { */ -/* template */ -/* OpAttrsInterface const &operator()(T const &p) { */ -/* return p; */ -/* } */ -/* }; */ - -/* OperatorType get_op_type(OpAttrsInterface const &o) { */ -/* return o.op_type(); */ -/* } */ -/* // */ -/* OperatorType get_op_type(CompGraphOperatorAttrs const &o) { */ -/* return get_op_type(visit(AsOpAttrs{}, o)); */ -/* } */ - -/* OperatorType get_op_type(PCGOperatorAttrs const &o) { */ -/* return get_op_type(visit(AsOpAttrs{}, o)); */ -/* } */ - -/* std::vector get_output_shapes(PCGOperatorAttrs const - * &op_params, std::vector const &input_tensor_shapes) { */ -/* return mpark::visit(AsOpAttrs{}, - * op_params).output_shapes(input_tensor_shapes); */ -/* } */ - -/* bool is_parallel_op(PCGOperatorAttrs const &o) { */ -/* return is_parallel_op(get_op_type(o)); */ -/* } */ -template -typename std::enable_if<(is_streamable::value && - !is_fmtable::value)>::type - as_dot(T const &t, RecordFormatter &r) { - std::ostringstream oss; - oss << t; - r << oss; -} - -template -typename std::enable_if<(is_fmtable::value)>::type - as_dot(T const &t, RecordFormatter &r) { - r << fmt::to_string(t); -} -void as_dot(int x, RecordFormatter &r) { - r << std::to_string(x); -} - -void as_dot(std::string const &s, RecordFormatter &r) { - r << s; -} - -template -void as_dot(std::vector const &x, RecordFormatter &r) { - RecordFormatter rr; - for (T const &t : x) { - as_dot(t, r); - } - r << rr; -} - -template -void as_dot(stack_vector const &x, RecordFormatter &r) { - RecordFormatter rr; - for (T const &t : x) { - as_dot(t, r); - } - r << rr; -} - -struct as_dot_visitor { - as_dot_visitor() = delete; - as_dot_visitor(RecordFormatter &result) : result(result) {} - - RecordFormatter &result; - - template - void operator()(char const *name, T const &t) { - RecordFormatter kv; - kv << name; - as_dot(t, result); - result << kv; - } - - template - void operator()(T const &t) { - as_dot(t, result); - } - - /* template */ - /* void operator()(const char *name, std::vector const &t) { */ - /* RecordFormatter kv; */ - /* kv << name; */ - /* RecordFormatter v; */ - /* for (V const &vv : t) { */ - /* v << as_dot_str(vv); */ - /* } */ - /* kv << v; */ - /* } */ -}; - -template -typename std::enable_if::value>::type - as_dot(T const &t, RecordFormatter &r) { - as_dot_visitor vis(r); - visit_struct::for_each(t, vis); -} - -struct AsDot { - template - RecordFormatter operator()(T const &t) { - return as_dot(t); - } -}; - -template -RecordFormatter as_dot(std::variant const &o) { - return std::visit(AsDot{}, o); -} - -struct IsValidFunctor { - IsValidFunctor(std::vector const &_input_shapes) - : input_shapes(_input_shapes) {} - - std::vector const &input_shapes; - - // bool operator()(AggregateAttrs const &attrs) { - // return is_valid(attrs, - // input_shapes.at(0), - // input_shapes.at(1), - // input_shapes.at(2), - // input_shapes.at(3), - // subvec(input_shapes, 4, nullopt)); - // } - - template - bool operator()(T const &) { - return true; // TODO FIXME @lockshaw - } -}; - -bool is_valid(PCGOperatorAttrs const &attrs, - std::vector const &input_shapes) { - NOT_IMPLEMENTED(); -} - -/* int num_outputs(OperatorParameters const &o) { */ -/* switch (get_op_type(o)) { */ -/* case OP_SPLIT: */ -/* } */ -/* } */ - -// tl::optional get_op_parameters(Op const *op) { -// switch (op->op_type) { -// case OP_LINEAR: -// return ((Linear *)op)->get_params(); -// case OP_CONV2D: -// return ((Conv2D *)op)->get_params(); -// case OP_EW_ADD: -// case OP_EW_SUB: -// case OP_EW_MUL: -// case OP_EW_DIV: -// return ((ElementBinary *)op)->get_params(); -// case OP_EXP: -// case OP_SIN: -// case OP_COS: -// case OP_SCALAR_MULTIPLY: -// case OP_SCALAR_ADD: -// case OP_SCALAR_SUB: -// case OP_SCALAR_TRUE_DIV: -// case OP_RELU: -// case OP_SIGMOID: -// case OP_TANH: -// case OP_IDENTITY: -// case OP_GELU: -// case OP_ELU: -// return ((ElementUnary *)op)->get_params(); -// case OP_CONCAT: -// return ((Concat *)op)->get_params(); -// case OP_POOL2D: -// return ((Pool2D *)op)->get_params(); -// case OP_CAST: -// return ((Cast *)op)->get_params(); -// case OP_DROPOUT: -// return ((Dropout *)op)->get_params(); -// case OP_EMBEDDING: -// return ((Embedding *)op)->get_params(); -// case OP_FLAT: -// return ((Flat *)op)->get_params(); -// case OP_MULTIHEAD_ATTENTION: -// return ((MultiHeadAttention *)op)->get_params(); -// case OP_LAYERNORM: -// return ((LayerNorm *)op)->get_params(); -// case OP_RESHAPE: -// return ((Reshape *)op)->get_params(); -// case OP_SOFTMAX: -// return ((Softmax *)op)->get_params(); -// case OP_REPARTITION: -// return ((Repartition *)op)->get_params(); -// case OP_REPLICATE: -// return ((Replicate *)op)->get_params(); -// case OP_REDUCTION: -// return ((Reduction *)op)->get_params(); -// case OP_COMBINE: -// return ((Combine *)op)->get_params(); -// case OP_FUSED_PARALLEL: -// return ((FusedParallelOp *)op)->get_params(); -// case OP_TRANSPOSE: -// return ((Transpose *)op)->get_params(); -// case OP_BATCHMATMUL: -// return ((BatchMatmul *)op)->get_params(); -// case OP_SPLIT: -// return ((Split *)op)->get_params(); -// -// // TODO: implement the get_params() function for the operators below -// and -// // uncomment the lines below -// -// // case OP_NOOP: -// // return ((NoOp *)op)->get_params(); -// // case OP_TOPK: -// // return ((TopK *)op)->get_params(); -// // case OP_MEAN: -// // return ((Mean *)op)->get_params(); -// // case OP_CACHE: -// // return ((Cache *)op)->get_params(); -// // case OP_REVERSE: -// // return ((Reverse *)op)->get_params(); -// // case OP_BATCHNORM: -// // return ((BatchNorm *)op)->get_params(); -// -// default: -// return tl::nullopt; -// } -// } - -} // namespace FlexFlow diff --git a/lib/op-attrs/test/src/op-attrs/computation_graph_op_attrs.cc b/lib/op-attrs/test/src/op-attrs/computation_graph_op_attrs.cc new file mode 100644 index 0000000000..84f1861f0b --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/computation_graph_op_attrs.cc @@ -0,0 +1,19 @@ +#include "op-attrs/computation_graph_op_attrs.dtg.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("ComputationGraphOpAttrs to/from json") { + ComputationGraphOpAttrs correct = ComputationGraphOpAttrs{BatchNormAttrs{ + /*relu=*/false, + /*affine=*/true, + /*eps=*/1e-5, + /*momentum=*/0.1, + }}; + nlohmann::json j = correct; + auto result = j.get(); + + CHECK(result == correct); + } +} diff --git a/lib/op-attrs/test/src/datatype.cc b/lib/op-attrs/test/src/op-attrs/datatype.cc similarity index 95% rename from lib/op-attrs/test/src/datatype.cc rename to lib/op-attrs/test/src/op-attrs/datatype.cc index cc7e496c60..d45c156d59 100644 --- a/lib/op-attrs/test/src/datatype.cc +++ b/lib/op-attrs/test/src/op-attrs/datatype.cc @@ -1,6 +1,8 @@ #include "op-attrs/datatype.h" -#include "test/utils/doctest.h" #include "test/utils/rapidcheck.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("can_promote_datatype_from_to(DataType, DataType)") { diff --git a/lib/op-attrs/test/src/op-attrs/dim_ordered/concat.cc b/lib/op-attrs/test/src/op-attrs/dim_ordered/concat.cc new file mode 100644 index 0000000000..2ac641cfc2 --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/dim_ordered/concat.cc @@ -0,0 +1,66 @@ +#include "op-attrs/dim_ordered/concat.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("concat(FFOrdered, FFOrdered)") { + SUBCASE("inputs have elements") { + FFOrdered l_input = FFOrdered{1, 3, 1}; + FFOrdered r_input = FFOrdered{2, 1}; + + FFOrdered result = concat(l_input, r_input); + FFOrdered correct = {1, 3, 1, 2, 1}; + + CHECK(result == correct); + } + + SUBCASE("inputs are empty") { + FFOrdered l_input = FFOrdered{}; + FFOrdered r_input = FFOrdered{}; + + FFOrdered result = concat(l_input, r_input); + FFOrdered correct = {}; + + CHECK(result == correct); + } + } + + TEST_CASE("concat(std::vector>)") { + SUBCASE("inputs have elements") { + std::vector> input = { + {1}, + {2, 1}, + {1}, + }; + + FFOrdered result = concat(input); + FFOrdered correct = { + 1, + 2, + 1, + 1, + }; + + CHECK(result == correct); + } + + SUBCASE("no inputs") { + std::vector> input = {}; + + FFOrdered result = concat(input); + FFOrdered correct = {}; + + CHECK(result == correct); + } + + SUBCASE("inputs are empty") { + std::vector> input = {{}, {}, {}}; + + FFOrdered result = concat(input); + FFOrdered correct = {}; + + CHECK(result == correct); + } + } +} diff --git a/lib/op-attrs/test/src/test_dim_ordered.cc b/lib/op-attrs/test/src/op-attrs/dim_ordered/dim_ordered.cc similarity index 89% rename from lib/op-attrs/test/src/test_dim_ordered.cc rename to lib/op-attrs/test/src/op-attrs/dim_ordered/dim_ordered.cc index ac05767800..d7901a0c53 100644 --- a/lib/op-attrs/test/src/test_dim_ordered.cc +++ b/lib/op-attrs/test/src/op-attrs/dim_ordered/dim_ordered.cc @@ -1,5 +1,5 @@ +#include "op-attrs/dim_ordered/dim_ordered.h" #include "doctest/doctest.h" -#include "op-attrs/dim_ordered.h" #include "test/utils/rapidcheck.h" using namespace FlexFlow; diff --git a/lib/op-attrs/test/src/op-attrs/dim_ordered/enumerate.cc b/lib/op-attrs/test/src/op-attrs/dim_ordered/enumerate.cc index d2c758a05f..180bc2a01f 100644 --- a/lib/op-attrs/test/src/op-attrs/dim_ordered/enumerate.cc +++ b/lib/op-attrs/test/src/op-attrs/dim_ordered/enumerate.cc @@ -1,5 +1,5 @@ #include "op-attrs/dim_ordered/enumerate.h" -#include "utils/fmt/map.h" +#include "test/utils/doctest/fmt/map.h" #include using namespace ::FlexFlow; diff --git a/lib/op-attrs/test/src/op-attrs/dim_ordered/ff_ordered_from_map.cc b/lib/op-attrs/test/src/op-attrs/dim_ordered/ff_ordered_from_map.cc new file mode 100644 index 0000000000..7bc1695e5c --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/dim_ordered/ff_ordered_from_map.cc @@ -0,0 +1,66 @@ +#include "op-attrs/dim_ordered/ff_ordered_from_map.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE_TEMPLATE("ff_ordered_from_map", + T, + std::map, + std::unordered_map) { + SUBCASE("input is empty") { + T m = {}; + + FFOrdered result = ff_ordered_from_map(m); + FFOrdered correct = {}; + + CHECK(result == correct); + } + + SUBCASE("input is missing keys") { + SUBCASE("missing key is in middle") { + T m = { + {ff_dim_t{0}, 4}, + {ff_dim_t{1}, 2}, + {ff_dim_t{3}, 5}, + }; + + CHECK_THROWS(ff_ordered_from_map(m)); + } + + SUBCASE("missing key is 0 idx") { + T m = { + {ff_dim_t{1}, 2}, + {ff_dim_t{2}, 7}, + {ff_dim_t{3}, 5}, + }; + + CHECK_THROWS(ff_ordered_from_map(m)); + } + } + + SUBCASE("input has negative keys") { + T m = { + {ff_dim_t{0}, 4}, + {ff_dim_t{1}, 5}, + {ff_dim_t{-1}, 2}, + }; + + CHECK_THROWS(ff_ordered_from_map(m)); + } + + SUBCASE("input is valid") { + T m = { + {ff_dim_t{0}, 4}, + {ff_dim_t{1}, 5}, + {ff_dim_t{2}, 2}, + {ff_dim_t{3}, 7}, + }; + + FFOrdered result = ff_ordered_from_map(m); + FFOrdered correct = {4, 5, 2, 7}; + + CHECK(result == correct); + } + } +} diff --git a/lib/op-attrs/test/src/dim_ordered/slice.cc b/lib/op-attrs/test/src/op-attrs/dim_ordered/slice.cc similarity index 88% rename from lib/op-attrs/test/src/dim_ordered/slice.cc rename to lib/op-attrs/test/src/op-attrs/dim_ordered/slice.cc index 8640b077dc..8d5f247756 100644 --- a/lib/op-attrs/test/src/dim_ordered/slice.cc +++ b/lib/op-attrs/test/src/op-attrs/dim_ordered/slice.cc @@ -1,5 +1,7 @@ #include "op-attrs/dim_ordered/slice.h" -#include "test/utils/doctest.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE( diff --git a/lib/op-attrs/test/src/op-attrs/dim_ordered/zip.cc b/lib/op-attrs/test/src/op-attrs/dim_ordered/zip.cc index 11e09dc43f..8e3d0f1b80 100644 --- a/lib/op-attrs/test/src/op-attrs/dim_ordered/zip.cc +++ b/lib/op-attrs/test/src/op-attrs/dim_ordered/zip.cc @@ -1,6 +1,6 @@ #include "op-attrs/dim_ordered/zip.h" #include "op-attrs/ff_dim.dtg.h" -#include "utils/fmt/pair.h" +#include "test/utils/doctest/fmt/pair.h" #include using namespace ::FlexFlow; diff --git a/lib/op-attrs/test/src/op-attrs/get_incoming_tensor_roles.cc b/lib/op-attrs/test/src/op-attrs/get_incoming_tensor_roles.cc new file mode 100644 index 0000000000..33cc00c6a1 --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/get_incoming_tensor_roles.cc @@ -0,0 +1,26 @@ +#include "op-attrs/get_incoming_tensor_roles.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE( + "get_incoming_tensor_roles(ComputationGraphOpAttrs, int num_incoming)") { + SUBCASE("Concat") { + int num_incoming = 4; + ComputationGraphOpAttrs attrs = + ComputationGraphOpAttrs{ConcatAttrs{ff_dim_t{0}}}; + + std::vector result = + get_incoming_tensor_roles(attrs, num_incoming); + std::vector correct = { + IncomingTensorRole::INPUT, + IncomingTensorRole::INPUT, + IncomingTensorRole::INPUT, + IncomingTensorRole::INPUT, + }; + + CHECK(result == correct); + } + } +} diff --git a/lib/op-attrs/test/src/ops/attention.cc b/lib/op-attrs/test/src/op-attrs/ops/attention.cc similarity index 86% rename from lib/op-attrs/test/src/ops/attention.cc rename to lib/op-attrs/test/src/op-attrs/ops/attention.cc index ade219a6a9..eca8559b21 100644 --- a/lib/op-attrs/test/src/ops/attention.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/attention.cc @@ -1,9 +1,61 @@ #include "op-attrs/ops/attention.h" #include "op-attrs/parallel_tensor_shape.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/expected.h" #include "utils/integer_conversions.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_attention_incoming_tensor_roles(MultiHeadAttentionAttrs)") { + auto make_attrs = [](bool bias) { + return MultiHeadAttentionAttrs{ + /*embed_dim=*/32, + /*num_heads=*/10, + /*kdim=*/32, + /*vdim=*/32, + /*dropout=*/0.0, + /*bias=*/bias, + /*add_bias_kv=*/false, + /*add_zero_attn=*/false, + }; + }; + + SUBCASE("without bias") { + MultiHeadAttentionAttrs attrs = make_attrs(/*bias=*/false); + + tl::expected, std::string> result = + get_attention_incoming_tensor_roles(attrs); + tl::expected, std::string> correct = + std::vector{ + IncomingTensorRole::INPUT, + IncomingTensorRole::INPUT, + IncomingTensorRole::INPUT, + IncomingTensorRole::WEIGHT, + }; + + CHECK(result == correct); + } + + SUBCASE("with bias") { + MultiHeadAttentionAttrs attrs = make_attrs(/*bias=*/true); + + tl::expected, std::string> result = + get_attention_incoming_tensor_roles(attrs); + tl::expected, std::string> correct = + std::vector{ + IncomingTensorRole::INPUT, + IncomingTensorRole::INPUT, + IncomingTensorRole::INPUT, + IncomingTensorRole::WEIGHT, + IncomingTensorRole::WEIGHT, + IncomingTensorRole::WEIGHT, + }; + + CHECK(result == correct); + } + } + TEST_CASE("get_output_shape(MultiHeadAttentionAttrs, TensorShape, " "TensorShape, TensorShape)") { int embed_dim = 32; diff --git a/lib/op-attrs/test/src/ops/batch_matmul.cc b/lib/op-attrs/test/src/op-attrs/ops/batch_matmul.cc similarity index 98% rename from lib/op-attrs/test/src/ops/batch_matmul.cc rename to lib/op-attrs/test/src/op-attrs/ops/batch_matmul.cc index 3ff02ccece..56a2e3fa52 100644 --- a/lib/op-attrs/test/src/ops/batch_matmul.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/batch_matmul.cc @@ -1,5 +1,8 @@ #include "op-attrs/ops/batch_matmul.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/expected.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_output_shape(BatchMatmulAttrs, TensorShape)") { diff --git a/lib/op-attrs/test/src/op-attrs/ops/batch_norm.cc b/lib/op-attrs/test/src/op-attrs/ops/batch_norm.cc new file mode 100644 index 0000000000..4196394d00 --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/ops/batch_norm.cc @@ -0,0 +1,404 @@ +#include "op-attrs/ops/batch_norm.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "utils/expected.h" +#include "utils/fmt/expected.h" +#include "utils/fmt/optional.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_batch_norm_incoming_tensor_roles(BatchNormAttrs)") { + auto make_attrs = [](bool affine) { + return BatchNormAttrs{ + /*relu=*/false, + /*affine=*/affine, + /*eps=*/1.0, + /*momentum=*/0.1, + }; + }; + + SUBCASE("affine = true") { + BatchNormAttrs attrs = make_attrs(/*affine=*/true); + + std::vector result = + get_batch_norm_incoming_tensor_roles(attrs); + std::vector correct = { + IncomingTensorRole::INPUT, + IncomingTensorRole::WEIGHT, + IncomingTensorRole::WEIGHT, + }; + + CHECK(result == correct); + } + + SUBCASE("affine = false") { + BatchNormAttrs attrs = make_attrs(/*affine=*/false); + + std::vector result = + get_batch_norm_incoming_tensor_roles(attrs); + std::vector correct = { + IncomingTensorRole::INPUT, + }; + + CHECK(result == correct); + } + } + + TEST_CASE("shape inference (BatchNorm)") { + BatchNormAttrs attrs_affine_true = BatchNormAttrs{ + /*relu=*/false, + /*affine=*/true, + /*eps=*/1.0, + /*momentum=*/0.1, + }; + + BatchNormAttrs attrs_affine_false = [&] { + BatchNormAttrs attrs = attrs_affine_true; + attrs.affine = false; + return attrs; + }(); + + TensorShape input = TensorShape{ + TensorDims{FFOrdered{ + 12, + 14, + 16, + 18, + }}, + DataType::FLOAT, + }; + + TensorShape output = input; + + TensorShape gamma = TensorShape{ + TensorDims{FFOrdered{ + 14, + }}, + DataType::FLOAT, + }; + + TensorShape beta = gamma; + + SUBCASE("get_output_shape(BatchNormAttrs, TensorShape)") { + tl::expected result = + get_output_shape(attrs_affine_true, input); + tl::expected correct = output; + + CHECK(result == correct); + } + + SUBCASE("get_gamma_weights_shape(BatchNormAttrs, TensorShape)") { + SUBCASE("affine = true") { + tl::expected result = + get_gamma_weights_shape(attrs_affine_true, input); + tl::expected correct = gamma; + + CHECK(result == correct); + } + + SUBCASE("affine = false") { + std::optional result = optional_from_expected( + get_gamma_weights_shape(attrs_affine_false, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + + SUBCASE("get_beta_weights_shape(BatchNormAttrs, TensorShape)") { + SUBCASE("affine = true") { + tl::expected result = + get_beta_weights_shape(attrs_affine_true, input); + tl::expected correct = beta; + + CHECK(result == correct); + } + + SUBCASE("affine = false") { + std::optional result = optional_from_expected( + get_beta_weights_shape(attrs_affine_false, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + } + + TEST_CASE("parallel dim degree inference (BatchNormAttrs)") { + BatchNormAttrs attrs_affine_true = BatchNormAttrs{ + /*relu=*/false, + /*affine=*/true, + /*eps=*/1.0, + /*momentum=*/0.1, + }; + + BatchNormAttrs attrs_affine_false = [&] { + BatchNormAttrs attrs = attrs_affine_true; + attrs.affine = false; + return attrs; + }(); + + SUBCASE("partition parallelism (in channel dim)") { + int degree = 2; + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{ + 1, + degree, + 1, + 1, + }, + }; + + SUBCASE("get_output_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + tl::expected result = + get_output_parallel_dim_degrees(attrs_affine_true, input); + tl::expected correct = input; + + CHECK(result == correct); + } + + SUBCASE("get_gamma_weights_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + SUBCASE("affine = true") { + tl::expected result = + get_gamma_weights_parallel_dim_degrees(attrs_affine_true, input); + tl::expected correct = + ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{degree}, + }; + + CHECK(result == correct); + } + + SUBCASE("affine = false") { + std::optional result = + optional_from_expected(get_gamma_weights_parallel_dim_degrees( + attrs_affine_false, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + + SUBCASE("get_beta_weights_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + SUBCASE("affine = true") { + tl::expected result = + get_beta_weights_parallel_dim_degrees(attrs_affine_true, input); + tl::expected correct = + ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{degree}, + }; + + CHECK(result == correct); + } + + SUBCASE("affine = false") { + std::optional result = + optional_from_expected(get_beta_weights_parallel_dim_degrees( + attrs_affine_false, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + } + + SUBCASE("partition parallelism (not in channel dim)") { + int degree = 2; + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{1, 1, degree, 1}, + }; + + SUBCASE("get_output_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + std::optional result = optional_from_expected( + get_output_parallel_dim_degrees(attrs_affine_true, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("get_gamma_weights_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + std::optional result = optional_from_expected( + get_gamma_weights_parallel_dim_degrees(attrs_affine_true, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("get_beta_weights_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + std::optional result = optional_from_expected( + get_beta_weights_parallel_dim_degrees(attrs_affine_true, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + + SUBCASE("sum parallelism") { + SumDegree sum_degree = SumDegree{2}; + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + sum_degree, + DiscardCopyDegree{1}, + FFOrdered{1, 1, 1, 1}, + }; + + SUBCASE("get_output_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + std::optional result = optional_from_expected( + get_output_parallel_dim_degrees(attrs_affine_true, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("get_gamma_weights_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + std::optional result = optional_from_expected( + get_gamma_weights_parallel_dim_degrees(attrs_affine_true, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("get_beta_weights_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + std::optional result = optional_from_expected( + get_beta_weights_parallel_dim_degrees(attrs_affine_true, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + + SUBCASE("discard copy parallelism") { + DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{2}; + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{1}, + discard_copy_degree, + FFOrdered{1, 1, 1, 1}, + }; + + SUBCASE("get_output_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + std::optional result = optional_from_expected( + get_output_parallel_dim_degrees(attrs_affine_true, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("get_gamma_weights_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + std::optional result = optional_from_expected( + get_gamma_weights_parallel_dim_degrees(attrs_affine_true, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("get_beta_weights_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + std::optional result = optional_from_expected( + get_beta_weights_parallel_dim_degrees(attrs_affine_true, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + } + + TEST_CASE("parallel shape inference (BatchNormAttrs)") { + // since most of the edge cases are already tested in the above test cases + // (i.e., shape inference and parallel degree inference) + // here we just do a basic check that they compose + + BatchNormAttrs attrs = BatchNormAttrs{ + /*relu=*/true, + /*affine=*/true, + /*eps=*/1.0, + /*momentum=*/0.1, + }; + + ParallelTensorShape input = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{12, 1}, + ShardParallelDim{14, 2}, + ShardParallelDim{16, 1}, + ShardParallelDim{18, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + SUBCASE("get_output_shape(BatchNormAttrs, ParallelTensorShape)") { + tl::expected result = + get_output_shape(attrs, input); + tl::expected correct = input; + + CHECK(result == correct); + } + + SUBCASE("get_gamma_weights_shape(BatchNormAttrs, ParallelTensorShape)") { + tl::expected result = + get_gamma_weights_shape(attrs, input); + tl::expected correct = + ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{14, 2}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + + SUBCASE("get_beta_weights_shape(BatchNormAttrs, ParallelTensorShape)") { + tl::expected result = + get_beta_weights_shape(attrs, input); + tl::expected correct = + ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{14, 2}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + } +} diff --git a/lib/op-attrs/test/src/op-attrs/ops/batch_norm_attrs.cc b/lib/op-attrs/test/src/op-attrs/ops/batch_norm_attrs.cc new file mode 100644 index 0000000000..3d86576279 --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/ops/batch_norm_attrs.cc @@ -0,0 +1,20 @@ +#include "op-attrs/ops/batch_norm_attrs.dtg.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("BatchNormAttrs to/from json") { + BatchNormAttrs correct = BatchNormAttrs{ + /*relu=*/false, + /*affine=*/true, + /*eps=*/1e-5, + /*momentum=*/0.1, + }; + + nlohmann::json j = correct; + BatchNormAttrs result = j.get(); + + CHECK(result == correct); + } +} diff --git a/lib/op-attrs/test/src/ops/cast.cc b/lib/op-attrs/test/src/op-attrs/ops/cast.cc similarity index 95% rename from lib/op-attrs/test/src/ops/cast.cc rename to lib/op-attrs/test/src/op-attrs/ops/cast.cc index 31030ca0f9..c7395316ad 100644 --- a/lib/op-attrs/test/src/ops/cast.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/cast.cc @@ -1,6 +1,9 @@ #include "op-attrs/ops/cast.h" #include "op-attrs/parallel_tensor_shape.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/expected.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Cast shape inference") { diff --git a/lib/op-attrs/test/src/ops/combine.cc b/lib/op-attrs/test/src/op-attrs/ops/combine.cc similarity index 93% rename from lib/op-attrs/test/src/ops/combine.cc rename to lib/op-attrs/test/src/op-attrs/ops/combine.cc index ac18bbc798..bf74a072e0 100644 --- a/lib/op-attrs/test/src/ops/combine.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/combine.cc @@ -1,5 +1,8 @@ #include "op-attrs/ops/combine.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/expected.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Combine shape inference") { diff --git a/lib/op-attrs/test/src/op-attrs/ops/concat.cc b/lib/op-attrs/test/src/op-attrs/ops/concat.cc new file mode 100644 index 0000000000..9e842c3ebe --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/ops/concat.cc @@ -0,0 +1,331 @@ +#include "op-attrs/ops/concat.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "test/utils/doctest/fmt/expected.h" +#include "test/utils/doctest/fmt/optional.h" +#include "utils/expected.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_output_shape(ConcatAttrs, std::vector)") { + ConcatAttrs attrs = ConcatAttrs{ + ff_dim_t{1}, + }; + + SUBCASE("empty input shapes list passed") { + std::vector input_shapes = {}; + + std::optional result = + optional_from_expected(get_output_shape(attrs, input_shapes)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + size_t dim0_size = 12; + size_t dim2_size = 20; + TensorShape input_shape1 = TensorShape{ + TensorDims{FFOrdered{ + dim0_size, + 14, + dim2_size, + }}, + DataType::FLOAT, + }; + + SUBCASE("single element input shapes list passed") { + std::vector input_shapes = {input_shape1}; + + std::optional result = + optional_from_expected(get_output_shape(attrs, input_shapes)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + TensorShape input_shape2 = TensorShape{ + TensorDims{FFOrdered{ + dim0_size, + 16, + dim2_size, + }}, + DataType::FLOAT, + }; + + TensorShape input_shape3 = TensorShape{ + TensorDims{FFOrdered{dim0_size, 18, dim2_size}}, + DataType::FLOAT, + }; + + SUBCASE("input shapes do not shared the same num_dims") { + TensorShape mismatched_num_dims = TensorShape{ + TensorDims{FFOrdered{ + dim0_size, + 20, + dim2_size, + 1, + }}, + DataType::FLOAT, + }; + + std::vector input_shapes = { + input_shape1, input_shape2, input_shape3, mismatched_num_dims}; + + std::optional result = + optional_from_expected(get_output_shape(attrs, input_shapes)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("concat axis is out of bounds") { + attrs = ConcatAttrs{ + ff_dim_t{3}, + }; + + std::vector input_shapes = { + input_shape1, input_shape2, input_shape3}; + + std::optional result = + optional_from_expected(get_output_shape(attrs, input_shapes)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("input shapes are valid") { + std::vector input_shapes = { + input_shape1, input_shape2, input_shape3}; + + tl::expected result = + get_output_shape(attrs, input_shapes); + tl::expected correct = TensorShape{ + TensorDims{FFOrdered{ + dim0_size, + 14 + 16 + 18, + dim2_size, + }}, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + } + + TEST_CASE("get_output_shape(ConcatAttrs, std::vector)") { + ConcatAttrs attrs = ConcatAttrs{ + ff_dim_t{1}, + }; + + size_t dim0_size = 12; + size_t dim2_size = 20; + + TensorShape input_shape1 = TensorShape{ + TensorDims{FFOrdered{ + dim0_size, + 14, + dim2_size, + }}, + DataType::FLOAT, + }; + + TensorShape input_shape2 = TensorShape{ + TensorDims{FFOrdered{ + dim0_size, + 16, + dim2_size, + }}, + DataType::FLOAT, + }; + + TensorShape input_shape3 = TensorShape{ + TensorDims{FFOrdered{dim0_size, 18, dim2_size}}, + DataType::FLOAT, + }; + + TensorShape output_shape = TensorShape{ + TensorDims{FFOrdered{dim0_size, 14 + 16 + 18, dim2_size}}, + DataType::FLOAT, + }; + + auto lift_input1 = + [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o0, int o1, int o2) { + return lift_to_parallel_with_degrees( + input_shape1, o_sum, o_eq, FFOrdered{o0, o1, o2}); + }; + + auto lift_input2 = + [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o0, int o1, int o2) { + return lift_to_parallel_with_degrees( + input_shape2, o_sum, o_eq, FFOrdered{o0, o1, o2}); + }; + + auto lift_input3 = + [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o0, int o1, int o2) { + return lift_to_parallel_with_degrees( + input_shape3, o_sum, o_eq, FFOrdered{o0, o1, o2}); + }; + + auto lift_output = + [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o0, int o1, int o2) { + return lift_to_parallel_with_degrees( + output_shape, o_sum, o_eq, FFOrdered{o0, o1, o2}); + }; + + SUBCASE("sum reduction parallelism") { + SUBCASE("matching") { + SumDegree sum_degree = SumDegree{2}; + + std::vector inputs = { + lift_input1(sum_degree, DiscardCopyDegree{1}, 1, 1, 1), + lift_input2(sum_degree, DiscardCopyDegree{1}, 1, 1, 1), + lift_input3(sum_degree, DiscardCopyDegree{1}, 1, 1, 1), + }; + + tl::expected result = + get_output_shape(attrs, inputs); + tl::expected correct = + lift_output(sum_degree, DiscardCopyDegree{1}, 1, 1, 1); + + CHECK(result == correct); + } + + SUBCASE("not matching") { + std::vector inputs = { + lift_input1(SumDegree{2}, DiscardCopyDegree{1}, 1, 1, 1), + lift_input2(SumDegree{4}, DiscardCopyDegree{1}, 1, 1, 1), + lift_input3(SumDegree{4}, DiscardCopyDegree{1}, 1, 1, 1), + }; + + std::optional result = + optional_from_expected(get_output_shape(attrs, inputs)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + + SUBCASE("discard copy reduction parallelism") { + SUBCASE("matching") { + DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{2}; + + std::vector inputs = { + lift_input1(SumDegree{1}, discard_copy_degree, 1, 1, 1), + lift_input2(SumDegree{1}, discard_copy_degree, 1, 1, 1), + lift_input3(SumDegree{1}, discard_copy_degree, 1, 1, 1), + }; + + tl::expected result = + get_output_shape(attrs, inputs); + tl::expected correct = + lift_output(SumDegree{1}, discard_copy_degree, 1, 1, 1); + + CHECK(result == correct); + } + + SUBCASE("not matching") { + std::vector inputs = { + lift_input1(SumDegree{1}, DiscardCopyDegree{2}, 1, 1, 1), + lift_input2(SumDegree{1}, DiscardCopyDegree{2}, 1, 1, 1), + lift_input3(SumDegree{1}, DiscardCopyDegree{4}, 1, 1, 1), + }; + + std::optional result = + optional_from_expected(get_output_shape(attrs, inputs)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + + SUBCASE("parallelism in axis dim") { + SUBCASE("matching") { + int degree = 2; + + std::vector inputs = { + lift_input1(SumDegree{1}, DiscardCopyDegree{1}, 1, degree, 1), + lift_input2(SumDegree{1}, DiscardCopyDegree{1}, 1, degree, 1), + lift_input3(SumDegree{1}, DiscardCopyDegree{1}, 1, degree, 1), + }; + + std::optional result = + optional_from_expected(get_output_shape(attrs, inputs)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("not matching") { + std::vector inputs = { + lift_input1(SumDegree{1}, DiscardCopyDegree{1}, 1, 1, 1), + lift_input2(SumDegree{1}, DiscardCopyDegree{1}, 1, 1, 1), + lift_input3(SumDegree{1}, DiscardCopyDegree{1}, 1, 2, 1), + }; + + std::optional result = + optional_from_expected(get_output_shape(attrs, inputs)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + + SUBCASE("parallelism in non-axis shard dims") { + SUBCASE("matching") { + int degree0 = 2; + int degree2 = 4; + + std::vector inputs = { + lift_input1( + SumDegree{1}, DiscardCopyDegree{1}, degree0, 1, degree2), + lift_input2( + SumDegree{1}, DiscardCopyDegree{1}, degree0, 1, degree2), + lift_input3( + SumDegree{1}, DiscardCopyDegree{1}, degree0, 1, degree2), + }; + + tl::expected result = + get_output_shape(attrs, inputs); + tl::expected correct = lift_output( + SumDegree{1}, DiscardCopyDegree{1}, degree0, 1, degree2); + + CHECK(result == correct); + } + + SUBCASE("not matching") { + std::vector inputs = { + lift_input1(SumDegree{1}, DiscardCopyDegree{1}, 2, 1, 4), + lift_input2(SumDegree{1}, DiscardCopyDegree{1}, 4, 1, 2), + lift_input3(SumDegree{1}, DiscardCopyDegree{1}, 4, 1, 2), + }; + + std::optional result = + optional_from_expected(get_output_shape(attrs, inputs)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + + SUBCASE("parallelism degrees are not mutually exclusive") { + SumDegree sum_degree = SumDegree{3}; + DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{5}; + int degree0 = 2; + int degree2 = 4; + + std::vector inputs = { + lift_input1(sum_degree, discard_copy_degree, degree0, 1, degree2), + lift_input2(sum_degree, discard_copy_degree, degree0, 1, degree2), + lift_input3(sum_degree, discard_copy_degree, degree0, 1, degree2), + }; + + tl::expected result = + get_output_shape(attrs, inputs); + tl::expected correct = + lift_output(sum_degree, discard_copy_degree, degree0, 1, degree2); + + CHECK(result == correct); + } + } +} diff --git a/lib/op-attrs/test/src/ops/conv_2d.cc b/lib/op-attrs/test/src/op-attrs/ops/conv_2d.cc similarity index 84% rename from lib/op-attrs/test/src/ops/conv_2d.cc rename to lib/op-attrs/test/src/op-attrs/ops/conv_2d.cc index c4462eb7ec..7abb98f3e3 100644 --- a/lib/op-attrs/test/src/ops/conv_2d.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/conv_2d.cc @@ -5,6 +5,48 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_conv2d_incoming_tensor_roles(Conv2DAttrs") { + auto make_attrs = [](bool use_bias) { + return Conv2DAttrs{/*out_channels=*/4, + /*kernel_h=*/3, + /*kernel_w=*/2, + /*stride_h=*/2, + /*stride_w=*/2, + /*padding_h=*/1, + /*padding_w=*/1, + /*groups=*/1, + /*activation=*/std::nullopt, + /*use_bias=*/use_bias}; + }; + + SUBCASE("with bias") { + Conv2DAttrs attrs = make_attrs(/*use_bias=*/true); + + std::vector result = + get_conv2d_incoming_tensor_roles(attrs); + std::vector correct = { + IncomingTensorRole::INPUT, + IncomingTensorRole::WEIGHT, + IncomingTensorRole::WEIGHT, + }; + + CHECK(result == correct); + } + + SUBCASE("without bias") { + Conv2DAttrs attrs = make_attrs(/*use_bias=*/false); + + std::vector result = + get_conv2d_incoming_tensor_roles(attrs); + std::vector correct = { + IncomingTensorRole::INPUT, + IncomingTensorRole::WEIGHT, + }; + + CHECK(result == correct); + } + } + TEST_CASE("Conv2D shape inference") { int out_channels = 4; int kernel_h = 3; @@ -31,8 +73,8 @@ TEST_SUITE(FF_TEST_SUITE) { }; size_t num_samples = 7; - size_t input_channels = 6; - size_t input_height = 10; + size_t input_channels = 4; + size_t input_height = 11; size_t input_width = 15; TensorShape input = TensorShape{ @@ -45,8 +87,8 @@ TEST_SUITE(FF_TEST_SUITE) { DataType::FLOAT, }; - size_t output_height = 3; - size_t output_width = 6; + size_t output_height = 6; + size_t output_width = 8; TensorShape output = TensorShape{ TensorDims{FFOrdered{ diff --git a/lib/op-attrs/test/src/op-attrs/ops/dropout.cc b/lib/op-attrs/test/src/op-attrs/ops/dropout.cc index 17a68ccbc8..7580de24e5 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/dropout.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/dropout.cc @@ -1,5 +1,6 @@ #include "op-attrs/ops/dropout.h" #include "op-attrs/parallel_tensor_shape.h" +#include "test/utils/doctest/fmt/optional.h" #include "utils/expected.h" #include "utils/fmt/expected.h" #include diff --git a/lib/op-attrs/test/src/ops/element_binary.cc b/lib/op-attrs/test/src/op-attrs/ops/element_binary.cc similarity index 98% rename from lib/op-attrs/test/src/ops/element_binary.cc rename to lib/op-attrs/test/src/op-attrs/ops/element_binary.cc index 0ed695eb89..b091833f10 100644 --- a/lib/op-attrs/test/src/ops/element_binary.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/element_binary.cc @@ -1,6 +1,9 @@ #include "op-attrs/ops/element_binary.h" #include "op-attrs/parallel_tensor_shape.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/expected.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("EWAdd shape inference") { diff --git a/lib/op-attrs/test/src/ops/element_unary.cc b/lib/op-attrs/test/src/op-attrs/ops/element_unary.cc similarity index 95% rename from lib/op-attrs/test/src/ops/element_unary.cc rename to lib/op-attrs/test/src/op-attrs/ops/element_unary.cc index 4239782d55..94c382356e 100644 --- a/lib/op-attrs/test/src/ops/element_unary.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/element_unary.cc @@ -1,6 +1,9 @@ #include "op-attrs/ops/element_unary.h" #include "op-attrs/parallel_tensor_shape.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/expected.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("ReLU shape inference") { diff --git a/lib/op-attrs/test/src/ops/embedding.cc b/lib/op-attrs/test/src/op-attrs/ops/embedding.cc similarity index 98% rename from lib/op-attrs/test/src/ops/embedding.cc rename to lib/op-attrs/test/src/op-attrs/ops/embedding.cc index 9180f7055d..134737f6c0 100644 --- a/lib/op-attrs/test/src/ops/embedding.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/embedding.cc @@ -1,7 +1,10 @@ #include "op-attrs/ops/embedding.h" #include "op-attrs/parallel_tensor_shape.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/expected.h" #include "utils/integer_conversions.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Sum embedding shape inference") { diff --git a/lib/op-attrs/test/src/op-attrs/ops/flat.cc b/lib/op-attrs/test/src/op-attrs/ops/flat.cc new file mode 100644 index 0000000000..d81ab95c35 --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/ops/flat.cc @@ -0,0 +1,244 @@ +#include "op-attrs/ops/flat.h" +#include "utils/expected.h" +#include "utils/fmt/expected.h" +#include "utils/fmt/optional.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_output_shape(FlatAttrs, TensorShape)") { + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + 2, + 4, + 2, + 3, + }}, + DataType::FLOAT, + }; + + SUBCASE("flatten all dims") { + FlatAttrs attrs = FlatAttrs{ + /*start_dim=*/ff_dim_t{0}, + /*end_dim=*/ff_dim_t{4}, + }; + + TensorShape result = get_output_shape(attrs, input_shape); + TensorShape correct = TensorShape{ + TensorDims{FFOrdered{ + 2 * 4 * 2 * 3, + }}, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + + SUBCASE("flatten trailing dims") { + FlatAttrs attrs = FlatAttrs{ + /*start_dim=*/ff_dim_t{2}, + /*end_dim=*/ff_dim_t{4}, + }; + + TensorShape result = get_output_shape(attrs, input_shape); + TensorShape correct = TensorShape{ + TensorDims{FFOrdered{ + 2, + 4, + 2 * 3, + }}, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + + SUBCASE("flatten leading dims") { + FlatAttrs attrs = FlatAttrs{ + /*start_dim=*/ff_dim_t{0}, + /*end_dim=*/ff_dim_t{2}, + }; + + TensorShape result = get_output_shape(attrs, input_shape); + TensorShape correct = TensorShape{ + TensorDims{FFOrdered{ + 2 * 4, + 2, + 3, + }}, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + + SUBCASE("flatten middle dims") { + FlatAttrs attrs = FlatAttrs{ + /*start_dim=*/ff_dim_t{1}, + /*end_dim=*/ff_dim_t{3}, + }; + + TensorShape result = get_output_shape(attrs, input_shape); + TensorShape correct = TensorShape{ + TensorDims{FFOrdered{ + 2, + 4 * 2, + 3, + }}, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + + SUBCASE("flatten no dims (start_dim == end_dim)") { + FlatAttrs attrs = FlatAttrs{ + /*start_dim=*/ff_dim_t{2}, + /*end_dim=*/ff_dim_t{2}, + }; + + TensorShape result = get_output_shape(attrs, input_shape); + TensorShape correct = input_shape; + + CHECK(result == correct); + } + + SUBCASE("flatten no dims (start_dim < end_dim)") { + FlatAttrs attrs = FlatAttrs{ + /*start_dim=*/ff_dim_t{2}, + /*end_dim=*/ff_dim_t{1}, + }; + + TensorShape result = get_output_shape(attrs, input_shape); + TensorShape correct = input_shape; + + CHECK(result == correct); + } + } + + TEST_CASE( + "get_output_parallel_dim_degrees(FlatAttrs, ParallelTensorDimDegrees)") { + FlatAttrs attrs = FlatAttrs{/*start_dim=*/ff_dim_t{1}, + /*end_dim=*/ff_dim_t{3}}; + + SUBCASE("allows shard parallelism in non-flattened dims") { + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{2, 1, 1, 3}, + }; + + tl::expected result = + get_output_parallel_dim_degrees(attrs, input); + tl::expected correct = + ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{2, 1, 3}, + }; + + CHECK(result == correct); + } + + SUBCASE("does not allow shard parallelism in flattened dims") { + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{1, 1, 2, 1}, + }; + + std::optional result = + optional_from_expected(get_output_parallel_dim_degrees(attrs, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("allows sum parallelism") { + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{2}, + DiscardCopyDegree{1}, + FFOrdered{1, 1, 1, 1}, + }; + + std::optional result = + optional_from_expected(get_output_parallel_dim_degrees(attrs, input)); + std::optional correct = + ParallelTensorDimDegrees{ + SumDegree{2}, + DiscardCopyDegree{1}, + FFOrdered{1, 1, 1}, + }; + + CHECK(result == correct); + } + + SUBCASE("allows discard copy parallelism") { + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{2}, + FFOrdered{1, 1, 1, 1}, + }; + + std::optional result = + optional_from_expected(get_output_parallel_dim_degrees(attrs, input)); + std::optional correct = + ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{2}, + FFOrdered{1, 1, 1}, + }; + + CHECK(result == correct); + } + } + + TEST_CASE("get_output_shape(FlatAttrs, ParallelTensorShape)") { + // since most of the edge cases are already tested in + // get_output_shape(FlatAttrs, TensorShape) and + // get_output_parallel_dim_degrees(FlatAttrs, ParallelTensorDimDegrees), + // here we just do a basic check that they compose + + ParallelTensorShape input_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{4, 2}, + ShardParallelDim{8, 1}, + ShardParallelDim{6, 1}, + ShardParallelDim{9, 3}, + }, + ReplicaParallelDimSet{ + SumDegree{7}, + DiscardCopyDegree{5}, + }, + }, + DataType::FLOAT, + }; + + FlatAttrs attrs = FlatAttrs{ + /*start_dim=*/ff_dim_t{1}, + /*end_dim=*/ff_dim_t{3}, + }; + + tl::expected result = + get_output_shape(attrs, input_shape); + tl::expected correct = + ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{4, 2}, + ShardParallelDim{8 * 6, 1}, + ShardParallelDim{9, 3}, + }, + ReplicaParallelDimSet{ + SumDegree{7}, + DiscardCopyDegree{5}, + }, + }, + DataType::FLOAT, + }; + + CHECK(result == correct); + } +} diff --git a/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc b/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc index b9dd66df5d..f45ea91dac 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc @@ -1,13 +1,49 @@ #include "op-attrs/ops/layer_norm.h" #include "op-attrs/parallel_tensor_shape.h" -#include "test/utils/doctest.h" #include "utils/expected.h" #include "utils/fmt/expected.h" #include "utils/fmt/optional.h" +#include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_layer_norm_incoming_tensor_roles(LayerNormAttrs)") { + auto make_attrs = [](bool elementwise_affine) { + return LayerNormAttrs{ + /*axes=*/{ff_dim_t{0}, ff_dim_t{2}}, + elementwise_affine, + /*eps=*/1.0, + }; + }; + + SUBCASE("elementwise_affine = true") { + LayerNormAttrs attrs = make_attrs(/*elementwise_affine=*/true); + + std::vector result = + get_layer_norm_incoming_tensor_roles(attrs); + std::vector correct = { + IncomingTensorRole::INPUT, + IncomingTensorRole::WEIGHT, + IncomingTensorRole::WEIGHT, + }; + + CHECK(result == correct); + } + + SUBCASE("elementwise_affine = false") { + LayerNormAttrs attrs = make_attrs(/*elementwise_affine=*/false); + + std::vector result = + get_layer_norm_incoming_tensor_roles(attrs); + std::vector correct = { + IncomingTensorRole::INPUT, + }; + + CHECK(result == correct); + } + } + TEST_CASE("shape inference (LayerNorm)") { LayerNormAttrs attrs_affine_true = LayerNormAttrs{ /*axes=*/{ff_dim_t{1}, ff_dim_t{3}}, diff --git a/lib/op-attrs/test/src/ops/linear.cc b/lib/op-attrs/test/src/op-attrs/ops/linear.cc similarity index 73% rename from lib/op-attrs/test/src/ops/linear.cc rename to lib/op-attrs/test/src/op-attrs/ops/linear.cc index 0d23dc35df..191515b062 100644 --- a/lib/op-attrs/test/src/ops/linear.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/linear.cc @@ -1,9 +1,51 @@ #include "op-attrs/ops/linear.h" #include "op-attrs/parallel_tensor_shape.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/expected.h" #include "utils/integer_conversions.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_linear_incoming_tensor_roles(LinearAttrs)") { + auto make_attrs = [](bool use_bias) { + return LinearAttrs{ + /*out_channels=*/16, + /*use_bias=*/use_bias, + /*data_type=*/DataType::FLOAT, + /*activation=*/Activation::RELU, + /*regularizer=*/std::nullopt, + }; + }; + + SUBCASE("use_bias = true") { + LinearAttrs attrs = make_attrs(/*use_bias=*/true); + + std::vector result = + get_linear_incoming_tensor_roles(attrs); + std::vector correct = { + IncomingTensorRole::INPUT, + IncomingTensorRole::WEIGHT, + IncomingTensorRole::WEIGHT, + }; + + CHECK(result == correct); + } + + SUBCASE("use_bias = false") { + LinearAttrs attrs = make_attrs(/*use_bias=*/false); + + std::vector result = + get_linear_incoming_tensor_roles(attrs); + std::vector correct = { + IncomingTensorRole::INPUT, + IncomingTensorRole::WEIGHT, + }; + + CHECK(result == correct); + } + } + TEST_CASE("Linear shape inference") { int out_channels = 16; LinearAttrs attrs = LinearAttrs{ @@ -40,7 +82,7 @@ TEST_SUITE(FF_TEST_SUITE) { DataType::FLOAT, }; - TensorShape kernel = TensorShape{ + TensorShape projection = TensorShape{ TensorDims{ FFOrdered{ in_channels, @@ -69,10 +111,10 @@ TEST_SUITE(FF_TEST_SUITE) { // get_weight_shape { - tl::expected kernel_result = - get_kernel_shape(attrs, input); - tl::expected kernel_correct = kernel; - CHECK(kernel_result == kernel_correct); + tl::expected projection_result = + get_projection_shape(attrs, input); + tl::expected projection_correct = projection; + CHECK(projection_result == projection_correct); } // get_bias_shape @@ -101,12 +143,12 @@ TEST_SUITE(FF_TEST_SUITE) { output, o_sum, o_eq, FFOrdered{o_batch, o_extra_dim, o_channel}); }; - auto make_kernel = [&](SumDegree o_sum, - DiscardCopyDegree o_eq, - int o_inchannel, - int o_outchannel) { + auto make_projection = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_inchannel, + int o_outchannel) { return lift_to_parallel_with_degrees( - kernel, o_sum, o_eq, FFOrdered{o_inchannel, o_outchannel}); + projection, o_sum, o_eq, FFOrdered{o_inchannel, o_outchannel}); }; auto make_bias = @@ -140,12 +182,13 @@ TEST_SUITE(FF_TEST_SUITE) { { tl::expected result = - get_kernel_shape(attrs, par_input); - tl::expected correct = make_kernel( - SumDegree{1}, - DiscardCopyDegree{input_sum_degree * degree * extra_dim_degree}, - 1, - 1); + get_projection_shape(attrs, par_input); + tl::expected correct = + make_projection( + SumDegree{1}, + DiscardCopyDegree{input_sum_degree * degree * extra_dim_degree}, + 1, + 1); CHECK(result == correct); } @@ -181,9 +224,10 @@ TEST_SUITE(FF_TEST_SUITE) { { tl::expected result = - get_kernel_shape(attrs, par_input); - tl::expected correct = make_kernel( - SumDegree{1}, DiscardCopyDegree{input_sum_degree}, degree, 1); + get_projection_shape(attrs, par_input); + tl::expected correct = + make_projection( + SumDegree{1}, DiscardCopyDegree{input_sum_degree}, degree, 1); CHECK(result == correct); } @@ -213,9 +257,10 @@ TEST_SUITE(FF_TEST_SUITE) { { tl::expected result = - get_kernel_shape(attrs, par_input); - tl::expected correct = make_kernel( - SumDegree{1}, DiscardCopyDegree{input_sum_degree}, 1, degree); + get_projection_shape(attrs, par_input); + tl::expected correct = + make_projection( + SumDegree{1}, DiscardCopyDegree{input_sum_degree}, 1, degree); CHECK(result == correct); } diff --git a/lib/op-attrs/test/src/op-attrs/ops/pool_2d.cc b/lib/op-attrs/test/src/op-attrs/ops/pool_2d.cc new file mode 100644 index 0000000000..0c14c0fc2a --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/ops/pool_2d.cc @@ -0,0 +1,400 @@ +#include "op-attrs/ops/pool_2d.h" +#include "utils/expected.h" +#include "utils/fmt/expected.h" +#include "utils/fmt/optional.h" +#include "utils/integer_conversions.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("make_adaptive_pool2d") { + size_t input_n = 10; + size_t input_c = 11; + size_t input_h = 15; + size_t input_w = 20; + Activation activation = Activation::RELU; + PoolOp op = PoolOp::AVG; + + TensorDims input_dims = + TensorDims{FFOrdered{input_n, input_c, input_h, input_w}}; + + SUBCASE("input_h divisible by output_h && input_w divisible by output_w") { + int output_h = 5; + int output_w = 2; + + Pool2DAttrs correct_attrs = Pool2DAttrs{ + /*kernel_h=*/3, + /*kernel_w=*/10, + /*stride_h=*/3, + /*stride_w=*/10, + /*padding_h=*/0, + /*padding_w=*/0, + /*pool_type=*/op, + /*activation=*/activation, + }; + + SUBCASE("returns correct attrs") { + tl::expected result = + make_adaptive_pool2d_attrs( + input_dims, output_h, output_w, op, activation); + tl::expected correct = correct_attrs; + + CHECK(result == correct); + } + + SUBCASE( + "confirm that output shape is as expected for the expected attrs") { + TensorShape input_shape = TensorShape{input_dims, DataType::FLOAT}; + + tl::expected result = + get_output_shape(correct_attrs, input_shape); + tl::expected correct = TensorShape{ + TensorDims{FFOrdered{ + input_n, + input_c, + size_t_from_int(output_h), + size_t_from_int(output_w), + }}, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + } + + SUBCASE("input_h not divisible by output_h") { + int output_h = 6; + int output_w = 2; + + std::optional result = + optional_from_expected(make_adaptive_pool2d_attrs( + input_dims, output_h, output_w, op, activation)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("input_w not divisible by output_w") { + int output_h = 5; + int output_w = 3; + + std::optional result = + optional_from_expected(make_adaptive_pool2d_attrs( + input_dims, output_h, output_w, op, activation)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("input_h == output_h and input_w == output_w") { + int output_h = input_h; + int output_w = input_w; + + Pool2DAttrs correct_attrs = Pool2DAttrs{ + /*kernel_h=*/1, + /*kernel_w=*/1, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/0, + /*padding_w=*/0, + /*pool_type=*/op, + /*activation=*/activation, + }; + + SUBCASE("returns correct attrs") { + tl::expected result = + make_adaptive_pool2d_attrs( + input_dims, output_h, output_w, op, activation); + tl::expected correct = correct_attrs; + + CHECK(result == correct); + } + + SUBCASE( + "confirm that output shape is as expected for the expected attrs") { + TensorShape input_shape = TensorShape{input_dims, DataType::FLOAT}; + + tl::expected result = + get_output_shape(correct_attrs, input_shape); + tl::expected correct = input_shape; + + CHECK(result == correct); + } + } + } + + TEST_CASE("get_output_shape(Pool2DAttrs, TensorShape)") { + Pool2DAttrs attrs = Pool2DAttrs{ + /*kernel_h=*/3, + /*kernel_w=*/2, + /*stride_h=*/2, + /*stride_w=*/2, + /*padding_h=*/1, + /*padding_w=*/1, + /*pool_type=*/PoolOp::MAX, + /*activation=*/std::nullopt, + }; + + SUBCASE("fails on non-4d inputs") { + TensorShape input = TensorShape{ + TensorDims{FFOrdered{ + 10, + 12, + 14, + }}, + DataType::FLOAT, + }; + + std::optional result = + optional_from_expected(get_output_shape(attrs, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("4d input") { + TensorShape input = TensorShape{ + TensorDims{FFOrdered{11, 13, 12, 6}}, + DataType::FLOAT, + }; + + tl::expected result = + get_output_shape(attrs, input); + tl::expected correct = TensorShape{ + TensorDims{FFOrdered{11, 13, 6, 4}}, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + } + + TEST_CASE("get_output_parallel_dim_degrees(Pool2DAttrs, " + "ParallelTensorDimDegrees)") { + auto make_attrs = [](PoolOp pool_type, + std::optional const &activation) { + return Pool2DAttrs{ + /*kernel_h=*/3, + /*kernel_w=*/2, + /*stride_h=*/2, + /*stride_w=*/2, + /*padding_h=*/1, + /*padding_w=*/1, + /*pool_type=*/pool_type, + /*activation=*/activation, + }; + }; + + SUBCASE("allows data parallelism") { + Pool2DAttrs attrs = make_attrs(PoolOp::MAX, /*activation=*/std::nullopt); + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{ + 4, + 1, + 1, + 1, + }, + }; + + tl::expected result = + get_output_parallel_dim_degrees(attrs, input); + tl::expected correct = input; + + CHECK(result == correct); + } + + SUBCASE("allows arbitrary input sharding parallelism") { + Pool2DAttrs attrs = make_attrs(PoolOp::MAX, /*activation=*/std::nullopt); + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{ + 4, + 2, + 5, + 6, + }, + }; + + tl::expected result = + get_output_parallel_dim_degrees(attrs, input); + tl::expected correct = input; + + CHECK(result == correct); + } + + SUBCASE("allows discard copy parallelism") { + Pool2DAttrs attrs = make_attrs(PoolOp::MAX, /*activation=*/std::nullopt); + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{3}, + FFOrdered{ + 1, + 1, + 1, + 1, + }, + }; + + tl::expected result = + get_output_parallel_dim_degrees(attrs, input); + tl::expected correct = input; + + CHECK(result == correct); + } + + SUBCASE("sum parallelism") { + SUBCASE("without activation") { + SUBCASE("PoolOp::MAX does not allow sum parallelism") { + Pool2DAttrs attrs = + make_attrs(PoolOp::MAX, /*activation=*/std::nullopt); + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{2}, + DiscardCopyDegree{1}, + FFOrdered{ + 1, + 1, + 1, + 1, + }, + }; + + std::optional result = + optional_from_expected( + get_output_parallel_dim_degrees(attrs, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("PoolOp::AVG does allow sum parallelism") { + Pool2DAttrs attrs = + make_attrs(PoolOp::AVG, /*activation=*/std::nullopt); + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{2}, + DiscardCopyDegree{1}, + FFOrdered{ + 1, + 1, + 1, + 1, + }, + }; + + tl::expected result = + get_output_parallel_dim_degrees(attrs, input); + tl::expected correct = input; + + CHECK(result == correct); + } + } + + SUBCASE("with activation does not allow sum parallelism") { + Pool2DAttrs attrs = + make_attrs(PoolOp::AVG, /*activation=*/Activation::RELU); + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{2}, + DiscardCopyDegree{1}, + FFOrdered{ + 1, + 1, + 1, + 1, + }, + }; + + std::optional result = optional_from_expected( + get_output_parallel_dim_degrees(attrs, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + } + + TEST_CASE("get_output_shape(Pool2DAttrs, ParallelTensorShape)") { + // this function is mostly covered by the tests above, so we + // just do a single test to make sure it works/exists + + Pool2DAttrs attrs = Pool2DAttrs{ + /*kernel_h=*/3, + /*kernel_w=*/2, + /*stride_h=*/2, + /*stride_w=*/2, + /*padding_h=*/1, + /*padding_w=*/1, + /*pool_type=*/PoolOp::MAX, + /*activation=*/std::nullopt, + }; + + SUBCASE("valid parallelism") { + ParallelTensorShape input = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{14, 7}, + ShardParallelDim{16, 8}, + ShardParallelDim{12, 3}, + ShardParallelDim{6, 2}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{2}, + }, + }, + DataType::FLOAT, + }; + + tl::expected result = + get_output_shape(attrs, input); + tl::expected correct = + ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{14, 7}, + ShardParallelDim{16, 8}, + ShardParallelDim{6, 3}, + ShardParallelDim{4, 2}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{2}, + }, + }, + DataType::FLOAT, + }; + } + + SUBCASE("invalid parallelism") { + ParallelTensorShape input = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{14, 1}, + ShardParallelDim{16, 1}, + ShardParallelDim{12, 1}, + ShardParallelDim{6, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{2}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + std::optional result = + optional_from_expected(get_output_shape(attrs, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } +} diff --git a/lib/op-attrs/test/src/ops/reduction.cc b/lib/op-attrs/test/src/op-attrs/ops/reduction.cc similarity index 93% rename from lib/op-attrs/test/src/ops/reduction.cc rename to lib/op-attrs/test/src/op-attrs/ops/reduction.cc index 59ed5bb5ee..0d1c8bdf98 100644 --- a/lib/op-attrs/test/src/ops/reduction.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/reduction.cc @@ -1,5 +1,8 @@ #include "op-attrs/ops/reduction.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/expected.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Reduction shape inference") { diff --git a/lib/op-attrs/test/src/ops/repartition.cc b/lib/op-attrs/test/src/op-attrs/ops/repartition.cc similarity index 91% rename from lib/op-attrs/test/src/ops/repartition.cc rename to lib/op-attrs/test/src/op-attrs/ops/repartition.cc index af28a6d471..8bc8205183 100644 --- a/lib/op-attrs/test/src/ops/repartition.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/repartition.cc @@ -1,5 +1,8 @@ #include "op-attrs/ops/repartition.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/expected.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Repartition shape inference") { diff --git a/lib/op-attrs/test/src/ops/replicate.cc b/lib/op-attrs/test/src/op-attrs/ops/replicate.cc similarity index 93% rename from lib/op-attrs/test/src/ops/replicate.cc rename to lib/op-attrs/test/src/op-attrs/ops/replicate.cc index a0ec40cc14..60a1018479 100644 --- a/lib/op-attrs/test/src/ops/replicate.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/replicate.cc @@ -1,5 +1,7 @@ #include "op-attrs/ops/replicate.h" -#include "test/utils/doctest.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Replicate shape inference") { diff --git a/lib/op-attrs/test/src/op-attrs/ops/softmax.cc b/lib/op-attrs/test/src/op-attrs/ops/softmax.cc index f6a8da016f..65a74932cb 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/softmax.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/softmax.cc @@ -1,5 +1,6 @@ #include "op-attrs/ops/softmax.h" #include "op-attrs/parallel_tensor_shape.h" +#include "test/utils/doctest/fmt/optional.h" #include "utils/expected.h" #include "utils/fmt/expected.h" #include diff --git a/lib/op-attrs/test/src/op-attrs/pcg_operator_attrs.cc b/lib/op-attrs/test/src/op-attrs/pcg_operator_attrs.cc new file mode 100644 index 0000000000..ebeaec4d19 --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/pcg_operator_attrs.cc @@ -0,0 +1,17 @@ +#include "op-attrs/pcg_operator_attrs.dtg.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("PCGOperatorAttrs to/from json") { + PCGOperatorAttrs correct = PCGOperatorAttrs{RepartitionAttrs{ + /*repartition_dim=*/ff_dim_t{1}, + /*repartition_degree=*/4, + }}; + nlohmann::json j = correct; + auto result = j.get(); + + CHECK(result == correct); + } +} diff --git a/lib/op-attrs/test/src/test_regularizer_attrs.cc b/lib/op-attrs/test/src/op-attrs/regularizer_attrs.cc similarity index 83% rename from lib/op-attrs/test/src/test_regularizer_attrs.cc rename to lib/op-attrs/test/src/op-attrs/regularizer_attrs.cc index 35851463bb..6e172d1e8e 100644 --- a/lib/op-attrs/test/src/test_regularizer_attrs.cc +++ b/lib/op-attrs/test/src/op-attrs/regularizer_attrs.cc @@ -1,6 +1,8 @@ #include "op-attrs/regularizer_attrs.dtg.h" -#include "test/utils/doctest.h" #include "test/utils/rapidcheck.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Arbitrary") { diff --git a/lib/op-attrs/test/src/op-attrs/tensor_dims.cc b/lib/op-attrs/test/src/op-attrs/tensor_dims.cc index 25c7eb036f..60d87300c1 100644 --- a/lib/op-attrs/test/src/op-attrs/tensor_dims.cc +++ b/lib/op-attrs/test/src/op-attrs/tensor_dims.cc @@ -1,4 +1,5 @@ #include "op-attrs/tensor_dims.h" +#include "test/utils/doctest/fmt/optional.h" #include using namespace ::FlexFlow; diff --git a/lib/op-attrs/test/src/op-attrs/tensor_shape.cc b/lib/op-attrs/test/src/op-attrs/tensor_shape.cc deleted file mode 100644 index bc715c183a..0000000000 --- a/lib/op-attrs/test/src/op-attrs/tensor_shape.cc +++ /dev/null @@ -1,64 +0,0 @@ -#include "op-attrs/tensor_shape.h" -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_broadcast_target_shape(std::unordered_set)") { - SUBCASE("target exists in inputs") { - DataType datatype = DataType::FLOAT; - - TensorShape s1 = TensorShape{ - TensorDims{FFOrdered{ - 1, - }}, - datatype, - }; - - TensorShape s2 = TensorShape{ - TensorDims{FFOrdered{10, 4, 3}}, - datatype, - }; - - TensorShape s3 = TensorShape{ - TensorDims{FFOrdered{ - 4, - 1, - }}, - datatype, - }; - - std::optional result = - get_broadcast_target_shape({s1, s2, s3}); - std::optional correct = s2; - - CHECK(result == correct); - } - - SUBCASE("datatypes don't match") { - TensorDims dims = TensorDims{FFOrdered{10, 4, 3}}; - - TensorShape s1 = TensorShape{ - dims, - DataType::FLOAT, - }; - - TensorShape s2 = TensorShape{ - dims, - DataType::DOUBLE, - }; - - std::optional result = get_broadcast_target_shape({s1, s2}); - std::optional correct = std::nullopt; - - CHECK(result == correct); - } - - SUBCASE("inputs is empty") { - std::optional result = get_broadcast_target_shape({}); - std::optional correct = std::nullopt; - - CHECK(result == correct); - } - } -} diff --git a/lib/op-attrs/test/src/test_operator_attrs.cc b/lib/op-attrs/test/src/test_operator_attrs.cc deleted file mode 100644 index f485b07b02..0000000000 --- a/lib/op-attrs/test/src/test_operator_attrs.cc +++ /dev/null @@ -1,37 +0,0 @@ -#include "op-attrs/computation_graph_op_attrs.dtg.h" -#include "op-attrs/pcg_operator_attrs.dtg.h" -#include "utils/json.h" -#include -#include -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("BatchNormAttrs to/from json") { - BatchNormAttrs correct = BatchNormAttrs{true}; - json j = correct; - auto result = j.get(); - CHECK(result == correct); - } - - TEST_CASE("ComputationGraphAttrs to/from json") { - ComputationGraphOpAttrs correct = - ComputationGraphOpAttrs{BatchNormAttrs{true}}; - json j = correct; - auto result = j.get(); - - CHECK(result == correct); - } - - TEST_CASE("PCGOperatorAttrs to/from json") { - PCGOperatorAttrs correct = PCGOperatorAttrs{RepartitionAttrs{ - /*repartition_dim=*/ff_dim_t{1}, - /*repartition_degree=*/4, - }}; - json j = correct; - auto result = j.get(); - - CHECK(result == correct); - } -} diff --git a/lib/pcg/CMakeLists.txt b/lib/pcg/CMakeLists.txt index e1875ca694..e6eb182740 100644 --- a/lib/pcg/CMakeLists.txt +++ b/lib/pcg/CMakeLists.txt @@ -10,6 +10,7 @@ ff_add_library( DEPS op-attrs utils + rapidcheck ) add_subdirectory(ffi) diff --git a/lib/pcg/include/pcg/computation_graph.h b/lib/pcg/include/pcg/computation_graph.h index 088139a0f3..b29d683edb 100644 --- a/lib/pcg/include/pcg/computation_graph.h +++ b/lib/pcg/include/pcg/computation_graph.h @@ -1,7 +1,9 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_H #define _FLEXFLOW_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_H +#include "op-attrs/incoming_tensor_role.dtg.h" #include "pcg/computation_graph.dtg.h" +#include "pcg/computation_graph/computation_graph_edge.dtg.h" #include "pcg/computation_graph/layer_added_result.dtg.h" #include "pcg/layer_guid_t.dtg.h" #include "pcg/tensor_attrs.dtg.h" @@ -30,11 +32,34 @@ std::vector get_outgoing_tensors(ComputationGraph const &cg, std::vector get_incoming_tensors(ComputationGraph const &cg, layer_guid_t n); +std::vector get_incoming_inputs(ComputationGraph const &, + layer_guid_t const &); +std::vector get_incoming_weights(ComputationGraph const &, + layer_guid_t const &); + +std::unordered_set + get_subgraph_incoming_edges(ComputationGraph const &, + std::unordered_set const &); +std::unordered_set + get_subgraph_outgoing_edges(ComputationGraph const &, + std::unordered_set const &); +std::unordered_set + get_subgraph_successors(ComputationGraph const &, + std::unordered_set const &); + LayerAttrs get_layer_attrs(ComputationGraph const &cg, layer_guid_t const &n); layer_guid_t get_layer_by_name(ComputationGraph const &cg, std::string const &name); +ComputationGraph without_layer_names(ComputationGraph const &); + +bool computation_graphs_are_isomorphic(ComputationGraph const &, + ComputationGraph const &); + +std::string as_dot(ComputationGraph const &); +void debug_print_dot(ComputationGraph const &); + } // namespace FlexFlow #endif diff --git a/lib/pcg/include/pcg/computation_graph/computation_graph_edge.h b/lib/pcg/include/pcg/computation_graph/computation_graph_edge.h new file mode 100644 index 0000000000..2a9a9ee04a --- /dev/null +++ b/lib/pcg/include/pcg/computation_graph/computation_graph_edge.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_COMPUTATION_GRAPH_EDGE_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_COMPUTATION_GRAPH_EDGE_H + +#include "pcg/computation_graph/computation_graph_edge.dtg.h" +#include "pcg/layer_guid_t.dtg.h" + +namespace FlexFlow { + +layer_guid_t get_computation_graph_edge_src_layer(ComputationGraphEdge const &); +layer_guid_t get_computation_graph_edge_dst_layer(ComputationGraphEdge const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/computation_graph/computation_graph_edge.struct.toml b/lib/pcg/include/pcg/computation_graph/computation_graph_edge.struct.toml new file mode 100644 index 0000000000..311c47d277 --- /dev/null +++ b/lib/pcg/include/pcg/computation_graph/computation_graph_edge.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "ComputationGraphEdge" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/dataflow_graph/dataflow_edge.dtg.h", +] + +[[fields]] +name = "raw_edge" +type = "::FlexFlow::DataflowEdge" diff --git a/lib/pcg/include/pcg/computation_graph_builder.h b/lib/pcg/include/pcg/computation_graph_builder.h index c641aed6a4..45cde0de57 100644 --- a/lib/pcg/include/pcg/computation_graph_builder.h +++ b/lib/pcg/include/pcg/computation_graph_builder.h @@ -137,6 +137,13 @@ struct ComputationGraphBuilder { PoolOp type = PoolOp::MAX, std::optional const &activation = std::nullopt, std::optional const &name = std::nullopt); + tensor_guid_t adaptive_pool2d( + tensor_guid_t const &input, + int output_h, + int output_w, + PoolOp type = PoolOp::MAX, + std::optional const &activation = std::nullopt, + std::optional const &name = std::nullopt); tensor_guid_t layer_norm(tensor_guid_t const &input, std::vector const &axes, @@ -145,7 +152,10 @@ struct ComputationGraphBuilder { std::optional const &name = std::nullopt); tensor_guid_t batch_norm(tensor_guid_t const &input, - bool relu = true, + bool affine, + std::optional const &activation, + float eps, + std::optional const &momentum, std::optional const &name = std::nullopt); tensor_guid_t batch_matmul(tensor_guid_t const &A, @@ -159,19 +169,20 @@ struct ComputationGraphBuilder { std::optional activation = std::nullopt, bool use_bias = true, DataType data_type = DataType::FLOAT, - std::optional const &kernel_initializer = std::nullopt, + std::optional const &projection_initializer = + std::nullopt, std::optional const &bias_initializer = std::nullopt, - std::optional const &name = std::nullopt); + std::optional const &name = std::nullopt, + std::optional const &projection_name = std::nullopt, + std::optional const &bias_name = std::nullopt); // Add a cast layer tensor_guid_t cast(tensor_guid_t const &input, DataType dtype, std::optional const &name = std::nullopt); // Add a concat layer - tensor_guid_t - concat(int n, - std::vector const &tensors, - int axis, - std::optional const &maybe_name = std::nullopt); + tensor_guid_t concat(std::vector const &tensors, + int axis, + std::optional const &name = std::nullopt); // Add a mean layer tensor_guid_t mean(tensor_guid_t const &input, std::vector const &dims, @@ -185,6 +196,8 @@ struct ComputationGraphBuilder { std::optional const &name = std::nullopt); // Add a flat layer tensor_guid_t flat(tensor_guid_t const &input, + int start_dim = 0, + std::optional const &end_dim = std::nullopt, std::optional const &name = std::nullopt); // Add a softmax layer tensor_guid_t softmax(tensor_guid_t const &input, @@ -225,47 +238,40 @@ struct ComputationGraphBuilder { bool add_zero_attn = false, std::optional initializer = std::nullopt, std::optional const &maybe_name = std::nullopt); - tensor_guid_t create_tensor(TensorShape const &, CreateGrad); + tensor_guid_t + create_input(TensorShape const &, + CreateGrad, + std::optional const &name = std::nullopt); + tensor_guid_t create_weight( TensorShape const &, - bool create_grad = true, + CreateGrad create_grad = CreateGrad::YES, std::optional const &initializer = std::nullopt, - std::optional sync_type = std::nullopt); + std::optional sync_type = std::nullopt, + std::optional const &name = std::nullopt); + tensor_guid_t + create_weight(TensorAttrs const &, + std::optional const &name = std::nullopt); std::vector get_outputs(LayerAttrs const &) const; tensor_guid_t get_output(LayerAttrs const &, int idx) const; - std::vector add_layer(LayerAttrs const &layer, - std::vector const &inputs, - std::vector const &weights, - std::vector const &outputs); + std::vector + add_layer(LayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + std::vector const &outputs); -private: TensorShape get_shape(tensor_guid_t const &) const; - tensor_guid_t broadcast(tensor_guid_t const &, - TensorShape const &, - std::string const &); +private: + tensor_guid_t + broadcast(tensor_guid_t const &, TensorDims const &, std::string const &); tensor_guid_t as_type(tensor_guid_t const &, DataType, std::string const &); - tensor_guid_t add_layer(LayerAttrs const &layer, - std::vector const &inputs, - std::vector const &weights, - TensorAttrs const &output); - - std::vector add_layer(LayerAttrs const &layer, - std::vector const &inputs, - std::vector const &weights, - std::vector const &outputs); - - tensor_guid_t add_layer(LayerAttrs const &layer, - std::vector const &inputs, - std::vector const &weights, - TensorShape const &output); - - TensorShape get_broadcast_target_shape(std::vector const &); - TensorShape get_broadcast_target_shape(std::vector const &); + TensorDims get_broadcast_target_dims(std::vector const &); + TensorDims get_broadcast_target_dims(std::vector const &); tensor_guid_t element_binary(OperatorType, diff --git a/lib/pcg/include/pcg/device_id.h b/lib/pcg/include/pcg/device_id.h index 1157a2932a..28cf30eaba 100644 --- a/lib/pcg/include/pcg/device_id.h +++ b/lib/pcg/include/pcg/device_id.h @@ -13,6 +13,7 @@ device_id_t operator+(device_id_t, size_t); DeviceType get_device_type(device_id_t const &device_id); gpu_id_t unwrap_gpu(device_id_t); cpu_id_t unwrap_cpu(device_id_t); +int get_raw_id(device_id_t); device_id_t device_id_from_index(int, DeviceType); diff --git a/lib/pcg/include/pcg/file_format/file_format.h b/lib/pcg/include/pcg/file_format/file_format.h deleted file mode 100644 index 823846754c..0000000000 --- a/lib/pcg/include/pcg/file_format/file_format.h +++ /dev/null @@ -1,19 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_FILE_FORMAT_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_FILE_FORMAT_H - -#include "graphs.h" -#include "utils/json.h" - -namespace FlexFlow { - -enum class FileFormatVersion { - V1, - UNSTABLE, -}; - -json to_json(ComputationGraph const &, FileFormatVersion); -ComputationGraph from_json(json const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/file_format/keyed_variant.h b/lib/pcg/include/pcg/file_format/keyed_variant.h index 11044de12b..5e29d8c252 100644 --- a/lib/pcg/include/pcg/file_format/keyed_variant.h +++ b/lib/pcg/include/pcg/file_format/keyed_variant.h @@ -1,10 +1,11 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_KEYED_VARIANT_H #define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_KEYED_VARIANT_H -#include "utils/json.h" +#include "utils/json/is_jsonable.h" #include "utils/sequence.h" #include "utils/strong_typedef.h" #include "utils/variant.h" +#include namespace FlexFlow { @@ -29,9 +30,9 @@ struct KeyedVariant { }; struct ToJsonFunctor { - ToJsonFunctor(json &j) : j(j) {} + ToJsonFunctor(nlohmann::json &j) : j(j) {} - json &j; + nlohmann::json &j; template void operator()(T const &t) { @@ -42,20 +43,20 @@ struct ToJsonFunctor { }; template -void to_json(json &j, KeyedVariant const &v) { +void to_json(nlohmann::json &j, KeyedVariant const &v) { static_assert(is_jsonable::value, ""); K key = static_cast(v.value.index()); j["type"] = key; - json &jj = j["value"]; + nlohmann::json &jj = j["value"]; visit(ToJsonFunctor{j["value"]}, v.value); } template struct FromJsonFunctor { - FromJsonFunctor(json const &j, int idx) : j(j), idx(idx) {} + FromJsonFunctor(nlohmann::json const &j, int idx) : j(j), idx(idx) {} - json const &j; + nlohmann::json const &j; int idx; template @@ -68,31 +69,31 @@ struct FromJsonFunctor { template std::string get_json_name(T const &t) { - return json{t}.get(); + return nlohmann::json{t}.get(); } template struct FromJsonMoveOnlyFunctor { - FromJsonMoveOnlyFunctor(json const &j, Key const &key) : j(j) {} + FromJsonMoveOnlyFunctor(nlohmann::json const &j, Key const &key) : j(j) {} - json const &j; + nlohmann::json const &j; Key const &key; template Variant operator()(std::integral_constant const &) const { - return j.get::type>(); + return j.get::type>(); } }; template -Variant from_json_moveonly(json const &j, K const &key) { +Variant from_json_moveonly(nlohmann::json const &j, K const &key) { FromJsonMoveOnlyFunctor func(j); return seq_get(func, idx, seq_count_t::value>{}); } template typename std::enable_if::value>::type - from_json(json const &j, KeyedVariant &v) { + from_json(nlohmann::json const &j, KeyedVariant &v) { K key = j.at("type").get(); std::string key_string = j.at("type").get(); @@ -100,7 +101,7 @@ typename std::enable_if::value>::type } template -KeyedVariant keyed_variant_from_json(json const &j) { +KeyedVariant keyed_variant_from_json(nlohmann::json const &j) { K key = j.at("type").get(); return KeyedVariant{ diff --git a/lib/pcg/include/pcg/file_format/v1/data_type_value.h b/lib/pcg/include/pcg/file_format/v1/data_type_value.h index 6e4e5abc54..ec3910aab3 100644 --- a/lib/pcg/include/pcg/file_format/v1/data_type_value.h +++ b/lib/pcg/include/pcg/file_format/v1/data_type_value.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_DATA_TYPE_H #include "utils/fp16.h" -#include "utils/json.h" +#include namespace FlexFlow { diff --git a/lib/pcg/include/pcg/file_format/v1/graphs.h b/lib/pcg/include/pcg/file_format/v1/graphs.h deleted file mode 100644 index 702c79c2b6..0000000000 --- a/lib/pcg/include/pcg/file_format/v1/graphs.h +++ /dev/null @@ -1,26 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_H - -#include "pcg/computation_graph.dtg.h" -#include "pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.dtg.h" -#include "pcg/layer_attrs.dtg.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" -#include "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h" -#include "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h" -#include "pcg/tensor_attrs.dtg.h" -#include "utils/json.h" - -namespace FlexFlow { - -using V1ComputationGraph = V1LabelledDataflowGraph; -CHECK_IS_JSONABLE(V1ComputationGraph); -V1ComputationGraph to_v1(ComputationGraph const &); - -using V1ParallelComputationGraph = - V1LabelledDataflowGraph; -CHECK_IS_JSONABLE(V1ParallelComputationGraph); -V1ParallelComputationGraph to_v1(ParallelComputationGraph const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.struct.toml b/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.struct.toml index d9aade739c..c332b6b41d 100644 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.struct.toml +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.struct.toml @@ -1,9 +1,9 @@ namespace = "FlexFlow" name = "V1DataflowGraph" features = [ - # "eq", + "eq", # "ord", - # "hash", + "hash", "json", # "rapidcheck", "fmt", @@ -13,8 +13,13 @@ includes = [ "", "", "pcg/file_format/v1/graphs/v1_graph_edge.dtg.h", +] + +src_includes = [ "utils/fmt/vector.h", + "utils/hash/vector.h", "utils/fmt/unordered_set.h", + "utils/hash/unordered_set.h", ] [[fields]] diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h index 48203d73ae..fc9dfcef9a 100644 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h @@ -13,8 +13,9 @@ namespace FlexFlow { template -V1LabelledDataflowGraph - to_v1(LabelledDataflowGraphView const &g) { +std::pair, bidict> + to_v1_including_node_numbering( + LabelledDataflowGraphView const &g) { bidict nodes = bidict_from_enumerating(get_nodes(g)); @@ -29,8 +30,17 @@ V1LabelledDataflowGraph [&](DataflowOutput const &o) { return g.at(o); }); }); - return V1LabelledDataflowGraph{ - node_labels, output_labels, unlabelled}; + return { + V1LabelledDataflowGraph{ + node_labels, output_labels, unlabelled}, + nodes, + }; +} + +template +V1LabelledDataflowGraph + to_v1(LabelledDataflowGraphView const &g) { + return to_v1_including_node_numbering(g).first; } } // namespace FlexFlow diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.struct.toml b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.struct.toml index fd8d4c39c4..b440d0f03d 100644 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.struct.toml +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.struct.toml @@ -1,9 +1,9 @@ namespace = "FlexFlow" name = "V1LabelledDataflowGraph" features = [ - # "eq", + "eq", # "ord", - # "hash", + "hash", "json", # "rapidcheck", "fmt", @@ -20,6 +20,13 @@ includes = [ "pcg/file_format/v1/graphs/v1_graph_output.dtg.h", ] +src_includes = [ + "utils/fmt/unordered_map.h", + "utils/hash/unordered_map.h", + "utils/fmt/vector.h", + "utils/hash/vector.h", +] + [[fields]] name = "node_labels" type = "std::unordered_map" @@ -31,4 +38,3 @@ type = "std::unordered_map>" [[fields]] name = "graph" type = "::FlexFlow::V1DataflowGraph" - diff --git a/lib/pcg/include/pcg/file_format/v1/v1.h b/lib/pcg/include/pcg/file_format/v1/v1.h deleted file mode 100644 index e2557af4f5..0000000000 --- a/lib/pcg/include/pcg/file_format/v1/v1.h +++ /dev/null @@ -1,9 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_V1_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_V1_H - -#include "graphs.h" -#include "pcg/computation_graph.h" - -namespace FlexFlow {} - -#endif diff --git a/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/json.h b/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/json.h new file mode 100644 index 0000000000..a1ca0aceed --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/json.h @@ -0,0 +1,29 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_V1_BINARY_SP_DECOMPOSITION_JSON_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_V1_BINARY_SP_DECOMPOSITION_JSON_H + +#include "pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.dtg.h" +#include + +namespace nlohmann { + +template <> +struct adl_serializer<::FlexFlow::V1BinarySPDecomposition> { + static ::FlexFlow::V1BinarySPDecomposition from_json(json const &); + static void to_json(json &, ::FlexFlow::V1BinarySPDecomposition const &); +}; + +template <> +struct adl_serializer<::FlexFlow::V1BinarySeriesSplit> { + static ::FlexFlow::V1BinarySeriesSplit from_json(json const &); + static void to_json(json &, ::FlexFlow::V1BinarySeriesSplit const &); +}; + +template <> +struct adl_serializer<::FlexFlow::V1BinaryParallelSplit> { + static ::FlexFlow::V1BinaryParallelSplit from_json(json const &); + static void to_json(json &, ::FlexFlow::V1BinaryParallelSplit const &); +}; + +} // namespace nlohmann + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_parallel_split.struct.toml b/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_parallel_split.struct.toml new file mode 100644 index 0000000000..d2d0c3bc77 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_parallel_split.struct.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "V1BinaryParallelSplit" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct V1BinarySPDecomposition" +] + +post_includes = [ + "pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.dtg.h", +] + +[[fields]] +name = "left_child" +type = "::FlexFlow::V1BinarySPDecomposition" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::V1BinarySPDecomposition" +indirect = true diff --git a/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_series_split.struct.toml b/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_series_split.struct.toml new file mode 100644 index 0000000000..317fa8b6ce --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_series_split.struct.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "V1BinarySeriesSplit" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct V1BinarySPDecomposition" +] + +post_includes = [ + "pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.dtg.h", +] + +[[fields]] +name = "left_child" +type = "::FlexFlow::V1BinarySPDecomposition" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::V1BinarySPDecomposition" +indirect = true diff --git a/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.variant.toml b/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.variant.toml new file mode 100644 index 0000000000..0fe0b1761f --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.variant.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "V1BinarySPDecomposition" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_series_split.dtg.h", + "pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_parallel_split.dtg.h", +] + +[[values]] +type = "::FlexFlow::V1BinarySeriesSplit" +key = "series" + +[[values]] +type = "::FlexFlow::V1BinaryParallelSplit" +key = "parallel" + +[[values]] +type = "int" +key = "leaf" diff --git a/lib/pcg/include/pcg/file_format/v1/v1_computation_graph.h b/lib/pcg/include/pcg/file_format/v1/v1_computation_graph.h new file mode 100644 index 0000000000..5590d6999b --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/v1_computation_graph.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_V1_COMPUTATION_GRAPH_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_V1_COMPUTATION_GRAPH_H + +#include "pcg/computation_graph.dtg.h" +#include "pcg/file_format/v1/v1_computation_graph.dtg.h" +#include "pcg/layer_guid_t.dtg.h" + +namespace FlexFlow { + +V1ComputationGraph to_v1(ComputationGraph const &); + +std::pair> + to_v1_including_node_numbering(ComputationGraph const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/v1_computation_graph.struct.toml b/lib/pcg/include/pcg/file_format/v1/v1_computation_graph.struct.toml new file mode 100644 index 0000000000..0d7135ec74 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/v1_computation_graph.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "V1ComputationGraph" +features = [ + "eq", + "hash", + "fmt", + "json", +] + +includes = [ + "pcg/layer_attrs.dtg.h", + "pcg/tensor_attrs.dtg.h", + "pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.dtg.h", +] + +[[fields]] +name = "raw_graph" +type = "::FlexFlow::V1LabelledDataflowGraph<::FlexFlow::LayerAttrs, ::FlexFlow::TensorAttrs>" diff --git a/lib/pcg/include/pcg/file_format/v1/v1_parallel_computation_graph.h b/lib/pcg/include/pcg/file_format/v1/v1_parallel_computation_graph.h new file mode 100644 index 0000000000..aceb59f5af --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/v1_parallel_computation_graph.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_V1_PARALLEL_COMPUTATION_GRAPH_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_V1_PARALLEL_COMPUTATION_GRAPH_H + +#include "pcg/file_format/v1/v1_parallel_computation_graph.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" + +namespace FlexFlow { + +V1ParallelComputationGraph to_v1(ParallelComputationGraph const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/v1_parallel_computation_graph.struct.toml b/lib/pcg/include/pcg/file_format/v1/v1_parallel_computation_graph.struct.toml new file mode 100644 index 0000000000..16be4a9561 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/v1_parallel_computation_graph.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "V1ParallelComputationGraph" +features = [ + "eq", + "hash", + "fmt", + "json", +] + +includes = [ + "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h", + "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h", + "pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.dtg.h", +] + +[[fields]] +name = "raw_graph" +type = "::FlexFlow::V1LabelledDataflowGraph<::FlexFlow::ParallelLayerAttrs, ::FlexFlow::ParallelTensorAttrs>" diff --git a/lib/pcg/include/pcg/initializer_attrs.variant.toml b/lib/pcg/include/pcg/initializer_attrs.variant.toml index 1ea9ce05a6..c67999ff32 100644 --- a/lib/pcg/include/pcg/initializer_attrs.variant.toml +++ b/lib/pcg/include/pcg/initializer_attrs.variant.toml @@ -11,9 +11,11 @@ features = [ includes = [ "pcg/initializers/glorot_uniform_attrs.dtg.h", + "pcg/initializers/glorot_normal_attrs.dtg.h", "pcg/initializers/zero_initializer_attrs.dtg.h", "pcg/initializers/uniform_initializer_attrs.h", "pcg/initializers/norm_initializer_attrs.dtg.h", + "pcg/initializers/truncated_normal_initializer_attrs.dtg.h", "pcg/initializers/constant_initializer_attrs.dtg.h", ] @@ -21,6 +23,10 @@ includes = [ type = "::FlexFlow::GlorotUniformAttrs" key = "glorot_uniform" +[[values]] +type = "::FlexFlow::GlorotNormalAttrs" +key = "glorot_normal" + [[values]] type = "::FlexFlow::ZeroInitializerAttrs" key = "zero" @@ -33,6 +39,10 @@ key = "uniform" type = "::FlexFlow::NormInitializerAttrs" key = "normal" +[[values]] +type = "::FlexFlow::TruncatedNormalInitializerAttrs" +key = "truncated_normal" + [[values]] type = "::FlexFlow::ConstantInitializerAttrs" key = "constant" diff --git a/lib/pcg/include/pcg/initializers/constant_initializer_attrs.struct.toml b/lib/pcg/include/pcg/initializers/constant_initializer_attrs.struct.toml index 12917d0989..4e3c31bd36 100644 --- a/lib/pcg/include/pcg/initializers/constant_initializer_attrs.struct.toml +++ b/lib/pcg/include/pcg/initializers/constant_initializer_attrs.struct.toml @@ -10,12 +10,7 @@ features = [ ] includes = [ - "op-attrs/datatype.h", - "utils/json.h", -] - -src_includes = [ - "utils/fmt/variant.h", + "op-attrs/datatype_value.dtg.h", ] [[fields]] diff --git a/lib/pcg/include/pcg/initializers/glorot_normal_attrs.struct.toml b/lib/pcg/include/pcg/initializers/glorot_normal_attrs.struct.toml new file mode 100644 index 0000000000..fd0d8eb9be --- /dev/null +++ b/lib/pcg/include/pcg/initializers/glorot_normal_attrs.struct.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "GlorotNormalAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "seed" +type = "int" diff --git a/lib/pcg/include/pcg/initializers/truncated_normal_initializer_attrs.struct.toml b/lib/pcg/include/pcg/initializers/truncated_normal_initializer_attrs.struct.toml new file mode 100644 index 0000000000..9e4ec0272d --- /dev/null +++ b/lib/pcg/include/pcg/initializers/truncated_normal_initializer_attrs.struct.toml @@ -0,0 +1,30 @@ +namespace = "FlexFlow" +name = "TruncatedNormalInitializerAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "seed" +type = "int" + +[[fields]] +name = "mean" +type = "float" + +[[fields]] +name = "stddev" +type = "float" + +[[fields]] +name = "min_cutoff" +type = "float" + +[[fields]] +name = "max_cutoff" +type = "float" diff --git a/lib/pcg/include/pcg/layer_attrs.struct.toml b/lib/pcg/include/pcg/layer_attrs.struct.toml index d062f6cd78..8290795174 100644 --- a/lib/pcg/include/pcg/layer_attrs.struct.toml +++ b/lib/pcg/include/pcg/layer_attrs.struct.toml @@ -13,11 +13,11 @@ includes = [ "op-attrs/computation_graph_op_attrs.dtg.h", "utils/stack_string.h", "", - "utils/json.h" ] src_includes = [ "utils/fmt/optional.h", + "utils/json/optional.h", ] [[fields]] diff --git a/lib/pcg/include/pcg/machine_space_coordinate.struct.toml b/lib/pcg/include/pcg/machine_space_coordinate.struct.toml new file mode 100644 index 0000000000..9b197a74c9 --- /dev/null +++ b/lib/pcg/include/pcg/machine_space_coordinate.struct.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "MachineSpaceCoordinate" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "pcg/device_type.dtg.h", +] + +[[fields]] +name = "node_idx" +type = "int" + +[[fields]] +name = "device_idx" +type = "int" + +[[fields]] +name = "device_type" +type = "::FlexFlow::DeviceType" diff --git a/lib/pcg/include/pcg/machine_space_offset.h b/lib/pcg/include/pcg/machine_space_offset.h new file mode 100644 index 0000000000..2f702cc518 --- /dev/null +++ b/lib/pcg/include/pcg/machine_space_offset.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_MACHINE_SPACE_OFFSET_H +#define _FLEXFLOW_PCG_INCLUDE_MACHINE_SPACE_OFFSET_H + +#include "pcg/machine_space_coordinate.dtg.h" +#include "pcg/machine_space_offset.dtg.h" + +namespace FlexFlow { + +MachineSpaceOffset get_machine_space_offset_from_coordinate( + MachineSpaceCoordinate const &start, MachineSpaceCoordinate const &coord); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/machine_space_offset.struct.toml b/lib/pcg/include/pcg/machine_space_offset.struct.toml new file mode 100644 index 0000000000..3f6eab38fd --- /dev/null +++ b/lib/pcg/include/pcg/machine_space_offset.struct.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "MachineSpaceOffset" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "pcg/device_type.dtg.h", +] + +[[fields]] +name = "node_offset" +type = "int" + +[[fields]] +name = "device_offset" +type = "int" + +[[fields]] +name = "device_type" +type = "::FlexFlow::DeviceType" diff --git a/lib/pcg/include/pcg/machine_specification.h b/lib/pcg/include/pcg/machine_specification.h index f66723b0ff..6ffa9900c2 100644 --- a/lib/pcg/include/pcg/machine_specification.h +++ b/lib/pcg/include/pcg/machine_specification.h @@ -1,6 +1,25 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_MACHINE_SPECIFICATION_H #define _FLEXFLOW_PCG_INCLUDE_PCG_MACHINE_SPECIFICATION_H -namespace FlexFlow {} // namespace FlexFlow +#include "pcg/device_id_t.dtg.h" +#include "pcg/device_type.dtg.h" +#include "pcg/machine_space_coordinate.dtg.h" +#include "pcg/machine_specification.dtg.h" + +namespace FlexFlow { + +int get_num_gpus(MachineSpecification const &ms); +int get_num_cpus(MachineSpecification const &ms); +int get_num_devices(MachineSpecification const &ms, + DeviceType const &device_type); +int get_num_devices_per_node(MachineSpecification const &ms, + DeviceType const &device_type); + +bool is_valid_machine_space_coordinate(MachineSpecification const &ms, + MachineSpaceCoordinate const &coord); + +device_id_t get_device_id(MachineSpecification const &ms, + MachineSpaceCoordinate const &coord); +} // namespace FlexFlow #endif diff --git a/lib/pcg/include/pcg/machine_specification_dimension.enum.toml b/lib/pcg/include/pcg/machine_specification_dimension.enum.toml new file mode 100644 index 0000000000..837b4306da --- /dev/null +++ b/lib/pcg/include/pcg/machine_specification_dimension.enum.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "MachineSpecificationDimension" +features = [ + "hash", + "json", + "fmt", + "rapidcheck", +] + +[[values]] +name = "INTER_NODE" + +[[values]] +name = "INTRA_NODE" diff --git a/lib/pcg/include/pcg/machine_view.h b/lib/pcg/include/pcg/machine_view.h index 56abf5aa20..293227b7a1 100644 --- a/lib/pcg/include/pcg/machine_view.h +++ b/lib/pcg/include/pcg/machine_view.h @@ -1,50 +1,41 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_MACHINE_VIEW_H #define _FLEXFLOW_PCG_INCLUDE_PCG_MACHINE_VIEW_H -#include "pcg/cpu_id_t.dtg.h" -#include "pcg/device_id.h" +#include "machine_specification.dtg.h" +#include "machine_view.dtg.h" #include "pcg/device_id_t.dtg.h" -#include "pcg/device_type.dtg.h" -#include "pcg/gpu_id_t.dtg.h" -#include "pcg/machine_view.dtg.h" -#include "pcg/num_points_t.dtg.h" -#include "pcg/side_size_t.dtg.h" +#include "pcg/operator_task_space.dtg.h" +#include "task_space_coordinate.dtg.h" #include -#include +#include +#include namespace FlexFlow { -std::vector device_ids(MachineView const &); -size_t num_dims(MachineView const &); -std::size_t num_devices(MachineView const &); -DeviceType get_device_type(MachineView const &); - -MachineView make_1d_machine_view(gpu_id_t start, gpu_id_t stop, int stride = 1); -MachineView make_1d_machine_view(cpu_id_t start, cpu_id_t stop, int stride = 1); -MachineView - make_1d_machine_view(device_id_t start, device_id_t stop, int stride = 1); - -MachineView make_1d_machine_view(gpu_id_t start, - num_points_t num_points, - int stride = 1); -MachineView make_1d_machine_view(cpu_id_t start, - num_points_t num_points, - int stride = 1); -MachineView make_1d_machine_view(device_id_t start, - num_points_t num_points, - int stride = 1); - -MachineView make_1d_machine_view(gpu_id_t start, - side_size_t interval_size, - int stride = 1); -MachineView make_1d_machine_view(cpu_id_t start, - side_size_t interval_size, - int stride = 1); -MachineView make_1d_machine_view(device_id_t start, - side_size_t interval_size, - int stride = 1); - -MachineView make_1d_machine_view(device_id_t start, size_t interval_size); +size_t num_dims(MachineView const &mv); + +DeviceType get_device_type(MachineView const &mv); + +std::vector get_strides(MachineView const &mv); + +std::vector + get_dimensions(MachineView const &mv); + +MachineView machine_view_from_strides_and_machine_spec_dimensions( + MachineSpaceCoordinate const &start, + std::vector const &strides, + std::vector const &dims); + +std::optional + get_machine_space_coordinate(OperatorTaskSpace const &task, + MachineView const &mv, + TaskSpaceCoordinate const &coordinates, + MachineSpecification const &ms); + +std::unordered_set + get_machine_space_coordinates(OperatorTaskSpace const &task, + MachineView const &mv, + MachineSpecification const &ms); } // namespace FlexFlow diff --git a/lib/pcg/include/pcg/machine_view.struct.toml b/lib/pcg/include/pcg/machine_view.struct.toml index c97731991f..e4de69eafc 100644 --- a/lib/pcg/include/pcg/machine_view.struct.toml +++ b/lib/pcg/include/pcg/machine_view.struct.toml @@ -9,15 +9,21 @@ features = [ "fmt", ] -includes = [ - "pcg/device_id_t.dtg.h", - "pcg/strided_rectangle.dtg.h", +includes = [ + "pcg/machine_view_dimension.dtg.h", + "pcg/machine_space_coordinate.dtg.h" ] +src_includes = [ + "utils/fmt/vector.h", + "utils/hash/vector.h" +] + + [[fields]] name = "start" -type = "::FlexFlow::device_id_t" +type = "::FlexFlow::MachineSpaceCoordinate" [[fields]] -name = "rect" -type = "::FlexFlow::StridedRectangle" +name = "dimensions" +type = "std::vector<::FlexFlow::MachineViewDimension>" diff --git a/lib/pcg/include/pcg/machine_view_dimension.struct.toml b/lib/pcg/include/pcg/machine_view_dimension.struct.toml new file mode 100644 index 0000000000..03b0ac51e4 --- /dev/null +++ b/lib/pcg/include/pcg/machine_view_dimension.struct.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "MachineViewDimension" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "pcg/machine_specification_dimension.dtg.h", + "pcg/stride_t.dtg.h", +] + + +[[fields]] +name = "stride" +type = "::FlexFlow::stride_t" + +[[fields]] +name = "projection" +type = "::FlexFlow::MachineSpecificationDimension" diff --git a/lib/pcg/include/pcg/multi_dimensional_stride.struct.toml b/lib/pcg/include/pcg/multi_dimensional_stride.struct.toml new file mode 100644 index 0000000000..9fa5a77f77 --- /dev/null +++ b/lib/pcg/include/pcg/multi_dimensional_stride.struct.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "MultiDimensionalStride" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "", + "pcg/stride_t.dtg.h", +] + +src_includes = [ + "utils/hash/vector.h", + "utils/fmt/vector.h" + +] + +[[fields]] +name = "raw_strides" +type = "std::vector<::FlexFlow::stride_t>" diff --git a/lib/pcg/include/pcg/operator_task_space.h b/lib/pcg/include/pcg/operator_task_space.h new file mode 100644 index 0000000000..61cab4eff1 --- /dev/null +++ b/lib/pcg/include/pcg/operator_task_space.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_OPERATOR_TASK_SPACE_H +#define _FLEXFLOW_PCG_INCLUDE_OPERATOR_TASK_SPACE_H + +#include "pcg/operator_task_space.dtg.h" +#include "pcg/task_space_coordinate.dtg.h" +#include +#include + +namespace FlexFlow { + +std::unordered_set + get_task_space_coordinates(OperatorTaskSpace const &task); + +TaskSpaceCoordinate + get_task_space_maximum_coordinate(OperatorTaskSpace const &task); + +size_t num_dims(OperatorTaskSpace const &task); +size_t num_tasks(OperatorTaskSpace const &task); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/operator_task_space.struct.toml b/lib/pcg/include/pcg/operator_task_space.struct.toml new file mode 100644 index 0000000000..3ab8b83173 --- /dev/null +++ b/lib/pcg/include/pcg/operator_task_space.struct.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "OperatorTaskSpace" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "", +] + +src_includes = [ + "utils/fmt/vector.h", + "utils/hash/vector.h" +] + +[[fields]] +name = "degrees" +type = "std::vector" diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h index 9150681070..c740e1ffd2 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_H #include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" #include "pcg/parallel_computation_graph/parallel_layer_added_result.dtg.h" #include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" #include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" @@ -21,17 +22,37 @@ ParallelLayerAddedResult std::vector const &inputs, std::vector const &output_labels); +ParallelLayerAddedResult + pcg_add_input_layer(ParallelComputationGraph &pcg, + ParallelTensorShape const &tensor_shape); + +std::unordered_set + get_pcg_edges_from_layer_to_layer(ParallelComputationGraph const &, + parallel_layer_guid_t const &, + parallel_layer_guid_t const &); + std::vector - get_layer_inputs(ParallelComputationGraph const &, - parallel_layer_guid_t const &); + get_incoming_tensors(ParallelComputationGraph const &, + parallel_layer_guid_t const &); std::vector get_layer_outputs(ParallelComputationGraph const &, parallel_layer_guid_t const &); +std::vector + get_incoming_inputs(ParallelComputationGraph const &, + parallel_layer_guid_t const &); +std::vector + get_incoming_weights(ParallelComputationGraph const &, + parallel_layer_guid_t const &); + ParallelLayerAttrs get_parallel_layer_attrs(ParallelComputationGraph const &, parallel_layer_guid_t const &); +PCGOperatorAttrs pcg_get_op_attrs(ParallelComputationGraph const &, + parallel_layer_guid_t const &); ParallelTensorAttrs get_parallel_tensor_attrs(ParallelComputationGraph const &, parallel_tensor_guid_t const &); +ParallelTensorShape get_parallel_tensor_shape(ParallelComputationGraph const &, + parallel_tensor_guid_t const &); std::vector topological_ordering(ParallelComputationGraph const &); @@ -40,6 +61,11 @@ parallel_layer_guid_t get_parallel_layer_by_name(ParallelComputationGraph const &pcg, std::string const &name); +ParallelComputationGraph without_layer_names(ParallelComputationGraph const &); + +bool pcgs_are_isomorphic(ParallelComputationGraph const &, + ParallelComputationGraph const &); + } // namespace FlexFlow #endif diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h index 20e947ad58..019b120936 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h @@ -13,7 +13,7 @@ struct ParallelComputationGraphBuilder { parallel_tensor_guid_t create_input_tensor( ParallelTensorShape const &shape, - bool create_grad = true, + CreateGrad create_grad = CreateGrad::YES, std::optional const &name = std::nullopt); parallel_tensor_guid_t @@ -54,7 +54,8 @@ struct ParallelComputationGraphBuilder { std::optional activation = std::nullopt, bool use_bias = true, DataType data_type = DataType::FLOAT, - std::optional const &kernel_initializer = std::nullopt, + std::optional const &projection_initializer = + std::nullopt, std::optional const &bias_initializer = std::nullopt, std::optional const &name = std::nullopt); @@ -86,7 +87,10 @@ struct ParallelComputationGraphBuilder { parallel_tensor_guid_t batch_norm(parallel_tensor_guid_t const &input, - bool relu = true, + bool affine, + std::optional const &activation, + float eps, + std::optional const &momentum, std::optional const &name = std::nullopt); parallel_tensor_guid_t diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.struct.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.struct.toml index 60cfc426cc..027b9f6c80 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.struct.toml +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.struct.toml @@ -10,13 +10,15 @@ features = [ ] includes = [ - "op-attrs/operator_attrs.h", + "op-attrs/pcg_operator_attrs.dtg.h", "utils/stack_string.h", "", ] src_includes = [ "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", ] [[fields]] diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_attrs.struct.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_attrs.struct.toml index d9e6cf113b..323932fec6 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_attrs.struct.toml +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_attrs.struct.toml @@ -19,6 +19,8 @@ includes = [ src_includes = [ "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", ] [[fields]] diff --git a/lib/pcg/include/pcg/start_invariant_machine_view.h b/lib/pcg/include/pcg/start_invariant_machine_view.h new file mode 100644 index 0000000000..f5091c69d1 --- /dev/null +++ b/lib/pcg/include/pcg/start_invariant_machine_view.h @@ -0,0 +1,47 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_START_INVARIANT_MACHINE_VIEW_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_START_INVARIANT_MACHINE_VIEW_H + +#include "pcg/machine_space_offset.h" +#include "pcg/machine_specification.dtg.h" +#include "pcg/machine_view.dtg.h" +#include "pcg/operator_task_space.dtg.h" +#include "pcg/start_invariant_machine_view.dtg.h" +#include "pcg/task_space_coordinate.dtg.h" +#include + +namespace FlexFlow { + +MachineView + machine_view_from_start_invariant(StartInvariantMachineView const &mv, + MachineSpaceCoordinate const &start); +StartInvariantMachineView + start_invariant_from_machine_view(MachineView const &mv); + +size_t num_dims(StartInvariantMachineView const &mv); + +DeviceType get_device_type(StartInvariantMachineView const &mv); + +std::vector get_strides(StartInvariantMachineView const &mv); + +std::vector + get_dimensions(StartInvariantMachineView const &mv); + +StartInvariantMachineView + start_invariant_machine_view_from_strides_and_machine_spec_dimensions( + std::vector const &strides, + std::vector const &dims); + +std::optional + get_machine_space_offset(OperatorTaskSpace const &task, + StartInvariantMachineView const &mv, + TaskSpaceCoordinate const &coordinates, + MachineSpecification const &ms); + +std::unordered_set + get_machine_space_offsets(OperatorTaskSpace const &task, + StartInvariantMachineView const &mv, + MachineSpecification const &ms); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/start_invariant_machine_view.struct.toml b/lib/pcg/include/pcg/start_invariant_machine_view.struct.toml new file mode 100644 index 0000000000..a1b2b40524 --- /dev/null +++ b/lib/pcg/include/pcg/start_invariant_machine_view.struct.toml @@ -0,0 +1,29 @@ +namespace = "FlexFlow" +name = "StartInvariantMachineView" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "pcg/machine_view_dimension.dtg.h", + "pcg/device_type.dtg.h" +] + +src_includes = [ + "utils/fmt/vector.h", + "utils/hash/vector.h", +] + +[[fields]] +name = "dimensions" +type = "std::vector<::FlexFlow::MachineViewDimension>" + + +[[fields]] +name = "device_type" +type = "::FlexFlow::DeviceType" diff --git a/lib/pcg/include/pcg/side_size_t.struct.toml b/lib/pcg/include/pcg/stride_t.struct.toml similarity index 87% rename from lib/pcg/include/pcg/side_size_t.struct.toml rename to lib/pcg/include/pcg/stride_t.struct.toml index dbaad4fedb..a764497b8b 100644 --- a/lib/pcg/include/pcg/side_size_t.struct.toml +++ b/lib/pcg/include/pcg/stride_t.struct.toml @@ -1,5 +1,5 @@ namespace = "FlexFlow" -name = "side_size_t" +name = "stride_t" features = [ "eq", "ord", diff --git a/lib/pcg/include/pcg/strided_rectangle.h b/lib/pcg/include/pcg/strided_rectangle.h deleted file mode 100644 index 9c3b8eeda9..0000000000 --- a/lib/pcg/include/pcg/strided_rectangle.h +++ /dev/null @@ -1,17 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_STRIDED_RECTANGLE_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_STRIDED_RECTANGLE_H - -#include "op-attrs/ff_dim.dtg.h" -#include "pcg/side_size_t.dtg.h" -#include "pcg/strided_rectangle.dtg.h" - -namespace FlexFlow { - -size_t get_num_dims(StridedRectangle const &); -StridedRectangleSide get_side_at_idx(StridedRectangle const &rect, - ff_dim_t const &idx); -num_points_t get_num_points(StridedRectangle const &rect); - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/strided_rectangle.struct.toml b/lib/pcg/include/pcg/strided_rectangle.struct.toml deleted file mode 100644 index 3dfd90e296..0000000000 --- a/lib/pcg/include/pcg/strided_rectangle.struct.toml +++ /dev/null @@ -1,19 +0,0 @@ -namespace = "FlexFlow" -name = "StridedRectangle" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "pcg/strided_rectangle_side.dtg.h", - "op-attrs/dim_ordered.h", -] - -[[fields]] -name = "sides" -type = "::FlexFlow::FFOrdered<::FlexFlow::StridedRectangleSide>" diff --git a/lib/pcg/include/pcg/strided_rectangle_side.h b/lib/pcg/include/pcg/strided_rectangle_side.h deleted file mode 100644 index 1486b73143..0000000000 --- a/lib/pcg/include/pcg/strided_rectangle_side.h +++ /dev/null @@ -1,15 +0,0 @@ -#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_STRIDED_RECTANGLE_SIDE_H -#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_STRIDED_RECTANGLE_SIDE_H - -#include "pcg/side_size_t.dtg.h" -#include "pcg/strided_rectangle_side.dtg.h" - -namespace FlexFlow { - -StridedRectangleSide strided_side_from_size_and_stride(side_size_t, int stride); - -side_size_t get_side_size(StridedRectangleSide const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/strided_rectangle_side.struct.toml b/lib/pcg/include/pcg/strided_rectangle_side.struct.toml deleted file mode 100644 index f26adfafd5..0000000000 --- a/lib/pcg/include/pcg/strided_rectangle_side.struct.toml +++ /dev/null @@ -1,22 +0,0 @@ -namespace = "FlexFlow" -name = "StridedRectangleSide" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "pcg/num_points_t.dtg.h", -] - -[[fields]] -name = "num_points" -type = "::FlexFlow::num_points_t" - -[[fields]] -name = "stride" -type = "int" diff --git a/lib/pcg/include/pcg/task_space_coordinate.struct.toml b/lib/pcg/include/pcg/task_space_coordinate.struct.toml new file mode 100644 index 0000000000..65aea167cb --- /dev/null +++ b/lib/pcg/include/pcg/task_space_coordinate.struct.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "TaskSpaceCoordinate" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "", +] + +src_includes = [ + "utils/hash/vector.h", + "utils/fmt/vector.h", +] + +[[fields]] +name = "raw_coord" +type = "std::vector" diff --git a/lib/pcg/include/pcg/tensor_attrs.struct.toml b/lib/pcg/include/pcg/tensor_attrs.struct.toml index c0b89cfc99..7f16e60914 100644 --- a/lib/pcg/include/pcg/tensor_attrs.struct.toml +++ b/lib/pcg/include/pcg/tensor_attrs.struct.toml @@ -19,6 +19,7 @@ includes = [ src_includes = [ "utils/fmt/optional.h", + "utils/json/optional.h", ] [[fields]] diff --git a/lib/pcg/src/file_format.cc b/lib/pcg/src/file_format.cc deleted file mode 100644 index bb01ac2dbf..0000000000 --- a/lib/pcg/src/file_format.cc +++ /dev/null @@ -1,14 +0,0 @@ -#include "pcg/file_format/v1/v1.h" - -namespace FlexFlow { - -/* void thing() { */ -/* static_assert(is_visitable::value, ""); */ - -/* json j; */ -/* auto g = j.get(); */ - -/* /1* IllBehaved v = j.get(); *1/ */ -/* } */ - -} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/graphs.cc b/lib/pcg/src/file_format/v1/graphs.cc deleted file mode 100644 index de8d5dddb4..0000000000 --- a/lib/pcg/src/file_format/v1/graphs.cc +++ /dev/null @@ -1,16 +0,0 @@ -#include "pcg/file_format/v1/graphs.h" -#include "pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h" -#include "utils/graph/algorithms.h" -#include "utils/integer_conversions.h" - -namespace FlexFlow { - -V1ComputationGraph to_v1(ComputationGraph const &g) { - return to_v1(g.raw_graph); -} - -V1ParallelComputationGraph to_v1(ParallelComputationGraph const &g) { - return to_v1(g.raw_graph); -} - -} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/computation_graph.cc b/lib/pcg/src/pcg/computation_graph.cc index deaa440ef8..3d1bc629e4 100644 --- a/lib/pcg/src/pcg/computation_graph.cc +++ b/lib/pcg/src/pcg/computation_graph.cc @@ -1,11 +1,22 @@ #include "pcg/computation_graph.h" +#include "op-attrs/computation_graph_op_attrs.h" +#include "op-attrs/get_incoming_tensor_roles.h" +#include "utils/containers/filtrans.h" #include "utils/containers/get_only.h" #include "utils/containers/reversed.h" #include "utils/containers/transform.h" #include "utils/graph/dataflow_graph/algorithms.h" +#include "utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.h" +#include "utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.h" +#include "utils/graph/digraph/algorithms/get_subgraph_successors.h" #include "utils/graph/digraph/algorithms/get_topological_ordering.h" #include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" +#include "utils/graph/labelled_dataflow_graph/algorithms/find_isomorphism.h" +#include "utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.h" +#include "utils/graph/labelled_dataflow_graph/algorithms/view_as_labelled_open_dataflow_graph.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.h" #include "utils/graph/node/algorithms.h" +#include "utils/record_formatter.h" namespace FlexFlow { @@ -20,6 +31,23 @@ std::unordered_set get_layers(ComputationGraph const &cg) { [&](Node const &n) { return layer_guid_t{n}; }); } +LayerAddedResult add_layer(ComputationGraph &computation_graph, + LayerAttrs const &attrs, + std::vector const &inputs, + std::vector const &outputs) { + std::vector raw_inputs = transform( + inputs, [](tensor_guid_t const &t) { return t.raw_graph_output; }); + + NodeAddedResult added = + computation_graph.raw_graph.add_node(attrs, raw_inputs, outputs); + + return LayerAddedResult{ + layer_guid_t{added.node}, + transform(added.outputs, + [](DataflowOutput const &o) { return tensor_guid_t{o}; }), + }; +} + TensorAttrs get_tensor_attrs(ComputationGraph const &cg, tensor_guid_t const &t) { return cg.raw_graph.at(t.raw_graph_output); @@ -39,8 +67,7 @@ std::vector topological_ordering(ComputationGraph const &cg) { std::vector reverse_topological_ordering(ComputationGraph const &cg) { - std::vector layers = - reversed>(get_topological_ordering(cg.raw_graph)); + std::vector layers = reversed(get_topological_ordering(cg.raw_graph)); return transform( layers, [&](Node const &e) -> layer_guid_t { return layer_guid_t{e}; }); } @@ -57,6 +84,86 @@ std::vector get_incoming_tensors(ComputationGraph const &cg, [](DataflowOutput const &o) { return tensor_guid_t{o}; }); } +static std::vector + get_incoming_tensors_with_role(ComputationGraph const &cg, + layer_guid_t const &l, + IncomingTensorRole desired_role) { + ComputationGraphOpAttrs attrs = get_layer_attrs(cg, l).attrs; + + std::vector incoming_tensors = get_incoming_tensors(cg, l); + + std::vector incoming_tensor_roles = + get_incoming_tensor_roles(attrs, incoming_tensors.size()); + + assert(incoming_tensors.size() == incoming_tensor_roles.size()); + + std::vector result = + filtrans(zip(incoming_tensors, incoming_tensor_roles), + [&](std::pair const &p) + -> std::optional { + tensor_guid_t tensor = p.first; + IncomingTensorRole role = p.second; + + if (role == desired_role) { + return tensor; + } else { + return std::nullopt; + } + }); + return result; +} + +std::vector get_incoming_inputs(ComputationGraph const &cg, + layer_guid_t const &l) { + return get_incoming_tensors_with_role(cg, l, IncomingTensorRole::INPUT); +} + +std::vector get_incoming_weights(ComputationGraph const &cg, + layer_guid_t const &l) { + return get_incoming_tensors_with_role(cg, l, IncomingTensorRole::WEIGHT); +} + +std::unordered_set get_subgraph_incoming_edges( + ComputationGraph const &cg, + std::unordered_set const &subgraph_nodes) { + + std::unordered_set raw_subgraph_nodes = transform( + subgraph_nodes, [](layer_guid_t const &l) { return l.raw_node; }); + std::unordered_set raw_incoming_edges = + get_subgraph_incoming_edges(cg.raw_graph, raw_subgraph_nodes); + + return transform(raw_incoming_edges, [](DataflowEdge const &e) { + return ComputationGraphEdge{e}; + }); +} + +std::unordered_set get_subgraph_outgoing_edges( + ComputationGraph const &cg, + std::unordered_set const &subgraph_nodes) { + + std::unordered_set raw_subgraph_nodes = transform( + subgraph_nodes, [](layer_guid_t const &l) { return l.raw_node; }); + std::unordered_set raw_outgoing_edges = + get_subgraph_outgoing_edges(cg.raw_graph, raw_subgraph_nodes); + + return transform(raw_outgoing_edges, [](DataflowEdge const &e) { + return ComputationGraphEdge{e}; + }); +} + +std::unordered_set get_subgraph_successors( + ComputationGraph const &cg, + std::unordered_set const &subgraph_nodes) { + + std::unordered_set raw_subgraph_nodes = transform( + subgraph_nodes, [](layer_guid_t const &l) { return l.raw_node; }); + std::unordered_set raw_successors = + get_subgraph_successors(cg.raw_graph, raw_subgraph_nodes); + + return transform(raw_successors, + [](Node const &n) { return layer_guid_t{n}; }); +} + LayerAttrs get_layer_attrs(ComputationGraph const &cg, layer_guid_t const &n) { return cg.raw_graph.at(n.raw_node); } @@ -70,4 +177,60 @@ layer_guid_t get_layer_by_name(ComputationGraph const &cg, return get_only(found); } +ComputationGraph without_layer_names(ComputationGraph const &cg) { + return ComputationGraph{ + LabelledDataflowGraph::create_copy_of< + UnorderedSetLabelledOpenDataflowGraph>( + rewrite_node_labels(cg.raw_graph, + [](Node const &n, LayerAttrs const &old_attrs) { + LayerAttrs new_attrs = old_attrs; + new_attrs.name = std::nullopt; + return new_attrs; + })), + }; +} + +bool computation_graphs_are_isomorphic(ComputationGraph const &lhs, + ComputationGraph const &rhs) { + return find_isomorphism(without_layer_names(lhs).raw_graph, + without_layer_names(rhs).raw_graph) + .has_value(); +} + +std::string as_dot(ComputationGraph const &cg) { + std::function get_node_label = + [](LayerAttrs const &a) -> std::string { + RecordFormatter r = as_dot(a.attrs); + + if (a.name.has_value()) { + RecordFormatter rr; + rr << "Name" << a.name.value(); + r << rr; + } + + std::ostringstream oss; + oss << r; + return oss.str(); + }; + + std::function get_input_label = + [](TensorAttrs const &a) -> std::string { + RecordFormatter r; + + r << fmt::to_string(a.shape); + + std::ostringstream oss; + oss << r; + return oss.str(); + }; + + return as_dot(view_as_labelled_open_dataflow_graph(cg.raw_graph), + get_node_label, + get_input_label); +} + +void debug_print_dot(ComputationGraph const &cg) { + std::cout << as_dot(cg) << std::endl; +} + } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/computation_graph/computation_graph_edge.cc b/lib/pcg/src/pcg/computation_graph/computation_graph_edge.cc new file mode 100644 index 0000000000..0efa0620c4 --- /dev/null +++ b/lib/pcg/src/pcg/computation_graph/computation_graph_edge.cc @@ -0,0 +1,15 @@ +#include "pcg/computation_graph/computation_graph_edge.h" + +namespace FlexFlow { + +layer_guid_t + get_computation_graph_edge_src_layer(ComputationGraphEdge const &e) { + return layer_guid_t{e.raw_edge.src.node}; +} + +layer_guid_t + get_computation_graph_edge_dst_layer(ComputationGraphEdge const &e) { + return layer_guid_t{e.raw_edge.dst.node}; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/computation_graph_builder.cc b/lib/pcg/src/pcg/computation_graph_builder.cc index 3f2feaf619..dff647f5a1 100644 --- a/lib/pcg/src/pcg/computation_graph_builder.cc +++ b/lib/pcg/src/pcg/computation_graph_builder.cc @@ -1,20 +1,25 @@ #include "pcg/computation_graph_builder.h" #include "op-attrs/computation_graph_op_attrs.h" +#include "op-attrs/get_incoming_tensor_roles.h" #include "op-attrs/get_op_type.h" #include "op-attrs/get_output_shapes.h" #include "op-attrs/ops/attention.h" #include "op-attrs/ops/batch_norm.h" #include "op-attrs/ops/broadcast.h" +#include "op-attrs/ops/concat.h" #include "op-attrs/ops/conv_2d.h" #include "op-attrs/ops/dropout.h" #include "op-attrs/ops/element_binary.h" #include "op-attrs/ops/element_unary.h" #include "op-attrs/ops/embedding.h" +#include "op-attrs/ops/flat.h" #include "op-attrs/ops/gather.h" #include "op-attrs/ops/layer_norm.h" #include "op-attrs/ops/linear.h" +#include "op-attrs/ops/pool_2d.h" #include "op-attrs/ops/softmax.h" #include "op-attrs/ops/weight_attrs.dtg.h" +#include "op-attrs/tensor_dims.h" #include "pcg/computation_graph.h" #include "utils/containers/any_of.h" #include "utils/containers/concat_vectors.h" @@ -26,6 +31,16 @@ namespace FlexFlow { +static TensorAttrs make_weight_attrs( + TensorShape const &shape, + std::optional const &initializer_attrs) { + return TensorAttrs{shape, initializer_attrs, std::nullopt, CreateGrad::YES}; +} + +static TensorAttrs make_output_attrs(TensorShape const &shape) { + return TensorAttrs{shape, std::nullopt, std::nullopt, CreateGrad::YES}; +} + ComputationGraphBuilder::ComputationGraphBuilder() : computation_graph(make_empty_computation_graph()) {} @@ -33,83 +48,74 @@ TensorShape ComputationGraphBuilder::get_shape(tensor_guid_t const &t) const { return get_tensor_attrs(this->computation_graph, t).shape; } -tensor_guid_t ComputationGraphBuilder::create_tensor(TensorShape const &shape, - CreateGrad create_grad) { +tensor_guid_t ComputationGraphBuilder::create_input( + TensorShape const &shape, + CreateGrad create_grad, + std::optional const &maybe_name) { TensorAttrs tensor_attrs = TensorAttrs{shape, std::nullopt, std::nullopt, create_grad}; LayerAttrs layer_attrs = LayerAttrs{ ComputationGraphOpAttrs{InputAttrs{}}, - std::nullopt, + maybe_name, }; - return this->add_layer(layer_attrs, {}, {}, tensor_attrs); + return get_only(this->add_layer(layer_attrs, {}, {}, {tensor_attrs})); } -std::vector ComputationGraphBuilder::add_layer( - LayerAttrs const &layer, - std::vector const &inputs, - std::vector const &weights, - std::vector const &outputs) { - std::vector raw_weight_tensors; - for (auto const &kv : enumerate_vector(weights)) { - int weight_idx = kv.first; - TensorAttrs weight_tensor_attrs = kv.second; - - std::optional weight_name = - transform(layer.name, [&](std::string const &layer_name) { - return fmt::format("{}.weights[{}]", layer_name, weight_idx); - }); - LayerAttrs weight_layer_attrs = LayerAttrs{ - ComputationGraphOpAttrs{WeightAttrs{weight_tensor_attrs.shape}}, - weight_name, - }; - std::vector weight_layer_inputs = {}; - std::vector weight_output_attrs = {weight_tensor_attrs}; - raw_weight_tensors.push_back(get_only(this->computation_graph.raw_graph - .add_node(weight_layer_attrs, - weight_layer_inputs, - weight_output_attrs) - .outputs)); - } +tensor_guid_t ComputationGraphBuilder::create_weight( + TensorAttrs const &tensor_attrs, + std::optional const &maybe_name) { + LayerAttrs layer_attrs = LayerAttrs{ + ComputationGraphOpAttrs{InputAttrs{}}, + maybe_name, + }; - std::vector raw_inputs = transform( - inputs, [](tensor_guid_t const &t) { return t.raw_graph_output; }); - std::vector raw_outputs = - this->computation_graph.raw_graph - .add_node( - layer, concat_vectors(raw_inputs, raw_weight_tensors), outputs) - .outputs; - return transform(raw_outputs, - [](DataflowOutput const &o) { return tensor_guid_t{o}; }); + return get_only(this->add_layer(layer_attrs, + std::vector{}, + std::vector{}, + {tensor_attrs})); } -tensor_guid_t - ComputationGraphBuilder::add_layer(LayerAttrs const &layer, - std::vector const &inputs, - std::vector const &weights, - TensorAttrs const &output) { - std::vector outputs = {output}; - return get_only(this->add_layer(layer, inputs, weights, outputs)); +tensor_guid_t ComputationGraphBuilder::create_weight( + TensorShape const &shape, + CreateGrad create_grad, + std::optional const &initializer, + std::optional param_sync, + std::optional const &maybe_name) { + TensorAttrs tensor_attrs = + TensorAttrs{shape, initializer, param_sync, create_grad}; + + return this->create_weight(tensor_attrs, maybe_name); +} + +static void check_incoming_tensor_roles(LayerAttrs const &layer, + int num_inputs, + int num_weights) { + std::vector correct = + get_incoming_tensor_roles(layer.attrs, num_inputs + num_weights); + std::vector current = concat_vectors( + std::vector(num_inputs, IncomingTensorRole::INPUT), + std::vector(num_weights, IncomingTensorRole::WEIGHT)); + + if (correct != current) { + throw mk_runtime_error( + fmt::format("check_incoming_tensor_roles found deviation in incoming " + "tensors: expected {}, received {}", + correct, + current)); + } } std::vector ComputationGraphBuilder::add_layer( LayerAttrs const &layer, std::vector const &inputs, - std::vector const &weights, - std::vector const &outputs) { - return this->add_layer( - layer, inputs, weights, transform(outputs, [](TensorShape const &s) { - return TensorAttrs{s, std::nullopt, std::nullopt, CreateGrad::YES}; - })); -} + std::vector const &weights, + std::vector const &outputs) { + check_incoming_tensor_roles(layer, inputs.size(), weights.size()); -tensor_guid_t - ComputationGraphBuilder::add_layer(LayerAttrs const &layer, - std::vector const &inputs, - std::vector const &weights, - TensorShape const &output) { - return get_only(this->add_layer( - layer, inputs, weights, std::vector{output})); + LayerAddedResult added = ::FlexFlow::add_layer( + this->computation_graph, layer, concat_vectors(inputs, weights), outputs); + return added.outputs; } tensor_guid_t ComputationGraphBuilder::as_type(tensor_guid_t const &x, @@ -129,25 +135,29 @@ tensor_guid_t ComputationGraphBuilder::as_type(tensor_guid_t const &x, } } -tensor_guid_t - ComputationGraphBuilder::broadcast(tensor_guid_t const &input, - TensorShape const &target_shape, - std::string const &name) { +tensor_guid_t ComputationGraphBuilder::broadcast(tensor_guid_t const &input, + TensorDims const &target_dims, + std::string const &name) { TensorShape input_shape = this->get_shape(input); - if (!tensor_shape_is_broadcastable_to(input_shape, target_shape)) { + if (input_shape.dims == target_dims) { + return input; + } + + if (!tensor_dims_is_broadcastable_to(input_shape.dims, target_dims)) { throw mk_runtime_error(fmt::format( - "Cannot broadcast input tensor of shape {} to target shape {}", - input_shape, - target_shape)); + "Cannot broadcast input tensor of dims {} to target dims {}", + input_shape.dims, + target_dims)); } - BroadcastAttrs attrs = BroadcastAttrs{target_shape.dims}; + BroadcastAttrs attrs = BroadcastAttrs{target_dims}; LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; TensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, input_shape)); - return this->add_layer(layer, {input}, {}, output_shape); + return get_only( + this->add_layer(layer, {input}, {}, {make_output_attrs(output_shape)})); } tensor_guid_t @@ -184,7 +194,8 @@ tensor_guid_t ComputationGraphBuilder::element_unary( TensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); - return this->add_layer(layer, {input}, {}, output_shape); + return get_only( + this->add_layer(layer, {input}, {}, {make_output_attrs(output_shape)})); } tensor_guid_t ComputationGraphBuilder::element_binary( @@ -194,18 +205,18 @@ tensor_guid_t ComputationGraphBuilder::element_binary( std::optional const &maybe_name) { std::string name = maybe_name.value_or(get_default_name(op_type)); - TensorShape compute_shape = this->get_broadcast_target_shape({lhs, rhs}); + TensorDims compute_dims = this->get_broadcast_target_dims({lhs, rhs}); DataType compute_type = std::max(this->get_shape(lhs).data_type, this->get_shape(rhs).data_type); tensor_guid_t lhs_input = this->as_type( this->broadcast( - lhs, compute_shape, fmt::format("{}_inputl_broadcast", name)), + lhs, compute_dims, fmt::format("{}_inputl_broadcast", name)), compute_type, name + "_inputl_cast"); tensor_guid_t rhs_input = this->as_type( this->broadcast( - rhs, compute_shape, fmt::format("{}_inputr_broadcast", name)), + rhs, compute_dims, fmt::format("{}_inputr_broadcast", name)), compute_type, name + "_inputr_cast"); @@ -217,7 +228,8 @@ tensor_guid_t ComputationGraphBuilder::element_binary( TensorShape output_shape = throw_if_unexpected(get_output_shape( attrs, this->get_shape(lhs_input), this->get_shape(rhs_input))); - return this->add_layer(layer, {lhs_input, rhs_input}, {}, output_shape); + return get_only(this->add_layer( + layer, {lhs_input, rhs_input}, {}, {make_output_attrs(output_shape)})); } tensor_guid_t @@ -359,12 +371,6 @@ tensor_guid_t return this->element_unary(OperatorType::ELU, input, std::nullopt, name); } -static TensorAttrs make_weight_attrs( - TensorShape const &shape, - std::optional const &initializer_attrs) { - return TensorAttrs{shape, initializer_attrs, std::nullopt, CreateGrad::YES}; -} - tensor_guid_t ComputationGraphBuilder::conv2d( tensor_guid_t const &x, int outChannels, @@ -413,7 +419,12 @@ tensor_guid_t ComputationGraphBuilder::conv2d( bias_initializer)); } - return this->add_layer(layer, {input}, weights, output_shape); + return get_only(this->add_layer( + layer, + {input}, + transform(weights, + [&](TensorAttrs const &a) { return this->create_weight(a); }), + {make_output_attrs(output_shape)})); } tensor_guid_t ComputationGraphBuilder::dropout( @@ -431,7 +442,8 @@ tensor_guid_t ComputationGraphBuilder::dropout( TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); - return this->add_layer(layer, {input}, {}, output_shape); + return get_only( + this->add_layer(layer, {input}, {}, {make_output_attrs(output_shape)})); } tensor_guid_t ComputationGraphBuilder::embedding( @@ -459,7 +471,10 @@ tensor_guid_t ComputationGraphBuilder::embedding( TensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); - return this->add_layer(layer, {input}, {weight_attrs}, output_shape); + return get_only(this->add_layer(layer, + {input}, + {this->create_weight(weight_attrs)}, + {make_output_attrs(output_shape)})); } tensor_guid_t ComputationGraphBuilder::gather( @@ -474,64 +489,143 @@ tensor_guid_t ComputationGraphBuilder::gather( LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; if (this->get_shape(index).data_type != DataType::INT32 && this->get_shape(index).data_type != DataType::INT64) { - throw mk_runtime_error("Invalid data type for input tensor 2 for Gather: " - "{} (should be {} or {})", - this->get_shape(input).data_type, - DataType::INT32, - DataType::INT64); + throw mk_runtime_error( + fmt::format("Invalid data type for input tensor 2 for Gather: " + "{} (should be {} or {})", + this->get_shape(input).data_type, + DataType::INT32, + DataType::INT64)); } TensorShape output_shape = get_output_shape(attrs, this->get_shape(input), this->get_shape(index)); - return this->add_layer(layer, {input}, {}, output_shape); -} - -/* std::vector - * ComputationGraphBuilder::get_shapes(std::vector const &ts) - * const { */ -/* return transform(ts, [&](tensor_guid_t const &t) { return - * this->get_shape(t); }); */ -/* } */ - -// tensor_guid_t ComputationGraphBuilder::aggregate( -// tensor_guid_t const &gate_preds, -// tensor_guid_t const &gate_assign, -// tensor_guid_t const &true_gate_assign, -// tensor_guid_t const &full_gate_gradients, -// std::vector const &exp_preds, -// int n, -// float lambda_bal, -// std::optional const &maybe_name) { -// AggregateAttrs attrs = {n, lambda_bal}; -// std::string name = maybe_name.value_or(get_default_name(attrs)); - -// LayerAttrs layer = {attrs, name}; -// TensorShape output_shape = get_output_shape(attrs, -// this->get_shape(gate_preds), -// this->get_shape(gate_assign), -// this->get_shape(true_gate_assign), -// this->get_shape(full_gate_gradients), -// this->get_shape(exp_preds)); - -// std::vector inputs = { -// gate_preds, gate_assign, true_gate_assign, full_gate_gradients}; -// extend(inputs, exp_preds); -// return this->add_layer(layer, inputs, {}, output_shape); -// } + return get_only( + this->add_layer(layer, {input}, {}, {make_output_attrs(output_shape)})); +} +tensor_guid_t ComputationGraphBuilder::pool2d( + tensor_guid_t const &x, + int kernelH, + int kernelW, + int strideH, + int strideW, + int paddingH, + int paddingW, + PoolOp type, + std::optional const &activation, + std::optional const &maybe_name) { + + Pool2DAttrs attrs = Pool2DAttrs{ + /*kernel_h=*/kernelH, + /*kernel_w=*/kernelW, + /*stride_h=*/strideH, + /*stride_w=*/strideW, + /*padding_h=*/paddingH, + /*padding_w=*/paddingW, + /*pool_type=*/type, + /*activation=*/activation, + }; + + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); + + tensor_guid_t input = + this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); + + LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; + + TensorShape output_shape = + throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); + + return get_only( + this->add_layer(layer, {input}, {}, {make_output_attrs(output_shape)})); +} + +tensor_guid_t ComputationGraphBuilder::adaptive_pool2d( + tensor_guid_t const &uncasted_input, + int output_h, + int output_w, + PoolOp type, + std::optional const &activation, + std::optional const &maybe_name) { + + TensorDims input_dims = this->get_shape(uncasted_input).dims; + + Pool2DAttrs attrs = throw_if_unexpected(make_adaptive_pool2d_attrs( + input_dims, output_h, output_w, type, activation)); + + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); + + tensor_guid_t casted_input = + this->as_type(uncasted_input, DataType::FLOAT, name + "input_pre_cast"); + + LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; + + TensorShape output_shape = throw_if_unexpected( + get_output_shape(attrs, this->get_shape(casted_input))); + + return get_only(this->add_layer( + layer, {casted_input}, {}, {make_output_attrs(output_shape)})); +} tensor_guid_t ComputationGraphBuilder::batch_norm( tensor_guid_t const &input, - bool relu, + bool affine, + std::optional const &activation, + float eps, + std::optional const &momentum, std::optional const &maybe_name) { - BatchNormAttrs attrs = BatchNormAttrs{relu}; + + if (activation.has_value() && activation.value() != Activation::RELU) { + throw mk_runtime_error(fmt::format( + "batch_norm currently only supports (1) no activation function, or (2) " + "relu activation function, but received {}. " + "If you need support for additional activation functions, please " + "create an issue.", + activation)); + } + + BatchNormAttrs attrs = BatchNormAttrs{ + /*relu=*/activation.has_value(), + /*affine=*/affine, + /*eps=*/eps, + /*momentum=*/momentum, + }; + std::string name = maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; - TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); + TensorShape input_shape = this->get_shape(input); + TensorShape output_shape = + throw_if_unexpected(get_output_shape(attrs, input_shape)); - return this->add_layer(layer, {input}, {}, output_shape); + std::vector weights; + + if (affine) { + // initializers chosen to match those of + // https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html + + TensorShape gamma_shape = + throw_if_unexpected(get_gamma_weights_shape(attrs, input_shape)); + InitializerAttrs gamma_initializer = + InitializerAttrs{ConstantInitializerAttrs{DataTypeValue{float{1}}}}; + weights.push_back(make_weight_attrs(gamma_shape, gamma_initializer)); + + TensorShape beta_shape = + throw_if_unexpected(get_beta_weights_shape(attrs, input_shape)); + InitializerAttrs beta_initializer = + InitializerAttrs{ConstantInitializerAttrs{DataTypeValue{float{0}}}}; + weights.push_back(make_weight_attrs(beta_shape, beta_initializer)); + } + + return get_only(this->add_layer( + layer, + {input}, + transform(weights, + [&](TensorAttrs const &a) { return this->create_weight(a); }), + {make_output_attrs(output_shape)})); } tensor_guid_t ComputationGraphBuilder::multihead_attention( @@ -549,6 +643,20 @@ tensor_guid_t ComputationGraphBuilder::multihead_attention( std::optional initializer, std::optional const &maybe_name) { + if (add_bias_kv) { + throw mk_runtime_error( + "ComputationGraphBuilder::multihead_attention received currently " + "unsupported argument add_bias_kv=true. " + "If you need this functionality, please create an issue."); + } + + if (add_zero_attn) { + throw mk_runtime_error( + "ComputationGraphBuilder::multihead_attention received currently " + "unsupported argument add_zero_attn=true. " + "If you need this functionality, please create an issue."); + } + MultiHeadAttentionAttrs attrs = MultiHeadAttentionAttrs{embed_dim, num_heads, kdim, @@ -561,46 +669,70 @@ tensor_guid_t ComputationGraphBuilder::multihead_attention( std::string name = maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); + TensorShape query_shape = this->get_shape(query); + TensorShape key_shape = this->get_shape(key); + TensorShape value_shape = this->get_shape(value); + LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; - TensorShape output_shape = - throw_if_unexpected(get_output_shape(attrs, - this->get_shape(query), - this->get_shape(key), - this->get_shape(value))); + TensorShape output_shape = throw_if_unexpected( + get_output_shape(attrs, query_shape, key_shape, value_shape)); - TensorShape weights_shape = - throw_if_unexpected(get_weights_shape(attrs, - this->get_shape(query), - this->get_shape(key), - this->get_shape(value))); - TensorAttrs weight_attrs = make_weight_attrs(weights_shape, initializer); + std::vector weights; - return this->add_layer(layer, - std::vector{query, key, value}, - {weight_attrs}, - output_shape); + TensorShape weights_shape = throw_if_unexpected( + get_weights_shape(attrs, query_shape, key_shape, value_shape)); + weights.push_back(make_weight_attrs(weights_shape, initializer)); + + if (bias) { + TensorShape input_bias_shape = throw_if_unexpected( + get_input_bias_shape(attrs, query_shape, key_shape, value_shape)); + // initializer chosen based on + // https://github.com/pytorch/pytorch/blob/31c4e0d37d8efc37a0697159e5b9121ec34d5141/torch/nn/modules/activation.py#L1120-L1121 + InitializerAttrs input_bias_initializer = + InitializerAttrs{ConstantInitializerAttrs{DataTypeValue{float{0}}}}; + + weights.push_back( + make_weight_attrs(input_bias_shape, input_bias_initializer)); + + TensorShape output_bias_shape = throw_if_unexpected( + get_output_bias_shape(attrs, query_shape, key_shape, value_shape)); + // initializer chosen based on + // https://github.com/pytorch/pytorch/blob/31c4e0d37d8efc37a0697159e5b9121ec34d5141/torch/nn/modules/activation.py#L1120-L1121 + InitializerAttrs output_bias_initializer = + InitializerAttrs{ConstantInitializerAttrs{DataTypeValue{float{0}}}}; + + weights.push_back( + make_weight_attrs(output_bias_shape, output_bias_initializer)); + } + + return get_only(this->add_layer( + layer, + {query, key, value}, + transform(weights, + [&](TensorAttrs const &a) { return this->create_weight(a); }), + {make_output_attrs(output_shape)})); } -TensorShape ComputationGraphBuilder::get_broadcast_target_shape( +TensorDims ComputationGraphBuilder::get_broadcast_target_dims( std::vector const &inputs) { - std::vector input_shapes = transform( - inputs, [&](tensor_guid_t const &t) { return this->get_shape(t); }); + std::vector inputs_dims = transform( + inputs, [&](tensor_guid_t const &t) { return this->get_shape(t).dims; }); - return this->get_broadcast_target_shape(input_shapes); + return this->get_broadcast_target_dims(inputs_dims); } -TensorShape ComputationGraphBuilder::get_broadcast_target_shape( - std::vector const &input_shapes) { - std::optional maybe_result = - ::FlexFlow::get_broadcast_target_shape(unordered_set_of(input_shapes)); +TensorDims ComputationGraphBuilder::get_broadcast_target_dims( + std::vector const &inputs_dims) { + std::optional maybe_result = + ::FlexFlow::get_broadcast_target_dims(unordered_set_of(inputs_dims)); if (maybe_result.has_value()) { return maybe_result.value(); } else { throw mk_runtime_error(fmt::format( - "ComputationGraphBuilder::get_broadcast_target_shape failed to find " - "target tensor shape for input tensor shapes {}", - input_shapes)); + "ComputationGraphBuilder::get_broadcast_target_dims failed to find " + "target tensor dims for input tensor dims {}", + inputs_dims)); } } @@ -610,9 +742,11 @@ tensor_guid_t ComputationGraphBuilder::dense( std::optional activation, bool use_bias, DataType data_type, - std::optional const &kernel_initializer, + std::optional const &projection_initializer, std::optional const &bias_initializer, - std::optional const &maybe_name) { + std::optional const &maybe_name, + std::optional const &projection_name, + std::optional const &bias_name) { LinearAttrs attrs = LinearAttrs{outDim, use_bias, data_type, activation, std::nullopt}; @@ -623,18 +757,78 @@ tensor_guid_t ComputationGraphBuilder::dense( TensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); - std::vector weights; - TensorShape kernel_shape = - throw_if_unexpected(get_kernel_shape(attrs, this->get_shape(input))); - weights.push_back(make_weight_attrs(kernel_shape, kernel_initializer)); + std::vector weights; + + TensorShape projection_shape = + throw_if_unexpected(get_projection_shape(attrs, this->get_shape(input))); + + tensor_guid_t projection_weights = + this->create_weight(projection_shape, + CreateGrad::YES, + projection_initializer, + /*sync_type=*/std::nullopt, + projection_name); + + weights.push_back(projection_weights); if (use_bias) { TensorShape bias_shape = throw_if_unexpected(get_bias_shape(attrs, this->get_shape(input))); - weights.push_back(make_weight_attrs(bias_shape, bias_initializer)); + + tensor_guid_t bias_weights = this->create_weight(bias_shape, + CreateGrad::YES, + bias_initializer, + /*sync_type=*/std::nullopt, + bias_name); + weights.push_back(bias_weights); } - return this->add_layer(layer, {input}, weights, output_shape); + return get_only(this->add_layer( + layer, {input}, weights, {make_output_attrs(output_shape)})); +} + +tensor_guid_t ComputationGraphBuilder::concat( + std::vector const &inputs, + int axis, + std::optional const &maybe_name) { + + ConcatAttrs attrs = ConcatAttrs{ff_dim_t{axis}}; + + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); + + LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; + + std::vector input_shapes = transform( + inputs, [&](tensor_guid_t const &i) { return this->get_shape(i); }); + TensorShape output_shape = + throw_if_unexpected(get_output_shape(attrs, input_shapes)); + + return get_only( + this->add_layer(layer, inputs, {}, {make_output_attrs(output_shape)})); +} + +tensor_guid_t ComputationGraphBuilder::flat( + tensor_guid_t const &input, + int start_dim, + std::optional const &end_dim, + std::optional const &maybe_name) { + int input_num_dims = num_dims(this->get_shape(input)); + + FlatAttrs attrs = FlatAttrs{ + /*start_dim=*/ff_dim_t{start_dim}, + /*end_dim=*/ff_dim_t{end_dim.value_or(input_num_dims)}, + }; + + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); + + LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; + + TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); + + return get_only( + this->add_layer(layer, {input}, {}, {make_output_attrs(output_shape)})); } tensor_guid_t ComputationGraphBuilder::layer_norm( @@ -677,17 +871,22 @@ tensor_guid_t ComputationGraphBuilder::layer_norm( TensorShape gamma_shape = throw_if_unexpected(get_gamma_weights_shape(attrs, input_shape)); InitializerAttrs gamma_initializer = - InitializerAttrs{ConstantInitializerAttrs{float{1}}}; + InitializerAttrs{ConstantInitializerAttrs{DataTypeValue{float{1}}}}; weights.push_back(make_weight_attrs(gamma_shape, gamma_initializer)); TensorShape beta_shape = throw_if_unexpected(get_beta_weights_shape(attrs, input_shape)); InitializerAttrs beta_initializer = - InitializerAttrs{ConstantInitializerAttrs{float{0}}}; + InitializerAttrs{ConstantInitializerAttrs{DataTypeValue{float{0}}}}; weights.push_back(make_weight_attrs(beta_shape, beta_initializer)); } - return this->add_layer(layer, {input}, weights, output_shape); + return get_only(this->add_layer( + layer, + {input}, + transform(weights, + [&](TensorAttrs const &a) { return this->create_weight(a); }), + {make_output_attrs(output_shape)})); } tensor_guid_t ComputationGraphBuilder::softmax( @@ -716,7 +915,8 @@ tensor_guid_t ComputationGraphBuilder::softmax( TensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, input_shape)); - return this->add_layer(layer, {input}, {}, output_shape); + return get_only( + this->add_layer(layer, {input}, {}, {make_output_attrs(output_shape)})); } } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/device_id.cc b/lib/pcg/src/pcg/device_id.cc index 35b0c9aeda..a8cfe1f82f 100644 --- a/lib/pcg/src/pcg/device_id.cc +++ b/lib/pcg/src/pcg/device_id.cc @@ -25,8 +25,27 @@ cpu_id_t unwrap_cpu(device_id_t device_id) { return device_id.get(); } -device_id_t device_id_from_index(int, DeviceType) { - NOT_IMPLEMENTED(); +int get_raw_id(device_id_t device_id) { + switch (get_device_type(device_id)) { + case DeviceType::GPU: + return unwrap_gpu(device_id).gpu_index; + case DeviceType::CPU: + return unwrap_cpu(device_id).cpu_index; + default: + throw mk_runtime_error(fmt::format("Unsupported device {}", device_id)); + } +} + +device_id_t device_id_from_index(int idx, DeviceType device_type) { + switch (device_type) { + case DeviceType::GPU: + return device_id_t{gpu_id_t{idx}}; + case DeviceType::CPU: + return device_id_t{cpu_id_t{idx}}; + default: + throw mk_runtime_error( + fmt::format("Unsupported DeviceType {}", device_type)); + } } } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/file_format/v1/v1_binary_sp_decomposition/json.cc b/lib/pcg/src/pcg/file_format/v1/v1_binary_sp_decomposition/json.cc new file mode 100644 index 0000000000..5341e03c0a --- /dev/null +++ b/lib/pcg/src/pcg/file_format/v1/v1_binary_sp_decomposition/json.cc @@ -0,0 +1,84 @@ +#include "pcg/file_format/v1/v1_binary_sp_decomposition/json.h" +#include "utils/exception.h" +#include "utils/fmt/json.h" +#include "utils/overload.h" + +using namespace ::FlexFlow; + +namespace nlohmann { + +V1BinarySPDecomposition + adl_serializer::from_json(json const &j) { + std::string type = j.at("type").get(); + + if (type == "series") { + return V1BinarySPDecomposition{ + j.get(), + }; + } else if (type == "parallel") { + return V1BinarySPDecomposition{ + j.get(), + }; + } else if (type == "leaf") { + return V1BinarySPDecomposition{ + j.at("value").get(), + }; + } else { + throw mk_runtime_error(fmt::format( + "Unknown json type value for LeafOnlyBinarySPDecompositionTree \"{}\" " + "in json object: {}", + type, + j)); + } +} + +void adl_serializer::to_json( + json &j, V1BinarySPDecomposition const &tree) { + tree.visit(overload{ + [&](V1BinarySeriesSplit const &split) { + j = split; + j["type"] = "series"; + return std::monostate{}; + }, + [&](V1BinaryParallelSplit const &split) { + j = split; + j["type"] = "parallel"; + return std::monostate{}; + }, + [&](int leaf) { + j["value"] = leaf; + j["type"] = "leaf"; + return std::monostate{}; + }, + }); +} + +V1BinarySeriesSplit + adl_serializer::from_json(json const &j) { + return V1BinarySeriesSplit{ + /*lhs=*/j.at("left_child").get(), + /*rhs=*/j.at("right_child").get(), + }; +} + +void adl_serializer::to_json( + json &j, V1BinarySeriesSplit const &series) { + j["left_child"] = series.get_left_child(); + j["right_child"] = series.get_right_child(); +} + +V1BinaryParallelSplit + adl_serializer::from_json(json const &j) { + return V1BinaryParallelSplit{ + /*lhs=*/j.at("left_child").get(), + /*rhs=*/j.at("right_child").get(), + }; +} + +void adl_serializer::to_json( + json &j, V1BinaryParallelSplit const &series) { + j["left_child"] = series.get_left_child(); + j["right_child"] = series.get_right_child(); +} + +} // namespace nlohmann diff --git a/lib/pcg/src/pcg/file_format/v1/v1_computation_graph.cc b/lib/pcg/src/pcg/file_format/v1/v1_computation_graph.cc new file mode 100644 index 0000000000..975e92dfb7 --- /dev/null +++ b/lib/pcg/src/pcg/file_format/v1/v1_computation_graph.cc @@ -0,0 +1,24 @@ +#include "pcg/file_format/v1/v1_computation_graph.h" +#include "pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h" + +namespace FlexFlow { + +V1ComputationGraph to_v1(ComputationGraph const &g) { + return V1ComputationGraph{ + to_v1(g.raw_graph), + }; +} + +std::pair> + to_v1_including_node_numbering(ComputationGraph const &cg) { + std::pair, bidict> + raw = + to_v1_including_node_numbering(cg.raw_graph); + V1ComputationGraph v1_cg = V1ComputationGraph{raw.first}; + bidict v1_node_ids = + map_values(raw.second, [](Node const &n) { return layer_guid_t{n}; }); + + return {v1_cg, v1_node_ids}; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/file_format/v1/v1_parallel_computation_graph.cc b/lib/pcg/src/pcg/file_format/v1/v1_parallel_computation_graph.cc new file mode 100644 index 0000000000..9da58fcf6e --- /dev/null +++ b/lib/pcg/src/pcg/file_format/v1/v1_parallel_computation_graph.cc @@ -0,0 +1,12 @@ +#include "pcg/file_format/v1/v1_parallel_computation_graph.h" +#include "pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h" + +namespace FlexFlow { + +V1ParallelComputationGraph to_v1(ParallelComputationGraph const &g) { + return V1ParallelComputationGraph{ + to_v1(g.raw_graph), + }; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/machine_space_offset.cc b/lib/pcg/src/pcg/machine_space_offset.cc new file mode 100644 index 0000000000..9990023f8c --- /dev/null +++ b/lib/pcg/src/pcg/machine_space_offset.cc @@ -0,0 +1,25 @@ +#include "pcg/machine_space_offset.h" +#include "utils/exception.h" + +namespace FlexFlow { +MachineSpaceOffset get_machine_space_offset_from_coordinate( + MachineSpaceCoordinate const &start, MachineSpaceCoordinate const &coord) { + if ((coord.device_idx < start.device_idx) || + (coord.node_idx < start.node_idx)) { + throw mk_runtime_error(fmt::format( + "One of the coordinates of start {} is greater than one of the " + "coordinates of coord {}, are you sure you didn't swap them?", + start, + coord)); + } + if (start.device_type != coord.device_type) { + throw mk_runtime_error( + fmt::format("{} has different DeviceType from {}", start, coord)); + } + + return MachineSpaceOffset{coord.node_idx - start.node_idx, + coord.device_idx - start.device_idx, + coord.device_type}; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/machine_specification.cc b/lib/pcg/src/pcg/machine_specification.cc new file mode 100644 index 0000000000..ca5b8ba047 --- /dev/null +++ b/lib/pcg/src/pcg/machine_specification.cc @@ -0,0 +1,53 @@ +#include "pcg/machine_specification.h" +#include "pcg/device_id.h" +#include "utils/exception.h" +namespace FlexFlow { + +int get_num_gpus(MachineSpecification const &ms) { + return ms.num_nodes * ms.num_gpus_per_node; +} +int get_num_cpus(MachineSpecification const &ms) { + return ms.num_nodes * ms.num_cpus_per_node; +} +int get_num_devices(MachineSpecification const &ms, + DeviceType const &device_type) { + switch (device_type) { + case DeviceType::GPU: + return get_num_gpus(ms); + case DeviceType::CPU: + return get_num_cpus(ms); + default: + throw mk_runtime_error(fmt::format("Unknown DeviceType {}", device_type)); + } +} + +int get_num_devices_per_node(MachineSpecification const &ms, + DeviceType const &device_type) { + switch (device_type) { + case DeviceType::GPU: + return ms.num_gpus_per_node; + case DeviceType::CPU: + return ms.num_cpus_per_node; + default: + throw mk_runtime_error(fmt::format("Unknown DeviceType {}", device_type)); + } +} +bool is_valid_machine_space_coordinate(MachineSpecification const &ms, + MachineSpaceCoordinate const &coord) { + return (coord.node_idx < ms.num_nodes) && + (coord.device_idx < get_num_devices_per_node(ms, coord.device_type)); +} + +device_id_t get_device_id(MachineSpecification const &ms, + MachineSpaceCoordinate const &coord) { + if (!is_valid_machine_space_coordinate(ms, coord)) { + throw mk_runtime_error(fmt::format( + "Invalid coordinate {} for machine specification {}", ms, coord)); + } + int raw_idx = + coord.node_idx * get_num_devices_per_node(ms, coord.device_type) + + coord.device_idx; + return device_id_from_index(raw_idx, coord.device_type); +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/machine_view.cc b/lib/pcg/src/pcg/machine_view.cc index c09ab1a3c9..18f6cacb7e 100644 --- a/lib/pcg/src/pcg/machine_view.cc +++ b/lib/pcg/src/pcg/machine_view.cc @@ -1,121 +1,115 @@ #include "pcg/machine_view.h" -#include "pcg/device_id.h" -#include "pcg/strided_rectangle.dtg.h" -#include "pcg/strided_rectangle.h" -#include "pcg/strided_rectangle_side.h" +#include "pcg/machine_specification.h" +#include "pcg/operator_task_space.h" +#include "utils/containers/contains.h" +#include "utils/containers/count.h" +#include "utils/containers/filter.h" +#include "utils/containers/scanl.h" +#include "utils/containers/sum.h" +#include "utils/containers/transform.h" +#include "utils/containers/zip.h" namespace FlexFlow { -std::vector device_ids(MachineView const &) { - NOT_IMPLEMENTED(); -} - -std::size_t num_dims(MachineView const &mv) { - return get_num_dims(mv.rect); -} - -size_t num_devices(MachineView const &mv) { - return get_num_points(mv.rect).unwrapped; +size_t num_dims(MachineView const &mv) { + return get_strides(mv).size(); } DeviceType get_device_type(MachineView const &mv) { - return get_device_type(mv.start); -} - -static StridedRectangle make_1d_rect(int start, int stop, int stride) { - assert(stop > start); - assert(stride > 0); - StridedRectangleSide side = - strided_side_from_size_and_stride(side_size_t{stop - start}, stride); - StridedRectangle rect = - StridedRectangle{std::vector{side}}; - return rect; -} - -MachineView make_1d_machine_view(gpu_id_t start, gpu_id_t stop, int stride) { - StridedRectangle rect = make_1d_rect(start.gpu_index, stop.gpu_index, stride); - return MachineView{device_id_t{start}, rect}; -} - -MachineView make_1d_machine_view(cpu_id_t start, cpu_id_t stop, int stride) { - StridedRectangle rect = make_1d_rect(start.cpu_index, stop.cpu_index, stride); - return MachineView{device_id_t{start}, rect}; -} - -MachineView - make_1d_machine_view(device_id_t start, device_id_t stop, int stride) { - assert(get_device_type(start) == get_device_type(stop)); - if (get_device_type(start) == DeviceType::CPU) { - return make_1d_machine_view(unwrap_cpu(start), unwrap_cpu(stop), stride); + return mv.start.device_type; +} + +std::vector get_strides(MachineView const &mv) { + return transform(mv.dimensions, + [](MachineViewDimension const &dim) { return dim.stride; }); +} + +std::vector + get_dimensions(MachineView const &mv) { + return transform(mv.dimensions, [](MachineViewDimension const &dim) { + return dim.projection; + }); +} + +MachineView machine_view_from_strides_and_machine_spec_dimensions( + MachineSpaceCoordinate const &start, + std::vector const &strides, + std::vector const &dims) { + std::vector dimensions = + transform(zip(strides, dims), [&](auto const &p) { + return MachineViewDimension{p.first, p.second}; + }); + return MachineView{start, dimensions}; +} + +std::optional get_machine_space_coordinate( + OperatorTaskSpace const &task, + MachineView const &machine_view, + TaskSpaceCoordinate const &coord, + MachineSpecification const &machine_specification) { + + auto get_dimension_indices_for_dimension = + [&](MachineSpecificationDimension dimension) { + std::vector mv_dimensions = + get_dimensions(machine_view); + return filter(count(mv_dimensions.size()), [&](size_t idx) { + return mv_dimensions.at(idx) == dimension; + }); + }; + + auto compute_index = [&](int start_idx, + std::vector const &dimension_indices) { + std::vector mv_strides = get_strides(machine_view); + + std::vector sizes = transform(dimension_indices, [&](size_t i) { + return task.degrees.at(i) * mv_strides.at(i).unwrapped; + }); + std::vector coord_points = transform( + dimension_indices, [&](size_t i) { return coord.raw_coord.at(i); }); + std::vector strides = transform(dimension_indices, [&](size_t i) { + return mv_strides.at(i).unwrapped; + }); + + std::vector coeffs = scanl(sizes, 1, std::multiplies()); + + int index = start_idx; + for (auto [coeff, coord_point, stride] : + zip(coeffs, coord_points, strides)) { + index += coeff * coord_point * stride; + } + return index; + }; + + std::vector inter_dimension_indices = + get_dimension_indices_for_dimension( + MachineSpecificationDimension::INTER_NODE); + std::vector intra_dimension_indices = + get_dimension_indices_for_dimension( + MachineSpecificationDimension::INTRA_NODE); + + int node_idx = + compute_index(machine_view.start.node_idx, inter_dimension_indices); + int device_idx = + compute_index(machine_view.start.device_idx, intra_dimension_indices); + MachineSpaceCoordinate ms_coord = MachineSpaceCoordinate{ + node_idx, device_idx, get_device_type(machine_view)}; + + if (!is_valid_machine_space_coordinate(machine_specification, ms_coord)) { + return std::nullopt; } - assert(get_device_type(start) == DeviceType::GPU); - return make_1d_machine_view(unwrap_gpu(start), unwrap_gpu(stop), stride); + return ms_coord; } -static StridedRectangle - make_1d_rect(int start, num_points_t num_points, int stride) { - return make_1d_rect(start, start + num_points.unwrapped * stride, stride); +std::unordered_set get_machine_space_coordinates( + OperatorTaskSpace const &task, + MachineView const &machine_view, + MachineSpecification const &machine_specification) { + return transform( + get_task_space_coordinates(task), [&](TaskSpaceCoordinate const &coord) { + return get_machine_space_coordinate( + task, machine_view, coord, machine_specification) + .value(); + }); } -MachineView - make_1d_machine_view(cpu_id_t start, num_points_t num_points, int stride) { - StridedRectangle rect = make_1d_rect(start.cpu_index, num_points, stride); - return MachineView{device_id_t{start}, rect}; -} - -MachineView - make_1d_machine_view(gpu_id_t start, num_points_t num_points, int stride) { - StridedRectangle rect = make_1d_rect(start.gpu_index, num_points, stride); - return MachineView{device_id_t{start}, rect}; -} - -MachineView make_1d_machine_view(device_id_t start, - num_points_t num_points, - int stride) { - if (get_device_type(start) == DeviceType::CPU) { - return make_1d_machine_view(unwrap_cpu(start), num_points, stride); - } else { - assert(get_device_type(start) == DeviceType::GPU); - return make_1d_machine_view(unwrap_gpu(start), num_points, stride); - } -} - -static StridedRectangle - make_1d_rect(int start, side_size_t interval_size, int stride) { - return make_1d_rect(start, start + interval_size.unwrapped, stride); -} - -MachineView make_1d_machine_view(cpu_id_t start, - side_size_t interval_size, - int stride) { - StridedRectangle rect = make_1d_rect(start.cpu_index, interval_size, stride); - return MachineView{device_id_t{start}, rect}; -} - -MachineView make_1d_machine_view(gpu_id_t start, - side_size_t interval_size, - int stride) { - StridedRectangle rect = make_1d_rect(start.gpu_index, interval_size, stride); - return MachineView{device_id_t{start}, rect}; -} -MachineView make_1d_machine_view(device_id_t start, - side_size_t interval_size, - int stride) { - - if (get_device_type(start) == DeviceType::CPU) { - return make_1d_machine_view(unwrap_cpu(start), interval_size, stride); - } else { - assert(get_device_type(start) == DeviceType::GPU); - return make_1d_machine_view(unwrap_gpu(start), interval_size, stride); - } -} -MachineView make_1d_machine_view(device_id_t start, size_t interval_size) { - NOT_IMPLEMENTED(); -} - -/* device_id_t MachineView::at(FFOrdered const &coord) const { */ -/* size_t offset = this->rect.at(coord); */ -/* return this->start + offset; */ -/* } */ - } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/operator_task_space.cc b/lib/pcg/src/pcg/operator_task_space.cc new file mode 100644 index 0000000000..02522ae411 --- /dev/null +++ b/lib/pcg/src/pcg/operator_task_space.cc @@ -0,0 +1,38 @@ +#include "pcg/operator_task_space.h" +#include "utils/containers/cartesian_product.h" +#include "utils/containers/maximum.h" +#include "utils/containers/product.h" +#include "utils/containers/range.h" +#include "utils/containers/transform.h" +#include "utils/containers/unordered_set_of.h" + +namespace FlexFlow { + +std::unordered_set + get_task_space_coordinates(OperatorTaskSpace const &task) { + + std::vector> coordinate_ranges = transform( + task.degrees, [&](int const &num_points) { return range(num_points); }); + + std::unordered_set> raw_coordinates = + unordered_set_of(cartesian_product(coordinate_ranges)); + std::unordered_set task_space_coordinates = + transform(raw_coordinates, [](std::vector const &point) { + return TaskSpaceCoordinate{point}; + }); + return task_space_coordinates; +} + +TaskSpaceCoordinate + get_task_space_maximum_coordinate(OperatorTaskSpace const &task) { + return maximum(get_task_space_coordinates(task)).value(); +} + +size_t num_dims(OperatorTaskSpace const &task) { + return task.degrees.size(); +} +size_t num_tasks(OperatorTaskSpace const &task) { + return product(task.degrees); +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index 5b178160cd..781c44640c 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -1,9 +1,14 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "op-attrs/get_incoming_tensor_roles.h" +#include "utils/containers/filtrans.h" #include "utils/containers/get_only.h" #include "utils/containers/transform.h" #include "utils/graph/dataflow_graph/algorithms.h" +#include "utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.h" #include "utils/graph/digraph/algorithms/get_topological_ordering.h" #include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" +#include "utils/graph/labelled_dataflow_graph/algorithms/find_isomorphism.h" +#include "utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.h" #include "utils/graph/node/algorithms.h" namespace FlexFlow { @@ -40,9 +45,42 @@ ParallelLayerAddedResult }; } +ParallelLayerAddedResult + pcg_add_input_layer(ParallelComputationGraph &pcg, + ParallelTensorShape const &tensor_shape) { + ParallelLayerAttrs layer_attrs = ParallelLayerAttrs{ + /*op_attrs=*/PCGOperatorAttrs{InputAttrs{}}, + /*name=*/std::nullopt, + }; + + ParallelTensorAttrs tensor_attrs = ParallelTensorAttrs{ + /*shape=*/tensor_shape, + /*sync_type=*/std::nullopt, + /*initializer=*/std::nullopt, + /*create_gradients=*/CreateGrad::NO, + }; + + return add_parallel_layer(/*pcg=*/pcg, + /*layer_attrs=*/layer_attrs, + /*inputs=*/{}, + /*output_labels=*/{tensor_attrs}); +} + +std::unordered_set + get_pcg_edges_from_layer_to_layer(ParallelComputationGraph const &pcg, + parallel_layer_guid_t const &src, + parallel_layer_guid_t const &dst) { + std::unordered_set raw_edges = + get_dataflow_edges_from_node_to_node( + pcg.raw_graph, src.raw_graph_node, dst.raw_graph_node); + return transform(raw_edges, [](DataflowEdge const &e) { + return ParallelComputationGraphEdge{e}; + }); +} + std::vector - get_layer_inputs(ParallelComputationGraph const &pcg, - parallel_layer_guid_t const &l) { + get_incoming_tensors(ParallelComputationGraph const &pcg, + parallel_layer_guid_t const &l) { return transform( get_input_values(pcg.raw_graph, l.raw_graph_node), [](DataflowOutput const &o) { return parallel_tensor_guid_t{o}; }); @@ -56,6 +94,48 @@ std::vector [](DataflowOutput const &o) { return parallel_tensor_guid_t{o}; }); } +static std::vector + get_incoming_tensors_with_role(ParallelComputationGraph const &pcg, + parallel_layer_guid_t const &l, + IncomingTensorRole desired_role) { + PCGOperatorAttrs attrs = get_parallel_layer_attrs(pcg, l).op_attrs; + + std::vector incoming_tensors = + get_incoming_tensors(pcg, l); + + std::vector incoming_tensor_roles = + get_incoming_tensor_roles(attrs, incoming_tensors.size()); + + assert(incoming_tensors.size() == incoming_tensor_roles.size()); + + std::vector result = filtrans( + zip(incoming_tensors, incoming_tensor_roles), + [&](std::pair const &p) + -> std::optional { + parallel_tensor_guid_t tensor = p.first; + IncomingTensorRole role = p.second; + + if (role == desired_role) { + return tensor; + } else { + return std::nullopt; + } + }); + return result; +} + +std::vector + get_incoming_inputs(ParallelComputationGraph const &pcg, + parallel_layer_guid_t const &l) { + return get_incoming_tensors_with_role(pcg, l, IncomingTensorRole::INPUT); +} + +std::vector + get_incoming_weights(ParallelComputationGraph const &pcg, + parallel_layer_guid_t const &l) { + return get_incoming_tensors_with_role(pcg, l, IncomingTensorRole::WEIGHT); +} + parallel_layer_guid_t get_source_layer(ParallelComputationGraph const &g, parallel_tensor_guid_t const &t) { return parallel_layer_guid_t{t.raw_graph_output.node}; @@ -66,12 +146,23 @@ ParallelLayerAttrs get_parallel_layer_attrs(ParallelComputationGraph const &pcg, return pcg.raw_graph.at(l.raw_graph_node); } +PCGOperatorAttrs pcg_get_op_attrs(ParallelComputationGraph const &pcg, + parallel_layer_guid_t const &l) { + return get_parallel_layer_attrs(pcg, l).op_attrs; +} + ParallelTensorAttrs get_parallel_tensor_attrs(ParallelComputationGraph const &pcg, parallel_tensor_guid_t const &t) { return pcg.raw_graph.at(t.raw_graph_output); } +ParallelTensorShape + get_parallel_tensor_shape(ParallelComputationGraph const &pcg, + parallel_tensor_guid_t const &t) { + return get_parallel_tensor_attrs(pcg, t).shape; +} + std::vector topological_ordering(ParallelComputationGraph const &pcg) { return transform(get_topological_ordering(pcg.raw_graph), @@ -88,4 +179,28 @@ parallel_layer_guid_t return get_only(found); } +ParallelComputationGraph + without_layer_names(ParallelComputationGraph const &pcg) { + return ParallelComputationGraph{ + LabelledDataflowGraph:: + create_copy_of< + UnorderedSetLabelledOpenDataflowGraph>( + rewrite_node_labels( + pcg.raw_graph, + [](Node const &n, ParallelLayerAttrs const &old_attrs) { + ParallelLayerAttrs new_attrs = old_attrs; + new_attrs.name = std::nullopt; + return new_attrs; + })), + }; +} + +bool pcgs_are_isomorphic(ParallelComputationGraph const &lhs, + ParallelComputationGraph const &rhs) { + return find_isomorphism(without_layer_names(lhs).raw_graph, + without_layer_names(rhs).raw_graph) + .has_value(); +} + } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc index 8290a2ff94..f33b4dcd17 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -1,4 +1,18 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include "op-attrs/get_incoming_tensor_roles.h" +#include "op-attrs/ops/attention.h" +#include "op-attrs/ops/batch_matmul.h" +#include "op-attrs/ops/batch_norm.h" +#include "op-attrs/ops/cast.h" +#include "op-attrs/ops/combine.h" +#include "op-attrs/ops/conv_2d.h" +#include "op-attrs/ops/element_binary.h" +#include "op-attrs/ops/element_unary.h" +#include "op-attrs/ops/embedding.h" +#include "op-attrs/ops/linear.h" +#include "op-attrs/ops/reduction.h" +#include "op-attrs/ops/repartition.h" +#include "op-attrs/ops/replicate.h" #include "op-attrs/ops/weight_attrs.dtg.h" #include "op-attrs/parallel_op_attrs.h" #include "op-attrs/pcg_operator_attrs.h" @@ -35,13 +49,13 @@ ParallelComputationGraphBuilder::ParallelComputationGraphBuilder() parallel_tensor_guid_t ParallelComputationGraphBuilder::create_input_tensor( ParallelTensorShape const &shape, - bool create_grad, + CreateGrad create_grad, std::optional const &name) { ParallelTensorAttrs tensor_attrs = ParallelTensorAttrs{ /*shape=*/shape, /*sync_type=*/std::nullopt, /*initializer=*/std::nullopt, - /*create_gradients=*/(create_grad ? CreateGrad::YES : CreateGrad::NO), + /*create_gradients=*/create_grad, }; ParallelLayerAttrs layer_attrs = ParallelLayerAttrs{ PCGOperatorAttrs{InputAttrs{}}, @@ -181,7 +195,7 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::dense( std::optional activation, bool use_bias, DataType data_type, - std::optional const &kernel_initializer, + std::optional const &projection_initializer, std::optional const &bias_initializer, std::optional const &maybe_name) { LinearAttrs attrs = LinearAttrs{ @@ -204,9 +218,10 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::dense( std::vector weights; { - ParallelTensorShape kernel_shape = - throw_if_unexpected(get_kernel_shape(attrs, input_shape)); - weights.push_back(make_weight_attrs(kernel_shape, kernel_initializer)); + ParallelTensorShape projection_shape = + throw_if_unexpected(get_projection_shape(attrs, input_shape)); + weights.push_back( + make_weight_attrs(projection_shape, projection_initializer)); } if (use_bias) { @@ -330,18 +345,56 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::multihead_attention( parallel_tensor_guid_t ParallelComputationGraphBuilder::batch_norm( parallel_tensor_guid_t const &input, - bool relu, + bool affine, + std::optional const &activation, + float eps, + std::optional const &momentum, std::optional const &maybe_name) { - BatchNormAttrs attrs = BatchNormAttrs{relu}; + if (activation.has_value() && activation.value() != Activation::RELU) { + throw mk_runtime_error(fmt::format( + "batch_norm currently only supports (1) no activation function, or (2) " + "relu activation function, but received {}. " + "If you need support for additional activation functions, please " + "create an issue.", + activation)); + } + + BatchNormAttrs attrs = BatchNormAttrs{ + /*relu=*/activation.has_value(), + /*affine=*/affine, + /*eps=*/eps, + /*momentum=*/momentum, + }; std::string name = maybe_name.value_or(get_default_name(PCGOperatorAttrs{attrs})); ParallelLayerAttrs layer = ParallelLayerAttrs{PCGOperatorAttrs{attrs}, name}; + ParallelTensorShape input_shape = this->get_shape(input); + ParallelTensorShape output_shape = - get_output_shape(attrs, this->get_shape(input)); + throw_if_unexpected(get_output_shape(attrs, input_shape)); + + std::vector weights; + + if (attrs.affine) { + // initializers chosen to match those of + // https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html + + ParallelTensorShape gamma_shape = + throw_if_unexpected(get_gamma_weights_shape(attrs, input_shape)); + InitializerAttrs gamma_initializer = + InitializerAttrs{ConstantInitializerAttrs{DataTypeValue{float{1}}}}; + weights.push_back(make_weight_attrs(gamma_shape, gamma_initializer)); + + ParallelTensorShape beta_shape = + throw_if_unexpected(get_beta_weights_shape(attrs, input_shape)); + InitializerAttrs beta_initializer = + InitializerAttrs{ConstantInitializerAttrs{DataTypeValue{float{0}}}}; + weights.push_back(make_weight_attrs(beta_shape, beta_initializer)); + } return this->add_layer(layer, {input}, {}, {output_shape}); } @@ -580,11 +633,32 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::add_weight( return parallel_tensor_guid_t{current_raw_weight_tensor}; } +static void check_incoming_tensor_roles(ParallelLayerAttrs const &layer, + int num_inputs, + int num_weights) { + std::vector correct = + get_incoming_tensor_roles(layer.op_attrs, num_inputs + num_weights); + std::vector current = concat_vectors( + std::vector(num_inputs, IncomingTensorRole::INPUT), + std::vector(num_weights, IncomingTensorRole::WEIGHT)); + + if (correct != current) { + throw mk_runtime_error( + fmt::format("check_incoming_tensor_roles found deviation in incoming " + "tensors: expected {}, received {}", + correct, + current)); + } +} + std::vector ParallelComputationGraphBuilder::add_layer( ParallelLayerAttrs const &layer, std::vector const &inputs, std::vector const &weights, std::vector const &outputs) { + + check_incoming_tensor_roles(layer, inputs.size(), weights.size()); + std::vector raw_weight_tensors; for (auto const &kv : enumerate_vector(weights)) { int weight_idx = kv.first; @@ -603,6 +677,7 @@ std::vector ParallelComputationGraphBuilder::add_layer( transform(inputs, [](parallel_tensor_guid_t const &t) { return t.raw_graph_output; }); + std::vector raw_outputs = this->pcg.raw_graph .add_node( diff --git a/lib/pcg/src/pcg/start_invariant_machine_view.cc b/lib/pcg/src/pcg/start_invariant_machine_view.cc new file mode 100644 index 0000000000..1fcc3ea12f --- /dev/null +++ b/lib/pcg/src/pcg/start_invariant_machine_view.cc @@ -0,0 +1,86 @@ +#include "pcg/start_invariant_machine_view.h" +#include "pcg/machine_space_offset.h" +#include "pcg/machine_view.h" +#include "pcg/operator_task_space.h" +#include "utils/containers/count.h" +#include "utils/containers/filter.h" +#include "utils/containers/scanl.h" +#include "utils/containers/transform.h" +#include "utils/containers/zip.h" +namespace FlexFlow { + +MachineView machine_view_from_start_invariant( + StartInvariantMachineView const &start_inv_mv, + MachineSpaceCoordinate const &start) { + return MachineView{start, start_inv_mv.dimensions}; +} + +StartInvariantMachineView + start_invariant_from_machine_view(MachineView const &mv) { + return StartInvariantMachineView{mv.dimensions, get_device_type(mv)}; +} + +size_t num_dims(StartInvariantMachineView const &start_inv_mv) { + return start_inv_mv.dimensions.size(); +} + +DeviceType get_device_type(StartInvariantMachineView const &start_inv_mv) { + return start_inv_mv.device_type; +} + +std::vector + get_strides(StartInvariantMachineView const &start_inv_mv) { + return transform(start_inv_mv.dimensions, + [](MachineViewDimension const &dim) { return dim.stride; }); +} + +std::vector + get_dimensions(StartInvariantMachineView const &start_inv_mv) { + return transform( + start_inv_mv.dimensions, + [](MachineViewDimension const &dim) { return dim.projection; }); +} + +StartInvariantMachineView + start_invariant_machine_view_from_strides_and_machine_spec_dimensions( + std::vector const &strides, + std::vector const &dims, + DeviceType device_type) { + std::vector dimensions = + transform(zip(strides, dims), [&](auto const &p) { + return MachineViewDimension{p.first, p.second}; + }); + return StartInvariantMachineView{dimensions, device_type}; +} + +std::optional get_machine_space_offset( + OperatorTaskSpace const &task, + StartInvariantMachineView const &start_inv_machine_view, + TaskSpaceCoordinate const &coord, + MachineSpecification const &machine_specification) { + MachineSpaceCoordinate dummy_start = + MachineSpaceCoordinate{0, 0, get_device_type(start_inv_machine_view)}; + MachineView mv = + machine_view_from_start_invariant(start_inv_machine_view, dummy_start); + std::optional ms_coord = + get_machine_space_coordinate(task, mv, coord, machine_specification); + if (ms_coord == std::nullopt) { + return std::nullopt; + } + return get_machine_space_offset_from_coordinate(dummy_start, + ms_coord.value()); +} + +std::unordered_set get_machine_space_offsets( + OperatorTaskSpace const &task, + StartInvariantMachineView const &start_inv_machine_view, + MachineSpecification const &machine_specification) { + return transform( + get_task_space_coordinates(task), [&](TaskSpaceCoordinate const &coord) { + return get_machine_space_offset( + task, start_inv_machine_view, coord, machine_specification) + .value(); + }); +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/strided_rectangle.cc b/lib/pcg/src/pcg/strided_rectangle.cc deleted file mode 100644 index dfb5d0af12..0000000000 --- a/lib/pcg/src/pcg/strided_rectangle.cc +++ /dev/null @@ -1,35 +0,0 @@ -#include "pcg/strided_rectangle.h" -#include "op-attrs/dim_ordered/transform.h" -#include "utils/containers/product.h" - -namespace FlexFlow { - -/* size_t StridedRectangle::at(FFOrdered const &coord) const { */ -/* assert(coord.size() == this->num_dims()); */ - -/* size_t _1d_stride = 1; */ -/* size_t idx = 0; */ -/* for (auto dim : inner_to_outer_idxs(this->sides)) { */ -/* idx += this->sides.at(dim).at(coord.at(dim)).value() * _1d_stride; */ -/* _1d_stride *= this->sides.at(dim).get_size().value(); */ -/* } */ -/* return idx; */ -/* } */ - -size_t get_num_dims(StridedRectangle const &rect) { - return rect.sides.size(); -} - -num_points_t get_num_points(StridedRectangle const &rect) { - return num_points_t{ - product(transform(rect.sides, [](StridedRectangleSide const &side) { - return side.num_points.unwrapped; - }))}; -} - -StridedRectangleSide get_side_at_idx(StridedRectangle const &rect, - ff_dim_t const &idx) { - return rect.sides.at(idx); -} - -} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/strided_rectangle_side.cc b/lib/pcg/src/pcg/strided_rectangle_side.cc deleted file mode 100644 index e6caf4cb86..0000000000 --- a/lib/pcg/src/pcg/strided_rectangle_side.cc +++ /dev/null @@ -1,17 +0,0 @@ -#include "pcg/strided_rectangle_side.h" -#include "utils/exception.h" - -namespace FlexFlow { - -StridedRectangleSide strided_side_from_size_and_stride(side_size_t side_size, - int stride) { - assert((side_size.unwrapped % stride) == 0); - return StridedRectangleSide{num_points_t{side_size.unwrapped / stride}, - stride}; -} - -side_size_t get_side_size(StridedRectangleSide const &s) { - return side_size_t{s.num_points.unwrapped * s.stride}; -} - -} // namespace FlexFlow diff --git a/lib/pcg/test/src/pcg/computation_graph.cc b/lib/pcg/test/src/pcg/computation_graph.cc new file mode 100644 index 0000000000..e2ed51b2f1 --- /dev/null +++ b/lib/pcg/test/src/pcg/computation_graph.cc @@ -0,0 +1,206 @@ +#include "pcg/computation_graph.h" +#include "pcg/computation_graph_builder.h" +#include "utils/containers/get_only.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_incoming_inputs(ComputationGraph, layer_guid_t)") { + SUBCASE("layer has no inputs") { + std::string input_name = "input"; + ComputationGraph cg = [&] { + ComputationGraphBuilder b; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + 10, + 12, + }}, + DataType::FLOAT, + }; + + tensor_guid_t input = + b.create_input(input_shape, CreateGrad::YES, input_name); + + return b.computation_graph; + }(); + + layer_guid_t input_layer = get_layer_by_name(cg, input_name); + + std::vector result = get_incoming_inputs(cg, input_layer); + std::vector correct = {}; + + CHECK(result == correct); + } + + SUBCASE("layer has inputs but no weights") { + std::string layer_name = "my op"; + + ComputationGraphBuilder b; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + 10, + 12, + }}, + DataType::FLOAT, + }; + + tensor_guid_t input = b.create_input(input_shape, CreateGrad::YES); + b.relu(input, layer_name); + + ComputationGraph cg = b.computation_graph; + + layer_guid_t layer = get_layer_by_name(cg, layer_name); + + std::vector result = get_incoming_inputs(cg, layer); + std::vector correct = {input}; + + CHECK(result == correct); + } + + SUBCASE("layer has inputs and weights") { + std::string layer_name = "my op"; + + ComputationGraphBuilder b; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + 10, + 12, + }}, + DataType::FLOAT, + }; + + tensor_guid_t input = b.create_input(input_shape, CreateGrad::YES); + b.dense(input, + /*outDim=*/14, + /*activation=*/Activation::RELU, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*projection_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt, + /*name=*/layer_name); + + ComputationGraph cg = b.computation_graph; + + layer_guid_t dense_layer = get_layer_by_name(cg, layer_name); + + std::vector result = get_incoming_inputs(cg, dense_layer); + std::vector correct = { + input, + }; + + CHECK(result == correct); + } + } + + TEST_CASE("get_incoming_weights(ComputationGraph, layer_guid_t)") { + SUBCASE("layer has no inputs or weights") { + std::string input_name = "input"; + ComputationGraph cg = [&] { + ComputationGraphBuilder b; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + 10, + 12, + }}, + DataType::FLOAT, + }; + + tensor_guid_t input = + b.create_input(input_shape, CreateGrad::YES, input_name); + + return b.computation_graph; + }(); + + layer_guid_t input_layer = get_layer_by_name(cg, input_name); + + std::vector result = get_incoming_weights(cg, input_layer); + std::vector correct = {}; + + CHECK(result == correct); + } + + SUBCASE("layer has inputs but no weights") { + std::string layer_name = "my op"; + + ComputationGraph cg = [&] { + ComputationGraphBuilder b; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + 10, + 12, + }}, + DataType::FLOAT, + }; + + tensor_guid_t input = b.create_input(input_shape, CreateGrad::YES); + b.relu(input, layer_name); + + return b.computation_graph; + }(); + + layer_guid_t layer = get_layer_by_name(cg, layer_name); + + std::vector result = get_incoming_weights(cg, layer); + std::vector correct = {}; + + CHECK(result == correct); + } + + SUBCASE("layer has inputs and weights") { + std::string layer_name = "my op"; + std::string projection_name = "my projection weight"; + std::string bias_name = "my bias weight"; + + ComputationGraph cg = [&] { + ComputationGraphBuilder b; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + 10, + 12, + }}, + DataType::FLOAT, + }; + + tensor_guid_t input = b.create_input(input_shape, CreateGrad::YES); + b.dense(input, + /*outDim=*/14, + /*activation=*/Activation::RELU, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*projection_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt, + /*name=*/layer_name, + /*projection_name=*/projection_name, + /*bias_name=*/bias_name); + + return b.computation_graph; + }(); + + layer_guid_t dense_layer = get_layer_by_name(cg, layer_name); + + layer_guid_t projection_weight_layer = + get_layer_by_name(cg, projection_name); + tensor_guid_t projection_weight = + get_only(get_outgoing_tensors(cg, projection_weight_layer)); + + layer_guid_t bias_weight_layer = get_layer_by_name(cg, bias_name); + tensor_guid_t bias_weight = + get_only(get_outgoing_tensors(cg, bias_weight_layer)); + + std::vector result = get_incoming_weights(cg, dense_layer); + std::vector correct = { + projection_weight, + bias_weight, + }; + + CHECK(result == correct); + } + } +} diff --git a/lib/pcg/test/src/test_computation_graph_builder.cc b/lib/pcg/test/src/pcg/computation_graph_builder.cc similarity index 92% rename from lib/pcg/test/src/test_computation_graph_builder.cc rename to lib/pcg/test/src/pcg/computation_graph_builder.cc index 936c2de00d..e7fa853be9 100644 --- a/lib/pcg/test/src/test_computation_graph_builder.cc +++ b/lib/pcg/test/src/pcg/computation_graph_builder.cc @@ -1,6 +1,6 @@ +#include "pcg/computation_graph_builder.h" #include "doctest/doctest.h" #include "pcg/computation_graph.h" -#include "pcg/computation_graph_builder.h" using namespace ::FlexFlow; @@ -15,7 +15,7 @@ TEST_SUITE(FF_TEST_SUITE) { DataType::FLOAT, }; - tensor_guid_t input = b.create_tensor(input_shape, CreateGrad::YES); + tensor_guid_t input = b.create_input(input_shape, CreateGrad::YES); tensor_guid_t output = b.conv2d(input, /*outChannels=*/5, /*kernelH=*/3, diff --git a/lib/pcg/test/src/pcg/file_format/v1/v1_binary_sp_decomposition/json.cc b/lib/pcg/test/src/pcg/file_format/v1/v1_binary_sp_decomposition/json.cc new file mode 100644 index 0000000000..9068e14517 --- /dev/null +++ b/lib/pcg/test/src/pcg/file_format/v1/v1_binary_sp_decomposition/json.cc @@ -0,0 +1,178 @@ +#include "pcg/file_format/v1/v1_binary_sp_decomposition/json.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("adl_serializer") { + V1BinarySPDecomposition example_tree = V1BinarySPDecomposition{ + V1BinarySeriesSplit{ + V1BinarySPDecomposition{ + V1BinaryParallelSplit{ + V1BinarySPDecomposition{2}, + V1BinarySPDecomposition{2}, + }, + }, + V1BinarySPDecomposition{3}, + }, + }; + + nlohmann::json example_json = { + {"type", "series"}, + { + "left_child", + { + {"type", "parallel"}, + { + "left_child", + { + {"type", "leaf"}, + {"value", 2}, + }, + }, + { + "right_child", + { + {"type", "leaf"}, + {"value", 2}, + }, + }, + }, + }, + { + "right_child", + { + {"type", "leaf"}, + {"value", 3}, + }, + }, + }; + + SUBCASE("to_json") { + nlohmann::json result = example_tree; + nlohmann::json correct = example_json; + + CHECK(result == correct); + } + + SUBCASE("from_json") { + V1BinarySPDecomposition result = + example_json.get(); + V1BinarySPDecomposition correct = example_tree; + + CHECK(result == correct); + } + } + + TEST_CASE("adl_serializer") { + V1BinarySeriesSplit example_split = V1BinarySeriesSplit{ + V1BinarySPDecomposition{ + V1BinaryParallelSplit{ + V1BinarySPDecomposition{2}, + V1BinarySPDecomposition{2}, + }, + }, + V1BinarySPDecomposition{3}, + }; + + nlohmann::json example_json = { + { + "left_child", + { + {"type", "parallel"}, + { + "left_child", + { + {"type", "leaf"}, + {"value", 2}, + }, + }, + { + "right_child", + { + {"type", "leaf"}, + {"value", 2}, + }, + }, + }, + }, + { + "right_child", + { + {"type", "leaf"}, + {"value", 3}, + }, + }, + }; + + SUBCASE("to_json") { + nlohmann::json result = example_split; + nlohmann::json correct = example_json; + + CHECK(result == correct); + } + + SUBCASE("from_json") { + V1BinarySeriesSplit result = example_json.get(); + V1BinarySeriesSplit correct = example_split; + + CHECK(result == correct); + } + } + + TEST_CASE("adl_serializer") { + V1BinaryParallelSplit example_split = V1BinaryParallelSplit{ + V1BinarySPDecomposition{ + V1BinaryParallelSplit{ + V1BinarySPDecomposition{2}, + V1BinarySPDecomposition{2}, + }, + }, + V1BinarySPDecomposition{3}, + }; + + nlohmann::json example_json = { + { + "left_child", + { + {"type", "parallel"}, + { + "left_child", + { + {"type", "leaf"}, + {"value", 2}, + }, + }, + { + "right_child", + { + {"type", "leaf"}, + {"value", 2}, + }, + }, + }, + }, + { + "right_child", + { + {"type", "leaf"}, + {"value", 3}, + }, + }, + }; + + SUBCASE("to_json") { + nlohmann::json result = example_split; + nlohmann::json correct = example_json; + + CHECK(result == correct); + } + + SUBCASE("from_json") { + V1BinaryParallelSplit result = example_json.get(); + V1BinaryParallelSplit correct = example_split; + + CHECK(result == correct); + } + } +} diff --git a/lib/pcg/test/src/pcg/file_format/v1/v1_computation_graph.cc b/lib/pcg/test/src/pcg/file_format/v1/v1_computation_graph.cc new file mode 100644 index 0000000000..8336d81bb4 --- /dev/null +++ b/lib/pcg/test/src/pcg/file_format/v1/v1_computation_graph.cc @@ -0,0 +1,30 @@ +#include "pcg/file_format/v1/v1_computation_graph.h" +#include "pcg/computation_graph_builder.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("V1ComputationGraph") { + ComputationGraph cg = [] { + ComputationGraphBuilder b; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + 12, + 16, + }}, + DataType::FLOAT, + }; + + tensor_guid_t input = b.create_input(input_shape, CreateGrad::YES); + tensor_guid_t mm_output = b.dense(input, 8); + tensor_guid_t relu_output = b.relu(mm_output); + + return b.computation_graph; + }(); + + V1ComputationGraph v1_cg = to_v1(cg); + nlohmann::json j = v1_cg; + } +} diff --git a/lib/pcg/test/src/pcg/file_format/v1/v1_parallel_computation_graph.cc b/lib/pcg/test/src/pcg/file_format/v1/v1_parallel_computation_graph.cc new file mode 100644 index 0000000000..8ce25c4bc5 --- /dev/null +++ b/lib/pcg/test/src/pcg/file_format/v1/v1_parallel_computation_graph.cc @@ -0,0 +1,36 @@ +#include "pcg/file_format/v1/v1_parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("V1ParallelComputationGraph") { + ParallelComputationGraph pcg = [] { + ParallelComputationGraphBuilder b; + + ParallelTensorShape input_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{12, 2}, + ShardParallelDim{16, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + parallel_tensor_guid_t input = b.create_input_tensor(input_shape); + parallel_tensor_guid_t mm_output = b.dense(input, 8); + parallel_tensor_guid_t relu_output = b.relu(mm_output); + + return b.pcg; + }(); + + V1ParallelComputationGraph v1_pcg = to_v1(pcg); + nlohmann::json j = v1_pcg; + } +} diff --git a/lib/pcg/test/src/pcg/initializers/uniform_initializer_attrs.cc b/lib/pcg/test/src/pcg/initializers/uniform_initializer_attrs.cc index 0b75e3ae1a..703c129da4 100644 --- a/lib/pcg/test/src/pcg/initializers/uniform_initializer_attrs.cc +++ b/lib/pcg/test/src/pcg/initializers/uniform_initializer_attrs.cc @@ -1,6 +1,8 @@ #include "pcg/initializers/uniform_initializer_attrs.h" -#include "test/utils/doctest.h" #include "test/utils/rapidcheck.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Arbitrary") { diff --git a/lib/pcg/test/src/pcg/machine_specification.cc b/lib/pcg/test/src/pcg/machine_specification.cc new file mode 100644 index 0000000000..c183ae0d31 --- /dev/null +++ b/lib/pcg/test/src/pcg/machine_specification.cc @@ -0,0 +1,54 @@ +#include "pcg/machine_specification.h" +#include "pcg/device_id.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("MachineSpecification") { + + MachineSpecification ms = MachineSpecification{ + /*num_nodes=*/4, + /*num_cpus_per_node=*/16, + /*num_gpus_per_node=*/8, + /*inter_node_bandwidth=*/0, + /*intra_node_bandwidth=*/0, + }; + + SUBCASE("get_num_gpus") { + CHECK(get_num_gpus(ms) == 4 * 8); + } + + SUBCASE("get_num_cpus") { + CHECK(get_num_cpus(ms) == 4 * 16); + } + + SUBCASE("get_num_devices") { + CHECK(get_num_devices(ms, DeviceType::GPU) == 4 * 8); + CHECK(get_num_devices(ms, DeviceType::CPU) == 16 * 4); + } + + SUBCASE("get_device_id") { + SUBCASE("valid MachineSpaceCoordinate") { + MachineSpaceCoordinate coord = MachineSpaceCoordinate{ + /*node_idx=*/2, + /*device_idx=*/12, + DeviceType::CPU, + }; + device_id_t correct = + device_id_from_index(2 * 16 + 12, DeviceType::CPU); + device_id_t result = get_device_id(ms, coord); + CHECK(correct == result); + } + SUBCASE("MachineSpaceCoordinate out of bounds for given machine spec") { + MachineSpaceCoordinate coord = MachineSpaceCoordinate{ + /*node_idx=*/2, + /*device_idx=*/18, + DeviceType::CPU, + }; + CHECK_THROWS(get_device_id(ms, coord)); + } + } + } +} diff --git a/lib/pcg/test/src/pcg/machine_view.cc b/lib/pcg/test/src/pcg/machine_view.cc new file mode 100644 index 0000000000..dcf22d6c00 --- /dev/null +++ b/lib/pcg/test/src/pcg/machine_view.cc @@ -0,0 +1,301 @@ +#include "pcg/machine_view.h" +#include "test/utils/doctest/fmt/optional.h" +#include "utils/containers/transform.h" +#include "utils/fmt/unordered_set.h" +#include "utils/fmt/vector.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("MachineView - utility functions") { + MachineView mv = MachineView{ + MachineSpaceCoordinate{ + /*node_idx=*/0, /*device_idx=*/0, DeviceType::GPU}, + {MachineViewDimension{stride_t{2}, + MachineSpecificationDimension::INTER_NODE}, + MachineViewDimension{stride_t{2}, + MachineSpecificationDimension::INTER_NODE}}}; + + SUBCASE("num_dims") { + CHECK(num_dims(mv) == 2); + } + SUBCASE("get_device_type") { + CHECK(get_device_type(mv) == DeviceType::GPU); + } + } + + TEST_CASE("get_machine_space_coordinate") { + SUBCASE("1D case") { + + // This operator has shape (3,), and thus 3 tasks. + // The (only) dimension is projected on the INTER (device) dimension with + // a stride of 2. The start of the projection defined by MachineView + // starts at MachineSpaceCoordinate (0,1), and the machine space has 1 + // node and 6 devices per node. + + /** + * The tasks will thus be distributed like this: + * +-------+-------+-------+-------+-------+-------+ + * | | (0,) | | (1,) | | (2,) | + * +-------+-------+-------+-------+-------+-------+ + * Where the (x,) are the `TaskSpaceCoordinate`s, and the underlying grid + * is the machine space. + */ + OperatorTaskSpace task = OperatorTaskSpace{{3}}; + MachineView mv = MachineView{ + MachineSpaceCoordinate{ + /*node_idx=*/0, /*device_idx=*/1, DeviceType::GPU}, + {MachineViewDimension{stride_t{2}, + MachineSpecificationDimension::INTRA_NODE}}}; + MachineSpecification ms = + MachineSpecification{/*num_nodes=*/1, + /*num_cpus_per_node=*/6, + /*num_gpus_per_node=*/6, + /*inter_node_bandwidth=*/0, + /*intra_node_bandwidth=*/0}; + + SUBCASE("Task with TaskSpaceCoordinate = (0,)") { + TaskSpaceCoordinate coord = TaskSpaceCoordinate{{0}}; + MachineSpaceCoordinate correct = MachineSpaceCoordinate{ + /*node_idx=*/0, /*device_idx=*/1, DeviceType::GPU}; + MachineSpaceCoordinate result = + get_machine_space_coordinate(task, mv, coord, ms).value(); + CHECK(correct == result); + } + + SUBCASE("Task with TaskSpaceCoordinate = (1,)") { + TaskSpaceCoordinate coord = TaskSpaceCoordinate{{1}}; + MachineSpaceCoordinate correct = MachineSpaceCoordinate{ + /*node_idx=*/0, /*device_idx=*/3, DeviceType::GPU}; + MachineSpaceCoordinate result = + get_machine_space_coordinate(task, mv, coord, ms).value(); + CHECK(correct == result); + } + + SUBCASE("Task with TaskSpaceCoordinate = (2,)") { + TaskSpaceCoordinate coord = TaskSpaceCoordinate{{2}}; + MachineSpaceCoordinate correct = MachineSpaceCoordinate{ + /*node_idx=*/0, /*device_idx=*/5, DeviceType::GPU}; + MachineSpaceCoordinate result = + get_machine_space_coordinate(task, mv, coord, ms).value(); + CHECK(correct == result); + } + + SUBCASE("TaskSpaceCoordinate is out of bounds") { + TaskSpaceCoordinate coord = TaskSpaceCoordinate{{4}}; + std::optional result = + get_machine_space_coordinate(task, mv, coord, ms); + std::optional correct = std::nullopt; + CHECK(result == correct); + } + + SUBCASE("2D case - projection on different dimensions") { + // This operator has shape (2, 2), and thus 2 * 2 = 4 tasks. + // The first dimension is projected onto the INTER (node) dimension with + // stride 1, while the second dimension is projected onto the INTRA + // (device) dimension with stride 2. The start of the projection defined + // by MachineView is at MachineSpaceCoordinates (1, 2), and the machine + // space has 3 nodes and 5 devices per node. + + /** + * The tasks will thus be distributed like this: + * +-------+-------+-------+-------+-------+ + * | | | | | | + * +-------+-------+-------+-------+-------+ + * | | | (0,0) | | (0,1) | + * +-------+-------+-------+-------+-------+ + * | | | (1,0) | | (1,1) | + * +-------+-------+-------+-------+-------+ + * Where the (x,y) are the `TaskSpaceCoordinate`s, and the underlying + * grid is the machine space. + */ + + OperatorTaskSpace task = OperatorTaskSpace{{2, 2}}; + MachineView mv = MachineView{ + MachineSpaceCoordinate{ + /*node_idx=*/1, /*device_idx=*/2, DeviceType::GPU}, + {MachineViewDimension{stride_t{1}, + MachineSpecificationDimension::INTER_NODE}, + MachineViewDimension{stride_t{2}, + MachineSpecificationDimension::INTRA_NODE}}}; + MachineSpecification ms = + MachineSpecification{/*num_nodes=*/3, + /*num_cpus_per_node=*/5, + /*num_gpus_per_node=*/5, + /*inter_node_bandwidth=*/0, + /*intra_node_bandwidth=*/0}; + + SUBCASE("Task with TaskSpaceCoordinate = (0,0)") { + TaskSpaceCoordinate coord = TaskSpaceCoordinate{{0, 0}}; + MachineSpaceCoordinate correct = MachineSpaceCoordinate{ + /*node_idx=*/1, /*device_idx=*/2, DeviceType::GPU}; + MachineSpaceCoordinate result = + get_machine_space_coordinate(task, mv, coord, ms).value(); + CHECK(correct == result); + } + + SUBCASE("Task with TaskSpaceCoordinate = (0,1)") { + TaskSpaceCoordinate coord = TaskSpaceCoordinate{{0, 1}}; + MachineSpaceCoordinate correct = MachineSpaceCoordinate{ + /*node_idx=*/1, /*device_idx=*/4, DeviceType::GPU}; + MachineSpaceCoordinate result = + get_machine_space_coordinate(task, mv, coord, ms).value(); + CHECK(correct == result); + } + + SUBCASE("Task with TaskSpaceCoordinate = (1,0)") { + TaskSpaceCoordinate coord = TaskSpaceCoordinate{{1, 0}}; + MachineSpaceCoordinate correct = MachineSpaceCoordinate{ + /*node_idx=*/2, /*device_idx=*/2, DeviceType::GPU}; + MachineSpaceCoordinate result = + get_machine_space_coordinate(task, mv, coord, ms).value(); + CHECK(correct == result); + } + + SUBCASE("Task with TaskSpaceCoordinate = (1,1)") { + TaskSpaceCoordinate coord = TaskSpaceCoordinate{{1, 1}}; + MachineSpaceCoordinate correct = MachineSpaceCoordinate{ + /*node_idx=*/2, /*device_idx=*/4, DeviceType::GPU}; + MachineSpaceCoordinate result = + get_machine_space_coordinate(task, mv, coord, ms).value(); + CHECK(correct == result); + } + } + + SUBCASE("2D case - projection on same dimension") { + // This operator has shape (2, 2), and thus 2 * 2 = 4 tasks. + // Both dimensions are projected on the INTRA (device) dimension, with + // strides 1 and 2 respectively. The start of the projection defined by + // MachineView is at MachineSpaceCoordinates (1, 0), and the machine + // space has 2 nodes and 6 devices per node. + + /** + * +-------+-------+-------+-------+-------+-------+ + * | (0,0) | (1,0) | | | (0,1) | (1,1) | + * +-------+-------+-------+-------+-------+-------+ + * Where the (x,y) are the `TaskSpaceCoordinate`s, and the underlying + * grid is the machine space. + */ + + OperatorTaskSpace task = OperatorTaskSpace{{2, 2}}; + MachineView mv = MachineView{ + MachineSpaceCoordinate{ + /*node_idx=*/1, /*device_idx=*/0, DeviceType::GPU}, + {MachineViewDimension{stride_t{1}, + MachineSpecificationDimension::INTRA_NODE}, + MachineViewDimension{stride_t{2}, + MachineSpecificationDimension::INTRA_NODE}}}; + MachineSpecification ms = + MachineSpecification{/*num_nodes=*/2, + /*num_cpus_per_node=*/6, + /*num_gpus_per_node=*/6, + /*inter_node_bandwidth=*/0, + /*intra_node_bandwidth=*/0}; + + SUBCASE("Task with TaskSpaceCoordinate = (0,0)") { + TaskSpaceCoordinate coord = TaskSpaceCoordinate{{0, 0}}; + MachineSpaceCoordinate correct = MachineSpaceCoordinate{ + /*node_idx=*/1, /*device_idx=*/0, DeviceType::GPU}; + MachineSpaceCoordinate result = + get_machine_space_coordinate(task, mv, coord, ms).value(); + CHECK(correct == result); + } + + SUBCASE("Task with TaskSpaceCoordinate = (0,1)") { + TaskSpaceCoordinate coord = TaskSpaceCoordinate{{0, 1}}; + MachineSpaceCoordinate correct = MachineSpaceCoordinate{ + /*node_idx=*/1, /*device_idx=*/4, DeviceType::GPU}; + MachineSpaceCoordinate result = + get_machine_space_coordinate(task, mv, coord, ms).value(); + CHECK(correct == result); + } + + SUBCASE("Task with TaskSpaceCoordinate = (1,0)") { + TaskSpaceCoordinate coord = TaskSpaceCoordinate{{1, 0}}; + MachineSpaceCoordinate correct = MachineSpaceCoordinate{ + /*node_idx=*/1, /*device_idx=*/1, DeviceType::GPU}; + MachineSpaceCoordinate result = + get_machine_space_coordinate(task, mv, coord, ms).value(); + CHECK(correct == result); + } + + SUBCASE("Task with TaskSpaceCoordinate = (1,1)") { + TaskSpaceCoordinate coord = TaskSpaceCoordinate{{1, 1}}; + MachineSpaceCoordinate correct = MachineSpaceCoordinate{ + /*node_idx=*/1, /*device_idx=*/5, DeviceType::GPU}; + MachineSpaceCoordinate result = + get_machine_space_coordinate(task, mv, coord, ms).value(); + CHECK(correct == result); + } + } + + SUBCASE("3D case") { + // This operator has shape (2, 2, 2), and thus 2 * 2 * 2 = 8 tasks. + // - The first dimension is projected onto the INTER (node) dimension + // with stride 1, + // - The second dimension is projected onto the INTRA (device) dimension + // with stride 2, + // - The third dimension is projected onto the INTRA (device) dimension + // with stride 1. The start of the projection defined by MachineView is + // at MachineSpaceCoordinates (0, 1), and the machine space has 2 nodes + // and 8 devices per node. + + /** + * The tasks will thus be distributed like this: + * +-------+-------+-------+-------+-------+-------+-------+-------+ + * | |(0,0,0)| |(0,0,1)| |(0,1,0)| |(0,1,1)| + * +-------+-------+-------+-------+-------+-------+-------+-------+ + * | |(1,0,0)| |(1,0,1)| |(1,1,0)| |(1,1,1)| + * +-------+-------+-------+-------+-------+-------+-------+-------+ + * Where the (x,y,z) are the `TaskSpaceCoordinate`s, and the underlying + * grid is the machine space. + */ + + OperatorTaskSpace task = OperatorTaskSpace{{2, 2, 2}}; + MachineView mv = MachineView{ + MachineSpaceCoordinate{ + /*node_idx=*/0, /*device_idx=*/1, DeviceType::GPU}, + {MachineViewDimension{stride_t{1}, + MachineSpecificationDimension::INTER_NODE}, + MachineViewDimension{stride_t{2}, + MachineSpecificationDimension::INTRA_NODE}, + MachineViewDimension{stride_t{1}, + MachineSpecificationDimension::INTRA_NODE}}}; + MachineSpecification ms = + MachineSpecification{/*num_nodes=*/2, + /*num_cpus_per_node=*/8, + /*num_gpus_per_node=*/8, + /*inter_node_bandwidth=*/0, + /*intra_node_bandwidth=*/0}; + + SUBCASE("Task with TaskSpaceCoordinate = (0,0,1)") { + TaskSpaceCoordinate coord = TaskSpaceCoordinate{{0, 1, 0}}; + MachineSpaceCoordinate correct = MachineSpaceCoordinate{ + /*node_idx=*/0, /*device_idx=*/3, DeviceType::GPU}; + MachineSpaceCoordinate result = + get_machine_space_coordinate(task, mv, coord, ms).value(); + CHECK(correct == result); + } + + SUBCASE("Task with TaskSpaceCoordinate = (1,1,0)") { + TaskSpaceCoordinate coord = TaskSpaceCoordinate{{1, 0, 1}}; + MachineSpaceCoordinate correct = MachineSpaceCoordinate{ + /*node_idx=*/1, /*device_idx=*/5, DeviceType::GPU}; + MachineSpaceCoordinate result = + get_machine_space_coordinate(task, mv, coord, ms).value(); + CHECK(correct == result); + } + + SUBCASE("Task with TaskSpaceCoordinate = (1,1,1)") { + TaskSpaceCoordinate coord = TaskSpaceCoordinate{{1, 1, 1}}; + MachineSpaceCoordinate correct = MachineSpaceCoordinate{ + /*node_idx=*/1, /*device_idx=*/7, DeviceType::GPU}; + MachineSpaceCoordinate result = + get_machine_space_coordinate(task, mv, coord, ms).value(); + CHECK(correct == result); + } + } + } + } +} diff --git a/lib/pcg/test/src/pcg/operator_task_space.cc b/lib/pcg/test/src/pcg/operator_task_space.cc new file mode 100644 index 0000000000..13198d9456 --- /dev/null +++ b/lib/pcg/test/src/pcg/operator_task_space.cc @@ -0,0 +1,66 @@ +#include "pcg/operator_task_space.h" +#include "utils/fmt/unordered_set.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_task_space_coordinates") { + + SUBCASE("OperatorTaskSpace has 0 dimensions") { + OperatorTaskSpace task = OperatorTaskSpace{{}}; + + std::unordered_set correct = { + TaskSpaceCoordinate{{}}}; + std::unordered_set result = + get_task_space_coordinates(task); + CHECK(correct == result); + } + SUBCASE("OperatorTaskSpace has 2 dimensions") { + + OperatorTaskSpace task = OperatorTaskSpace{{2, 2}}; + + std::unordered_set correct = {{ + TaskSpaceCoordinate{{0, 0}}, + TaskSpaceCoordinate{{0, 1}}, + TaskSpaceCoordinate{{1, 0}}, + TaskSpaceCoordinate{{1, 1}}, + }}; + std::unordered_set result = + get_task_space_coordinates(task); + CHECK(correct == result); + } + SUBCASE("OperatorTaskSpace has 3 dimensions") { + + OperatorTaskSpace task = OperatorTaskSpace{{1, 2, 2}}; + + std::unordered_set correct = {{ + TaskSpaceCoordinate{{0, 0, 0}}, + TaskSpaceCoordinate{{0, 0, 1}}, + TaskSpaceCoordinate{{0, 1, 0}}, + TaskSpaceCoordinate{{0, 1, 1}}, + }}; + std::unordered_set result = + get_task_space_coordinates(task); + CHECK(correct == result); + } + } + TEST_CASE("get_task_space_maximum_coordinate") { + SUBCASE("OperatorTaskSpace has 2 dimensions") { + + OperatorTaskSpace task = OperatorTaskSpace{{3, 2}}; + + TaskSpaceCoordinate correct = TaskSpaceCoordinate{{2, 1}}; + TaskSpaceCoordinate result = get_task_space_maximum_coordinate(task); + CHECK(correct == result); + } + SUBCASE("OperatorTaskSpace has 3 dimensions") { + + OperatorTaskSpace task = OperatorTaskSpace{{3, 2, 4}}; + + TaskSpaceCoordinate correct = TaskSpaceCoordinate{{2, 1, 3}}; + TaskSpaceCoordinate result = get_task_space_maximum_coordinate(task); + CHECK(correct == result); + } + } +} diff --git a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index 188447da92..fc07edf5b3 100644 --- a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -1,4 +1,8 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "op-attrs/ops/linear.h" +#include "op-attrs/ops/replicate.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" #include "test/utils/rapidcheck.h" #include "utils/containers/get_only.h" @@ -35,4 +39,277 @@ TEST_SUITE(FF_TEST_SUITE) { // std::vector correct = {layer1, layer2, layer3}; // CHECK(result == correct); } + + TEST_CASE( + "get_incoming_inputs(ParallelComputationGraph, parallel_layer_guid_t)") { + ParallelTensorShape input_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{10, 2}, + ShardParallelDim{12, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + SUBCASE("layer has no inputs") { + std::string input_name = "my input"; + ParallelComputationGraph pcg = [&] { + ParallelComputationGraphBuilder b; + + b.create_input_tensor(input_shape, CreateGrad::YES, input_name); + + return b.pcg; + }(); + + parallel_layer_guid_t input_layer = + get_parallel_layer_by_name(pcg, input_name); + + std::vector result = + get_incoming_inputs(pcg, input_layer); + std::vector correct = {}; + + CHECK(result == correct); + } + + SUBCASE("layer has inputs and weights") { + std::string my_op_name = "my op"; + + ParallelComputationGraphBuilder b; + + parallel_tensor_guid_t input = + b.create_input_tensor(input_shape, CreateGrad::YES); + b.dense(input, + /*outDim=*/14, + /*activation=*/Activation::RELU, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*projection_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt, + /*name=*/my_op_name); + + ParallelComputationGraph pcg = b.pcg; + + parallel_layer_guid_t my_op_layer = + get_parallel_layer_by_name(pcg, my_op_name); + + std::vector result = + get_incoming_inputs(pcg, my_op_layer); + std::vector correct = {input}; + + CHECK(result == correct); + } + } + + TEST_CASE( + "get_incoming_weights(ParallelComputationGraph, parallel_layer_guid_t)") { + ParallelTensorShape input_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{10, 2}, + ShardParallelDim{12, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + SUBCASE("layer has no inputs or weights") { + std::string input_name = "my input"; + ParallelComputationGraph pcg = [&] { + ParallelComputationGraphBuilder b; + + b.create_input_tensor(input_shape, CreateGrad::YES, input_name); + + return b.pcg; + }(); + + parallel_layer_guid_t input_layer = + get_parallel_layer_by_name(pcg, input_name); + + std::vector result = + get_incoming_weights(pcg, input_layer); + std::vector correct = {}; + + CHECK(result == correct); + } + + SUBCASE("layer has inputs but no weights") { + std::string my_op_name = "my op"; + ParallelComputationGraph pcg = [&] { + ParallelComputationGraphBuilder b; + + parallel_tensor_guid_t input = + b.create_input_tensor(input_shape, CreateGrad::YES); + b.relu(input, my_op_name); + + return b.pcg; + }(); + + parallel_layer_guid_t my_op_layer = + get_parallel_layer_by_name(pcg, my_op_name); + + std::vector result = + get_incoming_weights(pcg, my_op_layer); + std::vector correct = {}; + + CHECK(result == correct); + } + + SUBCASE("layer has inputs and weights, and weights are separate by " + "parallel ops") { + std::string my_op_name = "my op"; + + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + LinearAttrs op_attrs = LinearAttrs{ + /*out_channels=*/14, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*activation=*/Activation::RELU, + /*regularizer=*/std::nullopt, + }; + + ParallelLayerAddedResult input_added = [&] { + ParallelLayerAttrs input_attrs = ParallelLayerAttrs{ + PCGOperatorAttrs{InputAttrs{}}, + std::nullopt, + }; + ParallelTensorAttrs input_tensor_attrs = + ParallelTensorAttrs{input_shape, + /*sync_type=*/std::nullopt, + /*initializer=*/std::nullopt, + CreateGrad::YES}; + + return add_parallel_layer(pcg, input_attrs, {}, {input_tensor_attrs}); + }(); + parallel_tensor_guid_t input = get_only(input_added.outputs); + + ParallelLayerAddedResult projection_weight_added = [&] { + ParallelTensorShape projection_weight_shape = + throw_if_unexpected(get_projection_shape(op_attrs, input_shape)); + + TensorShape unpar_projection_shape = + get_reduced_shape(projection_weight_shape); + ParallelTensorShape raw_projection_weight_shape = + lift_to_parallel(unpar_projection_shape); + + ParallelLayerAttrs raw_projection_weight_attrs = ParallelLayerAttrs{ + PCGOperatorAttrs{WeightAttrs{unpar_projection_shape}}, + std::nullopt, + }; + ParallelTensorAttrs raw_projection_tensor_attrs = + ParallelTensorAttrs{raw_projection_weight_shape, + /*sync_type=*/std::nullopt, + /*initializer=*/std::nullopt, + CreateGrad::YES}; + + ParallelLayerAddedResult raw_weight_added = + add_parallel_layer(pcg, + raw_projection_weight_attrs, + {}, + {raw_projection_tensor_attrs}); + + ReplicateAttrs replicate_attrs = ReplicateAttrs{/*degree=*/2}; + ParallelLayerAttrs replicate_layer_attrs = ParallelLayerAttrs{ + PCGOperatorAttrs{replicate_attrs}, + std::nullopt, + }; + ParallelTensorAttrs replicated_projection_tensor_attrs = + ParallelTensorAttrs{ + get_output_shape(replicate_attrs, raw_projection_weight_shape), + /*sync_type=*/std::nullopt, + /*initializer=*/std::nullopt, + CreateGrad::YES}; + return add_parallel_layer(pcg, + replicate_layer_attrs, + {}, + {replicated_projection_tensor_attrs}); + }(); + parallel_tensor_guid_t projection_weight = + get_only(projection_weight_added.outputs); + + ParallelLayerAddedResult my_op_added = [&] { + ParallelTensorShape output_shape = + throw_if_unexpected(get_output_shape(op_attrs, input_shape)); + + ParallelLayerAttrs layer_attrs = ParallelLayerAttrs{ + PCGOperatorAttrs{op_attrs}, + std::nullopt, + }; + ParallelTensorAttrs output_tensor_attrs = + ParallelTensorAttrs{output_shape, + /*sync_type=*/std::nullopt, + /*initializer=*/std::nullopt, + CreateGrad::YES}; + + return add_parallel_layer(pcg, + layer_attrs, + {input, projection_weight}, + {output_tensor_attrs}); + }(); + + parallel_layer_guid_t my_op_layer = my_op_added.parallel_layer; + + std::vector result = + get_incoming_weights(pcg, my_op_layer); + std::vector correct = {projection_weight}; + + CHECK(result == correct); + } + } + + TEST_CASE("pcg_add_input_layer") { + ParallelTensorShape tensor_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{12, 2}, + ShardParallelDim{10, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{2}, + DiscardCopyDegree{2}, + }, + }, + DataType::FLOAT, + }; + + ParallelComputationGraph result = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + pcg_add_input_layer(pcg, tensor_shape); + return pcg; + }(); + + ParallelComputationGraph correct = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAttrs layer_attrs = ParallelLayerAttrs{ + /*op_attrs=*/PCGOperatorAttrs{InputAttrs{}}, + /*name=*/std::nullopt, + }; + + ParallelTensorAttrs tensor_attrs = ParallelTensorAttrs{ + /*shape=*/tensor_shape, + /*sync_type=*/std::nullopt, + /*initializer=*/std::nullopt, + /*create_gradients=*/CreateGrad::NO, + }; + + add_parallel_layer(/*pcg=*/pcg, + /*layer_attrs=*/layer_attrs, + /*inputs=*/{}, + /*output_labels=*/{tensor_attrs}); + + return pcg; + }(); + + CHECK(pcgs_are_isomorphic(result, correct)); + } } diff --git a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc index 440f735e80..20bd0ac92d 100644 --- a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc +++ b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -1,9 +1,9 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include "op-attrs/ops/conv_2d.h" #include "op-attrs/parallel_tensor_shape.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "pcg/parallel_computation_graph/parallel_layer_attrs.h" #include "pcg/parallel_computation_graph/parallel_tensor_guid_t.h" -#include "test/utils/doctest.h" #include "utils/containers/count.h" #include "utils/containers/generate_map.h" #include "utils/containers/get_only.h" @@ -12,7 +12,16 @@ #include "utils/containers/values.h" #include "utils/containers/without_nullopts.h" #include "utils/hash/pair.h" +#include +using namespace ::FlexFlow; + +// Stylistically these tests are not great (they're rather complicated +// and hard to read) and should not be used as a model for other FlexFlow +// tests. +// +// Improving them is being tracked in +// https://github.com/flexflow/FlexFlow/issues/1474 TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("ParallelComputationGraphBuilder::add") { ParallelComputationGraphBuilder b; @@ -42,9 +51,9 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t out = b.add(lhs, rhs); parallel_layer_guid_t layer = get_source_layer(out); - SUBCASE("inputs") { + SUBCASE("incoming") { std::vector result = - get_layer_inputs(b.pcg, layer); + get_incoming_tensors(b.pcg, layer); std::vector correct = {lhs, rhs}; CHECK(result == correct); } @@ -105,9 +114,9 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t out = b.batch_matmul(a_tensor, b_tensor); parallel_layer_guid_t layer = get_source_layer(out); - SUBCASE("inputs") { + SUBCASE("incoming") { std::vector result = - get_layer_inputs(b.pcg, layer); + get_incoming_tensors(b.pcg, layer); std::vector correct = {a_tensor, b_tensor}; CHECK(result == correct); } @@ -148,9 +157,9 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t output = b.cast(input, output_datatype); parallel_layer_guid_t layer = get_source_layer(output); - SUBCASE("inputs") { + SUBCASE("incoming") { std::vector result = - get_layer_inputs(b.pcg, layer); + get_incoming_tensors(b.pcg, layer); std::vector correct = {input}; CHECK(result == correct); } @@ -227,7 +236,7 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(num_replicate_attrs == 2); parallel_layer_guid_t conv_guid = get_only(without_nullopts(transform( - as_vector(items(layers)), + vector_of(items(layers)), [](std::pair const &kv) -> std::optional { if (get_op_type(kv.second) == OperatorType::CONV2D) { @@ -258,20 +267,20 @@ TEST_SUITE(FF_TEST_SUITE) { ParallelTensorShape correct_bias_shape = get_bias_shape(correct_attrs, input_shape); - std::vector conv_inputs = - get_layer_inputs(b.pcg, conv_guid); + std::vector conv_incoming = + get_incoming_tensors(b.pcg, conv_guid); - parallel_tensor_guid_t conv_input = conv_inputs.at(0); + parallel_tensor_guid_t conv_input = conv_incoming.at(0); ParallelTensorShape conv_input_shape = get_parallel_tensor_attrs(b.pcg, conv_input).shape; CHECK(conv_input_shape == input_shape); - parallel_tensor_guid_t conv_kernel = conv_inputs.at(1); + parallel_tensor_guid_t conv_kernel = conv_incoming.at(1); ParallelTensorShape conv_kernel_shape = get_parallel_tensor_attrs(b.pcg, conv_kernel).shape; CHECK(conv_kernel_shape == correct_kernel_shape); - parallel_tensor_guid_t conv_bias = conv_inputs.at(2); + parallel_tensor_guid_t conv_bias = conv_incoming.at(2); ParallelTensorShape conv_bias_shape = get_parallel_tensor_attrs(b.pcg, conv_bias).shape; CHECK(conv_bias_shape == correct_bias_shape); @@ -313,9 +322,9 @@ TEST_SUITE(FF_TEST_SUITE) { DataType::FLOAT); parallel_layer_guid_t layer = get_source_layer(output); - SUBCASE("inputs") { + SUBCASE("incoming") { std::vector result = - get_layer_inputs(b.pcg, layer); + get_incoming_tensors(b.pcg, layer); CHECK(result.at(0) == input); CHECK(result.size() == 3); @@ -356,9 +365,9 @@ TEST_SUITE(FF_TEST_SUITE) { DataType::FLOAT); parallel_layer_guid_t layer = get_source_layer(output); - SUBCASE("inputs") { + SUBCASE("incoming") { std::vector result = - get_layer_inputs(b.pcg, layer); + get_incoming_tensors(b.pcg, layer); CHECK(result.at(0) == input); CHECK(result.size() == 2); @@ -406,9 +415,9 @@ TEST_SUITE(FF_TEST_SUITE) { b.multihead_attention(query, key, value, embed_dim, num_heads); parallel_layer_guid_t layer = get_source_layer(output); - SUBCASE("inputs") { + SUBCASE("incoming") { std::vector result = - get_layer_inputs(b.pcg, layer); + get_incoming_tensors(b.pcg, layer); CHECK(result.at(0) == query); CHECK(result.at(1) == key); CHECK(result.at(2) == value); @@ -447,9 +456,9 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t output = b.relu(input); parallel_layer_guid_t layer = get_source_layer(output); - SUBCASE("inputs") { + SUBCASE("incoming") { std::vector result = - get_layer_inputs(b.pcg, layer); + get_incoming_tensors(b.pcg, layer); std::vector correct = {input}; CHECK(result == correct); } @@ -486,9 +495,9 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t output = b.parallel_partition(input, ff_dim_t{0}, 2); parallel_layer_guid_t layer = get_source_layer(output); - SUBCASE("inputs") { + SUBCASE("incoming") { std::vector result = - get_layer_inputs(b.pcg, layer); + get_incoming_tensors(b.pcg, layer); std::vector correct = {input}; CHECK(result == correct); } @@ -525,9 +534,9 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t output = b.parallel_combine(input, ff_dim_t{0}, 2); parallel_layer_guid_t layer = get_source_layer(output); - SUBCASE("inputs") { + SUBCASE("incoming") { std::vector result = - get_layer_inputs(b.pcg, layer); + get_incoming_tensors(b.pcg, layer); std::vector correct = {input}; CHECK(result == correct); } @@ -564,9 +573,9 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t output = b.parallel_replicate(input, 2); parallel_layer_guid_t layer = get_source_layer(output); - SUBCASE("inputs") { + SUBCASE("incoming") { std::vector result = - get_layer_inputs(b.pcg, layer); + get_incoming_tensors(b.pcg, layer); std::vector correct = {input}; CHECK(result == correct); } @@ -603,9 +612,9 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t output = b.parallel_reduce(input, 2); parallel_layer_guid_t layer = get_source_layer(output); - SUBCASE("inputs") { + SUBCASE("incoming") { std::vector result = - get_layer_inputs(b.pcg, layer); + get_incoming_tensors(b.pcg, layer); std::vector correct = {input}; CHECK(result == correct); } diff --git a/lib/pcg/test/src/pcg/start_invariant_machine_view.cc b/lib/pcg/test/src/pcg/start_invariant_machine_view.cc new file mode 100644 index 0000000000..8383754aa2 --- /dev/null +++ b/lib/pcg/test/src/pcg/start_invariant_machine_view.cc @@ -0,0 +1,229 @@ +#include "pcg/start_invariant_machine_view.h" +#include "utils/fmt/unordered_set.h" +#include "utils/fmt/vector.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("StartInvariantMachineView - utility functions") { + StartInvariantMachineView simv = StartInvariantMachineView{ + {MachineViewDimension{stride_t{2}, + MachineSpecificationDimension::INTER_NODE}, + MachineViewDimension{stride_t{2}, + MachineSpecificationDimension::INTER_NODE}}, + DeviceType::GPU}; + + SUBCASE("num_dims") { + int result = num_dims(simv); + int correct = 2; + CHECK(result == correct); + } + + SUBCASE("get_device_type") { + DeviceType result = get_device_type(simv); + DeviceType correct = DeviceType::GPU; + CHECK(result == correct); + } + + SUBCASE("get_strides") { + std::vector result = get_strides(simv); + std::vector correct = {stride_t{2}, stride_t{2}}; + CHECK(result == correct); + } + + SUBCASE("get_dimensions") { + std::vector result = get_dimensions(simv); + std::vector correct = { + MachineSpecificationDimension::INTER_NODE, + MachineSpecificationDimension::INTER_NODE}; + CHECK(result == correct); + } + } + + TEST_CASE("StartInvariantMachineView - conversions") { + MachineSpaceCoordinate start = + MachineSpaceCoordinate{1, 2, DeviceType::GPU}; + std::vector dimensions = { + MachineViewDimension{stride_t{2}, + MachineSpecificationDimension::INTER_NODE}, + MachineViewDimension{stride_t{3}, + MachineSpecificationDimension::INTRA_NODE}}; + + MachineView mv = MachineView{start, dimensions}; + StartInvariantMachineView simv = + StartInvariantMachineView{dimensions, DeviceType::GPU}; + + SUBCASE("start_invariant_from_machine_view") { + StartInvariantMachineView result = start_invariant_from_machine_view(mv); + StartInvariantMachineView correct = simv; + CHECK(result == correct); + } + + SUBCASE("machine_view_from_start_invariant") { + MachineView result = machine_view_from_start_invariant(simv, start); + MachineView correct = mv; + CHECK(result == correct); + } + + SUBCASE("conversion is invertible") { + SUBCASE("MachineView -> StartInvariant -> MachineView") { + MachineView result = machine_view_from_start_invariant( + start_invariant_from_machine_view(mv), start); + MachineView correct = mv; + CHECK(result == correct); + } + + SUBCASE("StartInvariant -> MachineView -> StartInvariant") { + StartInvariantMachineView result = start_invariant_from_machine_view( + machine_view_from_start_invariant(simv, start)); + StartInvariantMachineView correct = simv; + CHECK(result == correct); + } + } + } + + TEST_CASE("StartInvariantMachineView - get_machine_space_offset") { + SUBCASE("1D case") { + // This operator has shape (3,), and thus 3 tasks. + // The (only) dimension is projected on the INTRA (device) dimension with + // a stride of 2. The machine space has 1 node and 6 devices per node. + /** + * The tasks will thus be distributed like this: + * +-------+-------+-------+-------+-------+-------+ + * | (0,) | | (1,) | | (2,) | | + * +-------+-------+-------+-------+-------+-------+ + */ + OperatorTaskSpace task = OperatorTaskSpace{{3}}; + StartInvariantMachineView simv = StartInvariantMachineView{ + {MachineViewDimension{stride_t{2}, + MachineSpecificationDimension::INTRA_NODE}}, + DeviceType::GPU}; + MachineSpecification ms = + MachineSpecification{/*num_nodes=*/1, + /*num_cpus_per_node=*/6, + /*num_gpus_per_node=*/6, + /*inter_node_bandwidth=*/0, + /*intra_node_bandwidth=*/0}; + + SUBCASE("get_machine_space_offset") { + SUBCASE("Task with TaskSpaceCoordinate = (0,)") { + TaskSpaceCoordinate coord = TaskSpaceCoordinate{{0}}; + MachineSpaceOffset correct = + MachineSpaceOffset{0, 0, DeviceType::GPU}; + MachineSpaceOffset result = + get_machine_space_offset(task, simv, coord, ms).value(); + CHECK(correct == result); + } + + SUBCASE("Task with TaskSpaceCoordinate = (1,)") { + TaskSpaceCoordinate coord = TaskSpaceCoordinate{{1}}; + MachineSpaceOffset correct = + MachineSpaceOffset{0, 2, DeviceType::GPU}; + MachineSpaceOffset result = + get_machine_space_offset(task, simv, coord, ms).value(); + CHECK(correct == result); + } + + SUBCASE("Task with TaskSpaceCoordinate = (2,)") { + TaskSpaceCoordinate coord = TaskSpaceCoordinate{{2}}; + MachineSpaceOffset correct = + MachineSpaceOffset{0, 4, DeviceType::GPU}; + MachineSpaceOffset result = + get_machine_space_offset(task, simv, coord, ms).value(); + CHECK(correct == result); + } + } + + SUBCASE("get_machine_space_offsets") { + std::unordered_set correct = { + MachineSpaceOffset{0, 0, DeviceType::GPU}, + MachineSpaceOffset{0, 2, DeviceType::GPU}, + MachineSpaceOffset{0, 4, DeviceType::GPU}}; + std::unordered_set result = + get_machine_space_offsets(task, simv, ms); + CHECK(correct == result); + } + } + + SUBCASE("2D case") { + // This operator has shape (2, 2), and thus 2 * 2 = 4 tasks. + // The first dimension is projected onto the INTER (node) dimension with + // stride 1, while the second dimension is projected onto the INTRA + // (device) dimension with stride 2. The machine space has 2 nodes and 4 + // devices per node. + + /** + * The tasks will thus be distributed like this: + * +-------+-------+-------+-------+ + * | (0,0) | | (0,1) | | + * +-------+-------+-------+-------+ + * | (1,0) | | (1,1) | | + * +-------+-------+-------+-------+ + */ + + OperatorTaskSpace task = OperatorTaskSpace{{2, 2}}; + StartInvariantMachineView simv = StartInvariantMachineView{ + {MachineViewDimension{stride_t{1}, + MachineSpecificationDimension::INTER_NODE}, + MachineViewDimension{stride_t{2}, + MachineSpecificationDimension::INTRA_NODE}}, + DeviceType::GPU}; + MachineSpecification ms = + MachineSpecification{/*num_nodes=*/2, + /*num_cpus_per_node=*/4, + /*num_gpus_per_node=*/4, + /*inter_node_bandwidth=*/0, + /*intra_node_bandwidth=*/0}; + + SUBCASE("get_machine_space_offset") { + SUBCASE("Task with TaskSpaceCoordinate = (0,0)") { + TaskSpaceCoordinate coord = TaskSpaceCoordinate{{0, 0}}; + MachineSpaceOffset correct = + MachineSpaceOffset{0, 0, DeviceType::GPU}; + MachineSpaceOffset result = + get_machine_space_offset(task, simv, coord, ms).value(); + CHECK(correct == result); + } + + SUBCASE("Task with TaskSpaceCoordinate = (0,1)") { + TaskSpaceCoordinate coord = TaskSpaceCoordinate{{0, 1}}; + MachineSpaceOffset correct = + MachineSpaceOffset{0, 2, DeviceType::GPU}; + MachineSpaceOffset result = + get_machine_space_offset(task, simv, coord, ms).value(); + CHECK(correct == result); + } + + SUBCASE("Task with TaskSpaceCoordinate = (1,0)") { + TaskSpaceCoordinate coord = TaskSpaceCoordinate{{1, 0}}; + MachineSpaceOffset correct = + MachineSpaceOffset{1, 0, DeviceType::GPU}; + MachineSpaceOffset result = + get_machine_space_offset(task, simv, coord, ms).value(); + CHECK(correct == result); + } + + SUBCASE("Task with TaskSpaceCoordinate = (1,1)") { + TaskSpaceCoordinate coord = TaskSpaceCoordinate{{1, 1}}; + MachineSpaceOffset correct = + MachineSpaceOffset{1, 2, DeviceType::GPU}; + MachineSpaceOffset result = + get_machine_space_offset(task, simv, coord, ms).value(); + CHECK(correct == result); + } + } + + SUBCASE("get_machine_space_offsets") { + std::unordered_set correct = { + MachineSpaceOffset{0, 0, DeviceType::GPU}, + MachineSpaceOffset{0, 2, DeviceType::GPU}, + MachineSpaceOffset{1, 0, DeviceType::GPU}, + MachineSpaceOffset{1, 2, DeviceType::GPU}}; + std::unordered_set result = + get_machine_space_offsets(task, simv, ms); + CHECK(correct == result); + } + } + } +} diff --git a/lib/pcg/test/src/test_machine_view.cc b/lib/pcg/test/src/test_machine_view.cc deleted file mode 100644 index 70fe958d8c..0000000000 --- a/lib/pcg/test/src/test_machine_view.cc +++ /dev/null @@ -1,76 +0,0 @@ -#include "pcg/machine_view.h" -#include "pcg/strided_rectangle.h" -#include "pcg/strided_rectangle_side.h" -#include "test/utils/doctest.h" - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("MachineView general util functions") { - StridedRectangle rect{{StridedRectangleSide{num_points_t{7}, 5}, - StridedRectangleSide{num_points_t{10}, 2}}}; - gpu_id_t start(1); - MachineView mv{device_id_t{start}, rect}; - SUBCASE("num_dims") { - CHECK(num_dims(mv) == 2); - } - SUBCASE("num_devices") { - CHECK(num_devices(mv) == 7 * 10); - } - SUBCASE("get_device_type") { - CHECK(get_device_type(mv) == DeviceType::GPU); - } - } - - TEST_CASE("MachineView make_1d_machine_view - GPU") { - StridedRectangle rect{{StridedRectangleSide{num_points_t{7}, 5}}}; - device_id_t start_gpu{gpu_id_t{1}}; - MachineView gpu_mv{start_gpu, rect}; - - SUBCASE("make_1d_machine_view(gpu_id_t start, gpu_id_t stop, int stride)") { - MachineView result = - make_1d_machine_view(start_gpu, device_id_t{gpu_id_t(1 + 7 * 5)}, 5); - MachineView correct = gpu_mv; - CHECK(result == correct); - } - SUBCASE("make_1d_machine_view(gpu_id_t start, num_points_t num_points, int " - "stride)") { - MachineView result = make_1d_machine_view(start_gpu, num_points_t{7}, 5); - MachineView correct = gpu_mv; - CHECK(result == correct); - } - SUBCASE("make_1d_machine_view(gpu_id_t start, side_size_t interval_size, " - "int stride)") { - MachineView result = make_1d_machine_view( - start_gpu, get_side_size(rect.sides.at(ff_dim_t{0})), 5); - MachineView correct = gpu_mv; - CHECK(result == correct); - } - } - - TEST_CASE("MachineView make_1d_machine_view - CPU") { - StridedRectangle rect{{StridedRectangleSide{num_points_t{11}, 4}}}; - device_id_t start_cpu{cpu_id_t{2}}; - MachineView cpu_mv{start_cpu, rect}; - - SUBCASE("make_1d_machine_view(cpu_id_t start, cpu_id_t stop, int stride)") { - MachineView result = - make_1d_machine_view(start_cpu, device_id_t{cpu_id_t(2 + 11 * 4)}, 4); - MachineView correct = cpu_mv; - CHECK(result == correct); - } - SUBCASE("make_1d_machine_view(cpu_id_t start, num_points_t num_points, int " - "stride)") { - MachineView result = make_1d_machine_view(start_cpu, num_points_t{11}, 4); - MachineView correct = cpu_mv; - CHECK(result == correct); - } - SUBCASE("make_1d_machine_view(cpu_id_t start, side_size_t interval_size, " - "int stride)") { - MachineView result = make_1d_machine_view( - start_cpu, get_side_size(rect.sides.at(ff_dim_t{0})), 4); - MachineView correct = cpu_mv; - CHECK(result == correct); - } - } -} diff --git a/lib/pcg/test/src/test_strided_rectangle.cc b/lib/pcg/test/src/test_strided_rectangle.cc deleted file mode 100644 index 2fe3005b15..0000000000 --- a/lib/pcg/test/src/test_strided_rectangle.cc +++ /dev/null @@ -1,39 +0,0 @@ -#include "pcg/strided_rectangle.h" -#include "pcg/strided_rectangle_side.h" -#include "test/utils/doctest.h" - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_side_size(StridedRectangleSide)") { - StridedRectangleSide side{num_points_t{7}, 5}; - - CHECK(get_side_size(side) == side_size_t{7 * 5}); - } - TEST_CASE("strided_side_from_size_and_stride") { - StridedRectangleSide correct{num_points_t{10}, 3}; - StridedRectangleSide result = - strided_side_from_size_and_stride(side_size_t{10 * 3}, 3); - CHECK(result == correct); - } - - TEST_CASE("StridedRectangle - helper functions") { - - StridedRectangleSide s0{num_points_t{7}, 5}; - StridedRectangleSide s1{num_points_t{10}, 2}; - StridedRectangleSide s2{num_points_t{8}, 1}; - StridedRectangle rect{{s0, s1, s2}}; - - SUBCASE("get_num_dims") { - CHECK(get_num_dims(rect) == 3); - } - SUBCASE("get_num_points") { - CHECK(get_num_points(rect) == num_points_t{7 * 8 * 10}); - } - SUBCASE("get_side_at_idx") { - CHECK(get_side_at_idx(rect, ff_dim_t{0}) == s0); - CHECK(get_side_at_idx(rect, ff_dim_t{1}) == s1); - CHECK(get_side_at_idx(rect, ff_dim_t{2}) == s2); - } - } -} diff --git a/lib/runtime/src/accessor.cc b/lib/runtime/src/accessor.cc index 44ad8ab40d..84573fb4aa 100644 --- a/lib/runtime/src/accessor.cc +++ b/lib/runtime/src/accessor.cc @@ -129,7 +129,7 @@ struct GetTensorPointerWOFunctor { FieldID fid, Context ctx, Runtime *runtime) const { - return (void *)helperGetTensorPointerWO>( + return (void *)helperGetTensorPointerWO>( region, req, fid, ctx, runtime); } }; @@ -141,7 +141,7 @@ struct GetTensorPointerROFunctor { FieldID fid, Context ctx, Runtime *runtime) const { - return (void const *)helperGetTensorPointerRO>( + return (void const *)helperGetTensorPointerRO>( region, req, fid, ctx, runtime); } }; @@ -153,7 +153,7 @@ struct GetTensorPointerRWFUnctor { FieldID fid, Context ctx, Runtime *runtime) const { - return (void *)helperGetTensorPointerRW>( + return (void *)helperGetTensorPointerRW>( region, req, fid, ctx, runtime); } }; diff --git a/lib/runtime/test/src/test_serialization.cc b/lib/runtime/test/src/test_serialization.cc index e46a481a1a..471f2a2709 100644 --- a/lib/runtime/test/src/test_serialization.cc +++ b/lib/runtime/test/src/test_serialization.cc @@ -1,7 +1,6 @@ #include "doctest/doctest.h" #include "legion/legion_utilities.h" #include "op-attrs/ffconst.h" -#include "op-attrs/operator_attrs.h" #include "serialization.h" #include diff --git a/lib/substitutions/include/substitutions/operator_pattern/get_attribute.h b/lib/substitutions/include/substitutions/operator_pattern/get_attribute.h index d4d38af228..a5f0cc6fdc 100644 --- a/lib/substitutions/include/substitutions/operator_pattern/get_attribute.h +++ b/lib/substitutions/include/substitutions/operator_pattern/get_attribute.h @@ -10,64 +10,66 @@ namespace FlexFlow { std::optional get_attribute(PCGOperatorAttrs const &, OperatorAttributeKey); -std::optional get_attribute(BatchMatmulAttrs const &p, +std::optional get_attribute(BatchMatmulAttrs const &, OperatorAttributeKey); -std::optional get_attribute(BatchNormAttrs const &p, +std::optional get_attribute(BatchNormAttrs const &, OperatorAttributeKey); -std::optional get_attribute(CastAttrs const &p, +std::optional get_attribute(BroadcastAttrs const &, OperatorAttributeKey); -std::optional get_attribute(CombineAttrs const &p, +std::optional get_attribute(CastAttrs const &, OperatorAttributeKey); -std::optional get_attribute(ConcatAttrs const &p, +std::optional get_attribute(CombineAttrs const &, OperatorAttributeKey); -std::optional get_attribute(Conv2DAttrs const &p, +std::optional get_attribute(ConcatAttrs const &, OperatorAttributeKey); -std::optional get_attribute(ElementBinaryAttrs const &p, +std::optional get_attribute(Conv2DAttrs const &, OperatorAttributeKey); -std::optional get_attribute(ElementUnaryAttrs const &p, +std::optional get_attribute(ElementBinaryAttrs const &, OperatorAttributeKey); -std::optional get_attribute(DropoutAttrs const &p, +std::optional get_attribute(ElementUnaryAttrs const &, OperatorAttributeKey); -std::optional get_attribute(EmbeddingAttrs const &p, +std::optional get_attribute(DropoutAttrs const &, OperatorAttributeKey); -std::optional get_attribute(FlatAttrs const &p, +std::optional get_attribute(EmbeddingAttrs const &, OperatorAttributeKey); -std::optional get_attribute(GatherAttrs const &p, +std::optional get_attribute(FlatAttrs const &, OperatorAttributeKey); -std::optional get_attribute(InputAttrs const &p, +std::optional get_attribute(GatherAttrs const &, OperatorAttributeKey); -std::optional get_attribute(LayerNormAttrs const &p, +std::optional get_attribute(InputAttrs const &, OperatorAttributeKey); -std::optional get_attribute(LinearAttrs const &p, +std::optional get_attribute(LayerNormAttrs const &, + OperatorAttributeKey); +std::optional get_attribute(LinearAttrs const &, OperatorAttributeKey); std::optional - get_attribute(MultiHeadAttentionAttrs const &p, OperatorAttributeKey); + get_attribute(MultiHeadAttentionAttrs const &, OperatorAttributeKey); -std::optional get_attribute(NoopAttrs const &p, +std::optional get_attribute(NoopAttrs const &, OperatorAttributeKey); -std::optional get_attribute(Pool2DAttrs const &p, +std::optional get_attribute(Pool2DAttrs const &, OperatorAttributeKey); -std::optional get_attribute(ReduceAttrs const &p, +std::optional get_attribute(ReduceAttrs const &, OperatorAttributeKey); -std::optional get_attribute(ReductionAttrs const &p, +std::optional get_attribute(ReductionAttrs const &, OperatorAttributeKey); -std::optional get_attribute(RepartitionAttrs const &p, +std::optional get_attribute(RepartitionAttrs const &, OperatorAttributeKey); -std::optional get_attribute(ReplicateAttrs const &p, +std::optional get_attribute(ReplicateAttrs const &, OperatorAttributeKey); -std::optional get_attribute(ReshapeAttrs const &p, +std::optional get_attribute(ReshapeAttrs const &, OperatorAttributeKey); -std::optional get_attribute(ReverseAttrs const &p, +std::optional get_attribute(ReverseAttrs const &, OperatorAttributeKey); -std::optional get_attribute(SplitAttrs const &p, +std::optional get_attribute(SplitAttrs const &, OperatorAttributeKey); -std::optional get_attribute(SoftmaxAttrs const &p, +std::optional get_attribute(SoftmaxAttrs const &, OperatorAttributeKey); -std::optional get_attribute(TopKAttrs const &p, +std::optional get_attribute(TopKAttrs const &, OperatorAttributeKey); -std::optional get_attribute(TransposeAttrs const &p, +std::optional get_attribute(TransposeAttrs const &, OperatorAttributeKey); -// optional get_attribute(FusedParallelOpAttrs const &p, +// optional get_attribute(FusedParallelOpAttrs const &, // OperatorAttributeKey); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.enum.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.enum.toml index 59e913750e..eb758ea4fc 100644 --- a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.enum.toml +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.enum.toml @@ -55,7 +55,8 @@ values = [ { name = "SHOULD_BROADCAST_LHS" }, { name = "SHOULD_BROADCAST_RHS" }, { name = "DIM" }, - { name = "ELEMENTWISE_AFFINE" }, + { name = "AFFINE" }, + { name = "MOMENTUM" }, { name = "REGULARIZER" }, { name = "SHAPE" }, { name = "SPLITS" }, diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml index 7df65ef361..02a856f59a 100644 --- a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml @@ -25,6 +25,7 @@ includes = [ src_includes = [ "utils/fmt/optional.h", + "utils/json/optional.h", "utils/fmt/vector.h", "utils/hash/vector.h", ] @@ -68,5 +69,8 @@ type = "::FlexFlow::PoolOp" [[values]] type = "::FlexFlow::TensorShape" +[[values]] +type = "::FlexFlow::TensorDims" + [[values]] type = "::FlexFlow::DataType" diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h index 00032045c0..2d76352ccf 100644 --- a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h +++ b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h @@ -45,10 +45,6 @@ std::unordered_set get_subgraph_outgoing_edges( SubParallelComputationGraph const &, std::unordered_set const &); -std::unordered_set get_subgraph_incoming_edges( - SubParallelComputationGraph const &, - std::unordered_set const &); - std::unordered_set get_parallel_tensor_uses(SubParallelComputationGraph const &, open_parallel_tensor_guid_t const &); diff --git a/lib/substitutions/include/substitutions/substitution_internal/perform_shape_inference.h b/lib/substitutions/include/substitutions/substitution_internal/perform_shape_inference.h index de9d1cd78a..b7ce13db0e 100644 --- a/lib/substitutions/include/substitutions/substitution_internal/perform_shape_inference.h +++ b/lib/substitutions/include/substitutions/substitution_internal/perform_shape_inference.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUBSTITUTION_INTERNAL_PERFORM_SHAPE_INFERENCE_H #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUBSTITUTION_INTERNAL_PERFORM_SHAPE_INFERENCE_H +#include "op-attrs/parallel_tensor_shape.dtg.h" #include "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h" #include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h" diff --git a/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc index 26f8ff5062..442d3345a1 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc @@ -1,6 +1,6 @@ #include "substitutions/operator_pattern/get_attribute.h" #include "op-attrs/get_op_type.h" -#include "utils/containers/as_vector.h" +#include "utils/containers/vector_of.h" namespace FlexFlow { @@ -19,8 +19,24 @@ std::optional get_attribute(BatchNormAttrs const &p, switch (key) { case OperatorAttributeKey::OP_TYPE: return get_op_type(p); - case OperatorAttributeKey::RELU: - return p.relu; + case OperatorAttributeKey::EPSILON: + return p.eps; + case OperatorAttributeKey::AFFINE: + return p.affine; + case OperatorAttributeKey::MOMENTUM: + return p.momentum; + default: + return std::nullopt; + } +} + +std::optional get_attribute(BroadcastAttrs const &p, + OperatorAttributeKey key) { + switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); + case OperatorAttributeKey::TARGET_DIMS: + return p.target_dims; default: return std::nullopt; } @@ -177,6 +193,10 @@ std::optional get_attribute(LayerNormAttrs const &p, switch (key) { case OperatorAttributeKey::OP_TYPE: return get_op_type(p); + case OperatorAttributeKey::AFFINE: + return p.elementwise_affine; + case OperatorAttributeKey::AXES: + return vector_of(p.axes); default: return std::nullopt; } @@ -364,7 +384,7 @@ std::optional get_attribute(TransposeAttrs const &p, case OperatorAttributeKey::OP_TYPE: return get_op_type(p); case OperatorAttributeKey::PERMUTATION: - return as_vector(p.perm); + return vector_of(p.perm); default: return std::nullopt; } diff --git a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc index 0bbe0e97a7..0c673f0a8a 100644 --- a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc +++ b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc @@ -4,7 +4,7 @@ #include "utils/containers/values.h" #include "utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.h" #include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" -#include "utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_open_dataflow_graph_view.h" +#include "utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_dataflow_graph_view.h" #include "utils/graph/labelled_dataflow_graph/algorithms/view_as_labelled_open_dataflow_graph.h" #include "utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.h" #include "utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.h" @@ -54,12 +54,6 @@ ParallelComputationGraph pcg_from_sub_pcg_by_dropping_inputs( UnorderedSetLabelledOpenDataflowGraph>( sub_pcg.raw_graph)}; - // return ParallelComputationGraph{ - // make_lazy_copy_of< - // UnorderedSetLabelledOpenDataflowGraph - // >(sub_pcg.raw_graph) - // }; } parallel_layer_guid_t diff --git a/lib/substitutions/src/substitutions/substitution_internal/perform_shape_inference.cc b/lib/substitutions/src/substitutions/substitution_internal/perform_shape_inference.cc index 0bde326bd1..9fa91d75b7 100644 --- a/lib/substitutions/src/substitutions/substitution_internal/perform_shape_inference.cc +++ b/lib/substitutions/src/substitutions/substitution_internal/perform_shape_inference.cc @@ -1,4 +1,5 @@ #include "substitutions/substitution_internal/perform_shape_inference.h" +#include "op-attrs/get_output_shapes.h" #include "utils/containers/map_keys.h" #include "utils/containers/transform.h" #include "utils/containers/zip.h" diff --git a/lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc b/lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc index 05f21247c7..286bc69b84 100644 --- a/lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc +++ b/lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc @@ -1,7 +1,7 @@ #include "substitutions/tensor_pattern/get_attribute.h" #include "op-attrs/parallel_tensor_dims.h" -#include "utils/containers/as_vector.h" #include "utils/containers/transform.h" +#include "utils/containers/vector_of.h" #include "utils/integer_conversions.h" namespace FlexFlow { @@ -11,13 +11,13 @@ TensorAttributeValue get_attribute(ParallelTensorAttrs const &attrs, switch (key) { case TensorAttributeKey::DIM_SIZES: { std::vector sizes = - transform(as_vector(ff_ordered_shard_dims(attrs.shape.dims)), + transform(vector_of(ff_ordered_shard_dims(attrs.shape.dims)), [](ShardParallelDim const &d) { return d.size; }); return TensorAttributeValue{sizes}; } case TensorAttributeKey::DIM_DEGREES: { std::vector degrees = transform( - as_vector(ff_ordered_shard_dims(attrs.shape.dims)), + vector_of(ff_ordered_shard_dims(attrs.shape.dims)), [](ShardParallelDim const &d) { return size_t_from_int(d.degree); }); return TensorAttributeValue{degrees}; } diff --git a/lib/substitutions/test/src/substitutions/operator_pattern/get_attribute.cc b/lib/substitutions/test/src/substitutions/operator_pattern/get_attribute.cc index 70e960bc73..95b61e0ef4 100644 --- a/lib/substitutions/test/src/substitutions/operator_pattern/get_attribute.cc +++ b/lib/substitutions/test/src/substitutions/operator_pattern/get_attribute.cc @@ -1,4 +1,5 @@ #include "substitutions/operator_pattern/get_attribute.h" +#include "test/utils/doctest/fmt/optional.h" #include using namespace ::FlexFlow; diff --git a/lib/substitutions/test/src/substitutions/pcg_pattern.cc b/lib/substitutions/test/src/substitutions/pcg_pattern.cc index 6922798a97..d9273b4bcf 100644 --- a/lib/substitutions/test/src/substitutions/pcg_pattern.cc +++ b/lib/substitutions/test/src/substitutions/pcg_pattern.cc @@ -5,9 +5,9 @@ #include "substitutions/operator_pattern/operator_attribute_constraint.h" #include "substitutions/sub_parallel_computation_graph.h" #include "substitutions/tensor_pattern/tensor_attribute_pattern.h" -#include "test/utils/doctest.h" #include "utils/containers/get_only.h" #include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" +#include using namespace ::FlexFlow; @@ -35,7 +35,7 @@ TEST_SUITE(FF_TEST_SUITE) { std::string a_name = "a"; parallel_tensor_guid_t a_tensor = - builder.create_input_tensor(a_shape, /*create_grad=*/true, a_name); + builder.create_input_tensor(a_shape, CreateGrad::YES, a_name); int outDim = 16; std::string x_matmul_name = "x_matmul"; @@ -65,14 +65,14 @@ TEST_SUITE(FF_TEST_SUITE) { get_parallel_layer_by_name(pcg, x_matmul_name); parallel_layer_guid_t y_matmul = get_parallel_layer_by_name(pcg, y_matmul_name); - std::vector x_inputs = - get_layer_inputs(pcg, x_matmul); - REQUIRE(x_inputs.size() == 2); - parallel_tensor_guid_t x_weights = x_inputs.at(1); - std::vector y_inputs = - get_layer_inputs(pcg, y_matmul); - REQUIRE(y_inputs.size() == 2); - parallel_tensor_guid_t y_weights = y_inputs.at(1); + std::vector x_incoming = + get_incoming_tensors(pcg, x_matmul); + REQUIRE(x_incoming.size() == 2); + parallel_tensor_guid_t x_weights = x_incoming.at(1); + std::vector y_incoming = + get_incoming_tensors(pcg, y_matmul); + REQUIRE(y_incoming.size() == 2); + parallel_tensor_guid_t y_weights = y_incoming.at(1); LabelledOpenDataflowGraph g = LabelledOpenDataflowGraph using namespace ::FlexFlow; diff --git a/lib/substitutions/test/src/substitutions/unlabelled/unlabelled_graph_pattern.cc b/lib/substitutions/test/src/substitutions/unlabelled/unlabelled_graph_pattern.cc index 3475c10235..e0805dbfd4 100644 --- a/lib/substitutions/test/src/substitutions/unlabelled/unlabelled_graph_pattern.cc +++ b/lib/substitutions/test/src/substitutions/unlabelled/unlabelled_graph_pattern.cc @@ -1,8 +1,8 @@ #include "substitutions/unlabelled/unlabelled_graph_pattern.h" -#include "test/utils/doctest.h" #include "utils/containers/get_only.h" #include "utils/graph/instances/unordered_set_dataflow_graph.h" #include "utils/graph/open_dataflow_graph/open_dataflow_graph.h" +#include using namespace ::FlexFlow; diff --git a/lib/substitutions/test/src/test_pattern_matches.cc b/lib/substitutions/test/src/test_pattern_matches.cc index 9478195523..aeedd65f82 100644 --- a/lib/substitutions/test/src/test_pattern_matches.cc +++ b/lib/substitutions/test/src/test_pattern_matches.cc @@ -1,9 +1,6 @@ -#include "doctest/doctest.h" -#include "rapidcheck.h" #include "substitutions/unlabelled/find_pattern_matches.h" #include "substitutions/unlabelled/match_additional_criterion.h" #include "substitutions/unlabelled/pattern_matching.h" -#include "test/utils/all.h" #include "utils/containers/get_only.h" #include "utils/graph/instances/unordered_set_dataflow_graph.h" #include "utils/graph/node/algorithms.h" @@ -13,6 +10,7 @@ #include "utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.h" #include "utils/graph/open_dataflow_graph/open_dataflow_graph.h" #include "utils/overload.h" +#include using namespace FlexFlow; diff --git a/lib/utils/CMakeLists.txt b/lib/utils/CMakeLists.txt index ae5e120fad..a0d77b9f76 100644 --- a/lib/utils/CMakeLists.txt +++ b/lib/utils/CMakeLists.txt @@ -13,7 +13,6 @@ ff_add_library( fmt json cuda - doctest ) add_subdirectory(ffi) diff --git a/lib/utils/include/utils/any_value_type/any_value_type.h b/lib/utils/include/utils/any_value_type/any_value_type.h new file mode 100644 index 0000000000..a99ce5c8f0 --- /dev/null +++ b/lib/utils/include/utils/any_value_type/any_value_type.h @@ -0,0 +1,66 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ANY_VALUE_TYPE_ANY_VALUE_TYPE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ANY_VALUE_TYPE_ANY_VALUE_TYPE_H + +#include +#include +#include +#include + +namespace FlexFlow { + +struct any_value_type { +public: + any_value_type( + std::any const &value, + std::function const &eq, + std::function const &neq, + std::function const &hash, + std::function const &to_string); + + bool operator==(any_value_type const &other) const; + bool operator!=(any_value_type const &other) const; + + template + T get() const { + return std::any_cast(value); + } + + friend std::string format_as(any_value_type const &); + +private: + std::any value; + std::function eq; + std::function neq; + std::function hash; + std::function to_string; + + friend std::hash; +}; + +template +any_value_type make_any_value_type(T const &t) { + return any_value_type{ + std::make_any(t), + [](std::any const &l, std::any const &r) { + return std::any_cast(l) == std::any_cast(r); + }, + [](std::any const &l, std::any const &r) { + return std::any_cast(l) != std::any_cast(r); + }, + [](std::any const &v) { return std::hash{}(std::any_cast(v)); }, + [](std::any const &v) { return fmt::to_string(std::any_cast(v)); }, + }; +} + +} // namespace FlexFlow + +namespace std { + +template <> +struct hash<::FlexFlow::any_value_type> { + size_t operator()(::FlexFlow::any_value_type const &) const; +}; + +} // namespace std + +#endif diff --git a/lib/utils/include/utils/archetypes/value_type.h b/lib/utils/include/utils/archetypes/value_type.h new file mode 100644 index 0000000000..1635747612 --- /dev/null +++ b/lib/utils/include/utils/archetypes/value_type.h @@ -0,0 +1,48 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ARCHETYPES_VALUE_TYPE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ARCHETYPES_VALUE_TYPE_H + +#include +#include + +namespace FlexFlow { + +template +struct value_type { + value_type() = delete; + + value_type(value_type const &) { + assert(false); + } + value_type &operator=(value_type const &) { + assert(false); + } + + value_type(value_type &&) { + assert(false); + } + value_type &operator=(value_type &&) { + assert(false); + } + + bool operator==(value_type const &) const { + assert(false); + } + bool operator!=(value_type const &) const { + assert(false); + } +}; + +} // namespace FlexFlow + +namespace std { + +template +struct hash<::FlexFlow::value_type> { + size_t operator()(::FlexFlow::value_type const &) const { + assert(false); + }; +}; + +} // namespace std + +#endif diff --git a/lib/utils/include/utils/cli/cli_argument_key.variant.toml b/lib/utils/include/utils/cli/cli_argument_key.variant.toml new file mode 100644 index 0000000000..be118160ce --- /dev/null +++ b/lib/utils/include/utils/cli/cli_argument_key.variant.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "CLIArgumentKey" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "", + "utils/cli/cli_positional_argument_key.dtg.h", + "utils/cli/cli_flag_key.dtg.h", +] + +[[values]] +type = "::FlexFlow::CLIPositionalArgumentKey" + +[[values]] +type = "::FlexFlow::CLIFlagKey" diff --git a/lib/utils/include/utils/cli/cli_flag_key.struct.toml b/lib/utils/include/utils/cli/cli_flag_key.struct.toml new file mode 100644 index 0000000000..790a752911 --- /dev/null +++ b/lib/utils/include/utils/cli/cli_flag_key.struct.toml @@ -0,0 +1,13 @@ +namespace = "FlexFlow" +name = "CLIFlagKey" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [] + +[[fields]] +name = "raw_idx" +type = "int" diff --git a/lib/utils/include/utils/cli/cli_flag_spec.struct.toml b/lib/utils/include/utils/cli/cli_flag_spec.struct.toml new file mode 100644 index 0000000000..66a47de067 --- /dev/null +++ b/lib/utils/include/utils/cli/cli_flag_spec.struct.toml @@ -0,0 +1,28 @@ +namespace = "FlexFlow" +name = "CLIFlagSpec" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "", + "", +] + +src_includes = [ + "utils/fmt/optional.h", +] + +[[fields]] +name = "long_flag" +type = "std::string" + +[[fields]] +name = "short_flag" +type = "std::optional" + +[[fields]] +name = "description" +type = "std::optional" diff --git a/lib/utils/include/utils/cli/cli_get_help_message.h b/lib/utils/include/utils/cli/cli_get_help_message.h new file mode 100644 index 0000000000..d51579a8e2 --- /dev/null +++ b/lib/utils/include/utils/cli/cli_get_help_message.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CLI_CLI_GET_HELP_MESSAGE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CLI_CLI_GET_HELP_MESSAGE_H + +#include "utils/cli/cli_spec.dtg.h" + +namespace FlexFlow { + +std::string cli_get_help_message(std::string const &program_name, + CLISpec const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/cli/cli_parse.h b/lib/utils/include/utils/cli/cli_parse.h new file mode 100644 index 0000000000..3c91a8423b --- /dev/null +++ b/lib/utils/include/utils/cli/cli_parse.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CLI_CLI_PARSE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CLI_CLI_PARSE_H + +#include "utils/cli/cli_parse_result.dtg.h" +#include "utils/cli/cli_spec.dtg.h" +#include + +namespace FlexFlow { + +tl::expected cli_parse_flag(CLISpec const &cli, + std::string const &arg); +tl::expected + cli_parse(CLISpec const &, std::vector const &); +tl::expected + cli_parse(CLISpec const &, int argc, char const *const *argv); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/cli/cli_parse_result.h b/lib/utils/include/utils/cli/cli_parse_result.h new file mode 100644 index 0000000000..155caac7ae --- /dev/null +++ b/lib/utils/include/utils/cli/cli_parse_result.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CLI_CLI_PARSE_RESULT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CLI_CLI_PARSE_RESULT_H + +#include "utils/cli/cli_argument_key.dtg.h" +#include "utils/cli/cli_parse_result.dtg.h" + +namespace FlexFlow { + +bool cli_get_flag(CLIParseResult const &, CLIArgumentKey const &); +std::string cli_get_argument(CLIParseResult const &, CLIArgumentKey const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/cli/cli_parse_result.struct.toml b/lib/utils/include/utils/cli/cli_parse_result.struct.toml new file mode 100644 index 0000000000..b63da7be14 --- /dev/null +++ b/lib/utils/include/utils/cli/cli_parse_result.struct.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "CLIParseResult" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "", + "", + "utils/cli/cli_flag_key.dtg.h", + "utils/cli/cli_positional_argument_key.dtg.h", +] + +src_includes = [ + "utils/fmt/unordered_map.h", + "utils/hash/unordered_map.h", +] + +[[fields]] +name = "flags" +type = "std::unordered_map<::FlexFlow::CLIFlagKey, bool>" + +[[fields]] +name = "positional_arguments" +type = "std::unordered_map<::FlexFlow::CLIPositionalArgumentKey, std::string>" diff --git a/lib/utils/include/utils/cli/cli_positional_argument_key.struct.toml b/lib/utils/include/utils/cli/cli_positional_argument_key.struct.toml new file mode 100644 index 0000000000..d571d0deb3 --- /dev/null +++ b/lib/utils/include/utils/cli/cli_positional_argument_key.struct.toml @@ -0,0 +1,13 @@ +namespace = "FlexFlow" +name = "CLIPositionalArgumentKey" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [] + +[[fields]] +name = "raw_idx" +type = "int" diff --git a/lib/utils/include/utils/cli/cli_positional_argument_spec.struct.toml b/lib/utils/include/utils/cli/cli_positional_argument_spec.struct.toml new file mode 100644 index 0000000000..b1e74701ee --- /dev/null +++ b/lib/utils/include/utils/cli/cli_positional_argument_spec.struct.toml @@ -0,0 +1,31 @@ +namespace = "FlexFlow" +name = "CLIPositionalArgumentSpec" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "", + "", + "", +] + +src_includes = [ + "utils/fmt/optional.h", + "utils/fmt/vector.h", + "utils/hash/vector.h", +] + +[[fields]] +name = "name" +type = "std::string" + +[[fields]] +name = "choices" +type = "std::optional>" + +[[fields]] +name = "description" +type = "std::optional" diff --git a/lib/utils/include/utils/cli/cli_spec.h b/lib/utils/include/utils/cli/cli_spec.h new file mode 100644 index 0000000000..2c0df08c55 --- /dev/null +++ b/lib/utils/include/utils/cli/cli_spec.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CLI_CLI_SPEC_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CLI_CLI_SPEC_H + +#include "utils/cli/cli_argument_key.dtg.h" +#include "utils/cli/cli_flag_spec.dtg.h" +#include "utils/cli/cli_spec.dtg.h" +#include + +namespace FlexFlow { + +CLISpec empty_cli_spec(); +std::vector cli_get_flag_keys(CLISpec const &); +CLIArgumentKey cli_add_help_flag(CLISpec &); +CLIArgumentKey cli_add_flag(CLISpec &, CLIFlagSpec const &); +CLIArgumentKey cli_add_positional_argument(CLISpec &, + CLIPositionalArgumentSpec const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/cli/cli_spec.struct.toml b/lib/utils/include/utils/cli/cli_spec.struct.toml new file mode 100644 index 0000000000..9f64f62c15 --- /dev/null +++ b/lib/utils/include/utils/cli/cli_spec.struct.toml @@ -0,0 +1,29 @@ +namespace = "FlexFlow" +name = "CLISpec" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "", + "utils/cli/cli_flag_spec.dtg.h", + "utils/cli/cli_positional_argument_spec.dtg.h", + "", +] + +src_includes = [ + "utils/fmt/unordered_set.h", + "utils/hash/unordered_set.h", + "utils/fmt/vector.h", + "utils/hash/vector.h", +] + +[[fields]] +name = "flags" +type = "std::vector<::FlexFlow::CLIFlagSpec>" + +[[fields]] +name = "positional_arguments" +type = "std::vector<::FlexFlow::CLIPositionalArgumentSpec>" diff --git a/lib/utils/include/utils/containers/are_all_same.h b/lib/utils/include/utils/containers/are_all_same.h index 79c8705999..37b1838146 100644 --- a/lib/utils/include/utils/containers/are_all_same.h +++ b/lib/utils/include/utils/containers/are_all_same.h @@ -1,7 +1,5 @@ -#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_ARE_ALL_SAME_H -#define _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_ARE_ALL_SAME_H - -#include "utils/exception.h" +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ARE_ALL_SAME_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ARE_ALL_SAME_H namespace FlexFlow { @@ -10,6 +8,7 @@ bool are_all_same(C const &c) { if (c.empty()) { return true; } + auto const &first = *c.cbegin(); for (auto const &v : c) { if (v != first) { diff --git a/lib/utils/include/utils/containers/cartesian_product.h b/lib/utils/include/utils/containers/cartesian_product.h new file mode 100644 index 0000000000..28d0fb118c --- /dev/null +++ b/lib/utils/include/utils/containers/cartesian_product.h @@ -0,0 +1,38 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_CARTESIAN_PRODUCT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_CARTESIAN_PRODUCT_H + +#include "utils/hash/vector.h" +#include +#include +#include + +namespace FlexFlow { + +template +std::unordered_multiset> + cartesian_product(std::vector const &containers) { + std::unordered_multiset> result; + + std::function &, size_t)> recurse = + [&](std::vector ¤t, size_t depth) { + if (depth == containers.size()) { + result.insert(current); + return; + } + + for (E const &item : containers.at(depth)) { + current.push_back(item); + recurse(current, depth + 1); + current.pop_back(); + } + }; + + std::vector current; + recurse(current, 0); + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/enumerate_vector.h b/lib/utils/include/utils/containers/enumerate_vector.h index 11ee8d2352..700106ea3f 100644 --- a/lib/utils/include/utils/containers/enumerate_vector.h +++ b/lib/utils/include/utils/containers/enumerate_vector.h @@ -2,7 +2,6 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ENUMERATE_VECTOR_H #include -#include #include namespace FlexFlow { diff --git a/lib/utils/include/utils/containers/filter.h b/lib/utils/include/utils/containers/filter.h index fb8c703d2a..07f25dc348 100644 --- a/lib/utils/include/utils/containers/filter.h +++ b/lib/utils/include/utils/containers/filter.h @@ -44,6 +44,14 @@ std::map filter(std::map const &m, F const &f) { return result; } +template +std::unordered_multiset filter(std::unordered_multiset const &m, + F const &f) { + std::unordered_multiset result; + std::copy_if(m.cbegin(), m.cend(), std::inserter(result, result.begin()), f); + return result; +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/containers/flatmap.h b/lib/utils/include/utils/containers/flatmap.h index 0f8906f34a..b016a1e03d 100644 --- a/lib/utils/include/utils/containers/flatmap.h +++ b/lib/utils/include/utils/containers/flatmap.h @@ -3,7 +3,9 @@ #include "utils/containers/extend.h" #include "utils/containers/get_element_type.h" +#include "utils/containers/merge_maps.h" #include +#include namespace FlexFlow { @@ -39,6 +41,23 @@ std::unordered_set flatmap_v2(std::unordered_set const &v, return result; } +template < + typename InK, + typename InV, + typename F, + typename OutK = typename std::invoke_result_t::key_type, + typename OutV = typename std::invoke_result_t::mapped_type> +std::unordered_map flatmap(std::unordered_map const &m, + F &&f) { + std::unordered_map result; + + for (auto const &[k, v] : m) { + result = merge_maps(result, f(k, v)); + } + + return result; +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/containers/foldl.h b/lib/utils/include/utils/containers/foldl.h new file mode 100644 index 0000000000..16851d7d9b --- /dev/null +++ b/lib/utils/include/utils/containers/foldl.h @@ -0,0 +1,72 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REPLICATE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REPLICATE_H + +#include "utils/exception.h" +#include "utils/fmt/vector.h" +#include +#include +#include +#include + +namespace FlexFlow { + +/** + * @brief + * Iteratively applies `func` to the elements of `c` from left to right. + * `init` is used as the starting value. + * + * @example + * std::vector nums = {1, 2, 3, 4}; + * int result = foldl(nums, 0, [](int a, int b) { return a + b; }); + * result -> ((((0+1)+2)+3)+4) = 10 + * + * @note + * For more information, see + * https://hackage.haskell.org/package/base-4.20.0.1/docs/Prelude.html#v:foldl + */ +template +T foldl(C const &c, T init, F func) { + T result = init; + for (auto const &elem : c) { + result = func(result, elem); + } + return result; +} + +/** + * @brief + * Applies `func` to the elements of `c` from left to right, accumulating the + * result. The first element of `c` is used as the starting point for the + * accumulation. + * + * @example + * std::vector nums = {1, 2, 3, 4}; + * int result = foldl1(nums, [](int a, int b) { return a + b; }); + * result -> (((1+2)+3)+4) = 10 + * + * @note + * For more information, see + * https://hackage.haskell.org/package/base-4.20.0.1/docs/Prelude.html#v:foldl1 + * @throws std::runtime_error if the container is empty. + */ +template +E foldl1(C const &c, F func) { + if (c.empty()) { + throw mk_runtime_error( + fmt::format("foldl1 received empty container: {}", c)); + } + std::optional result = std::nullopt; + + for (E const &e : c) { + if (!result.has_value()) { + result = e; + } else { + result = func(result.value(), e); + } + } + return result.value(); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/foldl1.h b/lib/utils/include/utils/containers/foldl1.h new file mode 100644 index 0000000000..f542f8cf00 --- /dev/null +++ b/lib/utils/include/utils/containers/foldl1.h @@ -0,0 +1,29 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FOLDL1_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FOLDL1_H + +#include "utils/exception.h" +#include + +namespace FlexFlow { + +template +T foldl1(std::vector const &vec, F f) { + if (vec.empty()) { + throw mk_runtime_error(fmt::format( + "foldl1 expected non-empty vector, but receieved empty vector")); + } + + auto it = vec.cbegin(); + T result = *it; + it++; + + for (; it != vec.cend(); it++) { + result = f(result, *it); + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/foldr1.h b/lib/utils/include/utils/containers/foldr1.h new file mode 100644 index 0000000000..4a7e8e098c --- /dev/null +++ b/lib/utils/include/utils/containers/foldr1.h @@ -0,0 +1,28 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FOLDR1_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FOLDR1_H + +#include "utils/exception.h" +#include + +namespace FlexFlow { + +template +T foldr1(std::vector const &vec, F f) { + if (vec.empty()) { + throw mk_runtime_error(fmt::format( + "foldr1 expected non-empty vector, but receieved empty vector")); + } + + auto it = vec.crbegin(); + T result = *it; + it++; + for (; it != vec.crend(); it++) { + result = f(result, *it); + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/generate_map.h b/lib/utils/include/utils/containers/generate_map.h index 1afa534a19..53b2a590c5 100644 --- a/lib/utils/include/utils/containers/generate_map.h +++ b/lib/utils/include/utils/containers/generate_map.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GENERATE_MAP_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GENERATE_MAP_H -#include "utils/containers/as_vector.h" #include "utils/containers/get_element_type.h" +#include "utils/containers/vector_of.h" #include "utils/containers/vector_transform.h" #include "utils/type_traits_core.h" #include @@ -17,7 +17,7 @@ std::unordered_map generate_map(C const &c, F const &f) { static_assert(is_hashable_v, "Key type should be hashable (but is not)"); auto transformed = - vector_transform(as_vector(c), [&](K const &k) -> std::pair { + vector_transform(vector_of(c), [&](K const &k) -> std::pair { return {k, f(k)}; }); return {transformed.cbegin(), transformed.cend()}; diff --git a/lib/utils/include/utils/containers/get_all_assignments.h b/lib/utils/include/utils/containers/get_all_assignments.h new file mode 100644 index 0000000000..9981948f47 --- /dev/null +++ b/lib/utils/include/utils/containers/get_all_assignments.h @@ -0,0 +1,44 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GET_ALL_ASSIGNMENTS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GET_ALL_ASSIGNMENTS_H + +#include "utils/containers/cartesian_product.h" +#include "utils/containers/keys.h" +#include "utils/containers/transform.h" +#include "utils/containers/unordered_map_from_pairs.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/containers/vector_of.h" +#include "utils/containers/zip.h" +#include "utils/hash/unordered_map.h" +#include +#include +#include + +namespace FlexFlow { + +/** + * @note If \p options_per_key is empty, an set containing a single empty + * assignment is returned + */ +template +std::unordered_set> get_all_assignments( + std::unordered_map> const &options_per_key) { + if (options_per_key.empty()) { + return {{}}; + } + + std::vector ordered_keys = vector_of(keys(options_per_key)); + std::vector> ordered_value_option_sets = transform( + ordered_keys, [&](K const &k) { return options_per_key.at(k); }); + + std::unordered_set> result = transform( + unordered_set_of(cartesian_product(ordered_value_option_sets)), + [&](std::vector const &chosen_values) { + return unordered_map_from_pairs(zip(ordered_keys, chosen_values)); + }); + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/get_all_permutations_with_repetition.h b/lib/utils/include/utils/containers/get_all_permutations_with_repetition.h new file mode 100644 index 0000000000..ccdde0131a --- /dev/null +++ b/lib/utils/include/utils/containers/get_all_permutations_with_repetition.h @@ -0,0 +1,50 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GET_ALL_PERMUTATIONS_WITH_REPETITION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GET_ALL_PERMUTATIONS_WITH_REPETITION_H + +#include +#include + +namespace FlexFlow { + +/** + * @brief For a given container `c` and integer `n`, return all possible vectors + * of size `n` that only contain (possibly duplicated) elements of `c`. + * @details + * https://en.wikipedia.org/wiki/Permutation#Permutations_with_repetition + **/ +template +std::unordered_multiset> + get_all_permutations_with_repetition(C const &container, int n) { + std::unordered_multiset> result; + + if (container.empty() || n == 0) { + return result; + } + + std::vector elements(std::begin(container), std::end(container)); + std::vector indices(n, 0); + + while (true) { + std::vector perm(n); + for (int i = 0; i < n; ++i) { + perm[i] = elements[indices[i]]; + } + result.insert(perm); + + int i = n - 1; + while (i != -1 && ++indices[i] == elements.size()) { + indices[i] = 0; + --i; + } + + if (i == -1) { + break; + } + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/get_only.h b/lib/utils/include/utils/containers/get_only.h index fedb87413d..201095c47d 100644 --- a/lib/utils/include/utils/containers/get_only.h +++ b/lib/utils/include/utils/containers/get_only.h @@ -10,8 +10,8 @@ namespace FlexFlow { template typename C::value_type get_only(C const &c) { return unwrap(maybe_get_only(c), [&] { - throw mk_runtime_error("Encountered container with size {} in get_only", - c.size()); + throw mk_runtime_error(fmt::format( + "Encountered container with size {} in get_only", c.size())); }); } diff --git a/lib/utils/include/utils/containers/map_from_keys_and_values.h b/lib/utils/include/utils/containers/map_from_keys_and_values.h new file mode 100644 index 0000000000..499965dc5e --- /dev/null +++ b/lib/utils/include/utils/containers/map_from_keys_and_values.h @@ -0,0 +1,29 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAP_FROM_KEYS_AND_VALUES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAP_FROM_KEYS_AND_VALUES_H + +#include "utils/containers/zip.h" +#include "utils/exception.h" +#include + +namespace FlexFlow { + +template +std::unordered_map + map_from_keys_and_values(std::vector const &keys, + std::vector const &values) { + if (keys.size() != values.size()) { + throw mk_runtime_error(fmt::format( + "recieved keys (of size {}) not matching values (of size {})", + keys.size(), + values.size())); + } + std::unordered_map result; + for (auto const &[k, v] : zip(keys, values)) { + result.insert({k, v}); + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/maximum.h b/lib/utils/include/utils/containers/maximum.h index 5d21dd9fa1..b3d6d0c6d7 100644 --- a/lib/utils/include/utils/containers/maximum.h +++ b/lib/utils/include/utils/containers/maximum.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_MAXIMUM_H -#define _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_MAXIMUM_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAXIMUM_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAXIMUM_H #include "utils/exception.h" #include diff --git a/lib/utils/include/utils/containers/merge_maps.h b/lib/utils/include/utils/containers/merge_maps.h index 4176a36762..dd886ab8aa 100644 --- a/lib/utils/include/utils/containers/merge_maps.h +++ b/lib/utils/include/utils/containers/merge_maps.h @@ -4,7 +4,7 @@ #include "utils/containers/are_disjoint.h" #include "utils/containers/keys.h" #include "utils/exception.h" -#include +#include "utils/fmt/unordered_map.h" #include namespace FlexFlow { diff --git a/lib/utils/include/utils/containers/multiset_union.h b/lib/utils/include/utils/containers/multiset_union.h new file mode 100644 index 0000000000..6f2b2a7889 --- /dev/null +++ b/lib/utils/include/utils/containers/multiset_union.h @@ -0,0 +1,48 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MULTISET_UNION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MULTISET_UNION_H + +#include +#include + +namespace FlexFlow { + +template +std::unordered_multiset + multiset_union(std::unordered_multiset const &lhs, + std::unordered_multiset const &rhs) { + std::unordered_multiset result = lhs; + + for (T const &t : rhs) { + result.insert(t); + } + + return result; +} + +template +std::multiset multiset_union(std::multiset const &lhs, + std::multiset const &rhs) { + std::multiset result = lhs; + + for (T const &t : rhs) { + result.insert(t); + } + + return result; +} + +template +std::unordered_multiset multiset_union(C const &c) { + std::unordered_multiset result; + for (auto const &s : c) { + for (T const &element : s) { + result.insert(element); + } + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/range.h b/lib/utils/include/utils/containers/range.h new file mode 100644 index 0000000000..ff6b9f44ee --- /dev/null +++ b/lib/utils/include/utils/containers/range.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_RANGE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_RANGE_H + +#include + +namespace FlexFlow { + +std::vector range(int start, int end, int step = 1); +std::vector range(int end); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/replicate.h b/lib/utils/include/utils/containers/replicate.h new file mode 100644 index 0000000000..aa3d0a7e35 --- /dev/null +++ b/lib/utils/include/utils/containers/replicate.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REPLICATE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REPLICATE_H + +#include + +namespace FlexFlow { + +template +std::vector replicate(int n, T const &element) { + return std::vector(n, element); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/require_all_same1.h b/lib/utils/include/utils/containers/require_all_same1.h new file mode 100644 index 0000000000..2f42243857 --- /dev/null +++ b/lib/utils/include/utils/containers/require_all_same1.h @@ -0,0 +1,31 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REQUIRE_ALL_SAME1_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REQUIRE_ALL_SAME1_H + +#include +#include + +namespace FlexFlow { + +template +tl::expected require_all_same1(C const &c) { + if (c.empty()) { + return tl::unexpected(fmt::format( + "require_all_same1 expected non-empty container, but received {}", c)); + } + + T const &first = *c.cbegin(); + for (T const &v : c) { + if (v != first) { + return tl::unexpected(fmt::format("require_all_same1 found non-same " + "elements {} and {} in containers {}", + first, + v, + c)); + } + } + return first; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/require_no_duplicates.h b/lib/utils/include/utils/containers/require_no_duplicates.h new file mode 100644 index 0000000000..0cbe361bdd --- /dev/null +++ b/lib/utils/include/utils/containers/require_no_duplicates.h @@ -0,0 +1,40 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REQUIRE_NO_DUPLICATES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REQUIRE_NO_DUPLICATES_H + +#include "utils/exception.h" +#include "utils/fmt/multiset.h" +#include "utils/fmt/unordered_multiset.h" +#include +#include +#include + +namespace FlexFlow { + +template +std::unordered_set + require_no_duplicates(std::unordered_multiset const &s) { + std::unordered_set result{s.cbegin(), s.cend()}; + + if (result.size() != s.size()) { + throw mk_runtime_error(fmt::format( + "require_no_duplicates encountered duplicate in set {}", s)); + } + + return result; +} + +template +std::set require_no_duplicates(std::multiset const &s) { + std::set result{s.cbegin(), s.cend()}; + + if (result.size() != s.size()) { + throw mk_runtime_error(fmt::format( + "require_no_duplicates encountered duplicate in set {}", s)); + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/reversed.h b/lib/utils/include/utils/containers/reversed.h index 621eee9519..902b247469 100644 --- a/lib/utils/include/utils/containers/reversed.h +++ b/lib/utils/include/utils/containers/reversed.h @@ -1,15 +1,14 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REVERSED_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REVERSED_H +#include + namespace FlexFlow { template -T reversed(T const &t) { - T r; - for (auto i = t.cend() - 1; i >= t.begin(); i--) { - r.push_back(*i); - } - return r; +std::vector reversed(std::vector const &t) { + std::vector result(std::crbegin(t), std::crend(t)); + return result; } } // namespace FlexFlow diff --git a/lib/utils/include/utils/containers/scanl.h b/lib/utils/include/utils/containers/scanl.h new file mode 100644 index 0000000000..a30a9e1576 --- /dev/null +++ b/lib/utils/include/utils/containers/scanl.h @@ -0,0 +1,77 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SCANL_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SCANL_H + +#include +#include + +namespace FlexFlow { + +/** + * @brief + * Applies `op` to the elements of `c` from left to right, accumulating + * the intermediate results in a vector. `init` is used as the starting point + * for the accumulation. + * + * @example + * std::vector nums = {1, 2, 3, 4}; + * auto result = scanl(nums, 0, [](int a, int b) {return a+b;}); + * result -> {0,1,3,6,10} + * + * @note + * Essentially a foldl which stores the intermediate results + * For more information, see + * https://hackage.haskell.org/package/base-4.20.0.1/docs/Prelude.html#v:scanl + */ +template +std::vector scanl(C const &c, T init, F const &op) { + std::vector result; + + result.push_back(init); + for (auto const &elem : c) { + init = op(init, elem); + result.push_back(init); + } + + return result; +} + +/** + * @brief + * Applies `op` to the elements of `c` from left to right, accumulating + * the intermediate results in a vector. The first item of `c` is used as the + * starting point for the accumulation. + * + * @example + * std::vector nums = {1, 2, 3, 4}; + * auto result = scanl1(nums, [](int a, int b) {return a+b;}); + * result -> {1,3,6,10} + * + * @note + * Essentially a foldl1 which stores the intermediate results. + * For more information, see + * https://hackage.haskell.org/package/base-4.20.0.1/docs/Prelude.html#v:scanl1 + */ +template +std::vector scanl1(C const &c, F op) { + + if (c.empty()) { + return std::vector(); + } + + std::optional init = std::nullopt; + std::vector result; + + for (T const &elem : c) { + if (!init.has_value()) { + init = elem; + } else { + init = op(init.value(), elem); + } + result.push_back(init.value()); + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/set_minus.h b/lib/utils/include/utils/containers/set_minus.h index 6efa2f0a84..fdd1f11995 100644 --- a/lib/utils/include/utils/containers/set_minus.h +++ b/lib/utils/include/utils/containers/set_minus.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SET_MINUS_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SET_MINUS_H +#include #include namespace FlexFlow { @@ -15,6 +16,15 @@ std::unordered_set set_minus(std::unordered_set const &l, return result; } +template +std::set set_minus(std::set const &l, std::set const &r) { + std::set result = l; + for (T const &t : r) { + result.erase(t); + } + return result; +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/containers/set_of.h b/lib/utils/include/utils/containers/set_of.h new file mode 100644 index 0000000000..14658209aa --- /dev/null +++ b/lib/utils/include/utils/containers/set_of.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SET_OF_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SET_OF_H + +#include + +namespace FlexFlow { + +template +std::set set_of(C const &c) { + std::set result; + for (T const &t : c) { + result.insert(t); + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/subvec.h b/lib/utils/include/utils/containers/subvec.h index 13586d21da..c89e9227de 100644 --- a/lib/utils/include/utils/containers/subvec.h +++ b/lib/utils/include/utils/containers/subvec.h @@ -31,6 +31,7 @@ std::vector subvec(std::vector const &v, if (maybe_start.has_value()) { begin_iter += resolve_loc(maybe_start.value()); } + if (maybe_end.has_value()) { end_iter = v.cbegin() + resolve_loc(maybe_end.value()); } @@ -38,6 +39,10 @@ std::vector subvec(std::vector const &v, return {}; } + if (end_iter < begin_iter) { + end_iter = begin_iter; + } + std::vector output(begin_iter, end_iter); return output; } diff --git a/lib/utils/include/utils/containers/to_uppercase.h b/lib/utils/include/utils/containers/to_uppercase.h new file mode 100644 index 0000000000..a2dc7786f9 --- /dev/null +++ b/lib/utils/include/utils/containers/to_uppercase.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_TO_UPPERCASE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_TO_UPPERCASE_H + +#include + +namespace FlexFlow { + +std::string to_uppercase(std::string const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/transform.h b/lib/utils/include/utils/containers/transform.h index ec4185f822..a8a6a749cd 100644 --- a/lib/utils/include/utils/containers/transform.h +++ b/lib/utils/include/utils/containers/transform.h @@ -25,7 +25,7 @@ auto transform(req const &c, F const &f) template > std::unordered_set transform(std::unordered_set const &v, F const &f) { std::unordered_set result; - for (auto const &e : v) { + for (In const &e : v) { result.insert(f(e)); } return result; @@ -35,7 +35,7 @@ template > std::unordered_multiset transform(std::unordered_multiset const &v, F const &f) { std::unordered_multiset result; - for (auto const &e : v) { + for (In const &e : v) { result.insert(f(e)); } return result; @@ -44,6 +44,15 @@ std::unordered_multiset transform(std::unordered_multiset const &v, template > std::set transform(std::set const &v, F const &f) { std::set result; + for (In const &e : v) { + result.insert(f(e)); + } + return result; +} + +template > +std::multiset transform(std::multiset const &v, F const &f) { + std::multiset result; for (auto const &e : v) { result.insert(f(e)); } diff --git a/lib/utils/include/utils/containers/try_at.h b/lib/utils/include/utils/containers/try_at.h new file mode 100644 index 0000000000..45e50fca27 --- /dev/null +++ b/lib/utils/include/utils/containers/try_at.h @@ -0,0 +1,31 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_TRY_AT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_TRY_AT_H + +#include "utils/containers/contains_key.h" +#include +#include +#include + +namespace FlexFlow { + +template +std::optional try_at(std::unordered_map const &m, K const &k) { + if (contains_key(m, k)) { + return m.at(k); + } else { + return std::nullopt; + } +} + +template +std::optional try_at(std::map const &m, K const &k) { + if (contains_key(m, k)) { + return m.at(k); + } else { + return std::nullopt; + } +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/unordered_map_from_pairs.h b/lib/utils/include/utils/containers/unordered_map_from_pairs.h new file mode 100644 index 0000000000..660c57c5e7 --- /dev/null +++ b/lib/utils/include/utils/containers/unordered_map_from_pairs.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_UNORDERED_MAP_FROM_PAIRS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_UNORDERED_MAP_FROM_PAIRS_H + +#include + +namespace FlexFlow { + +template +std::unordered_map unordered_map_from_pairs(C const &c) { + return std::unordered_map(c.cbegin(), c.cend()); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/values.h b/lib/utils/include/utils/containers/values.h index 7c487d1d43..2a730ccc42 100644 --- a/lib/utils/include/utils/containers/values.h +++ b/lib/utils/include/utils/containers/values.h @@ -1,15 +1,15 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_VALUES_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_VALUES_H -#include +#include namespace FlexFlow { template -std::vector values(C const &c) { - std::vector result; +std::unordered_multiset values(C const &c) { + std::unordered_multiset result; for (auto const &kv : c) { - result.push_back(kv.second); + result.insert(kv.second); } return result; } diff --git a/lib/utils/include/utils/containers/as_vector.h b/lib/utils/include/utils/containers/vector_of.h similarity index 54% rename from lib/utils/include/utils/containers/as_vector.h rename to lib/utils/include/utils/containers/vector_of.h index fafa1dc799..7fb903b4a8 100644 --- a/lib/utils/include/utils/containers/as_vector.h +++ b/lib/utils/include/utils/containers/vector_of.h @@ -1,12 +1,12 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_AS_VECTOR_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_AS_VECTOR_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_VECTOR_OF_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_VECTOR_OF_H #include namespace FlexFlow { template -std::vector as_vector(C const &c) { +std::vector vector_of(C const &c) { std::vector result(c.cbegin(), c.cend()); return result; } diff --git a/lib/utils/include/utils/containers/without_order.h b/lib/utils/include/utils/containers/without_order.h deleted file mode 100644 index 7199b2bd4a..0000000000 --- a/lib/utils/include/utils/containers/without_order.h +++ /dev/null @@ -1,15 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_WITHOUT_ORDER_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_WITHOUT_ORDER_H - -#include - -namespace FlexFlow { - -template -std::unordered_multiset without_order(C const &c) { - return {c.cbegin(), c.cend()}; -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/containers/zip.h b/lib/utils/include/utils/containers/zip.h index 94182577ee..0f6dbed1d3 100644 --- a/lib/utils/include/utils/containers/zip.h +++ b/lib/utils/include/utils/containers/zip.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_H +#include #include #include @@ -16,6 +17,17 @@ std::vector> zip(std::vector const &l, return result; } +template +std::vector> zip(std::vector const &a, + std::vector const &b, + std::vector const &c) { + std::vector> result; + for (int i = 0; i < std::min({a.size(), b.size(), c.size()}); i++) { + result.push_back(std::make_tuple(a.at(i), b.at(i), c.at(i))); + } + return result; +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/exception.h b/lib/utils/include/utils/exception.h index 20a8098040..080cbb3611 100644 --- a/lib/utils/include/utils/exception.h +++ b/lib/utils/include/utils/exception.h @@ -34,12 +34,7 @@ T throw_if_unexpected(tl::expected const &r) { } } -template -std::runtime_error mk_runtime_error(fmt::format_string fmt_str, - T &&...args) { - return std::runtime_error( - fmt::vformat(fmt_str, fmt::make_format_args(args...))); -} +std::runtime_error mk_runtime_error(std::string const &); } // namespace FlexFlow diff --git a/lib/utils/include/utils/fmt/expected.h b/lib/utils/include/utils/fmt/expected.h index d339e6e251..7ef7f24eb7 100644 --- a/lib/utils/include/utils/fmt/expected.h +++ b/lib/utils/include/utils/fmt/expected.h @@ -1,9 +1,8 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_EXPECTED_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_EXPECTED_H -#include "fmt/format.h" #include "utils/check_fmtable.h" -#include +#include #include #include @@ -44,15 +43,4 @@ std::ostream &operator<<(std::ostream &s, tl::expected const &t) { } // namespace FlexFlow -namespace doctest { - -template -struct StringMaker> { - static String convert(tl::expected const &m) { - return toString(fmt::to_string(m)); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/include/utils/fmt/json.h b/lib/utils/include/utils/fmt/json.h new file mode 100644 index 0000000000..c7aa87e3eb --- /dev/null +++ b/lib/utils/include/utils/fmt/json.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_JSON_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_JSON_H + +#include +#include + +namespace fmt { + +template +struct formatter<::nlohmann::json, Char> : formatter { + template + auto format(::nlohmann::json const &j, FormatContext &ctx) { + std::ostringstream oss; + oss << j; + return formatter::format(oss.str(), ctx); + } +}; + +} // namespace fmt + +#endif diff --git a/lib/utils/include/utils/fmt/map.h b/lib/utils/include/utils/fmt/map.h index f01a8d63df..9225040d4d 100644 --- a/lib/utils/include/utils/fmt/map.h +++ b/lib/utils/include/utils/fmt/map.h @@ -5,7 +5,6 @@ #include "utils/containers/sorted.h" #include "utils/fmt/pair.h" #include "utils/join_strings.h" -#include #include #include @@ -48,15 +47,4 @@ std::ostream &operator<<(std::ostream &s, std::map const &m) { } // namespace FlexFlow -namespace doctest { - -template -struct StringMaker> { - static String convert(std::map const &m) { - return toString(fmt::to_string(m)); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/include/utils/fmt/monostate.h b/lib/utils/include/utils/fmt/monostate.h new file mode 100644 index 0000000000..884f4d389e --- /dev/null +++ b/lib/utils/include/utils/fmt/monostate.h @@ -0,0 +1,32 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_MONOSTATE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_MONOSTATE_H + +#include +#include + +namespace fmt { + +template +struct formatter< + ::std::monostate, + Char, + std::enable_if_t::value>> + : formatter<::std::string> { + template + auto format(::std::monostate const &, FormatContext &ctx) + -> decltype(ctx.out()) { + std::string result = ""; + + return formatter::format(result, ctx); + } +}; + +} // namespace fmt + +namespace FlexFlow { + +std::ostream &operator<<(std::ostream &, std::monostate const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/fmt/multiset.h b/lib/utils/include/utils/fmt/multiset.h index b60471d8a8..4234d90e94 100644 --- a/lib/utils/include/utils/fmt/multiset.h +++ b/lib/utils/include/utils/fmt/multiset.h @@ -3,7 +3,6 @@ #include "utils/check_fmtable.h" #include "utils/join_strings.h" -#include #include #include @@ -42,15 +41,4 @@ std::ostream &operator<<(std::ostream &s, std::multiset const &x) { } // namespace FlexFlow -namespace doctest { - -template -struct StringMaker> { - static String convert(std::multiset const &m) { - return toString(fmt::to_string(m)); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/include/utils/fmt/optional.h b/lib/utils/include/utils/fmt/optional.h index 4917454a82..cb72cd03e3 100644 --- a/lib/utils/include/utils/fmt/optional.h +++ b/lib/utils/include/utils/fmt/optional.h @@ -2,7 +2,6 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_OPTIONAL_H #include "utils/check_fmtable.h" -#include #include #include @@ -33,7 +32,7 @@ struct formatter< template struct formatter : formatter { template - auto format(std::nullopt_t, FormatContext &ctx) -> decltype(ctx.out()) { + auto format(std::nullopt_t, FormatContext &ctx) const -> decltype(ctx.out()) { return formatter::format("nullopt", ctx); } }; @@ -55,22 +54,4 @@ inline std::ostream &operator<<(std::ostream &s, std::nullopt_t) { } // namespace FlexFlow -namespace doctest { - -template -struct StringMaker> { - static String convert(std::optional const &m) { - return toString(fmt::to_string(m)); - } -}; - -template <> -struct StringMaker { - static String convert(std::nullopt_t) { - return toString("nullopt"); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/include/utils/fmt/pair.h b/lib/utils/include/utils/fmt/pair.h index 218764e485..d261c344a1 100644 --- a/lib/utils/include/utils/fmt/pair.h +++ b/lib/utils/include/utils/fmt/pair.h @@ -2,7 +2,6 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_PAIR_H #include "utils/check_fmtable.h" -#include #include #include @@ -40,15 +39,4 @@ std::ostream &operator<<(std::ostream &s, std::pair const &m) { } // namespace FlexFlow -namespace doctest { - -template -struct StringMaker> { - static String convert(std::pair const &m) { - return toString(fmt::to_string(m)); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/include/utils/fmt/set.h b/lib/utils/include/utils/fmt/set.h index 9dd3438fae..c46984cc5a 100644 --- a/lib/utils/include/utils/fmt/set.h +++ b/lib/utils/include/utils/fmt/set.h @@ -4,7 +4,6 @@ #include "utils/check_fmtable.h" #include "utils/containers/sorted.h" #include "utils/join_strings.h" -#include #include #include #include @@ -43,15 +42,4 @@ std::ostream &operator<<(std::ostream &s, std::set const &x) { } // namespace FlexFlow -namespace doctest { - -template -struct StringMaker> { - static String convert(std::set const &m) { - return toString(fmt::to_string(m)); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/include/utils/fmt/unordered_map.h b/lib/utils/include/utils/fmt/unordered_map.h index 93089c02e7..12faa64e32 100644 --- a/lib/utils/include/utils/fmt/unordered_map.h +++ b/lib/utils/include/utils/fmt/unordered_map.h @@ -5,7 +5,6 @@ #include "utils/fmt/pair.h" #include "utils/join_strings.h" #include -#include #include #include #include @@ -48,15 +47,4 @@ std::ostream &operator<<(std::ostream &s, std::unordered_map const &m) { } // namespace FlexFlow -namespace doctest { - -template -struct StringMaker> { - static String convert(std::unordered_map const &m) { - return toString(fmt::to_string(m)); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/include/utils/fmt/unordered_multiset.h b/lib/utils/include/utils/fmt/unordered_multiset.h index 2648bc6589..a4c17f7f5e 100644 --- a/lib/utils/include/utils/fmt/unordered_multiset.h +++ b/lib/utils/include/utils/fmt/unordered_multiset.h @@ -3,7 +3,6 @@ #include "utils/check_fmtable.h" #include "utils/join_strings.h" -#include #include #include @@ -24,7 +23,6 @@ struct formatter< ::FlexFlow::join_strings(m.cbegin(), m.cend(), ", ", [](T const &t) { return fmt::to_string(t); }); - // } return formatter::format("{" + result + "}", ctx); } }; @@ -42,15 +40,4 @@ std::ostream &operator<<(std::ostream &s, std::unordered_multiset const &x) { } // namespace FlexFlow -namespace doctest { - -template -struct StringMaker> { - static String convert(std::unordered_multiset const &m) { - return toString(fmt::to_string(m)); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/include/utils/fmt/unordered_set.h b/lib/utils/include/utils/fmt/unordered_set.h index 4000a2f098..20a08916fc 100644 --- a/lib/utils/include/utils/fmt/unordered_set.h +++ b/lib/utils/include/utils/fmt/unordered_set.h @@ -4,7 +4,6 @@ #include "utils/check_fmtable.h" #include "utils/join_strings.h" #include "utils/type_traits_core.h" -#include #include #include @@ -25,7 +24,6 @@ struct formatter< ::FlexFlow::join_strings(m.cbegin(), m.cend(), ", ", [](T const &t) { return fmt::to_string(t); }); - // } return formatter::format("{" + result + "}", ctx); } }; @@ -43,15 +41,4 @@ std::ostream &operator<<(std::ostream &s, std::unordered_set const &x) { } // namespace FlexFlow -namespace doctest { - -template -struct StringMaker> { - static String convert(std::unordered_set const &m) { - return toString(fmt::to_string(m)); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/include/utils/fmt/variant.h b/lib/utils/include/utils/fmt/variant.h index 81786f9cb0..1690955286 100644 --- a/lib/utils/include/utils/fmt/variant.h +++ b/lib/utils/include/utils/fmt/variant.h @@ -1,7 +1,6 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_VARIANT_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_VARIANT_H -#include #include #include @@ -33,15 +32,4 @@ std::ostream &operator<<(std::ostream &s, std::variant const &v) { } // namespace FlexFlow -namespace doctest { - -template -struct StringMaker> { - static String convert(std::variant const &m) { - return toString(fmt::to_string(m)); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/include/utils/fmt/vector.h b/lib/utils/include/utils/fmt/vector.h index b46ac9423a..1eec7a306b 100644 --- a/lib/utils/include/utils/fmt/vector.h +++ b/lib/utils/include/utils/fmt/vector.h @@ -3,7 +3,6 @@ #include "utils/check_fmtable.h" #include "utils/join_strings.h" -#include #include #include @@ -41,15 +40,4 @@ std::ostream &operator<<(std::ostream &s, std::vector const &v) { } // namespace FlexFlow -namespace doctest { - -template -struct StringMaker> { - static String convert(std::vector const &m) { - return toString(fmt::to_string(m)); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/include/utils/full_binary_tree/binary_tree_path.h b/lib/utils/include/utils/full_binary_tree/binary_tree_path.h new file mode 100644 index 0000000000..e3ed967a23 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/binary_tree_path.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_BINARY_TREE_PATH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_BINARY_TREE_PATH_H + +#include "utils/full_binary_tree/binary_tree_path.dtg.h" + +namespace FlexFlow { + +BinaryTreePath binary_tree_root_path(); +BinaryTreePath nest_inside_left_child(BinaryTreePath const &); +BinaryTreePath nest_inside_right_child(BinaryTreePath const &); + +BinaryTreePathEntry binary_tree_path_get_top_level(BinaryTreePath const &); +BinaryTreePath binary_tree_path_get_non_top_level(BinaryTreePath const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/binary_tree_path.struct.toml b/lib/utils/include/utils/full_binary_tree/binary_tree_path.struct.toml new file mode 100644 index 0000000000..08955c2d75 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/binary_tree_path.struct.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "BinaryTreePath" +features = [ + "eq", + "ord", + "hash", + "fmt", + "json", + "rapidcheck", +] + +includes = [ + "utils/full_binary_tree/binary_tree_path_entry.dtg.h", + "", +] + +src_includes = [ + "utils/fmt/vector.h", + "utils/hash/vector.h", +] + +[[fields]] +name = "entries" +type = "std::vector<::FlexFlow::BinaryTreePathEntry>" diff --git a/lib/utils/include/utils/full_binary_tree/binary_tree_path_entry.enum.toml b/lib/utils/include/utils/full_binary_tree/binary_tree_path_entry.enum.toml new file mode 100644 index 0000000000..6c81123dcf --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/binary_tree_path_entry.enum.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "BinaryTreePathEntry" +features = [ + "hash", + "fmt", + "rapidcheck", + "json", +] + +[[values]] +name = "LEFT_CHILD" +key = "left" + +[[values]] +name = "RIGHT_CHILD" +key = "right" diff --git a/lib/utils/include/utils/full_binary_tree/find_paths_to_leaf.h b/lib/utils/include/utils/full_binary_tree/find_paths_to_leaf.h new file mode 100644 index 0000000000..9cf5d63210 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/find_paths_to_leaf.h @@ -0,0 +1,49 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_FIND_PATHS_TO_LEAF_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_FIND_PATHS_TO_LEAF_H + +#include "utils/containers/set_union.h" +#include "utils/containers/transform.h" +#include "utils/full_binary_tree/binary_tree_path.dtg.h" +#include "utils/full_binary_tree/binary_tree_path.h" +#include "utils/full_binary_tree/visit.h" +#include + +namespace FlexFlow { + +template +std::unordered_set find_paths_to_leaf( + Tree const &tree, + FullBinaryTreeImplementation const &impl, + Leaf const &needle) { + auto visitor = FullBinaryTreeVisitor, + Tree, + Parent, + Leaf>{ + [&](Parent const &parent) -> std::unordered_set { + return set_union( + transform( + find_paths_to_leaf(impl.get_left_child(parent), impl, needle), + [](BinaryTreePath const &path) { + return nest_inside_left_child(path); + }), + transform( + find_paths_to_leaf(impl.get_right_child(parent), impl, needle), + [](BinaryTreePath const &path) { + return nest_inside_right_child(path); + })); + }, + [&](Leaf const &leaf) -> std::unordered_set { + if (leaf == needle) { + return {binary_tree_root_path()}; + } else { + return {}; + } + }, + }; + + return visit(tree, impl, visitor); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/full_binary_tree_implementation.struct.toml b/lib/utils/include/utils/full_binary_tree/full_binary_tree_implementation.struct.toml new file mode 100644 index 0000000000..bf08701840 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/full_binary_tree_implementation.struct.toml @@ -0,0 +1,33 @@ +namespace = "FlexFlow" +name = "FullBinaryTreeImplementation" +features = [] + +template_params = [ + "Tree", + "Parent", + "Leaf", +] + +includes = [ + "", +] + +[[fields]] +name = "get_left_child" +type = "std::function" + +[[fields]] +name = "get_right_child" +type = "std::function" + +[[fields]] +name = "is_leaf" +type = "std::function" + +[[fields]] +name = "require_leaf" +type = "std::function" + +[[fields]] +name = "require_parent" +type = "std::function" diff --git a/lib/utils/include/utils/full_binary_tree/full_binary_tree_node_type.enum.toml b/lib/utils/include/utils/full_binary_tree/full_binary_tree_node_type.enum.toml new file mode 100644 index 0000000000..1f8af17cf3 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/full_binary_tree_node_type.enum.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "FullBinaryTreeNodeType" +features = [ + "hash", + "fmt", + "json", + "rapidcheck", +] + +[[values]] +name = "PARENT" +key = "parent" + +[[values]] +name = "LEAF" +key = "leaf" diff --git a/lib/utils/include/utils/full_binary_tree/full_binary_tree_visitor.struct.toml b/lib/utils/include/utils/full_binary_tree/full_binary_tree_visitor.struct.toml new file mode 100644 index 0000000000..7418d7a016 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/full_binary_tree_visitor.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "FullBinaryTreeVisitor" +features = [] + +template_params = [ + "Result", + "Tree", + "Parent", + "Leaf", +] + +includes = [ + "", +] + +[[fields]] +name = "parent_func" +type = "std::function" + +[[fields]] +name = "leaf_func" +type = "std::function" diff --git a/lib/utils/include/utils/full_binary_tree/get_all_leaf_paths.h b/lib/utils/include/utils/full_binary_tree/get_all_leaf_paths.h new file mode 100644 index 0000000000..822acfe9ee --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/get_all_leaf_paths.h @@ -0,0 +1,43 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_ALL_LEAF_PATHS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_ALL_LEAF_PATHS_H + +#include "utils/containers/set_union.h" +#include "utils/containers/transform.h" +#include "utils/full_binary_tree/binary_tree_path.dtg.h" +#include "utils/full_binary_tree/binary_tree_path.h" +#include "utils/full_binary_tree/visit.h" +#include "utils/overload.h" +#include + +namespace FlexFlow { + +template +std::unordered_set get_all_leaf_paths( + Tree const &tree, + FullBinaryTreeImplementation const &impl) { + auto visitor = FullBinaryTreeVisitor, + Tree, + Parent, + Leaf>{ + [&](Parent const &parent) -> std::unordered_set { + return set_union( + transform(get_all_leaf_paths(impl.get_left_child(parent), impl), + [](BinaryTreePath const &path) { + return nest_inside_left_child(path); + }), + transform(get_all_leaf_paths(impl.get_right_child(parent), impl), + [](BinaryTreePath const &path) { + return nest_inside_right_child(path); + })); + }, + [&](Leaf const &leaf) -> std::unordered_set { + return {binary_tree_root_path()}; + }, + }; + + return visit(tree, impl, visitor); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/get_child.h b/lib/utils/include/utils/full_binary_tree/get_child.h new file mode 100644 index 0000000000..7517028ec0 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/get_child.h @@ -0,0 +1,28 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_CHILD_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_CHILD_H + +#include "utils/exception.h" +#include "utils/full_binary_tree/binary_tree_path_entry.dtg.h" +#include "utils/full_binary_tree/full_binary_tree_implementation.dtg.h" +#include + +namespace FlexFlow { + +template +Tree get_child(Parent const &parent, + FullBinaryTreeImplementation const &impl, + BinaryTreePathEntry const &e) { + switch (e) { + case BinaryTreePathEntry::LEFT_CHILD: + return impl.get_left_child(parent); + case BinaryTreePathEntry::RIGHT_CHILD: + return impl.get_right_child(parent); + default: + throw mk_runtime_error( + fmt::format("Unhandled BinaryTreePathEntry value: {}", e)); + } +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/get_leaves.h b/lib/utils/include/utils/full_binary_tree/get_leaves.h new file mode 100644 index 0000000000..8f9d8e919f --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/get_leaves.h @@ -0,0 +1,33 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_LEAVES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_LEAVES_H + +#include "utils/containers/multiset_union.h" +#include "utils/full_binary_tree/full_binary_tree_visitor.dtg.h" +#include "utils/full_binary_tree/visit.h" +#include + +namespace FlexFlow { + +template +std::unordered_multiset + get_leaves(Tree const &tree, + FullBinaryTreeImplementation const &impl) { + + auto visitor = + FullBinaryTreeVisitor, Tree, Parent, Leaf>{ + [&](Parent const &parent) -> std::unordered_multiset { + return multiset_union( + get_leaves(impl.get_left_child(parent), impl), + get_leaves(impl.get_right_child(parent), impl)); + }, + [](Leaf const &leaf) -> std::unordered_multiset { + return {leaf}; + }, + }; + + return visit(tree, impl, visitor); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/get_num_tree_nodes.h b/lib/utils/include/utils/full_binary_tree/get_num_tree_nodes.h new file mode 100644 index 0000000000..922a42242c --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/get_num_tree_nodes.h @@ -0,0 +1,27 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_NUM_TREE_NODES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_NUM_TREE_NODES_H + +#include "utils/full_binary_tree/full_binary_tree_implementation.dtg.h" +#include "utils/full_binary_tree/visit.h" + +namespace FlexFlow { + +template +int get_num_tree_nodes( + Tree const &tree, + FullBinaryTreeImplementation const &impl) { + + auto visitor = FullBinaryTreeVisitor{ + [&](Parent const &parent) -> int { + return 1 + get_num_tree_nodes(impl.get_left_child(parent), impl) + + get_num_tree_nodes(impl.get_right_child(parent), impl); + }, + [](Leaf const &) -> int { return 1; }, + }; + + return visit(tree, impl, visitor); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/get_subtree_at_path.h b/lib/utils/include/utils/full_binary_tree/get_subtree_at_path.h new file mode 100644 index 0000000000..83ce1367b9 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/get_subtree_at_path.h @@ -0,0 +1,36 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_SUBTREE_AT_PATH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_SUBTREE_AT_PATH_H + +#include "utils/full_binary_tree/binary_tree_path.dtg.h" +#include "utils/full_binary_tree/binary_tree_path.h" +#include "utils/full_binary_tree/get_child.h" +#include "utils/full_binary_tree/visit.h" +#include + +namespace FlexFlow { + +template +std::optional get_subtree_at_path( + Tree const &tree, + FullBinaryTreeImplementation const &impl, + BinaryTreePath const &p) { + if (p == binary_tree_root_path()) { + return tree; + } + + auto visitor = FullBinaryTreeVisitor, Tree, Parent, Leaf>{ + [&](Parent const &parent) -> std::optional { + BinaryTreePathEntry curr = binary_tree_path_get_top_level(p); + BinaryTreePath rest = binary_tree_path_get_non_top_level(p); + + return get_subtree_at_path(get_child(parent, impl, curr), impl, rest); + }, + [](Leaf const &leaf) -> std::optional { return std::nullopt; }, + }; + + return visit(tree, impl, visitor); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/visit.h b/lib/utils/include/utils/full_binary_tree/visit.h new file mode 100644 index 0000000000..832d39bdff --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/visit.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_VISIT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_VISIT_H + +#include "utils/exception.h" +#include "utils/full_binary_tree/full_binary_tree_implementation.dtg.h" +#include "utils/full_binary_tree/full_binary_tree_visitor.dtg.h" + +namespace FlexFlow { + +template +Result visit(Tree const &tree, + FullBinaryTreeImplementation const &impl, + FullBinaryTreeVisitor const &visitor) { + if (impl.is_leaf(tree)) { + return visitor.leaf_func(impl.require_leaf(tree)); + } else { + return visitor.parent_func(impl.require_parent(tree)); + } +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.h new file mode 100644 index 0000000000..de7ead8fb6 --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_GET_DATAFLOW_EDGES_FROM_NODE_TO_NODE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_GET_DATAFLOW_EDGES_FROM_NODE_TO_NODE_H + +#include "utils/graph/dataflow_graph/dataflow_graph_view.h" + +namespace FlexFlow { + +std::unordered_set get_dataflow_edges_from_node_to_node( + DataflowGraphView const &g, Node const &src, Node const &dst); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.h new file mode 100644 index 0000000000..2ed0bc02be --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_GET_SUBGRAPH_INCOMING_EDGES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_GET_SUBGRAPH_INCOMING_EDGES_H + +#include "utils/graph/dataflow_graph/dataflow_graph_view.h" + +namespace FlexFlow { + +std::unordered_set + get_subgraph_incoming_edges(DataflowGraphView const &, + std::unordered_set const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.h new file mode 100644 index 0000000000..be0e57435a --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_DATAFLOW_GRAPH_GET_TRANSITIVE_REDUCED_BOUNDARY_NODES_FOR_SPLIT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_DATAFLOW_GRAPH_GET_TRANSITIVE_REDUCED_BOUNDARY_NODES_FOR_SPLIT_H + +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/split_boundary_nodes.dtg.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.h" + +namespace FlexFlow { + +SplitBoundaryNodes get_transitive_reduced_boundary_nodes_for_split( + TransitiveReducedDataflowGraphView const &, BinarySeriesSplit const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.h new file mode 100644 index 0000000000..e53bb876a1 --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_DATAFLOW_GRAPH_GET_TRANSITIVE_REDUCED_EDGES_ACROSS_SPLIT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_DATAFLOW_GRAPH_GET_TRANSITIVE_REDUCED_EDGES_ACROSS_SPLIT_H + +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.h" + +namespace FlexFlow { + +std::unordered_set get_transitive_reduced_edges_across_split( + TransitiveReducedDataflowGraphView const &, BinarySeriesSplit const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.h new file mode 100644 index 0000000000..ad8eadda0e --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_DATAFLOW_GRAPH_GET_TRANSITIVE_REDUCED_OUTPUTS_ACROSS_SPLIT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_DATAFLOW_GRAPH_GET_TRANSITIVE_REDUCED_OUTPUTS_ACROSS_SPLIT_H + +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.h" + +namespace FlexFlow { + +std::unordered_set get_transitive_reduced_outputs_across_split( + TransitiveReducedDataflowGraphView const &, BinarySeriesSplit const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/split_boundary_nodes.struct.toml b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/split_boundary_nodes.struct.toml new file mode 100644 index 0000000000..32582a6b74 --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/split_boundary_nodes.struct.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "SplitBoundaryNodes" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/graph/node/node.dtg.h", + "", +] + +src_includes = [ + "utils/hash/unordered_set.h", + "utils/fmt/unordered_set.h", +] + +[[fields]] +name = "pre_split_boundary" +type = "std::unordered_set<::FlexFlow::Node>" + +[[fields]] +name = "post_split_boundary" +type = "std::unordered_set<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.h new file mode 100644 index 0000000000..916e8f7896 --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_DATAFLOW_GRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_DATAFLOW_GRAPH_H + +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.dtg.h" + +namespace FlexFlow { + +TransitiveReducedDataflowGraphView + get_dataflow_graph_transitive_reduction(DataflowGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.struct.toml b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.struct.toml new file mode 100644 index 0000000000..54c710b26e --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.struct.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "TransitiveReducedDataflowGraphView" +features = [] + +includes = [ + "utils/graph/dataflow_graph/dataflow_graph_view.h", + "utils/graph/digraph/digraph_view.h", +] + +[[fields]] +name = "full_dataflow_graph" +type = "::FlexFlow::DataflowGraphView" + +[[fields]] +name = "transitive_reduction" +type = "::FlexFlow::DiGraphView" + diff --git a/lib/utils/include/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.h b/lib/utils/include/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.h index fc372f68aa..afc9c47c1c 100644 --- a/lib/utils/include/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.h +++ b/lib/utils/include/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.h @@ -6,6 +6,9 @@ namespace FlexFlow { +std::optional + get_cbc_decomposition_with_edge_order_internal( + DiGraphView const &, std::vector const &); std::optional get_cbc_decomposition(DiGraphView const &); diff --git a/lib/utils/include/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.h b/lib/utils/include/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.h new file mode 100644 index 0000000000..3066886e37 --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_COMPLETE_BIPARTITE_COMPOSITE_IS_COMPLETE_BIPARTITE_DIGRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_COMPLETE_BIPARTITE_COMPOSITE_IS_COMPLETE_BIPARTITE_DIGRAPH_H + +#include "utils/graph/digraph/digraph_view.h" + +namespace FlexFlow { + +bool is_complete_bipartite_digraph(DiGraphView const &); +bool is_complete_bipartite_digraph(DiGraphView const &, + std::unordered_set const &srcs); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/digraph/algorithms/digraph_as_dot.h b/lib/utils/include/utils/graph/digraph/algorithms/digraph_as_dot.h new file mode 100644 index 0000000000..ee533a1180 --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/digraph_as_dot.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_DIGRAPH_AS_DOT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_DIGRAPH_AS_DOT_H + +#include "utils/graph/digraph/digraph_view.h" + +namespace FlexFlow { + +std::string digraph_as_dot( + DiGraphView const &, + std::function const &get_node_label); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/digraph/algorithms/digraph_has_edge.h b/lib/utils/include/utils/graph/digraph/algorithms/digraph_has_edge.h new file mode 100644 index 0000000000..87d0d3143a --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/digraph_has_edge.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_HAS_EDGE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_HAS_EDGE_H + +#include "utils/graph/digraph/digraph_view.h" + +namespace FlexFlow { + +bool digraph_has_edge(DiGraphView const &, DirectedEdge const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.h b/lib/utils/include/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.h new file mode 100644 index 0000000000..240fc66426 --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_EDGES_FROM_SUBGRAPH_TO_SUBGRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_EDGES_FROM_SUBGRAPH_TO_SUBGRAPH_H + +#include "utils/graph/digraph/digraph_view.h" +namespace FlexFlow { + +std::unordered_set + get_edges_from_subgraph_to_subgraph(DiGraphView const &, + std::unordered_set const &, + std::unordered_set const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/digraph/algorithms/get_subgraph_outgoing_edges.h b/lib/utils/include/utils/graph/digraph/algorithms/get_subgraph_outgoing_edges.h new file mode 100644 index 0000000000..6d98c5c20d --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/get_subgraph_outgoing_edges.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_SUBGRAPH_OUTGOING_EDGES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_SUBGRAPH_OUTGOING_EDGES_H + +#include "utils/graph/digraph/digraph_view.h" + +namespace FlexFlow { + +std::unordered_set + get_subgraph_outgoing_edges(DiGraphView const &, + std::unordered_set const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/digraph/algorithms/get_subgraph_successors.h b/lib/utils/include/utils/graph/digraph/algorithms/get_subgraph_successors.h new file mode 100644 index 0000000000..2c48d327c4 --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/get_subgraph_successors.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_SUBGRAPH_SUCCESSORS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_SUBGRAPH_SUCCESSORS_H + +#include "utils/graph/digraph/digraph_view.h" + +namespace FlexFlow { + +std::unordered_set + get_subgraph_successors(DiGraphView const &, + std::unordered_set const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/digraph/algorithms/transitive_closure.h b/lib/utils/include/utils/graph/digraph/algorithms/transitive_closure.h new file mode 100644 index 0000000000..c9751124c8 --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/transitive_closure.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_TRANSITIVE_CLOSURE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_TRANSITIVE_CLOSURE_H + +#include "utils/graph/digraph/digraph_view.h" + +namespace FlexFlow { + +DiGraphView transitive_closure(DiGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/instances/unordered_set_undirected_graph.h b/lib/utils/include/utils/graph/instances/unordered_set_undirected_graph.h new file mode 100644 index 0000000000..db2526f973 --- /dev/null +++ b/lib/utils/include/utils/graph/instances/unordered_set_undirected_graph.h @@ -0,0 +1,37 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_INSTANCES_UNORDERED_SET_UNDIRECTED_GRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_INSTANCES_UNORDERED_SET_UNDIRECTED_GRAPH_H + +#include "utils/graph/node/node_source.h" +#include "utils/graph/undirected/i_undirected_graph.h" + +namespace FlexFlow { + +struct UnorderedSetUndirectedGraph final : public IUndirectedGraph { +public: + UnorderedSetUndirectedGraph(); + + Node add_node() override; + void add_node_unsafe(Node const &) override; + void remove_node_unsafe(Node const &) override; + void add_edge(UndirectedEdge const &) override; + void remove_edge(UndirectedEdge const &) override; + + std::unordered_set query_nodes(NodeQuery const &) const override; + std::unordered_set + query_edges(UndirectedEdgeQuery const &) const override; + + UnorderedSetUndirectedGraph *clone() const override; + +private: + UnorderedSetUndirectedGraph(NodeSource const &, + std::unordered_set const &, + std::unordered_set const &); + + NodeSource node_source; + std::unordered_set nodes; + std::unordered_set edges; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_open_dataflow_graph_view.h b/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_dataflow_graph_view.h similarity index 88% rename from lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_open_dataflow_graph_view.h rename to lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_dataflow_graph_view.h index a8e08cb995..b9894fbac3 100644 --- a/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_open_dataflow_graph_view.h +++ b/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_dataflow_graph_view.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_ALGORITHMS_CREATE_LAZY_COPY_OF_LABELLED_OPEN_DATAFLOW_GRAPH_VIEW_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_ALGORITHMS_CREATE_LAZY_COPY_OF_LABELLED_OPEN_DATAFLOW_GRAPH_VIEW_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_ALGORITHMS_CREATE_LAZY_COPY_OF_LABELLED_DATAFLOW_GRAPH_VIEW_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_ALGORITHMS_CREATE_LAZY_COPY_OF_LABELLED_DATAFLOW_GRAPH_VIEW_H #include "utils/graph/labelled_dataflow_graph/i_labelled_dataflow_graph.h" #include "utils/graph/labelled_dataflow_graph/labelled_dataflow_graph.h" @@ -8,6 +8,10 @@ namespace FlexFlow { +// NOTE(@lockshaw) This code is not tested and I don't necessarily trust it. +// Figuring out what to do with it is tracked in +// https://github.com/flexflow/FlexFlow/issues/1513 + template struct LazyLabelledDataflowGraph final : public ILabelledDataflowGraph { @@ -42,11 +46,11 @@ struct LazyLabelledDataflowGraph final return this->get_view().query_outputs(q); } - NodeLabel const &at(Node const &n) const override { + NodeLabel at(Node const &n) const override { return this->get_view().at(n); } - ValueLabel const &at(DataflowOutput const &v) const override { + ValueLabel at(DataflowOutput const &v) const override { return this->get_view().at(v); } @@ -95,7 +99,7 @@ template static typename std::enable_if< std::is_base_of, T>::value, LabelledDataflowGraph>::type - make_lazy_copy_of( + create_lazy_copy_of_labelled_dataflow_graph_view( LabelledDataflowGraphView const &view) { std::function( LabelledDataflowGraphView const &)> diff --git a/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.h b/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.h new file mode 100644 index 0000000000..07aa64aa62 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_ALGORITHMS_REWRITE_NODE_LABELS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_ALGORITHMS_REWRITE_NODE_LABELS_H + +#include "utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_node_labels.h" + +namespace FlexFlow { + +template > +LabelledDataflowGraphView rewrite_node_labels( + LabelledDataflowGraphView const &g, F f) { + return rewrite_node_labels( + view_as_labelled_open_dataflow_graph(g), f); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.h index a1d6e9e37a..8306dad1ec 100644 --- a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.h +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.h @@ -1,7 +1,6 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_FIND_ISOMORPHISM_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_FIND_ISOMORPHISM_H -#include "utils/containers/as_vector.h" #include "utils/containers/get_all_permutations.h" #include "utils/containers/zip.h" #include "utils/graph/labelled_open_dataflow_graph/algorithms/is_isomorphic_under.h" diff --git a/lib/utils/include/utils/graph/node/node.struct.toml b/lib/utils/include/utils/graph/node/node.struct.toml index 0b6f348ddf..d5c22e5d3d 100644 --- a/lib/utils/include/utils/graph/node/node.struct.toml +++ b/lib/utils/include/utils/graph/node/node.struct.toml @@ -5,6 +5,7 @@ features = [ "ord", "hash", "fmt", + "json", ] includes = [ diff --git a/lib/utils/include/utils/graph/serial_parallel/get_serial_parallel_decomposition.h b/lib/utils/include/utils/graph/serial_parallel/get_serial_parallel_decomposition.h deleted file mode 100644 index be6b9ce12c..0000000000 --- a/lib/utils/include/utils/graph/serial_parallel/get_serial_parallel_decomposition.h +++ /dev/null @@ -1,17 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_GET_SERIAL_PARALLEL_DECOMPOSITION_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_GET_SERIAL_PARALLEL_DECOMPOSITION_H - -#include "utils/graph/digraph/digraph.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" -#include "utils/optional.h" -#include -#include - -namespace FlexFlow { - -std::optional - get_serial_parallel_decomposition(DiGraphView const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.h b/lib/utils/include/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.h deleted file mode 100644 index 6285d7ae1f..0000000000 --- a/lib/utils/include/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.h +++ /dev/null @@ -1,13 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_INTERMEDIATE_SP_DECOMPOSITION_TREE_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_INTERMEDIATE_SP_DECOMPOSITION_TREE_H - -#include "utils/graph/serial_parallel/intermediate_sp_decomposition_tree.dtg.h" - -namespace FlexFlow { - -std::variant - flatten_ast(std::variant const &ast); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.h b/lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.h deleted file mode 100644 index 7d8efc96f2..0000000000 --- a/lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.h +++ /dev/null @@ -1,22 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_SERIAL_PARALLEL_DECOMPOSITION_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_SERIAL_PARALLEL_DECOMPOSITION_H - -#include "utils/graph/serial_parallel/intermediate_sp_decomposition_tree.dtg.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" -#include - -namespace FlexFlow { - -std::variant internal_to_final_ast( - std::variant const &ast); -SerialParallelDecomposition - to_final_ast(std::variant const &); - -std::unordered_set get_nodes(SerialParallelDecomposition const &sp); -std::unordered_set get_nodes(SerialSplit const &); -std::unordered_set get_nodes(ParallelSplit const &); -std::unordered_set get_nodes(Node const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/serial_parallel/serial_parallel_splits.h b/lib/utils/include/utils/graph/serial_parallel/serial_parallel_splits.h deleted file mode 100644 index 081137e513..0000000000 --- a/lib/utils/include/utils/graph/serial_parallel/serial_parallel_splits.h +++ /dev/null @@ -1,80 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_FLATTENED_DECOMPOSITION_TREE_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_FLATTENED_DECOMPOSITION_TREE_H - -#include "utils/graph/node/node.dtg.h" -#include -#include - -namespace FlexFlow { - -struct SerialSplit; -struct ParallelSplit; - -struct SerialSplit { -public: - SerialSplit() = delete; - explicit SerialSplit(std::vector> const &); - explicit SerialSplit( - std::initializer_list> const &); - - bool operator==(SerialSplit const &) const; - bool operator!=(SerialSplit const &) const; - -public: - std::vector> children; - -private: - using Tie = std::tuple; - Tie tie() const; -}; - -std::string format_as(SerialSplit const &); -std::ostream &operator<<(std::ostream &, SerialSplit const &); - -} // namespace FlexFlow - -namespace std { - -template <> -struct hash<::FlexFlow::SerialSplit> { - size_t operator()(::FlexFlow::SerialSplit const &) const; -}; - -} // namespace std - -namespace FlexFlow { - -struct ParallelSplit { -public: - ParallelSplit() = delete; - explicit ParallelSplit( - std::unordered_set> const &); - explicit ParallelSplit( - std::initializer_list> const &); - - bool operator==(ParallelSplit const &) const; - bool operator!=(ParallelSplit const &) const; - -public: - std::unordered_set> children; - -private: - using Tie = std::tuple; - Tie tie() const; -}; - -std::string format_as(ParallelSplit const &); -std::ostream &operator<<(std::ostream &, ParallelSplit const &); - -} // namespace FlexFlow - -namespace std { - -template <> -struct hash<::FlexFlow::ParallelSplit> { - size_t operator()(::FlexFlow::ParallelSplit const &) const; -}; - -} // namespace std - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.struct.toml new file mode 100644 index 0000000000..37e3bbee09 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.struct.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "BinaryParallelSplit" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct BinarySPDecompositionTree", +] + +post_includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h", +] + +[[fields]] +name = "left_child" +type = "::FlexFlow::BinarySPDecompositionTree" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::BinarySPDecompositionTree" +indirect = true diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.struct.toml new file mode 100644 index 0000000000..7e6e86ba76 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.struct.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "BinarySeriesSplit" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct BinarySPDecompositionTree", +] + +post_includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h", +] + +[[fields]] +name = "left_child" +type = "::FlexFlow::BinarySPDecompositionTree" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::BinarySPDecompositionTree" +indirect = true diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h new file mode 100644 index 0000000000..de48cd17e9 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h @@ -0,0 +1,28 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_BINARY_SP_DECOMPOSITION_TREE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_BINARY_SP_DECOMPOSITION_TREE_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h" +#include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" +#include + +namespace FlexFlow { + +GenericBinarySPDecompositionTreeImplementation + generic_impl_for_binary_sp_tree(); + +bool is_binary_sp_tree_left_associative(BinarySPDecompositionTree const &); +bool is_binary_sp_tree_right_associative(BinarySPDecompositionTree const &); + +std::unordered_multiset get_leaves(BinarySPDecompositionTree const &); + +SPDecompositionTreeNodeType get_node_type(BinarySPDecompositionTree const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.variant.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.variant.toml new file mode 100644 index 0000000000..c586b49d9d --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.variant.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "BinarySPDecompositionTree" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.dtg.h", + "utils/graph/node/node.dtg.h", +] + +[[values]] +type = "::FlexFlow::BinarySeriesSplit" +key = "series" + +[[values]] +type = "::FlexFlow::BinaryParallelSplit" +key = "parallel" + +[[values]] +type = "::FlexFlow::Node" +key = "node" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h new file mode 100644 index 0000000000..105f5490a4 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h @@ -0,0 +1,25 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_FIND_PATHS_TO_LEAF_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_FIND_PATHS_TO_LEAF_H + +#include "utils/full_binary_tree/find_paths_to_leaf.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h" + +namespace FlexFlow { + +template +std::unordered_set find_paths_to_leaf( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &impl, + Leaf const &needle) { + FullBinaryTreeImplementation, Leaf> + full_binary_impl = get_full_binary_impl_from_generic_sp_impl(impl); + + return find_paths_to_leaf(tree, full_binary_impl, needle); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h new file mode 100644 index 0000000000..0bddbee81c --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h @@ -0,0 +1,73 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IMPLEMENTATION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IMPLEMENTATION_H + +#include "utils/exception.h" +#include "utils/full_binary_tree/full_binary_tree_implementation.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h" +#include "utils/overload.h" +#include + +namespace FlexFlow { + +template +FullBinaryTreeImplementation, Leaf> + get_full_binary_impl_from_generic_sp_impl( + GenericBinarySPDecompositionTreeImplementation const &impl) { + + using Parent = std::variant; + + auto full_binary_impl = FullBinaryTreeImplementation{ + /*get_left_child=*/[impl](Parent const &parent) -> Tree const & { + return std::visit(overload{ + [&](Series const &series) -> Tree const & { + return impl.series_get_left_child(series); + }, + [&](Parallel const ¶llel) -> Tree const & { + return impl.parallel_get_left_child(parallel); + }, + }, + parent); + }, + /*get_right_child=*/ + [impl](Parent const &parent) -> Tree const & { + return std::visit(overload{ + [&](Series const &series) -> Tree const & { + return impl.series_get_right_child(series); + }, + [&](Parallel const ¶llel) -> Tree const & { + return impl.parallel_get_right_child(parallel); + }, + }, + parent); + }, + /*is_leaf=*/ + [impl](Tree const &tree) -> bool { + return impl.get_node_type(tree) == SPDecompositionTreeNodeType::NODE; + }, + /*require_leaf=*/ + [impl](Tree const &tree) -> Leaf const & { + return impl.require_leaf(tree); + }, + /*require_parent=*/ + [impl](Tree const &tree) -> Parent { + SPDecompositionTreeNodeType node_type = impl.get_node_type(tree); + switch (node_type) { + case SPDecompositionTreeNodeType::SERIES: + return Parent{impl.require_series(tree)}; + case SPDecompositionTreeNodeType::PARALLEL: + return Parent{impl.require_parallel(tree)}; + default: + throw mk_runtime_error(fmt::format( + "Unexpected SPDecompositionTreeNodeType: {}", node_type)); + } + }}; + + return full_binary_impl; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.struct.toml new file mode 100644 index 0000000000..3ccbfd959b --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.struct.toml @@ -0,0 +1,47 @@ +namespace = "FlexFlow" +name = "GenericBinarySPDecompositionTreeImplementation" +features = [] + +template_params = [ + "Tree", + "Series", + "Parallel", + "Leaf", +] + +includes = [ + "", + "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h", +] + +[[fields]] +name = "series_get_left_child" +type = "std::function" + +[[fields]] +name = "parallel_get_left_child" +type = "std::function" + +[[fields]] +name = "series_get_right_child" +type = "std::function" + +[[fields]] +name = "parallel_get_right_child" +type = "std::function" + +[[fields]] +name = "get_node_type" +type = "std::function<::FlexFlow::SPDecompositionTreeNodeType(Tree const &)>" + +[[fields]] +name = "require_series" +type = "std::function" + +[[fields]] +name = "require_parallel" +type = "std::function" + +[[fields]] +name = "require_leaf" +type = "std::function" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_visitor.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_visitor.struct.toml new file mode 100644 index 0000000000..6275c82a0c --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_visitor.struct.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "GenericBinarySPDecompositionTreeVisitor" +features = [] + +template_params = [ + "ReturnType", + "Tree", + "Series", + "Parallel", + "Leaf", +] + +includes = [ + "", +] + +[[fields]] +name = "series_func" +type = "std::function" + +[[fields]] +name = "parallel_func" +type = "std::function" + +[[fields]] +name = "leaf_func" +type = "std::function" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h new file mode 100644 index 0000000000..b0bb8355db --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h @@ -0,0 +1,25 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_ALL_LEAF_PATHS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_ALL_LEAF_PATHS_H + +#include "utils/full_binary_tree/get_all_leaf_paths.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h" + +namespace FlexFlow { + +template +std::unordered_set get_all_leaf_paths( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &impl) { + + FullBinaryTreeImplementation, Leaf> + full_binary_impl = get_full_binary_impl_from_generic_sp_impl(impl); + + return get_all_leaf_paths(tree, full_binary_impl); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h new file mode 100644 index 0000000000..c543375148 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h @@ -0,0 +1,25 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_LEAVES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_LEAVES_H + +#include "utils/full_binary_tree/get_leaves.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h" + +namespace FlexFlow { + +template +std::unordered_multiset get_leaves( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &impl) { + + FullBinaryTreeImplementation, Leaf> + full_binary_impl = get_full_binary_impl_from_generic_sp_impl(impl); + + return get_leaves(tree, full_binary_impl); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h new file mode 100644 index 0000000000..4678e0c0f7 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h @@ -0,0 +1,25 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GET_NUM_TREE_NODES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GET_NUM_TREE_NODES_H + +#include "utils/full_binary_tree/get_num_tree_nodes.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h" + +namespace FlexFlow { + +template +int get_num_tree_nodes( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &impl) { + + FullBinaryTreeImplementation, Leaf> + full_binary_impl = get_full_binary_impl_from_generic_sp_impl(impl); + + return get_num_tree_nodes(tree, full_binary_impl); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h new file mode 100644 index 0000000000..c48185fb7f --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h @@ -0,0 +1,26 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_SUBTREE_AT_PATH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_SUBTREE_AT_PATH_H + +#include "utils/full_binary_tree/get_subtree_at_path.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h" +#include + +namespace FlexFlow { + +template +std::optional get_subtree_at_path( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &impl, + BinaryTreePath const &path) { + FullBinaryTreeImplementation, Leaf> + full_binary_impl = get_full_binary_impl_from_generic_sp_impl(impl); + + return get_subtree_at_path(tree, full_binary_impl, path); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h new file mode 100644 index 0000000000..68e0a3af32 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h @@ -0,0 +1,46 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_LEFT_ASSOCIATIVE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_LEFT_ASSOCIATIVE_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" + +namespace FlexFlow { + +template +bool is_binary_sp_tree_left_associative( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &impl) { + + auto visitor = GenericBinarySPDecompositionTreeVisitor{ + [&](Series const &split) { + return impl.get_node_type(impl.series_get_right_child(split)) != + SPDecompositionTreeNodeType::SERIES && + is_binary_sp_tree_left_associative( + impl.series_get_left_child(split), impl) && + is_binary_sp_tree_left_associative( + impl.series_get_right_child(split), impl); + }, + [&](Parallel const &split) { + return impl.get_node_type(impl.parallel_get_right_child(split)) != + SPDecompositionTreeNodeType::PARALLEL && + is_binary_sp_tree_left_associative( + impl.parallel_get_left_child(split), impl) && + is_binary_sp_tree_left_associative( + impl.parallel_get_right_child(split), impl); + }, + [&](Leaf const &leaf) { return true; }, + }; + + return visit(tree, impl, visitor); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h new file mode 100644 index 0000000000..7042765203 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h @@ -0,0 +1,45 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_RIGHT_ASSOCIATIVE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_RIGHT_ASSOCIATIVE_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" + +namespace FlexFlow { + +template +bool is_binary_sp_tree_right_associative( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &impl) { + auto visitor = GenericBinarySPDecompositionTreeVisitor{ + [&](Series const &split) { + return impl.get_node_type(impl.series_get_left_child(split)) != + SPDecompositionTreeNodeType::SERIES && + is_binary_sp_tree_right_associative( + impl.series_get_left_child(split), impl) && + is_binary_sp_tree_right_associative( + impl.series_get_right_child(split), impl); + }, + [&](Parallel const &split) { + return impl.get_node_type(impl.parallel_get_left_child(split)) != + SPDecompositionTreeNodeType::PARALLEL && + is_binary_sp_tree_right_associative( + impl.parallel_get_left_child(split), impl) && + is_binary_sp_tree_right_associative( + impl.parallel_get_right_child(split), impl); + }, + [&](Leaf const &leaf) { return true; }, + }; + + return visit(tree, impl, visitor); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h new file mode 100644 index 0000000000..c06db135b2 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h @@ -0,0 +1,48 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_VISIT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_VISIT_H + +#include "utils/exception.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_visitor.dtg.h" + +namespace FlexFlow { + +template +ReturnType + visit(Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &impl, + GenericBinarySPDecompositionTreeVisitor const &visitor) { + SPDecompositionTreeNodeType node_type = impl.get_node_type(tree); + switch (node_type) { + case SPDecompositionTreeNodeType::SERIES: { + ReturnType result = visitor.series_func(impl.require_series(tree)); + return result; + } + case SPDecompositionTreeNodeType::PARALLEL: { + ReturnType result = visitor.parallel_func(impl.require_parallel(tree)); + return result; + } + case SPDecompositionTreeNodeType::NODE: { + ReturnType result = visitor.leaf_func(impl.require_leaf(tree)); + return result; + } + default: + throw mk_runtime_error(fmt::format( + "Unknown SPDecompositionTreeNodeType value: {}", node_type)); + } +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.h new file mode 100644 index 0000000000..183ece3a89 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEFT_ASSOCIATIVE_BINARY_SP_TREE_FROM_NARY_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEFT_ASSOCIATIVE_BINARY_SP_TREE_FROM_NARY_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" + +namespace FlexFlow { + +BinarySPDecompositionTree left_associative_binary_sp_tree_from_nary( + SeriesParallelDecomposition const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.h new file mode 100644 index 0000000000..f5174aee56 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_NARY_SP_TREE_FROM_BINARY_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_NARY_SP_TREE_FROM_BINARY_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" + +namespace FlexFlow { + +SeriesParallelDecomposition + nary_sp_tree_from_binary(BinarySPDecompositionTree const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h new file mode 100644 index 0000000000..e01ec0bdde --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_RIGHT_ASSOCIATIVE_BINARY_SP_TREE_FROM_NARY_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_RIGHT_ASSOCIATIVE_BINARY_SP_TREE_FROM_NARY_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" + +namespace FlexFlow { + +BinarySPDecompositionTree right_associative_binary_sp_tree_from_nary( + SeriesParallelDecomposition const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/get_series_parallel_decomposition.h b/lib/utils/include/utils/graph/series_parallel/get_series_parallel_decomposition.h new file mode 100644 index 0000000000..f2a006d899 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/get_series_parallel_decomposition.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_GET_SERIES_PARALLEL_DECOMPOSITION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_GET_SERIES_PARALLEL_DECOMPOSITION_H + +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" +#include "utils/optional.h" +#include +#include + +namespace FlexFlow { + +std::optional + get_series_parallel_decomposition(DiGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/serial_parallel/graph_generation.h b/lib/utils/include/utils/graph/series_parallel/graph_generation.h similarity index 56% rename from lib/utils/include/utils/graph/serial_parallel/graph_generation.h rename to lib/utils/include/utils/graph/series_parallel/graph_generation.h index fac9c98db2..f18fd63d24 100644 --- a/lib/utils/include/utils/graph/serial_parallel/graph_generation.h +++ b/lib/utils/include/utils/graph/series_parallel/graph_generation.h @@ -1,23 +1,23 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_GRAPH_GENERATION_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_GRAPH_GENERATION_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_GRAPH_GENERATION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_GRAPH_GENERATION_H #include "utils/graph/dataflow_graph/dataflow_graph.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" namespace FlexFlow { void parallel_extend_unsafe(DataflowGraph &g, DataflowGraphView const &ext); -void serial_extend(DataflowGraph &g, DataflowGraphView const &ext); +void series_extend(DataflowGraph &g, DataflowGraphView const &ext); -DataflowGraph serial_composition(DataflowGraphView const &g1, +DataflowGraph series_composition(DataflowGraphView const &g1, DataflowGraphView const &g2); DataflowGraph parallel_composition(DataflowGraphView const &g1, DataflowGraphView const &g2); DataflowGraph dataflow_graph_from_sp_decomposition( - SerialParallelDecomposition const &sp_decomposition); + SeriesParallelDecomposition const &sp_decomposition); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/intermediate_sp_decomposition_tree.h b/lib/utils/include/utils/graph/series_parallel/intermediate_sp_decomposition_tree.h new file mode 100644 index 0000000000..1283a6df3a --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/intermediate_sp_decomposition_tree.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_INTERMEDIATE_SP_DECOMPOSITION_TREE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_INTERMEDIATE_SP_DECOMPOSITION_TREE_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/intermediate_sp_decomposition_tree.dtg.h" + +namespace FlexFlow { + +std::variant + flatten_ast(std::variant const &ast); + +std::variant + from_binary_sp_tree(BinarySPDecompositionTree const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.struct.toml b/lib/utils/include/utils/graph/series_parallel/intermediate_sp_decomposition_tree.struct.toml similarity index 90% rename from lib/utils/include/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.struct.toml rename to lib/utils/include/utils/graph/series_parallel/intermediate_sp_decomposition_tree.struct.toml index 08f03ed12a..e7666fcd3f 100644 --- a/lib/utils/include/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.struct.toml +++ b/lib/utils/include/utils/graph/series_parallel/intermediate_sp_decomposition_tree.struct.toml @@ -8,7 +8,7 @@ features = [ ] includes = [ - "utils/graph/serial_parallel/split_type.dtg.h", + "utils/graph/series_parallel/split_type.dtg.h", "", "", "utils/graph/node/node.dtg.h", diff --git a/lib/utils/include/utils/graph/serial_parallel/parallel_reduction.h b/lib/utils/include/utils/graph/series_parallel/parallel_reduction.h similarity index 70% rename from lib/utils/include/utils/graph/serial_parallel/parallel_reduction.h rename to lib/utils/include/utils/graph/series_parallel/parallel_reduction.h index 71cc5e3998..3fc1347ee5 100644 --- a/lib/utils/include/utils/graph/serial_parallel/parallel_reduction.h +++ b/lib/utils/include/utils/graph/series_parallel/parallel_reduction.h @@ -1,8 +1,8 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_PARALLEL_REDUCTION_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_PARALLEL_REDUCTION_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_PARALLEL_REDUCTION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_PARALLEL_REDUCTION_H #include "utils/graph/multidigraph/multidigraph.h" -#include "utils/graph/serial_parallel/parallel_reduction.dtg.h" +#include "utils/graph/series_parallel/parallel_reduction.dtg.h" #include namespace FlexFlow { diff --git a/lib/utils/include/utils/graph/serial_parallel/parallel_reduction.struct.toml b/lib/utils/include/utils/graph/series_parallel/parallel_reduction.struct.toml similarity index 100% rename from lib/utils/include/utils/graph/serial_parallel/parallel_reduction.struct.toml rename to lib/utils/include/utils/graph/series_parallel/parallel_reduction.struct.toml diff --git a/lib/utils/include/utils/graph/series_parallel/parallel_split.struct.toml b/lib/utils/include/utils/graph/series_parallel/parallel_split.struct.toml new file mode 100644 index 0000000000..dd68adf3f6 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/parallel_split.struct.toml @@ -0,0 +1,32 @@ +namespace = "FlexFlow" +name = "ParallelSplit" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct SeriesSplit" +] + +post_includes = [ + "utils/graph/series_parallel/series_split.dtg.h", +] + +includes = [ + "", + "", + "utils/graph/node/node.dtg.h", +] + +src_includes = [ + "utils/fmt/variant.h", + "utils/fmt/unordered_multiset.h", + "utils/hash/unordered_multiset.h", +] + +[[fields]] +name = "children" +type = "std::unordered_multiset>" +indirect = true diff --git a/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.h b/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.h new file mode 100644 index 0000000000..52d2cb7236 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_SERIES_PARALLEL_DECOMPOSITION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_SERIES_PARALLEL_DECOMPOSITION_H + +#include "utils/graph/series_parallel/intermediate_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" +#include + +namespace FlexFlow { + +std::variant internal_to_final_ast( + std::variant const &ast); +SeriesParallelDecomposition + to_final_ast(std::variant const &); + +std::unordered_multiset get_nodes(SeriesParallelDecomposition const &sp); +std::unordered_multiset get_nodes(SeriesSplit const &); +std::unordered_multiset get_nodes(ParallelSplit const &); +std::unordered_multiset get_nodes(Node const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.variant.toml b/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.variant.toml similarity index 62% rename from lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.variant.toml rename to lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.variant.toml index f816abfbb4..921499ebd1 100644 --- a/lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.variant.toml +++ b/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.variant.toml @@ -1,5 +1,5 @@ namespace = "FlexFlow" -name = "SerialParallelDecomposition" +name = "SeriesParallelDecomposition" features = [ "eq", "hash", @@ -7,12 +7,12 @@ features = [ ] includes = [ - "utils/graph/serial_parallel/serial_parallel_splits.h", + "utils/graph/series_parallel/series_parallel_splits.h", "utils/graph/node/node.dtg.h", ] [[values]] -type = "::FlexFlow::SerialSplit" +type = "::FlexFlow::SeriesSplit" [[values]] type = "::FlexFlow::ParallelSplit" diff --git a/lib/utils/include/utils/graph/series_parallel/series_parallel_splits.h b/lib/utils/include/utils/graph/series_parallel/series_parallel_splits.h new file mode 100644 index 0000000000..7374b45a60 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/series_parallel_splits.h @@ -0,0 +1,76 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_FLATTENED_DECOMPOSITION_TREE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_FLATTENED_DECOMPOSITION_TREE_H + +#include "utils/graph/series_parallel/parallel_split.dtg.h" +#include "utils/graph/series_parallel/series_split.dtg.h" + +namespace FlexFlow { + +// struct SeriesSplit { +// public: +// SeriesSplit() = delete; +// explicit SeriesSplit(std::vector> const +// &); explicit SeriesSplit( +// std::initializer_list> const &); +// +// bool operator==(SeriesSplit const &) const; +// bool operator!=(SeriesSplit const &) const; +// +// public: +// std::vector> children; +// +// private: +// using Tie = std::tuple; +// Tie tie() const; +// }; +// +// std::string format_as(SeriesSplit const &); +// std::ostream &operator<<(std::ostream &, SeriesSplit const &); +// +// } // namespace FlexFlow +// +// namespace std { +// +// template <> +// struct hash<::FlexFlow::SeriesSplit> { +// size_t operator()(::FlexFlow::SeriesSplit const &) const; +// }; +// +// } // namespace std +// +// namespace FlexFlow { +// +// struct ParallelSplit { +// public: +// ParallelSplit() = delete; +// explicit ParallelSplit( +// std::unordered_multiset> const &); +// explicit ParallelSplit( +// std::initializer_list> const &); +// +// bool operator==(ParallelSplit const &) const; +// bool operator!=(ParallelSplit const &) const; +// +// public: +// std::unordered_multiset> children; +// +// private: +// using Tie = std::tuple; +// Tie tie() const; +// }; +// +// std::string format_as(ParallelSplit const &); +// std::ostream &operator<<(std::ostream &, ParallelSplit const &); +// +// } // namespace FlexFlow +// +// namespace std { +// +// template <> +// struct hash<::FlexFlow::ParallelSplit> { +// size_t operator()(::FlexFlow::ParallelSplit const &) const; +// }; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/serial_parallel/series_reduction.h b/lib/utils/include/utils/graph/series_parallel/series_reduction.h similarity index 77% rename from lib/utils/include/utils/graph/serial_parallel/series_reduction.h rename to lib/utils/include/utils/graph/series_parallel/series_reduction.h index c9bae58546..a7d53fecfc 100644 --- a/lib/utils/include/utils/graph/serial_parallel/series_reduction.h +++ b/lib/utils/include/utils/graph/series_parallel/series_reduction.h @@ -1,9 +1,9 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_SERIES_REDUCTION_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_SERIES_REDUCTION_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_SERIES_REDUCTION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_SERIES_REDUCTION_H #include "utils/graph/multidigraph/multidiedge.dtg.h" #include "utils/graph/multidigraph/multidigraph.h" -#include "utils/graph/serial_parallel/series_reduction.dtg.h" +#include "utils/graph/series_parallel/series_reduction.dtg.h" namespace FlexFlow { diff --git a/lib/utils/include/utils/graph/serial_parallel/series_reduction.struct.toml b/lib/utils/include/utils/graph/series_parallel/series_reduction.struct.toml similarity index 100% rename from lib/utils/include/utils/graph/serial_parallel/series_reduction.struct.toml rename to lib/utils/include/utils/graph/series_parallel/series_reduction.struct.toml diff --git a/lib/utils/include/utils/graph/series_parallel/series_split.struct.toml b/lib/utils/include/utils/graph/series_parallel/series_split.struct.toml new file mode 100644 index 0000000000..fdb0a29972 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/series_split.struct.toml @@ -0,0 +1,31 @@ +namespace = "FlexFlow" +name = "SeriesSplit" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct ParallelSplit" +] + +post_includes = [ + "utils/graph/series_parallel/parallel_split.dtg.h", +] + +includes = [ + "", + "", + "utils/graph/node/node.dtg.h", +] + +src_includes = [ + "utils/fmt/variant.h", + "utils/fmt/vector.h", + "utils/hash/vector.h", +] + +[[fields]] +name = "children" +type = "std::vector>" diff --git a/lib/utils/include/utils/graph/serial_parallel/sink_settings.enum.toml b/lib/utils/include/utils/graph/series_parallel/sink_settings.enum.toml similarity index 100% rename from lib/utils/include/utils/graph/serial_parallel/sink_settings.enum.toml rename to lib/utils/include/utils/graph/series_parallel/sink_settings.enum.toml diff --git a/lib/utils/include/utils/graph/serial_parallel/source_settings.enum.toml b/lib/utils/include/utils/graph/series_parallel/source_settings.enum.toml similarity index 100% rename from lib/utils/include/utils/graph/serial_parallel/source_settings.enum.toml rename to lib/utils/include/utils/graph/series_parallel/source_settings.enum.toml diff --git a/lib/utils/include/utils/graph/series_parallel/sp_decomposition_tree_node_type.enum.toml b/lib/utils/include/utils/graph/series_parallel/sp_decomposition_tree_node_type.enum.toml new file mode 100644 index 0000000000..2050800cbd --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/sp_decomposition_tree_node_type.enum.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "SPDecompositionTreeNodeType" +features = [ + "hash", + "fmt", + "rapidcheck", + "json", +] + +[[values]] +name = "SERIES" + +[[values]] +name = "PARALLEL" + +[[values]] +name = "NODE" diff --git a/lib/utils/include/utils/graph/serial_parallel/split_type.enum.toml b/lib/utils/include/utils/graph/series_parallel/split_type.enum.toml similarity index 90% rename from lib/utils/include/utils/graph/serial_parallel/split_type.enum.toml rename to lib/utils/include/utils/graph/series_parallel/split_type.enum.toml index 96d85f0e12..c1a1cb5978 100644 --- a/lib/utils/include/utils/graph/serial_parallel/split_type.enum.toml +++ b/lib/utils/include/utils/graph/series_parallel/split_type.enum.toml @@ -8,7 +8,7 @@ features = [ ] [[values]] -name = "SERIAL" +name = "SERIES" [[values]] name = "PARALLEL" diff --git a/lib/utils/include/utils/graph/undirected/algorithms/get_edges.h b/lib/utils/include/utils/graph/undirected/algorithms/get_edges.h new file mode 100644 index 0000000000..3e951b1db1 --- /dev/null +++ b/lib/utils/include/utils/graph/undirected/algorithms/get_edges.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_ALGORITHMS_GET_EDGES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_ALGORITHMS_GET_EDGES_H + +#include "utils/graph/undirected/undirected_graph_view.h" + +namespace FlexFlow { + +std::unordered_set get_edges(UndirectedGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/undirected/algorithms/get_neighboring_nodes.h b/lib/utils/include/utils/graph/undirected/algorithms/get_neighboring_nodes.h new file mode 100644 index 0000000000..bc605360d2 --- /dev/null +++ b/lib/utils/include/utils/graph/undirected/algorithms/get_neighboring_nodes.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_ALGORITHMS_GET_NEIGHBORING_NODES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_ALGORITHMS_GET_NEIGHBORING_NODES_H + +#include "utils/graph/undirected/undirected_graph_view.h" + +namespace FlexFlow { + +std::unordered_set get_neighboring_nodes(UndirectedGraphView const &, + Node const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/undirected/i_undirected_graph.h b/lib/utils/include/utils/graph/undirected/i_undirected_graph.h index 1662ec6d8c..4761275031 100644 --- a/lib/utils/include/utils/graph/undirected/i_undirected_graph.h +++ b/lib/utils/include/utils/graph/undirected/i_undirected_graph.h @@ -15,7 +15,7 @@ struct IUndirectedGraph : public IUndirectedGraphView { virtual std::unordered_set query_nodes(NodeQuery const &query) const = 0; - virtual IUndirectedGraph *clone() const override = 0; + virtual IUndirectedGraph *clone() const = 0; }; } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/undirected/undirected_edge_query.h b/lib/utils/include/utils/graph/undirected/undirected_edge_query.h index 9aa0f189ec..65939acc87 100644 --- a/lib/utils/include/utils/graph/undirected/undirected_edge_query.h +++ b/lib/utils/include/utils/graph/undirected/undirected_edge_query.h @@ -1,11 +1,13 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_UNDIRECTED_EDGE_QUERY_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_UNDIRECTED_EDGE_QUERY_H +#include "utils/graph/undirected/undirected_edge.h" #include "utils/graph/undirected/undirected_edge_query.dtg.h" namespace FlexFlow { UndirectedEdgeQuery undirected_edge_query_all(); +bool matches_edge(UndirectedEdgeQuery const &, UndirectedEdge const &); UndirectedEdgeQuery query_intersection(UndirectedEdgeQuery const &, UndirectedEdgeQuery const &); diff --git a/lib/utils/include/utils/hash/multiset.h b/lib/utils/include/utils/hash/multiset.h new file mode 100644 index 0000000000..4695b89165 --- /dev/null +++ b/lib/utils/include/utils/hash/multiset.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_HASH_MULTISET_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_HASH_MULTISET_H + +#include "utils/hash-utils.h" +#include + +namespace std { + +template +struct hash> { + size_t operator()(std::multiset const &s) const { + size_t result = 0; + ::FlexFlow::unordered_container_hash(result, s); + return result; + } +}; + +} // namespace std + +#endif diff --git a/lib/utils/include/utils/hash/unordered_multiset.h b/lib/utils/include/utils/hash/unordered_multiset.h new file mode 100644 index 0000000000..b19c76bfef --- /dev/null +++ b/lib/utils/include/utils/hash/unordered_multiset.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_HASH_UNORDERED_MULTISET_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_HASH_UNORDERED_MULTISET_H + +#include "utils/hash-utils.h" +#include + +namespace std { + +template +struct hash> { + size_t operator()(std::unordered_multiset const &s) const { + size_t result = 0; + ::FlexFlow::unordered_container_hash(result, s); + return result; + } +}; + +} // namespace std + +#endif diff --git a/lib/utils/include/utils/json/check_is_json_deserializable.h b/lib/utils/include/utils/json/check_is_json_deserializable.h new file mode 100644 index 0000000000..dd5f397c19 --- /dev/null +++ b/lib/utils/include/utils/json/check_is_json_deserializable.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_CHECK_IS_JSON_DESERIALIZABLE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_CHECK_IS_JSON_DESERIALIZABLE_H + +#include "utils/json/is_json_deserializable.h" + +namespace FlexFlow { + +#define CHECK_IS_JSON_DESERIALIZABLE(TYPENAME) \ + static_assert(::FlexFlow::is_json_deserializable::value, \ + #TYPENAME " should be json deserializeable") + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/json/check_is_json_serializable.h b/lib/utils/include/utils/json/check_is_json_serializable.h new file mode 100644 index 0000000000..dfcb26081d --- /dev/null +++ b/lib/utils/include/utils/json/check_is_json_serializable.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_CHECK_IS_JSON_SERIALIZABLE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_CHECK_IS_JSON_SERIALIZABLE_H + +#include "utils/json/is_json_serializable.h" + +namespace FlexFlow { + +#define CHECK_IS_JSON_SERIALIZABLE(TYPENAME) \ + static_assert(::FlexFlow::is_json_serializable::value, \ + #TYPENAME " should be json serializeable") + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/json/check_is_jsonable.h b/lib/utils/include/utils/json/check_is_jsonable.h new file mode 100644 index 0000000000..41a64a1b83 --- /dev/null +++ b/lib/utils/include/utils/json/check_is_jsonable.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_CHECK_IS_JSONABLE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_CHECK_IS_JSONABLE_H + +#include "utils/json/is_json_deserializable.h" +#include "utils/json/is_json_serializable.h" + +namespace FlexFlow { + +#define CHECK_IS_JSONABLE(TYPENAME) \ + static_assert(is_json_serializable::value, \ + #TYPENAME " should be json serializeable"); \ + static_assert(is_json_deserializable::value, \ + #TYPENAME " should be json deserializeable") + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/json/is_json_deserializable.h b/lib/utils/include/utils/json/is_json_deserializable.h new file mode 100644 index 0000000000..9e6625428b --- /dev/null +++ b/lib/utils/include/utils/json/is_json_deserializable.h @@ -0,0 +1,25 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_IS_JSON_DESERIALIZABLE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_IS_JSON_DESERIALIZABLE_H + +#include "utils/type_traits_core.h" +#include +#include + +namespace FlexFlow { + +template +struct is_json_deserializable : std::false_type {}; + +template +struct is_json_deserializable< + T, + void_t().get())>> + : std::true_type {}; + +template +inline constexpr bool is_json_deserializable_v = + is_json_deserializable::value; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/json/is_json_serializable.h b/lib/utils/include/utils/json/is_json_serializable.h new file mode 100644 index 0000000000..926a8037d4 --- /dev/null +++ b/lib/utils/include/utils/json/is_json_serializable.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_IS_JSON_SERIALIZABLE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_IS_JSON_SERIALIZABLE_H + +#include "utils/type_traits_core.h" +#include +#include + +namespace FlexFlow { + +template +struct is_json_serializable : std::false_type {}; + +template +struct is_json_serializable< + T, + void_t() = std::declval())>> + : std::true_type {}; + +template +inline constexpr bool is_json_serializable_v = is_json_serializable::value; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/json/is_jsonable.h b/lib/utils/include/utils/json/is_jsonable.h new file mode 100644 index 0000000000..2c8c103650 --- /dev/null +++ b/lib/utils/include/utils/json/is_jsonable.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_IS_JSONABLE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_IS_JSONABLE_H + +#include "utils/json/is_json_deserializable.h" +#include "utils/json/is_json_serializable.h" + +namespace FlexFlow { + +template +struct is_jsonable + : std::conjunction, is_json_deserializable> {}; + +template +inline constexpr bool is_jsonable_v = is_jsonable::value; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/json/optional.h b/lib/utils/include/utils/json/optional.h new file mode 100644 index 0000000000..c88dd24a15 --- /dev/null +++ b/lib/utils/include/utils/json/optional.h @@ -0,0 +1,33 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_OPTIONAL_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_OPTIONAL_H + +#include "utils/json/is_jsonable.h" +#include +#include + +namespace nlohmann { + +template +struct adl_serializer< + std::optional, + typename std::enable_if<::FlexFlow::is_jsonable::value>::type> { + static void to_json(json &j, std::optional const &t) { + if (t.has_value()) { + j = t.value(); + } else { + j = nullptr; + } + } + + static void from_json(json const &j, std::optional &t) { + if (j == nullptr) { + t = std::nullopt; + } else { + t = j.get(); + } + } +}; + +} // namespace nlohmann + +#endif diff --git a/lib/utils/include/utils/json/variant.h b/lib/utils/include/utils/json/variant.h new file mode 100644 index 0000000000..fe2c3f3b6c --- /dev/null +++ b/lib/utils/include/utils/json/variant.h @@ -0,0 +1,89 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_VARIANT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_VARIANT_H + +#include "utils/json/is_jsonable.h" +#include + +namespace FlexFlow { + +struct VariantToJsonFunctor { + VariantToJsonFunctor(nlohmann::json &j) : j(j) {} + + nlohmann::json &j; + + template + void operator()(T const &t) { + static_assert(is_jsonable::value, ""); + + j = t; + } +}; + +template +void variant_to_json(json &j, std::variant const &v) { + json jval; + visit(::FlexFlow::VariantToJsonFunctor{jval}, v); + j["value"] = jval; + j["index"] = v.index(); +} + +template +std::optional variant_from_json_impl(json const &j) { + using Type = typename std::variant_alternative::type; + + if (j.at("index").get() == Idx) { + return j.at("value").get(); + } + return std::nullopt; +} + +template +std::optional variant_from_json_impl(json const &j, + std::index_sequence) { + // If there were no errors when parsing, all but one element of the array + // will be nullopt. This is because each call to variant_from_json_impl will + // have a unique index and exactly one of them will match the index in the + // json object. + std::array, sizeof...(Is)> results{ + variant_from_json_impl(j)...}; + for (std::optional &maybe : results) { + if (maybe) { + return maybe.value(); + } + } + return std::nullopt; +} + +template +std::variant variant_from_json(json const &j) { + using Variant = std::variant; + std::optional result = variant_from_json_impl( + j, std::make_index_sequence()); + if (!result.has_value()) { + throw ::FlexFlow::mk_runtime_error("Invalid type {} found in json", + j.at("index").get()); + } + return result.value(); +} + +} // namespace FlexFlow + +namespace nlohmann { + +template +struct adl_serializer, + typename std::enable_if<::FlexFlow::elements_satisfy< + ::FlexFlow::is_json_serializable, + std::variant>::value>::type> { + static void to_json(json &j, std::variant const &v) { + return ::FlexFlow::variant_to_json(j, v); + } + + static std::variant from_json(json const &j) { + return ::FlexFlow::variant_from_json(j); + } +}; + +} // namespace nlohmann + +#endif diff --git a/lib/utils/include/utils/json.h b/lib/utils/include/utils/json/visitable.h similarity index 52% rename from lib/utils/include/utils/json.h rename to lib/utils/include/utils/json/visitable.h index f56917e329..abc20065de 100644 --- a/lib/utils/include/utils/json.h +++ b/lib/utils/include/utils/json/visitable.h @@ -1,6 +1,9 @@ #ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_JSON_H #define _FLEXFLOW_UTILS_INCLUDE_UTILS_JSON_H +#include "utils/json/is_json_deserializable.h" +#include "utils/json/is_json_serializable.h" +#include "utils/json/is_jsonable.h" #include "utils/json_core.h" #include "utils/optional.h" #include "utils/sequence.h" @@ -10,33 +13,6 @@ namespace FlexFlow { -template -struct is_json_serializable : std::false_type {}; - -template -struct is_json_serializable< - T, - void_t() = std::declval())>> - : std::true_type {}; - -template -struct is_json_deserializable : std::false_type {}; - -template -struct is_json_deserializable().get())>> - : std::true_type {}; - -template -struct is_jsonable - : conjunction, is_json_deserializable> {}; - -#define CHECK_IS_JSONABLE(TYPENAME) \ - static_assert(is_json_serializable::value, \ - #TYPENAME " should be json serializeable"); \ - static_assert(is_json_deserializable::value, \ - #TYPENAME " should be json deserializeable") - struct json_serialization_visitor { json_serialization_visitor() = delete; json_serialization_visitor(json &j) : j(j) {} @@ -134,66 +110,6 @@ T moveonly_visit_json_deserialize(json const &j) { return visitable_from_tuple(tuple_from_json(j)); } -struct VariantToJsonFunctor { - VariantToJsonFunctor(json &j) : j(j) {} - - json &j; - - template - void operator()(T const &t) { - static_assert(is_jsonable::value, ""); - - j = t; - } -}; - -template -void variant_to_json(json &j, std::variant const &v) { - json jval; - visit(::FlexFlow::VariantToJsonFunctor{jval}, v); - j["value"] = jval; - j["index"] = v.index(); -} - -template -std::optional variant_from_json_impl(json const &j) { - using Type = typename std::variant_alternative::type; - - if (j.at("index").get() == Idx) { - return j.at("value").get(); - } - return std::nullopt; -} - -template -std::optional variant_from_json_impl(json const &j, - std::index_sequence) { - // If there were no errors when parsing, all but one element of the array - // will be nullopt. This is because each call to variant_from_json_impl will - // have a unique index and exactly one of them will match the index in the - // json object. - std::array, sizeof...(Is)> results{ - variant_from_json_impl(j)...}; - for (std::optional &maybe : results) { - if (maybe) { - return maybe.value(); - } - } - return std::nullopt; -} - -template -std::variant variant_from_json(json const &j) { - using Variant = std::variant; - std::optional result = variant_from_json_impl( - j, std::make_index_sequence()); - if (!result.has_value()) { - throw ::FlexFlow::mk_runtime_error("Invalid type {} found in json", - j.at("index").get()); - } - return result.value(); -} - } // namespace FlexFlow namespace nlohmann { @@ -231,41 +147,6 @@ struct adl_serializer< } }; -template -struct adl_serializer< - std::optional, - typename std::enable_if<::FlexFlow::is_jsonable::value>::type> { - static void to_json(json &j, std::optional const &t) { - if (t.has_value()) { - to_json(j, t.value()); - } else { - j = nullptr; - } - } - - static void from_json(json const &j, std::optional &t) { - if (j == nullptr) { - t = std::nullopt; - } else { - t = j.get(); - } - } -}; - -template -struct adl_serializer, - typename std::enable_if<::FlexFlow::elements_satisfy< - ::FlexFlow::is_json_serializable, - std::variant>::value>::type> { - static void to_json(json &j, std::variant const &v) { - return ::FlexFlow::variant_to_json(j, v); - } - - static std::variant from_json(json const &j) { - return ::FlexFlow::variant_from_json(j); - } -}; - } // namespace nlohmann #endif diff --git a/lib/utils/include/utils/optional.h b/lib/utils/include/utils/optional.h index 3448ec4e0e..377561d70c 100644 --- a/lib/utils/include/utils/optional.h +++ b/lib/utils/include/utils/optional.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_OPTIONAL_H -#define _FLEXFLOW_UTILS_INCLUDE_UTILS_OPTIONAL_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_OPTIONAL_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_OPTIONAL_H #include "utils/exception.h" #include "utils/fmt/optional.h" @@ -7,6 +7,15 @@ namespace FlexFlow { +template +T or_else(std::optional const &o, F &&f) { + if (o.has_value()) { + return o.value(); + } else { + return f(); + } +} + template T const &unwrap(std::optional const &o, F const &f) { if (o.has_value()) { @@ -25,18 +34,4 @@ T const &assert_unwrap(std::optional const &o) { } // namespace FlexFlow -namespace rc { - -template -struct Arbitrary> { - static Gen> arbitrary() { - return gen::map( - gen::maybe(std::move(gen::arbitrary())), [](Maybe &&m) { - return m ? std::optional(std::move(*m)) : std::optional(); - }); - } -}; - -} // namespace rc - #endif diff --git a/lib/utils/include/utils/rapidcheck/optional.h b/lib/utils/include/utils/rapidcheck/optional.h new file mode 100644 index 0000000000..edb28fdb81 --- /dev/null +++ b/lib/utils/include/utils/rapidcheck/optional.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_RAPIDCHECK_OPTIONAL_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_RAPIDCHECK_OPTIONAL_H + +#include +#include + +namespace rc { + +template +struct Arbitrary> { + static Gen> arbitrary() { + return gen::map( + gen::maybe(std::move(gen::arbitrary())), [](Maybe &&m) { + return m ? std::optional(std::move(*m)) : std::optional(); + }); + } +}; + +} // namespace rc + +#endif diff --git a/lib/utils/include/utils/required.h b/lib/utils/include/utils/required.h index 9cdd7918dd..d16b67ba86 100644 --- a/lib/utils/include/utils/required.h +++ b/lib/utils/include/utils/required.h @@ -1,9 +1,13 @@ #ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_REQUIRED_H #define _FLEXFLOW_UTILS_INCLUDE_UTILS_REQUIRED_H -#include "utils/json.h" +#include "utils/fmt/vector.h" +#include "utils/json/is_json_deserializable.h" +#include "utils/json/is_json_serializable.h" +#include "utils/json/is_jsonable.h" #include "utils/required_core.h" #include "utils/type_traits.h" +#include namespace FlexFlow { @@ -14,11 +18,11 @@ static_assert(is_list_initializable, int>::value, ""); namespace nlohmann { template struct adl_serializer<::FlexFlow::req> { - static ::FlexFlow::req from_json(json const &j) { + static ::FlexFlow::req from_json(nlohmann::json const &j) { return {j.template get()}; } - static void to_json(json &j, ::FlexFlow::req const &t) { + static void to_json(nlohmann::json &j, ::FlexFlow::req const &t) { j = static_cast(t); } }; diff --git a/lib/utils/include/utils/sequence.h b/lib/utils/include/utils/sequence.h index 6c66949fd8..07e4554299 100644 --- a/lib/utils/include/utils/sequence.h +++ b/lib/utils/include/utils/sequence.h @@ -135,7 +135,7 @@ auto seq_get(F const &f, int i, seq const &s) template auto seq_get(F const &f, int i, seq<> const &) -> decltype(f(std::declval>())) { - throw mk_runtime_error("Failed seq_get for index {}", i); + throw mk_runtime_error(fmt::format("Failed seq_get for index {}", i)); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/stack_string.h b/lib/utils/include/utils/stack_string.h index 321e66bac8..2a7b2d1849 100644 --- a/lib/utils/include/utils/stack_string.h +++ b/lib/utils/include/utils/stack_string.h @@ -4,9 +4,9 @@ #include "fmt/core.h" #include "stack_vector.h" #include "utils/fmt.h" -#include "utils/json.h" #include "utils/type_traits.h" #include +#include #include #include @@ -70,13 +70,13 @@ template using stack_string = stack_basic_string; template -void to_json(json &j, stack_string const &v) { +void to_json(nlohmann::json &j, stack_string const &v) { std::string as_string = v; j = as_string; } template -void from_json(json const &j, stack_string &v) { +void from_json(nlohmann::json const &j, stack_string &v) { std::string as_string; j.get_to(as_string); v = stack_string{as_string}; diff --git a/lib/utils/include/utils/stack_vector.h b/lib/utils/include/utils/stack_vector.h index 1d654e3415..7a7bce7afc 100644 --- a/lib/utils/include/utils/stack_vector.h +++ b/lib/utils/include/utils/stack_vector.h @@ -3,12 +3,12 @@ #include "utils/hash-utils.h" #include "utils/join_strings.h" -#include "utils/json.h" #include "utils/test_types.h" #include "utils/type_traits.h" #include #include #include +#include #include #include #include @@ -326,13 +326,13 @@ std::ostream &operator<<(std::ostream &s, stack_vector const &v) { } template -void to_json(json &j, stack_vector const &v) { +void to_json(nlohmann::json &j, stack_vector const &v) { std::vector as_vec(v.begin(), v.end()); j = as_vec; } template -void from_json(json const &j, stack_vector &v) { +void from_json(nlohmann::json const &j, stack_vector &v) { std::vector as_vec; j.get_to(as_vec); v = stack_vector{as_vec.begin(), as_vec.end()}; diff --git a/lib/utils/include/utils/tuple.h b/lib/utils/include/utils/tuple.h index afc16d4c4b..0296e365a3 100644 --- a/lib/utils/include/utils/tuple.h +++ b/lib/utils/include/utils/tuple.h @@ -67,8 +67,8 @@ template std::any get(std::tuple const &t, int idx) { size_t tuple_size = std::tuple_size::value; if (idx < 0 || idx >= tuple_size) { - throw mk_runtime_error( - "Error: idx {} out of bounds for tuple of size {}", idx, tuple_size); + throw mk_runtime_error(fmt::format( + "Error: idx {} out of bounds for tuple of size {}", idx, tuple_size)); } std::any result; visit_tuple(t, tuple_get_visitor{idx, result}); diff --git a/lib/utils/src/utils/any_value_type/any_value_type.cc b/lib/utils/src/utils/any_value_type/any_value_type.cc new file mode 100644 index 0000000000..d4c605c441 --- /dev/null +++ b/lib/utils/src/utils/any_value_type/any_value_type.cc @@ -0,0 +1,34 @@ +#include "utils/any_value_type/any_value_type.h" + +namespace FlexFlow { + +any_value_type::any_value_type( + std::any const &value, + std::function const &eq, + std::function const &neq, + std::function const &hash, + std::function const &to_string) + : value(value), eq(eq), neq(neq), hash(hash), to_string(to_string) {} + +bool any_value_type::operator==(any_value_type const &other) const { + return this->eq(this->value, other.value); +} + +bool any_value_type::operator!=(any_value_type const &other) const { + return this->neq(this->value, other.value); +} + +std::string format_as(any_value_type const &v) { + return v.to_string(v.value); +} + +} // namespace FlexFlow + +namespace std { + +size_t hash<::FlexFlow::any_value_type>::operator()( + ::FlexFlow::any_value_type const &v) const { + return v.hash(v); +} + +} // namespace std diff --git a/lib/utils/src/utils/archetypes/value_type.cc b/lib/utils/src/utils/archetypes/value_type.cc new file mode 100644 index 0000000000..f7da47d8f9 --- /dev/null +++ b/lib/utils/src/utils/archetypes/value_type.cc @@ -0,0 +1,7 @@ +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +template struct value_type<0>; + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/cli/cli_get_help_message.cc b/lib/utils/src/utils/cli/cli_get_help_message.cc new file mode 100644 index 0000000000..03c53c9356 --- /dev/null +++ b/lib/utils/src/utils/cli/cli_get_help_message.cc @@ -0,0 +1,101 @@ +#include "utils/cli/cli_get_help_message.h" +#include "utils/containers/concat_vectors.h" +#include "utils/containers/maximum.h" +#include "utils/containers/transform.h" +#include "utils/integer_conversions.h" +#include "utils/join_strings.h" +#include + +namespace FlexFlow { + +std::string cli_get_help_message(std::string const &program_name, + CLISpec const &cli) { + auto render_pos_arg = [](CLIPositionalArgumentSpec const &pos_arg_spec) { + if (pos_arg_spec.choices.has_value()) { + return "{" + join_strings(pos_arg_spec.choices.value(), ",") + "}"; + } else { + return pos_arg_spec.name; + } + }; + + auto render_flag_option_column_key = [](CLIFlagSpec const &flag_spec) { + std::ostringstream oss; + if (flag_spec.short_flag.has_value()) { + oss << "-" << flag_spec.short_flag.value() << ", "; + } + oss << "--" << flag_spec.long_flag; + return oss.str(); + }; + + std::ostringstream oss; + + oss << "usage: " << program_name; + for (CLIFlagSpec const &flag_spec : cli.flags) { + if (flag_spec.short_flag.has_value()) { + oss << " [-" << flag_spec.short_flag.value() << "]"; + } else { + oss << " [--" << flag_spec.long_flag << "]"; + } + } + for (CLIPositionalArgumentSpec const &pos_arg_spec : + cli.positional_arguments) { + oss << " " << render_pos_arg(pos_arg_spec); + } + + oss << std::endl; + + std::vector all_arg_columns = concat_vectors(std::vector{ + transform(cli.positional_arguments, render_pos_arg), + transform(cli.flags, render_flag_option_column_key), + }); + std::vector all_arg_column_widths = + transform(all_arg_columns, [](std::string const &s) { return s.size(); }); + + if (!all_arg_columns.empty()) { + int max_column_width = + std::min(int_from_size_t(maximum(all_arg_column_widths).value()), 20); + + auto render_column = [&](std::string const &key, + std::optional const &description) { + if (description.has_value()) { + if (key.size() > max_column_width) { + return " " + key + "\n" + std::string(24, ' ') + description.value(); + } else { + } + return fmt::format( + " {:<{}} {}", key, max_column_width, description.value()); + } else { + return fmt::format(" {}", key); + } + }; + + if (!cli.positional_arguments.empty()) { + oss << std::endl; + oss << "positional arguments:" << std::endl; + + if (!cli.positional_arguments.empty()) { + for (CLIPositionalArgumentSpec const &pos_arg_spec : + cli.positional_arguments) { + oss << render_column(render_pos_arg(pos_arg_spec), + pos_arg_spec.description) + << std::endl; + } + } + } + + if (!cli.flags.empty()) { + oss << std::endl; + oss << "options:" << std::endl; + + for (CLIFlagSpec const &flag_spec : cli.flags) { + oss << render_column(render_flag_option_column_key(flag_spec), + flag_spec.description) + << std::endl; + } + } + } + + return oss.str(); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/cli/cli_parse.cc b/lib/utils/src/utils/cli/cli_parse.cc new file mode 100644 index 0000000000..07982c0c2d --- /dev/null +++ b/lib/utils/src/utils/cli/cli_parse.cc @@ -0,0 +1,96 @@ +#include "utils/cli/cli_parse.h" +#include "utils/cli/cli_spec.h" +#include "utils/containers/contains.h" +#include "utils/containers/enumerate.h" +#include "utils/containers/generate_map.h" + +namespace FlexFlow { + +tl::expected cli_parse_flag(CLISpec const &cli, + std::string const &arg) { + for (auto const &[idx, flag_spec] : enumerate(cli.flags)) { + CLIFlagKey key = CLIFlagKey{idx}; + if (("--" + flag_spec.long_flag) == arg) { + return key; + } + + if (flag_spec.short_flag.has_value()) { + if ((std::string{"-"} + flag_spec.short_flag.value()) == arg) { + return key; + } + } + } + + return tl::unexpected(fmt::format("Encountered unknown flag {}", arg)); +} + +tl::expected + cli_parse(CLISpec const &cli, std::vector const &args) { + CLIParseResult result = CLIParseResult{ + generate_map(cli_get_flag_keys(cli), + [](CLIFlagKey const &) { return false; }), + {}, + }; + + int consumed_positional_args = 0; + auto parse_positional_arg = + [&](std::string const &arg) -> std::optional { + if (consumed_positional_args >= cli.positional_arguments.size()) { + return fmt::format("Too many positional arguments: expected {}", + cli.positional_arguments.size()); + } + + CLIPositionalArgumentSpec arg_spec = + cli.positional_arguments.at(consumed_positional_args); + + if (arg_spec.choices.has_value() && + !contains(arg_spec.choices.value(), arg)) { + return fmt::format( + "Invalid option for positional argument \"{}\": \"{}\"", + arg_spec.name, + arg); + } + + result.positional_arguments.insert( + {CLIPositionalArgumentKey{consumed_positional_args}, arg}); + consumed_positional_args++; + + return std::nullopt; + }; + + for (int i = 1; i < args.size(); i++) { + std::string arg = args.at(i); + + if (!arg.empty() && arg.at(0) == '-') { + tl::expected parsed_flag = + cli_parse_flag(cli, arg); + + if (parsed_flag.has_value()) { + result.flags.at(parsed_flag.value()) = true; + } + } else { + std::optional maybe_err_msg = parse_positional_arg(arg); + if (maybe_err_msg.has_value()) { + return tl::unexpected(maybe_err_msg.value()); + } + } + } + + if (consumed_positional_args != cli.positional_arguments.size()) { + return tl::unexpected( + fmt::format("Not enough positional arguments: found {}, expected {}", + consumed_positional_args, + cli.positional_arguments.size())); + } + + return result; +} + +tl::expected + cli_parse(CLISpec const &cli, int argc, char const *const *argv) { + std::vector args = {argv, argv + argc}; + + return cli_parse(cli, args); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/cli/cli_parse_result.cc b/lib/utils/src/utils/cli/cli_parse_result.cc new file mode 100644 index 0000000000..6682a7a6eb --- /dev/null +++ b/lib/utils/src/utils/cli/cli_parse_result.cc @@ -0,0 +1,14 @@ +#include "utils/cli/cli_parse_result.h" + +namespace FlexFlow { + +bool cli_get_flag(CLIParseResult const &result, CLIArgumentKey const &key) { + return result.flags.at(key.get()); +} + +std::string cli_get_argument(CLIParseResult const &result, + CLIArgumentKey const &key) { + return result.positional_arguments.at(key.get()); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/cli/cli_spec.cc b/lib/utils/src/utils/cli/cli_spec.cc new file mode 100644 index 0000000000..ca51cfe57f --- /dev/null +++ b/lib/utils/src/utils/cli/cli_spec.cc @@ -0,0 +1,37 @@ +#include "utils/cli/cli_spec.h" +#include "utils/containers/count.h" +#include "utils/containers/transform.h" +#include "utils/integer_conversions.h" + +namespace FlexFlow { + +CLISpec empty_cli_spec() { + return CLISpec{{}, {}}; +} + +std::vector cli_get_flag_keys(CLISpec const &cli) { + return transform(count(cli.flags.size()), + [](int idx) { return CLIFlagKey{idx}; }); +} + +CLIArgumentKey cli_add_help_flag(CLISpec &cli) { + CLIFlagSpec help_flag = + CLIFlagSpec{"help", 'h', "show this help message and exit"}; + return cli_add_flag(cli, help_flag); +} + +CLIArgumentKey cli_add_flag(CLISpec &cli, CLIFlagSpec const &flag_spec) { + cli.flags.push_back(flag_spec); + + return CLIArgumentKey{CLIFlagKey{int_from_size_t(cli.flags.size()) - 1}}; +} + +CLIArgumentKey + cli_add_positional_argument(CLISpec &cli, + CLIPositionalArgumentSpec const &arg) { + cli.positional_arguments.push_back(arg); + return CLIArgumentKey{CLIPositionalArgumentKey{ + int_from_size_t(cli.positional_arguments.size()) - 1}}; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/as_vector.cc b/lib/utils/src/utils/containers/as_vector.cc deleted file mode 100644 index 9c7b63ca58..0000000000 --- a/lib/utils/src/utils/containers/as_vector.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/containers/as_vector.h" diff --git a/lib/utils/src/utils/containers/cartesian_product.cc b/lib/utils/src/utils/containers/cartesian_product.cc new file mode 100644 index 0000000000..b716a49ad5 --- /dev/null +++ b/lib/utils/src/utils/containers/cartesian_product.cc @@ -0,0 +1 @@ +#include "utils/containers/cartesian_product.h" diff --git a/lib/utils/src/utils/containers/enumerate_vector.cc b/lib/utils/src/utils/containers/enumerate_vector.cc new file mode 100644 index 0000000000..d4fd131af2 --- /dev/null +++ b/lib/utils/src/utils/containers/enumerate_vector.cc @@ -0,0 +1 @@ +#include "utils/containers/enumerate_vector.h" diff --git a/lib/utils/src/utils/containers/foldl.cc b/lib/utils/src/utils/containers/foldl.cc new file mode 100644 index 0000000000..a4c32e83cc --- /dev/null +++ b/lib/utils/src/utils/containers/foldl.cc @@ -0,0 +1 @@ +#include "utils/containers/foldl.h" diff --git a/lib/utils/src/utils/containers/foldl1.cc b/lib/utils/src/utils/containers/foldl1.cc new file mode 100644 index 0000000000..c6cdd0eec9 --- /dev/null +++ b/lib/utils/src/utils/containers/foldl1.cc @@ -0,0 +1 @@ +#include "utils/containers/foldl1.h" diff --git a/lib/utils/src/utils/containers/foldr1.cc b/lib/utils/src/utils/containers/foldr1.cc new file mode 100644 index 0000000000..9d00d81565 --- /dev/null +++ b/lib/utils/src/utils/containers/foldr1.cc @@ -0,0 +1 @@ +#include "utils/containers/foldr1.h" diff --git a/lib/utils/src/utils/containers/get_all_assignments.cc b/lib/utils/src/utils/containers/get_all_assignments.cc new file mode 100644 index 0000000000..f920ba1c1a --- /dev/null +++ b/lib/utils/src/utils/containers/get_all_assignments.cc @@ -0,0 +1,12 @@ +#include "utils/containers/get_all_assignments.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = value_type<0>; +using V = value_type<1>; + +template std::unordered_set> + get_all_assignments(std::unordered_map> const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/get_element_counts.cc b/lib/utils/src/utils/containers/get_element_counts.cc index 9840ed34d8..ac8e289523 100644 --- a/lib/utils/src/utils/containers/get_element_counts.cc +++ b/lib/utils/src/utils/containers/get_element_counts.cc @@ -1,10 +1,10 @@ #include "utils/containers/get_element_counts.h" -#include "utils/containers/as_vector.h" +#include "utils/containers/vector_of.h" namespace FlexFlow { std::unordered_map get_element_counts(std::string const &s) { - return get_element_counts(as_vector(s)); + return get_element_counts(vector_of(s)); } } // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/multiset_union.cc b/lib/utils/src/utils/containers/multiset_union.cc new file mode 100644 index 0000000000..a053d05fa6 --- /dev/null +++ b/lib/utils/src/utils/containers/multiset_union.cc @@ -0,0 +1 @@ +#include "utils/containers/multiset_union.h" diff --git a/lib/utils/src/utils/containers/range.cc b/lib/utils/src/utils/containers/range.cc new file mode 100644 index 0000000000..d3ebd1063b --- /dev/null +++ b/lib/utils/src/utils/containers/range.cc @@ -0,0 +1,26 @@ +#include "utils/containers/range.h" +#include + +namespace FlexFlow { + +std::vector range(int start, int end, int step) { + assert(step != 0); + + std::vector result; + if (step > 0) { + for (int i = start; i < end; i += step) { + result.push_back(i); + } + } else { + for (int i = start; i > end; i += step) { + result.push_back(i); + } + } + return result; +} + +std::vector range(int end) { + return range(0, end); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/replicate.cc b/lib/utils/src/utils/containers/replicate.cc new file mode 100644 index 0000000000..2fb2f079f6 --- /dev/null +++ b/lib/utils/src/utils/containers/replicate.cc @@ -0,0 +1 @@ +#include "utils/containers/replicate.h" diff --git a/lib/utils/src/utils/containers/require_all_same1.cc b/lib/utils/src/utils/containers/require_all_same1.cc new file mode 100644 index 0000000000..295339a91d --- /dev/null +++ b/lib/utils/src/utils/containers/require_all_same1.cc @@ -0,0 +1 @@ +#include "utils/containers/require_all_same1.h" diff --git a/lib/utils/src/utils/containers/require_no_duplicates.cc b/lib/utils/src/utils/containers/require_no_duplicates.cc new file mode 100644 index 0000000000..b1d21ad832 --- /dev/null +++ b/lib/utils/src/utils/containers/require_no_duplicates.cc @@ -0,0 +1 @@ +#include "utils/containers/require_no_duplicates.h" diff --git a/lib/utils/src/utils/containers/scanl.cc b/lib/utils/src/utils/containers/scanl.cc new file mode 100644 index 0000000000..4f7ff78b9f --- /dev/null +++ b/lib/utils/src/utils/containers/scanl.cc @@ -0,0 +1 @@ +#include "utils/containers/scanl.h" diff --git a/lib/utils/src/utils/containers/set_of.cc b/lib/utils/src/utils/containers/set_of.cc new file mode 100644 index 0000000000..3a12ee539d --- /dev/null +++ b/lib/utils/src/utils/containers/set_of.cc @@ -0,0 +1 @@ +#include "utils/containers/set_of.h" diff --git a/lib/utils/src/utils/containers/to_uppercase.cc b/lib/utils/src/utils/containers/to_uppercase.cc new file mode 100644 index 0000000000..6c02b5a109 --- /dev/null +++ b/lib/utils/src/utils/containers/to_uppercase.cc @@ -0,0 +1,10 @@ +#include "utils/containers/to_uppercase.h" +#include "utils/containers/transform.h" + +namespace FlexFlow { + +std::string to_uppercase(std::string const &s) { + return transform(s, [](char c) -> char { return std::toupper(c); }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/try_at.cc b/lib/utils/src/utils/containers/try_at.cc new file mode 100644 index 0000000000..0d1ed3b04a --- /dev/null +++ b/lib/utils/src/utils/containers/try_at.cc @@ -0,0 +1 @@ +#include "utils/containers/try_at.h" diff --git a/lib/utils/src/utils/containers/unordered_map_from_pairs.cc b/lib/utils/src/utils/containers/unordered_map_from_pairs.cc new file mode 100644 index 0000000000..60cc978be7 --- /dev/null +++ b/lib/utils/src/utils/containers/unordered_map_from_pairs.cc @@ -0,0 +1 @@ +#include "utils/containers/unordered_map_from_pairs.h" diff --git a/lib/utils/src/utils/containers/vector_of.cc b/lib/utils/src/utils/containers/vector_of.cc new file mode 100644 index 0000000000..b997076511 --- /dev/null +++ b/lib/utils/src/utils/containers/vector_of.cc @@ -0,0 +1 @@ +#include "utils/containers/vector_of.h" diff --git a/lib/utils/src/utils/containers/without_order.cc b/lib/utils/src/utils/containers/without_order.cc deleted file mode 100644 index 3ef44b8044..0000000000 --- a/lib/utils/src/utils/containers/without_order.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/containers/without_order.h" diff --git a/lib/utils/src/utils/exception.cc b/lib/utils/src/utils/exception.cc index 9bbf780fd8..c645f241aa 100644 --- a/lib/utils/src/utils/exception.cc +++ b/lib/utils/src/utils/exception.cc @@ -1 +1,9 @@ #include "utils/exception.h" + +namespace FlexFlow { + +std::runtime_error mk_runtime_error(std::string const &s) { + return std::runtime_error(s); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/fmt/json.cc b/lib/utils/src/utils/fmt/json.cc new file mode 100644 index 0000000000..49ad57fba7 --- /dev/null +++ b/lib/utils/src/utils/fmt/json.cc @@ -0,0 +1,7 @@ +#include "utils/fmt/json.h" + +namespace fmt { + +template struct formatter<::nlohmann::json, char>; + +} diff --git a/lib/utils/src/utils/fmt/monostate.cc b/lib/utils/src/utils/fmt/monostate.cc new file mode 100644 index 0000000000..55988cdce0 --- /dev/null +++ b/lib/utils/src/utils/fmt/monostate.cc @@ -0,0 +1,9 @@ +#include "utils/fmt/monostate.h" + +namespace FlexFlow { + +std::ostream &operator<<(std::ostream &s, std::monostate const &m) { + return (s << fmt::to_string(m)); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/binary_tree_path.cc b/lib/utils/src/utils/full_binary_tree/binary_tree_path.cc new file mode 100644 index 0000000000..8445a2721a --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/binary_tree_path.cc @@ -0,0 +1,34 @@ +#include "utils/full_binary_tree/binary_tree_path.h" +#include "utils/containers/subvec.h" + +namespace FlexFlow { + +BinaryTreePath binary_tree_root_path() { + return BinaryTreePath{{}}; +} + +BinaryTreePath nest_inside_left_child(BinaryTreePath const &p) { + BinaryTreePath result = p; + result.entries.insert(result.entries.begin(), + BinaryTreePathEntry::LEFT_CHILD); + return result; +} + +BinaryTreePath nest_inside_right_child(BinaryTreePath const &p) { + BinaryTreePath result = p; + result.entries.insert(result.entries.begin(), + BinaryTreePathEntry::RIGHT_CHILD); + return result; +} + +BinaryTreePathEntry binary_tree_path_get_top_level(BinaryTreePath const &p) { + return p.entries.at(0); +} + +BinaryTreePath binary_tree_path_get_non_top_level(BinaryTreePath const &p) { + return BinaryTreePath{ + subvec(p.entries, 1, std::nullopt), + }; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/find_paths_to_leaf.cc b/lib/utils/src/utils/full_binary_tree/find_paths_to_leaf.cc new file mode 100644 index 0000000000..47845720ed --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/find_paths_to_leaf.cc @@ -0,0 +1,15 @@ +#include "utils/full_binary_tree/find_paths_to_leaf.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<0>; +using Parent = value_type<1>; +using Leaf = value_type<2>; + +template std::unordered_set + find_paths_to_leaf(Tree const &, + FullBinaryTreeImplementation const &, + Leaf const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/get_all_leaf_paths.cc b/lib/utils/src/utils/full_binary_tree/get_all_leaf_paths.cc new file mode 100644 index 0000000000..b4d8aa1011 --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/get_all_leaf_paths.cc @@ -0,0 +1,12 @@ +#include "utils/full_binary_tree/get_all_leaf_paths.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +template std::unordered_set + get_all_leaf_paths(value_type<0> const &, + FullBinaryTreeImplementation, + value_type<1>, + value_type<2>> const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/get_child.cc b/lib/utils/src/utils/full_binary_tree/get_child.cc new file mode 100644 index 0000000000..19362ae510 --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/get_child.cc @@ -0,0 +1,15 @@ +#include "utils/full_binary_tree/get_child.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<0>; +using Parent = value_type<1>; +using Leaf = value_type<2>; + +template Tree + get_child(Parent const &, + FullBinaryTreeImplementation const &, + BinaryTreePathEntry const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/get_leaves.cc b/lib/utils/src/utils/full_binary_tree/get_leaves.cc new file mode 100644 index 0000000000..0d7e9106f6 --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/get_leaves.cc @@ -0,0 +1,14 @@ +#include "utils/full_binary_tree/get_leaves.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<0>; +using Parent = value_type<1>; +using Leaf = value_type<2>; + +template std::unordered_multiset + get_leaves(Tree const &, + FullBinaryTreeImplementation const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/get_num_tree_nodes.cc b/lib/utils/src/utils/full_binary_tree/get_num_tree_nodes.cc new file mode 100644 index 0000000000..7a99dd60fa --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/get_num_tree_nodes.cc @@ -0,0 +1,13 @@ +#include "utils/full_binary_tree/get_num_tree_nodes.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<0>; +using Parent = value_type<1>; +using Leaf = value_type<2>; + +template int get_num_tree_nodes( + Tree const &, FullBinaryTreeImplementation const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/get_subtree_at_path.cc b/lib/utils/src/utils/full_binary_tree/get_subtree_at_path.cc new file mode 100644 index 0000000000..1eea13fedd --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/get_subtree_at_path.cc @@ -0,0 +1,15 @@ +#include "utils/full_binary_tree/get_subtree_at_path.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<0>; +using Parent = value_type<1>; +using Leaf = value_type<2>; + +template std::optional get_subtree_at_path( + Tree const &, + FullBinaryTreeImplementation const &, + BinaryTreePath const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/visit.cc b/lib/utils/src/utils/full_binary_tree/visit.cc new file mode 100644 index 0000000000..4a4f7c9302 --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/visit.cc @@ -0,0 +1,9 @@ +#include "utils/full_binary_tree/visit.h" + +namespace FlexFlow { + +template int visit(std::string const &, + FullBinaryTreeImplementation const &, + FullBinaryTreeVisitor const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/algorithms.cc b/lib/utils/src/utils/graph/algorithms.cc index 323f444a22..6ed41daf43 100644 --- a/lib/utils/src/utils/graph/algorithms.cc +++ b/lib/utils/src/utils/graph/algorithms.cc @@ -219,10 +219,6 @@ std::unordered_set get_endpoints(UndirectedEdge const &e) { // return g.query_edges(MultiDiEdgeQuery::all()); // } -std::unordered_set get_edges(UndirectedGraphView const &g) { - return g.query_edges(undirected_edge_query_all()); -} - // std::unordered_set get_edges(OpenMultiDiGraphView const &g) // { // return g.query_edges(OpenMultiDiEdgeQuery::all()); diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.cc new file mode 100644 index 0000000000..c07d344d05 --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.cc @@ -0,0 +1,15 @@ +#include "utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.h" + +namespace FlexFlow { + +std::unordered_set get_dataflow_edges_from_node_to_node( + DataflowGraphView const &g, Node const &src, Node const &dst) { + return g.query_edges(DataflowEdgeQuery{ + /*src_nodes=*/query_set{src}, + /*src_idxs=*/query_set::matchall(), + /*dst_nodes=*/query_set{dst}, + /*dst_idxs=*/query_set::matchall(), + }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.cc new file mode 100644 index 0000000000..d17a84dd12 --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.cc @@ -0,0 +1,24 @@ +#include "utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.h" +#include "utils/containers/set_minus.h" +#include "utils/graph/node/algorithms.h" + +namespace FlexFlow { + +std::unordered_set + get_subgraph_incoming_edges(DataflowGraphView const &g, + std::unordered_set const &ns) { + + std::unordered_set all_nodes = get_nodes(g); + query_set src_query = query_set{set_minus(all_nodes, ns)}; + + DataflowEdgeQuery query = DataflowEdgeQuery{ + src_query, + query_set::matchall(), + query_set{ns}, + query_set::matchall(), + }; + + return g.query_edges(query); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc new file mode 100644 index 0000000000..70a66c9a21 --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc @@ -0,0 +1,24 @@ +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.h" + +namespace FlexFlow { + +SplitBoundaryNodes get_transitive_reduced_boundary_nodes_for_split( + TransitiveReducedDataflowGraphView const &tr_g, + BinarySeriesSplit const &split) { + std::unordered_set edges = + get_transitive_reduced_edges_across_split(tr_g, split); + + std::unordered_set src_boundary_nodes = + transform(edges, [](DataflowEdge const &e) { return e.src.node; }); + + std::unordered_set dst_boundary_nodes = + transform(edges, [](DataflowEdge const &e) { return e.dst.node; }); + + return SplitBoundaryNodes{ + /*pre_split_boundary=*/src_boundary_nodes, + /*post_split_boundary=*/dst_boundary_nodes, + }; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc new file mode 100644 index 0000000000..8a4adf0b3a --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc @@ -0,0 +1,27 @@ +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.h" +#include "utils/containers/flatmap.h" +#include "utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.h" +#include "utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" + +namespace FlexFlow { + +std::unordered_set get_transitive_reduced_edges_across_split( + TransitiveReducedDataflowGraphView const &tr_g, + BinarySeriesSplit const &split) { + std::unordered_set src_subgraph = + unordered_set_of(get_leaves(split.get_left_child())); + std::unordered_set dst_subgraph = + unordered_set_of(get_leaves(split.get_right_child())); + + std::unordered_set raw_edges = + get_edges_from_subgraph_to_subgraph( + tr_g.transitive_reduction, src_subgraph, dst_subgraph); + + return flatmap(raw_edges, [&](DirectedEdge const &e) { + return get_dataflow_edges_from_node_to_node( + tr_g.full_dataflow_graph, e.src, e.dst); + }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc new file mode 100644 index 0000000000..0bb94c87f4 --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc @@ -0,0 +1,14 @@ +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.h" +#include "utils/containers/transform.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.h" + +namespace FlexFlow { + +std::unordered_set get_transitive_reduced_outputs_across_split( + TransitiveReducedDataflowGraphView const &tr_g, + BinarySeriesSplit const &split) { + return transform(get_transitive_reduced_edges_across_split(tr_g, split), + [](DataflowEdge const &e) { return e.src; }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.cc new file mode 100644 index 0000000000..81751702a2 --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.cc @@ -0,0 +1,17 @@ +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.h" +#include "utils/graph/digraph/algorithms/transitive_reduction.h" + +namespace FlexFlow { + +TransitiveReducedDataflowGraphView + get_dataflow_graph_transitive_reduction(DataflowGraphView const &g) { + DiGraphView as_digraph = g; + DiGraphView transitive_reduced = transitive_reduction(as_digraph); + + return TransitiveReducedDataflowGraphView{ + /*full_dataflow_graph=*/g, + /*transitive_reduction=*/transitive_reduced, + }; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc b/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc index f1f1cd7b21..8afe7da926 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc @@ -1,10 +1,15 @@ #include "utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.h" #include "utils/containers/are_disjoint.h" #include "utils/containers/extend.h" -#include "utils/containers/get_one_of.h" +#include "utils/containers/get_first.h" #include "utils/containers/set_minus.h" +#include "utils/containers/set_of.h" #include "utils/containers/values.h" +#include "utils/containers/vector_of.h" +#include "utils/fmt/set.h" +#include "utils/graph/algorithms.h" #include "utils/graph/digraph/algorithms.h" +#include "utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.h" #include "utils/graph/digraph/algorithms/get_incoming_edges.h" #include "utils/graph/digraph/algorithms/get_outgoing_edges.h" #include "utils/graph/digraph/algorithms/get_predecessors.h" @@ -12,23 +17,35 @@ #include "utils/graph/digraph/algorithms/get_weakly_connected_components.h" #include "utils/graph/node/algorithms.h" #include "utils/hash/unordered_set.h" +#include namespace FlexFlow { std::optional - get_cbc_decomposition(DiGraphView const &g) { + get_cbc_decomposition_with_edge_order_internal( + DiGraphView const &g, std::vector const &edge_order) { // implementation of the algorithm from https://doi.org/10.1145/800135.804393 // top left of page 8, second paragraph + std::queue edges_to_process; + for (DirectedEdge const &e : edge_order) { + edges_to_process.push(e); + } + std::unordered_set already_in_a_head = {}; std::unordered_set already_in_a_tail = {}; - std::unordered_set edges_to_process = get_edges(g); + + std::unordered_set already_processed = {}; CompleteBipartiteCompositeDecomposition result = CompleteBipartiteCompositeDecomposition{{}}; while (!edges_to_process.empty()) { - DirectedEdge e = get_one_of(edges_to_process); + DirectedEdge e = edges_to_process.front(); + edges_to_process.pop(); + if (contains(already_processed, e)) { + continue; + } std::unordered_set head = get_predecessors(g, e.dst); std::unordered_set tail = get_successors(g, e.src); @@ -39,6 +56,12 @@ std::optional std::unordered_set from_head_to_tail = g.query_edges(DirectedEdgeQuery{head, tail}); + + DiGraphView subgraph = get_subgraph(g, set_union(head, tail)); + if (!is_complete_bipartite_digraph(subgraph, head)) { + return std::nullopt; + } + if (set_union(values(get_outgoing_edges(g, head))) != from_head_to_tail) { return std::nullopt; } @@ -47,7 +70,7 @@ std::optional } result.subgraphs.insert(BipartiteComponent{head, tail}); - edges_to_process = set_minus(edges_to_process, from_head_to_tail); + already_processed = set_union(already_processed, from_head_to_tail); extend(already_in_a_head, head); extend(already_in_a_tail, tail); } @@ -58,4 +81,10 @@ std::optional return result; } +std::optional + get_cbc_decomposition(DiGraphView const &g) { + std::vector edge_order = vector_of(get_edges(g)); + return get_cbc_decomposition_with_edge_order_internal(g, edge_order); +} + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.cc b/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.cc new file mode 100644 index 0000000000..2eab8371b2 --- /dev/null +++ b/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.cc @@ -0,0 +1,29 @@ +#include "utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.h" +#include "utils/containers/get_first.h" +#include "utils/containers/set_minus.h" +#include "utils/graph/digraph/algorithms.h" +#include "utils/graph/node/algorithms.h" + +namespace FlexFlow { + +bool is_complete_bipartite_digraph(DiGraphView const &g) { + return is_complete_bipartite_digraph(g, get_sources(g)); +} + +bool is_complete_bipartite_digraph(DiGraphView const &g, + std::unordered_set const &srcs) { + std::unordered_set sinks = set_minus(get_nodes(g), srcs); + + std::unordered_set edges = get_edges(g); + + std::unordered_set expected_edges; + for (Node const &src : srcs) { + for (Node const &sink : sinks) { + expected_edges.insert(DirectedEdge{src, sink}); + } + } + + return edges == expected_edges; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/digraph_as_dot.cc b/lib/utils/src/utils/graph/digraph/algorithms/digraph_as_dot.cc new file mode 100644 index 0000000000..ad7830cc76 --- /dev/null +++ b/lib/utils/src/utils/graph/digraph/algorithms/digraph_as_dot.cc @@ -0,0 +1,32 @@ +#include "utils/graph/digraph/algorithms/digraph_as_dot.h" +#include "utils/dot_file.h" +#include "utils/graph/digraph/algorithms.h" +#include "utils/graph/node/algorithms.h" + +namespace FlexFlow { + +std::string digraph_as_dot( + DiGraphView const &g, + std::function const &get_node_label) { + std::ostringstream oss; + DotFile dot = DotFile{oss}; + + auto get_node_name = [](Node const &n) { + return fmt::format("n{}", n.raw_uid); + }; + + for (Node const &n : get_nodes(g)) { + RecordFormatter rec; + rec << get_node_label(n); + dot.add_record_node(get_node_name(n), rec); + } + + for (DirectedEdge const &e : get_edges(g)) { + dot.add_edge(get_node_name(e.src), get_node_name(e.dst)); + } + + dot.close(); + return oss.str(); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/digraph_has_edge.cc b/lib/utils/src/utils/graph/digraph/algorithms/digraph_has_edge.cc new file mode 100644 index 0000000000..5c790abb8c --- /dev/null +++ b/lib/utils/src/utils/graph/digraph/algorithms/digraph_has_edge.cc @@ -0,0 +1,13 @@ +#include "utils/graph/digraph/algorithms/digraph_has_edge.h" + +namespace FlexFlow { + +bool digraph_has_edge(DiGraphView const &g, DirectedEdge const &e) { + return !g.query_edges(DirectedEdgeQuery{ + query_set{e.src}, + query_set{e.dst}, + }) + .empty(); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc new file mode 100644 index 0000000000..2c6606a06b --- /dev/null +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc @@ -0,0 +1,25 @@ +#include "utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.h" +#include "utils/containers/are_disjoint.h" + +namespace FlexFlow { + +std::unordered_set get_edges_from_subgraph_to_subgraph( + DiGraphView const &g, + std::unordered_set const &src_subgraph, + std::unordered_set const &dst_subgraph) { + if (!are_disjoint(src_subgraph, dst_subgraph)) { + throw mk_runtime_error( + fmt::format("get_edges_from_subgraph_to_subgraph(DiGraphView, ...) " + "expected src_subgraph and dst_subgraph to be disjoint, " + "but found src_subgraph={}, dst_subgraph={}", + src_subgraph, + dst_subgraph)); + } + + return g.query_edges(DirectedEdgeQuery{ + /*srcs=*/query_set{src_subgraph}, + /*dsts=*/query_set{dst_subgraph}, + }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_imm_dominators_map.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_imm_dominators_map.cc index 2e570cbdf9..34cc7fcc6f 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/get_imm_dominators_map.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_imm_dominators_map.cc @@ -1,5 +1,4 @@ #include "utils/graph/digraph/algorithms/get_imm_dominators_map.h" -#include "utils/containers/as_vector.h" #include "utils/containers/concat_vectors.h" #include "utils/containers/filter_values.h" #include "utils/containers/generate_map.h" @@ -7,6 +6,7 @@ #include "utils/containers/get_only.h" #include "utils/containers/keys.h" #include "utils/containers/transform.h" +#include "utils/containers/vector_of.h" #include "utils/graph/digraph/algorithms/get_dominators_map.h" #include "utils/graph/node/algorithms.h" @@ -22,8 +22,8 @@ std::unordered_map> std::unordered_set n_dominators = node_to_its_dominators.at(n); n_dominators.erase(n); std::vector recursive_dominator_list = concat_vectors( - transform(as_vector(n_dominators), [&](Node const &dominator) { - return as_vector(node_to_its_dominators.at(dominator)); + transform(vector_of(n_dominators), [&](Node const &dominator) { + return vector_of(node_to_its_dominators.at(dominator)); })); std::unordered_map dominator_counts = get_element_counts(recursive_dominator_list); diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_subgraph_outgoing_edges.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_subgraph_outgoing_edges.cc new file mode 100644 index 0000000000..f19deb3046 --- /dev/null +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_subgraph_outgoing_edges.cc @@ -0,0 +1,16 @@ +#include "utils/graph/digraph/algorithms/get_subgraph_outgoing_edges.h" +#include "utils/containers/set_minus.h" +#include "utils/graph/node/algorithms.h" + +namespace FlexFlow { + +std::unordered_set get_subgraph_outgoing_edges( + DiGraphView const &g, std::unordered_set const &subgraph_nodes) { + std::unordered_set external_nodes = + set_minus(get_nodes(g), subgraph_nodes); + DirectedEdgeQuery query = DirectedEdgeQuery{query_set{subgraph_nodes}, + query_set{external_nodes}}; + return g.query_edges(query); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_subgraph_successors.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_subgraph_successors.cc new file mode 100644 index 0000000000..e860fb11b1 --- /dev/null +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_subgraph_successors.cc @@ -0,0 +1,16 @@ +#include "utils/graph/digraph/algorithms/get_subgraph_successors.h" +#include "utils/graph/digraph/algorithms/get_subgraph_outgoing_edges.h" + +namespace FlexFlow { + +std::unordered_set + get_subgraph_successors(DiGraphView const &g, + std::unordered_set const &subgraph_nodes) { + std::unordered_set successors = + transform(get_subgraph_outgoing_edges(g, subgraph_nodes), + [](DirectedEdge const &e) { return e.dst; }); + + return successors; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/transitive_closure.cc b/lib/utils/src/utils/graph/digraph/algorithms/transitive_closure.cc new file mode 100644 index 0000000000..3efea1c138 --- /dev/null +++ b/lib/utils/src/utils/graph/digraph/algorithms/transitive_closure.cc @@ -0,0 +1,51 @@ +#include "utils/graph/digraph/algorithms/transitive_closure.h" +#include "utils/bidict/algorithms/bidict_from_enumerating.h" +#include "utils/containers/vector_of.h" +#include "utils/graph/digraph/algorithms.h" +#include "utils/graph/digraph/algorithms/digraph_has_edge.h" +#include "utils/graph/digraph/algorithms/materialize_digraph_view.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/node/algorithms.h" + +namespace FlexFlow { + +DiGraphView transitive_closure(DiGraphView const &g) { + // Logic dropped down to raw adjacency matrix for performance. + // The version going through the full graph abstraction was + // incredibly slow (> minutes) for even moderately sized graphs + // (i.e., 200 nodes) without optimization enabled. + + bidict nodes = bidict_from_enumerating(get_nodes(g)); + std::unordered_set edges = get_edges(g); + + int num_nodes = nodes.size(); + + std::vector edge_matrix(num_nodes * num_nodes, false); + + auto has_edge = [&](int src_idx, + int dst_idx) -> std::vector::reference { + return edge_matrix[src_idx * num_nodes + dst_idx]; + }; + + for (DirectedEdge const &e : get_edges(g)) { + has_edge(nodes.at_r(e.src), nodes.at_r(e.dst)) = true; + } + + DiGraph result = materialize_digraph_view(g); + for (int k = 0; k < num_nodes; k++) { + for (int i = 0; i < num_nodes; i++) { + if (has_edge(i, k)) { + for (int j = 0; j < num_nodes; j++) { + if (has_edge(k, j)) { + has_edge(i, j) = true; + result.add_edge(DirectedEdge{nodes.at_l(i), nodes.at_l(j)}); + } + } + } + } + } + + return result; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/transitive_reduction.cc b/lib/utils/src/utils/graph/digraph/algorithms/transitive_reduction.cc index 10ffe4fc33..97a2439263 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/transitive_reduction.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/transitive_reduction.cc @@ -1,7 +1,12 @@ #include "utils/graph/digraph/algorithms/transitive_reduction.h" +#include "utils/bidict/algorithms/bidict_from_enumerating.h" +#include "utils/containers/is_subseteq_of.h" +#include "utils/containers/vector_of.h" #include "utils/graph/digraph/algorithms.h" #include "utils/graph/digraph/algorithms/materialize_digraph_view.h" +#include "utils/graph/digraph/algorithms/transitive_closure.h" #include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/node/algorithms.h" namespace FlexFlow { @@ -24,29 +29,60 @@ DirectedEdgeMaskView *DirectedEdgeMaskView::clone() const { } DiGraphView transitive_reduction(DiGraphView const &g) { - std::unordered_set edge_mask = get_edges(g); + // Logic dropped down to raw adjacency matrix for performance. + // The version going through the full graph abstraction was + // incredibly slow (> minutes) for even moderately sized graphs + // (i.e., 200 nodes) without optimization enabled. + // + // transitive_closure inlined to avoid any drifts in node numbering + // between transitive_closure and transitive_reduction + + bidict nodes = bidict_from_enumerating(get_nodes(g)); + int num_nodes = nodes.size(); + + std::vector edge_matrix(num_nodes * num_nodes, false); + + auto has_edge = [&](int src_idx, + int dst_idx) -> std::vector::reference { + return edge_matrix[src_idx * num_nodes + dst_idx]; + }; + + for (DirectedEdge const &e : get_edges(g)) { + has_edge(nodes.at_r(e.src), nodes.at_r(e.dst)) = true; + } - while (true) { - std::unordered_set new_edge_mask = edge_mask; - for (DirectedEdge const &e1 : edge_mask) { - for (DirectedEdge const &e2 : edge_mask) { - if (e1.dst == e2.src && e1 != e2) { - DirectedEdge trans_edge = DirectedEdge{e1.src, e2.dst}; - if (contains(new_edge_mask, trans_edge)) { - new_edge_mask.erase(trans_edge); + // compute transitive closure + // see https://cs.winona.edu/lin/cs440/ch08-2.pdf slide 8-8 + for (int k = 0; k < num_nodes; k++) { + for (int i = 0; i < num_nodes; i++) { + if (has_edge(i, k)) { + for (int j = 0; j < num_nodes; j++) { + if (has_edge(k, j)) { + has_edge(i, j) = true; } } } } + } - if (new_edge_mask == edge_mask) { - break; - } else { - edge_mask = new_edge_mask; + DiGraph result = materialize_digraph_view(g); + // compute transitive reduction + // see https://stackoverflow.com/a/6702198 + std::unordered_set edge_mask = get_edges(g); + for (int j = 0; j < num_nodes; j++) { + for (int i = 0; i < num_nodes; i++) { + if (has_edge(i, j)) { + for (int k = 0; k < num_nodes; k++) { + if (has_edge(j, k)) { + has_edge(i, k) = false; + result.remove_edge(DirectedEdge{nodes.at_l(i), nodes.at_l(k)}); + } + } + } } } - return DiGraphView::create(g, edge_mask); + return result; } } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/instances/adjacency_digraph.cc b/lib/utils/src/utils/graph/instances/adjacency_digraph.cc index 34a8eff503..68ef12c49e 100644 --- a/lib/utils/src/utils/graph/instances/adjacency_digraph.cc +++ b/lib/utils/src/utils/graph/instances/adjacency_digraph.cc @@ -38,11 +38,7 @@ void AdjacencyDiGraph::add_edge(DirectedEdge const &e) { } void AdjacencyDiGraph::remove_edge(DirectedEdge const &e) { - std::unordered_set &m = this->adjacency.at(e.src); - auto iter = m.find(e.dst); - if (iter != m.end()) { - m.erase(iter); - } + this->adjacency.at(e.src).erase(e.dst); } std::unordered_set diff --git a/lib/utils/src/utils/graph/instances/adjacency_multidigraph.cc b/lib/utils/src/utils/graph/instances/adjacency_multidigraph.cc index 4f4c846433..941c8e8e3e 100644 --- a/lib/utils/src/utils/graph/instances/adjacency_multidigraph.cc +++ b/lib/utils/src/utils/graph/instances/adjacency_multidigraph.cc @@ -6,6 +6,7 @@ #include "utils/containers/values.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" #include "utils/graph/node/algorithms.h" +#include "utils/hash/unordered_set.h" namespace FlexFlow { diff --git a/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc b/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc index 92a6d0b9eb..61c4f80763 100644 --- a/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc +++ b/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc @@ -23,12 +23,12 @@ void HashmapUndirectedGraph::remove_node_unsafe(Node const &n) { void HashmapUndirectedGraph::add_edge(UndirectedEdge const &e) { if (!contains_key(this->adjacency, e.bigger)) { - throw mk_runtime_error( - "Could not add edge connected to non-existent node {}", e.bigger); + throw mk_runtime_error(fmt::format( + "Could not add edge connected to non-existent node {}", e.bigger)); } if (!contains_key(this->adjacency, e.smaller)) { - throw mk_runtime_error( - "Could not add edge connected to non-existent node {}", e.smaller); + throw mk_runtime_error(fmt::format( + "Could not add edge connected to non-existent node {}", e.smaller)); } this->adjacency.at(e.bigger).insert(e.smaller); diff --git a/lib/utils/src/utils/graph/instances/unordered_set_undirected_graph.cc b/lib/utils/src/utils/graph/instances/unordered_set_undirected_graph.cc new file mode 100644 index 0000000000..6f6722f635 --- /dev/null +++ b/lib/utils/src/utils/graph/instances/unordered_set_undirected_graph.cc @@ -0,0 +1,58 @@ +#include "utils/graph/instances/unordered_set_undirected_graph.h" +#include "utils/graph/node/node_query.h" +#include "utils/graph/undirected/undirected_edge_query.h" + +namespace FlexFlow { + +UnorderedSetUndirectedGraph::UnorderedSetUndirectedGraph() {} + +UnorderedSetUndirectedGraph::UnorderedSetUndirectedGraph( + NodeSource const &node_source, + std::unordered_set const &nodes, + std::unordered_set const &edges) + : node_source(node_source), nodes(nodes), edges(edges) {} + +Node UnorderedSetUndirectedGraph::add_node() { + Node new_node = this->node_source.new_node(); + this->nodes.insert(new_node); + return new_node; +} + +void UnorderedSetUndirectedGraph::add_node_unsafe(Node const &n) { + this->nodes.insert(n); +} + +void UnorderedSetUndirectedGraph::remove_node_unsafe(Node const &n) { + this->nodes.erase(n); +} + +void UnorderedSetUndirectedGraph::add_edge(UndirectedEdge const &e) { + assert(contains(this->nodes, e.bigger)); + assert(contains(this->nodes, e.smaller)); + this->edges.insert(e); +} + +void UnorderedSetUndirectedGraph::remove_edge(UndirectedEdge const &e) { + this->edges.erase(e); +} + +std::unordered_set + UnorderedSetUndirectedGraph::query_nodes(NodeQuery const &q) const { + return apply_node_query(q, this->nodes); +} + +std::unordered_set UnorderedSetUndirectedGraph::query_edges( + UndirectedEdgeQuery const &q) const { + return filter(this->edges, + [&](UndirectedEdge const &e) { return matches_edge(q, e); }); +} + +UnorderedSetUndirectedGraph *UnorderedSetUndirectedGraph::clone() const { + return new UnorderedSetUndirectedGraph{ + this->node_source, + this->nodes, + this->edges, + }; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_dataflow_graph_view.cc b/lib/utils/src/utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_dataflow_graph_view.cc new file mode 100644 index 0000000000..28d63f9ee1 --- /dev/null +++ b/lib/utils/src/utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_dataflow_graph_view.cc @@ -0,0 +1 @@ +#include "utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_dataflow_graph_view.h" diff --git a/lib/utils/src/utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.cc b/lib/utils/src/utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.cc new file mode 100644 index 0000000000..dc5ce4fbda --- /dev/null +++ b/lib/utils/src/utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.cc @@ -0,0 +1 @@ +#include "utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.h" diff --git a/lib/utils/src/utils/graph/multidigraph/algorithms/get_edge_counts.cc b/lib/utils/src/utils/graph/multidigraph/algorithms/get_edge_counts.cc index 47096d492c..53497a715d 100644 --- a/lib/utils/src/utils/graph/multidigraph/algorithms/get_edge_counts.cc +++ b/lib/utils/src/utils/graph/multidigraph/algorithms/get_edge_counts.cc @@ -1,7 +1,7 @@ #include "utils/graph/multidigraph/algorithms/get_edge_counts.h" -#include "utils/containers/as_vector.h" #include "utils/containers/get_element_counts.h" #include "utils/containers/transform.h" +#include "utils/containers/vector_of.h" #include "utils/graph/multidigraph/algorithms/get_directed_edge.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" @@ -10,7 +10,7 @@ namespace FlexFlow { std::unordered_map get_edge_counts(MultiDiGraphView const &g) { return get_element_counts( - transform(as_vector(get_edges(g)), + transform(vector_of(get_edges(g)), [&](MultiDiEdge const &e) { return get_directed_edge(g, e); })); } diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc index 387b5b8972..fa17678943 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc @@ -2,12 +2,12 @@ #include "utils/bidict/algorithms/bidict_from_keys_and_values.h" #include "utils/bidict/algorithms/left_entries.h" #include "utils/bidict/algorithms/right_entries.h" -#include "utils/containers/as_vector.h" #include "utils/containers/get_all_permutations.h" #include "utils/containers/get_one_of.h" #include "utils/containers/is_subseteq_of.h" #include "utils/containers/keys.h" #include "utils/containers/values.h" +#include "utils/containers/vector_of.h" #include "utils/containers/zip.h" #include "utils/graph/digraph/algorithms.h" #include "utils/graph/node/algorithms.h" @@ -201,7 +201,7 @@ std::unordered_set OpenDataflowGraphView const &dst) { std::unordered_set result; - std::vector src_sink_nodes = as_vector(get_sinks(src)); + std::vector src_sink_nodes = vector_of(get_sinks(src)); std::unordered_set dst_sink_nodes = get_sinks(dst); if (src_sink_nodes.size() != dst_sink_nodes.size()) { @@ -209,7 +209,7 @@ std::unordered_set } std::vector src_unused_graph_inputs = - as_vector(get_unused_open_dataflow_graph_inputs(src)); + vector_of(get_unused_open_dataflow_graph_inputs(src)); std::unordered_set dst_unused_graph_inputs = get_unused_open_dataflow_graph_inputs(dst); diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.cc index 4ade34941c..08dda09698 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.cc @@ -5,6 +5,7 @@ #include "utils/containers/values.h" #include "utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.h" #include "utils/graph/open_dataflow_graph/open_dataflow_edge.h" +#include "utils/hash/vector.h" #include "utils/overload.h" namespace FlexFlow { diff --git a/lib/utils/src/utils/graph/serial_parallel/serial_parallel_splits.cc b/lib/utils/src/utils/graph/serial_parallel/serial_parallel_splits.cc deleted file mode 100644 index 8fa42d4b22..0000000000 --- a/lib/utils/src/utils/graph/serial_parallel/serial_parallel_splits.cc +++ /dev/null @@ -1,85 +0,0 @@ -#include "utils/graph/serial_parallel/serial_parallel_splits.h" -#include "utils/fmt/unordered_set.h" -#include "utils/fmt/variant.h" -#include "utils/fmt/vector.h" -#include "utils/hash-utils.h" -#include "utils/hash/unordered_set.h" -#include "utils/hash/vector.h" - -namespace FlexFlow { - -SerialSplit::SerialSplit( - std::vector> const &children) - : children(children) {} - -SerialSplit::SerialSplit( - std::initializer_list> const &children) - : children(children) {} - -bool SerialSplit::operator==(SerialSplit const &other) const { - return this->tie() == other.tie(); -} - -bool SerialSplit::operator!=(SerialSplit const &other) const { - return this->tie() != other.tie(); -} - -SerialSplit::Tie SerialSplit::tie() const { - return std::tie(this->children); -} - -std::string format_as(SerialSplit const &split) { - return fmt::format("", split.children); -} - -std::ostream &operator<<(std::ostream &s, SerialSplit const &split) { - return s << fmt::to_string(split); -} - -ParallelSplit::ParallelSplit( - std::unordered_set> const &children) - : children(children) {} - -ParallelSplit::ParallelSplit( - std::initializer_list> const &children) - : children(children) {} - -bool ParallelSplit::operator==(ParallelSplit const &other) const { - return this->tie() == other.tie(); -} - -bool ParallelSplit::operator!=(ParallelSplit const &other) const { - return this->tie() != other.tie(); -} - -ParallelSplit::Tie ParallelSplit::tie() const { - return std::tie(this->children); -} - -std::string format_as(ParallelSplit const &split) { - return fmt::format("", split.children); -} - -std::ostream &operator<<(std::ostream &s, ParallelSplit const &split) { - return s << fmt::to_string(split); -} - -} // namespace FlexFlow - -namespace std { - -size_t hash<::FlexFlow::SerialSplit>::operator()( - ::FlexFlow::SerialSplit const &s) const { - size_t result = 0; - ::FlexFlow::hash_combine(result, s.children); - return result; -} - -size_t hash<::FlexFlow::ParallelSplit>::operator()( - ::FlexFlow::ParallelSplit const &s) const { - size_t result = 0; - ::FlexFlow::hash_combine(result, s.children); - return result; -} - -} // namespace std diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc new file mode 100644 index 0000000000..62489ff75f --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc @@ -0,0 +1,85 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" + +namespace FlexFlow { + +GenericBinarySPDecompositionTreeImplementation + generic_impl_for_binary_sp_tree() { + + return GenericBinarySPDecompositionTreeImplementation< + BinarySPDecompositionTree, + BinarySeriesSplit, + BinaryParallelSplit, + Node>{ + /*series_get_left_child=*/[](BinarySeriesSplit const &split) + -> BinarySPDecompositionTree const & { + return split.get_left_child(); + }, + /*parallel_get_left_child=*/ + [](BinaryParallelSplit const &split) + -> BinarySPDecompositionTree const & { + return split.get_left_child(); + }, + /*series_get_right_child=*/ + [](BinarySeriesSplit const &split) -> BinarySPDecompositionTree const & { + return split.get_right_child(); + }, + /*parallel_get_right_child=*/ + [](BinaryParallelSplit const &split) + -> BinarySPDecompositionTree const & { + return split.get_right_child(); + }, + /*get_node_type=*/ + [](BinarySPDecompositionTree const &tree) -> SPDecompositionTreeNodeType { + return get_node_type(tree); + }, + /*require_series=*/ + [](BinarySPDecompositionTree const &tree) -> BinarySeriesSplit const & { + return tree.require_series(); + }, + /*require_parallel=*/ + [](BinarySPDecompositionTree const &tree) -> BinaryParallelSplit const & { + return tree.require_parallel(); + }, + /*require_leaf=*/ + [](BinarySPDecompositionTree const &tree) -> Node const & { + return tree.require_node(); + }, + }; +} + +bool is_binary_sp_tree_left_associative(BinarySPDecompositionTree const &tree) { + return is_binary_sp_tree_left_associative(tree, + generic_impl_for_binary_sp_tree()); +} + +bool is_binary_sp_tree_right_associative( + BinarySPDecompositionTree const &tree) { + return is_binary_sp_tree_right_associative(tree, + generic_impl_for_binary_sp_tree()); +} + +std::unordered_multiset + get_leaves(BinarySPDecompositionTree const &tree) { + return get_leaves(tree, generic_impl_for_binary_sp_tree()); +} + +SPDecompositionTreeNodeType + get_node_type(BinarySPDecompositionTree const &tree) { + return tree.visit(overload{ + [](BinarySeriesSplit const &) { + return SPDecompositionTreeNodeType::SERIES; + }, + [](BinaryParallelSplit const &) { + return SPDecompositionTreeNodeType::PARALLEL; + }, + [](Node const &) { return SPDecompositionTreeNodeType::NODE; }, + }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.cc new file mode 100644 index 0000000000..07e2c3e3e3 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.cc @@ -0,0 +1,19 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<0>; +using Series = value_type<1>; +using Parallel = value_type<2>; +using Leaf = value_type<3>; + +template std::unordered_set find_paths_to_leaf( + Tree const &, + GenericBinarySPDecompositionTreeImplementation const &, + Leaf const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.cc new file mode 100644 index 0000000000..56a6d0cc85 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.cc @@ -0,0 +1,18 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<0>; +using Series = value_type<1>; +using Parallel = value_type<2>; +using Leaf = value_type<3>; + +FullBinaryTreeImplementation, Leaf> + get_full_binary_impl_from_generic_sp_impl( + GenericBinarySPDecompositionTreeImplementation const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.cc new file mode 100644 index 0000000000..71d3f6ac31 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.cc @@ -0,0 +1,18 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<0>; +using Series = value_type<1>; +using Parallel = value_type<2>; +using Leaf = value_type<3>; + +template std::unordered_set get_all_leaf_paths( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc new file mode 100644 index 0000000000..3bb90bfa32 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc @@ -0,0 +1,18 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<0>; +using Series = value_type<1>; +using Parallel = value_type<2>; +using Leaf = value_type<3>; + +template std::unordered_multiset + get_leaves(Tree const &, + GenericBinarySPDecompositionTreeImplementation const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc new file mode 100644 index 0000000000..3d166145c1 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc @@ -0,0 +1,18 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<0>; +using Series = value_type<1>; +using Parallel = value_type<2>; +using Leaf = value_type<3>; + +template int get_num_tree_nodes( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.cc new file mode 100644 index 0000000000..d1d8079c0b --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.cc @@ -0,0 +1,19 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<0>; +using Series = value_type<1>; +using Parallel = value_type<2>; +using Leaf = value_type<3>; + +template std::optional get_subtree_at_path( + Tree const &, + GenericBinarySPDecompositionTreeImplementation const &, + BinaryTreePath const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc new file mode 100644 index 0000000000..69cbb28582 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc @@ -0,0 +1,18 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<0>; +using Series = value_type<1>; +using Parallel = value_type<2>; +using Leaf = value_type<3>; + +template bool is_binary_sp_tree_left_associative( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc new file mode 100644 index 0000000000..584099e33e --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc @@ -0,0 +1,18 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<1>; +using Series = value_type<2>; +using Parallel = value_type<3>; +using Leaf = value_type<4>; + +template bool is_binary_sp_tree_right_associative( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.cc new file mode 100644 index 0000000000..056ae2a8d4 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.cc @@ -0,0 +1,24 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using ReturnType = value_type<0>; +using Tree = value_type<1>; +using Series = value_type<2>; +using Parallel = value_type<3>; +using Leaf = value_type<4>; + +template ReturnType + visit(Tree const &, + GenericBinarySPDecompositionTreeImplementation const &, + GenericBinarySPDecompositionTreeVisitor const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc new file mode 100644 index 0000000000..69b2ebea8e --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc @@ -0,0 +1,76 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.h" +#include "utils/containers/foldl1.h" +#include "utils/containers/transform.h" +#include "utils/containers/vector_of.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" +#include "utils/overload.h" + +namespace FlexFlow { + +BinarySPDecompositionTree left_associative_binary_sp_tree_from_nary( + SeriesParallelDecomposition const &nary) { + std::function const &)> + from_series_child; + std::function const &)> + from_parallel_child; + + auto from_node = [](Node const &n) -> BinarySPDecompositionTree { + return BinarySPDecompositionTree{n}; + }; + + auto from_series = [&](SeriesSplit const &s) -> BinarySPDecompositionTree { + std::vector children = + transform(s.children, from_series_child); + return foldl1( + children, + [](BinarySPDecompositionTree const &accum, + BinarySPDecompositionTree const &x) -> BinarySPDecompositionTree { + return BinarySPDecompositionTree{ + BinarySeriesSplit{accum, x}, + }; + }); + }; + + auto from_parallel = + [&](ParallelSplit const &s) -> BinarySPDecompositionTree { + std::vector children = + transform(vector_of(s.get_children()), from_parallel_child); + return foldl1( + children, + [](BinarySPDecompositionTree const &accum, + BinarySPDecompositionTree const &x) -> BinarySPDecompositionTree { + return BinarySPDecompositionTree{ + BinaryParallelSplit{accum, x}, + }; + }); + }; + + from_parallel_child = [&](std::variant const &v) + -> BinarySPDecompositionTree { + return std::visit(overload{ + [&](Node const &n) { return from_node(n); }, + [&](SeriesSplit const &s) { return from_series(s); }, + }, + v); + }; + + from_series_child = [&](std::variant const &v) + -> BinarySPDecompositionTree { + return std::visit( + overload{ + [&](Node const &n) { return from_node(n); }, + [&](ParallelSplit const &p) { return from_parallel(p); }, + }, + v); + }; + + return nary.visit(overload{ + [&](Node const &n) { return from_node(n); }, + [&](SeriesSplit const &s) { return from_series(s); }, + [&](ParallelSplit const &p) { return from_parallel(p); }, + }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc new file mode 100644 index 0000000000..3b8affd16d --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc @@ -0,0 +1,12 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.h" +#include "utils/graph/series_parallel/intermediate_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.h" + +namespace FlexFlow { + +SeriesParallelDecomposition + nary_sp_tree_from_binary(BinarySPDecompositionTree const &binary) { + return to_final_ast(from_binary_sp_tree(binary)); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc new file mode 100644 index 0000000000..478d90e0c3 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc @@ -0,0 +1,73 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h" +#include "utils/containers/foldr1.h" +#include "utils/containers/transform.h" +#include "utils/containers/vector_of.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" +#include "utils/overload.h" + +namespace FlexFlow { + +BinarySPDecompositionTree right_associative_binary_sp_tree_from_nary( + SeriesParallelDecomposition const &nary) { + std::function const &)> + from_series_child; + std::function const &)> + from_parallel_child; + + auto from_node = [](Node const &n) { return BinarySPDecompositionTree{n}; }; + + auto from_series = [&](SeriesSplit const &s) { + std::vector children = + transform(s.children, from_series_child); + return foldr1( + children, + [](BinarySPDecompositionTree const &accum, + BinarySPDecompositionTree const &x) -> BinarySPDecompositionTree { + return BinarySPDecompositionTree{ + BinarySeriesSplit{x, accum}, + }; + }); + }; + + auto from_parallel = [&](ParallelSplit const &s) { + std::vector children = + transform(vector_of(s.get_children()), from_parallel_child); + return foldr1( + children, + [](BinarySPDecompositionTree const &accum, + BinarySPDecompositionTree const &x) -> BinarySPDecompositionTree { + return BinarySPDecompositionTree{ + BinaryParallelSplit{x, accum}, + }; + }); + }; + + from_parallel_child = [&](std::variant const &v) + -> BinarySPDecompositionTree { + return std::visit(overload{ + [&](Node const &n) { return from_node(n); }, + [&](SeriesSplit const &s) { return from_series(s); }, + }, + v); + }; + + from_series_child = [&](std::variant const &v) + -> BinarySPDecompositionTree { + return std::visit( + overload{ + [&](Node const &n) { return from_node(n); }, + [&](ParallelSplit const &p) { return from_parallel(p); }, + }, + v); + }; + + return nary.visit(overload{ + [&](Node const &n) { return from_node(n); }, + [&](SeriesSplit const &s) { return from_series(s); }, + [&](ParallelSplit const &p) { return from_parallel(p); }, + }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/serial_parallel/get_serial_parallel_decomposition.cc b/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc similarity index 63% rename from lib/utils/src/utils/graph/serial_parallel/get_serial_parallel_decomposition.cc rename to lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc index 632f5245db..cd29af59a0 100644 --- a/lib/utils/src/utils/graph/serial_parallel/get_serial_parallel_decomposition.cc +++ b/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc @@ -1,23 +1,28 @@ -#include "utils/graph/serial_parallel/get_serial_parallel_decomposition.h" +#include "utils/graph/series_parallel/get_series_parallel_decomposition.h" #include "utils/containers/get_only.h" #include "utils/containers/map_values.h" #include "utils/containers/transform.h" #include "utils/graph/digraph/algorithms/inverse_line_graph/get_inverse_line_graph.h" +#include "utils/graph/digraph/algorithms/transitive_reduction.h" #include "utils/graph/instances/adjacency_multidigraph.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" #include "utils/graph/node/algorithms.h" -#include "utils/graph/serial_parallel/parallel_reduction.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.h" -#include "utils/graph/serial_parallel/series_reduction.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.h" +#include "utils/graph/series_parallel/parallel_reduction.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.h" +#include "utils/graph/series_parallel/series_reduction.h" namespace FlexFlow { -std::optional - get_serial_parallel_decomposition(DiGraphView const &g) { +std::optional + get_series_parallel_decomposition(DiGraphView const &g) { + + DiGraphView transitively_reduced = transitive_reduction(g); InverseLineGraphResult inverse_line_graph_result = ({ std::optional maybe_line_graph = - get_inverse_line_graph(g); + get_inverse_line_graph(transitively_reduced); if (!maybe_line_graph.has_value()) { return std::nullopt; } @@ -27,14 +32,11 @@ std::optional MultiDiGraph ttsp = MultiDiGraph::materialize_copy_of( inverse_line_graph_result.graph); - std::unordered_map> + std::unordered_map ttsp_edge_to_sp_tree = map_values( inverse_line_graph_result.inverse_edge_to_line_node_bidict .as_unordered_map(), - [](Node const &n) { - return std::variant{n}; - }); + [](Node const &n) { return BinarySPDecompositionTree{n}; }); while (true) { assert(ttsp_edge_to_sp_tree.size() == get_edges(ttsp).size()); @@ -44,11 +46,12 @@ std::optional ParallelReduction parallel_reduction = maybe_parallel_reduction.value(); auto [e1, e2] = parallel_reduction.edges.ordered(); MultiDiEdge merged = apply_parallel_reduction(ttsp, parallel_reduction); - std::variant new_tree = - IntermediateSpDecompositionTree{ - SplitType::PARALLEL, - {ttsp_edge_to_sp_tree.at(e1), ttsp_edge_to_sp_tree.at(e2)}, - }; + BinarySPDecompositionTree new_tree = BinarySPDecompositionTree{ + BinaryParallelSplit{ + ttsp_edge_to_sp_tree.at(e1), + ttsp_edge_to_sp_tree.at(e2), + }, + }; ttsp_edge_to_sp_tree.erase(e1); ttsp_edge_to_sp_tree.erase(e2); ttsp_edge_to_sp_tree.insert({merged, new_tree}); @@ -63,11 +66,12 @@ std::optional MultiDiEdge e1 = series_reduction.first; MultiDiEdge e2 = series_reduction.second; MultiDiEdge merged = apply_series_reduction(ttsp, series_reduction); - std::variant new_tree = - IntermediateSpDecompositionTree{ - SplitType::SERIAL, - {ttsp_edge_to_sp_tree.at(e1), ttsp_edge_to_sp_tree.at(e2)}, - }; + BinarySPDecompositionTree new_tree = BinarySPDecompositionTree{ + BinarySeriesSplit{ + ttsp_edge_to_sp_tree.at(e1), + ttsp_edge_to_sp_tree.at(e2), + }, + }; ttsp_edge_to_sp_tree.erase(e1); ttsp_edge_to_sp_tree.erase(e2); ttsp_edge_to_sp_tree.insert({merged, new_tree}); @@ -83,7 +87,7 @@ std::optional MultiDiEdge e = get_only(get_edges(ttsp)); if (ttsp.get_multidiedge_src(e) != ttsp.get_multidiedge_dst(e)) { - return to_final_ast(ttsp_edge_to_sp_tree.at(e)); + return nary_sp_tree_from_binary(ttsp_edge_to_sp_tree.at(e)); } } } diff --git a/lib/utils/src/utils/graph/serial_parallel/graph_generation.cc b/lib/utils/src/utils/graph/series_parallel/graph_generation.cc similarity index 79% rename from lib/utils/src/utils/graph/serial_parallel/graph_generation.cc rename to lib/utils/src/utils/graph/series_parallel/graph_generation.cc index 4c9eb9d3ef..7070d04c4a 100644 --- a/lib/utils/src/utils/graph/serial_parallel/graph_generation.cc +++ b/lib/utils/src/utils/graph/series_parallel/graph_generation.cc @@ -1,4 +1,4 @@ -#include "utils/graph/serial_parallel/graph_generation.h" +#include "utils/graph/series_parallel/graph_generation.h" #include "utils/graph/dataflow_graph/algorithms.h" #include "utils/graph/instances/unordered_set_dataflow_graph.h" #include "utils/graph/node/algorithms.h" @@ -12,7 +12,7 @@ void parallel_extend_unsafe(DataflowGraph &g, DataflowGraphView const &ext) { } } -void serial_extend_unsafe(DataflowGraph &g, DataflowGraphView const &ext) { +void series_extend_unsafe(DataflowGraph &g, DataflowGraphView const &ext) { // TODO(@lockshaw): This function signature is impossible to implement in // general, as there is no guarantee that the graph view ext actually has // source nodes with inputs Either the signature should be changed, or an @@ -22,11 +22,11 @@ void serial_extend_unsafe(DataflowGraph &g, DataflowGraphView const &ext) { NOT_IMPLEMENTED(); } -DataflowGraph serial_composition(DataflowGraphView const &g1, +DataflowGraph series_composition(DataflowGraphView const &g1, DataflowGraphView const &g2) { DataflowGraph g = DataflowGraph::create_copy_of(g1); - serial_extend_unsafe(g, g2); + series_extend_unsafe(g, g2); return g; } @@ -39,8 +39,8 @@ DataflowGraph parallel_composition(DataflowGraphView const &g1, } DataflowGraph dataflow_graph_from_sp_decomposition( - SerialParallelDecomposition const &sp_decomposition) { - // TODO(@lockshaw): see existing concerns about serial_extend_unsafe + SeriesParallelDecomposition const &sp_decomposition) { + // TODO(@lockshaw): see existing concerns about series_extend_unsafe NOT_IMPLEMENTED(); } diff --git a/lib/utils/src/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.cc b/lib/utils/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc similarity index 54% rename from lib/utils/src/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.cc rename to lib/utils/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc index 6384bd9159..410a40236d 100644 --- a/lib/utils/src/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.cc +++ b/lib/utils/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc @@ -1,5 +1,7 @@ -#include "utils/graph/serial_parallel/intermediate_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/intermediate_sp_decomposition_tree.h" #include "utils/containers/extend.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" +#include "utils/overload.h" namespace FlexFlow { @@ -45,4 +47,31 @@ std::variant flatten_ast( return std::visit(FlattenAST{}, ast); } +std::variant + from_binary_sp_tree(BinarySPDecompositionTree const &binary) { + return binary + .template visit>( + overload{ + [](Node const &n) { return n; }, + [](BinarySeriesSplit const &s) { + return IntermediateSpDecompositionTree{ + SplitType::SERIES, + { + from_binary_sp_tree(s.get_left_child()), + from_binary_sp_tree(s.get_right_child()), + }, + }; + }, + [](BinaryParallelSplit const &p) { + return IntermediateSpDecompositionTree{ + SplitType::PARALLEL, + { + from_binary_sp_tree(p.get_left_child()), + from_binary_sp_tree(p.get_right_child()), + }, + }; + }, + }); +} + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/serial_parallel/parallel_reduction.cc b/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc similarity index 93% rename from lib/utils/src/utils/graph/serial_parallel/parallel_reduction.cc rename to lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc index 30aa10edd7..12a6630bf0 100644 --- a/lib/utils/src/utils/graph/serial_parallel/parallel_reduction.cc +++ b/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc @@ -1,4 +1,4 @@ -#include "utils/graph/serial_parallel/parallel_reduction.h" +#include "utils/graph/series_parallel/parallel_reduction.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" namespace FlexFlow { diff --git a/lib/utils/src/utils/graph/serial_parallel/serial_parallel_decomposition.cc b/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc similarity index 51% rename from lib/utils/src/utils/graph/serial_parallel/serial_parallel_decomposition.cc rename to lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc index 666bf40f10..b7a84b871a 100644 --- a/lib/utils/src/utils/graph/serial_parallel/serial_parallel_decomposition.cc +++ b/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc @@ -1,18 +1,20 @@ -#include "utils/graph/serial_parallel/serial_parallel_decomposition.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.h" +#include "utils/containers/multiset_union.h" #include "utils/containers/set_union.h" #include "utils/containers/transform.h" -#include "utils/containers/unordered_set_of.h" -#include "utils/graph/serial_parallel/intermediate_sp_decomposition_tree.h" +#include "utils/containers/unordered_multiset_of.h" +#include "utils/containers/vector_of.h" +#include "utils/graph/series_parallel/intermediate_sp_decomposition_tree.h" #include "utils/hash/unordered_set.h" #include "utils/variant.h" namespace FlexFlow { struct ToFinalAST { - std::variant + std::variant operator()(IntermediateSpDecompositionTree const &node) { - if (node.type == SplitType::SERIAL) { - return SerialSplit{transform( + if (node.type == SplitType::SERIES) { + return SeriesSplit{transform( node.children, [](std::variant const &s) { return narrow>( @@ -20,54 +22,55 @@ struct ToFinalAST { .value(); })}; } else { - return ParallelSplit{unordered_set_of(transform( + return ParallelSplit{unordered_multiset_of(transform( node.children, [](std::variant const &s) { - return narrow>( + return narrow>( internal_to_final_ast(s)) .value(); }))}; } } - std::variant operator()(Node const &node) { + std::variant operator()(Node const &node) { return node; } }; -std::variant internal_to_final_ast( +std::variant internal_to_final_ast( std::variant const &ast) { return std::visit(ToFinalAST{}, flatten_ast(ast)); } -SerialParallelDecomposition to_final_ast( +SeriesParallelDecomposition to_final_ast( std::variant const &ast) { - return std::visit([](auto &&x) { return SerialParallelDecomposition{x}; }, + return std::visit([](auto &&x) { return SeriesParallelDecomposition{x}; }, internal_to_final_ast(ast)); } -std::unordered_set get_nodes(SerialParallelDecomposition const &sp) { - return sp.visit>( +std::unordered_multiset get_nodes(SeriesParallelDecomposition const &sp) { + return sp.visit>( [](auto &&t) { return get_nodes(t); }); } -std::unordered_set get_nodes(SerialSplit const &serial) { - return set_union(transform( +std::unordered_multiset get_nodes(SeriesSplit const &serial) { + return multiset_union(transform( serial.children, [](std::variant const &child) - -> std::unordered_set { + -> std::unordered_multiset { return std::visit([](auto &&t) { return get_nodes(t); }, child); })); } -std::unordered_set get_nodes(ParallelSplit const ¶llel) { - return set_union(transform( - parallel.children, [](std::variant const &child) { +std::unordered_multiset get_nodes(ParallelSplit const ¶llel) { + return multiset_union(transform( + vector_of(parallel.get_children()), + [](std::variant const &child) { return std::visit([](auto &&t) { return get_nodes(t); }, child); })); } -std::unordered_set get_nodes(Node const &node) { +std::unordered_multiset get_nodes(Node const &node) { return {node}; } diff --git a/lib/utils/src/utils/graph/serial_parallel/series_reduction.cc b/lib/utils/src/utils/graph/series_parallel/series_reduction.cc similarity index 97% rename from lib/utils/src/utils/graph/serial_parallel/series_reduction.cc rename to lib/utils/src/utils/graph/series_parallel/series_reduction.cc index e26f460e0e..7300c93fb0 100644 --- a/lib/utils/src/utils/graph/serial_parallel/series_reduction.cc +++ b/lib/utils/src/utils/graph/series_parallel/series_reduction.cc @@ -1,4 +1,4 @@ -#include "utils/graph/serial_parallel/series_reduction.h" +#include "utils/graph/series_parallel/series_reduction.h" #include "utils/containers/require_same.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" #include "utils/graph/multidigraph/algorithms/get_incoming_edges.h" diff --git a/lib/utils/src/utils/graph/undirected/algorithms/get_edges.cc b/lib/utils/src/utils/graph/undirected/algorithms/get_edges.cc new file mode 100644 index 0000000000..8ae825c1ab --- /dev/null +++ b/lib/utils/src/utils/graph/undirected/algorithms/get_edges.cc @@ -0,0 +1,10 @@ +#include "utils/graph/undirected/algorithms/get_edges.h" +#include "utils/graph/undirected/undirected_edge_query.h" + +namespace FlexFlow { + +std::unordered_set get_edges(UndirectedGraphView const &g) { + return g.query_edges(undirected_edge_query_all()); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/undirected/algorithms/get_neighboring_nodes.cc b/lib/utils/src/utils/graph/undirected/algorithms/get_neighboring_nodes.cc new file mode 100644 index 0000000000..3c05b9d5d5 --- /dev/null +++ b/lib/utils/src/utils/graph/undirected/algorithms/get_neighboring_nodes.cc @@ -0,0 +1,19 @@ +#include "utils/graph/undirected/algorithms/get_neighboring_nodes.h" +#include "utils/containers/vector_of.h" + +namespace FlexFlow { + +std::unordered_set get_neighboring_nodes(UndirectedGraphView const &g, + Node const &n) { + std::unordered_set edges = + g.query_edges(UndirectedEdgeQuery{query_set{n}}); + + std::unordered_set result = + set_union(transform(vector_of(edges), [](UndirectedEdge const &e) { + return std::unordered_set{e.bigger, e.smaller}; + })); + result.erase(n); + return result; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/undirected/undirected_edge_query.cc b/lib/utils/src/utils/graph/undirected/undirected_edge_query.cc index 5c41eef7da..3cccf1c6eb 100644 --- a/lib/utils/src/utils/graph/undirected/undirected_edge_query.cc +++ b/lib/utils/src/utils/graph/undirected/undirected_edge_query.cc @@ -6,6 +6,10 @@ UndirectedEdgeQuery undirected_edge_query_all() { return UndirectedEdgeQuery{matchall()}; } +bool matches_edge(UndirectedEdgeQuery const &q, UndirectedEdge const &e) { + return includes(q.nodes, e.bigger) && includes(q.nodes, e.smaller); +} + UndirectedEdgeQuery query_intersection(UndirectedEdgeQuery const &lhs, UndirectedEdgeQuery const &rhs) { return UndirectedEdgeQuery{ diff --git a/lib/utils/src/utils/hash/multiset.cc b/lib/utils/src/utils/hash/multiset.cc new file mode 100644 index 0000000000..d84ca7d614 --- /dev/null +++ b/lib/utils/src/utils/hash/multiset.cc @@ -0,0 +1 @@ +#include "utils/hash/multiset.h" diff --git a/lib/utils/src/utils/hash/unordered_multiset.cc b/lib/utils/src/utils/hash/unordered_multiset.cc new file mode 100644 index 0000000000..7f6f73f428 --- /dev/null +++ b/lib/utils/src/utils/hash/unordered_multiset.cc @@ -0,0 +1 @@ +#include "utils/hash/unordered_multiset.h" diff --git a/lib/utils/src/utils/json/check_is_json_deserializable.cc b/lib/utils/src/utils/json/check_is_json_deserializable.cc new file mode 100644 index 0000000000..7e17ced7e5 --- /dev/null +++ b/lib/utils/src/utils/json/check_is_json_deserializable.cc @@ -0,0 +1 @@ +#include "utils/json/check_is_json_deserializable.h" diff --git a/lib/utils/src/utils/json/check_is_json_serializable.cc b/lib/utils/src/utils/json/check_is_json_serializable.cc new file mode 100644 index 0000000000..1c9af4d3cb --- /dev/null +++ b/lib/utils/src/utils/json/check_is_json_serializable.cc @@ -0,0 +1 @@ +#include "utils/json/check_is_json_serializable.h" diff --git a/lib/utils/src/utils/json/check_is_jsonable.cc b/lib/utils/src/utils/json/check_is_jsonable.cc new file mode 100644 index 0000000000..1e78fdb21f --- /dev/null +++ b/lib/utils/src/utils/json/check_is_jsonable.cc @@ -0,0 +1 @@ +#include "utils/json/check_is_jsonable.h" diff --git a/lib/utils/src/utils/json/is_json_deserializable.cc b/lib/utils/src/utils/json/is_json_deserializable.cc new file mode 100644 index 0000000000..17df41433d --- /dev/null +++ b/lib/utils/src/utils/json/is_json_deserializable.cc @@ -0,0 +1 @@ +#include "utils/json/is_json_deserializable.h" diff --git a/lib/utils/src/utils/json/is_json_serializable.cc b/lib/utils/src/utils/json/is_json_serializable.cc new file mode 100644 index 0000000000..883ee9f51a --- /dev/null +++ b/lib/utils/src/utils/json/is_json_serializable.cc @@ -0,0 +1 @@ +#include "utils/json/is_json_serializable.h" diff --git a/lib/utils/src/utils/json/is_jsonable.cc b/lib/utils/src/utils/json/is_jsonable.cc new file mode 100644 index 0000000000..3f819f8556 --- /dev/null +++ b/lib/utils/src/utils/json/is_jsonable.cc @@ -0,0 +1 @@ +#include "utils/json/is_jsonable.h" diff --git a/lib/utils/src/utils/json/optional.cc b/lib/utils/src/utils/json/optional.cc new file mode 100644 index 0000000000..c8f0fd2e3c --- /dev/null +++ b/lib/utils/src/utils/json/optional.cc @@ -0,0 +1 @@ +#include "utils/json/optional.h" diff --git a/lib/utils/src/utils/rapidcheck/optional.cc b/lib/utils/src/utils/rapidcheck/optional.cc new file mode 100644 index 0000000000..6d62532e7e --- /dev/null +++ b/lib/utils/src/utils/rapidcheck/optional.cc @@ -0,0 +1 @@ +#include "utils/rapidcheck/optional.h" diff --git a/lib/utils/test/common/include/test/utils/all.h b/lib/utils/test/common/include/test/utils/all.h deleted file mode 100644 index ced1c9ce38..0000000000 --- a/lib/utils/test/common/include/test/utils/all.h +++ /dev/null @@ -1,2 +0,0 @@ -#include "test/utils/doctest.h" -#include "test/utils/rapidcheck.h" diff --git a/lib/utils/test/common/include/test/utils/doctest.h b/lib/utils/test/common/include/test/utils/doctest/check_without_stringify.h similarity index 100% rename from lib/utils/test/common/include/test/utils/doctest.h rename to lib/utils/test/common/include/test/utils/doctest/check_without_stringify.h diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/expected.h b/lib/utils/test/common/include/test/utils/doctest/fmt/expected.h new file mode 100644 index 0000000000..8333ac4777 --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/expected.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_EXPECTED_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_EXPECTED_H + +#include "utils/fmt/expected.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(tl::expected const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/map.h b/lib/utils/test/common/include/test/utils/doctest/fmt/map.h new file mode 100644 index 0000000000..d20dbe6943 --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/map.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_MAP_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_MAP_H + +#include "utils/fmt/map.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(std::map const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/multiset.h b/lib/utils/test/common/include/test/utils/doctest/fmt/multiset.h new file mode 100644 index 0000000000..b26eee28ba --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/multiset.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_MULTISET_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_MULTISET_H + +#include "utils/fmt/multiset.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(std::multiset const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/optional.h b/lib/utils/test/common/include/test/utils/doctest/fmt/optional.h new file mode 100644 index 0000000000..519cde7d74 --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/optional.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_OPTIONAL_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_OPTIONAL_H + +#include "utils/fmt/optional.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(std::optional const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/pair.h b/lib/utils/test/common/include/test/utils/doctest/fmt/pair.h new file mode 100644 index 0000000000..db0ed24f13 --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/pair.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_PAIR_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_PAIR_H + +#include "utils/fmt/pair.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(std::pair const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/set.h b/lib/utils/test/common/include/test/utils/doctest/fmt/set.h new file mode 100644 index 0000000000..3dd386645c --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/set.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_SET_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_SET_H + +#include "utils/fmt/set.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(std::set const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/unordered_map.h b/lib/utils/test/common/include/test/utils/doctest/fmt/unordered_map.h new file mode 100644 index 0000000000..4fd5d15009 --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/unordered_map.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_UNORDERED_MAP_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_UNORDERED_MAP_H + +#include "utils/fmt/unordered_map.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(std::unordered_map const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/unordered_multiset.h b/lib/utils/test/common/include/test/utils/doctest/fmt/unordered_multiset.h new file mode 100644 index 0000000000..94dae42239 --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/unordered_multiset.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_UNORDERED_MULTISET_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_UNORDERED_MULTISET_H + +#include "utils/fmt/unordered_multiset.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(std::unordered_multiset const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/unordered_set.h b/lib/utils/test/common/include/test/utils/doctest/fmt/unordered_set.h new file mode 100644 index 0000000000..441590365d --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/unordered_set.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_UNORDERED_SET_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_UNORDERED_SET_H + +#include "utils/fmt/unordered_set.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(std::unordered_set const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/variant.h b/lib/utils/test/common/include/test/utils/doctest/fmt/variant.h new file mode 100644 index 0000000000..c30862274a --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/variant.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_VARIANT_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_VARIANT_H + +#include "utils/fmt/variant.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(std::variant const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/vector.h b/lib/utils/test/common/include/test/utils/doctest/fmt/vector.h new file mode 100644 index 0000000000..56198a7558 --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/vector.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_VECTOR_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_VECTOR_H + +#include "utils/fmt/vector.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(std::vector const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/src/common.cc b/lib/utils/test/common/src/common.cc deleted file mode 100644 index 51e981b1f5..0000000000 --- a/lib/utils/test/common/src/common.cc +++ /dev/null @@ -1 +0,0 @@ -#include "test/utils/all.h" diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/expected.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/expected.cc new file mode 100644 index 0000000000..1cff2195db --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/expected.cc @@ -0,0 +1 @@ +#include "test/utils/doctest/fmt/expected.h" diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/map.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/map.cc new file mode 100644 index 0000000000..976e65cfca --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/map.cc @@ -0,0 +1 @@ +#include "test/utils/doctest/fmt/map.h" diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/multiset.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/multiset.cc new file mode 100644 index 0000000000..9c5b2f4d1e --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/multiset.cc @@ -0,0 +1 @@ +#include "test/utils/doctest/fmt/multiset.h" diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/optional.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/optional.cc new file mode 100644 index 0000000000..8a3f7f158e --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/optional.cc @@ -0,0 +1 @@ +#include "test/utils/doctest/fmt/optional.h" diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/pair.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/pair.cc new file mode 100644 index 0000000000..106fb1c900 --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/pair.cc @@ -0,0 +1 @@ +#include "test/utils/doctest/fmt/pair.h" diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/set.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/set.cc new file mode 100644 index 0000000000..9ec70698bc --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/set.cc @@ -0,0 +1 @@ +#include "test/utils/doctest/fmt/set.h" diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/unordered_map.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/unordered_map.cc new file mode 100644 index 0000000000..b893e632ed --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/unordered_map.cc @@ -0,0 +1 @@ +#include "test/utils/doctest/fmt/unordered_map.h" diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/unordered_multiset.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/unordered_multiset.cc new file mode 100644 index 0000000000..55d2e69056 --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/unordered_multiset.cc @@ -0,0 +1 @@ +#include "test/utils/doctest/fmt/unordered_multiset.h" diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/unordered_set.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/unordered_set.cc new file mode 100644 index 0000000000..13ad811e63 --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/unordered_set.cc @@ -0,0 +1 @@ +#include "test/utils/doctest/fmt/unordered_set.h" diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/variant.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/variant.cc new file mode 100644 index 0000000000..b6cc4f54e4 --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/variant.cc @@ -0,0 +1 @@ +#include "test/utils/doctest/fmt/variant.h" diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/vector.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/vector.cc new file mode 100644 index 0000000000..0102cd86da --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/vector.cc @@ -0,0 +1 @@ +#include "test/utils/doctest/fmt/vector.h" diff --git a/lib/utils/test/src/utils/bidict/algorithms/bidict_from_enumerating.cc b/lib/utils/test/src/utils/bidict/algorithms/bidict_from_enumerating.cc index 6e3ac8c155..b5a373e5c9 100644 --- a/lib/utils/test/src/utils/bidict/algorithms/bidict_from_enumerating.cc +++ b/lib/utils/test/src/utils/bidict/algorithms/bidict_from_enumerating.cc @@ -1,7 +1,7 @@ #include "utils/bidict/algorithms/bidict_from_enumerating.h" +#include "test/utils/doctest/fmt/unordered_set.h" #include "utils/bidict/algorithms/left_entries.h" #include "utils/bidict/algorithms/right_entries.h" -#include "utils/fmt/unordered_set.h" #include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/bidict/bidict.cc b/lib/utils/test/src/utils/bidict/bidict.cc index 739d2b72c7..d158af129f 100644 --- a/lib/utils/test/src/utils/bidict/bidict.cc +++ b/lib/utils/test/src/utils/bidict/bidict.cc @@ -1,6 +1,8 @@ #include "utils/bidict/bidict.h" -#include "test/utils/doctest.h" -#include "utils/fmt/vector.h" +#include "test/utils/doctest/check_without_stringify.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/vector.h" +#include using namespace FlexFlow; diff --git a/lib/utils/test/src/utils/bidict/try_merge_nondisjoint_bidicts.cc b/lib/utils/test/src/utils/bidict/try_merge_nondisjoint_bidicts.cc index 2eb8f869f9..49fed81b29 100644 --- a/lib/utils/test/src/utils/bidict/try_merge_nondisjoint_bidicts.cc +++ b/lib/utils/test/src/utils/bidict/try_merge_nondisjoint_bidicts.cc @@ -1,6 +1,6 @@ #include "utils/bidict/try_merge_nondisjoint_bidicts.h" -#include "test/utils/doctest.h" -#include "utils/fmt/optional.h" +#include "test/utils/doctest/fmt/optional.h" +#include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/cli/cli_get_help_message.cc b/lib/utils/test/src/utils/cli/cli_get_help_message.cc new file mode 100644 index 0000000000..b3ee4d3318 --- /dev/null +++ b/lib/utils/test/src/utils/cli/cli_get_help_message.cc @@ -0,0 +1,519 @@ +#include "utils/cli/cli_get_help_message.h" +#include "utils/join_strings.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("cli_get_help_message(std::string, CLISpec)") { + std::string program_name = "prog_name"; + + SUBCASE("no flags or positional arguments") { + CLISpec cli = CLISpec{ + {}, + {}, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name\n"); + + CHECK(result == correct); + } + + SUBCASE("no flags") { + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + "pos-arg-1", + std::nullopt, + std::nullopt, + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name pos-arg-1\n" + "\n" + "positional arguments:\n" + " pos-arg-1\n"); + + CHECK(result == correct); + } + + SUBCASE("no positional arguments") { + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + "flag-1", + 'f', + std::nullopt, + }, + }, + {}, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name [-f]\n" + "\n" + "options:\n" + " -f, --flag-1\n"); + + CHECK(result == correct); + } + + SUBCASE("flag formatting") { + SUBCASE("flag with shortname") { + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + "flag", + 'f', + std::nullopt, + }, + }, + {}, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name [-f]\n" + "\n" + "options:\n" + " -f, --flag\n"); + + CHECK(result == correct); + } + + SUBCASE("flag without shortname") { + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + "flag", + std::nullopt, + std::nullopt, + }, + }, + {}, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name [--flag]\n" + "\n" + "options:\n" + " --flag\n"); + + CHECK(result == correct); + } + + SUBCASE("flags are displayed in provided order") { + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + "flag2", + std::nullopt, + std::nullopt, + }, + CLIFlagSpec{ + "flag1", + std::nullopt, + std::nullopt, + }, + }, + {}, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name [--flag2] [--flag1]\n" + "\n" + "options:\n" + " --flag2\n" + " --flag1\n"); + + CHECK(result == correct); + } + } + + SUBCASE("positional argument formatting") { + SUBCASE("without choices") { + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + "posarg", + std::nullopt, + std::nullopt, + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name posarg\n" + "\n" + "positional arguments:\n" + " posarg\n"); + + CHECK(result == correct); + } + + SUBCASE("with choices") { + SUBCASE("choices are not empty") { + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + "posarg", + std::vector{"red", "blue", "green"}, + std::nullopt, + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name {red,blue,green}\n" + "\n" + "positional arguments:\n" + " {red,blue,green}\n"); + + CHECK(result == correct); + } + + SUBCASE("choices are empty") { + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + "posarg", + std::vector{}, + std::nullopt, + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name {}\n" + "\n" + "positional arguments:\n" + " {}\n"); + + CHECK(result == correct); + } + } + + SUBCASE("are displayed in provided order") { + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + "posarg2", + std::nullopt, + std::nullopt, + }, + CLIPositionalArgumentSpec{ + "posarg1", + std::nullopt, + std::nullopt, + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name posarg2 posarg1\n" + "\n" + "positional arguments:\n" + " posarg2\n" + " posarg1\n"); + + CHECK(result == correct); + } + } + + SUBCASE("flag and positional argument alignment") { + SUBCASE("flags are longer") { + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + "flag1", + '1', + "flag1 description", + }, + CLIFlagSpec{ + "flag2-is-long", + std::nullopt, + "flag2-is-long description", + }, + }, + { + CLIPositionalArgumentSpec{ + "posarg", + std::nullopt, + "help text for posarg", + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = + ("usage: prog_name [-1] [--flag2-is-long] posarg\n" + "\n" + "positional arguments:\n" + " posarg help text for posarg\n" + "\n" + "options:\n" + " -1, --flag1 flag1 description\n" + " --flag2-is-long flag2-is-long description\n"); + + CHECK(result == correct); + } + + SUBCASE("pos args are longer") { + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + "flag1", + '1', + "flag1 description", + }, + }, + { + CLIPositionalArgumentSpec{ + "posarg1-is-very-long", + std::nullopt, + "help text for posarg1-is-very-long", + }, + CLIPositionalArgumentSpec{ + "posarg2", + std::nullopt, + "help text for posarg2", + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = + ("usage: prog_name [-1] posarg1-is-very-long posarg2\n" + "\n" + "positional arguments:\n" + " posarg1-is-very-long help text for posarg1-is-very-long\n" + " posarg2 help text for posarg2\n" + "\n" + "options:\n" + " -1, --flag1 flag1 description\n"); + + CHECK(result == correct); + } + + SUBCASE("line break behavior") { + SUBCASE("line breaks max out other argument alignments") { + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + "flag", + 'f', + "flag help text", + }, + }, + { + CLIPositionalArgumentSpec{ + "abcdefghijklmnopqrstuvwxyz0123456789", + std::nullopt, + "long arg help text", + }, + CLIPositionalArgumentSpec{ + "posarg", + std::nullopt, + "posarg help text", + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name [-f] " + "abcdefghijklmnopqrstuvwxyz0123456789 posarg\n" + "\n" + "positional arguments:\n" + " abcdefghijklmnopqrstuvwxyz0123456789\n" + " long arg help text\n" + " posarg posarg help text\n" + "\n" + "options:\n" + " -f, --flag flag help text\n"); + + CHECK(result == correct); + } + SUBCASE("positional argument line break behavior") { + SUBCASE("positional arguments cause a line break at or above " + "formatted-length 22") { + std::string arg_name = "aaaaaaaaaaaaaaaaaaaaaa"; + REQUIRE(arg_name.size() == 22); + + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + arg_name, + std::nullopt, + "help text", + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name aaaaaaaaaaaaaaaaaaaaaa\n" + "\n" + "positional arguments:\n" + " aaaaaaaaaaaaaaaaaaaaaa\n" + " help text\n"); + + CHECK(result == correct); + } + + SUBCASE("positional arguments do not cause a line break below " + "formatted-length 22") { + std::string arg_name = "aaaaaaaaaaaaaaaaaaaaa"; + REQUIRE(arg_name.size() == 21); + + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + arg_name, + std::nullopt, + "help text", + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name aaaaaaaaaaaaaaaaaaaaa\n" + "\n" + "positional arguments:\n" + " aaaaaaaaaaaaaaaaaaaaa\n" + " help text\n"); + } + } + + SUBCASE("flag line break behavior") { + SUBCASE("flags cause a line break at or above formatted-length 21") { + std::string arg_name = "bbbbbbbbbbbbbbb"; + { + std::string formatted = "-b, --" + arg_name; + REQUIRE(formatted.size() == 21); + } + + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + arg_name, + 'b', + "flag description", + }, + }, + {}, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = + ("usage: prog_name [-b]\n" + "\n" + "options:\n" + " -b, --bbbbbbbbbbbbbbb\n" + " flag description\n"); + + CHECK(result == correct); + } + + SUBCASE("flags do not cause a line break below formatted-length 21") { + std::string arg_name = "bbbbbbbbbbbbbb"; + { + std::string formatted = "-b, --" + arg_name; + REQUIRE(formatted.size() == 20); + } + + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + arg_name, + 'b', + "flag description", + }, + }, + {}, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = + ("usage: prog_name [-b]\n" + "\n" + "options:\n" + " -b, --bbbbbbbbbbbbbb flag description\n"); + + CHECK(result == correct); + } + } + + SUBCASE("choice line breakpoint formatting") { + SUBCASE( + "choices cause a line break at or above formatted-length 21") { + std::vector choices = { + "a", "b", "c", "d", "e", "fffffffff"}; + { + std::string formatted_choices = + "{" + join_strings(choices, ",") + "}"; + REQUIRE(formatted_choices.size() == 21); + } + + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + "posarg", + choices, + "help text", + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name {a,b,c,d,e,fffffffff}\n" + "\n" + "positional arguments:\n" + " {a,b,c,d,e,fffffffff}\n" + " help text\n"); + + CHECK(result == correct); + } + + SUBCASE( + "choices do not cause a line break below formatted-length 21") { + std::vector choices = { + "a", "b", "c", "d", "e", "ffffffff"}; + { + std::string formatted_choices = + "{" + join_strings(choices, ",") + "}"; + REQUIRE(formatted_choices.size() == 20); + } + + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + "posarg", + choices, + "help text", + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name {a,b,c,d,e,ffffffff}\n" + "\n" + "positional arguments:\n" + " {a,b,c,d,e,ffffffff} help text\n"); + + CHECK(result == correct); + } + } + } + } + } +} diff --git a/lib/utils/test/src/utils/cli/cli_parse.cc b/lib/utils/test/src/utils/cli/cli_parse.cc new file mode 100644 index 0000000000..40dea86ae0 --- /dev/null +++ b/lib/utils/test/src/utils/cli/cli_parse.cc @@ -0,0 +1,477 @@ +#include "utils/cli/cli_parse.h" +#include "test/utils/doctest/fmt/expected.h" +#include "test/utils/doctest/fmt/optional.h" +#include "utils/expected.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("cli_parse_flag(CLISpec, std::string)") { + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + "flag1", + std::nullopt, + std::nullopt, + }, + CLIFlagSpec{ + "flag2", + '2', + std::nullopt, + }, + }, + {}, + }; + + CLIFlagKey key_flag1 = CLIFlagKey{0}; + CLIFlagKey key_flag2 = CLIFlagKey{1}; + + SUBCASE("correctly parses short flag") { + std::string input = "-2"; + + tl::expected result = cli_parse_flag(cli, input); + tl::expected correct = key_flag2; + + CHECK(result == correct); + } + + SUBCASE("correctly parses long flag") { + std::string input = "--flag1"; + + tl::expected result = cli_parse_flag(cli, input); + tl::expected correct = key_flag1; + + CHECK(result == correct); + } + + SUBCASE("fails on unknown flag") { + std::string input = "--not-real"; + + tl::expected result = cli_parse_flag(cli, input); + tl::expected correct = + tl::unexpected("Encountered unknown flag --not-real"); + + CHECK(result == correct); + } + + SUBCASE("fails on non-flag") { + std::string input = "-flag1"; + + std::optional result = + optional_from_expected(cli_parse_flag(cli, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + + TEST_CASE("cli_parse(CLISpec, std::vector)") { + SUBCASE("works even if cli is empty") { + CLISpec cli = CLISpec{{}, {}}; + std::vector inputs = {"prog_name"}; + + tl::expected result = cli_parse(cli, inputs); + tl::expected correct = + CLIParseResult{{}, {}}; + + CHECK(result == correct); + } + + SUBCASE("flag parsing") { + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + "flag1", + std::nullopt, + std::nullopt, + }, + CLIFlagSpec{ + "flag2", + '2', + std::nullopt, + }, + }, + {}, + }; + CLIFlagKey key_flag1 = CLIFlagKey{0}; + CLIFlagKey key_flag2 = CLIFlagKey{1}; + + SUBCASE("parses flags in any order") { + std::vector inputs = {"prog_name", "-2", "--flag1"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = CLIParseResult{ + { + {key_flag1, true}, + {key_flag2, true}, + }, + {}, + }; + + CHECK(result == correct); + } + + SUBCASE("is fine if some are not present") { + std::vector inputs = {"prog_name", "-2"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = CLIParseResult{ + { + {key_flag1, false}, + {key_flag2, true}, + }, + {}, + }; + + CHECK(result == correct); + } + + SUBCASE("is fine if none are present") { + std::vector inputs = {"prog_name"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = CLIParseResult{ + { + {key_flag1, false}, + {key_flag2, false}, + }, + {}, + }; + + CHECK(result == correct); + } + + SUBCASE("is fine even if the program name is a flag") { + std::vector inputs = {"--flag1", "-2"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = CLIParseResult{ + { + {key_flag1, false}, + {key_flag2, true}, + }, + {}, + }; + + CHECK(result == correct); + } + } + + SUBCASE("positional argument parsing") { + SUBCASE("without choices") { + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + "posarg1", + std::nullopt, + std::nullopt, + }, + CLIPositionalArgumentSpec{ + "posarg2", + std::nullopt, + std::nullopt, + }, + }, + }; + + CLIPositionalArgumentKey key_posarg1 = CLIPositionalArgumentKey{0}; + CLIPositionalArgumentKey key_posarg2 = CLIPositionalArgumentKey{1}; + + SUBCASE("can parse multiple positional arguments") { + std::vector inputs = {"prog_name", "hello", "world"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = + CLIParseResult{{}, + { + {key_posarg1, "hello"}, + {key_posarg2, "world"}, + }}; + + CHECK(result == correct); + } + + SUBCASE("requires all positional arguments to be present") { + std::vector inputs = {"prog_name", "hello"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = tl::unexpected( + "Not enough positional arguments: found 1, expected 2"); + + CHECK(result == correct); + } + + SUBCASE("requires no extra positional arguments to be present") { + std::vector inputs = { + "prog_name", "hello", "there", "world"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = + tl::unexpected("Too many positional arguments: expected 2"); + + CHECK(result == correct); + } + + SUBCASE("allows arguments to contain spaces") { + std::vector inputs = { + "prog_name", "hello there", "world"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = + CLIParseResult{{}, + { + {key_posarg1, "hello there"}, + {key_posarg2, "world"}, + }}; + + CHECK(result == correct); + } + + SUBCASE("allows arguments to be empty") { + std::vector inputs = {"prog_name", "hello", ""}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = + CLIParseResult{{}, + { + {key_posarg1, "hello"}, + {key_posarg2, ""}, + }}; + + CHECK(result == correct); + } + } + + SUBCASE("with choices") { + SUBCASE("choices is non-empty") { + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + "posarg", + std::vector{"red", "blue", "green"}, + std::nullopt, + }, + }, + }; + + CLIPositionalArgumentKey key_posarg = CLIPositionalArgumentKey{0}; + + SUBCASE( + "succeeds if a positional argument is set to a valid choice") { + std::vector inputs = {"prog_name", "blue"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = CLIParseResult{ + {}, + { + {key_posarg, "red"}, + }, + }; + } + + SUBCASE( + "fails if a positional argument is set to an invalid choice") { + std::vector inputs = {"prog_name", " red"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = tl::unexpected( + "Invalid option for positional argument \"posarg\": \" red\""); + + CHECK(result == correct); + } + } + + SUBCASE("if choices is empty, rejects everything") { + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + "posarg", + std::vector{}, + std::nullopt, + }, + }, + }; + + std::vector inputs = {"prog_name", ""}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = tl::unexpected( + "Invalid option for positional argument \"posarg\": \"\""); + + CHECK(result == correct); + } + } + } + + SUBCASE("correctly differentiates mixed arguments/flags") { + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + "flag1", + 'f', + std::nullopt, + }, + CLIFlagSpec{ + "flag2", + std::nullopt, + std::nullopt, + }, + CLIFlagSpec{ + "flag3", + 'a', + std::nullopt, + }, + }, + { + CLIPositionalArgumentSpec{ + "posarg1", + std::vector{"red", "blue", "green"}, + std::nullopt, + }, + CLIPositionalArgumentSpec{ + "posarg2", + std::nullopt, + std::nullopt, + }, + }, + }; + CLIFlagKey key_flag1 = CLIFlagKey{0}; + CLIFlagKey key_flag2 = CLIFlagKey{1}; + CLIFlagKey key_flag3 = CLIFlagKey{2}; + CLIPositionalArgumentKey key_posarg1 = CLIPositionalArgumentKey{0}; + CLIPositionalArgumentKey key_posarg2 = CLIPositionalArgumentKey{1}; + + SUBCASE("works if flags are before positional arguments") { + std::vector inputs = { + "prog_name", "-f", "--flag3", "red", "world"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = CLIParseResult{ + { + {key_flag1, true}, + {key_flag2, false}, + {key_flag3, true}, + }, + { + {key_posarg1, "red"}, + {key_posarg2, "world"}, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("works if flags are interspersed") { + std::vector inputs = { + "prog_name", "red", "-f", "world", "--flag3"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = CLIParseResult{ + { + {key_flag1, true}, + {key_flag2, false}, + {key_flag3, true}, + }, + { + {key_posarg1, "red"}, + {key_posarg2, "world"}, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("detects if posargs are missing instead of treating flags as " + "posarg values") { + std::vector inputs = {"prog_name", "-f", "red", "--flag2"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = tl::unexpected( + "Not enough positional arguments: found 1, expected 2"); + + CHECK(result == correct); + } + } + } + + TEST_CASE("cli_parse(CLISpec, int argc, char const * const *argv)") { + // most cases are checked in the other overload, + // i.e., cli_parse(CLISpec, std::vector), + // so here we just throw in a single check to make sure + // nothing has unexpectedly gone wrong + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + "flag1", + 'f', + std::nullopt, + }, + CLIFlagSpec{ + "flag2", + std::nullopt, + std::nullopt, + }, + CLIFlagSpec{ + "flag3", + 'a', + std::nullopt, + }, + }, + { + CLIPositionalArgumentSpec{ + "posarg1", + std::vector{"red", "blue", "green"}, + std::nullopt, + }, + CLIPositionalArgumentSpec{ + "posarg2", + std::nullopt, + std::nullopt, + }, + }, + }; + CLIFlagKey key_flag1 = CLIFlagKey{0}; + CLIFlagKey key_flag2 = CLIFlagKey{1}; + CLIFlagKey key_flag3 = CLIFlagKey{2}; + CLIPositionalArgumentKey key_posarg1 = CLIPositionalArgumentKey{0}; + CLIPositionalArgumentKey key_posarg2 = CLIPositionalArgumentKey{1}; + + int argc = 5; + char const *argv[] = {"prog_name", "red", "-f", "world", "--flag3"}; + + tl::expected result = + cli_parse(cli, argc, argv); + tl::expected correct = CLIParseResult{ + { + {key_flag1, true}, + {key_flag2, false}, + {key_flag3, true}, + }, + { + {key_posarg1, "red"}, + {key_posarg2, "world"}, + }, + }; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/containers/are_all_same.cc b/lib/utils/test/src/utils/containers/are_all_same.cc index df671f8541..fd8b321439 100644 --- a/lib/utils/test/src/utils/containers/are_all_same.cc +++ b/lib/utils/test/src/utils/containers/are_all_same.cc @@ -5,20 +5,32 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("are_all_same") { - SUBCASE("All elements are the same") { - std::vector input = {2, 2, 2, 2}; - CHECK(are_all_same(input)); + TEST_CASE("are_all_same(std::vector)") { + SUBCASE("input is empty") { + std::vector input = {}; + + bool result = are_all_same(input); + bool correct = true; + + CHECK(result == correct); } - SUBCASE("Not all elements are the same") { - std::vector input = {1, 2, 3, 4}; - CHECK_FALSE(are_all_same(input)); + SUBCASE("input elements are all same") { + std::vector input = {1, 1, 1}; + + bool result = are_all_same(input); + bool correct = true; + + CHECK(result == correct); } - SUBCASE("Empty Container") { - std::vector input = {}; - CHECK(are_all_same(input)); + SUBCASE("input elements are not all same") { + std::vector input = {1, 1, 2, 1}; + + bool result = are_all_same(input); + bool correct = false; + + CHECK(result == correct); } } } diff --git a/lib/utils/test/src/utils/containers/cartesian_product.cc b/lib/utils/test/src/utils/containers/cartesian_product.cc new file mode 100644 index 0000000000..773d94c8d0 --- /dev/null +++ b/lib/utils/test/src/utils/containers/cartesian_product.cc @@ -0,0 +1,71 @@ +#include "utils/containers/cartesian_product.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" +#include "test/utils/doctest/fmt/vector.h" +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("cartesian_product") { + + SUBCASE("empty") { + std::vector> containers = {}; + std::unordered_multiset> result = + cartesian_product(containers); + std::unordered_multiset> correct = {{}}; + CHECK(result == correct); + } + + SUBCASE("single container, one element") { + std::vector> containers = {{1}}; + std::unordered_multiset> result = + cartesian_product(containers); + std::unordered_multiset> correct = {{1}}; + CHECK(result == correct); + } + + SUBCASE("single container, multiple elements") { + std::vector> containers = {{1, 2, 3}}; + std::unordered_multiset> result = + cartesian_product(containers); + std::unordered_multiset> correct = {{1}, {2}, {3}}; + CHECK(result == correct); + } + + SUBCASE("multiple containers, one element each") { + std::vector> containers = {{1}, {2}, {3}}; + std::unordered_multiset> result = + cartesian_product(containers); + std::unordered_multiset> correct = {{1, 2, 3}}; + CHECK(result == correct); + } + + SUBCASE("multiple containers, multiple elements") { + std::vector> containers = {{1, 2}, {3, 4}}; + std::unordered_multiset> result = + cartesian_product(containers); + std::unordered_multiset> correct = { + {1, 3}, {1, 4}, {2, 3}, {2, 4}}; + CHECK(result == correct); + } + + SUBCASE("multiple containers, duplicate elements") { + std::vector> containers = {{1, 1}, {2, 3}}; + std::unordered_multiset> result = + cartesian_product(containers); + std::unordered_multiset> correct = { + {1, 2}, {1, 3}, {1, 3}, {1, 2}}; + CHECK(result == correct); + } + + SUBCASE("1 empty container, 1 non-empty container") { + std::vector> containers = {{}, {2, 3}}; + std::unordered_multiset> result = + cartesian_product(containers); + std::unordered_multiset> correct = {}; + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/contains_key.cc b/lib/utils/test/src/utils/containers/contains_key.cc index acc6551cd4..da099113a6 100644 --- a/lib/utils/test/src/utils/containers/contains_key.cc +++ b/lib/utils/test/src/utils/containers/contains_key.cc @@ -1,8 +1,11 @@ #include "utils/containers/contains_key.h" -#include "test/utils/doctest.h" +#include #include +#include #include +using namespace ::FlexFlow; + TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("contains_key(std::unordered_map, K)") { std::unordered_map m = { diff --git a/lib/utils/test/src/utils/containers/enumerate.cc b/lib/utils/test/src/utils/containers/enumerate.cc index 61d70f5aff..fb55643fb3 100644 --- a/lib/utils/test/src/utils/containers/enumerate.cc +++ b/lib/utils/test/src/utils/containers/enumerate.cc @@ -1,16 +1,14 @@ #include "utils/containers/enumerate.h" -#include "utils/containers/as_vector.h" +#include "test/utils/doctest/fmt/map.h" +#include "test/utils/doctest/fmt/pair.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/vector.h" #include "utils/containers/keys.h" -#include "utils/containers/unordered_multiset_of.h" +#include "utils/containers/unordered_set_of.h" #include "utils/containers/values.h" -#include "utils/fmt/map.h" -#include "utils/fmt/pair.h" -#include "utils/fmt/unordered_multiset.h" -#include "utils/fmt/unordered_set.h" -#include "utils/fmt/vector.h" +#include "utils/containers/vector_of.h" #include #include -#include using namespace ::FlexFlow; @@ -31,7 +29,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("check iteration order") { std::vector> iterated_result = - as_vector(result); + vector_of(result); std::vector> correct_iteration_order = { {0, "zero"}, {1, "one"}, diff --git a/lib/utils/test/src/utils/containers/extend.cc b/lib/utils/test/src/utils/containers/extend.cc index e0d156a3fc..ef2a67725c 100644 --- a/lib/utils/test/src/utils/containers/extend.cc +++ b/lib/utils/test/src/utils/containers/extend.cc @@ -1,6 +1,6 @@ #include "utils/containers/extend.h" -#include "utils/fmt/unordered_set.h" -#include "utils/fmt/vector.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/vector.h" #include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/containers/filter.cc b/lib/utils/test/src/utils/containers/filter.cc index da459094ef..9462d30024 100644 --- a/lib/utils/test/src/utils/containers/filter.cc +++ b/lib/utils/test/src/utils/containers/filter.cc @@ -1,10 +1,11 @@ #include "utils/containers/filter.h" -#include "test/utils/all.h" -#include "utils/fmt/map.h" -#include "utils/fmt/set.h" -#include "utils/fmt/unordered_map.h" -#include "utils/fmt/unordered_set.h" -#include "utils/fmt/vector.h" +#include "test/utils/doctest/fmt/map.h" +#include "test/utils/doctest/fmt/set.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/vector.h" +#include "test/utils/rapidcheck.h" using namespace ::FlexFlow; @@ -95,4 +96,13 @@ TEST_SUITE(FF_TEST_SUITE) { }; CHECK(result == correct); } + + TEST_CASE("filter(std::unordered_multiset, F)") { + std::unordered_multiset input = {1, 1, 2, 2, 2, 3, 4, 5, 6, 7, 8, 8}; + auto predicate = [](int x) { return x % 2 == 0; }; + + std::unordered_multiset result = filter(input, predicate); + std::unordered_multiset correct = {2, 2, 2, 4, 6, 8, 8}; + CHECK(result == correct); + } } diff --git a/lib/utils/test/src/utils/containers/filtermap_keys.cc b/lib/utils/test/src/utils/containers/filtermap_keys.cc index 758264627b..582e94392b 100644 --- a/lib/utils/test/src/utils/containers/filtermap_keys.cc +++ b/lib/utils/test/src/utils/containers/filtermap_keys.cc @@ -1,7 +1,7 @@ #include "utils/containers/filtermap_keys.h" -#include "test/utils/doctest.h" -#include "utils/fmt/map.h" -#include "utils/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/map.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include using namespace FlexFlow; diff --git a/lib/utils/test/src/utils/containers/filtermap_values.cc b/lib/utils/test/src/utils/containers/filtermap_values.cc index d2b6ddd220..8db6d6a964 100644 --- a/lib/utils/test/src/utils/containers/filtermap_values.cc +++ b/lib/utils/test/src/utils/containers/filtermap_values.cc @@ -1,7 +1,7 @@ #include "utils/containers/filtermap_values.h" -#include "test/utils/doctest.h" -#include "utils/fmt/map.h" -#include "utils/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/map.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include using namespace FlexFlow; diff --git a/lib/utils/test/src/utils/containers/filtrans.cc b/lib/utils/test/src/utils/containers/filtrans.cc index b8bb832b06..cd1c2f896c 100644 --- a/lib/utils/test/src/utils/containers/filtrans.cc +++ b/lib/utils/test/src/utils/containers/filtrans.cc @@ -1,7 +1,7 @@ #include "utils/containers/filtrans.h" -#include "utils/fmt/set.h" -#include "utils/fmt/unordered_set.h" -#include "utils/fmt/vector.h" +#include "test/utils/doctest/fmt/set.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/vector.h" #include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/containers/flatmap.cc b/lib/utils/test/src/utils/containers/flatmap.cc index f96175b879..25dcea38dc 100644 --- a/lib/utils/test/src/utils/containers/flatmap.cc +++ b/lib/utils/test/src/utils/containers/flatmap.cc @@ -1,12 +1,16 @@ #include "utils/containers/flatmap.h" -#include "utils/fmt/vector.h" +#include "test/utils/doctest/fmt/pair.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "utils/containers/map_keys.h" +#include "utils/hash/pair.h" #include -#include +#include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("flatmap") { + TEST_CASE("flatmap(std::vector, F)") { SUBCASE("same data-type") { auto get_factors = [](int x) -> std::vector { // Returns a vector of factors of x @@ -37,5 +41,97 @@ TEST_SUITE(FF_TEST_SUITE) { "1", "2", "4", "3", "4", "8", "9", "10", "20"}; CHECK(result == correct); } + + TEST_CASE("flatmap(std::unordered_set, F)") { + auto get_chars = [](std::string const &s) { + std::unordered_set result; + for (char c : s) { + result.insert(c); + } + return result; + }; + + SUBCASE("type changing") { + std::unordered_set input = {"hello", " ", "", "world", "!"}; + + std::unordered_set result = flatmap(input, get_chars); + std::unordered_set correct = { + 'h', 'e', 'l', 'o', ' ', 'w', 'r', 'd', '!'}; + + CHECK(result == correct); + } + + SUBCASE("input is empty") { + std::unordered_set input = {}; + + std::unordered_set result = flatmap(input, get_chars); + std::unordered_set correct = {}; + + CHECK(result == correct); + } + } + + TEST_CASE("flatmap(std::unordered_map, F)") { + auto de_nest_keys = [](int k1, + std::unordered_map const &v) { + return map_keys(v, [&](int k2) { return std::pair{k1, k2}; }); + }; + + SUBCASE("input is empty") { + std::unordered_map> input = {}; + + std::unordered_map, std::string> result = + flatmap(input, de_nest_keys); + std::unordered_map, std::string> correct = {}; + + CHECK(result == correct); + } + + SUBCASE("input is not empty") { + std::unordered_map> input = { + { + 1, + { + {2, "a"}, + {3, "b"}, + }, + }, + { + 2, + {}, + }, + { + 3, + { + {3, "a"}, + }, + }, + }; + + std::unordered_map, std::string> result = + flatmap(input, de_nest_keys); + std::unordered_map, std::string> correct = { + {{1, 2}, "a"}, + {{1, 3}, "b"}, + {{3, 3}, "a"}, + }; + + CHECK(result == correct); + } + + SUBCASE("duplicate result keys") { + auto always_return_same_map = [](int, std::string const &) { + return std::unordered_map{ + {"mykey", 10000}, + }; + }; + + std::unordered_map input = { + {1, "a"}, + {2, "b"}, + }; + + CHECK_THROWS(flatmap(input, always_return_same_map)); + } } } diff --git a/lib/utils/test/src/utils/containers/foldl.cc b/lib/utils/test/src/utils/containers/foldl.cc new file mode 100644 index 0000000000..9ed9768a92 --- /dev/null +++ b/lib/utils/test/src/utils/containers/foldl.cc @@ -0,0 +1,47 @@ +#include "utils/containers/foldl.h" +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("foldl") { + SUBCASE("product") { + std::vector container = {1, 2, 3, 4, 5}; + int result = + foldl(container, 1, [](int acc, int elem) { return acc * elem; }); + int correct = 120; + CHECK(result == correct); + } + + SUBCASE("string concat") { + std::vector container = {1, 2, 3, 4, 5}; + std::string result = + foldl(container, std::string(""), [](std::string acc, int elem) { + return acc + std::to_string(elem); + }); + std::string correct = "12345"; + CHECK(result == correct); + } + } + + TEST_CASE("foldl1") { + SUBCASE("product") { + std::vector container = {1, 2, 3, 4, 5}; + int result = + foldl1(container, [](int acc, int elem) { return acc * elem; }); + int correct = 120; + CHECK(result == correct); + } + + SUBCASE("string concat") { + std::vector container = {"1", "2", "3", "4", "5"}; + std::string result = + foldl1(container, + [](std::string acc, std::string elem) { return acc + elem; }); + std::string correct = "12345"; + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/foldl1.cc b/lib/utils/test/src/utils/containers/foldl1.cc new file mode 100644 index 0000000000..597aa5e109 --- /dev/null +++ b/lib/utils/test/src/utils/containers/foldl1.cc @@ -0,0 +1,27 @@ +#include "utils/containers/foldl1.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("foldl1(std::vector, F)") { + auto concat = [](std::string const &accum, std::string const &s) { + return accum + s; + }; + + SUBCASE("empty input") { + std::vector input = {}; + CHECK_THROWS(foldl1(input, concat)); + } + + SUBCASE("non-empty input") { + std::vector input = {"a s", "tr", "ing"}; + + std::string result = foldl1(input, concat); + + std::string correct = "a string"; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/foldr1.cc b/lib/utils/test/src/utils/containers/foldr1.cc new file mode 100644 index 0000000000..3c9d9b66ae --- /dev/null +++ b/lib/utils/test/src/utils/containers/foldr1.cc @@ -0,0 +1,27 @@ +#include "utils/containers/foldr1.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("foldr1(std::vector, F)") { + auto concat = [](std::string const &accum, std::string const &s) { + return accum + s; + }; + + SUBCASE("empty input") { + std::vector input = {}; + CHECK_THROWS(foldr1(input, concat)); + } + + SUBCASE("non-empty input") { + std::vector input = {"ing", "tr", "a s"}; + + std::string result = foldr1(input, concat); + + std::string correct = "a string"; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/get_all_assignments.cc b/lib/utils/test/src/utils/containers/get_all_assignments.cc new file mode 100644 index 0000000000..d5f989318f --- /dev/null +++ b/lib/utils/test/src/utils/containers/get_all_assignments.cc @@ -0,0 +1,53 @@ +#include "utils/containers/get_all_assignments.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_all_assignments") { + SUBCASE("empty input") { + std::unordered_map> input = {}; + + std::unordered_set> result = + get_all_assignments(input); + std::unordered_set> correct = {{}}; + + CHECK(result == correct); + } + + SUBCASE("non-empty input") { + std::unordered_map> input = { + {"a", {1, 2, 3}}, + {"b", {2, 3}}, + }; + + std::unordered_set> result = + get_all_assignments(input); + std::unordered_set> correct = { + {{"a", 1}, {"b", 2}}, + {{"a", 1}, {"b", 3}}, + {{"a", 2}, {"b", 2}}, + {{"a", 2}, {"b", 3}}, + {{"a", 3}, {"b", 2}}, + {{"a", 3}, {"b", 3}}, + }; + + CHECK(result == correct); + } + + SUBCASE("one possible-values set is empty") { + std::unordered_map> input = { + {"a", {}}, + {"b", {2, 3}}, + }; + + std::unordered_set> result = + get_all_assignments(input); + std::unordered_set> correct = {}; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/get_all_permutations.cc b/lib/utils/test/src/utils/containers/get_all_permutations.cc index 5f22266809..cc5edb4075 100644 --- a/lib/utils/test/src/utils/containers/get_all_permutations.cc +++ b/lib/utils/test/src/utils/containers/get_all_permutations.cc @@ -1,8 +1,7 @@ #include "utils/containers/get_all_permutations.h" -#include "utils/containers/as_vector.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" +#include "test/utils/doctest/fmt/vector.h" #include "utils/containers/unordered_multiset_of.h" -#include "utils/fmt/unordered_multiset.h" -#include "utils/fmt/vector.h" #include "utils/hash/vector.h" #include diff --git a/lib/utils/test/src/utils/containers/get_all_permutations_with_repetition.cc b/lib/utils/test/src/utils/containers/get_all_permutations_with_repetition.cc new file mode 100644 index 0000000000..f25bcf65b1 --- /dev/null +++ b/lib/utils/test/src/utils/containers/get_all_permutations_with_repetition.cc @@ -0,0 +1,75 @@ +#include "utils/containers/get_all_permutations_with_repetition.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" +#include "test/utils/doctest/fmt/vector.h" +#include "utils/hash/vector.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("get_all_permutations_with_repetition") { + SUBCASE("output vector has only one element") { + std::vector input = {1, 2, 3}; + + std::unordered_multiset> result = + get_all_permutations_with_repetition(input, 1); + std::unordered_multiset> correct = { + {1}, + {2}, + {3}, + }; + + CHECK(result == correct); + } + + SUBCASE("input vector has only one element") { + std::vector input = {1}; + + std::unordered_multiset> result = + get_all_permutations_with_repetition(input, 2); + std::unordered_multiset> correct = { + {1, 1}, + }; + + CHECK(result == correct); + } + + SUBCASE("input, output vectors have more than 1 element") { + std::vector input = {1, 2}; + + std::unordered_multiset> result = + get_all_permutations_with_repetition(input, 3); + std::unordered_multiset> correct = { + {1, 1, 1}, + {1, 1, 2}, + {1, 2, 1}, + {1, 2, 2}, + {2, 1, 1}, + {2, 1, 2}, + {2, 2, 1}, + {2, 2, 2}, + }; + + CHECK(result == correct); + } + + SUBCASE("duplicate elements") { + std::vector input = {1, 2, 2}; + + std::unordered_multiset> result = + get_all_permutations_with_repetition(input, 2); + std::unordered_multiset> correct = {{1, 1}, + {1, 2}, + {1, 2}, + {2, 1}, + {2, 1}, + {2, 2}, + {2, 2}, + {2, 2}, + {2, 2}}; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/get_element_counts.cc b/lib/utils/test/src/utils/containers/get_element_counts.cc index 11e2ef7e05..8fc87dba90 100644 --- a/lib/utils/test/src/utils/containers/get_element_counts.cc +++ b/lib/utils/test/src/utils/containers/get_element_counts.cc @@ -1,5 +1,5 @@ #include "utils/containers/get_element_counts.h" -#include "utils/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/unordered_map.h" #include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/containers/inplace_filter.cc b/lib/utils/test/src/utils/containers/inplace_filter.cc index 7ef9d73339..ac430279b0 100644 --- a/lib/utils/test/src/utils/containers/inplace_filter.cc +++ b/lib/utils/test/src/utils/containers/inplace_filter.cc @@ -1,10 +1,11 @@ #include "utils/containers/inplace_filter.h" -#include "test/utils/all.h" -#include "utils/fmt/map.h" -#include "utils/fmt/set.h" -#include "utils/fmt/unordered_map.h" -#include "utils/fmt/unordered_set.h" -#include "utils/fmt/vector.h" +#include "test/utils/doctest/fmt/map.h" +#include "test/utils/doctest/fmt/set.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/vector.h" +#include "test/utils/rapidcheck.h" +#include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/containers/intersection.cc b/lib/utils/test/src/utils/containers/intersection.cc index ac9acf5e2b..52de6ee6d3 100644 --- a/lib/utils/test/src/utils/containers/intersection.cc +++ b/lib/utils/test/src/utils/containers/intersection.cc @@ -1,6 +1,6 @@ #include "utils/containers/intersection.h" -#include "utils/fmt/optional.h" -#include "utils/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/optional.h" +#include "test/utils/doctest/fmt/unordered_set.h" #include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/containers/multiset_union.cc b/lib/utils/test/src/utils/containers/multiset_union.cc new file mode 100644 index 0000000000..8c40bf55ab --- /dev/null +++ b/lib/utils/test/src/utils/containers/multiset_union.cc @@ -0,0 +1,29 @@ +#include "utils/containers/multiset_union.h" +#include "test/utils/doctest/fmt/multiset.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("multiset_union(std::unordered_multiset, " + "std::unordered_multiset)") { + std::unordered_multiset input_lhs = {1, 2, 2, 3}; + std::unordered_multiset input_rhs = {1, 2, 5}; + + std::unordered_multiset result = multiset_union(input_lhs, input_rhs); + std::unordered_multiset correct = {1, 1, 2, 2, 2, 3, 5}; + + CHECK(result == correct); + } + + TEST_CASE("multiset_union(std::multiset, std::multiset)") { + std::multiset input_lhs = {1, 2, 2, 3}; + std::multiset input_rhs = {1, 2, 5}; + + std::multiset result = multiset_union(input_lhs, input_rhs); + std::multiset correct = {1, 1, 2, 2, 2, 3, 5}; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/containers/range.cc b/lib/utils/test/src/utils/containers/range.cc new file mode 100644 index 0000000000..f115855323 --- /dev/null +++ b/lib/utils/test/src/utils/containers/range.cc @@ -0,0 +1,54 @@ +#include "utils/containers/range.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/vector.h" +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("range") { + SUBCASE("step=1") { + std::vector result = range(0, 5); + std::vector correct = {0, 1, 2, 3, 4}; + CHECK(result == correct); + } + + SUBCASE("step = 2") { + std::vector result = range(-2, 10, 2); + std::vector correct = {-2, 0, 2, 4, 6, 8}; + CHECK(result == correct); + } + + SUBCASE("step = -1") { + std::vector result = range(5, 0, -1); + std::vector correct = {5, 4, 3, 2, 1}; + CHECK(result == correct); + } + + SUBCASE("single argument") { + std::vector result = range(5); + std::vector correct = {0, 1, 2, 3, 4}; + CHECK(result == correct); + } + + SUBCASE("start = end") { + std::vector result = range(5, 5); + std::vector correct = {}; + CHECK(result == correct); + } + + SUBCASE("start > end") { + std::vector result = range(5, 4); + std::vector correct = {}; + CHECK(result == correct); + } + + SUBCASE("start < end, step < 0") { + std::vector result = range(0, 10, -1); + std::vector correct = {}; + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/repeat.cc b/lib/utils/test/src/utils/containers/repeat.cc index 50e4b3e7c5..d8ffe76a64 100644 --- a/lib/utils/test/src/utils/containers/repeat.cc +++ b/lib/utils/test/src/utils/containers/repeat.cc @@ -1,5 +1,5 @@ #include "utils/containers/repeat.h" -#include "utils/fmt/vector.h" +#include "test/utils/doctest/fmt/vector.h" #include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/containers/replicate.cc b/lib/utils/test/src/utils/containers/replicate.cc new file mode 100644 index 0000000000..1c7845642e --- /dev/null +++ b/lib/utils/test/src/utils/containers/replicate.cc @@ -0,0 +1,25 @@ +#include "utils/containers/replicate.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/vector.h" +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("replicate") { + SUBCASE("ints") { + int x = 42; + std::vector result = replicate(5, x); + std::vector correct = {42, 42, 42, 42, 42}; + CHECK(result == correct); + } + SUBCASE("unordered_set") { + std::unordered_set x = {1.0, 1.5}; + std::vector> result = replicate(3, x); + std::vector> correct = { + {1.0, 1.5}, {1.0, 1.5}, {1.0, 1.5}}; + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/require_all_same1.cc b/lib/utils/test/src/utils/containers/require_all_same1.cc new file mode 100644 index 0000000000..48c1ab0b99 --- /dev/null +++ b/lib/utils/test/src/utils/containers/require_all_same1.cc @@ -0,0 +1,54 @@ +#include "utils/containers/require_all_same1.h" +#include "test/utils/doctest/fmt/expected.h" +#include "test/utils/doctest/fmt/multiset.h" +#include "test/utils/doctest/fmt/optional.h" +#include "test/utils/doctest/fmt/set.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/vector.h" +#include "utils/expected.h" +#include +#include +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE_TEMPLATE("require_all_same1(T)", + T, + std::vector, + std::unordered_set, + std::unordered_multiset, + std::set, + std::multiset) { + SUBCASE("input is empty") { + T input = {}; + + std::optional result = + optional_from_expected(require_all_same1(input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("input elements are all the same") { + T input = {1, 1, 1}; + + tl::expected result = require_all_same1(input); + tl::expected correct = 1; + + CHECK(result == correct); + } + + SUBCASE("input elements are not all the same") { + T input = {1, 1, 2, 1}; + + std::optional result = + optional_from_expected(require_all_same1(input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/require_no_duplicates.cc b/lib/utils/test/src/utils/containers/require_no_duplicates.cc new file mode 100644 index 0000000000..67733d791a --- /dev/null +++ b/lib/utils/test/src/utils/containers/require_no_duplicates.cc @@ -0,0 +1,62 @@ +#include "utils/containers/require_no_duplicates.h" +#include "test/utils/doctest/fmt/multiset.h" +#include "test/utils/doctest/fmt/set.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("require_no_duplicates(std::unordered_multiset)") { + SUBCASE("empty") { + std::unordered_multiset input = {}; + + std::unordered_set result = require_no_duplicates(input); + std::unordered_set correct = {}; + + CHECK(result == correct); + } + + SUBCASE("input has duplicates") { + std::unordered_multiset input = {1, 2, 2}; + + CHECK_THROWS(require_no_duplicates(input)); + } + + SUBCASE("input does not have duplicates") { + std::unordered_multiset input = {1, 2, 4}; + + std::unordered_set result = require_no_duplicates(input); + std::unordered_set correct = {1, 2, 4}; + + CHECK(result == correct); + } + } + + TEST_CASE("require_no_duplicates(std::multiset)") { + SUBCASE("empty") { + std::multiset input = {}; + + std::set result = require_no_duplicates(input); + std::set correct = {}; + + CHECK(result == correct); + } + + SUBCASE("input has duplicates") { + std::multiset input = {1, 2, 2}; + + CHECK_THROWS(require_no_duplicates(input)); + } + + SUBCASE("input does not have duplicates") { + std::multiset input = {1, 2, 4}; + + std::set result = require_no_duplicates(input); + std::set correct = {1, 2, 4}; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/reversed.cc b/lib/utils/test/src/utils/containers/reversed.cc index 50900d936e..834a497152 100644 --- a/lib/utils/test/src/utils/containers/reversed.cc +++ b/lib/utils/test/src/utils/containers/reversed.cc @@ -1,16 +1,27 @@ #include "utils/containers/reversed.h" -#include "utils/fmt/vector.h" +#include "test/utils/doctest/fmt/vector.h" #include -#include -using namespace FlexFlow; +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("reversed(std::vector)") { - std::vector input_vec = {1, 2, 3, 4, 5}; - std::vector result = reversed(input_vec); - std::vector correct = {5, 4, 3, 2, 1}; + TEST_CASE("reversed(std::vector)") { + SUBCASE("non-empty input") { + std::vector input = {1, 2, 3, 2}; - CHECK(result == correct); + std::vector result = reversed(input); + std::vector correct = {2, 3, 2, 1}; + + CHECK(result == correct); + } + + SUBCASE("empty input") { + std::vector input = {}; + + std::vector result = reversed(input); + std::vector correct = {}; + + CHECK(result == correct); + } } } diff --git a/lib/utils/test/src/utils/containers/scanl.cc b/lib/utils/test/src/utils/containers/scanl.cc new file mode 100644 index 0000000000..d6da0ac0a1 --- /dev/null +++ b/lib/utils/test/src/utils/containers/scanl.cc @@ -0,0 +1,71 @@ +#include "utils/containers/scanl.h" +#include "test/utils/doctest/fmt/vector.h" +#include +#include +#include + +using namespace FlexFlow; +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("scanl") { + SUBCASE("sum") { + std::vector input = {1, 2, 3, 4}; + std::vector result = + scanl(input, 0, [](int a, int b) { return a + b; }); + std::vector correct = {0, 1, 3, 6, 10}; + CHECK(result == correct); + } + + SUBCASE("custom function") { + std::vector input = {1, 3, 1, 2}; + auto op = [](int a, int b) { return (a + 1) * (b + 1); }; + std::vector result = scanl(input, 1, op); + std::vector correct = {1, 4, 20, 42, 129}; + CHECK(result == correct); + } + + SUBCASE("heterogeneous types") { + std::vector input = {1, 2, 3, 4}; + auto op = [](std::string const &a, int b) { + return a + std::to_string(b); + }; + std::vector result = scanl(input, std::string(""), op); + std::vector correct = {"", "1", "12", "123", "1234"}; + CHECK(result == correct); + } + + SUBCASE("empty input") { + std::vector input = {}; + std::vector result = + scanl(input, 0, [](int a, int b) { return a + b; }); + std::vector correct = {0}; + CHECK(result == correct); + } + } + + TEST_CASE("scanl1") { + SUBCASE("sum") { + std::vector input = {1, 2, 3, 4}; + std::vector result = + scanl1(input, [](int a, int b) { return a + b; }); + std::vector correct = {1, 3, 6, 10}; + CHECK(result == correct); + } + + SUBCASE("custom function") { + std::vector input = {1, 2, 5, 2}; + auto op = [](int a, int b) { return a * b + 1; }; + std::vector result = scanl1(input, op); + std::vector correct = {1, 3, 16, 33}; + CHECK(result == correct); + } + + SUBCASE("empty input") { + std::vector input = {}; + std::vector result = + scanl1(input, [](int a, int b) { return a + b; }); + std::vector correct = {}; + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/sum.cc b/lib/utils/test/src/utils/containers/sum.cc index 211c70ce7a..32d8cd32a3 100644 --- a/lib/utils/test/src/utils/containers/sum.cc +++ b/lib/utils/test/src/utils/containers/sum.cc @@ -1,23 +1,27 @@ #include "utils/containers/sum.h" #include -#include #include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("sum") { - SUBCASE("non-empty container") { - std::vector input = {1, 2, 3, 4, 5}; - int correct = 15; + TEST_CASE("sum(std::vector)") { + SUBCASE("input is empty") { + std::vector input = {}; + int result = sum(input); - CHECK(correct == result); - } - SUBCASE("empty container") { - std::unordered_set input = {}; int correct = 0; + + CHECK(result == correct); + } + + SUBCASE("input is not empty") { + std::vector input = {1, 3, 2}; + int result = sum(input); - CHECK(correct == result); + int correct = 6; + + CHECK(result == correct); } } } diff --git a/lib/utils/test/src/utils/containers/to_uppercase.cc b/lib/utils/test/src/utils/containers/to_uppercase.cc new file mode 100644 index 0000000000..9729307304 --- /dev/null +++ b/lib/utils/test/src/utils/containers/to_uppercase.cc @@ -0,0 +1,15 @@ +#include "utils/containers/to_uppercase.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("to_uppercase(std::string)") { + std::string input = "Hello World"; + + std::string result = to_uppercase(input); + std::string correct = "HELLO WORLD"; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/containers/transform.cc b/lib/utils/test/src/utils/containers/transform.cc index 916bc20928..3122c67117 100644 --- a/lib/utils/test/src/utils/containers/transform.cc +++ b/lib/utils/test/src/utils/containers/transform.cc @@ -1,7 +1,7 @@ #include "utils/containers/transform.h" -#include "utils/fmt/optional.h" -#include "utils/fmt/unordered_set.h" -#include "utils/fmt/vector.h" +#include "test/utils/doctest/fmt/optional.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/vector.h" #include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/containers/try_at.cc b/lib/utils/test/src/utils/containers/try_at.cc new file mode 100644 index 0000000000..548c9b0c79 --- /dev/null +++ b/lib/utils/test/src/utils/containers/try_at.cc @@ -0,0 +1,29 @@ +#include "utils/containers/try_at.h" +#include "test/utils/doctest/fmt/optional.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE_TEMPLATE("try_at(T, K)", + T, + std::unordered_map, + std::map) { + T m = {{1, "one"}, {2, "two"}}; + + SUBCASE("map contains key") { + std::optional result = try_at(m, 1); + std::optional correct = "one"; + + CHECK(result == correct); + } + + SUBCASE("map does not contain key") { + std::optional result = try_at(m, 3); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/try_merge_nondisjoint_unordered_maps.cc b/lib/utils/test/src/utils/containers/try_merge_nondisjoint_unordered_maps.cc index 6aeab4ae6e..b8a7a85f74 100644 --- a/lib/utils/test/src/utils/containers/try_merge_nondisjoint_unordered_maps.cc +++ b/lib/utils/test/src/utils/containers/try_merge_nondisjoint_unordered_maps.cc @@ -1,7 +1,7 @@ #include "utils/containers/try_merge_nondisjoint_unordered_maps.h" -#include "test/utils/doctest.h" -#include "utils/fmt/optional.h" -#include "utils/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/optional.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/containers/unordered_map_from_pairs.cc b/lib/utils/test/src/utils/containers/unordered_map_from_pairs.cc new file mode 100644 index 0000000000..f0cdb19611 --- /dev/null +++ b/lib/utils/test/src/utils/containers/unordered_map_from_pairs.cc @@ -0,0 +1,57 @@ +#include "utils/containers/unordered_map_from_pairs.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include "utils/containers/contains.h" +#include +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("unordered_map_from_pairs") { + SUBCASE("nonempty input") { + std::vector> input = { + {1, "hello"}, + {3, "world"}, + }; + + std::unordered_map result = + unordered_map_from_pairs(input); + std::unordered_map correct = { + {1, "hello"}, + {3, "world"}, + }; + + CHECK(result == correct); + } + + SUBCASE("empty input") { + std::vector> input = {}; + + std::unordered_map result = + unordered_map_from_pairs(input); + std::unordered_map correct = {}; + + CHECK(result == correct); + } + + SUBCASE("input with duplicate keys") { + std::vector> input = { + {1, "a"}, + {2, "c"}, + {1, "b"}, + }; + + std::unordered_map result = + unordered_map_from_pairs(input); + + std::vector> + possible_correct_values = { + {{1, "a"}, {2, "c"}}, + {{1, "b"}, {2, "c"}}, + }; + + CHECK(contains(possible_correct_values, result)); + } + } +} diff --git a/lib/utils/test/src/utils/containers/unordered_multiset_of.cc b/lib/utils/test/src/utils/containers/unordered_multiset_of.cc index 0ab0ef1446..becb7fdce0 100644 --- a/lib/utils/test/src/utils/containers/unordered_multiset_of.cc +++ b/lib/utils/test/src/utils/containers/unordered_multiset_of.cc @@ -1,5 +1,5 @@ #include "utils/containers/unordered_multiset_of.h" -#include "utils/fmt/unordered_multiset.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" #include #include diff --git a/lib/utils/test/src/utils/containers/unordered_set_of.cc b/lib/utils/test/src/utils/containers/unordered_set_of.cc index d42b41dd50..b8ca1d1797 100644 --- a/lib/utils/test/src/utils/containers/unordered_set_of.cc +++ b/lib/utils/test/src/utils/containers/unordered_set_of.cc @@ -1,5 +1,5 @@ #include "utils/containers/unordered_set_of.h" -#include "utils/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/unordered_set.h" #include #include diff --git a/lib/utils/test/src/utils/containers/vector_of.cc b/lib/utils/test/src/utils/containers/vector_of.cc new file mode 100644 index 0000000000..8b9353e1b0 --- /dev/null +++ b/lib/utils/test/src/utils/containers/vector_of.cc @@ -0,0 +1,17 @@ +#include "utils/containers/vector_of.h" +#include "test/utils/doctest/fmt/vector.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("vector_of(std::set)") { + std::set input = {2, 3, 1, 4}; + + std::vector result = vector_of(input); + std::vector correct = {1, 2, 3, 4}; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/containers/without_order.cc b/lib/utils/test/src/utils/containers/without_order.cc deleted file mode 100644 index 939c6ff108..0000000000 --- a/lib/utils/test/src/utils/containers/without_order.cc +++ /dev/null @@ -1,15 +0,0 @@ -#include "utils/containers/without_order.h" -#include "utils/fmt/unordered_multiset.h" -#include -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("without_order") { - std::vector input = {1, 2, 3, 3, 2, 3}; - std::unordered_multiset result = without_order(input); - std::unordered_multiset correct = {1, 2, 3, 3, 2, 3}; - CHECK(result == correct); - } -} diff --git a/lib/utils/test/src/utils/deduplicated_priority_queue.cc b/lib/utils/test/src/utils/deduplicated_priority_queue.cc index 2bf413675d..048e95acb7 100644 --- a/lib/utils/test/src/utils/deduplicated_priority_queue.cc +++ b/lib/utils/test/src/utils/deduplicated_priority_queue.cc @@ -1,5 +1,5 @@ #include "utils/deduplicated_priority_queue.h" -#include "test/utils/doctest.h" +#include TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("DeduplicatedPriorityQueue push and pop") { diff --git a/lib/utils/test/src/utils/dot_file.cc b/lib/utils/test/src/utils/dot_file.cc index d7f5dcb02c..e409572511 100644 --- a/lib/utils/test/src/utils/dot_file.cc +++ b/lib/utils/test/src/utils/dot_file.cc @@ -1,5 +1,5 @@ #include "utils/dot_file.h" -#include "test/utils/doctest.h" +#include #include TEST_SUITE(FF_TEST_SUITE) { diff --git a/lib/utils/test/src/utils/expected.cc b/lib/utils/test/src/utils/expected.cc index 14679e0d13..3e5de13d49 100644 --- a/lib/utils/test/src/utils/expected.cc +++ b/lib/utils/test/src/utils/expected.cc @@ -1,6 +1,6 @@ #include "utils/expected.h" -#include "utils/fmt/expected.h" -#include "utils/fmt/optional.h" +#include "test/utils/doctest/fmt/expected.h" +#include "test/utils/doctest/fmt/optional.h" #include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/fmt/expected.cc b/lib/utils/test/src/utils/fmt/expected.cc index fb39732761..48df8634db 100644 --- a/lib/utils/test/src/utils/fmt/expected.cc +++ b/lib/utils/test/src/utils/fmt/expected.cc @@ -1,5 +1,6 @@ #include "utils/fmt/expected.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/expected.h" +#include using namespace ::FlexFlow; @@ -19,24 +20,4 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == correct); } } - - TEST_CASE("doctest::toString(tl::expected)") { - SUBCASE("has expected") { - tl::expected input = 3; - - doctest::String result = doctest::toString(input); - doctest::String correct = "expected(3)"; - - CHECK(result == correct); - } - - SUBCASE("has unexpected") { - tl::expected input = tl::make_unexpected("error"); - - doctest::String result = doctest::toString(input); - doctest::String correct = "unexpected(error)"; - - CHECK(result == correct); - } - } } diff --git a/lib/utils/test/src/utils/fmt/map.cc b/lib/utils/test/src/utils/fmt/map.cc index b65b4791ea..19f3a7d5cf 100644 --- a/lib/utils/test/src/utils/fmt/map.cc +++ b/lib/utils/test/src/utils/fmt/map.cc @@ -1,5 +1,5 @@ #include "utils/fmt/map.h" -#include "test/utils/doctest.h" +#include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/fmt/optional.cc b/lib/utils/test/src/utils/fmt/optional.cc index e7815a26ac..1cd79da747 100644 --- a/lib/utils/test/src/utils/fmt/optional.cc +++ b/lib/utils/test/src/utils/fmt/optional.cc @@ -1,5 +1,5 @@ #include "utils/fmt/optional.h" -#include "test/utils/doctest.h" +#include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/fmt/pair.cc b/lib/utils/test/src/utils/fmt/pair.cc index 3d7cc78756..e848eb08c7 100644 --- a/lib/utils/test/src/utils/fmt/pair.cc +++ b/lib/utils/test/src/utils/fmt/pair.cc @@ -1,5 +1,5 @@ #include "utils/fmt/pair.h" -#include "test/utils/doctest.h" +#include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/fmt/set.cc b/lib/utils/test/src/utils/fmt/set.cc index 66824f2b2a..e317954b02 100644 --- a/lib/utils/test/src/utils/fmt/set.cc +++ b/lib/utils/test/src/utils/fmt/set.cc @@ -1,5 +1,5 @@ #include "utils/fmt/set.h" -#include "test/utils/doctest.h" +#include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/fmt/unordered_map.cc b/lib/utils/test/src/utils/fmt/unordered_map.cc index 99752d73f4..c980bc1e52 100644 --- a/lib/utils/test/src/utils/fmt/unordered_map.cc +++ b/lib/utils/test/src/utils/fmt/unordered_map.cc @@ -1,6 +1,7 @@ #include "utils/fmt/unordered_map.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/unordered_map.h" #include "utils/containers/get_element_counts.h" +#include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/fmt/unordered_set.cc b/lib/utils/test/src/utils/fmt/unordered_set.cc index 9dc8d236f1..f492ea844d 100644 --- a/lib/utils/test/src/utils/fmt/unordered_set.cc +++ b/lib/utils/test/src/utils/fmt/unordered_set.cc @@ -1,7 +1,7 @@ #include "utils/fmt/unordered_set.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" #include "utils/containers/unordered_multiset_of.h" -#include "utils/fmt/unordered_multiset.h" +#include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/fmt/variant.cc b/lib/utils/test/src/utils/fmt/variant.cc index 3ada166de9..0c8dca35d7 100644 --- a/lib/utils/test/src/utils/fmt/variant.cc +++ b/lib/utils/test/src/utils/fmt/variant.cc @@ -1,5 +1,5 @@ #include "utils/fmt/variant.h" -#include "test/utils/doctest.h" +#include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/fmt/vector.cc b/lib/utils/test/src/utils/fmt/vector.cc index fee3eb34a5..91ef6c9efc 100644 --- a/lib/utils/test/src/utils/fmt/vector.cc +++ b/lib/utils/test/src/utils/fmt/vector.cc @@ -1,5 +1,5 @@ #include "utils/fmt/vector.h" -#include "test/utils/doctest.h" +#include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/graph/cow_ptr_t.cc b/lib/utils/test/src/utils/graph/cow_ptr_t.cc index 65088c19de..e6a6f9661e 100644 --- a/lib/utils/test/src/utils/graph/cow_ptr_t.cc +++ b/lib/utils/test/src/utils/graph/cow_ptr_t.cc @@ -1,5 +1,5 @@ #include "utils/graph/cow_ptr_t.h" -#include "test/utils/doctest.h" +#include #include #include #include diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.cc new file mode 100644 index 0000000000..fec5d3401e --- /dev/null +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.cc @@ -0,0 +1,104 @@ +#include "utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.h" +#include "utils/containers/get_only.h" +#include "utils/graph/dataflow_graph/dataflow_graph.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_dataflow_edges_from_node_to_node") { + DataflowGraph g = DataflowGraph::create(); + + SUBCASE("gets edges if there are multiple") { + NodeAddedResult n1_added = g.add_node({}, 2); + Node n1 = n1_added.node; + DataflowOutput n1_o0 = n1_added.outputs.at(0); + DataflowOutput n1_o1 = n1_added.outputs.at(1); + + NodeAddedResult n2_added = g.add_node({n1_o0, n1_o0, n1_o1}, 0); + Node n2 = n2_added.node; + + std::unordered_set result = + get_dataflow_edges_from_node_to_node(g, n1, n2); + std::unordered_set correct = { + DataflowEdge{ + n1_o0, + DataflowInput{n2, 0}, + }, + DataflowEdge{ + n1_o0, + DataflowInput{n2, 1}, + }, + DataflowEdge{ + n1_o1, + DataflowInput{n2, 2}, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("does not get edges to/from other nodes") { + NodeAddedResult n1_added = g.add_node({}, 1); + Node n1 = n1_added.node; + DataflowOutput o1 = get_only(n1_added.outputs); + + NodeAddedResult n2_added = g.add_node({o1}, 1); + Node n2 = n2_added.node; + DataflowOutput o2 = get_only(n2_added.outputs); + + NodeAddedResult n3_added = g.add_node({o2}, 1); + Node n3 = n3_added.node; + DataflowOutput o3 = get_only(n3_added.outputs); + + std::unordered_set result = + get_dataflow_edges_from_node_to_node(g, n1, n3); + std::unordered_set correct = {}; + + CHECK(result == correct); + } + + SUBCASE( + "does not get flipped edges (i.e., respects from vs to direction)") { + NodeAddedResult n1_added = g.add_node({}, 1); + Node n1 = n1_added.node; + DataflowOutput o1 = get_only(n1_added.outputs); + + NodeAddedResult n2_added = g.add_node({o1}, 0); + Node n2 = n2_added.node; + + std::unordered_set result = + get_dataflow_edges_from_node_to_node(g, n2, n1); + std::unordered_set correct = {}; + + CHECK(result == correct); + } + + SUBCASE("returns empty set if no edges exist between the given nodes") { + NodeAddedResult n1_added = g.add_node({}, 1); + Node n1 = n1_added.node; + + NodeAddedResult n2_added = g.add_node({}, 1); + Node n2 = n2_added.node; + + std::unordered_set result = + get_dataflow_edges_from_node_to_node(g, n1, n2); + std::unordered_set correct = {}; + + CHECK(result == correct); + } + + SUBCASE("returns empty set if src node == dst node (as cycles cannot exist " + "in DataflowGraph") { + NodeAddedResult n1_added = g.add_node({}, 1); + Node n1 = n1_added.node; + + std::unordered_set result = + get_dataflow_edges_from_node_to_node(g, n1, n1); + std::unordered_set correct = {}; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.cc new file mode 100644 index 0000000000..330628adfd --- /dev/null +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.cc @@ -0,0 +1,43 @@ +#include "utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.h" +#include "utils/containers/get_only.h" +#include "utils/graph/dataflow_graph/dataflow_graph.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_subgraph_incoming_edges(DataflowGraphView, " + "std::unordered_set") { + DataflowGraph g = DataflowGraph::create(); + + NodeAddedResult n1_added = g.add_node({}, 1); + Node n1 = n1_added.node; + DataflowOutput o1 = get_only(n1_added.outputs); + + NodeAddedResult n2_added = g.add_node({o1}, 1); + Node n2 = n2_added.node; + DataflowOutput o2 = get_only(n2_added.outputs); + + NodeAddedResult n3_added = g.add_node({o1, o2, o1}, 1); + Node n3 = n3_added.node; + DataflowOutput o3 = get_only(n3_added.outputs); + + NodeAddedResult n4_added = g.add_node({o2, o3}, 1); + Node n4 = n4_added.node; + DataflowOutput o4 = get_only(n4_added.outputs); + + std::unordered_set input_node_set = {n2, n3}; + + std::unordered_set result = + get_subgraph_incoming_edges(g, input_node_set); + + std::unordered_set correct = { + DataflowEdge{o1, DataflowInput{n2, 0}}, + DataflowEdge{o1, DataflowInput{n3, 0}}, + DataflowEdge{o1, DataflowInput{n3, 2}}, + }; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.cc index 7e02686dde..779d0a9560 100644 --- a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.cc +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.cc @@ -7,7 +7,8 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_outgoing_edges(DataflowGraphView, std::unordered_set") { + TEST_CASE("get_subgraph_outgoing_edges(DataflowGraphView, " + "std::unordered_set") { DataflowGraph g = DataflowGraph::create(); NodeAddedResult n1_added = g.add_node({}, 1); diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc new file mode 100644 index 0000000000..c35789044d --- /dev/null +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc @@ -0,0 +1,55 @@ +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.h" +#include "utils/containers/get_only.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.h" +#include "utils/graph/dataflow_graph/dataflow_graph.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_transitive_reduced_boundary_nodes_for_split") { + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; + + DataflowGraph g = DataflowGraph::create(); + + NodeAddedResult n1_added = g.add_node({}, 1); + Node n1 = n1_added.node; + DataflowOutput o1 = get_only(n1_added.outputs); + + NodeAddedResult n2_added = g.add_node({o1}, 1); + Node n2 = n2_added.node; + DataflowOutput o2 = get_only(n2_added.outputs); + + NodeAddedResult n3_added = g.add_node({o1, o2}, 1); + Node n3 = n3_added.node; + DataflowOutput o3 = get_only(n3_added.outputs); + + NodeAddedResult n4_added = g.add_node({o2, o3}, 1); + Node n4 = n4_added.node; + DataflowOutput o4 = get_only(n4_added.outputs); + + TransitiveReducedDataflowGraphView tr_g = + get_dataflow_graph_transitive_reduction(g); + + BinarySeriesSplit split = BinarySeriesSplit{ + make_series_split(make_leaf(n1), make_leaf(n2)), + make_series_split(make_leaf(n3), make_leaf(n4)), + }; + + SplitBoundaryNodes result = + get_transitive_reduced_boundary_nodes_for_split(tr_g, split); + SplitBoundaryNodes correct = SplitBoundaryNodes{ + /*pre_split_boundary=*/{n2}, + /*post_split_boundary=*/{n3}, + }; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc new file mode 100644 index 0000000000..1f8f66b932 --- /dev/null +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc @@ -0,0 +1,146 @@ +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.h" +#include "utils/containers/get_only.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.h" +#include "utils/graph/dataflow_graph/dataflow_graph.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_transitive_reduced_edges_across_split") { + DataflowGraph g = DataflowGraph::create(); + + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; + + SUBCASE("multiple nodes with edges across") { + NodeAddedResult n1_added = g.add_node({}, 1); + Node n1 = n1_added.node; + DataflowOutput o1 = get_only(n1_added.outputs); + + NodeAddedResult n2_added = g.add_node({}, 1); + Node n2 = n2_added.node; + DataflowOutput o2 = get_only(n2_added.outputs); + + NodeAddedResult n3_added = g.add_node({o2, o1}, 1); + Node n3 = n3_added.node; + DataflowOutput o3 = get_only(n3_added.outputs); + + NodeAddedResult n4_added = g.add_node({o1}, 1); + Node n4 = n4_added.node; + DataflowOutput o4 = get_only(n4_added.outputs); + + TransitiveReducedDataflowGraphView tr_g = + get_dataflow_graph_transitive_reduction(g); + + BinarySeriesSplit split = BinarySeriesSplit{ + make_parallel_split(make_leaf(n1), make_leaf(n2)), + make_parallel_split(make_leaf(n3), make_leaf(n4)), + }; + + std::unordered_set result = + get_transitive_reduced_edges_across_split(tr_g, split); + std::unordered_set correct = { + DataflowEdge{ + o1, + DataflowInput{n3, 1}, + }, + DataflowEdge{ + o2, + DataflowInput{n3, 0}, + }, + DataflowEdge{ + o1, + DataflowInput{n4, 0}, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("nodes each have multiple edges across") { + NodeAddedResult n1_added = g.add_node({}, 2); + Node n1 = n1_added.node; + DataflowOutput n1_o1 = n1_added.outputs.at(0); + DataflowOutput n1_o2 = n1_added.outputs.at(1); + + NodeAddedResult n2_added = g.add_node({n1_o1, n1_o2, n1_o1}, 1); + Node n2 = n2_added.node; + + TransitiveReducedDataflowGraphView tr_g = + get_dataflow_graph_transitive_reduction(g); + + BinarySeriesSplit split = BinarySeriesSplit{ + make_leaf(n1), + make_leaf(n2), + }; + + std::unordered_set result = + get_transitive_reduced_edges_across_split(tr_g, split); + std::unordered_set correct = { + DataflowEdge{ + n1_o1, + DataflowInput{n2, 0}, + }, + DataflowEdge{ + n1_o2, + DataflowInput{n2, 1}, + }, + DataflowEdge{ + n1_o1, + DataflowInput{n2, 2}, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("does not return edges eliminated by transitive reduction") { + NodeAddedResult n1_added = g.add_node({}, 1); + Node n1 = n1_added.node; + DataflowOutput o1 = get_only(n1_added.outputs); + + NodeAddedResult n2_added = g.add_node({o1}, 1); + Node n2 = n2_added.node; + DataflowOutput o2 = get_only(n2_added.outputs); + + NodeAddedResult n3_added = g.add_node({o1, o2}, 1); + Node n3 = n3_added.node; + DataflowOutput o3 = get_only(n3_added.outputs); + + NodeAddedResult n4_added = g.add_node({o2, o3}, 1); + Node n4 = n4_added.node; + DataflowOutput o4 = get_only(n4_added.outputs); + + TransitiveReducedDataflowGraphView tr_g = + get_dataflow_graph_transitive_reduction(g); + + BinarySeriesSplit split = BinarySeriesSplit{ + make_series_split(make_leaf(n1), make_leaf(n2)), + make_series_split(make_leaf(n3), make_leaf(n4)), + }; + + std::unordered_set result = + get_transitive_reduced_edges_across_split(tr_g, split); + std::unordered_set correct = { + DataflowEdge{ + o2, + DataflowInput{n3, 1}, + }, + }; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc new file mode 100644 index 0000000000..0e77739434 --- /dev/null +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc @@ -0,0 +1,52 @@ +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.h" +#include "utils/containers/get_only.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.h" +#include "utils/graph/dataflow_graph/dataflow_graph.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_transitive_reduced_outputs_across_split") { + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; + + DataflowGraph g = DataflowGraph::create(); + + NodeAddedResult n1_added = g.add_node({}, 1); + Node n1 = n1_added.node; + DataflowOutput o1 = get_only(n1_added.outputs); + + NodeAddedResult n2_added = g.add_node({o1}, 1); + Node n2 = n2_added.node; + DataflowOutput o2 = get_only(n2_added.outputs); + + NodeAddedResult n3_added = g.add_node({o1, o2}, 1); + Node n3 = n3_added.node; + DataflowOutput o3 = get_only(n3_added.outputs); + + NodeAddedResult n4_added = g.add_node({o2, o3}, 1); + Node n4 = n4_added.node; + DataflowOutput o4 = get_only(n4_added.outputs); + + TransitiveReducedDataflowGraphView tr_g = + get_dataflow_graph_transitive_reduction(g); + + BinarySeriesSplit split = BinarySeriesSplit{ + make_series_split(make_leaf(n1), make_leaf(n2)), + make_series_split(make_leaf(n3), make_leaf(n4)), + }; + + std::unordered_set result = + get_transitive_reduced_outputs_across_split(tr_g, split); + std::unordered_set correct = {o2}; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/unordered_open_dataflow_graph.cc b/lib/utils/test/src/utils/graph/dataflow_graph/unordered_open_dataflow_graph.cc index cfc912af6b..7a3237d432 100644 --- a/lib/utils/test/src/utils/graph/dataflow_graph/unordered_open_dataflow_graph.cc +++ b/lib/utils/test/src/utils/graph/dataflow_graph/unordered_open_dataflow_graph.cc @@ -1,9 +1,11 @@ -#include "test/utils/doctest.h" #include "utils/graph/dataflow_graph/dataflow_edge_query.h" #include "utils/graph/dataflow_graph/dataflow_graph.h" #include "utils/graph/dataflow_graph/dataflow_output_query.h" #include "utils/graph/instances/unordered_set_dataflow_graph.h" #include "utils/graph/node/node_query.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("UnorderedSetDataflowGraph") { diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc index 2ebfe232b6..eca7aa6c79 100644 --- a/lib/utils/test/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc @@ -1,5 +1,8 @@ #include "utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.h" +#include "utils/containers/reversed.h" +#include "utils/containers/vector_of.h" #include "utils/graph/algorithms.h" +#include "utils/graph/digraph/algorithms/transitive_reduction.h" #include "utils/graph/instances/adjacency_digraph.h" #include @@ -9,6 +12,25 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_cbc_decomposition") { DiGraph g = DiGraph::create(); + // used to check that the cbc decomposition result is the same regardless + // of the order in which the graph edges are processed, as this is a + // property that should hold, and violations of this property have been a + // source of bugs in the past + auto check_cbc_decomposition_is_edge_order_invariant = + [](DiGraphView const &g) { + std::unordered_set edges = get_edges(g); + + std::vector edge_order1 = vector_of(edges); + std::vector edge_order2 = reversed(edge_order1); + + std::optional result1 = + get_cbc_decomposition_with_edge_order_internal(g, edge_order1); + std::optional result2 = + get_cbc_decomposition_with_edge_order_internal(g, edge_order2); + + CHECK(result1 == result2); + }; + SUBCASE("six-node diamond graph") { std::vector n = add_nodes(g, 6); add_edges(g, @@ -32,6 +54,8 @@ TEST_SUITE(FF_TEST_SUITE) { }}; CHECK(result == correct); + + check_cbc_decomposition_is_edge_order_invariant(g); } SUBCASE("graph without any edges") { @@ -43,6 +67,27 @@ TEST_SUITE(FF_TEST_SUITE) { CompleteBipartiteCompositeDecomposition{{}}; CHECK(result == correct); + + check_cbc_decomposition_is_edge_order_invariant(g); + } + + SUBCASE("irreducible n-graph (non-cbc graph)") { + std::vector n = add_nodes(g, 4); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + }); + + std::optional result = + get_cbc_decomposition(g); + std::optional correct = + std::nullopt; + + CHECK(result == correct); + + check_cbc_decomposition_is_edge_order_invariant(g); } } } diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_graph.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_graph.cc new file mode 100644 index 0000000000..17c8b8da27 --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_graph.cc @@ -0,0 +1,175 @@ +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.h" +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("is_complete_bipartite_digraph(UndirectedGraphView, " + "std::unordered_set)") { + DiGraph g = DiGraph::create(); + + SUBCASE("simple bipartite graph") { + std::vector n = add_nodes(g, 5); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(0), n.at(4)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(1), n.at(4)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(2), n.at(4)}, + }); + + SUBCASE("source group") { + std::unordered_set group1 = {n.at(0), n.at(1), n.at(2)}; + + bool result = is_complete_bipartite_digraph(g, group1); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("sink group") { + std::unordered_set group1 = {n.at(3), n.at(4)}; + + bool result = is_complete_bipartite_digraph(g, group1); + bool correct = false; + + CHECK(result == correct); + } + } + + SUBCASE("missing an edge (i.e., not complete)") { + std::vector n = add_nodes(g, 4); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(3)}, + }); + std::unordered_set group1 = {n.at(0), n.at(1)}; + + bool result = is_complete_bipartite_digraph(g, group1); + bool correct = false; + + CHECK(result == correct); + } + + SUBCASE("extra edge (i.e., not bipartite)") { + std::vector n = add_nodes(g, 4); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(3)}, + }); + std::unordered_set group1 = {n.at(0), n.at(1)}; + + bool result = is_complete_bipartite_digraph(g, group1); + bool correct = false; + + CHECK(result == correct); + } + + SUBCASE("flipped edge") { + std::vector n = add_nodes(g, 4); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(2), n.at(1)}, + DirectedEdge{n.at(1), n.at(3)}, + }); + std::unordered_set group1 = {n.at(0), n.at(1)}; + + bool result = is_complete_bipartite_digraph(g, group1); + bool correct = false; + + CHECK(result == correct); + } + + SUBCASE("group too small") { + std::vector n = add_nodes(g, 4); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + }); + std::unordered_set group1 = {n.at(0)}; + + bool result = is_complete_bipartite_digraph(g, group1); + bool correct = false; + + CHECK(result == correct); + } + } + + TEST_CASE("is_complete_bipartite_digraph(UndirectedGraphView)") { + DiGraph g = DiGraph::create(); + + SUBCASE("simple bipartite graph") { + std::vector n = add_nodes(g, 5); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(0), n.at(4)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(1), n.at(4)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(2), n.at(4)}, + }); + + bool result = is_complete_bipartite_digraph(g); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("missing an edge") { + std::vector n = add_nodes(g, 4); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(3)}, + }); + + bool result = is_complete_bipartite_digraph(g); + bool correct = false; + + CHECK(result == correct); + } + + SUBCASE("extra edge") { + std::vector n = add_nodes(g, 4); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(3)}, + }); + + bool result = is_complete_bipartite_digraph(g); + bool correct = false; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc new file mode 100644 index 0000000000..5a1ea99671 --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc @@ -0,0 +1,142 @@ +#include "utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_edges_from_subgraph_to_subgraph") { + DiGraph g = DiGraph::create(); + + std::vector n = add_nodes(g, 5); + SUBCASE("basic tests") { + std::unordered_set src_subgraph = {n.at(0), n.at(1), n.at(4)}; + std::unordered_set dst_subgraph = {n.at(2), n.at(3)}; + + SUBCASE("returns all edges between subgraphs") { + std::vector e = { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(4), n.at(2)}, + }; + + add_edges(g, e); + + std::unordered_set result = + get_edges_from_subgraph_to_subgraph(g, src_subgraph, dst_subgraph); + std::unordered_set correct = unordered_set_of(e); + + CHECK(result == correct); + } + + SUBCASE("does not return reverse edges") { + std::vector e = { + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(0)}, + }; + + add_edges(g, e); + + std::unordered_set result = + get_edges_from_subgraph_to_subgraph(g, src_subgraph, dst_subgraph); + std::unordered_set correct = {e.at(0)}; + + CHECK(result == correct); + } + + SUBCASE("does not return edges within subgraph") { + std::vector e = { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(3)}, + }; + + add_edges(g, e); + + std::unordered_set result = + get_edges_from_subgraph_to_subgraph(g, src_subgraph, dst_subgraph); + std::unordered_set correct = {e.at(1)}; + + CHECK(result == correct); + } + + SUBCASE("returns no edges if there are no edges from src_subgraph to " + "dst_subgraph") { + std::vector e = { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(2), n.at(3)}, + }; + + add_edges(g, e); + + std::unordered_set result = + get_edges_from_subgraph_to_subgraph(g, src_subgraph, dst_subgraph); + std::unordered_set correct = {}; + + CHECK(result == correct); + } + } + + SUBCASE("empty subgraphs") { + std::vector e = { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(3)}, + }; + + add_edges(g, e); + + SUBCASE("returns no edges if no nodes in src_subgraph") { + std::unordered_set result = + get_edges_from_subgraph_to_subgraph(g, {}, unordered_set_of(n)); + std::unordered_set correct = {}; + + CHECK(result == correct); + } + + SUBCASE("returns no edges if no nodes in dst_subgraph") { + std::unordered_set result = + get_edges_from_subgraph_to_subgraph(g, unordered_set_of(n), {}); + std::unordered_set correct = {}; + + CHECK(result == correct); + } + + SUBCASE("returns no edges if both subgraphs are empty") { + std::unordered_set result = + get_edges_from_subgraph_to_subgraph(g, {}, {}); + std::unordered_set correct = {}; + + CHECK(result == correct); + } + } + + SUBCASE("if subgraphs do not cover graph, then does not return external " + "edges") { + std::vector e = { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(3)}, + }; + + add_edges(g, e); + + std::unordered_set src_subgraph = {n.at(0)}; + std::unordered_set dst_subgraph = {n.at(3)}; + + std::unordered_set result = + get_edges_from_subgraph_to_subgraph(g, src_subgraph, dst_subgraph); + std::unordered_set correct = {e.at(1)}; + + CHECK(result == correct); + } + + SUBCASE("throws an error if subgraphs are not disjoint") { + std::unordered_set src_subgraph = {n.at(0), n.at(1), n.at(2)}; + std::unordered_set dst_subgraph = {n.at(1), n.at(3)}; + CHECK_THROWS( + get_edges_from_subgraph_to_subgraph(g, src_subgraph, dst_subgraph)); + } + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/inverse_line_graph/get_inverse_line_graph.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/inverse_line_graph/get_inverse_line_graph.cc index fd2f469f93..a635658755 100644 --- a/lib/utils/test/src/utils/graph/digraph/algorithms/inverse_line_graph/get_inverse_line_graph.cc +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/inverse_line_graph/get_inverse_line_graph.cc @@ -4,6 +4,7 @@ #include "utils/containers/transform.h" #include "utils/graph/algorithms.h" #include "utils/graph/digraph/algorithms/get_successors.h" +#include "utils/graph/digraph/algorithms/transitive_reduction.h" #include "utils/graph/instances/adjacency_digraph.h" #include "utils/graph/multidigraph/algorithms/get_directed_edge.h" #include "utils/graph/multidigraph/algorithms/get_edge_counts.h" @@ -139,5 +140,27 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result_bidict == correct_bidict); } } + + SUBCASE("sp n-graph (inverse line graph does not exist)") { + // Tests that the inverse line graph of the sp n-graph + // + // a-b + // \ + // c-d + // + // does not exist + std::vector n = add_nodes(g, 4); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + }); + + std::optional result = + get_inverse_line_graph(transitive_reduction(g)); + + CHECK_FALSE(result.has_value()); + } } } diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/is_acyclic.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/is_acyclic.cc index 3ad506f40a..e675e6903f 100644 --- a/lib/utils/test/src/utils/graph/digraph/algorithms/is_acyclic.cc +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/is_acyclic.cc @@ -1,4 +1,5 @@ #include "utils/graph/digraph/algorithms/is_acyclic.h" +#include "test/utils/doctest/fmt/optional.h" #include "utils/graph/algorithms.h" #include "utils/graph/instances/adjacency_digraph.h" #include diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/transitive_closure.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/transitive_closure.cc new file mode 100644 index 0000000000..5f72355ed0 --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/transitive_closure.cc @@ -0,0 +1,50 @@ +#include "utils/graph/digraph/algorithms/transitive_closure.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/node/algorithms.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("transitive_closure(DiGraphView)") { + DiGraph g = DiGraph::create(); + + SUBCASE("maximum number of new edges") { + std::vector n = add_nodes(g, 5); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(3), n.at(4)}, + }); + + DiGraphView result = transitive_closure(g); + + SUBCASE("nodes") { + std::unordered_set result_nodes = get_nodes(result); + std::unordered_set correct_nodes = unordered_set_of(n); + CHECK(result_nodes == correct_nodes); + } + + SUBCASE("edges") { + std::unordered_set result_edges = get_edges(result); + std::unordered_set correct_edges = { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(0), n.at(4)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(1), n.at(4)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(2), n.at(4)}, + DirectedEdge{n.at(3), n.at(4)}, + }; + CHECK(result_edges == correct_edges); + } + } + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/transitive_reduction.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/transitive_reduction.cc index b8a35346f4..1f9062a8ed 100644 --- a/lib/utils/test/src/utils/graph/digraph/algorithms/transitive_reduction.cc +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/transitive_reduction.cc @@ -1,4 +1,5 @@ #include "utils/graph/digraph/algorithms/transitive_reduction.h" +#include "test/utils/doctest/fmt/optional.h" #include "utils/graph/algorithms.h" #include "utils/graph/instances/adjacency_digraph.h" #include "utils/graph/node/algorithms.h" @@ -76,5 +77,66 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result_edges == correct_edges); } } + + SUBCASE("longer paths") { + std::vector n = add_nodes(g, 5); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(0), n.at(4)}, + DirectedEdge{n.at(3), n.at(4)}, + }); + + DiGraphView result = transitive_reduction(g); + + SUBCASE("nodes") { + std::unordered_set result_nodes = get_nodes(result); + std::unordered_set correct_nodes = unordered_set_of(n); + CHECK(result_nodes == correct_nodes); + } + + SUBCASE("edges") { + std::unordered_set result_edges = get_edges(result); + std::unordered_set correct_edges = { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(3), n.at(4)}, + }; + CHECK(result_edges == correct_edges); + } + } + + SUBCASE("irreducible sp n-graph") { + std::vector n = add_nodes(g, 4); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + }); + + DiGraphView result = transitive_reduction(g); + + SUBCASE("nodes") { + std::unordered_set result_nodes = get_nodes(result); + std::unordered_set correct_nodes = unordered_set_of(n); + CHECK(result_nodes == correct_nodes); + } + + SUBCASE("edges") { + std::unordered_set result_edges = get_edges(result); + std::unordered_set correct_edges = { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + }; + CHECK(result_edges == correct_edges); + } + } } } diff --git a/lib/utils/test/src/utils/graph/serial_parallel/get_serial_parallel_decomposition.cc b/lib/utils/test/src/utils/graph/serial_parallel/get_serial_parallel_decomposition.cc deleted file mode 100644 index 04d82bf1d8..0000000000 --- a/lib/utils/test/src/utils/graph/serial_parallel/get_serial_parallel_decomposition.cc +++ /dev/null @@ -1,162 +0,0 @@ -#include "utils/graph/serial_parallel/get_serial_parallel_decomposition.h" -#include "utils/graph/algorithms.h" -#include "utils/graph/instances/adjacency_digraph.h" -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_serial_parallel_decomposition (base case)") { - DiGraph g = DiGraph::create(); - Node n = g.add_node(); - - std::optional result = - get_serial_parallel_decomposition(g); - std::optional correct = - SerialParallelDecomposition{n}; - CHECK(result == correct); - } - - TEST_CASE("get_serial_parallel_decomposition (parallel)") { - DiGraph g = DiGraph::create(); - std::vector n = add_nodes(g, 2); - - std::optional result = - get_serial_parallel_decomposition(g); - std::optional correct = - SerialParallelDecomposition{ParallelSplit{ - n.at(0), - n.at(1), - }}; - CHECK(result == correct); - } - - TEST_CASE("get_serial_parallel_decomposition (serial)") { - DiGraph g = DiGraph::create(); - std::vector n = add_nodes(g, 2); - g.add_edge(DirectedEdge{n.at(0), n.at(1)}); - - std::optional result = - get_serial_parallel_decomposition(g); - std::optional correct = - SerialParallelDecomposition{SerialSplit{ - n.at(0), - n.at(1), - }}; - CHECK(result == correct); - } - - TEST_CASE("get_serial_parallel_decomposition (composite)") { - DiGraph g = DiGraph::create(); - std::vector n = add_nodes(g, 3); - add_edges(g, - { - DirectedEdge{n.at(0), n.at(1)}, - DirectedEdge{n.at(0), n.at(2)}, - }); - - std::optional result = - get_serial_parallel_decomposition(g); - std::optional correct = - SerialParallelDecomposition{ - SerialSplit{ - n.at(0), - ParallelSplit{ - n.at(1), - n.at(2), - }, - }, - }; - CHECK(result == correct); - } - - TEST_CASE("get_serial_parallel_decomposition (diamond graph)") { - DiGraph g = DiGraph::create(); - - std::vector n = add_nodes(g, 6); - - add_edges(g, - { - DirectedEdge{n.at(0), n.at(1)}, - DirectedEdge{n.at(0), n.at(2)}, - DirectedEdge{n.at(1), n.at(3)}, - DirectedEdge{n.at(2), n.at(4)}, - DirectedEdge{n.at(3), n.at(5)}, - DirectedEdge{n.at(4), n.at(5)}, - }); - - std::optional correct = - SerialParallelDecomposition{SerialSplit{ - n.at(0), - ParallelSplit{ - SerialSplit{ - n.at(1), - n.at(3), - }, - SerialSplit{ - n.at(2), - n.at(4), - }, - }, - n.at(5), - }}; - - std::optional result = - get_serial_parallel_decomposition(g); - - CHECK(result == correct); - } - - TEST_CASE("get_serial_parallel_decomposition (all-to-all connection)") { - DiGraph g = DiGraph::create(); - - std::vector n = add_nodes(g, 4); - - add_edges(g, - { - DirectedEdge{n.at(0), n.at(2)}, - DirectedEdge{n.at(0), n.at(3)}, - DirectedEdge{n.at(1), n.at(2)}, - DirectedEdge{n.at(1), n.at(3)}, - }); - - std::optional correct = - SerialParallelDecomposition{ - SerialSplit{ - ParallelSplit{ - n.at(0), - n.at(1), - }, - ParallelSplit{ - n.at(2), - n.at(3), - }, - }, - }; - - std::optional result = - get_serial_parallel_decomposition(g); - - CHECK(result == correct); - } - - TEST_CASE("get_serial_parallel_decomposition (non-sp graph)") { - DiGraph g = DiGraph::create(); - - std::vector n = add_nodes(g, 4); - - // N-graph - add_edges(g, - { - DirectedEdge{n.at(0), n.at(2)}, - DirectedEdge{n.at(1), n.at(2)}, - DirectedEdge{n.at(1), n.at(3)}, - }); - - std::optional correct = std::nullopt; - std::optional result = - get_serial_parallel_decomposition(g); - - CHECK(result == correct); - } -} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc new file mode 100644 index 0000000000..9ca869b2b0 --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc @@ -0,0 +1,118 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_leaves") { + Node n1 = Node{1}; + Node n2 = Node{2}; + Node n3 = Node{3}; + + GenericBinarySPDecompositionTreeImplementation + impl = generic_impl_for_binary_sp_tree(); + + auto generic_get_leaves = [&](BinarySPDecompositionTree const &tree) { + return get_leaves(tree, impl); + }; + + SUBCASE("leaf") { + BinarySPDecompositionTree input = BinarySPDecompositionTree{n1}; + + std::unordered_multiset result = generic_get_leaves(input); + std::unordered_multiset correct = {n1}; + + CHECK(result == correct); + } + + SUBCASE("series split") { + SUBCASE("children are not the same") { + BinarySPDecompositionTree input = BinarySPDecompositionTree{ + BinarySeriesSplit{ + BinarySPDecompositionTree{n1}, + BinarySPDecompositionTree{n2}, + }, + }; + + std::unordered_multiset result = generic_get_leaves(input); + std::unordered_multiset correct = {n1, n2}; + + CHECK(result == correct); + } + + SUBCASE("children are the same") { + BinarySPDecompositionTree input = BinarySPDecompositionTree{ + BinarySeriesSplit{ + BinarySPDecompositionTree{n1}, + BinarySPDecompositionTree{n1}, + }, + }; + + std::unordered_multiset result = generic_get_leaves(input); + std::unordered_multiset correct = {n1, n1}; + + CHECK(result == correct); + } + } + + SUBCASE("parallel split") { + SUBCASE("children are not the same") { + BinarySPDecompositionTree input = BinarySPDecompositionTree{ + BinaryParallelSplit{ + BinarySPDecompositionTree{n1}, + BinarySPDecompositionTree{n2}, + }, + }; + + std::unordered_multiset result = generic_get_leaves(input); + std::unordered_multiset correct = {n1, n2}; + + CHECK(result == correct); + } + + SUBCASE("children are the same") { + BinarySPDecompositionTree input = BinarySPDecompositionTree{ + BinaryParallelSplit{ + BinarySPDecompositionTree{n1}, + BinarySPDecompositionTree{n1}, + }, + }; + + std::unordered_multiset result = generic_get_leaves(input); + std::unordered_multiset correct = {n1, n1}; + + CHECK(result == correct); + } + } + + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; + + SUBCASE("nested") { + BinarySPDecompositionTree input = make_parallel_split( + make_series_split(make_leaf(n1), + make_series_split(make_leaf(n2), make_leaf(n3))), + make_parallel_split(make_leaf(n2), make_leaf(n1))); + + std::unordered_multiset result = generic_get_leaves(input); + std::unordered_multiset correct = {n1, n1, n2, n2, n3}; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc new file mode 100644 index 0000000000..ad7e1c2609 --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc @@ -0,0 +1,102 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_num_tree_nodes") { + Node n1 = Node{1}; + Node n2 = Node{2}; + Node n3 = Node{3}; + + GenericBinarySPDecompositionTreeImplementation + impl = generic_impl_for_binary_sp_tree(); + + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinaryParallelSplit{lhs, rhs}}; + }; + + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; + + auto generic_get_num_tree_nodes = + [&](BinarySPDecompositionTree const &tree) { + return get_num_tree_nodes(tree, impl); + }; + + SUBCASE("leaf") { + BinarySPDecompositionTree input = make_leaf(n1); + + int result = generic_get_num_tree_nodes(input); + int correct = 1; + + CHECK(result == correct); + } + + SUBCASE("series split") { + SUBCASE("children are not the same") { + BinarySPDecompositionTree input = + make_series_split(make_leaf(n1), make_leaf(n2)); + + int result = generic_get_num_tree_nodes(input); + int correct = 3; + + CHECK(result == correct); + } + + SUBCASE("children are the same") { + BinarySPDecompositionTree input = + make_series_split(make_leaf(n1), make_leaf(n1)); + + int result = generic_get_num_tree_nodes(input); + int correct = 3; + + CHECK(result == correct); + } + } + + SUBCASE("parallel split") { + SUBCASE("children are not the same") { + BinarySPDecompositionTree input = + make_parallel_split(make_leaf(n1), make_leaf(n2)); + + int result = generic_get_num_tree_nodes(input); + int correct = 3; + + CHECK(result == correct); + } + + SUBCASE("children are the same") { + BinarySPDecompositionTree input = + make_parallel_split(make_leaf(n1), make_leaf(n1)); + + int result = generic_get_num_tree_nodes(input); + int correct = 3; + + CHECK(result == correct); + } + } + + SUBCASE("nested") { + BinarySPDecompositionTree input = make_parallel_split( + make_series_split(make_leaf(n1), + make_series_split(make_leaf(n2), make_leaf(n3))), + make_parallel_split(make_leaf(n2), make_leaf(n1))); + + int result = generic_get_num_tree_nodes(input); + int correct = 9; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc new file mode 100644 index 0000000000..3fae155280 --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc @@ -0,0 +1,97 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("is_binary_sp_tree_left_associative") { + Node n1 = Node{1}; + Node n2 = Node{2}; + Node n3 = Node{3}; + Node n4 = Node{4}; + + GenericBinarySPDecompositionTreeImplementation + impl = generic_impl_for_binary_sp_tree(); + + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinaryParallelSplit{lhs, rhs}}; + }; + + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; + + SUBCASE("input is actually left associative") { + SUBCASE("just node") { + BinarySPDecompositionTree input = make_leaf(n1); + + bool result = is_binary_sp_tree_left_associative(input); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("just series") { + BinarySPDecompositionTree input = make_series_split( + make_series_split(make_leaf(n1), make_leaf(n2)), make_leaf(n3)); + + bool result = is_binary_sp_tree_left_associative(input); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("just parallel") { + BinarySPDecompositionTree input = make_parallel_split( + make_parallel_split(make_leaf(n1), make_leaf(n2)), make_leaf(n3)); + + bool result = is_binary_sp_tree_left_associative(input); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("nested") { + BinarySPDecompositionTree input = make_series_split( + make_parallel_split(make_leaf(n1), make_leaf(n2)), + make_parallel_split(make_leaf(n3), make_leaf(n4))); + + bool result = is_binary_sp_tree_left_associative(input); + bool correct = true; + + CHECK(result == correct); + } + } + + SUBCASE("input is not left associative") { + SUBCASE("just series") { + BinarySPDecompositionTree input = make_series_split( + make_leaf(n1), make_series_split(make_leaf(n2), make_leaf(n3))); + + bool result = is_binary_sp_tree_left_associative(input); + bool correct = false; + + CHECK(result == correct); + } + + SUBCASE("just parallel") { + BinarySPDecompositionTree input = make_parallel_split( + make_leaf(n1), make_parallel_split(make_leaf(n2), make_leaf(n3))); + + bool result = is_binary_sp_tree_left_associative(input); + bool correct = false; + + CHECK(result == correct); + } + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc new file mode 100644 index 0000000000..5b4e26107e --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc @@ -0,0 +1,97 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("is_binary_sp_tree_right_associative") { + Node n1 = Node{1}; + Node n2 = Node{2}; + Node n3 = Node{3}; + Node n4 = Node{4}; + + GenericBinarySPDecompositionTreeImplementation + impl = generic_impl_for_binary_sp_tree(); + + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinaryParallelSplit{lhs, rhs}}; + }; + + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; + + SUBCASE("input is actually right associative") { + SUBCASE("just node") { + BinarySPDecompositionTree input = make_leaf(n1); + + bool result = is_binary_sp_tree_right_associative(input); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("just series") { + BinarySPDecompositionTree input = make_series_split( + make_leaf(n1), make_series_split(make_leaf(n2), make_leaf(n3))); + + bool result = is_binary_sp_tree_right_associative(input); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("just parallel") { + BinarySPDecompositionTree input = make_parallel_split( + make_leaf(n1), make_parallel_split(make_leaf(n2), make_leaf(n3))); + + bool result = is_binary_sp_tree_right_associative(input); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("nested") { + BinarySPDecompositionTree input = make_series_split( + make_parallel_split(make_leaf(n1), make_leaf(n2)), + make_parallel_split(make_leaf(n3), make_leaf(n4))); + + bool result = is_binary_sp_tree_right_associative(input); + bool correct = true; + + CHECK(result == correct); + } + } + + SUBCASE("input is not right associative") { + SUBCASE("just series") { + BinarySPDecompositionTree input = make_series_split( + make_series_split(make_leaf(n1), make_leaf(n2)), make_leaf(n3)); + + bool result = is_binary_sp_tree_right_associative(input); + bool correct = false; + + CHECK(result == correct); + } + + SUBCASE("just parallel") { + BinarySPDecompositionTree input = make_parallel_split( + make_parallel_split(make_leaf(n1), make_leaf(n2)), make_leaf(n3)); + + bool result = is_binary_sp_tree_right_associative(input); + bool correct = false; + + CHECK(result == correct); + } + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc new file mode 100644 index 0000000000..fee971e5e0 --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc @@ -0,0 +1,106 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" +#include "test/utils/rapidcheck.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("left_associative_binary_sp_tree_from_nary(" + "SeriesParallelDecomposition)") { + Node n1 = Node{1}; + Node n2 = Node{2}; + Node n3 = Node{3}; + Node n4 = Node{4}; + Node n5 = Node{5}; + Node n6 = Node{6}; + + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinaryParallelSplit{lhs, rhs}}; + }; + + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; + + SUBCASE("only node") { + SeriesParallelDecomposition input = SeriesParallelDecomposition{n1}; + + BinarySPDecompositionTree result = + left_associative_binary_sp_tree_from_nary(input); + BinarySPDecompositionTree correct = make_leaf(n1); + + CHECK(result == correct); + } + + SUBCASE("only serial") { + SeriesParallelDecomposition input = SeriesParallelDecomposition{ + SeriesSplit{{n1, n2, n3}}, + }; + + BinarySPDecompositionTree result = + left_associative_binary_sp_tree_from_nary(input); + + BinarySPDecompositionTree correct = make_series_split( + make_series_split(make_leaf(n1), make_leaf(n2)), make_leaf(n3)); + + CHECK(result == correct); + } + + SUBCASE("only parallel") { + SeriesParallelDecomposition input = SeriesParallelDecomposition{ + ParallelSplit{{n1, n2, n3}}, + }; + + BinarySPDecompositionTree result = + left_associative_binary_sp_tree_from_nary(input); + + // we use multiple checks here because SerialParallelDecomposition's + // ParallelSplit is unordered, so there are multiple possible + // left-associative binary SP trees + CHECK(is_binary_sp_tree_left_associative(result)); + + std::unordered_multiset result_nodes = get_nodes(input); + std::unordered_multiset correct_nodes = {n1, n2, n3}; + + CHECK(result_nodes == correct_nodes); + } + + SUBCASE("nested") { + SeriesParallelDecomposition input = SeriesParallelDecomposition{ + ParallelSplit{{ + n1, + SeriesSplit{{ + n2, + n3, + n3, + n5, + }}, + SeriesSplit{{ + n6, + n4, + }}, + n5, + }}, + }; + + BinarySPDecompositionTree result = + left_associative_binary_sp_tree_from_nary(input); + + CHECK(is_binary_sp_tree_left_associative(result)); + + std::unordered_multiset result_nodes = get_nodes(input); + std::unordered_multiset correct_nodes = { + n1, n2, n3, n3, n5, n6, n4, n5}; + + CHECK(result_nodes == correct_nodes); + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc new file mode 100644 index 0000000000..fd540f853f --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc @@ -0,0 +1,140 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("nary_sp_tree_from_binary(BinarySPDecompositionTree)") { + Node n1 = Node{1}; + Node n2 = Node{2}; + Node n3 = Node{3}; + Node n4 = Node{4}; + Node n5 = Node{5}; + Node n6 = Node{6}; + + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinaryParallelSplit{lhs, rhs}}; + }; + + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; + + SUBCASE("leaf") { + BinarySPDecompositionTree input = make_leaf(n1); + + SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); + SeriesParallelDecomposition correct = SeriesParallelDecomposition{n1}; + + CHECK(result == correct); + } + + SUBCASE("left associative series") { + BinarySPDecompositionTree input = make_series_split( + make_series_split(make_leaf(n2), make_leaf(n1)), make_leaf(n3)); + + SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{SeriesSplit{{n2, n1, n3}}}; + + CHECK(result == correct); + } + + SUBCASE("right associative series") { + BinarySPDecompositionTree input = make_series_split( + make_leaf(n2), make_series_split(make_leaf(n1), make_leaf(n3))); + + SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{SeriesSplit{{n2, n1, n3}}}; + + CHECK(result == correct); + } + + SUBCASE("series with duplicate children") { + BinarySPDecompositionTree input = + make_series_split(make_leaf(n1), make_leaf(n1)); + + SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{SeriesSplit{{n1, n1}}}; + + CHECK(get_nodes(result).size() == 2); + CHECK(result == correct); + } + + SUBCASE("left associative parallel") { + BinarySPDecompositionTree input = make_parallel_split( + make_parallel_split(make_leaf(n2), make_leaf(n1)), make_leaf(n3)); + + SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{ParallelSplit{{n2, n1, n3}}}; + + CHECK(result == correct); + } + + SUBCASE("right associative parallel") { + BinarySPDecompositionTree input = make_parallel_split( + make_leaf(n2), make_parallel_split(make_leaf(n1), make_leaf(n3))); + + SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{ParallelSplit{{n2, n1, n3}}}; + + CHECK(result == correct); + } + + SUBCASE("parallel with duplicate children") { + BinarySPDecompositionTree input = + make_parallel_split(make_leaf(n1), make_leaf(n1)); + + SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{ParallelSplit{{n1, n1}}}; + + CHECK(get_nodes(result).size() == 2); + CHECK(result == correct); + } + + SUBCASE("nested") { + BinarySPDecompositionTree input = make_parallel_split( + make_parallel_split( + make_parallel_split( + make_leaf(n1), + make_series_split( + make_series_split( + make_series_split(make_leaf(n2), make_leaf(n3)), + make_leaf(n3)), + make_leaf(n5))), + make_series_split(make_leaf(n6), make_leaf(n4))), + make_leaf(n5)); + + SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); + SeriesParallelDecomposition correct = SeriesParallelDecomposition{ + ParallelSplit{{ + n1, + SeriesSplit{{ + n2, + n3, + n3, + n5, + }}, + SeriesSplit{{ + n6, + n4, + }}, + n5, + }}, + }; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc new file mode 100644 index 0000000000..532ff86c90 --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc @@ -0,0 +1,104 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("right_associative_binary_sp_tree_from_nary(" + "SeriesParallelDecomposition)") { + Node n1 = Node{1}; + Node n2 = Node{2}; + Node n3 = Node{3}; + Node n4 = Node{4}; + Node n5 = Node{5}; + Node n6 = Node{6}; + + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinaryParallelSplit{lhs, rhs}}; + }; + + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; + + SUBCASE("only node") { + SeriesParallelDecomposition input = SeriesParallelDecomposition{n1}; + + BinarySPDecompositionTree result = + right_associative_binary_sp_tree_from_nary(input); + BinarySPDecompositionTree correct = make_leaf(n1); + + CHECK(result == correct); + } + + SUBCASE("only serial") { + SeriesParallelDecomposition input = SeriesParallelDecomposition{ + SeriesSplit{{n1, n2, n3}}, + }; + + BinarySPDecompositionTree result = + right_associative_binary_sp_tree_from_nary(input); + + BinarySPDecompositionTree correct = make_series_split( + make_leaf(n1), make_series_split(make_leaf(n2), make_leaf(n3))); + + CHECK(result == correct); + } + + SUBCASE("only parallel") { + SeriesParallelDecomposition input = SeriesParallelDecomposition{ + ParallelSplit{{n1, n2, n3}}, + }; + + BinarySPDecompositionTree result = + right_associative_binary_sp_tree_from_nary(input); + + // we use multiple checks here because SerialParallelDecomposition's + // ParallelSplit is unordered, so there are multiple possible + // right-associative binary SP trees + CHECK(is_binary_sp_tree_right_associative(result)); + + std::unordered_multiset result_nodes = get_nodes(input); + std::unordered_multiset correct_nodes = {n1, n2, n3}; + + CHECK(result_nodes == correct_nodes); + } + + SUBCASE("nested") { + SeriesParallelDecomposition input = SeriesParallelDecomposition{ + ParallelSplit{{ + n1, + SeriesSplit{{ + n2, + n3, + n3, + n5, + }}, + SeriesSplit{{ + n6, + n4, + }}, + n5, + }}, + }; + + BinarySPDecompositionTree result = + right_associative_binary_sp_tree_from_nary(input); + + CHECK(is_binary_sp_tree_right_associative(result)); + + std::unordered_multiset result_nodes = get_nodes(input); + std::unordered_multiset correct_nodes = { + n1, n2, n3, n3, n5, n6, n4, n5}; + + CHECK(result_nodes == correct_nodes); + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc b/lib/utils/test/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc new file mode 100644 index 0000000000..e5b9045739 --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc @@ -0,0 +1,192 @@ +#include "utils/graph/series_parallel/get_series_parallel_decomposition.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_series_parallel_decomposition (base case)") { + DiGraph g = DiGraph::create(); + Node n = g.add_node(); + + std::optional result = + get_series_parallel_decomposition(g); + std::optional correct = + SeriesParallelDecomposition{n}; + CHECK(result == correct); + } + + TEST_CASE("get_series_parallel_decomposition (parallel)") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 2); + + std::optional result = + get_series_parallel_decomposition(g); + std::optional correct = + SeriesParallelDecomposition{ParallelSplit{{ + n.at(0), + n.at(1), + }}}; + CHECK(result == correct); + } + + TEST_CASE("get_series_parallel_decomposition (serial)") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 2); + g.add_edge(DirectedEdge{n.at(0), n.at(1)}); + + std::optional result = + get_series_parallel_decomposition(g); + std::optional correct = + SeriesParallelDecomposition{SeriesSplit{{ + n.at(0), + n.at(1), + }}}; + CHECK(result == correct); + } + + TEST_CASE("get_series_parallel_decomposition (composite)") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 3); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + }); + + std::optional result = + get_series_parallel_decomposition(g); + std::optional correct = + SeriesParallelDecomposition{ + SeriesSplit{{ + n.at(0), + ParallelSplit{{ + n.at(1), + n.at(2), + }}, + }}, + }; + CHECK(result == correct); + } + + TEST_CASE("get_series_parallel_decomposition (diamond graph)") { + DiGraph g = DiGraph::create(); + + std::vector n = add_nodes(g, 6); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(4)}, + DirectedEdge{n.at(3), n.at(5)}, + DirectedEdge{n.at(4), n.at(5)}, + }); + + std::optional correct = + SeriesParallelDecomposition{SeriesSplit{{ + n.at(0), + ParallelSplit{{ + SeriesSplit{{ + n.at(1), + n.at(3), + }}, + SeriesSplit{{ + n.at(2), + n.at(4), + }}, + }}, + n.at(5), + }}}; + + std::optional result = + get_series_parallel_decomposition(g); + + CHECK(result == correct); + } + + TEST_CASE("get_series_parallel_decomposition (all-to-all connection)") { + DiGraph g = DiGraph::create(); + + std::vector n = add_nodes(g, 4); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + }); + + std::optional correct = + SeriesParallelDecomposition{ + SeriesSplit{{ + ParallelSplit{{ + n.at(0), + n.at(1), + }}, + ParallelSplit{{ + n.at(2), + n.at(3), + }}, + }}, + }; + + std::optional result = + get_series_parallel_decomposition(g); + + CHECK(result == correct); + } + + TEST_CASE("get_series_parallel_decomposition (non-sp graph)") { + DiGraph g = DiGraph::create(); + + std::vector n = add_nodes(g, 4); + + // N-graph + add_edges(g, + { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + }); + + std::optional correct = std::nullopt; + std::optional result = + get_series_parallel_decomposition(g); + + CHECK(result == correct); + } + + TEST_CASE( + "get_series_parallel_decomposition (requires transitive reduction)") { + DiGraph g = DiGraph::create(); + + std::vector n = add_nodes(g, 4); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(3)}, + }); + + std::optional correct = + SeriesParallelDecomposition{ + SeriesSplit{{ + n.at(0), + n.at(1), + n.at(2), + n.at(3), + }}, + }; + std::optional result = + get_series_parallel_decomposition(g); + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.cc b/lib/utils/test/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc similarity index 83% rename from lib/utils/test/src/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.cc rename to lib/utils/test/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc index 4560f95ff7..3a486c7094 100644 --- a/lib/utils/test/src/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc @@ -1,4 +1,4 @@ -#include "utils/graph/serial_parallel/intermediate_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/intermediate_sp_decomposition_tree.h" #include "utils/fmt/variant.h" #include @@ -8,11 +8,11 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("flatten_ast") { std::variant input = IntermediateSpDecompositionTree{ - SplitType::SERIAL, + SplitType::SERIES, { Node{1}, IntermediateSpDecompositionTree{ - SplitType::SERIAL, + SplitType::SERIES, { Node{2}, Node{3}, @@ -25,7 +25,7 @@ TEST_SUITE(FF_TEST_SUITE) { flatten_ast(input); std::variant correct = IntermediateSpDecompositionTree{ - SplitType::SERIAL, + SplitType::SERIES, { Node{1}, Node{2}, diff --git a/lib/utils/test/src/utils/graph/serial_parallel/parallel_reduction.cc b/lib/utils/test/src/utils/graph/series_parallel/parallel_reduction.cc similarity index 99% rename from lib/utils/test/src/utils/graph/serial_parallel/parallel_reduction.cc rename to lib/utils/test/src/utils/graph/series_parallel/parallel_reduction.cc index 8259d256d3..a62f528bcf 100644 --- a/lib/utils/test/src/utils/graph/serial_parallel/parallel_reduction.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/parallel_reduction.cc @@ -1,4 +1,4 @@ -#include "utils/graph/serial_parallel/parallel_reduction.h" +#include "utils/graph/series_parallel/parallel_reduction.h" #include "utils/graph/instances/adjacency_multidigraph.h" #include "utils/graph/multidigraph/algorithms/add_edges.h" #include "utils/graph/multidigraph/algorithms/add_nodes.h" diff --git a/lib/utils/test/src/utils/graph/serial_parallel/serial_parallel_decomposition.cc b/lib/utils/test/src/utils/graph/series_parallel/series_parallel_decomposition.cc similarity index 66% rename from lib/utils/test/src/utils/graph/serial_parallel/serial_parallel_decomposition.cc rename to lib/utils/test/src/utils/graph/series_parallel/series_parallel_decomposition.cc index 7cf17c3fee..f5766c9fdd 100644 --- a/lib/utils/test/src/utils/graph/serial_parallel/serial_parallel_decomposition.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/series_parallel_decomposition.cc @@ -1,5 +1,5 @@ -#include "utils/graph/serial_parallel/serial_parallel_decomposition.h" -#include "utils/fmt/unordered_set.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" #include using namespace ::FlexFlow; @@ -7,20 +7,20 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("to_final_ast (base case)") { std::variant input = Node{1}; - SerialParallelDecomposition result = to_final_ast(input); - SerialParallelDecomposition correct = SerialParallelDecomposition{Node{1}}; + SeriesParallelDecomposition result = to_final_ast(input); + SeriesParallelDecomposition correct = SeriesParallelDecomposition{Node{1}}; CHECK(result == correct); } TEST_CASE("to_final_ast (serial)") { std::variant input = IntermediateSpDecompositionTree{ - SplitType::SERIAL, + SplitType::SERIES, {Node{1}, Node{2}}, }; - SerialParallelDecomposition result = to_final_ast(input); - SerialParallelDecomposition correct = SerialParallelDecomposition{ - SerialSplit{{ + SeriesParallelDecomposition result = to_final_ast(input); + SeriesParallelDecomposition correct = SeriesParallelDecomposition{ + SeriesSplit{{ Node{1}, Node{2}, }}, @@ -30,11 +30,11 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("to_final_ast (composite)") { std::variant input = IntermediateSpDecompositionTree{ - SplitType::SERIAL, + SplitType::SERIES, { Node{0}, IntermediateSpDecompositionTree{ - SplitType::SERIAL, + SplitType::SERIES, { Node{1}, IntermediateSpDecompositionTree{ @@ -55,9 +55,9 @@ TEST_SUITE(FF_TEST_SUITE) { Node{5}, }}; - SerialParallelDecomposition result = to_final_ast(input); - SerialParallelDecomposition correct = - SerialParallelDecomposition{SerialSplit{{ + SeriesParallelDecomposition result = to_final_ast(input); + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{SeriesSplit{{ Node{0}, Node{1}, ParallelSplit{{ @@ -70,55 +70,55 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == correct); } - TEST_CASE("get_nodes(SerialParallelDecomposition)") { - SerialParallelDecomposition input = - SerialParallelDecomposition{SerialSplit{{ + TEST_CASE("get_nodes(SeriesParallelDecomposition)") { + SeriesParallelDecomposition input = + SeriesParallelDecomposition{SeriesSplit{{ ParallelSplit{{ Node{1}, Node{2}, }}, - Node{3}, + Node{2}, ParallelSplit{{ Node{4}, Node{5}, }}, }}}; - std::unordered_set result = get_nodes(input); - std::unordered_set correct = { + std::unordered_multiset result = get_nodes(input); + std::unordered_multiset correct = { Node{1}, Node{2}, - Node{3}, + Node{2}, Node{4}, Node{5}, }; CHECK(result == correct); } - TEST_CASE("get_nodes(SerialSplit)") { + TEST_CASE("get_nodes(SeriesSplit)") { ParallelSplit input = ParallelSplit{{ Node{1}, - SerialSplit{{ + SeriesSplit{{ Node{2}, ParallelSplit{{ Node{3}, Node{4}, }}, }}, - SerialSplit{{ - Node{5}, + SeriesSplit{{ + Node{1}, Node{6}, }}, Node{7}, }}; - std::unordered_set result = get_nodes(input); - std::unordered_set correct = { + std::unordered_multiset result = get_nodes(input); + std::unordered_multiset correct = { Node{1}, Node{2}, Node{3}, Node{4}, - Node{5}, + Node{1}, Node{6}, Node{7}, }; @@ -129,9 +129,9 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_nodes(ParallelSplit)") { ParallelSplit input = ParallelSplit{{ Node{1}, - SerialSplit{{ + SeriesSplit{{ Node{2}, - Node{3}, + Node{4}, ParallelSplit{{ Node{4}, Node{5}, @@ -139,11 +139,11 @@ TEST_SUITE(FF_TEST_SUITE) { }}, }}; - std::unordered_set result = get_nodes(input); - std::unordered_set correct = { + std::unordered_multiset result = get_nodes(input); + std::unordered_multiset correct = { Node{1}, Node{2}, - Node{3}, + Node{4}, Node{4}, Node{5}, }; @@ -153,8 +153,8 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_nodes(Node)") { Node input = Node{5}; - std::unordered_set result = get_nodes(input); - std::unordered_set correct = {input}; + std::unordered_multiset result = get_nodes(input); + std::unordered_multiset correct = {input}; CHECK(result == correct); } } diff --git a/lib/utils/test/src/utils/graph/serial_parallel/series_reduction.cc b/lib/utils/test/src/utils/graph/series_parallel/series_reduction.cc similarity index 99% rename from lib/utils/test/src/utils/graph/serial_parallel/series_reduction.cc rename to lib/utils/test/src/utils/graph/series_parallel/series_reduction.cc index e4d53b4136..c6b45ec6ce 100644 --- a/lib/utils/test/src/utils/graph/serial_parallel/series_reduction.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/series_reduction.cc @@ -1,4 +1,4 @@ -#include "utils/graph/serial_parallel/series_reduction.h" +#include "utils/graph/series_parallel/series_reduction.h" #include "utils/containers/set_minus.h" #include "utils/graph/instances/adjacency_multidigraph.h" #include "utils/graph/multidigraph/algorithms/add_edges.h" diff --git a/lib/utils/test/src/utils/hash/multiset.cc b/lib/utils/test/src/utils/hash/multiset.cc new file mode 100644 index 0000000000..5c2e01fda8 --- /dev/null +++ b/lib/utils/test/src/utils/hash/multiset.cc @@ -0,0 +1,34 @@ +#include "utils/hash/multiset.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("std::hash>") { + std::multiset input = {1, 2, 2, 1, 5}; + size_t input_hash = get_std_hash(input); + + SUBCASE("same values have the same hash") { + std::multiset also_input = {2, 1, 2, 5, 1}; + size_t also_input_hash = get_std_hash(input); + + CHECK(input_hash == also_input_hash); + } + + SUBCASE("different values have different hashes") { + SUBCASE("different number of duplicates") { + std::multiset other = {1, 2, 2, 1, 5, 5}; + size_t other_hash = get_std_hash(other); + + CHECK(input_hash != other_hash); + } + + SUBCASE("different elements") { + std::multiset other = {1, 2, 2, 1, 6}; + size_t other_hash = get_std_hash(other); + + CHECK(input_hash != other_hash); + } + } + } +} diff --git a/lib/utils/test/src/utils/hash/unordered_multiset.cc b/lib/utils/test/src/utils/hash/unordered_multiset.cc new file mode 100644 index 0000000000..6c730fad3c --- /dev/null +++ b/lib/utils/test/src/utils/hash/unordered_multiset.cc @@ -0,0 +1,34 @@ +#include "utils/hash/unordered_multiset.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("std::hash>") { + std::unordered_multiset input = {1, 2, 2, 1, 5}; + size_t input_hash = get_std_hash(input); + + SUBCASE("same values have the same hash") { + std::unordered_multiset also_input = {2, 1, 2, 5, 1}; + size_t also_input_hash = get_std_hash(input); + + CHECK(input_hash == also_input_hash); + } + + SUBCASE("different values have different hashes") { + SUBCASE("different number of duplicates") { + std::unordered_multiset other = {1, 2, 2, 1, 5, 5}; + size_t other_hash = get_std_hash(other); + + CHECK(input_hash != other_hash); + } + + SUBCASE("different elements") { + std::unordered_multiset other = {1, 2, 2, 1, 6}; + size_t other_hash = get_std_hash(other); + + CHECK(input_hash != other_hash); + } + } + } +} diff --git a/lib/utils/test/src/utils/json/optional.cc b/lib/utils/test/src/utils/json/optional.cc new file mode 100644 index 0000000000..61f5868c53 --- /dev/null +++ b/lib/utils/test/src/utils/json/optional.cc @@ -0,0 +1,49 @@ +#include "utils/json/optional.h" +#include "test/utils/doctest/fmt/optional.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("adl_serializer>") { + SUBCASE("to_json") { + SUBCASE("has value") { + std::optional input = 5; + + nlohmann::json result = input; + nlohmann::json correct = 5; + + CHECK(result == correct); + } + + SUBCASE("has nullopt") { + std::optional input = std::nullopt; + + nlohmann::json result = input; + nlohmann::json correct = nullptr; + + CHECK(result == correct); + } + } + + SUBCASE("from_json") { + SUBCASE("has value") { + nlohmann::json input = 5; + + std::optional result = input; + std::optional correct = 5; + + CHECK(result == correct); + } + + SUBCASE("has nullopt") { + nlohmann::json input = nullptr; + + std::optional result = input.get>(); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + } +} diff --git a/lib/utils/test/src/utils/optional.cc b/lib/utils/test/src/utils/rapidcheck/optional.cc similarity index 67% rename from lib/utils/test/src/utils/optional.cc rename to lib/utils/test/src/utils/rapidcheck/optional.cc index 16c9e964cb..96b17a5400 100644 --- a/lib/utils/test/src/utils/optional.cc +++ b/lib/utils/test/src/utils/rapidcheck/optional.cc @@ -1,7 +1,8 @@ -#include "utils/optional.h" -#include "test/utils/doctest.h" +#include "utils/rapidcheck/optional.h" #include "test/utils/rapidcheck.h" -#include +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE_TEMPLATE( diff --git a/lib/utils/test/src/utils/record_formatter.cc b/lib/utils/test/src/utils/record_formatter.cc index f211f5eb7b..f0d396a123 100644 --- a/lib/utils/test/src/utils/record_formatter.cc +++ b/lib/utils/test/src/utils/record_formatter.cc @@ -1,5 +1,5 @@ #include "utils/record_formatter.h" -#include "test/utils/doctest.h" +#include std::string formatRecord(RecordFormatter const &formatter) { std::ostringstream oss; diff --git a/lib/utils/test/src/utils/sequence.cc b/lib/utils/test/src/utils/sequence.cc index c0a0aaf64e..a758476fd9 100644 --- a/lib/utils/test/src/utils/sequence.cc +++ b/lib/utils/test/src/utils/sequence.cc @@ -1,5 +1,5 @@ #include "utils/sequence.h" -#include "test/utils/doctest.h" +#include using namespace FlexFlow; diff --git a/lib/utils/test/src/utils/stack_string.cc b/lib/utils/test/src/utils/stack_string.cc index 08659375d1..e0da9044c4 100644 --- a/lib/utils/test/src/utils/stack_string.cc +++ b/lib/utils/test/src/utils/stack_string.cc @@ -1,6 +1,6 @@ -#include "utils/stack_string.h" -#include "test/utils/doctest.h" #include "test/utils/rapidcheck.h" +#include "utils/stack_string.h" +#include using namespace FlexFlow; diff --git a/lib/utils/test/src/utils/stack_vector.cc b/lib/utils/test/src/utils/stack_vector.cc index 68e2af3d12..bce068da20 100644 --- a/lib/utils/test/src/utils/stack_vector.cc +++ b/lib/utils/test/src/utils/stack_vector.cc @@ -1,7 +1,6 @@ -#include "utils/stack_vector.h" -#include "test/utils/doctest.h" #include "test/utils/rapidcheck.h" -#include "utils/fmt/vector.h" +#include "utils/stack_vector.h" +#include #include using namespace FlexFlow; diff --git a/lib/utils/test/src/utils/tuple.cc b/lib/utils/test/src/utils/tuple.cc index e01d0c333a..01a1ebca18 100644 --- a/lib/utils/test/src/utils/tuple.cc +++ b/lib/utils/test/src/utils/tuple.cc @@ -1,5 +1,5 @@ #include "utils/tuple.h" -#include "test/utils/doctest.h" +#include #include #include diff --git a/lib/utils/test/src/utils/type_index.cc b/lib/utils/test/src/utils/type_index.cc index a92b9578ea..729ad2797a 100644 --- a/lib/utils/test/src/utils/type_index.cc +++ b/lib/utils/test/src/utils/type_index.cc @@ -1,5 +1,5 @@ #include "utils/type_index.h" -#include "test/utils/doctest.h" +#include #include using namespace FlexFlow; diff --git a/lib/utils/test/src/utils/variant.cc b/lib/utils/test/src/utils/variant.cc index 9a55539ccf..0740a7ecec 100644 --- a/lib/utils/test/src/utils/variant.cc +++ b/lib/utils/test/src/utils/variant.cc @@ -1,8 +1,8 @@ #include "utils/variant.h" -#include "test/utils/doctest.h" #include "test/utils/rapidcheck.h" #include "utils/fmt/optional.h" #include "utils/fmt/variant.h" +#include TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("widen and narrow functions") { diff --git a/lib/utils/test/src/utils/vector.cc b/lib/utils/test/src/utils/vector.cc index e4a8d742f0..c6eb0828b8 100644 --- a/lib/utils/test/src/utils/vector.cc +++ b/lib/utils/test/src/utils/vector.cc @@ -1,5 +1,4 @@ #include "utils/vector.h" -#include "utils/fmt/vector.h" #include TEST_SUITE(FF_TEST_SUITE) {