Skip to content

Commit

Permalink
Migrate over to use type-erased binary tree
Browse files Browse the repository at this point in the history
  • Loading branch information
lockshaw committed Oct 3, 2024
1 parent 597e13c commit 0c2ab05
Show file tree
Hide file tree
Showing 44 changed files with 666 additions and 467 deletions.
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
#ifndef _FLEXFLOW_COMPILER_MACHINE_MAPPING_GET_OPTIMAL_MACHINE_MAPPING_H
#define _FLEXFLOW_COMPILER_MACHINE_MAPPING_GET_OPTIMAL_MACHINE_MAPPING_H

#include "compiler/machine_mapping/machine_mapping_cache.h"
#include "compiler/machine_mapping/machine_mapping_cache.dtg.h"
#include "compiler/machine_mapping/machine_mapping_constraints.dtg.h"
#include "compiler/machine_mapping/machine_mapping_context.dtg.h"
#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h"
#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.dtg.h"
#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.dtg.h"
#include "compiler/machine_mapping/parallel_split_transformation.dtg.h"
#include "compiler/machine_mapping/machine_mapping_cache.dtg.h"
#include "pcg/machine_specification.dtg.h"

namespace FlexFlow {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,13 @@
#ifndef _FLEXFLOW_COMPILER_MACHINE_MAPPING_DP_CACHE_H
#define _FLEXFLOW_COMPILER_MACHINE_MAPPING_DP_CACHE_H
#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_CACHE_H
#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_CACHE_H

#include "compiler/machine_mapping/machine_mapping_state.dtg.h"
#include "compiler/machine_mapping/machine_mapping_result.dtg.h"
#include "utils/optional.h"
#include "compiler/machine_mapping/machine_mapping_cache.dtg.h"

namespace FlexFlow {

class MachineMappingCache {
public:
MachineMappingCache() = default;

std::optional<MachineMappingResult> load(MachineMappingState const &) const;
void save(MachineMappingState const &, MachineMappingResult const &);

private:
std::unordered_map<MachineMappingState, MachineMappingResult> cache;
};
MachineMappingCache empty_machine_mapping_cache();
std::optional<MachineMappingResult> machine_mapping_cache_load(MachineMappingCache const &, MachineMappingState const &);
void machine_mapping_cache_save(MachineMappingCache &, MachineMappingState const &, MachineMappingResult const &);

} // namespace FlexFlow

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
namespace = "FlexFlow"
name = "MachineMappingCache"
features = [
"eq",
"hash",
"fmt",
]

includes = [
"<unordered_map>",
"compiler/machine_mapping/machine_mapping_state.dtg.h",
"compiler/machine_mapping/machine_mapping_result.dtg.h",
]

src_includes = [
"utils/fmt/unordered_map.h",
"utils/hash/unordered_map.h",
]

[[fields]]
name = "raw_map"
type = "std::unordered_map<::FlexFlow::MachineMappingState, ::FlexFlow::MachineMappingResult>"
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ name = "problem_tree"
type = "::FlexFlow::MachineMappingProblemTree"

[[fields]]
name = "resource"
name = "resources"
type = "::FlexFlow::MachineSpecification"

[[fields]]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "compiler/machine_mapping/get_optimal_machine_mapping.h"
#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h"
#include "compiler/machine_mapping/get_machine_resource_splits.h"
#include "compiler/machine_mapping/machine_mapping_cache.h"
#include "compiler/machine_mapping/machine_mapping_constraints.h"
#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h"
#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.h"
Expand Down Expand Up @@ -39,7 +40,7 @@ MachineMappingResult get_optimal_machine_mapping(

{
std::optional<MachineMappingResult> cached_result =
result_cache.load(state);
machine_mapping_cache_load(result_cache, state);
if (cached_result) {
return cached_result.value();
}
Expand Down Expand Up @@ -67,7 +68,7 @@ MachineMappingResult get_optimal_machine_mapping(
},
});

result_cache.save(state, result);
machine_mapping_cache_save(result_cache, state, result);
return result;
}

Expand Down
20 changes: 13 additions & 7 deletions lib/compiler/src/compiler/machine_mapping/machine_mapping_cache.cc
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
#include "compiler/machine_mapping/machine_mapping_cache.h"
#include "utils/containers/try_at.h"
#include "utils/containers/contains_key.h"

namespace FlexFlow {

std::optional<MachineMappingResult>
MachineMappingCache::load(MachineMappingState const &state) const {
return try_at(this->cache, state);
MachineMappingCache empty_machine_mapping_cache() {
return MachineMappingCache{{}};
}

void MachineMappingCache::save(MachineMappingState const &state,
MachineMappingResult const &result) {
assert(!contains_key(cache, state));
cache.emplace(state, result);
std::optional<MachineMappingResult> machine_mapping_cache_load(MachineMappingCache const &cache, MachineMappingState const &k) {
return try_at(cache.raw_map, k);
}

void machine_mapping_cache_save(MachineMappingCache &cache, MachineMappingState const &k, MachineMappingResult const &v) {
if (contains_key(cache.raw_map, k)) {
throw mk_runtime_error(fmt::format("machine_mapping_cache_save expected key to not already exist, but received existing key {}", k));
}

cache.raw_map.emplace(k, v);
}

} // namespace FlexFlow
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "./cost_estimator_for_test.h"
#include <doctest/doctest.h>
#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h"
#include "compiler/machine_mapping/machine_mapping_cache.h"
#include "compiler/machine_mapping/machine_mapping_constraints.h"
#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h"
#include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h"
Expand Down Expand Up @@ -109,7 +110,7 @@ TEST_SUITE(FF_TEST_SUITE) {
allowed_machine_views1,
};

MachineMappingCache cache;
MachineMappingCache cache = empty_machine_mapping_cache();

SUBCASE("single layer") {
MachineMappingProblemTree problem_tree =
Expand Down Expand Up @@ -206,9 +207,5 @@ TEST_SUITE(FF_TEST_SUITE) {

CHECK(result == correct);
}

SUBCASE("multiple edges across split") {
FAIL("TODO");
}
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ TEST_SUITE(FF_TEST_SUITE) {
UnmappedOpCostEstimateKey input2_key = make_input_key(input_shape);

PCGBinarySPDecomposition sp_decomposition = \
make_pcg_series_split(
make_pcg_parallel_split(
make_pcg_leaf_node(input1_layer),
make_pcg_leaf_node(input2_layer));

Expand Down
26 changes: 2 additions & 24 deletions lib/utils/include/utils/full_binary_tree/find_paths_to_leaf.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,15 @@
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_FIND_PATHS_TO_LEAF_H

#include "utils/full_binary_tree/binary_tree_path.dtg.h"
#include "utils/full_binary_tree/binary_tree_path.h"
#include "utils/full_binary_tree/full_binary_tree.h"
#include "utils/full_binary_tree/visit.h"
#include "utils/full_binary_tree/full_binary_tree.dtg.h"
#include <unordered_set>
#include "utils/overload.h"
#include "utils/containers/transform.h"
#include "utils/containers/set_union.h"

namespace FlexFlow {

template <typename ParentLabel, typename LeafLabel>
std::unordered_set<BinaryTreePath> find_paths_to_leaf(FullBinaryTree<ParentLabel, LeafLabel> const &tree,
LeafLabel const &leaf) {
return visit<std::unordered_set<BinaryTreePath>>(
tree,
overload {
[&](LeafLabel const &l) -> std::unordered_set<BinaryTreePath> {
if (l == leaf) {
return {binary_tree_root_path()};
} else {
return {};
}
},
[&](FullBinaryTreeParentNode<ParentLabel, LeafLabel> const &parent) {
return set_union(
transform(find_paths_to_leaf(get_left_child(parent), leaf),
nest_inside_left_child),
transform(find_paths_to_leaf(get_right_child(parent), leaf),
nest_inside_right_child));
}
});
return find_paths_to_leaf(tree.raw_tree, make_any_value_type<LeafLabel>(leaf));
}

} // namespace FlexFlow
Expand Down
47 changes: 0 additions & 47 deletions lib/utils/include/utils/full_binary_tree/fmt.h

This file was deleted.

Loading

0 comments on commit 0c2ab05

Please sign in to comment.