Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unity device mapping algorithm #1459

Merged
merged 35 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
950da68
pass existing tests
wmdi Aug 5, 2024
d36d1ea
unity algorithm builds
wmdi Aug 7, 2024
c71b773
fmt
wmdi Aug 7, 2024
678e990
fix
wmdi Aug 21, 2024
ee2429c
Merge branch 'repo-refactor' into compiler_test
lockshaw Aug 22, 2024
49a4a0c
Merge remote-tracking branch 'flexflow/repo-refactor' into compiler_test
wmdi Aug 27, 2024
8e27f2a
refactor machine mapping
wmdi Aug 27, 2024
8191bcd
Merge branch 'compiler_test' of github.com:wmdi/FlexFlow into compile…
wmdi Aug 27, 2024
9b9f529
add unit tests
wmdi Aug 29, 2024
150ca5e
fmt
wmdi Aug 29, 2024
fc388ce
add more tests
wmdi Sep 2, 2024
e628c72
fmt
wmdi Sep 2, 2024
8eff2b9
fix
wmdi Sep 4, 2024
0d409e9
Merge remote-tracking branch 'flexflow/repo-refactor' into compiler_test
wmdi Sep 11, 2024
7c03f24
refactor get_optimal_machine_mapping a bit and improve the tests
wmdi Sep 12, 2024
89ed108
remove debug codes
wmdi Sep 12, 2024
a112225
Merge remote-tracking branch 'origin/repo-refactor' into wmdi-compile…
lockshaw Sep 18, 2024
3a35951
A lot of simplifying and modularizing of unity dp code
lockshaw Sep 26, 2024
00c2bae
Get tests building again
lockshaw Sep 27, 2024
3d9c9bb
Merge remote-tracking branch 'origin/repo-refactor' into wmdi-compile…
lockshaw Sep 27, 2024
7e73162
Get all the new testcases working
lockshaw Sep 28, 2024
bdcc10e
Move over to ProblemTree/ResultTree framework for machine mapping
lockshaw Sep 29, 2024
b0475b4
Settle on ProblemTree/BinaryTreePath-indexed-MachineMappingResult for…
lockshaw Sep 30, 2024
2bbec5c
More code cleanup and PR prep
lockshaw Oct 1, 2024
85fd5b4
Get tests building again
lockshaw Oct 2, 2024
597e13c
Pass some basic tests of get_optimal_machine_mapping
lockshaw Oct 3, 2024
0c2ab05
Migrate over to use type-erased binary tree
lockshaw Oct 3, 2024
e4073bc
Move back to templated FullBinaryTree
lockshaw Oct 3, 2024
5d22c6d
Get all existing tests passing again
lockshaw Oct 4, 2024
3d08831
Fix tests and format
lockshaw Oct 5, 2024
4b180df
Move graph_optimize_state.cc to correct location
lockshaw Oct 5, 2024
dcd2e13
Further code simplification and polishing
lockshaw Oct 7, 2024
39c8f1c
Pass all tests
lockshaw Oct 8, 2024
75f7e98
Remove a bunch of unnecessary code
lockshaw Oct 8, 2024
a2b8832
Format
lockshaw Oct 8, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
6 changes: 3 additions & 3 deletions .github/workflows/per-lib-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@ jobs:
run: |
test_target.sh substitutions

# - name: Test compiler
# run: |
# test_target.sh compiler
- name: Test compiler
run: |
test_target.sh compiler

- name: Test substitution-generator
run: |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,22 @@ name = "JsonSPModelExport"
features = [
"eq",
"hash",
"json",
"fmt",
"json",
]

includes = [
"pcg/file_format/v1/v1_computation_graph.dtg.h",
"utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h",
"pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.dtg.h",
]

src_includes = [
"utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h",
"utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h",
"utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.h",
"pcg/file_format/v1/v1_binary_sp_decomposition/json.h",
]

[[fields]]
name = "sp_decomposition"
type = "::FlexFlow::GenericBinarySPDecompositionTree<int>"
type = "::FlexFlow::V1BinarySPDecomposition"

[[fields]]
name = "computation_graph"
Expand Down
10 changes: 4 additions & 6 deletions bin/export-model-arch/src/export_model_arch.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include "compiler/series_parallel/computation_graph_binary_sp_decomposition.h"
#include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h"
#include "compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.h"
#include "compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.h"
#include "export_model_arch/json_sp_model_export.dtg.h"
#include "models/bert/bert.h"
#include "models/candle_uno/candle_uno.h"
Expand All @@ -13,7 +13,6 @@
#include "utils/cli/cli_parse.h"
#include "utils/cli/cli_parse_result.h"
#include "utils/cli/cli_spec.h"
#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h"
#include "utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h"
#include "utils/graph/series_parallel/get_series_parallel_decomposition.h"

Expand Down Expand Up @@ -105,9 +104,8 @@ tl::expected<JsonSPModelExport, std::string>
to_v1_including_node_numbering(computation_graph);
V1ComputationGraph v1_cg = v1_result.first;
bidict<int, layer_guid_t> layer_numbering = v1_result.second;
GenericBinarySPDecompositionTree<int> v1_sp_decomposition =
transform(sp_decomposition.raw_tree,
[&](layer_guid_t const &l) { return layer_numbering.at_r(l); });
V1BinarySPDecomposition v1_sp_decomposition =
to_v1(sp_decomposition, layer_numbering);

return JsonSPModelExport{
v1_sp_decomposition,
Expand Down
6 changes: 3 additions & 3 deletions flake.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

61 changes: 0 additions & 61 deletions lib/compiler/include/compiler/cost_estimate.h

This file was deleted.

45 changes: 45 additions & 0 deletions lib/compiler/include/compiler/cost_estimator/cost_estimator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_COST_ESTIMATOR_COST_ESTIMATOR_H
#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_COST_ESTIMATOR_COST_ESTIMATOR_H

#include "compiler/cost_estimator/op_cost_estimate_key.dtg.h"
#include "compiler/cost_estimator/tensor_set_movement.dtg.h"
#include "op-attrs/parallel_tensor_shape.dtg.h"
#include "op-attrs/pcg_operator_attrs.dtg.h"
#include "pcg/machine_view.dtg.h"
#include <vector>

namespace FlexFlow {

struct ICostEstimator {
virtual float estimate_cost(OpCostEstimateKey const &) const = 0;
virtual float estimate_cost(TensorSetMovement const &) const = 0;

ICostEstimator() = default;
ICostEstimator(ICostEstimator const &) = delete;
ICostEstimator &operator=(ICostEstimator const &) = delete;

virtual ~ICostEstimator() = default;
};
CHECK_RC_COPY_VIRTUAL_COMPLIANT(ICostEstimator);

struct CostEstimator {
float estimate_cost(OpCostEstimateKey const &k) const;
float estimate_cost(TensorSetMovement const &m) const;

template <typename T, typename... Args>
static typename std::enable_if<std::is_base_of<ICostEstimator, T>::value,
CostEstimator>::type
create(Args &&...args) {
return CostEstimator(std::make_shared<T>(std::forward<Args>(args)...));
}

private:
CostEstimator(std::shared_ptr<ICostEstimator> implementation_ptr);

private:
std::shared_ptr<ICostEstimator> implementation_ptr;
};

} // namespace FlexFlow

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
namespace = "FlexFlow"
name = "OpCostEstimateKey"
features = [
"eq",
"ord",
"fmt",
"hash",
]

includes = [
"op-attrs/pcg_operator_attrs.dtg.h",
"op-attrs/parallel_tensor_shape.dtg.h",
"<vector>",
"pcg/machine_view.dtg.h",
]

src_includes = [
"utils/hash/vector.h",
"utils/fmt/vector.h",
]

[[fields]]
name = "op_attrs"
type = "::FlexFlow::PCGOperatorAttrs"

[[fields]]
name = "input_shapes"
type = "std::vector<::FlexFlow::ParallelTensorShape>"

[[fields]]
name = "weight_shapes"
type = "std::vector<::FlexFlow::ParallelTensorShape>"

[[fields]]
name = "output_shapes"
type = "std::vector<::FlexFlow::ParallelTensorShape>"

[[fields]]
name = "machine_view"
type = "::FlexFlow::MachineView"
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
namespace = "FlexFlow"
name = "SingleTensorMovement"
features = [
"eq",
"hash",
"fmt",
]

includes = [
"op-attrs/parallel_tensor_shape.dtg.h",
"pcg/machine_view.dtg.h",
"<unordered_set>",
]

src_includes = [
"utils/hash/unordered_set.h",
"utils/fmt/unordered_set.h",
]

[[fields]]
name = "parallel_tensor_shape"
type = "::FlexFlow::ParallelTensorShape"

[[fields]]
name = "src_machine_views"
type = "std::unordered_set<::FlexFlow::MachineView>"

[[fields]]
name = "dst_machine_views"
type = "std::unordered_set<::FlexFlow::MachineView>"
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
namespace = "FlexFlow"
name = "TensorSetMovement"
features = [
"eq",
"hash",
"fmt",
]

includes = [
"compiler/cost_estimator/single_tensor_movement.dtg.h",
"<unordered_set>",
]

src_includes = [
"utils/fmt/unordered_multiset.h",
"utils/hash/unordered_multiset.h",
]

[[fields]]
name = "single_tensor_movements"
type = "std::unordered_multiset<::FlexFlow::SingleTensorMovement>"
16 changes: 16 additions & 0 deletions lib/compiler/include/compiler/graph_optimize_result.struct.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
namespace = "FlexFlow"
name = "GraphOptimizeResult"
features = [ ]

includes = [
"compiler/machine_mapping/machine_mapping.dtg.h",
"pcg/parallel_computation_graph/parallel_computation_graph.h"
]

[[fields]]
name = "pcg"
type = "::FlexFlow::ParallelComputationGraph"

[[fields]]
name = "machine_mapping"
type = "::FlexFlow::MachineMapping"
31 changes: 31 additions & 0 deletions lib/compiler/include/compiler/graph_optimize_state.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#ifndef _FLEXFLOW_COMPILER_MCMC_STATE_H
#define _FLEXFLOW_COMPILER_MCMC_STATE_H

#include "compiler/graph_optimize_result.dtg.h"

namespace FlexFlow {

struct GraphOptimizeState {
GraphOptimizeState(GraphOptimizeResult const &graph_optimize_result,
float runtime);

GraphOptimizeResult graph_optimize_result;
float runtime;

bool operator==(GraphOptimizeState const &other) const;
bool operator!=(GraphOptimizeState const &other) const;
bool operator<(GraphOptimizeState const &other) const;
};

} // namespace FlexFlow

namespace std {

template <>
struct hash<::FlexFlow::GraphOptimizeState> {
size_t operator()(::FlexFlow::GraphOptimizeState const &) const;
};

} // namespace std

#endif
37 changes: 0 additions & 37 deletions lib/compiler/include/compiler/graph_utils.h

This file was deleted.

Loading
Loading