From 5e1f349215294ae312ccac9f2891aecacaf95e39 Mon Sep 17 00:00:00 2001 From: Reyna Abhyankar Date: Thu, 22 Aug 2024 14:59:04 -0700 Subject: [PATCH] Local execution tests (#1418) * pr for debugging kernel driver issues * Commit flake files * current kernel tests * softmax, flat, transpose kernel tests * clang formatting kernel tests * reverse, split, full dropout kernels * rest of kernel-tests * minor cleannup * Restore .proj.toml * Delete misadded directory * merge fix * more merge fixes * resolved merge conflicts with repo-refactor * code review changes * allocator updates * allocation util updates * test clean up and review fixes * fixed forward backward pass consistencies, added filler tests for all tests, other review changes * unnested test subcases and more review changes * Add == in OpTaskBinding * Add single operator test example * Finish multi operator test * added managed_stream and handle classes, other minor clean up * fix accessor and corresponding shape clarity, other clean up * merge error fixes * More aggressive subcasing * Remove comment * managed handle and stream fixes, removed datatype dispatch from cuda_helper, other clean up * managed handle and stream updates * Refactoring and split tests * Fix build * Fix build * Add cuda test suite * Remove mock * Pass task registry * Pass slots backing and task arg acc * Pass cost estimator test * Fix * PR fixes * Fixes * Add test to ci * Fix test libs * Fix build, add more fmt placeholders * Fixes * Fixes * Delete file * Fixes * Fixes * Fixes * Fix includes * Fix includes --------- Co-authored-by: Dylan Lim Co-authored-by: Dylan Lim <72822184+oOTigger@users.noreply.github.com> Co-authored-by: Colin Unger --- .github/workflows/helpers/test_libs.sh | 2 +- .github/workflows/per-lib-check.yml | 4 + .proj.toml | 4 +- cmake/flexflow-utils.cmake | 2 +- lib/kernels/CMakeLists.txt | 2 +- lib/kernels/include/kernels/accessor.h | 17 + lib/kernels/include/kernels/array_shape.h | 6 +- .../include/kernels/attention_kernels.h | 22 + lib/kernels/include/kernels/device.h | 4 +- lib/kernels/include/kernels/ff_handle.h | 3 + lib/kernels/include/kernels/legion_dim.h | 17 + lib/kernels/include/kernels/profiling.h | 14 +- .../kernels/profiling_settings.struct.toml | 18 + lib/kernels/src/accessor.cc | 29 ++ lib/kernels/src/array_shape.cc | 37 +- lib/kernels/src/cuda/cuda_helper.cu | 27 +- lib/kernels/src/cuda/ops/attention_kernels.cu | 35 ++ lib/kernels/src/device.cc | 4 + lib/kernels/src/device.h | 12 +- lib/kernels/src/ff_handle.cc | 18 + lib/kernels/src/local_cuda_allocator.cc | 2 +- lib/kernels/test/src/test_cast_kernel.cc | 1 - lib/local-execution/CMakeLists.txt | 2 + .../include/local-execution/arg_ref.h | 13 +- .../include/local-execution/concrete_arg.h | 9 + .../local-execution/cost_details.struct.toml | 1 - .../include/local-execution/device_specific.h | 22 +- ...device_specific_device_states.variant.toml | 90 ++++ .../include/local-execution/device_states.h | 40 -- .../fwd_bwd_task_impl_function.h | 32 ++ .../local-execution/init_task_impl_function.h | 33 ++ .../local-execution/local_cpu_allocator.h | 22 + .../local-execution/local_slots_backing.h | 23 +- .../local_task_argument_accessor.h | 10 + .../local-execution/local_training_backing.h | 14 +- .../include/local-execution/op_arg_ref.h | 20 +- .../op_arg_ref_type.variant.toml | 22 + .../local-execution/op_arg_spec.variant.toml | 2 +- .../local-execution/op_task_invocation.h | 17 +- .../local-execution/op_task_signature.h | 34 +- .../local-execution}/ops/attention.h | 0 ...parallel_tensor_shape_ref_type.struct.toml | 14 + .../local-execution/per_device_op_state.h | 15 + .../per_device_op_state.variant.toml | 87 ++++ .../per_device_op_state_ref_type.struct.toml | 12 + .../include/local-execution/runtime_arg_ref.h | 1 + .../local-execution/task_argument_accessor.h | 16 +- .../local-execution/task_id_t.enum.toml | 431 ++++++++++++++++ .../task_impl_function.variant.toml | 21 + .../include/local-execution/task_registry.h | 21 +- .../local-execution/task_registry.struct.toml | 34 ++ .../local-execution/task_signature_impl.h | 23 +- .../task_signature_impl.struct.toml | 20 + .../include/local-execution/tasks.h | 160 +----- .../local-execution/tracked_allocator.h | 1 + lib/local-execution/src/concrete_arg.cc | 29 ++ .../src/fwd_bwd_task_impl_function.cc | 54 ++ .../src/init_task_impl_function.cc | 47 ++ .../src/legion_tensor_shape.cc | 13 + .../src/local_cost_estimator.cc | 3 +- .../src/local_cpu_allocator.cc | 24 + .../src/local_slots_backing.cc | 70 ++- .../src/local_task_argument_accessor.cc | 34 ++ .../src/local_training_backing.cc | 111 ++-- lib/local-execution/src/op_arg_ref.cc | 2 +- .../src/{local-execution => }/op_arg_spec.cc | 0 lib/local-execution/src/op_task_invocation.cc | 31 +- lib/local-execution/src/op_task_signature.cc | 69 ++- lib/local-execution/src/ops/attention.cc | 39 +- lib/local-execution/src/ops/batch_matmul.cc | 14 +- lib/local-execution/src/ops/batch_norm.cc | 23 +- lib/local-execution/src/ops/cast.cc | 10 +- lib/local-execution/src/ops/combine.cc | 11 +- lib/local-execution/src/ops/concat.cc | 13 +- lib/local-execution/src/ops/conv_2d.cc | 21 +- lib/local-execution/src/ops/dropout.cc | 21 +- lib/local-execution/src/ops/dropout.h | 2 +- lib/local-execution/src/ops/element_binary.cc | 23 +- lib/local-execution/src/ops/element_unary.cc | 23 +- lib/local-execution/src/ops/flat.cc | 10 +- lib/local-execution/src/ops/gather.cc | 21 +- lib/local-execution/src/ops/input.cc | 9 + lib/local-execution/src/ops/input.h | 13 + lib/local-execution/src/ops/layer_norm.cc | 21 +- lib/local-execution/src/ops/linear.cc | 41 +- lib/local-execution/src/ops/noop.cc | 14 +- lib/local-execution/src/ops/noop.h | 6 +- lib/local-execution/src/ops/pool_2d.cc | 57 ++- lib/local-execution/src/ops/reduce.cc | 21 +- lib/local-execution/src/ops/reduction.cc | 10 +- lib/local-execution/src/ops/repartition.cc | 23 +- lib/local-execution/src/ops/replicate.cc | 10 +- lib/local-execution/src/ops/reshape.cc | 21 +- lib/local-execution/src/ops/reverse.cc | 10 +- lib/local-execution/src/ops/softmax.cc | 21 +- lib/local-execution/src/ops/split.cc | 10 +- lib/local-execution/src/ops/topk.cc | 21 +- lib/local-execution/src/ops/transpose.cc | 21 +- lib/local-execution/src/ops/weight.cc | 9 + lib/local-execution/src/ops/weight.h | 13 + lib/local-execution/src/per_device_state.cc | 12 + lib/local-execution/src/task_registry.cc | 59 ++- .../src/task_signature_impl.cc | 478 ++++++++++++------ lib/local-execution/src/tracked_allocator.cc | 3 +- lib/local-execution/test/CMakeLists.txt | 14 + .../test/src/test_local_cost_estimator.cc | 77 +++ .../test/src/test_local_slots_backing.cc | 273 ++++++++++ .../test/src/test_local_task_arg_accessor.cc | 143 ++++++ .../test/src/test_task_registry.cc | 131 +++++ lib/local-execution/test/src/test_utils.cc | 9 + lib/local-execution/test/src/test_utils.h | 12 + lib/op-attrs/src/op-attrs/ops/attention.cc | 48 +- lib/pcg/include/pcg/computation_graph.h | 3 + .../include/pcg/computation_graph_builder.h | 11 +- lib/pcg/src/pcg/computation_graph.cc | 10 + lib/pcg/src/pcg/computation_graph_builder.cc | 41 ++ .../src/ops/embedding.cc | 0 lib/utils/include/utils/join_strings.h | 2 +- 118 files changed, 3055 insertions(+), 882 deletions(-) create mode 100644 lib/kernels/include/kernels/profiling_settings.struct.toml create mode 100644 lib/kernels/src/ff_handle.cc create mode 100644 lib/local-execution/include/local-execution/device_specific_device_states.variant.toml delete mode 100644 lib/local-execution/include/local-execution/device_states.h create mode 100644 lib/local-execution/include/local-execution/fwd_bwd_task_impl_function.h create mode 100644 lib/local-execution/include/local-execution/init_task_impl_function.h create mode 100644 lib/local-execution/include/local-execution/local_cpu_allocator.h create mode 100644 lib/local-execution/include/local-execution/op_arg_ref_type.variant.toml rename lib/local-execution/{src => include/local-execution}/ops/attention.h (100%) create mode 100644 lib/local-execution/include/local-execution/parallel_tensor_shape_ref_type.struct.toml create mode 100644 lib/local-execution/include/local-execution/per_device_op_state.h create mode 100644 lib/local-execution/include/local-execution/per_device_op_state.variant.toml create mode 100644 lib/local-execution/include/local-execution/per_device_op_state_ref_type.struct.toml create mode 100644 lib/local-execution/include/local-execution/task_id_t.enum.toml create mode 100644 lib/local-execution/include/local-execution/task_impl_function.variant.toml create mode 100644 lib/local-execution/include/local-execution/task_registry.struct.toml create mode 100644 lib/local-execution/include/local-execution/task_signature_impl.struct.toml create mode 100644 lib/local-execution/src/concrete_arg.cc create mode 100644 lib/local-execution/src/fwd_bwd_task_impl_function.cc create mode 100644 lib/local-execution/src/init_task_impl_function.cc create mode 100644 lib/local-execution/src/legion_tensor_shape.cc create mode 100644 lib/local-execution/src/local_cpu_allocator.cc rename lib/local-execution/src/{local-execution => }/op_arg_spec.cc (100%) create mode 100644 lib/local-execution/src/ops/input.cc create mode 100644 lib/local-execution/src/ops/input.h create mode 100644 lib/local-execution/src/ops/weight.cc create mode 100644 lib/local-execution/src/ops/weight.h create mode 100644 lib/local-execution/src/per_device_state.cc create mode 100644 lib/local-execution/test/CMakeLists.txt create mode 100644 lib/local-execution/test/src/test_local_cost_estimator.cc create mode 100644 lib/local-execution/test/src/test_local_slots_backing.cc create mode 100644 lib/local-execution/test/src/test_local_task_arg_accessor.cc create mode 100644 lib/local-execution/test/src/test_task_registry.cc create mode 100644 lib/local-execution/test/src/test_utils.cc create mode 100644 lib/local-execution/test/src/test_utils.h rename lib/{local-execution => runtime}/src/ops/embedding.cc (100%) diff --git a/.github/workflows/helpers/test_libs.sh b/.github/workflows/helpers/test_libs.sh index 7662a7e601..69baa66364 100755 --- a/.github/workflows/helpers/test_libs.sh +++ b/.github/workflows/helpers/test_libs.sh @@ -7,7 +7,7 @@ DIR="$(realpath -- "$(dirname "${BASH_SOURCE[0]}")")" REPO="$(realpath -- "$DIR/../../../")" TEST_LIBS=("${@/%/-tests}") -REGEX="^$(IFS='|'; echo "${TEST_LIBS[*]}")\$" +REGEX="^($(IFS='|'; echo "${TEST_LIBS[*]}"))\$" cd "$REPO/build-ci" make -j $(( $(nproc) < 2 ? 1 : $(nproc)-1 )) "${TEST_LIBS[@]}" diff --git a/.github/workflows/per-lib-check.yml b/.github/workflows/per-lib-check.yml index 35860fbcec..38556a3c0e 100644 --- a/.github/workflows/per-lib-check.yml +++ b/.github/workflows/per-lib-check.yml @@ -116,6 +116,10 @@ jobs: run: | test_libs.sh substitution-generator + - name: Test local-execution + run: | + test_libs.sh local-execution + - name: Generate code coverage run: | echo "gitwork: $GITHUB_WORKSPACE" diff --git a/.proj.toml b/.proj.toml index 14bdcdb3b7..ee91d07833 100644 --- a/.proj.toml +++ b/.proj.toml @@ -11,16 +11,18 @@ build_targets = [ "substitutions", "compiler", "substitution-generator", - "local-execution", + "local-execution", ] test_targets = [ + # "kernels-tests", "utils-tests", "op-attrs-tests", "pcg-tests", "substitutions-tests", "compiler-tests", "substitution-generator-tests", + "local-execution-tests" ] [cmake_flags_extra] diff --git a/cmake/flexflow-utils.cmake b/cmake/flexflow-utils.cmake index 32798e6833..1dbd16bdb1 100644 --- a/cmake/flexflow-utils.cmake +++ b/cmake/flexflow-utils.cmake @@ -118,7 +118,7 @@ function(ff_add_test_executable) ${FF_TEST_EXEC_NAME} ${FF_TEST_EXEC_DEPS}) - target_compile_definitions(${FF_TEST_EXEC_NAME} PRIVATE FF_TEST_SUITE="${FF_TEST_EXEC_NAME}") + target_compile_definitions(${FF_TEST_EXEC_NAME} PRIVATE FF_TEST_SUITE="${FF_TEST_EXEC_NAME}" FF_CUDA_TEST_SUITE="cuda-${FF_TEST_EXEC_NAME}") define_ff_vars(${FF_TEST_EXEC_NAME}) ff_set_cxx_properties(${FF_TEST_EXEC_NAME}) diff --git a/lib/kernels/CMakeLists.txt b/lib/kernels/CMakeLists.txt index f166dd027c..8ccd7c1011 100644 --- a/lib/kernels/CMakeLists.txt +++ b/lib/kernels/CMakeLists.txt @@ -40,4 +40,4 @@ set_target_properties( CUDA_STANDARD 17 ) -add_subdirectory(test) \ No newline at end of file +add_subdirectory(test) diff --git a/lib/kernels/include/kernels/accessor.h b/lib/kernels/include/kernels/accessor.h index 1ef121fb2a..2ee081ecbc 100644 --- a/lib/kernels/include/kernels/accessor.h +++ b/lib/kernels/include/kernels/accessor.h @@ -145,6 +145,22 @@ std::vector const *> GenericTensorAccessorR read_only_accessor_from_write_accessor( GenericTensorAccessorW const &write_accessor); +bool is_shape_and_dtype_equal(GenericTensorAccessorW const &acc1, + GenericTensorAccessorW const &acc2); + +bool shape_and_dtype_matches(GenericTensorAccessorW const &accessor, + ArrayShape const &expected_shape, + DataType const &expected_dtype); + +bool shape_and_dtype_matches(GenericTensorAccessorR const &accessor, + ArrayShape const &expected_shape, + DataType const &expected_dtype); + +std::pair + get_shape_and_datatype(GenericTensorAccessorR const &accessor); +std::pair + get_shape_and_datatype(GenericTensorAccessorW const &accessor); + } // namespace FlexFlow namespace FlexFlow { @@ -152,6 +168,7 @@ static_assert(is_well_behaved_value_type_no_hash::value, ""); static_assert(is_well_behaved_value_type_no_hash::value, ""); + } // namespace FlexFlow #endif diff --git a/lib/kernels/include/kernels/array_shape.h b/lib/kernels/include/kernels/array_shape.h index 5427d25bc3..5de9fae7ad 100644 --- a/lib/kernels/include/kernels/array_shape.h +++ b/lib/kernels/include/kernels/array_shape.h @@ -39,7 +39,8 @@ struct ArrayShape { legion_dim_t last_idx() const; legion_dim_t neg_idx(int) const; - std::optional at_maybe(std::size_t) const; + std::optional at_maybe(legion_dim_t) const; + std::optional at_maybe(ff_dim_t) const; ArrayShape sub_shape(std::optional> start, @@ -54,6 +55,9 @@ size_t get_volume(ArrayShape const &); TensorShape get_tensor_shape(ArrayShape const &, DataType); +std::string format_as(ArrayShape const &); +std::ostream &operator<<(std::ostream &, ArrayShape const &); + } // namespace FlexFlow #endif diff --git a/lib/kernels/include/kernels/attention_kernels.h b/lib/kernels/include/kernels/attention_kernels.h index de37b4169f..575de57f09 100644 --- a/lib/kernels/include/kernels/attention_kernels.h +++ b/lib/kernels/include/kernels/attention_kernels.h @@ -25,6 +25,25 @@ struct MHAPerDeviceState { int *hiWinIdx; void *reserveSpace; Allocator allocator; + + bool operator==(MHAPerDeviceState const &other) const; + bool operator!=(MHAPerDeviceState const &other) const; + +private: + std::tuple + tie() const; }; FF_VISITABLE_STRUCT_NO_EQ(MHAPerDeviceState, @@ -43,6 +62,9 @@ FF_VISITABLE_STRUCT_NO_EQ(MHAPerDeviceState, reserveSpace, allocator); +std::string format_as(MHAPerDeviceState const &x); +std::ostream &operator<<(std::ostream &s, MHAPerDeviceState const &x); + namespace Kernels { namespace MultiHeadAttention { diff --git a/lib/kernels/include/kernels/device.h b/lib/kernels/include/kernels/device.h index c4e78821dc..cf4329774d 100644 --- a/lib/kernels/include/kernels/device.h +++ b/lib/kernels/include/kernels/device.h @@ -95,11 +95,13 @@ using coord_t = long long; exit(1); \ } while (0) +char const *getCudaErrorString(cudaError_t status); + #define checkCUDA(status) \ do { \ std::stringstream _error; \ if (status != 0) { \ - _error << "CUDA failure: " << cudaGetErrorString(status) << " (" \ + _error << "CUDA failure: " << getCudaErrorString(status) << " (" \ << status << ")"; \ FatalError(_error.str()); \ } \ diff --git a/lib/kernels/include/kernels/ff_handle.h b/lib/kernels/include/kernels/ff_handle.h index 89df04e3c1..179ce41cbf 100644 --- a/lib/kernels/include/kernels/ff_handle.h +++ b/lib/kernels/include/kernels/ff_handle.h @@ -40,6 +40,9 @@ FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(PerDeviceFFHandle, allowTensorOpMathConversion); #endif +std::string format_as(PerDeviceFFHandle const &x); +std::ostream &operator<<(std::ostream &s, PerDeviceFFHandle const &x); + } // namespace FlexFlow #endif diff --git a/lib/kernels/include/kernels/legion_dim.h b/lib/kernels/include/kernels/legion_dim.h index d8ffd91489..aafbd2cdcb 100644 --- a/lib/kernels/include/kernels/legion_dim.h +++ b/lib/kernels/include/kernels/legion_dim.h @@ -15,6 +15,23 @@ using LegionOrdered = DimOrdered; using LegionTensorDims = LegionOrdered; +template +FFOrdered + ff_ordered_from_legion_ordered(LegionOrdered const &legion_ordered) { + return FFOrdered(legion_ordered.rbegin(), legion_ordered.rend()); +} + +template +std::string format_as(LegionOrdered const &v) { + std::vector as_vec(v.cbegin(), v.cend()); + return fmt::format("", as_vec); +} + +template +std::ostream &operator<<(std::ostream &s, LegionOrdered const &v) { + return (s << fmt::to_string(v)); +} + } // namespace FlexFlow #endif diff --git a/lib/kernels/include/kernels/profiling.h b/lib/kernels/include/kernels/profiling.h index 602689d491..655d540685 100644 --- a/lib/kernels/include/kernels/profiling.h +++ b/lib/kernels/include/kernels/profiling.h @@ -2,20 +2,11 @@ #define _FLEXFLOW_KERNELS_PROFILING_H #include "device.h" +#include "kernels/profiling_settings.dtg.h" #include "utils/visitable.h" namespace FlexFlow { -struct ProfilingSettings : public use_visitable_cmp { -public: - ProfilingSettings() = delete; - ProfilingSettings(int warmup_iters, int measure_iters); - -public: - int warmup_iters; - int measure_iters; -}; - template std::optional profiling_wrapper(F const &f, bool enable_profiling, Ts &&...ts) { @@ -59,7 +50,4 @@ std::optional profiling_wrapper(F const &f, } // namespace FlexFlow -VISITABLE_STRUCT(::FlexFlow::ProfilingSettings, warmup_iters, measure_iters); -MAKE_VISIT_HASHABLE(::FlexFlow::ProfilingSettings); - #endif diff --git a/lib/kernels/include/kernels/profiling_settings.struct.toml b/lib/kernels/include/kernels/profiling_settings.struct.toml new file mode 100644 index 0000000000..694dfac76a --- /dev/null +++ b/lib/kernels/include/kernels/profiling_settings.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "ProfilingSettings" + +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +[[fields]] +name = "warmup_iters" +type = "int" + +[[fields]] +name = "measure_iters" +type = "int" diff --git a/lib/kernels/src/accessor.cc b/lib/kernels/src/accessor.cc index 56002718b1..a852f0d7b3 100644 --- a/lib/kernels/src/accessor.cc +++ b/lib/kernels/src/accessor.cc @@ -138,4 +138,33 @@ GenericTensorAccessorR read_only_accessor_from_write_accessor( writable.data_type, writable.shape, req(writable.ptr)}; } +bool is_shape_and_dtype_equal(GenericTensorAccessorW const &acc1, + GenericTensorAccessorW const &acc2) { + return acc1.shape == acc2.shape && acc1.data_type == acc2.data_type; +} + +bool shape_and_dtype_matches(GenericTensorAccessorW const &accessor, + ArrayShape const &expected_shape, + DataType const &expected_dtype) { + return accessor.shape == expected_shape && + accessor.data_type == expected_dtype; +} + +bool shape_and_dtype_matches(GenericTensorAccessorR const &accessor, + ArrayShape const &expected_shape, + DataType const &expected_dtype) { + return accessor.shape == expected_shape && + accessor.data_type == expected_dtype; +} + +std::pair + get_shape_and_datatype(GenericTensorAccessorR const &accessor) { + return std::make_pair(accessor.shape, accessor.data_type); +} + +std::pair + get_shape_and_datatype(GenericTensorAccessorW const &accessor) { + return std::make_pair(accessor.shape, accessor.data_type); +} + } // namespace FlexFlow diff --git a/lib/kernels/src/array_shape.cc b/lib/kernels/src/array_shape.cc index 7daf97ecd1..d5e2f1167d 100644 --- a/lib/kernels/src/array_shape.cc +++ b/lib/kernels/src/array_shape.cc @@ -39,7 +39,15 @@ std::size_t ArrayShape::num_elements() const { } std::size_t ArrayShape::operator[](legion_dim_t idx) const { - return dims[idx]; + return dims.at(idx); +} + +std::size_t ArrayShape::at(legion_dim_t idx) const { + return dims.at(idx); +} + +std::size_t ArrayShape::at(ff_dim_t idx) const { + return dims.at(legion_dim_from_ff_dim(idx, this->num_dims())); } ArrayShape ArrayShape::sub_shape( @@ -48,16 +56,37 @@ ArrayShape ArrayShape::sub_shape( NOT_IMPLEMENTED(); } -std::optional ArrayShape::at_maybe(std::size_t index) const { - if (index < dims.size()) { - return dims.at(legion_dim_t(index)); +std::optional ArrayShape::at_maybe(legion_dim_t index) const { + if (index.value < dims.size()) { + return dims.at(index); } else { return std::nullopt; } } +std::optional ArrayShape::at_maybe(ff_dim_t index) const { + return this->at_maybe(legion_dim_from_ff_dim(index, this->num_dims())); +} + size_t get_volume(ArrayShape const &shape) { return shape.get_volume(); } +TensorShape get_tensor_shape(ArrayShape const &shape, DataType dtype) { + return TensorShape{TensorDims{ff_ordered_from_legion_ordered(shape.dims)}, + dtype}; +} + +std::string format_as(ArrayShape const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} + +std::ostream &operator<<(std::ostream &s, ArrayShape const &x) { + return (s << fmt::to_string(x)); +} + } // namespace FlexFlow diff --git a/lib/kernels/src/cuda/cuda_helper.cu b/lib/kernels/src/cuda/cuda_helper.cu index 3488ce29af..2ff02038f4 100644 --- a/lib/kernels/src/cuda/cuda_helper.cu +++ b/lib/kernels/src/cuda/cuda_helper.cu @@ -220,25 +220,14 @@ __host__ void ffStatus_t cudnnSetTensorDescriptorFromArrayShape(cudnnTensorDescriptor_t tensor, ArrayShape const &shape) { - std::vector reversed_dims(shape.dims.begin(), shape.dims.end()); - reversed(reversed_dims); - ArrayShape flipped(reversed_dims); - - if (flipped.get_dim() == 5) { - assert(flipped[legion_dim_t(0)] == 1); - flipped = flipped.sub_shape(legion_dim_t(1), std::nullopt); - } - - assert(flipped.get_dim() > 0); - assert(flipped.get_dim() < 4); - - return cudnnSetTensor4dDescriptor(tensor, - CUDNN_TENSOR_NCHW, - CUDNN_DATA_FLOAT, - flipped.at_maybe(0).value_or(1), - flipped.at_maybe(1).value_or(2), - flipped.at_maybe(2).value_or(3), - flipped.at_maybe(3).value_or(3)); + return cudnnSetTensor4dDescriptor( + tensor, + CUDNN_TENSOR_NCHW, + CUDNN_DATA_FLOAT, + shape.at_maybe(legion_dim_t{0}).value_or(1), + shape.at_maybe(legion_dim_t{1}).value_or(1), + shape.at_maybe(legion_dim_t{2}).value_or(1), + shape.at_maybe(legion_dim_t{3}).value_or(1)); } cudnnDataType_t ff_to_cudnn_datatype(DataType type) { diff --git a/lib/kernels/src/cuda/ops/attention_kernels.cu b/lib/kernels/src/cuda/ops/attention_kernels.cu index e50f3983cc..38c32ad9e4 100644 --- a/lib/kernels/src/cuda/ops/attention_kernels.cu +++ b/lib/kernels/src/cuda/ops/attention_kernels.cu @@ -18,6 +18,41 @@ #include "kernels/device.h" namespace FlexFlow { + +bool MHAPerDeviceState::operator==(MHAPerDeviceState const &other) const { + return this->tie() == other.tie(); +} + +bool MHAPerDeviceState::operator!=(MHAPerDeviceState const &other) const { + return this->tie() != other.tie(); +} + +std:: + tuple + MHAPerDeviceState::tie() const { + return std::tie(this->handle, + this->weightSize, + this->reserveSpaceSize, + this->attnDesc, + this->qDesc, + this->kDesc, + this->vDesc, + this->oDesc, + this->devQoSeqArray, + this->devKvSeqArray, + this->loWinIdx, + this->hiWinIdx, + this->reserveSpace); +} + +std::string format_as(MHAPerDeviceState const &x) { + return fmt::format("MHAPerDeviceState"); +} + +std::ostream &operator<<(std::ostream &s, MHAPerDeviceState const &x) { + return (s << fmt::to_string(x)); +} + namespace Kernels { namespace MultiHeadAttention { diff --git a/lib/kernels/src/device.cc b/lib/kernels/src/device.cc index 0df5e84ee9..f46099c79a 100644 --- a/lib/kernels/src/device.cc +++ b/lib/kernels/src/device.cc @@ -2,6 +2,10 @@ namespace FlexFlow { +char const *getCudaErrorString(cudaError_t status) { + return cudaGetErrorString(status); +} + ffError_t ffEventCreate(ffEvent_t *e) { #if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA) return cudaEventCreate(e); diff --git a/lib/kernels/src/device.h b/lib/kernels/src/device.h index 96670f712f..ceff2f92ff 100644 --- a/lib/kernels/src/device.h +++ b/lib/kernels/src/device.h @@ -7,6 +7,8 @@ #include "op-attrs/operator_type.h" #include +namespace FlexFlow { + #if defined(FF_USE_CUDA) #include #elif defined(FF_USE_HIP_CUDA) @@ -26,9 +28,6 @@ #error "Unknown device" #endif -using ::FlexFlow::DataType; -using ::FlexFlow::OperatorType; - #define checkCUDNN(status) \ do { \ std::stringstream _error; \ @@ -132,9 +131,8 @@ __host__ void updateGAS(float *para_ptr, template void print_tensor(T const *ptr, size_t num_elements, char const *prefix); -ffStatus_t - cudnnSetTensorDescriptorFromArrayShape(ffTensorDescriptor_t tensor, - FlexFlow::ArrayShape const &shape); +ffStatus_t cudnnSetTensorDescriptorFromArrayShape(ffTensorDescriptor_t tensor, + ArrayShape const &shape); ffDataType_t ff_to_cuda_datatype(DataType type); @@ -142,4 +140,6 @@ ffCudnnDataType_t ff_to_cudnn_datatype(DataType type); void handle_unimplemented_kernel(OperatorType op_type); +} // namespace FlexFlow + #endif diff --git a/lib/kernels/src/ff_handle.cc b/lib/kernels/src/ff_handle.cc new file mode 100644 index 0000000000..63ca6975fd --- /dev/null +++ b/lib/kernels/src/ff_handle.cc @@ -0,0 +1,18 @@ +#include "kernels/ff_handle.h" + +namespace FlexFlow { + +std::string format_as(PerDeviceFFHandle const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} + +std::ostream &operator<<(std::ostream &s, PerDeviceFFHandle const &x) { + return s << fmt::to_string(x); +} + +} // namespace FlexFlow diff --git a/lib/kernels/src/local_cuda_allocator.cc b/lib/kernels/src/local_cuda_allocator.cc index 9e9cb19070..cdcfb017a0 100644 --- a/lib/kernels/src/local_cuda_allocator.cc +++ b/lib/kernels/src/local_cuda_allocator.cc @@ -21,7 +21,7 @@ void LocalCudaAllocator::deallocate(void *ptr) { } LocalCudaAllocator::~LocalCudaAllocator() { - for (auto ptr : ptrs) { + for (void *ptr : this->ptrs) { checkCUDA(cudaFree(ptr)); } } diff --git a/lib/kernels/test/src/test_cast_kernel.cc b/lib/kernels/test/src/test_cast_kernel.cc index 004bc9c32f..b110208bce 100644 --- a/lib/kernels/test/src/test_cast_kernel.cc +++ b/lib/kernels/test/src/test_cast_kernel.cc @@ -1,6 +1,5 @@ #include "doctest/doctest.h" #include "kernels/cast_kernels.h" -#include "kernels/cast_kernels_cpu.h" #include "test_utils.h" #include diff --git a/lib/local-execution/CMakeLists.txt b/lib/local-execution/CMakeLists.txt index 52b0f8edf7..f649f86ce3 100644 --- a/lib/local-execution/CMakeLists.txt +++ b/lib/local-execution/CMakeLists.txt @@ -14,3 +14,5 @@ ff_add_library( pcg spdlog ) + +add_subdirectory(test) diff --git a/lib/local-execution/include/local-execution/arg_ref.h b/lib/local-execution/include/local-execution/arg_ref.h index 992a7971a5..30326b0e84 100644 --- a/lib/local-execution/include/local-execution/arg_ref.h +++ b/lib/local-execution/include/local-execution/arg_ref.h @@ -2,7 +2,6 @@ #define _FLEXFLOW_LOCAL_EXECUTION_ARG_REF_H #include "kernels/ff_handle.h" -#include "local-execution/profiling.h" // #include "local-execution/serialization.h #include "utils/type_index.h" #include "utils/visitable.h" @@ -32,6 +31,14 @@ struct ArgRefSpec { return this->type_idx; } + bool operator==(ArgRefSpec const &other) const { + return this->tie() == other.tie(); + } + + bool operator!=(ArgRefSpec const &other) const { + return this->tie() != other.tie(); + } + template static ArgRefSpec create(ArgRef const &r) { // static_assert(is_serializable::value, "Type must be serializeable"); @@ -46,6 +53,10 @@ struct ArgRefSpec { std::type_index type_idx; LABEL_TYPE ref_type; + std::tuple + tie() const { + return std::tie(this->type_idx, this->ref_type); + } friend struct std::hash>; }; diff --git a/lib/local-execution/include/local-execution/concrete_arg.h b/lib/local-execution/include/local-execution/concrete_arg.h index acb5c206da..3bc2714a71 100644 --- a/lib/local-execution/include/local-execution/concrete_arg.h +++ b/lib/local-execution/include/local-execution/concrete_arg.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_LOCAL_EXECUTION_CONCRETE_ARG_H #define _FLEXFLOW_LOCAL_EXECUTION_CONCRETE_ARG_H +#include "fmt/format.h" #include "local-execution/serialization.h" #include "utils/type_index.h" #include @@ -22,6 +23,9 @@ struct ConcreteArgSpec { return this->type_idx; } + bool operator==(ConcreteArgSpec const &other) const; + bool operator!=(ConcreteArgSpec const &other) const; + template static ConcreteArgSpec create(T const &t) { // static_assert(is_serializable::value, "Type must be serializable"); @@ -40,8 +44,13 @@ struct ConcreteArgSpec { std::type_index type_idx; std::shared_ptr ptr; + + std::tuple tie() const; }; +std::string format_as(ConcreteArgSpec const &); +std::ostream &operator<<(std::ostream &, ConcreteArgSpec const &); + } // namespace FlexFlow #endif diff --git a/lib/local-execution/include/local-execution/cost_details.struct.toml b/lib/local-execution/include/local-execution/cost_details.struct.toml index e0d89acfdb..d17438b9ff 100644 --- a/lib/local-execution/include/local-execution/cost_details.struct.toml +++ b/lib/local-execution/include/local-execution/cost_details.struct.toml @@ -16,4 +16,3 @@ type = "float" [[fields]] name = "total_mem_usage" type = "size_t" - diff --git a/lib/local-execution/include/local-execution/device_specific.h b/lib/local-execution/include/local-execution/device_specific.h index 25d940e089..3a36e02327 100644 --- a/lib/local-execution/include/local-execution/device_specific.h +++ b/lib/local-execution/include/local-execution/device_specific.h @@ -13,7 +13,17 @@ struct DeviceSpecific { template static DeviceSpecific create(Args &&...args) { - NOT_IMPLEMENTED(); + size_t device_idx = 0; + return DeviceSpecific(std::make_shared(std::forward(args)...), + device_idx); + } + + bool operator==(DeviceSpecific const &other) const { + return this->tie() == other.tie(); + } + + bool operator!=(DeviceSpecific const &other) const { + return this->tie() != other.tie(); } T const *get(size_t curr_device_idx) const { @@ -23,17 +33,21 @@ struct DeviceSpecific { curr_device_idx, this->device_idx); } - return this->ptr; + return (T const *)this->ptr.get(); } // TODO: can modify ptr private: - DeviceSpecific(T *ptr, size_t device_idx) + DeviceSpecific(std::shared_ptr ptr, size_t device_idx) : ptr(ptr), device_idx(device_idx) {} - T *ptr; + std::shared_ptr ptr; size_t device_idx; + + std::tuple tie() const { + return std::tie(this->ptr, this->device_idx); + } }; // manually force serialization to make DeviceSpecific trivially diff --git a/lib/local-execution/include/local-execution/device_specific_device_states.variant.toml b/lib/local-execution/include/local-execution/device_specific_device_states.variant.toml new file mode 100644 index 0000000000..5f73bbbb8e --- /dev/null +++ b/lib/local-execution/include/local-execution/device_specific_device_states.variant.toml @@ -0,0 +1,90 @@ +namespace = "FlexFlow" +name = "DeviceSpecificDeviceStates" +features = [ + "eq", +] + +includes = [ + "kernels/attention_kernels.h", + "kernels/batch_norm_kernels.h", + "kernels/conv_2d_kernels.h", + "kernels/dropout_kernels.h", + "kernels/element_binary_kernels.h", + "kernels/element_unary_kernels.h", + "kernels/gather_kernels.h", + "kernels/layer_norm_kernels.h", + "kernels/linear_kernels.h", + "kernels/partition_kernels.h", + "kernels/pool_2d_kernels.h", + "kernels/reduce_kernels.h", + "kernels/reduction_kernels.h", + "kernels/reshape_kernels.h", + "kernels/softmax_kernels.h", + "kernels/topk_kernels.h", + "kernels/transpose_kernels.h", + "local-execution/device_specific.h", +] + +[[values]] +type = "::FlexFlow::DeviceSpecific<::FlexFlow::MHAPerDeviceState>" +key = "device_specific_mha_per_device_state" + +[[values]] +type = "::FlexFlow::DeviceSpecific<::FlexFlow::BatchNormPerDeviceState>" +key = "device_specific_batch_norm_per_device_state" + +[[values]] +type = "::FlexFlow::DeviceSpecific<::FlexFlow::Conv2DPerDeviceState>" +key = "device_specific_conv2d_per_device_state" + +[[values]] +type = "::FlexFlow::DeviceSpecific<::FlexFlow::DropoutPerDeviceState>" +key = "device_specific_dropout_per_device_state" + +[[values]] +type = "::FlexFlow::DeviceSpecific<::FlexFlow::ElementBinaryPerDeviceState>" +key = "device_specific_element_binary_per_device_state" + +[[values]] +type = "::FlexFlow::DeviceSpecific<::FlexFlow::ElementUnaryPerDeviceState>" +key = "device_specific_element_unary_per_device_state" + +[[values]] +type = "::FlexFlow::DeviceSpecific<::FlexFlow::GatherPerDeviceState>" +key = "device_specific_gather_per_device_state" + +[[values]] +type = "::FlexFlow::DeviceSpecific<::FlexFlow::LayerNormPerDeviceState>" +key = "device_specific_layer_norm_per_device_state" + +[[values]] +type = "::FlexFlow::DeviceSpecific<::FlexFlow::LinearPerDeviceState>" +key = "device_specific_linear_per_device_state" + +[[values]] +type = "::FlexFlow::DeviceSpecific<::FlexFlow::Pool2DPerDeviceState>" +key = "device_specific_pool_2d_per_device_state" + +[[values]] +type = "::FlexFlow::DeviceSpecific<::FlexFlow::ReducePerDeviceState>" +key = "device_specific_reduce_per_device_state" + +[[values]] +type = "::FlexFlow::DeviceSpecific<::FlexFlow::RepartitionPerDeviceState>" +key = "device_specific_repartition_per_device_state" + +[[values]] +type = "::FlexFlow::DeviceSpecific<::FlexFlow::ReshapePerDeviceState>" +key = "device_specific_reshape_per_device_state" + +[[values]] +type = "::FlexFlow::DeviceSpecific<::FlexFlow::SoftmaxPerDeviceState>" +key = "device_specific_softmax_per_device_state" + +[[values]] +type = "::FlexFlow::DeviceSpecific<::FlexFlow::TopKPerDeviceState>" +key = "device_specific_topk_per_device_state" + +[[values]] +type = "::FlexFlow::DeviceSpecific<::FlexFlow::TransposePerDeviceState>" +key = "device_specific_transpose_per_device_state" diff --git a/lib/local-execution/include/local-execution/device_states.h b/lib/local-execution/include/local-execution/device_states.h deleted file mode 100644 index 550c64a4bd..0000000000 --- a/lib/local-execution/include/local-execution/device_states.h +++ /dev/null @@ -1,40 +0,0 @@ - -#include "kernels/attention_kernels.h" -#include "kernels/batch_norm_kernels.h" -#include "kernels/conv_2d_kernels.h" -#include "kernels/dropout_kernels.h" -#include "kernels/element_binary_kernels.h" -#include "kernels/element_unary_kernels.h" -#include "kernels/gather_kernels.h" -#include "kernels/layer_norm_kernels.h" -#include "kernels/linear_kernels.h" -#include "kernels/partition_kernels.h" -#include "kernels/pool_2d_kernels.h" -#include "kernels/reduce_kernels.h" -#include "kernels/reduction_kernels.h" -#include "kernels/reshape_kernels.h" -#include "kernels/softmax_kernels.h" -#include "kernels/topk_kernels.h" -#include "kernels/transpose_kernels.h" -#include - -namespace FlexFlow { - -using DeviceStates = std::variant; - -} diff --git a/lib/local-execution/include/local-execution/fwd_bwd_task_impl_function.h b/lib/local-execution/include/local-execution/fwd_bwd_task_impl_function.h new file mode 100644 index 0000000000..7f80af77f3 --- /dev/null +++ b/lib/local-execution/include/local-execution/fwd_bwd_task_impl_function.h @@ -0,0 +1,32 @@ +#ifndef _FLEXFLOW_LOCAL_EXECUTION_FWD_BWD_TASK_IMPL_FUNCTION_H +#define _FLEXFLOW_LOCAL_EXECUTION_FWD_BWD_TASK_IMPL_FUNCTION_H + +#include "local-execution/task_argument_accessor.h" + +namespace FlexFlow { + +struct FwdBwdTaskImplFunction { + + std::optional (*function_ptr)(TaskArgumentAccessor const &); + + bool operator==(FwdBwdTaskImplFunction const &) const; + bool operator!=(FwdBwdTaskImplFunction const &) const; + bool operator<(FwdBwdTaskImplFunction const &) const; + bool operator>(FwdBwdTaskImplFunction const &) const; + bool operator<=(FwdBwdTaskImplFunction const &) const; + bool operator>=(FwdBwdTaskImplFunction const &) const; +}; + +std::string format_as(FwdBwdTaskImplFunction const &x); +std::ostream &operator<<(std::ostream &s, FwdBwdTaskImplFunction const &x); + +} // namespace FlexFlow + +namespace std { +template <> +struct hash<::FlexFlow::FwdBwdTaskImplFunction> { + size_t operator()(::FlexFlow::FwdBwdTaskImplFunction const &) const; +}; +} // namespace std + +#endif diff --git a/lib/local-execution/include/local-execution/init_task_impl_function.h b/lib/local-execution/include/local-execution/init_task_impl_function.h new file mode 100644 index 0000000000..b85944e13a --- /dev/null +++ b/lib/local-execution/include/local-execution/init_task_impl_function.h @@ -0,0 +1,33 @@ +#ifndef _FLEXFLOW_LOCAL_EXECUTION_INIT_TASK_IMPL_FUNCTION_H +#define _FLEXFLOW_LOCAL_EXECUTION_INIT_TASK_IMPL_FUNCTION_H + +#include "local-execution/device_specific_device_states.dtg.h" +#include "local-execution/task_argument_accessor.h" + +namespace FlexFlow { + +struct InitTaskImplFunction { + + DeviceSpecificDeviceStates (*function_ptr)(TaskArgumentAccessor const &); + + bool operator==(InitTaskImplFunction const &) const; + bool operator!=(InitTaskImplFunction const &) const; + bool operator<(InitTaskImplFunction const &) const; + bool operator>(InitTaskImplFunction const &) const; + bool operator<=(InitTaskImplFunction const &) const; + bool operator>=(InitTaskImplFunction const &) const; +}; + +std::string format_as(InitTaskImplFunction const &x); +std::ostream &operator<<(std::ostream &s, InitTaskImplFunction const &x); + +} // namespace FlexFlow + +namespace std { +template <> +struct hash<::FlexFlow::InitTaskImplFunction> { + size_t operator()(::FlexFlow::InitTaskImplFunction const &) const; +}; +} // namespace std + +#endif diff --git a/lib/local-execution/include/local-execution/local_cpu_allocator.h b/lib/local-execution/include/local-execution/local_cpu_allocator.h new file mode 100644 index 0000000000..d1e81facf2 --- /dev/null +++ b/lib/local-execution/include/local-execution/local_cpu_allocator.h @@ -0,0 +1,22 @@ +#include "kernels/allocation.h" +#include + +namespace FlexFlow { + +struct LocalCPUAllocator : public IAllocator { + LocalCPUAllocator() = default; + LocalCPUAllocator(LocalCPUAllocator const &) = delete; + LocalCPUAllocator(LocalCPUAllocator &&) = delete; + ~LocalCPUAllocator() = default; + + void *allocate(size_t) override; + void deallocate(void *) override; + +private: + std::unordered_map> ptrs; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT(LocalCPUAllocator); + +Allocator create_local_cpu_memory_allocator(); + +} // namespace FlexFlow diff --git a/lib/local-execution/include/local-execution/local_slots_backing.h b/lib/local-execution/include/local-execution/local_slots_backing.h index 46dcc3d914..6a0c28e988 100644 --- a/lib/local-execution/include/local-execution/local_slots_backing.h +++ b/lib/local-execution/include/local-execution/local_slots_backing.h @@ -3,9 +3,9 @@ #define _FLEXFLOW_LOCAL_EXECUTION_SLOT_REGISTRY_H #include "kernels/accessor.h" -#include "local-execution/device_states.h" #include "local-execution/local_task_argument_accessor.h" #include "local-execution/op_task_invocation.h" +#include "local-execution/per_device_op_state.h" #include "local-execution/runtime_arg_config.h" namespace FlexFlow { @@ -18,17 +18,24 @@ struct LocalSlotsBacking { public: void add_per_device_op_state(layer_guid_t const &, - DeviceSpecific const &); - bool is_tensor_allocated(tensor_guid_t const &) const; - GenericTensorAccessorW const &get_tensor_backing(tensor_guid_t const &, - IsGrad) const; + DeviceSpecificDeviceStates const &); + void allocate_outgoing_tensors(layer_guid_t const &, + ComputationGraph const &, + Allocator &); TensorSlotsBacking construct_tensor_slots_backing(OpTaskBinding const &, layer_guid_t const &) const; ArgSlotsBacking construct_arg_slots_backing(OpTaskBinding const &, layer_guid_t const &) const; + + ConcreteArgSpec resolve_runtime_arg_ref_spec(RuntimeArgRefSpec const &) const; ConcreteArgSpec resolve_op_arg_ref_spec(OpArgRefSpec const &, layer_guid_t const &) const; - ConcreteArgSpec resolve_runtime_arg_ref_spec(RuntimeArgRefSpec const &) const; + +private: + bool is_tensor_allocated(tensor_guid_t const &) const; + bool is_gradient_tensor_allocated(tensor_guid_t const &) const; + GenericTensorAccessorW const &get_tensor_backing(tensor_guid_t const &, + IsGrad) const; public: // tensors @@ -36,13 +43,11 @@ struct LocalSlotsBacking { TensorBackingMap gradient_tensor_mapping; std::unordered_map> input_tensor_slots; - std::unordered_map> - weight_tensor_slots; std::unordered_map> output_tensor_slots; // arguments - std::unordered_map>> + std::unordered_map per_device_op_states; RuntimeArgConfig runtime_arg_config; }; diff --git a/lib/local-execution/include/local-execution/local_task_argument_accessor.h b/lib/local-execution/include/local-execution/local_task_argument_accessor.h index 27c8af0836..1e1516a0de 100644 --- a/lib/local-execution/include/local-execution/local_task_argument_accessor.h +++ b/lib/local-execution/include/local-execution/local_task_argument_accessor.h @@ -38,6 +38,16 @@ struct LocalTaskArgumentAccessor : public ITaskArgumentAccessor { TensorSlotsBacking tensor_slots_backing; ArgSlotsBacking arg_slots_backing; }; + +using TensorSlotsBackingWithoutAddresses = std::unordered_map< + SlotGradId, + std::variant, + std::vector>>>; + +TensorSlotsBackingWithoutAddresses + get_slots_backing_without_tensor_allocation_addresses( + TensorSlotsBacking const &); + CHECK_RC_COPY_VIRTUAL_COMPLIANT(LocalTaskArgumentAccessor); } // namespace FlexFlow diff --git a/lib/local-execution/include/local-execution/local_training_backing.h b/lib/local-execution/include/local-execution/local_training_backing.h index 2fe3a57407..b398bb8cc3 100644 --- a/lib/local-execution/include/local-execution/local_training_backing.h +++ b/lib/local-execution/include/local-execution/local_training_backing.h @@ -14,29 +14,27 @@ struct LocalTrainingBacking { ComputationGraph const &, TensorBackingMap const &, RuntimeArgConfig const &); - ~LocalTrainingBacking() = default; void execute_init(); PerLayerElapsedTime execute_forward(); PerLayerElapsedTime execute_backward(); void execute_update(); -private: - DeviceSpecific - call_init_task_impl(task_id_t, TaskArgumentAccessor const &); - std::optional call_task_impl(task_id_t, TaskArgumentAccessor); - TaskArgumentAccessor get_task_arg_accessor(OpTaskInvocation const &, layer_guid_t const &) const; +private: + DeviceSpecificDeviceStates call_init_task_impl(task_id_t, + TaskArgumentAccessor const &); + std::optional call_task_impl(task_id_t, TaskArgumentAccessor); + +private: Allocator allocator; ComputationGraph computation_graph; TaskRegistry task_registry; LocalSlotsBacking local_slots_backing; }; -std::vector get_task_ids(ComputationGraphOpAttrs const &); - } // namespace FlexFlow #endif diff --git a/lib/local-execution/include/local-execution/op_arg_ref.h b/lib/local-execution/include/local-execution/op_arg_ref.h index 939fab21c0..20d6ccb1c5 100644 --- a/lib/local-execution/include/local-execution/op_arg_ref.h +++ b/lib/local-execution/include/local-execution/op_arg_ref.h @@ -3,30 +3,22 @@ #include "local-execution/arg_ref.h" #include "local-execution/device_specific.h" +#include "local-execution/op_arg_ref_type.dtg.h" +#include "local-execution/per_device_op_state.h" #include "op-attrs/parallel_tensor_shape.h" namespace FlexFlow { -enum class OpArgRefLabel { PER_DEVICE_OP_STATE, PARALLEL_TENSOR_SHAPE }; - -struct PerDeviceOpStateRefType {}; - -struct ParallelTensorShapeRefType { - int idx; -}; - -using OpArgRefType = - std::variant; - template using OpArgRef = ArgRef; using OpArgRefSpec = ArgRefSpec; template -OpArgRef> per_device_op_state() { - OpArgRefType op_arg_ref_type = PerDeviceOpStateRefType{}; - ArgRef> arg_ref = {op_arg_ref_type}; +OpArgRef per_device_op_state() { + OpArgRefType op_arg_ref_type = OpArgRefType{PerDeviceOpStateRefType{}}; + static_assert(PerDeviceOpState::IsPartOfPerDeviceOpState_v); + ArgRef arg_ref = {op_arg_ref_type}; return arg_ref; } diff --git a/lib/local-execution/include/local-execution/op_arg_ref_type.variant.toml b/lib/local-execution/include/local-execution/op_arg_ref_type.variant.toml new file mode 100644 index 0000000000..cd226da161 --- /dev/null +++ b/lib/local-execution/include/local-execution/op_arg_ref_type.variant.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "OpArgRefType" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "local-execution/per_device_op_state_ref_type.dtg.h", + "local-execution/parallel_tensor_shape_ref_type.dtg.h", +] + +[[values]] +type = "::FlexFlow::PerDeviceOpStateRefType" +key = "per_device_op_state_ref_type" + +[[values]] +type = "::FlexFlow::ParallelTensorShapeRefType" +key = "parallel_tensor_shape_ref_type" diff --git a/lib/local-execution/include/local-execution/op_arg_spec.variant.toml b/lib/local-execution/include/local-execution/op_arg_spec.variant.toml index a13018e6a1..28169902ae 100644 --- a/lib/local-execution/include/local-execution/op_arg_spec.variant.toml +++ b/lib/local-execution/include/local-execution/op_arg_spec.variant.toml @@ -1,7 +1,7 @@ namespace = "FlexFlow" name = "OpArgSpec" features = [ - # "eq", + "eq", # "ord", # "hash", # "json", diff --git a/lib/local-execution/include/local-execution/op_task_invocation.h b/lib/local-execution/include/local-execution/op_task_invocation.h index eafd6b80b0..73a0460554 100644 --- a/lib/local-execution/include/local-execution/op_task_invocation.h +++ b/lib/local-execution/include/local-execution/op_task_invocation.h @@ -11,7 +11,7 @@ #include "local-execution/profiling.h" #include "local-execution/runtime_arg_ref.h" #include "local-execution/slot_grad_id.dtg.h" -#include "local-execution/tasks.h" +#include "local-execution/task_id_t.dtg.h" #include "local-execution/variadic_tensor_ref.h" #include "op-attrs/computation_graph_op_attrs.h" #include "pcg/computation_graph.h" @@ -85,6 +85,8 @@ struct OpTaskBinding { void bind_arg(slot_id_t name, OpArgRef const &ref) { this->insert_arg_spec(name, OpArgSpec{OpArgRefSpec::create(ref)}); } + bool operator==(OpTaskBinding const &other) const; + bool operator!=(OpTaskBinding const &other) const; std::unordered_map const & get_tensor_bindings() const; @@ -93,15 +95,19 @@ struct OpTaskBinding { void bind_from_forward(OpTaskBinding const &fwd); private: - void insert_arg_spec(slot_id_t name, OpArgSpec const &arg_spec); std::unordered_map tensor_bindings; std::unordered_map arg_bindings; + +private: + void insert_arg_spec(slot_id_t name, OpArgSpec const &arg_spec); + std::tuple + tie() const; }; struct OpTaskInvocation { public: OpTaskInvocation() = delete; - OpTaskInvocation(task_id_t const &task_id, OpTaskBinding const &binding) + OpTaskInvocation(task_id_t task_id, OpTaskBinding const &binding) : task_id(task_id), binding(binding) {} public: @@ -109,11 +115,6 @@ struct OpTaskInvocation { OpTaskBinding binding; }; -OpTaskInvocation init(ComputationGraphOpAttrs const &); -OpTaskInvocation forward(ComputationGraphOpAttrs const &); -OpTaskInvocation backward(ComputationGraphOpAttrs const &); - -OpTaskSignature infer_bwd_signature(OpTaskSignature const &fwd); OpTaskBinding infer_bwd_binding(OpTaskBinding const &fwd); bool is_invocation_valid(OpTaskSignature const &sig, diff --git a/lib/local-execution/include/local-execution/op_task_signature.h b/lib/local-execution/include/local-execution/op_task_signature.h index ad5177b289..0447644354 100644 --- a/lib/local-execution/include/local-execution/op_task_signature.h +++ b/lib/local-execution/include/local-execution/op_task_signature.h @@ -7,7 +7,9 @@ #include "local-execution/serialization.h" #include "local-execution/slot_id_t.dtg.h" #include "local-execution/slot_type.dtg.h" -#include "local-execution/tasks.h" +#include "local-execution/task_id_t.dtg.h" +#include "utils/hash/unordered_map.h" +#include "utils/hash/unordered_set.h" #include "utils/type_index.h" #include "utils/visitable.h" @@ -40,10 +42,9 @@ struct OpTaskSignature { void add_output_slot(int, SlotType slot_type = SlotType::TENSOR); void add_output_slot(slot_id_t, SlotType slot_type = SlotType::TENSOR); - void add_bwd_necessary_output_slot(int, - SlotType slot_type = SlotType::TENSOR); - void add_bwd_necessary_output_slot(slot_id_t, - SlotType slot_type = SlotType::TENSOR); + void add_bwd_optional_output_slot(int, SlotType slot_type = SlotType::TENSOR); + void add_bwd_optional_output_slot(slot_id_t, + SlotType slot_type = SlotType::TENSOR); void add_weight_slot(int, SlotType slot_type = SlotType::TENSOR); void add_weight_slot(slot_id_t, SlotType slot_type = SlotType::TENSOR); @@ -96,27 +97,10 @@ struct OpTaskSignature { FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION( OpTaskSignature, type, return_value, task_arg_types, op_tensor_slots); -template -void register_task(task_id_t, - std::string const &name, - OpTaskSignature const &, - F const &func); +std::string format_as(OpTaskSignature const &x); +std::ostream &operator<<(std::ostream &s, OpTaskSignature const &x); -template -void register_task(task_id_t, - std::string const &name, - OpTaskSignature const &, - F const &func, - F const &cpu_func); - -template -OpTaskSignature init_signature(); - -template -OpTaskSignature fwd_signature(); - -template -OpTaskSignature bwd_signature(); +OpTaskSignature infer_bwd_signature(OpTaskSignature const &fwd); } // namespace FlexFlow diff --git a/lib/local-execution/src/ops/attention.h b/lib/local-execution/include/local-execution/ops/attention.h similarity index 100% rename from lib/local-execution/src/ops/attention.h rename to lib/local-execution/include/local-execution/ops/attention.h diff --git a/lib/local-execution/include/local-execution/parallel_tensor_shape_ref_type.struct.toml b/lib/local-execution/include/local-execution/parallel_tensor_shape_ref_type.struct.toml new file mode 100644 index 0000000000..fe340f4451 --- /dev/null +++ b/lib/local-execution/include/local-execution/parallel_tensor_shape_ref_type.struct.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "ParallelTensorShapeRefType" + +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +[[fields]] +name = "idx" +type = "int" diff --git a/lib/local-execution/include/local-execution/per_device_op_state.h b/lib/local-execution/include/local-execution/per_device_op_state.h new file mode 100644 index 0000000000..1edd5b6360 --- /dev/null +++ b/lib/local-execution/include/local-execution/per_device_op_state.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LOCAL_EXECUTION_PER_DEVICE_STATE_H +#define _FLEXFLOW_LOCAL_EXECUTION_PER_DEVICE_STATE_H + +#include "local-execution/device_specific_device_states.dtg.h" +#include "local-execution/per_device_op_state.dtg.h" + +namespace FlexFlow { + +PerDeviceOpState + get_device_state_from_device_specific(DeviceSpecificDeviceStates const &, + size_t device_idx); + +} + +#endif diff --git a/lib/local-execution/include/local-execution/per_device_op_state.variant.toml b/lib/local-execution/include/local-execution/per_device_op_state.variant.toml new file mode 100644 index 0000000000..f99ff10bb9 --- /dev/null +++ b/lib/local-execution/include/local-execution/per_device_op_state.variant.toml @@ -0,0 +1,87 @@ +namespace = "FlexFlow" +name = "PerDeviceOpState" +features = [] + +includes = [ + "kernels/attention_kernels.h", + "kernels/batch_norm_kernels.h", + "kernels/conv_2d_kernels.h", + "kernels/dropout_kernels.h", + "kernels/element_binary_kernels.h", + "kernels/element_unary_kernels.h", + "kernels/gather_kernels.h", + "kernels/layer_norm_kernels.h", + "kernels/linear_kernels.h", + "kernels/partition_kernels.h", + "kernels/pool_2d_kernels.h", + "kernels/reduce_kernels.h", + "kernels/reduction_kernels.h", + "kernels/reshape_kernels.h", + "kernels/softmax_kernels.h", + "kernels/topk_kernels.h", + "kernels/transpose_kernels.h", +] + +[[values]] +type = "::FlexFlow::MHAPerDeviceState" +key = "mha_per_device_state" + +[[values]] +type = "::FlexFlow::BatchNormPerDeviceState" +key = "batch_norm_per_device_state" + +[[values]] +type = "::FlexFlow::Conv2DPerDeviceState" +key = "conv2d_per_device_state" + +[[values]] +type = "::FlexFlow::DropoutPerDeviceState" +key = "dropout_per_device_state" + +[[values]] +type = "::FlexFlow::ElementBinaryPerDeviceState" +key = "element_binary_per_device_state" + +[[values]] +type = "::FlexFlow::ElementUnaryPerDeviceState" +key = "element_unary_per_device_state" + +[[values]] +type = "::FlexFlow::GatherPerDeviceState" +key = "gather_per_device_state" + +[[values]] +type = "::FlexFlow::LayerNormPerDeviceState" +key = "layer_norm_per_device_state" + +[[values]] +type = "::FlexFlow::LinearPerDeviceState" +key = "linear_per_device_state" + +[[values]] +type = "::FlexFlow::Pool2DPerDeviceState" +key = "pool_2d_per_device_state" + +[[values]] +type = "::FlexFlow::ReducePerDeviceState" +key = "reduce_per_device_state" + +[[values]] +type = "::FlexFlow::RepartitionPerDeviceState" +key = "repartition_per_device_state" + +[[values]] +type = "::FlexFlow::ReshapePerDeviceState" +key = "reshape_per_device_state" + +[[values]] +type = "::FlexFlow::SoftmaxPerDeviceState" +key = "softmax_per_device_state" + +[[values]] +type = "::FlexFlow::TopKPerDeviceState" +key = "topk_per_device_state" + +[[values]] +type = "::FlexFlow::TransposePerDeviceState" +key = "transpose_per_device_state" diff --git a/lib/local-execution/include/local-execution/per_device_op_state_ref_type.struct.toml b/lib/local-execution/include/local-execution/per_device_op_state_ref_type.struct.toml new file mode 100644 index 0000000000..e3d48a02ee --- /dev/null +++ b/lib/local-execution/include/local-execution/per_device_op_state_ref_type.struct.toml @@ -0,0 +1,12 @@ +namespace = "FlexFlow" +name = "PerDeviceOpStateRefType" + +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +fields = [] diff --git a/lib/local-execution/include/local-execution/runtime_arg_ref.h b/lib/local-execution/include/local-execution/runtime_arg_ref.h index 470f02929b..279d854a27 100644 --- a/lib/local-execution/include/local-execution/runtime_arg_ref.h +++ b/lib/local-execution/include/local-execution/runtime_arg_ref.h @@ -4,6 +4,7 @@ #include "local-execution/arg_ref.h" #include "local-execution/config.h" #include "local-execution/device_specific.h" +#include "local-execution/profiling.h" namespace FlexFlow { diff --git a/lib/local-execution/include/local-execution/task_argument_accessor.h b/lib/local-execution/include/local-execution/task_argument_accessor.h index 7a84bfb5c3..54c8dfc5f1 100644 --- a/lib/local-execution/include/local-execution/task_argument_accessor.h +++ b/lib/local-execution/include/local-execution/task_argument_accessor.h @@ -1,19 +1,27 @@ #ifndef _FLEXFLOW_LOCAL_EXECUTION_TASK_ARGUMENT_ACCESSOR_H #define _FLEXFLOW_LOCAL_EXECUTION_TASK_ARGUMENT_ACCESSOR_H +#include "local-execution/device_specific.h" #include "local-execution/itask_argument_accessor.h" +#include "local-execution/per_device_op_state.dtg.h" namespace FlexFlow { struct TaskArgumentAccessor { template - T const &get_argument(int slot) const { - return this->get_argument(slot_id_t{slot}); + T const &get_argument(slot_id_t slot) const { + if constexpr (PerDeviceOpState::IsPartOfPerDeviceOpState_v) { + PerDeviceOpState device_states = + this->ptr->get_concrete_arg(slot).get(); + return device_states.get(); + } else { + return this->ptr->get_concrete_arg(slot).get(); + } } template - T const &get_argument(slot_id_t slot) const { - return this->ptr->get_concrete_arg(slot).get(); + T const &get_argument(int slot) const { + return this->get_argument(slot_id_t{slot}); } template diff --git a/lib/local-execution/include/local-execution/task_id_t.enum.toml b/lib/local-execution/include/local-execution/task_id_t.enum.toml new file mode 100644 index 0000000000..9cbe64c268 --- /dev/null +++ b/lib/local-execution/include/local-execution/task_id_t.enum.toml @@ -0,0 +1,431 @@ +namespace = "FlexFlow" +name = "task_id_t" +features = [ + "hash", + "fmt", + "rapidcheck", + "json", +] + +[[values]] +name = "TOP_LEVEL_TASK_ID" + +[[values]] +name = "FF_INIT_TASK_ID" + +[[values]] +name = "IMAGE_INIT_TASK_ID" + +[[values]] +name = "LABEL_INIT_TASK_ID" + +[[values]] +name = "LOAD_IMAGES_TASK_ID" + +[[values]] +name = "NORMALIZE_IMAGES_TASK_ID" + +[[values]] +name = "ELEMENTBINARY_INIT_TASK_ID" + +[[values]] +name = "ELEMENTBINARY_FWD_TASK_ID" + +[[values]] +name = "ELEMENTBINARY_BWD_TASK_ID" + +[[values]] +name = "ELEMENTUNARY_INIT_TASK_ID" + +[[values]] +name = "ELEMENTUNARY_FWD_TASK_ID" + +[[values]] +name = "ELEMENTUNARY_BWD_TASK_ID" + +[[values]] +name = "CONV2D_INIT_TASK_ID" + +[[values]] +name = "CONV2D_FWD_TASK_ID" + +[[values]] +name = "CONV2D_BWD_TASK_ID" + +[[values]] +name = "DROPOUT_INIT_TASK_ID" + +[[values]] +name = "DROPOUT_FWD_TASK_ID" + +[[values]] +name = "DROPOUT_BWD_TASK_ID" + +[[values]] +name = "EMBED_INIT_TASK_ID" + +[[values]] +name = "EMBED_FWD_TASK_ID" + +[[values]] +name = "EMBED_BWD_TASK_ID" + +[[values]] +name = "GATHER_INIT_TASK_ID" + +[[values]] +name = "GATHER_FWD_TASK_ID" + +[[values]] +name = "GATHER_BWD_TASK_ID" + +[[values]] +name = "CAST_INIT_TASK_ID" + +[[values]] +name = "CAST_FWD_TASK_ID" + +[[values]] +name = "CAST_BWD_TASK_ID" + +[[values]] +name = "POOL2D_INIT_TASK_ID" + +[[values]] +name = "POOL2D_FWD_TASK_ID" + +[[values]] +name = "POOL2D_BWD_TASK_ID" + +[[values]] +name = "BATCHNORM_INIT_TASK_ID" + +[[values]] +name = "BATCHNORM_FWD_TASK_ID" + +[[values]] +name = "BATCHNORM_BWD_TASK_ID" + +[[values]] +name = "BATCHMATMUL_INIT_TASK_ID" + +[[values]] +name = "BATCHMATMUL_FWD_TASK_ID" + +[[values]] +name = "BATCHMATMUL_BWD_TASK_ID" + +[[values]] +name = "LAYERNORM_INIT_TASK_ID" + +[[values]] +name = "LAYERNORM_FWD_TASK_ID" + +[[values]] +name = "LAYERNORM_BWD_TASK_ID" + +[[values]] +name = "LINEAR_INIT_TASK_ID" + +[[values]] +name = "LINEAR_FWD_TASK_ID" + +[[values]] +name = "LINEAR_BWD_TASK_ID" + +[[values]] +name = "FLAT_INIT_TASK_ID" + +[[values]] +name = "FLAT_FWD_TASK_ID" + +[[values]] +name = "FLAT_BWD_TASK_ID" + +[[values]] +name = "SOFTMAX_INIT_TASK_ID" + +[[values]] +name = "SOFTMAX_FWD_TASK_ID" + +[[values]] +name = "SOFTMAX_BWD_TASK_ID" + +[[values]] +name = "CONCAT_INIT_TASK_ID" + +[[values]] +name = "CONCAT_FWD_TASK_ID" + +[[values]] +name = "CONCAT_BWD_TASK_ID" + +[[values]] +name = "SPLIT_INIT_TASK_ID" + +[[values]] +name = "SPLIT_FWD_TASK_ID" + +[[values]] +name = "SPLIT_BWD_TASK_ID" + +[[values]] +name = "REDUCE_INIT_TASK_ID" + +[[values]] +name = "REDUCE_FWD_TASK_ID" + +[[values]] +name = "REDUCE_BWD_TASK_ID" + +[[values]] +name = "RESHAPE_INIT_TASK_ID" + +[[values]] +name = "RESHAPE_FWD_TASK_ID" + +[[values]] +name = "RESHAPE_BWD_TASK_ID" + +[[values]] +name = "REVERSE_INIT_TASK_ID" + +[[values]] +name = "REVERSE_FWD_TASK_ID" + +[[values]] +name = "REVERSE_BWD_TASK_ID" + +[[values]] +name = "TOPK_INIT_TASK_ID" + +[[values]] +name = "TOPK_FWD_TASK_ID" + +[[values]] +name = "TOPK_BWD_TASK_ID" + +[[values]] +name = "TRANSPOSE_INIT_TASK_ID" + +[[values]] +name = "TRANSPOSE_FWD_TASK_ID" + +[[values]] +name = "TRANSPOSE_BWD_TASK_ID" + +[[values]] +name = "ATTENTION_INIT_TASK_ID" + +[[values]] +name = "ATTENTION_FWD_TASK_ID" + +[[values]] +name = "ATTENTION_BWD_TASK_ID" + +[[values]] +name = "MSELOSS_BWD_TASK_ID" + +[[values]] +name = "FUSEDOP_INIT_TASK_ID" + +[[values]] +name = "FUSEDOP_FWD_TASK_ID" + +[[values]] +name = "FUSEDOP_BWD_TASK_ID" + +[[values]] +name = "NOOP_INIT_TASK_ID" + +[[values]] +name = "METRICS_COMP_TASK_ID" + +[[values]] +name = "UPDATE_METRICS_TASK_ID" + +[[values]] +name = "PS_PREFETCH_TASK_ID" + +[[values]] +name = "LOSS_BWD_TASK_ID" + +[[values]] +name = "SGD_UPD_PS_TASK_ID" + +[[values]] +name = "ADAM_UPD_PS_TASK_ID" + +[[values]] +name = "SGD_UPD_NCCL_TASK_ID" + +[[values]] +name = "ADAM_UPD_NCCL_TASK_ID" + +[[values]] +name = "GLOROT_INIT_TASK_ID" + +[[values]] +name = "ZERO_INIT_TASK_ID" + +[[values]] +name = "CONSTANT_INIT_TASK_ID" + +[[values]] +name = "UNIFORM_INIT_TASK_ID" + +[[values]] +name = "NORMAL_INIT_TASK_ID" + +[[values]] +name = "NCCL_GETUNIQUEID_TASK_ID" + +[[values]] +name = "NCCL_INIT_COMMS_TASK_ID" + +[[values]] +name = "STRATEGY_SEARCH_TASK_ID" + +[[values]] +name = "GRAPH_OPTIMIZE_TASK_ID" + +[[values]] +name = "PY_DL_FLOAT_LOAD_ENTIRE_CPU_TASK_ID" + +[[values]] +name = "PY_DL_INT32_LOAD_ENTIRE_CPU_TASK_ID" + +[[values]] +name = "PY_DL_INT64_LOAD_ENTIRE_CPU_TASK_ID" + +[[values]] +name = "PY_DL_FLOAT_INDEX_LOAD_ENTIRE_CPU_TASK_ID" + +[[values]] +name = "PY_DL_INT32_INDEX_LOAD_ENTIRE_CPU_TASK_ID" + +[[values]] +name = "PY_DL_INT64_INDEX_LOAD_ENTIRE_CPU_TASK_ID" + +[[values]] +name = "PY_DL_FLOAT_LOAD_BATCH_GPU_TASK_ID" + +[[values]] +name = "PY_DL_INT32_LOAD_BATCH_GPU_TASK_ID" + +[[values]] +name = "PY_DL_INT64_LOAD_BATCH_GPU_TASK_ID" + +[[values]] +name = "REPARTITION_INIT_TASK_ID" + +[[values]] +name = "REPARTITION_FWD_TASK_ID" + +[[values]] +name = "REPARTITION_BWD_TASK_ID" + +[[values]] +name = "COMBINE_INIT_TASK_ID" + +[[values]] +name = "COMBINE_FWD_TASK_ID" + +[[values]] +name = "COMBINE_BWD_TASK_ID" + +[[values]] +name = "REPLICATE_INIT_TASK_ID" + +[[values]] +name = "REPLICATE_FWD_TASK_ID" + +[[values]] +name = "REPLICATE_BWD_TASK_ID" + +[[values]] +name = "REDUCTION_INIT_TASK_ID" + +[[values]] +name = "REDUCTION_FWD_TASK_ID" + +[[values]] +name = "REDUCTION_BWD_TASK_ID" + +[[values]] +name = "PIPELINE_INIT_TASK_ID" + +[[values]] +name = "PIPELINE_FWD_TASK_ID" + +[[values]] +name = "PIPELINE_BWD_TASK_ID" + +[[values]] +name = "FUSED_PARALLELOP_INIT_TASK_ID" + +[[values]] +name = "FUSED_PARALLELOP_FWD_TASK_ID" + +[[values]] +name = "FUSED_PARALLELOP_BWD_TASK_ID" + +[[values]] +name = "CUSTOM_GPU_TASK_ID_FIRST" + +[[values]] +name = "CUSTOM_GPU_TASK_ID_1" + +[[values]] +name = "CUSTOM_GPU_TASK_ID_2" + +[[values]] +name = "CUSTOM_GPU_TASK_ID_3" + +[[values]] +name = "CUSTOM_GPU_TASK_ID_4" + +[[values]] +name = "CUSTOM_GPU_TASK_ID_5" + +[[values]] +name = "CUSTOM_GPU_TASK_ID_6" + +[[values]] +name = "CUSTOM_GPU_TASK_ID_7" + +[[values]] +name = "CUSTOM_GPU_TASK_ID_8" + +[[values]] +name = "CUSTOM_GPU_TASK_ID_LAST" + +[[values]] +name = "CUSTOM_CPU_TASK_ID_FIRST" + +[[values]] +name = "CUSTOM_CPU_TASK_ID_1" + +[[values]] +name = "CUSTOM_CPU_TASK_ID_2" + +[[values]] +name = "CUSTOM_CPU_TASK_ID_3" + +[[values]] +name = "CUSTOM_CPU_TASK_ID_4" + +[[values]] +name = "CUSTOM_CPU_TASK_ID_5" + +[[values]] +name = "CUSTOM_CPU_TASK_ID_6" + +[[values]] +name = "CUSTOM_CPU_TASK_ID_7" + +[[values]] +name = "CUSTOM_CPU_TASK_ID_LAST" + +[[values]] +name = "PYTHON_TOP_LEVEL_TASK_ID" diff --git a/lib/local-execution/include/local-execution/task_impl_function.variant.toml b/lib/local-execution/include/local-execution/task_impl_function.variant.toml new file mode 100644 index 0000000000..a12be37da2 --- /dev/null +++ b/lib/local-execution/include/local-execution/task_impl_function.variant.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "TaskImplFunction" +features = [ + "eq", + "fmt", + "hash", + "ord" +] + +includes = [ + "local-execution/init_task_impl_function.h", + "local-execution/fwd_bwd_task_impl_function.h", +] + +[[values]] +type = "::FlexFlow::InitTaskImplFunction" +key = "init_task_impl_function" + +[[values]] +type = "::FlexFlow::FwdBwdTaskImplFunction" +key = "fwd_bwd_task_impl_function" diff --git a/lib/local-execution/include/local-execution/task_registry.h b/lib/local-execution/include/local-execution/task_registry.h index 01b7d29b36..e00cc183da 100644 --- a/lib/local-execution/include/local-execution/task_registry.h +++ b/lib/local-execution/include/local-execution/task_registry.h @@ -2,25 +2,16 @@ #ifndef _FLEXFLOW_LOCAL_EXECUTION_TASK_REGISTRY_H #define _FLEXFLOW_LOCAL_EXECUTION_TASK_REGISTRY_H -#include "local-execution/op_task_invocation.h" -#include "local-execution/task_signature_impl.h" -#include "op-attrs/operator_attrs.h" -#include "pcg/computation_graph.h" +#include "local-execution/task_registry.dtg.h" +#include "op-attrs/computation_graph_op_attrs.h" namespace FlexFlow { -struct TaskRegistry { - TaskRegistry() = default; +TaskRegistry empty_task_registry(); - void register_task(task_id_t const &, - layer_guid_t const &, - ComputationGraphOpAttrs const &attrs); - - std::unordered_map init_task_ids; - std::unordered_map forward_task_ids; - std::unordered_map backward_task_ids; - std::unordered_map task_mapping; -}; +void register_tasks_for_layer(TaskRegistry &, + layer_guid_t const &, + ComputationGraphOpAttrs const &attrs); } // namespace FlexFlow diff --git a/lib/local-execution/include/local-execution/task_registry.struct.toml b/lib/local-execution/include/local-execution/task_registry.struct.toml new file mode 100644 index 0000000000..308527efac --- /dev/null +++ b/lib/local-execution/include/local-execution/task_registry.struct.toml @@ -0,0 +1,34 @@ +namespace = "FlexFlow" +name = "TaskRegistry" +features = [ + "eq", + "fmt", + "hash" +] + +includes = [ + "local-execution/task_signature_impl.dtg.h", + "local-execution/task_id_t.dtg.h", + "pcg/layer_guid_t.dtg.h", +] + +src_includes = [ + "utils/hash/unordered_map.h", + "utils/fmt/unordered_map.h", +] + +[[fields]] +name = "init_task_ids" +type = "std::unordered_map<::FlexFlow::layer_guid_t, std::optional<::FlexFlow::task_id_t>>" + +[[fields]] +name = "forward_task_ids" +type = "std::unordered_map<::FlexFlow::layer_guid_t, std::optional<::FlexFlow::task_id_t>>" + +[[fields]] +name = "backward_task_ids" +type = "std::unordered_map<::FlexFlow::layer_guid_t, std::optional<::FlexFlow::task_id_t>>" + +[[fields]] +name = "task_mapping" +type = "std::unordered_map<::FlexFlow::task_id_t, ::FlexFlow::TaskSignatureAndImpl>" diff --git a/lib/local-execution/include/local-execution/task_signature_impl.h b/lib/local-execution/include/local-execution/task_signature_impl.h index 659ccb23d6..98c5c0cb3b 100644 --- a/lib/local-execution/include/local-execution/task_signature_impl.h +++ b/lib/local-execution/include/local-execution/task_signature_impl.h @@ -1,24 +1,19 @@ #ifndef _FLEXFLOW_LOCAL_EXECUTION_TASK_SIGNATURE_IMPL_H #define _FLEXFLOW_LOCAL_EXECUTION_TASK_SIGNATURE_IMPL_H -#include "local-execution/device_specific.h" -#include "local-execution/device_states.h" -#include "local-execution/tasks.h" -#include "task_argument_accessor.h" -#include "utils/variant.h" +#include "local-execution/op_task_invocation.h" +#include "local-execution/task_id_t.dtg.h" +#include "local-execution/task_signature_impl.dtg.h" +#include "op-attrs/computation_graph_op_attrs.h" namespace FlexFlow { -using TaskImplFunction = std::variant< - std::function(TaskArgumentAccessor const &)>, - std::function(TaskArgumentAccessor const &)>>; - -struct TaskSignatureAndImpl { - TaskImplFunction impl_function; - OpTaskSignature task_signature; -}; - TaskSignatureAndImpl get_task_sig_impl(task_id_t const &); +std::vector get_task_ids(ComputationGraphOpAttrs const &); + +OpTaskInvocation init(ComputationGraphOpAttrs const &); +OpTaskInvocation forward(ComputationGraphOpAttrs const &); +OpTaskInvocation backward(ComputationGraphOpAttrs const &); } // namespace FlexFlow diff --git a/lib/local-execution/include/local-execution/task_signature_impl.struct.toml b/lib/local-execution/include/local-execution/task_signature_impl.struct.toml new file mode 100644 index 0000000000..981794503b --- /dev/null +++ b/lib/local-execution/include/local-execution/task_signature_impl.struct.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "TaskSignatureAndImpl" +features = [ + "eq", + "fmt", + "hash" +] + +includes = [ + "local-execution/task_impl_function.dtg.h", + "local-execution/op_task_signature.h", +] + +[[fields]] +name = "impl_function" +type = "::FlexFlow::TaskImplFunction" + +[[fields]] +name = "task_signature" +type = "::FlexFlow::OpTaskSignature" diff --git a/lib/local-execution/include/local-execution/tasks.h b/lib/local-execution/include/local-execution/tasks.h index a747c95383..4f5b26c43b 100644 --- a/lib/local-execution/include/local-execution/tasks.h +++ b/lib/local-execution/include/local-execution/tasks.h @@ -1,169 +1,13 @@ #ifndef _FLEXFLOW_LOCAL_EXECUTION_TASKS_H #define _FLEXFLOW_LOCAL_EXECUTION_TASKS_H +#include "local-execution/task_id_t.dtg.h" #include #include #include namespace FlexFlow { - -enum task_id_t { - TOP_LEVEL_TASK_ID, - FF_INIT_TASK_ID, - IMAGE_INIT_TASK_ID, - LABEL_INIT_TASK_ID, - LOAD_IMAGES_TASK_ID, - NORMALIZE_IMAGES_TASK_ID, - ELEMENTBINARY_INIT_TASK_ID, - ELEMENTBINARY_FWD_TASK_ID, - ELEMENTBINARY_BWD_TASK_ID, - ELEMENTUNARY_INIT_TASK_ID, - ELEMENTUNARY_FWD_TASK_ID, - ELEMENTUNARY_BWD_TASK_ID, - CONV2D_INIT_TASK_ID, - CONV2D_FWD_TASK_ID, - CONV2D_BWD_TASK_ID, - DROPOUT_INIT_TASK_ID, - DROPOUT_FWD_TASK_ID, - DROPOUT_BWD_TASK_ID, - EMBED_INIT_TASK_ID, - EMBED_FWD_TASK_ID, - EMBED_BWD_TASK_ID, - GATHER_INIT_TASK_ID, - GATHER_FWD_TASK_ID, - GATHER_BWD_TASK_ID, - CAST_INIT_TASK_ID, - CAST_FWD_TASK_ID, - CAST_BWD_TASK_ID, - POOL2D_INIT_TASK_ID, - POOL2D_FWD_TASK_ID, - POOL2D_BWD_TASK_ID, - BATCHNORM_INIT_TASK_ID, - BATCHNORM_FWD_TASK_ID, - BATCHNORM_BWD_TASK_ID, - BATCHMATMUL_INIT_TASK_ID, - BATCHMATMUL_FWD_TASK_ID, - BATCHMATMUL_BWD_TASK_ID, - LAYERNORM_INIT_TASK_ID, - LAYERNORM_FWD_TASK_ID, - LAYERNORM_BWD_TASK_ID, - LINEAR_INIT_TASK_ID, - LINEAR_FWD_TASK_ID, - LINEAR_BWD_TASK_ID, - FLAT_INIT_TASK_ID, - FLAT_FWD_TASK_ID, - FLAT_BWD_TASK_ID, - SOFTMAX_INIT_TASK_ID, - SOFTMAX_FWD_TASK_ID, - SOFTMAX_BWD_TASK_ID, - CONCAT_INIT_TASK_ID, - CONCAT_FWD_TASK_ID, - CONCAT_BWD_TASK_ID, - SPLIT_INIT_TASK_ID, - SPLIT_FWD_TASK_ID, - SPLIT_BWD_TASK_ID, - REDUCE_INIT_TASK_ID, - REDUCE_FWD_TASK_ID, - REDUCE_BWD_TASK_ID, - RESHAPE_INIT_TASK_ID, - RESHAPE_FWD_TASK_ID, - RESHAPE_BWD_TASK_ID, - REVERSE_INIT_TASK_ID, - REVERSE_FWD_TASK_ID, - REVERSE_BWD_TASK_ID, - TOPK_INIT_TASK_ID, - TOPK_FWD_TASK_ID, - TOPK_BWD_TASK_ID, - TRANSPOSE_INIT_TASK_ID, - TRANSPOSE_FWD_TASK_ID, - TRANSPOSE_BWD_TASK_ID, - ATTENTION_INIT_TASK_ID, - ATTENTION_FWD_TASK_ID, - ATTENTION_BWD_TASK_ID, - MSELOSS_BWD_TASK_ID, - FUSEDOP_INIT_TASK_ID, - FUSEDOP_FWD_TASK_ID, - FUSEDOP_BWD_TASK_ID, - NOOP_INIT_TASK_ID, - // Metrics tasks - METRICS_COMP_TASK_ID, - UPDATE_METRICS_TASK_ID, - // Parameter server prefetch task - PS_PREFETCH_TASK_ID, - // Loss - LOSS_BWD_TASK_ID, - // Optimizer with PS - SGD_UPD_PS_TASK_ID, - ADAM_UPD_PS_TASK_ID, - // Optimizer with NCCL - SGD_UPD_NCCL_TASK_ID, - ADAM_UPD_NCCL_TASK_ID, - // Initializer - GLOROT_INIT_TASK_ID, - ZERO_INIT_TASK_ID, - CONSTANT_INIT_TASK_ID, - UNIFORM_INIT_TASK_ID, - NORMAL_INIT_TASK_ID, - // NCCL tasks - NCCL_GETUNIQUEID_TASK_ID, - NCCL_INIT_COMMS_TASK_ID, - // Search - STRATEGY_SEARCH_TASK_ID, - // Graph - GRAPH_OPTIMIZE_TASK_ID, - // Python data loader - PY_DL_FLOAT_LOAD_ENTIRE_CPU_TASK_ID, - PY_DL_INT32_LOAD_ENTIRE_CPU_TASK_ID, - PY_DL_INT64_LOAD_ENTIRE_CPU_TASK_ID, - PY_DL_FLOAT_INDEX_LOAD_ENTIRE_CPU_TASK_ID, - PY_DL_INT32_INDEX_LOAD_ENTIRE_CPU_TASK_ID, - PY_DL_INT64_INDEX_LOAD_ENTIRE_CPU_TASK_ID, - PY_DL_FLOAT_LOAD_BATCH_GPU_TASK_ID, - PY_DL_INT32_LOAD_BATCH_GPU_TASK_ID, - PY_DL_INT64_LOAD_BATCH_GPU_TASK_ID, - // Parallel Ops - REPARTITION_INIT_TASK_ID, - REPARTITION_FWD_TASK_ID, - REPARTITION_BWD_TASK_ID, - COMBINE_INIT_TASK_ID, - COMBINE_FWD_TASK_ID, - COMBINE_BWD_TASK_ID, - REPLICATE_INIT_TASK_ID, - REPLICATE_FWD_TASK_ID, - REPLICATE_BWD_TASK_ID, - REDUCTION_INIT_TASK_ID, - REDUCTION_FWD_TASK_ID, - REDUCTION_BWD_TASK_ID, - PIPELINE_INIT_TASK_ID, - PIPELINE_FWD_TASK_ID, - PIPELINE_BWD_TASK_ID, - FUSED_PARALLELOP_INIT_TASK_ID, - FUSED_PARALLELOP_FWD_TASK_ID, - FUSED_PARALLELOP_BWD_TASK_ID, - // Custom tasks - CUSTOM_GPU_TASK_ID_FIRST, - CUSTOM_GPU_TASK_ID_1, - CUSTOM_GPU_TASK_ID_2, - CUSTOM_GPU_TASK_ID_3, - CUSTOM_GPU_TASK_ID_4, - CUSTOM_GPU_TASK_ID_5, - CUSTOM_GPU_TASK_ID_6, - CUSTOM_GPU_TASK_ID_7, - CUSTOM_GPU_TASK_ID_8, - CUSTOM_GPU_TASK_ID_LAST, - CUSTOM_CPU_TASK_ID_FIRST, - CUSTOM_CPU_TASK_ID_1, - CUSTOM_CPU_TASK_ID_2, - CUSTOM_CPU_TASK_ID_3, - CUSTOM_CPU_TASK_ID_4, - CUSTOM_CPU_TASK_ID_5, - CUSTOM_CPU_TASK_ID_6, - CUSTOM_CPU_TASK_ID_7, - CUSTOM_CPU_TASK_ID_LAST, - // Make sure PYTHON_TOP_LEVEL_TASK_ID is - // consistent with python/main.cc - PYTHON_TOP_LEVEL_TASK_ID = 11111, -}; +// PYTHON_TOP_LEVEL_TASK_ID = 11111, void register_flexflow_internal_tasks(); diff --git a/lib/local-execution/include/local-execution/tracked_allocator.h b/lib/local-execution/include/local-execution/tracked_allocator.h index ae7bd076ce..731e04fdc8 100644 --- a/lib/local-execution/include/local-execution/tracked_allocator.h +++ b/lib/local-execution/include/local-execution/tracked_allocator.h @@ -17,6 +17,7 @@ struct TrackedAllocator : public IAllocator { private: size_t current_mem_usage = 0; + std::unordered_map ptr_mem_usage; Allocator allocator; }; CHECK_RC_COPY_VIRTUAL_COMPLIANT(TrackedAllocator); diff --git a/lib/local-execution/src/concrete_arg.cc b/lib/local-execution/src/concrete_arg.cc new file mode 100644 index 0000000000..450d663e17 --- /dev/null +++ b/lib/local-execution/src/concrete_arg.cc @@ -0,0 +1,29 @@ +#include "local-execution/concrete_arg.h" + +namespace FlexFlow { + +bool ConcreteArgSpec::operator==(ConcreteArgSpec const &other) const { + return this->tie() == other.tie(); +} + +bool ConcreteArgSpec::operator!=(ConcreteArgSpec const &other) const { + return this->tie() != other.tie(); +} + +std::tuple ConcreteArgSpec::tie() const { + return std::tie(this->type_idx); +} + +std::string format_as(ConcreteArgSpec const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} + +std::ostream &operator<<(std::ostream &s, ConcreteArgSpec const &x) { + return s << fmt::to_string(x); +} + +} // namespace FlexFlow diff --git a/lib/local-execution/src/fwd_bwd_task_impl_function.cc b/lib/local-execution/src/fwd_bwd_task_impl_function.cc new file mode 100644 index 0000000000..f85d7cec61 --- /dev/null +++ b/lib/local-execution/src/fwd_bwd_task_impl_function.cc @@ -0,0 +1,54 @@ +#include "local-execution/fwd_bwd_task_impl_function.h" + +namespace FlexFlow { + +bool FwdBwdTaskImplFunction::operator==( + FwdBwdTaskImplFunction const &other) const { + return this->function_ptr == other.function_ptr; +} + +bool FwdBwdTaskImplFunction::operator!=( + FwdBwdTaskImplFunction const &other) const { + return this->function_ptr != other.function_ptr; +} + +bool FwdBwdTaskImplFunction::operator<( + FwdBwdTaskImplFunction const &other) const { + return this->function_ptr < other.function_ptr; +} + +bool FwdBwdTaskImplFunction::operator>( + FwdBwdTaskImplFunction const &other) const { + return this->function_ptr > other.function_ptr; +} + +bool FwdBwdTaskImplFunction::operator<=( + FwdBwdTaskImplFunction const &other) const { + return this->function_ptr <= other.function_ptr; +} + +bool FwdBwdTaskImplFunction::operator>=( + FwdBwdTaskImplFunction const &other) const { + return this->function_ptr >= other.function_ptr; +} + +std::string format_as(FwdBwdTaskImplFunction const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} + +std::ostream &operator<<(std::ostream &s, FwdBwdTaskImplFunction const &x) { + return s << fmt::to_string(x); +} + +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + ::FlexFlow::FwdBwdTaskImplFunction const &x) const { + return std::hash{}(x.function_ptr); +} +} // namespace std diff --git a/lib/local-execution/src/init_task_impl_function.cc b/lib/local-execution/src/init_task_impl_function.cc new file mode 100644 index 0000000000..9501f72dd6 --- /dev/null +++ b/lib/local-execution/src/init_task_impl_function.cc @@ -0,0 +1,47 @@ +#include "local-execution/init_task_impl_function.h" + +namespace FlexFlow { + +bool InitTaskImplFunction::operator==(InitTaskImplFunction const &other) const { + return this->function_ptr == other.function_ptr; +} + +bool InitTaskImplFunction::operator!=(InitTaskImplFunction const &other) const { + return this->function_ptr != other.function_ptr; +} + +bool InitTaskImplFunction::operator<(InitTaskImplFunction const &other) const { + return this->function_ptr < other.function_ptr; +} + +bool InitTaskImplFunction::operator>(InitTaskImplFunction const &other) const { + return this->function_ptr > other.function_ptr; +} + +bool InitTaskImplFunction::operator<=(InitTaskImplFunction const &other) const { + return this->function_ptr <= other.function_ptr; +} + +bool InitTaskImplFunction::operator>=(InitTaskImplFunction const &other) const { + return this->function_ptr >= other.function_ptr; +} + +std::string format_as(InitTaskImplFunction const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, InitTaskImplFunction const &x) { + return s << fmt::to_string(x); +} + +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + ::FlexFlow::InitTaskImplFunction const &x) const { + return std::hash{}(x.function_ptr); +} +} // namespace std diff --git a/lib/local-execution/src/legion_tensor_shape.cc b/lib/local-execution/src/legion_tensor_shape.cc new file mode 100644 index 0000000000..b3a045bab4 --- /dev/null +++ b/lib/local-execution/src/legion_tensor_shape.cc @@ -0,0 +1,13 @@ +#include "local-execution/legion_tensor_shape.h" + +namespace FlexFlow { + +legion_dim_t legion_dim_from_ff_dim(ff_dim_t ff_dim, size_t num_dims) { + return legion_dim_t(num_dims - ff_dim.value - 1); +} + +legion_dim_t legion_dim_from_ff_dim(ff_dim_t ff_dim, TensorShape const &shape) { + return legion_dim_t(num_dims(shape) - ff_dim.value - 1); +} + +} // namespace FlexFlow diff --git a/lib/local-execution/src/local_cost_estimator.cc b/lib/local-execution/src/local_cost_estimator.cc index be3dfb01aa..d4e0467cbf 100644 --- a/lib/local-execution/src/local_cost_estimator.cc +++ b/lib/local-execution/src/local_cost_estimator.cc @@ -32,7 +32,8 @@ CostDetails LocalCostEstimator::estimate_cost( std::vector const &outputs, MachineView const &mv) const { - if (is_parallel_op(op) || op.has() || op.has()) { + if (is_parallel_op(op) || op.has() || op.has() || + op.has()) { return CostDetails{0, 0}; } diff --git a/lib/local-execution/src/local_cpu_allocator.cc b/lib/local-execution/src/local_cpu_allocator.cc new file mode 100644 index 0000000000..4ca5f987a8 --- /dev/null +++ b/lib/local-execution/src/local_cpu_allocator.cc @@ -0,0 +1,24 @@ +#include "local-execution/local_cpu_allocator.h" +#include "utils/containers/contains_key.h" + +namespace FlexFlow { +void *LocalCPUAllocator::allocate(size_t requested_memory_size) { + void *ptr = malloc(requested_memory_size); + this->ptrs.insert({ptr, std::unique_ptr(ptr, free)}); + return ptr; +} + +void LocalCPUAllocator::deallocate(void *ptr) { + if (contains_key(this->ptrs, ptr)) { + this->ptrs.erase(ptr); + } else { + throw std::runtime_error( + "Deallocating a pointer that was not allocated by this Allocator"); + } +} + +Allocator create_local_cpu_memory_allocator() { + return Allocator::create(); +} + +} // namespace FlexFlow diff --git a/lib/local-execution/src/local_slots_backing.cc b/lib/local-execution/src/local_slots_backing.cc index a07799def6..0ec9068c6a 100644 --- a/lib/local-execution/src/local_slots_backing.cc +++ b/lib/local-execution/src/local_slots_backing.cc @@ -11,22 +11,61 @@ LocalSlotsBacking::LocalSlotsBacking(TensorBackingMap const &allocated_tensors, void LocalSlotsBacking::add_per_device_op_state( layer_guid_t const &op_guid, - DeviceSpecific const &device_state) { + DeviceSpecificDeviceStates const &device_state) { this->per_device_op_states.insert({op_guid, device_state}); } +void LocalSlotsBacking::allocate_outgoing_tensors( + layer_guid_t const &layer_guid, + ComputationGraph const &computation_graph, + Allocator &allocator) { + std::vector incoming_tensors = + get_incoming_tensors(computation_graph, layer_guid); + std::vector outgoing_tensors = + get_outgoing_tensors(computation_graph, layer_guid); + for (tensor_guid_t const &output_tensor : outgoing_tensors) { + TensorAttrs tensor_attrs = + get_tensor_attrs(computation_graph, output_tensor); + // tensor allocation + if (!is_tensor_allocated(output_tensor)) { + GenericTensorAccessorW tensor_backing = + allocator.allocate_tensor(tensor_attrs.shape); + this->tensor_mapping.insert({output_tensor, tensor_backing}); + } + + // gradient tensor allocation + if (tensor_attrs.create_gradients == CreateGrad::YES && + !is_gradient_tensor_allocated(output_tensor)) { + GenericTensorAccessorW gradient_tensor_backing = + allocator.allocate_tensor(tensor_attrs.shape); + this->gradient_tensor_mapping.insert( + {output_tensor, gradient_tensor_backing}); + } + } + + this->input_tensor_slots.insert({layer_guid, incoming_tensors}); + this->output_tensor_slots.insert({layer_guid, outgoing_tensors}); +} + bool LocalSlotsBacking::is_tensor_allocated( tensor_guid_t const &tensor_id) const { return contains_key(this->tensor_mapping, tensor_id); } +bool LocalSlotsBacking::is_gradient_tensor_allocated( + tensor_guid_t const &tensor_id) const { + return contains_key(this->gradient_tensor_mapping, tensor_id); +} + GenericTensorAccessorW const & LocalSlotsBacking::get_tensor_backing(tensor_guid_t const &tensor_id, IsGrad is_grad) const { switch (is_grad) { case IsGrad::NO: + assert(contains_key(this->tensor_mapping, tensor_id)); return this->tensor_mapping.at(tensor_id); case IsGrad::YES: + assert(contains_key(this->gradient_tensor_mapping, tensor_id)); return this->gradient_tensor_mapping.at(tensor_id); default: throw mk_runtime_error(fmt::format( @@ -43,12 +82,12 @@ TensorSlotsBacking LocalSlotsBacking::construct_tensor_slots_backing( std::vector tensor_guids; switch (tensor_spec.role) { case TensorRole::INPUT: - tensor_guids = this->input_tensor_slots.at(op_guid); - break; case TensorRole::WEIGHT: - tensor_guids = this->weight_tensor_slots.at(op_guid); + assert(contains_key(this->input_tensor_slots, op_guid)); + tensor_guids = this->input_tensor_slots.at(op_guid); break; case TensorRole::OUTPUT: + assert(contains_key(this->output_tensor_slots, op_guid)); tensor_guids = this->output_tensor_slots.at(op_guid); break; default: @@ -56,9 +95,12 @@ TensorSlotsBacking LocalSlotsBacking::construct_tensor_slots_backing( fmt::format("Invalid TensorRole")); // inserting role yields // "type_is_unformattable" error } + + assert(tensor_guids.size() > tensor_spec.idx); IsGrad is_grad = slot_grad_id.is_grad; GenericTensorAccessorW tensor_backing = this->get_tensor_backing(tensor_guids.at(tensor_spec.idx), is_grad); + mapping.insert({slot_grad_id, tensor_backing}); } return mapping; @@ -87,13 +129,22 @@ ArgSlotsBacking LocalSlotsBacking::construct_arg_slots_backing( ConcreteArgSpec LocalSlotsBacking::resolve_op_arg_ref_spec( OpArgRefSpec const &op_arg_ref_spec, layer_guid_t const &op_guid) const { - if (op_arg_ref_spec.holds>()) { - return ConcreteArgSpec::create(per_device_op_states.at(op_guid)); + if (op_arg_ref_spec.holds()) { + assert(contains_key(per_device_op_states, op_guid)); + DeviceSpecificDeviceStates device_specific = + per_device_op_states.at(op_guid); + PerDeviceOpState device_state = + get_device_state_from_device_specific(device_specific, 0); + return ConcreteArgSpec::create(device_state); } else if (op_arg_ref_spec.holds()) { ParallelTensorShapeRefType index_op_arg_ref = - std::get(op_arg_ref_spec.get_ref_type()); + op_arg_ref_spec.get_ref_type().get(); + + assert(contains_key(this->input_tensor_slots, op_guid)); std::vector input_tensor_guids = this->input_tensor_slots.at(op_guid); + + assert(input_tensor_guids.size() > index_op_arg_ref.idx); GenericTensorAccessorW tensor_backing = this->get_tensor_backing( input_tensor_guids.at(index_op_arg_ref.idx), IsGrad::NO); ParallelTensorShape shape = lift_to_parallel( @@ -107,9 +158,8 @@ ConcreteArgSpec LocalSlotsBacking::resolve_op_arg_ref_spec( ConcreteArgSpec LocalSlotsBacking::resolve_runtime_arg_ref_spec( RuntimeArgRefSpec const &runtime_arg_ref_spec) const { if (runtime_arg_ref_spec.holds>()) { - return ConcreteArgSpec::create(this->runtime_arg_config.ff_handle); - } else if (runtime_arg_ref_spec.holds()) { - return ConcreteArgSpec::create(this->runtime_arg_config.enable_profiling); + return ConcreteArgSpec::create( + *(this->runtime_arg_config.ff_handle.get(0))); } else if (runtime_arg_ref_spec.holds()) { return ConcreteArgSpec::create(this->runtime_arg_config.profiling_settings); } else { diff --git a/lib/local-execution/src/local_task_argument_accessor.cc b/lib/local-execution/src/local_task_argument_accessor.cc index 62fe9b2d16..5d0156201e 100644 --- a/lib/local-execution/src/local_task_argument_accessor.cc +++ b/lib/local-execution/src/local_task_argument_accessor.cc @@ -1,5 +1,8 @@ #include "local-execution/local_task_argument_accessor.h" +#include "utils/containers/contains_key.h" +#include "utils/containers/transform.h" #include "utils/hash/pair.h" +#include "utils/overload.h" namespace FlexFlow { @@ -54,6 +57,37 @@ Allocator LocalTaskArgumentAccessor::get_allocator() const { return this->allocator; } +TensorSlotsBackingWithoutAddresses + get_slots_backing_without_tensor_allocation_addresses( + TensorSlotsBacking const &slots_backing) { + + TensorSlotsBackingWithoutAddresses addressless_slots_backing; + + using TensorAccessorVariant = + std::variant>; + for (auto const &slot_tensor : slots_backing) { + TensorAccessorVariant accessor_variant = slot_tensor.second; + std::visit( + overload{ + [&](GenericTensorAccessorW const &accessor) { + addressless_slots_backing.insert( + {slot_tensor.first, get_shape_and_datatype(accessor)}); + }, + [&](std::vector const &variadic_accessor) { + std::vector> + variadic_addressless_accessor = + transform(variadic_accessor, + [](GenericTensorAccessorW const &accessor) { + return get_shape_and_datatype(accessor); + }); + addressless_slots_backing.insert( + {slot_tensor.first, variadic_addressless_accessor}); + }}, + accessor_variant); + } + return addressless_slots_backing; +} + size_t LocalTaskArgumentAccessor::get_device_idx() const { return 0; } diff --git a/lib/local-execution/src/local_training_backing.cc b/lib/local-execution/src/local_training_backing.cc index 6d5a5011fd..a2ee06a95a 100644 --- a/lib/local-execution/src/local_training_backing.cc +++ b/lib/local-execution/src/local_training_backing.cc @@ -1,4 +1,5 @@ #include "local-execution/local_training_backing.h" +#include "local-execution/task_signature_impl.h" #include "utils/containers/reversed.h" #include "utils/exception.h" @@ -10,51 +11,29 @@ LocalTrainingBacking::LocalTrainingBacking( TensorBackingMap const &tensor_backing_mapping, RuntimeArgConfig const &runtime_arg_config) : allocator(allocator), computation_graph(computation_graph), - local_slots_backing(tensor_backing_mapping, runtime_arg_config) { - std::vector layers = topological_ordering(computation_graph); - for (layer_guid_t const &node : layers) { + local_slots_backing(tensor_backing_mapping, runtime_arg_config), + task_registry(empty_task_registry()) { + + for (layer_guid_t const &node : topological_ordering(computation_graph)) { ComputationGraphOpAttrs attrs = get_layer_attrs(computation_graph, node).attrs; - // register tasks - std::vector task_ids = get_task_ids(attrs); - for (task_id_t task_id : task_ids) { - this->task_registry.register_task(task_id, node, attrs); - } + // allocate outgoing tensors + this->local_slots_backing.allocate_outgoing_tensors( + node, computation_graph, this->allocator); - // insert pre-allocated tensors - this->local_slots_backing.input_tensor_slots.insert( - {node, get_incoming_tensors(computation_graph, node)}); - this->local_slots_backing.output_tensor_slots.insert( - {node, get_outgoing_tensors(computation_graph, node)}); - - // allocate new tensors - for (tensor_guid_t const &edge : - get_outgoing_tensors(computation_graph, node)) { - if (!this->local_slots_backing.is_tensor_allocated(edge)) { - TensorAttrs tensor_attrs = get_tensor_attrs(computation_graph, edge); - GenericTensorAccessorW tensor_backing = - this->allocator.allocate_tensor(tensor_attrs.shape); - this->local_slots_backing.tensor_mapping.insert({edge, tensor_backing}); - - if (tensor_attrs.create_gradients == CreateGrad::YES) { - GenericTensorAccessorW gradient_tensor_backing = - this->allocator.allocate_tensor(tensor_attrs.shape); - this->local_slots_backing.gradient_tensor_mapping.insert( - {edge, gradient_tensor_backing}); - } - } - } + // register tasks + register_tasks_for_layer(this->task_registry, node, attrs); } } -DeviceSpecific +DeviceSpecificDeviceStates LocalTrainingBacking::call_init_task_impl(task_id_t task_id, TaskArgumentAccessor const &acc) { TaskSignatureAndImpl task_sig_impl = this->task_registry.task_mapping.at(task_id); - auto fn = std::get( - TaskArgumentAccessor const &)>>(task_sig_impl.impl_function); + auto fn = + task_sig_impl.impl_function.get().function_ptr; return fn(acc); } @@ -63,24 +42,26 @@ std::optional TaskArgumentAccessor acc) { TaskSignatureAndImpl task_sig_impl = this->task_registry.task_mapping.at(task_id); - auto fn = std::get< - std::function(TaskArgumentAccessor const &)>>( - task_sig_impl.impl_function); + auto fn = + task_sig_impl.impl_function.get().function_ptr; return fn(acc); } void LocalTrainingBacking::execute_init() { for (layer_guid_t const &operator_node : topological_ordering(this->computation_graph)) { - ComputationGraphOpAttrs attrs = - get_layer_attrs(this->computation_graph, operator_node).attrs; - OpTaskInvocation invocation = init(attrs); - TaskArgumentAccessor accessor = - this->get_task_arg_accessor(invocation, operator_node); - DeviceSpecific device_state = - this->call_init_task_impl(invocation.task_id, accessor); - this->local_slots_backing.add_per_device_op_state(operator_node, - device_state); + if (this->task_registry.init_task_ids.at(operator_node).has_value()) { + ComputationGraphOpAttrs attrs = + get_layer_attrs(this->computation_graph, operator_node).attrs; + + OpTaskInvocation invocation = init(attrs); + TaskArgumentAccessor accessor = + this->get_task_arg_accessor(invocation, operator_node); + DeviceSpecificDeviceStates device_state = + this->call_init_task_impl(invocation.task_id, accessor); + this->local_slots_backing.add_per_device_op_state(operator_node, + device_state); + } } } @@ -88,14 +69,17 @@ PerLayerElapsedTime LocalTrainingBacking::execute_forward() { PerLayerElapsedTime per_op_elapsed_time; for (layer_guid_t const &operator_node : topological_ordering(this->computation_graph)) { - ComputationGraphOpAttrs attrs = - get_layer_attrs(this->computation_graph, operator_node).attrs; - OpTaskInvocation invocation = forward(attrs); - TaskArgumentAccessor accessor = - this->get_task_arg_accessor(invocation, operator_node); - std::optional elapsed_time = - this->call_task_impl(invocation.task_id, accessor); - per_op_elapsed_time.insert({operator_node, elapsed_time}); + if (this->task_registry.forward_task_ids.at(operator_node).has_value()) { + ComputationGraphOpAttrs attrs = + get_layer_attrs(this->computation_graph, operator_node).attrs; + + OpTaskInvocation invocation = forward(attrs); + TaskArgumentAccessor accessor = + this->get_task_arg_accessor(invocation, operator_node); + std::optional elapsed_time = + this->call_task_impl(invocation.task_id, accessor); + per_op_elapsed_time.insert({operator_node, elapsed_time}); + } } return per_op_elapsed_time; } @@ -104,14 +88,17 @@ PerLayerElapsedTime LocalTrainingBacking::execute_backward() { PerLayerElapsedTime per_op_elapsed_time; for (layer_guid_t const &operator_node : reversed(topological_ordering(this->computation_graph))) { - ComputationGraphOpAttrs attrs = - get_layer_attrs(this->computation_graph, operator_node).attrs; - OpTaskInvocation invocation = backward(attrs); - TaskArgumentAccessor accessor = - this->get_task_arg_accessor(invocation, operator_node); - std::optional elapsed_time = - this->call_task_impl(invocation.task_id, accessor); - per_op_elapsed_time.insert({operator_node, elapsed_time}); + if (this->task_registry.backward_task_ids.at(operator_node).has_value()) { + ComputationGraphOpAttrs attrs = + get_layer_attrs(this->computation_graph, operator_node).attrs; + + OpTaskInvocation invocation = backward(attrs); + TaskArgumentAccessor accessor = + this->get_task_arg_accessor(invocation, operator_node); + std::optional elapsed_time = + this->call_task_impl(invocation.task_id, accessor); + per_op_elapsed_time.insert({operator_node, elapsed_time}); + } } return per_op_elapsed_time; } diff --git a/lib/local-execution/src/op_arg_ref.cc b/lib/local-execution/src/op_arg_ref.cc index cad251f33e..b3d6e2f1a5 100644 --- a/lib/local-execution/src/op_arg_ref.cc +++ b/lib/local-execution/src/op_arg_ref.cc @@ -3,7 +3,7 @@ namespace FlexFlow { OpArgRef input_parallel_tensor_shape(int idx) { - OpArgRefType arg_ref_type = ParallelTensorShapeRefType{idx}; + OpArgRefType arg_ref_type = OpArgRefType{ParallelTensorShapeRefType{idx}}; ArgRef arg_ref = {arg_ref_type}; return arg_ref; } diff --git a/lib/local-execution/src/local-execution/op_arg_spec.cc b/lib/local-execution/src/op_arg_spec.cc similarity index 100% rename from lib/local-execution/src/local-execution/op_arg_spec.cc rename to lib/local-execution/src/op_arg_spec.cc diff --git a/lib/local-execution/src/op_task_invocation.cc b/lib/local-execution/src/op_task_invocation.cc index 3569bfb122..19c8894b05 100644 --- a/lib/local-execution/src/op_task_invocation.cc +++ b/lib/local-execution/src/op_task_invocation.cc @@ -36,6 +36,20 @@ void OpTaskBinding::insert_arg_spec(slot_id_t name, OpArgSpec const &arg_spec) { this->arg_bindings.insert({name, arg_spec}); } +bool OpTaskBinding::operator==(OpTaskBinding const &other) const { + return this->tie() == other.tie(); +} + +bool OpTaskBinding::operator!=(OpTaskBinding const &other) const { + return this->tie() != other.tie(); +} + +std::tuple const &, + std::unordered_map const &> + OpTaskBinding::tie() const { + return std::tie(this->tensor_bindings, this->arg_bindings); +} + std::unordered_map const & OpTaskBinding::get_tensor_bindings() const { return this->tensor_bindings; @@ -82,6 +96,9 @@ bool is_tensor_invocation_valid(OpTaskSignature const &sig, return false; } } + + // FIXME -- make sure invocation doesn't contain MORE than signature + // https://github.com/flexflow/FlexFlow/issues/1442 return true; } @@ -93,13 +110,13 @@ bool is_arg_type_invalid(std::type_index expected_arg_type, bool is_arg_invocation_valid(OpTaskSignature const &sig, OpTaskInvocation const &inv) { - auto sig_arg_types = sig.get_arg_types(); - for (auto arg_binding : inv.binding.get_arg_bindings()) { - std::type_index arg_type = sig_arg_types.at(arg_binding.first); - if (is_arg_type_invalid(arg_type, arg_binding.second)) { - return false; - } - } + // FIXME -- arg signature/invocation checking + // https://github.com/flexflow/FlexFlow/issues/1442 + // auto sig_arg_types = sig.get_arg_types(); + // for (auto arg_binding : inv.binding.get_arg_bindings()) { + // std::type_index arg_type = sig_arg_types.at(arg_binding.first); + // assert (!is_arg_type_invalid(arg_type, arg_binding.second)); + // } return true; } diff --git a/lib/local-execution/src/op_task_signature.cc b/lib/local-execution/src/op_task_signature.cc index 3267ff592f..36a1dd708d 100644 --- a/lib/local-execution/src/op_task_signature.cc +++ b/lib/local-execution/src/op_task_signature.cc @@ -1,4 +1,6 @@ #include "local-execution/op_task_signature.h" +#include "utils/fmt/unordered_map.h" +#include "utils/fmt/unordered_set.h" namespace FlexFlow { @@ -61,18 +63,6 @@ void OpTaskSignature::add_output_slot(int name, SlotType slot_type) { } void OpTaskSignature::add_output_slot(slot_id_t name, SlotType slot_type) { - OpTensorSlotSpec op_tensor_slot_spec = OpTensorSlotSpec{ - name, slot_type, TensorRole::OUTPUT, IsGrad::NO, OpSlotOptions::OPTIONAL}; - this->op_tensor_slots.insert(op_tensor_slot_spec); -} - -void OpTaskSignature::add_bwd_necessary_output_slot(int name, - SlotType slot_type) { - this->add_bwd_necessary_output_slot(slot_id_t{name}, slot_type); -} - -void OpTaskSignature::add_bwd_necessary_output_slot(slot_id_t name, - SlotType slot_type) { OpTensorSlotSpec op_tensor_slot_spec = OpTensorSlotSpec{name, slot_type, @@ -82,6 +72,18 @@ void OpTaskSignature::add_bwd_necessary_output_slot(slot_id_t name, this->op_tensor_slots.insert(op_tensor_slot_spec); } +void OpTaskSignature::add_bwd_optional_output_slot(int name, + SlotType slot_type) { + this->add_bwd_optional_output_slot(slot_id_t{name}, slot_type); +} + +void OpTaskSignature::add_bwd_optional_output_slot(slot_id_t name, + SlotType slot_type) { + OpTensorSlotSpec op_tensor_slot_spec = OpTensorSlotSpec{ + name, slot_type, TensorRole::OUTPUT, IsGrad::NO, OpSlotOptions::OPTIONAL}; + this->op_tensor_slots.insert(op_tensor_slot_spec); +} + void OpTaskSignature::add_weight_slot(int name, SlotType slot_type) { this->add_weight_slot(slot_id_t{name}, slot_type); } @@ -116,4 +118,47 @@ void OpTaskSignature::add_from_slot_spec(OpTensorSlotSpec const &spec) { this->op_tensor_slots.insert(spec); } +OpTaskSignature infer_bwd_signature(OpTaskSignature const &fwd) { + OpTaskSignature bwd = fwd; + bwd.type = OpTaskType::BWD; + for (auto const &op_tensor_slot_spec : fwd.get_tensor_slots()) { + OpSlotOptions slot_option = op_tensor_slot_spec.slot_option; + if (slot_option != OpSlotOptions::UNTRAINABLE || + slot_option != OpSlotOptions::OPTIONAL_UNTRAINABLE) { + OpTensorSlotSpec grad_spec = + OpTensorSlotSpec{op_tensor_slot_spec.name, + op_tensor_slot_spec.slot_type, + op_tensor_slot_spec.tensor_role, + IsGrad::YES, + op_tensor_slot_spec.slot_option}; + bwd.op_tensor_slots.insert(grad_spec); + } + } + + return bwd; +} + +std::unordered_set OpTaskSignature::get_tensor_slots() const { + return this->op_tensor_slots; +} + +std::unordered_map + OpTaskSignature::get_arg_types() const { + return this->task_arg_types; +} + +std::string format_as(OpTaskSignature const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, OpTaskSignature const &x) { + return s << fmt::to_string(x); +} + } // namespace FlexFlow diff --git a/lib/local-execution/src/ops/attention.cc b/lib/local-execution/src/ops/attention.cc index fc3627404d..eebef9039d 100644 --- a/lib/local-execution/src/ops/attention.cc +++ b/lib/local-execution/src/ops/attention.cc @@ -13,9 +13,10 @@ * limitations under the License. */ -#include "attention.h" +#include "local-execution/ops/attention.h" #include "kernels/attention_kernels.h" #include "local-execution/op_task_signature.h" +#include "op-attrs/ops/attention.h" #include "op-attrs/ops/attention/multihead_attention_parallel_inputs.h" namespace FlexFlow { @@ -56,7 +57,7 @@ OpTaskInvocation init(MultiHeadAttentionAttrs const &attrs) { b.bind_arg(VPROJSIZE, get_vProjSize(attrs)); b.bind_arg(OPROJSIZE, get_oProjSize(attrs)); - return {ATTENTION_INIT_TASK_ID, b}; + return {task_id_t::ATTENTION_INIT_TASK_ID, b}; } OpTaskInvocation forward(MultiHeadAttentionAttrs const &attrs) { @@ -71,23 +72,24 @@ OpTaskInvocation forward(MultiHeadAttentionAttrs const &attrs) { b.bind_arg(PROFILING, profiling_settings()); b.bind_arg(PER_DEVICE_STATE, per_device_op_state()); - return {ATTENTION_FWD_TASK_ID, b}; + return {task_id_t::ATTENTION_FWD_TASK_ID, b}; } OpTaskInvocation backward(MultiHeadAttentionAttrs const &attrs) { OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); - return {ATTENTION_BWD_TASK_ID, b}; + return {task_id_t::ATTENTION_BWD_TASK_ID, b}; } -static DeviceSpecific +static DeviceSpecificDeviceStates init_task_impl(TaskArgumentAccessor const &acc) { auto const &attrs = acc.get_argument(ATTRS); Allocator allocator = acc.get_allocator(); - size_t qProjSize = acc.get_argument(QPROJSIZE); - size_t kProjSize = acc.get_argument(KPROJSIZE); - size_t vProjSize = acc.get_argument(VPROJSIZE); - size_t oProjSize = acc.get_argument(OPROJSIZE); + size_t qProjSize = acc.get_argument(QPROJSIZE); + size_t kProjSize = acc.get_argument(KPROJSIZE); + size_t vProjSize = acc.get_argument(VPROJSIZE); + size_t oProjSize = acc.get_argument(OPROJSIZE); + PerDeviceFFHandle handle = acc.get_argument(HANDLE); ParallelTensorShape query_parallel_tensor_shape = acc.get_argument(QUERY_PARALLEL_TENSOR_SHAPE); @@ -129,7 +131,8 @@ static DeviceSpecific qoSeqLength, kvSeqLength, attrs.add_bias_kv); - return DeviceSpecific::create(per_device_state); + return DeviceSpecificDeviceStates{ + DeviceSpecific::create(per_device_state)}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -140,7 +143,8 @@ static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { auto output = acc.get_tensor(OUTPUT); ProfilingSettings profiling = acc.get_argument(PROFILING); - auto per_device_state = acc.get_argument(PER_DEVICE_STATE); + MHAPerDeviceState per_device_state = + acc.get_argument(PER_DEVICE_STATE); return profile(forward_kernel, profiling, @@ -166,7 +170,8 @@ static std::optional auto key_grad = acc.get_tensor_grad(KEY); auto value_grad = acc.get_tensor_grad(VALUE); - auto per_device_state = acc.get_argument(PER_DEVICE_STATE); + MHAPerDeviceState per_device_state = + acc.get_argument(PER_DEVICE_STATE); ProfilingSettings profiling = acc.get_argument(PROFILING); float *key_grad_ptr = @@ -197,13 +202,13 @@ static std::optional } TaskImplFunction get_attention_init_task_impl() { - return init_task_impl; + return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; } TaskImplFunction get_attention_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_attention_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_attention_init_signature() { @@ -245,7 +250,9 @@ OpTaskSignature get_attention_bwd_signature() { } std::vector get_task_ids(MultiHeadAttentionAttrs const &) { - return {ATTENTION_INIT_TASK_ID, ATTENTION_FWD_TASK_ID, ATTENTION_BWD_TASK_ID}; + return {task_id_t::ATTENTION_INIT_TASK_ID, + task_id_t::ATTENTION_FWD_TASK_ID, + task_id_t::ATTENTION_BWD_TASK_ID}; } } // namespace FlexFlow diff --git a/lib/local-execution/src/ops/batch_matmul.cc b/lib/local-execution/src/ops/batch_matmul.cc index d18e58baf4..1eae409ae2 100644 --- a/lib/local-execution/src/ops/batch_matmul.cc +++ b/lib/local-execution/src/ops/batch_matmul.cc @@ -45,13 +45,13 @@ OpTaskInvocation forward(BatchMatmulAttrs const &attrs) { fwd.bind_arg(PROFILING, profiling_settings()); fwd.bind_arg(ITERATION_CONFIG, iteration_config()); - return {BATCHMATMUL_FWD_TASK_ID, fwd}; + return {task_id_t::BATCHMATMUL_FWD_TASK_ID, fwd}; } OpTaskInvocation backward(BatchMatmulAttrs const &attrs) { OpTaskBinding bwd = infer_bwd_binding(forward(attrs).binding); - return {BATCHMATMUL_BWD_TASK_ID, bwd}; + return {task_id_t::BATCHMATMUL_BWD_TASK_ID, bwd}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -153,10 +153,10 @@ static std::optional } TaskImplFunction get_batch_matmul_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_batch_matmul_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_batch_matmul_fwd_signature() { @@ -173,14 +173,14 @@ OpTaskSignature get_batch_matmul_fwd_signature() { } OpTaskSignature get_batch_matmul_bwd_signature() { - OpTaskSignature bwd = - infer_bwd_signature(fwd_signature()); + OpTaskSignature bwd = infer_bwd_signature(get_batch_matmul_fwd_signature()); return bwd; } std::vector get_task_ids(BatchMatmulAttrs const &) { - return {BATCHMATMUL_FWD_TASK_ID, BATCHMATMUL_BWD_TASK_ID}; + return {task_id_t::BATCHMATMUL_FWD_TASK_ID, + task_id_t::BATCHMATMUL_BWD_TASK_ID}; } }; // namespace FlexFlow diff --git a/lib/local-execution/src/ops/batch_norm.cc b/lib/local-execution/src/ops/batch_norm.cc index 5eaa264541..851566fc02 100644 --- a/lib/local-execution/src/ops/batch_norm.cc +++ b/lib/local-execution/src/ops/batch_norm.cc @@ -43,7 +43,7 @@ OpTaskInvocation init(BatchNormAttrs const &attrs) { binding.bind_arg(PROFILING, profiling_settings()); binding.bind_arg(HANDLE, ff_handle()); - return {BATCHNORM_INIT_TASK_ID, binding}; + return {task_id_t::BATCHNORM_INIT_TASK_ID, binding}; } OpTaskInvocation forward(BatchNormAttrs const &attrs) { @@ -57,16 +57,16 @@ OpTaskInvocation forward(BatchNormAttrs const &attrs) { binding.bind(BIAS, input_tensor(2)); binding.bind(OUTPUT, output_tensor(0)); - return {BATCHNORM_FWD_TASK_ID, binding}; + return {task_id_t::BATCHNORM_FWD_TASK_ID, binding}; } OpTaskInvocation backward(BatchNormAttrs const &attrs) { OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - return {BATCHNORM_BWD_TASK_ID, binding}; + return {task_id_t::BATCHNORM_BWD_TASK_ID, binding}; } -static DeviceSpecific +static DeviceSpecificDeviceStates init_task_impl(TaskArgumentAccessor const &acc) { Allocator allocator = acc.get_allocator(); PerDeviceFFHandle handle = acc.get_argument(HANDLE); @@ -91,7 +91,8 @@ static DeviceSpecific output_w, attrs.relu); - return DeviceSpecific::create(per_device_state); + return DeviceSpecificDeviceStates{ + DeviceSpecific::create(per_device_state)}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -143,13 +144,13 @@ static std::optional } TaskImplFunction get_batch_norm_init_task_impl() { - return init_task_impl; + return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; } TaskImplFunction get_batch_norm_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_batch_norm_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_batch_norm_init_signature() { @@ -185,9 +186,9 @@ OpTaskSignature get_batch_norm_bwd_signature() { std::vector get_task_ids(BatchNormAttrs const &) { return { - BATCHNORM_INIT_TASK_ID, - BATCHNORM_FWD_TASK_ID, - BATCHNORM_BWD_TASK_ID, + task_id_t::BATCHNORM_INIT_TASK_ID, + task_id_t::BATCHNORM_FWD_TASK_ID, + task_id_t::BATCHNORM_BWD_TASK_ID, }; } diff --git a/lib/local-execution/src/ops/cast.cc b/lib/local-execution/src/ops/cast.cc index 7e109a2140..3e7baf49a9 100644 --- a/lib/local-execution/src/ops/cast.cc +++ b/lib/local-execution/src/ops/cast.cc @@ -34,13 +34,13 @@ OpTaskInvocation forward(CastAttrs const &attrs) { binding.bind(INPUT, input_tensor(0)); binding.bind(OUTPUT, output_tensor(0)); - return {CAST_FWD_TASK_ID, binding}; + return {task_id_t::CAST_FWD_TASK_ID, binding}; } OpTaskInvocation backward(CastAttrs const &attrs) { OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - return {CAST_BWD_TASK_ID, binding}; + return {task_id_t::CAST_BWD_TASK_ID, binding}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -79,10 +79,10 @@ static std::optional } TaskImplFunction get_cast_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_cast_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_cast_fwd_signature() { @@ -104,7 +104,7 @@ OpTaskSignature get_cast_bwd_signature() { } std::vector get_task_ids(CastAttrs const &) { - return {CAST_FWD_TASK_ID, CAST_BWD_TASK_ID}; + return {task_id_t::CAST_FWD_TASK_ID, task_id_t::CAST_BWD_TASK_ID}; } }; // namespace FlexFlow diff --git a/lib/local-execution/src/ops/combine.cc b/lib/local-execution/src/ops/combine.cc index a6aeaebf14..ccc82cce17 100644 --- a/lib/local-execution/src/ops/combine.cc +++ b/lib/local-execution/src/ops/combine.cc @@ -32,13 +32,13 @@ OpTaskInvocation forward(CombineAttrs const &attrs) { binding.bind(INPUT, input_tensor(0)); binding.bind(OUTPUT, output_tensor(0)); - return {COMBINE_FWD_TASK_ID, binding}; + return {task_id_t::COMBINE_FWD_TASK_ID, binding}; } OpTaskInvocation backward(CombineAttrs const &attrs) { OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); - return {COMBINE_BWD_TASK_ID, b}; + return {task_id_t::COMBINE_BWD_TASK_ID, b}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -84,4 +84,11 @@ OpTaskSignature get_combine_bwd_signature() { return bwd; } +TaskImplFunction get_combine_fwd_task_impl() { + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; +} +TaskImplFunction get_combine_bwd_task_impl() { + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; +} + }; // namespace FlexFlow diff --git a/lib/local-execution/src/ops/concat.cc b/lib/local-execution/src/ops/concat.cc index 5dfd100d84..35f663b1cd 100644 --- a/lib/local-execution/src/ops/concat.cc +++ b/lib/local-execution/src/ops/concat.cc @@ -34,13 +34,13 @@ OpTaskInvocation forward(ConcatAttrs const &attrs) { binding.bind_arg(PROFILING, profiling_settings()); binding.bind_arg(ATTRS, attrs); - return {CONCAT_FWD_TASK_ID, binding}; + return {task_id_t::CONCAT_FWD_TASK_ID, binding}; } OpTaskInvocation backward(ConcatAttrs const &attrs) { OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); - return {CONCAT_BWD_TASK_ID, b}; + return {task_id_t::CONCAT_BWD_TASK_ID, b}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -79,10 +79,10 @@ static std::optional } TaskImplFunction get_concat_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_concat_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_concat_fwd_signature() { @@ -97,14 +97,13 @@ OpTaskSignature get_concat_fwd_signature() { } OpTaskSignature get_concat_bwd_signature() { - OpTaskSignature bwd = - infer_bwd_signature(fwd_signature()); + OpTaskSignature bwd = infer_bwd_signature(get_concat_fwd_signature()); return bwd; } std::vector get_task_ids(ConcatAttrs const &) { - return {CONCAT_FWD_TASK_ID, CONCAT_BWD_TASK_ID}; + return {task_id_t::CONCAT_FWD_TASK_ID, task_id_t::CONCAT_BWD_TASK_ID}; } }; // namespace FlexFlow diff --git a/lib/local-execution/src/ops/conv_2d.cc b/lib/local-execution/src/ops/conv_2d.cc index 7aede41355..d5c6e7f851 100644 --- a/lib/local-execution/src/ops/conv_2d.cc +++ b/lib/local-execution/src/ops/conv_2d.cc @@ -26,7 +26,7 @@ OpTaskInvocation init(Conv2DAttrs const &attrs) { binding.bind_arg(ATTRS, attrs); binding.bind_arg(HANDLE, ff_handle()); - return {CONV2D_INIT_TASK_ID, binding}; + return {task_id_t::CONV2D_INIT_TASK_ID, binding}; } OpTaskInvocation forward(Conv2DAttrs const &attrs) { @@ -42,16 +42,16 @@ OpTaskInvocation forward(Conv2DAttrs const &attrs) { binding.bind(FILTER, weight_tensor(0)); binding.bind(BIAS, weight_tensor(1)); - return {CONV2D_FWD_TASK_ID, binding}; + return {task_id_t::CONV2D_FWD_TASK_ID, binding}; } OpTaskInvocation backward(Conv2DAttrs const &attrs) { OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - return {CONV2D_BWD_TASK_ID, binding}; + return {task_id_t::CONV2D_BWD_TASK_ID, binding}; } -static DeviceSpecific +static DeviceSpecificDeviceStates init_task_impl(TaskArgumentAccessor const &acc) { PerDeviceFFHandle handle = acc.get_argument(HANDLE); @@ -75,7 +75,8 @@ static DeviceSpecific output, filter.get_float_ptr(), filter_grad.get_float_ptr()); - return DeviceSpecific::create(per_device_state); + return DeviceSpecificDeviceStates{ + DeviceSpecific::create(per_device_state)}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -131,13 +132,13 @@ static std::optional } TaskImplFunction get_conv_2d_init_task_impl() { - return init_task_impl; + return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; } TaskImplFunction get_conv_2d_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_conv_2d_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_conv_2d_init_signature() { @@ -176,7 +177,9 @@ OpTaskSignature get_conv_2d_bwd_signature() { } std::vector get_task_ids(Conv2DAttrs const &) { - return {CONV2D_INIT_TASK_ID, CONV2D_FWD_TASK_ID, CONV2D_BWD_TASK_ID}; + return {task_id_t::CONV2D_INIT_TASK_ID, + task_id_t::CONV2D_FWD_TASK_ID, + task_id_t::CONV2D_BWD_TASK_ID}; } } // namespace FlexFlow diff --git a/lib/local-execution/src/ops/dropout.cc b/lib/local-execution/src/ops/dropout.cc index 9195e859ce..cac08866cc 100644 --- a/lib/local-execution/src/ops/dropout.cc +++ b/lib/local-execution/src/ops/dropout.cc @@ -18,7 +18,7 @@ OpTaskInvocation init(DropoutAttrs const &attrs) { binding.bind_arg(FF_HANDLE, ff_handle()); binding.bind(OUTPUT, output_tensor(0)); - return {DROPOUT_INIT_TASK_ID, binding}; + return {task_id_t::DROPOUT_INIT_TASK_ID, binding}; } OpTaskInvocation forward(DropoutAttrs const &attrs) { @@ -31,16 +31,16 @@ OpTaskInvocation forward(DropoutAttrs const &attrs) { binding.bind_arg(PER_DEVICE_STATE, per_device_op_state()); - return {DROPOUT_FWD_TASK_ID, binding}; + return {task_id_t::DROPOUT_FWD_TASK_ID, binding}; } OpTaskInvocation backward(DropoutAttrs const &attrs) { OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); - return {DROPOUT_BWD_TASK_ID, b}; + return {task_id_t::DROPOUT_BWD_TASK_ID, b}; } -static DeviceSpecific +static DeviceSpecificDeviceStates init_task_impl(TaskArgumentAccessor const &acc) { auto output = acc.get_tensor(OUTPUT); Allocator allocator = acc.get_allocator(); @@ -49,7 +49,8 @@ static DeviceSpecific DropoutPerDeviceState per_device_state = init_kernel(handle, attrs.rate, attrs.seed, output.shape, allocator); - return DeviceSpecific::create(per_device_state); + return DeviceSpecificDeviceStates{ + DeviceSpecific::create(per_device_state)}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -86,13 +87,13 @@ static std::optional } TaskImplFunction get_dropout_init_task_impl() { - return init_task_impl; + return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; } TaskImplFunction get_dropout_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_dropout_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_dropout_init_signature() { @@ -126,7 +127,9 @@ OpTaskSignature get_dropout_bwd_signature() { } std::vector get_task_ids(DropoutAttrs const &) { - return {DROPOUT_INIT_TASK_ID, DROPOUT_FWD_TASK_ID, DROPOUT_BWD_TASK_ID}; + return {task_id_t::DROPOUT_INIT_TASK_ID, + task_id_t::DROPOUT_FWD_TASK_ID, + task_id_t::DROPOUT_BWD_TASK_ID}; } }; // namespace FlexFlow diff --git a/lib/local-execution/src/ops/dropout.h b/lib/local-execution/src/ops/dropout.h index 58910ab6dc..84b67a29c2 100644 --- a/lib/local-execution/src/ops/dropout.h +++ b/lib/local-execution/src/ops/dropout.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "local-execution/tasks.h" +#include "local-execution/task_id_t.dtg.h" #include "op-attrs/ops/dropout.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/element_binary.cc b/lib/local-execution/src/ops/element_binary.cc index bd5b415df3..48c6c699a2 100644 --- a/lib/local-execution/src/ops/element_binary.cc +++ b/lib/local-execution/src/ops/element_binary.cc @@ -27,7 +27,7 @@ OpTaskInvocation init(ElementBinaryAttrs const &attrs) { binding.bind_arg(ATTRS, attrs); binding.bind_arg(HANDLE, ff_handle()); - return {ELEMENTBINARY_INIT_TASK_ID, binding}; + return {task_id_t::ELEMENTBINARY_INIT_TASK_ID, binding}; } OpTaskInvocation forward(ElementBinaryAttrs const &attrs) { @@ -42,16 +42,16 @@ OpTaskInvocation forward(ElementBinaryAttrs const &attrs) { per_device_op_state()); binding.bind_arg(HANDLE, ff_handle()); - return {ELEMENTBINARY_FWD_TASK_ID, binding}; + return {task_id_t::ELEMENTBINARY_FWD_TASK_ID, binding}; } OpTaskInvocation backward(ElementBinaryAttrs const &attrs) { OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); - return {ELEMENTBINARY_BWD_TASK_ID, b}; + return {task_id_t::ELEMENTBINARY_BWD_TASK_ID, b}; } -static DeviceSpecific +static DeviceSpecificDeviceStates init_task_impl(TaskArgumentAccessor const &acc) { auto input_lhs = acc.get_tensor(LHS_INPUT); auto input_rhs = acc.get_tensor(RHS_INPUT); @@ -68,7 +68,8 @@ static DeviceSpecific input_lhs.shape, input_rhs.shape, output.shape); - return DeviceSpecific::create(per_device_state); + return DeviceSpecificDeviceStates{ + DeviceSpecific::create(per_device_state)}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -125,15 +126,15 @@ static std::optional } TaskImplFunction get_element_binary_init_task_impl() { - return init_task_impl; + return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; } TaskImplFunction get_element_binary_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_element_binary_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_element_binary_init_signature() { @@ -172,9 +173,9 @@ OpTaskSignature get_element_binary_bwd_signature() { } std::vector get_task_ids(ElementBinaryAttrs const &) { - return {ELEMENTBINARY_INIT_TASK_ID, - ELEMENTBINARY_FWD_TASK_ID, - ELEMENTBINARY_BWD_TASK_ID}; + return {task_id_t::ELEMENTBINARY_INIT_TASK_ID, + task_id_t::ELEMENTBINARY_FWD_TASK_ID, + task_id_t::ELEMENTBINARY_BWD_TASK_ID}; } }; // namespace FlexFlow diff --git a/lib/local-execution/src/ops/element_unary.cc b/lib/local-execution/src/ops/element_unary.cc index 3185bbfef9..a52ebb8089 100644 --- a/lib/local-execution/src/ops/element_unary.cc +++ b/lib/local-execution/src/ops/element_unary.cc @@ -26,7 +26,7 @@ OpTaskInvocation init(ElementUnaryAttrs const &attrs) { b.bind_arg(ATTRS, attrs); b.bind_arg(INPUT_SHAPE, input_parallel_tensor_shape(0)); - return {ELEMENTUNARY_INIT_TASK_ID, b}; + return {task_id_t::ELEMENTUNARY_INIT_TASK_ID, b}; } OpTaskInvocation forward(ElementUnaryAttrs const &attrs) { @@ -39,16 +39,16 @@ OpTaskInvocation forward(ElementUnaryAttrs const &attrs) { b.bind_arg(PER_DEVICE_STATE, per_device_op_state()); - return {ELEMENTUNARY_FWD_TASK_ID, b}; + return {task_id_t::ELEMENTUNARY_FWD_TASK_ID, b}; } OpTaskInvocation backward(ElementUnaryAttrs const &attrs) { OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); - return {ELEMENTUNARY_BWD_TASK_ID, b}; + return {task_id_t::ELEMENTUNARY_BWD_TASK_ID, b}; } -static DeviceSpecific +static DeviceSpecificDeviceStates init_task_impl(TaskArgumentAccessor const &acc) { auto const &attrs = acc.get_argument(ATTRS); @@ -61,7 +61,8 @@ static DeviceSpecific ElementUnaryPerDeviceState per_device_state = init_kernel( get_piece_shape(input_shape), get_piece_shape(output_shape), attrs); - return DeviceSpecific::create(per_device_state); + return DeviceSpecificDeviceStates{ + DeviceSpecific::create(per_device_state)}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -112,13 +113,13 @@ static std::optional } TaskImplFunction get_element_unary_init_task_impl() { - return init_task_impl; + return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; } TaskImplFunction get_element_unary_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_element_unary_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_element_unary_init_signature() { @@ -152,9 +153,9 @@ OpTaskSignature get_element_unary_bwd_signature() { } std::vector get_task_ids(ElementUnaryAttrs const &) { - return {ELEMENTUNARY_INIT_TASK_ID, - ELEMENTUNARY_FWD_TASK_ID, - ELEMENTUNARY_BWD_TASK_ID}; + return {task_id_t::ELEMENTUNARY_INIT_TASK_ID, + task_id_t::ELEMENTUNARY_FWD_TASK_ID, + task_id_t::ELEMENTUNARY_BWD_TASK_ID}; } } // namespace FlexFlow diff --git a/lib/local-execution/src/ops/flat.cc b/lib/local-execution/src/ops/flat.cc index 5d791e4459..3fe5029fa1 100644 --- a/lib/local-execution/src/ops/flat.cc +++ b/lib/local-execution/src/ops/flat.cc @@ -15,13 +15,13 @@ OpTaskInvocation forward(FlatAttrs const &attrs) { binding.bind(OUTPUT, output_tensor(0)); binding.bind_arg(PROFILING, profiling_settings()); - return {FLAT_FWD_TASK_ID, binding}; + return {task_id_t::FLAT_FWD_TASK_ID, binding}; } OpTaskInvocation backward(FlatAttrs const &attrs) { OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); - return {FLAT_BWD_TASK_ID, b}; + return {task_id_t::FLAT_BWD_TASK_ID, b}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -53,10 +53,10 @@ static std::optional } TaskImplFunction get_flat_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_flat_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_flat_fwd_signature() { @@ -76,7 +76,7 @@ OpTaskSignature get_flat_bwd_signature() { } std::vector get_task_ids(FlatAttrs const &) { - return {FLAT_FWD_TASK_ID, FLAT_BWD_TASK_ID}; + return {task_id_t::FLAT_FWD_TASK_ID, task_id_t::FLAT_BWD_TASK_ID}; } }; // namespace FlexFlow diff --git a/lib/local-execution/src/ops/gather.cc b/lib/local-execution/src/ops/gather.cc index 44455bc42c..a015c64f4d 100644 --- a/lib/local-execution/src/ops/gather.cc +++ b/lib/local-execution/src/ops/gather.cc @@ -34,7 +34,7 @@ OpTaskInvocation init(GatherAttrs const &attrs) { binding.bind_arg(ATTRS, attrs); binding.bind_arg(HANDLE, ff_handle()); - return {GATHER_INIT_TASK_ID, binding}; + return {task_id_t::GATHER_INIT_TASK_ID, binding}; } OpTaskInvocation forward(GatherAttrs const &attrs) { @@ -49,16 +49,16 @@ OpTaskInvocation forward(GatherAttrs const &attrs) { binding.bind(OUTPUT, output_tensor(0)); binding.bind(INDEX, weight_tensor(0)); - return {GATHER_FWD_TASK_ID, binding}; + return {task_id_t::GATHER_FWD_TASK_ID, binding}; } OpTaskInvocation backward(GatherAttrs const &attrs) { OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - return {GATHER_BWD_TASK_ID, binding}; + return {task_id_t::GATHER_BWD_TASK_ID, binding}; } -static DeviceSpecific +static DeviceSpecificDeviceStates init_task_impl(TaskArgumentAccessor const &acc) { auto input = acc.get_tensor(INPUT); auto index = acc.get_tensor(INDEX); @@ -80,7 +80,8 @@ static DeviceSpecific } GatherPerDeviceState per_device_state = {handle, legion_dim}; - return DeviceSpecific::create(per_device_state); + return DeviceSpecificDeviceStates{ + DeviceSpecific::create(per_device_state)}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -121,13 +122,13 @@ static std::optional } TaskImplFunction get_gather_init_task_impl() { - return init_task_impl; + return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; } TaskImplFunction get_gather_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_gather_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_gather_init_signature() { @@ -165,7 +166,9 @@ OpTaskSignature get_gather_bwd_signature() { } std::vector get_task_ids(GatherAttrs const &) { - return {GATHER_INIT_TASK_ID, GATHER_FWD_TASK_ID, GATHER_BWD_TASK_ID}; + return {task_id_t::GATHER_INIT_TASK_ID, + task_id_t::GATHER_FWD_TASK_ID, + task_id_t::GATHER_BWD_TASK_ID}; } }; // namespace FlexFlow diff --git a/lib/local-execution/src/ops/input.cc b/lib/local-execution/src/ops/input.cc new file mode 100644 index 0000000000..56d19fa1ba --- /dev/null +++ b/lib/local-execution/src/ops/input.cc @@ -0,0 +1,9 @@ +#include "input.h" + +namespace FlexFlow { + +std::vector get_task_ids(InputAttrs const &attrs) { + return {}; +} + +}; // namespace FlexFlow diff --git a/lib/local-execution/src/ops/input.h b/lib/local-execution/src/ops/input.h new file mode 100644 index 0000000000..97985585e1 --- /dev/null +++ b/lib/local-execution/src/ops/input.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_INPUT_H +#define _FLEXFLOW_INPUT_H + +#include "local-execution/op_task_invocation.h" +#include "op-attrs/ops/input.h" + +namespace FlexFlow { + +std::vector get_task_ids(InputAttrs const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/local-execution/src/ops/layer_norm.cc b/lib/local-execution/src/ops/layer_norm.cc index 9530628d46..e99d27319c 100644 --- a/lib/local-execution/src/ops/layer_norm.cc +++ b/lib/local-execution/src/ops/layer_norm.cc @@ -46,7 +46,7 @@ OpTaskInvocation init(LayerNormAttrs const &attrs) { b.bind_arg(HANDLE, ff_handle()); b.bind_arg(ATTRS, attrs); - return {LAYERNORM_INIT_TASK_ID, b}; + return {task_id_t::LAYERNORM_INIT_TASK_ID, b}; } OpTaskInvocation forward(LayerNormAttrs const &attrs) { @@ -59,13 +59,13 @@ OpTaskInvocation forward(LayerNormAttrs const &attrs) { b.bind_arg(PROFILING, profiling_settings()); b.bind_arg(PER_DEVICE_STATE, per_device_op_state()); - return {LAYERNORM_FWD_TASK_ID, b}; + return {task_id_t::LAYERNORM_FWD_TASK_ID, b}; } OpTaskInvocation backward(LayerNormAttrs const &attrs) { OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); - return {LAYERNORM_BWD_TASK_ID, b}; + return {task_id_t::LAYERNORM_BWD_TASK_ID, b}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -112,7 +112,7 @@ static std::optional beta_grad); } -static DeviceSpecific +static DeviceSpecificDeviceStates init_task_impl(TaskArgumentAccessor const &acc) { auto const &attrs = acc.get_argument(ATTRS); Allocator allocator = acc.get_allocator(); @@ -141,17 +141,18 @@ static DeviceSpecific effective_batch_size, effective_num_elements, attrs.eps); - return DeviceSpecific::create(per_device_state); + return DeviceSpecificDeviceStates{ + DeviceSpecific::create(per_device_state)}; } TaskImplFunction get_layer_norm_init_task_impl() { - return init_task_impl; + return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; } TaskImplFunction get_layer_norm_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_layer_norm_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_layer_norm_fwd_signature() { @@ -184,7 +185,9 @@ OpTaskSignature get_layer_norm_init_signature() { } std::vector get_task_ids(LayerNormAttrs const &) { - return {LAYERNORM_INIT_TASK_ID, LAYERNORM_FWD_TASK_ID, LAYERNORM_BWD_TASK_ID}; + return {task_id_t::LAYERNORM_INIT_TASK_ID, + task_id_t::LAYERNORM_FWD_TASK_ID, + task_id_t::LAYERNORM_BWD_TASK_ID}; } } // namespace FlexFlow diff --git a/lib/local-execution/src/ops/linear.cc b/lib/local-execution/src/ops/linear.cc index 599f671e92..9934e2a45c 100644 --- a/lib/local-execution/src/ops/linear.cc +++ b/lib/local-execution/src/ops/linear.cc @@ -31,7 +31,7 @@ OpTaskInvocation init(LinearAttrs const &attrs) { binding.bind(WEIGHT, weight_tensor(0)); // weight binding.bind(OUTPUT, output_tensor(0)); // output - return {LINEAR_INIT_TASK_ID, binding}; + return {task_id_t::LINEAR_INIT_TASK_ID, binding}; } OpTaskInvocation forward(LinearAttrs const &attrs) { @@ -49,16 +49,16 @@ OpTaskInvocation forward(LinearAttrs const &attrs) { per_device_op_state()); binding.bind_arg(ATTRS, attrs); - return {LINEAR_FWD_TASK_ID, binding}; + return {task_id_t::LINEAR_FWD_TASK_ID, binding}; } OpTaskInvocation backward(LinearAttrs const &attrs) { OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); - return {LINEAR_BWD_TASK_ID, b}; + return {task_id_t::LINEAR_BWD_TASK_ID, b}; } -static DeviceSpecific +static DeviceSpecificDeviceStates init_task_impl(TaskArgumentAccessor const &acc) { auto const &attrs = acc.get_argument(ATTRS); PerDeviceFFHandle handle = acc.get_argument(HANDLE); @@ -71,17 +71,18 @@ static DeviceSpecific float *one_ptr; - LinearPerDeviceState state = init_kernel(handle, - one_ptr, - attrs.activation, - attrs.regularizer, - attrs.use_bias, - input.data_type, - weight.data_type, - output.data_type, - batch_size, - attrs.out_channels); - return DeviceSpecific::create(state); + LinearPerDeviceState per_device_state = init_kernel(handle, + one_ptr, + attrs.activation, + attrs.regularizer, + attrs.use_bias, + input.data_type, + weight.data_type, + output.data_type, + batch_size, + attrs.out_channels); + return DeviceSpecificDeviceStates{ + DeviceSpecific::create(per_device_state)}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -160,13 +161,13 @@ static std::optional } TaskImplFunction get_linear_init_task_impl() { - return init_task_impl; + return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; } TaskImplFunction get_linear_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_linear_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_linear_init_signature() { @@ -203,7 +204,9 @@ OpTaskSignature get_linear_bwd_signature() { } std::vector get_task_ids(LinearAttrs const &) { - return {LINEAR_INIT_TASK_ID, LINEAR_FWD_TASK_ID, LINEAR_BWD_TASK_ID}; + return {task_id_t::LINEAR_INIT_TASK_ID, + task_id_t::LINEAR_FWD_TASK_ID, + task_id_t::LINEAR_BWD_TASK_ID}; } }; // namespace FlexFlow diff --git a/lib/local-execution/src/ops/noop.cc b/lib/local-execution/src/ops/noop.cc index 168d547c17..e35fdec275 100644 --- a/lib/local-execution/src/ops/noop.cc +++ b/lib/local-execution/src/ops/noop.cc @@ -14,21 +14,11 @@ */ #include "noop.h" -#include "local-execution/op_task_invocation.h" -#include "utils/hash-utils.h" namespace FlexFlow { -std::optional init(NoopAttrs const &attrs) { - return std::nullopt; -} - -std::optional forward(NoopAttrs const &attrs) { - return std::nullopt; -} - -std::optional backward(NoopAttrs const &attrs) { - return std::nullopt; +std::vector get_task_ids(NoopAttrs const &attrs) { + return {}; } }; // namespace FlexFlow diff --git a/lib/local-execution/src/ops/noop.h b/lib/local-execution/src/ops/noop.h index fab2cf1f86..959f7dc054 100644 --- a/lib/local-execution/src/ops/noop.h +++ b/lib/local-execution/src/ops/noop.h @@ -4,12 +4,12 @@ #include "local-execution/op_task_invocation.h" #include "op-attrs/ops/input.h" #include "op-attrs/ops/noop.h" +#include "op-attrs/ops/weight_attrs.dtg.h" namespace FlexFlow { -std::optional init(NoopAttrs const &); -std::optional forward(NoopAttrs const &); -std::optional backward(NoopAttrs const &); +std::vector get_task_ids(NoopAttrs const &); + } // namespace FlexFlow #endif diff --git a/lib/local-execution/src/ops/pool_2d.cc b/lib/local-execution/src/ops/pool_2d.cc index d6f100390a..789ed2cd63 100644 --- a/lib/local-execution/src/ops/pool_2d.cc +++ b/lib/local-execution/src/ops/pool_2d.cc @@ -20,10 +20,10 @@ OpTaskInvocation init(Pool2DAttrs const &attrs) { binding.bind_arg(ATTRS, attrs); binding.bind_arg(HANDLE, ff_handle()); - return {POOL2D_INIT_TASK_ID, binding}; + return {task_id_t::POOL2D_INIT_TASK_ID, binding}; } -static DeviceSpecific +static DeviceSpecificDeviceStates init_task_impl(TaskArgumentAccessor const &acc) { auto const &attrs = acc.get_argument(ATTRS); PerDeviceFFHandle handle = acc.get_argument(HANDLE); @@ -64,25 +64,26 @@ static DeviceSpecific printf("Warning: changing pool_padding_w to satisfy output_w size\n"); } - Pool2DPerDeviceState state = init_kernel(handle, - attrs.activation, - input_w, - input_h, - input_c, - input_n, - output_w, - output_h, - output_c, - output_n, - pad_h, - pad_w, - attrs.kernel_h, - attrs.kernel_w, - attrs.stride_h, - attrs.stride_w, - attrs.pool_type); - - return DeviceSpecific::create(state); + Pool2DPerDeviceState per_device_state = init_kernel(handle, + attrs.activation, + input_w, + input_h, + input_c, + input_n, + output_w, + output_h, + output_c, + output_n, + pad_h, + pad_w, + attrs.kernel_h, + attrs.kernel_w, + attrs.stride_h, + attrs.stride_w, + attrs.pool_type); + + return DeviceSpecificDeviceStates{ + DeviceSpecific::create(per_device_state)}; } OpTaskInvocation forward(Pool2DAttrs const &attrs) { @@ -94,13 +95,13 @@ OpTaskInvocation forward(Pool2DAttrs const &attrs) { binding.bind_arg(PER_DEVICE_STATE, per_device_op_state()); - return {POOL2D_FWD_TASK_ID, binding}; + return {task_id_t::POOL2D_FWD_TASK_ID, binding}; } OpTaskInvocation backward(Pool2DAttrs const &attrs) { OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); - return {POOL2D_BWD_TASK_ID, b}; + return {task_id_t::POOL2D_BWD_TASK_ID, b}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -141,13 +142,13 @@ static std::optional } TaskImplFunction get_pool_2d_init_task_impl() { - return init_task_impl; + return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; } TaskImplFunction get_pool_2d_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_pool_2d_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_pool_2d_init_signature() { @@ -178,7 +179,9 @@ OpTaskSignature get_pool_2d_bwd_signature() { } std::vector get_task_ids(Pool2DAttrs const &) { - return {POOL2D_INIT_TASK_ID, POOL2D_FWD_TASK_ID, POOL2D_BWD_TASK_ID}; + return {task_id_t::POOL2D_INIT_TASK_ID, + task_id_t::POOL2D_FWD_TASK_ID, + task_id_t::POOL2D_BWD_TASK_ID}; } }; // namespace FlexFlow diff --git a/lib/local-execution/src/ops/reduce.cc b/lib/local-execution/src/ops/reduce.cc index 23c05eb17e..a043d9f847 100644 --- a/lib/local-execution/src/ops/reduce.cc +++ b/lib/local-execution/src/ops/reduce.cc @@ -29,10 +29,10 @@ OpTaskInvocation init(ReduceAttrs const &attrs) { binding.bind(INPUT, input_tensor(0)); binding.bind(OUTPUT, output_tensor(0)); - return {REDUCE_INIT_TASK_ID, binding}; + return {task_id_t::REDUCE_INIT_TASK_ID, binding}; } -static DeviceSpecific +static DeviceSpecificDeviceStates init_task_impl(TaskArgumentAccessor const &acc) { PerDeviceFFHandle handle = acc.get_argument(HANDLE); auto attrs = acc.get_argument(ATTRS); @@ -44,7 +44,8 @@ static DeviceSpecific size_t reduction_size = input.shape.get_volume() / output.shape.get_volume(); ReducePerDeviceState per_device_state = init_kernel(handle, op_type, reduction_size, input.shape, output.shape); - return DeviceSpecific::create(per_device_state); + return DeviceSpecificDeviceStates{ + DeviceSpecific::create(per_device_state)}; } // Note: forward_kernel only needs ReducePerDeviceState, input, output @@ -58,7 +59,7 @@ OpTaskInvocation forward(ReduceAttrs const &attrs) { binding.bind(INPUT, input_tensor(0)); binding.bind(OUTPUT, output_tensor(0)); - return {REDUCE_FWD_TASK_ID, binding}; + return {task_id_t::REDUCE_FWD_TASK_ID, binding}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -80,7 +81,7 @@ static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { OpTaskInvocation backward(ReduceAttrs const &attrs) { OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - return {REDUCE_BWD_TASK_ID, binding}; + return {task_id_t::REDUCE_BWD_TASK_ID, binding}; } static std::optional @@ -101,13 +102,13 @@ static std::optional } TaskImplFunction get_reduce_init_task_impl() { - return init_task_impl; + return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; } TaskImplFunction get_reduce_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_reduce_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_reduce_init_signature() { @@ -135,7 +136,9 @@ OpTaskSignature get_reduce_bwd_signature() { } std::vector get_task_ids(ReduceAttrs const &) { - return {REDUCE_INIT_TASK_ID, REDUCE_FWD_TASK_ID, REDUCE_BWD_TASK_ID}; + return {task_id_t::REDUCE_INIT_TASK_ID, + task_id_t::REDUCE_FWD_TASK_ID, + task_id_t::REDUCE_BWD_TASK_ID}; } }; // namespace FlexFlow diff --git a/lib/local-execution/src/ops/reduction.cc b/lib/local-execution/src/ops/reduction.cc index 85d9d48148..a58d79a4f8 100644 --- a/lib/local-execution/src/ops/reduction.cc +++ b/lib/local-execution/src/ops/reduction.cc @@ -33,13 +33,13 @@ OpTaskInvocation forward(ReductionAttrs const &attrs) { binding.bind(INPUT, input_tensor(0)); binding.bind(OUTPUT, output_tensor(0)); - return {REDUCTION_FWD_TASK_ID, binding}; + return {task_id_t::REDUCTION_FWD_TASK_ID, binding}; } OpTaskInvocation backward(ReductionAttrs const &attrs) { OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - return {REDUCTION_BWD_TASK_ID, binding}; + return {task_id_t::REDUCTION_BWD_TASK_ID, binding}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -74,10 +74,10 @@ static std::optional } TaskImplFunction get_reduction_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_reduction_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_reduction_fwd_signature() { @@ -96,7 +96,7 @@ OpTaskSignature get_reduction_bwd_signature() { } std::vector get_task_ids(ReductionAttrs const &) { - return {REDUCTION_FWD_TASK_ID, REDUCTION_BWD_TASK_ID}; + return {task_id_t::REDUCTION_FWD_TASK_ID, task_id_t::REDUCTION_BWD_TASK_ID}; } }; // namespace FlexFlow diff --git a/lib/local-execution/src/ops/repartition.cc b/lib/local-execution/src/ops/repartition.cc index 61050ed3d0..73692f4a13 100644 --- a/lib/local-execution/src/ops/repartition.cc +++ b/lib/local-execution/src/ops/repartition.cc @@ -31,7 +31,7 @@ OpTaskInvocation init(RepartitionAttrs const &attrs) { binding.bind_arg(HANDLE, ff_handle()); binding.bind(INPUT, input_tensor(0)); - return {REPARTITION_INIT_TASK_ID, binding}; + return {task_id_t::REPARTITION_INIT_TASK_ID, binding}; } OpTaskInvocation forward(RepartitionAttrs const &attrs) { @@ -44,16 +44,16 @@ OpTaskInvocation forward(RepartitionAttrs const &attrs) { binding.bind(INPUT, input_tensor(0)); binding.bind(OUTPUT, output_tensor(0)); - return {REPARTITION_FWD_TASK_ID, binding}; + return {task_id_t::REPARTITION_FWD_TASK_ID, binding}; } OpTaskInvocation backward(RepartitionAttrs const &attrs) { OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - return {REPARTITION_BWD_TASK_ID, binding}; + return {task_id_t::REPARTITION_BWD_TASK_ID, binding}; } -static DeviceSpecific +static DeviceSpecificDeviceStates init_task_impl(TaskArgumentAccessor const &acc) { auto input = acc.get_tensor(INPUT); PerDeviceFFHandle handle = acc.get_argument(HANDLE); @@ -62,7 +62,8 @@ static DeviceSpecific RepartitionPerDeviceState per_device_state = init_kernel(handle, input.data_type); - return DeviceSpecific::create(per_device_state); + return DeviceSpecificDeviceStates{ + DeviceSpecific::create(per_device_state)}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -97,13 +98,13 @@ static std::optional } TaskImplFunction get_repartition_init_task_impl() { - return init_task_impl; + return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; } TaskImplFunction get_repartition_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_repartition_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_repartition_init_signature() { @@ -129,9 +130,9 @@ OpTaskSignature get_repartition_bwd_signature() { } std::vector get_task_ids(RepartitionAttrs const &) { - return {REPARTITION_INIT_TASK_ID, - REPARTITION_FWD_TASK_ID, - REPARTITION_BWD_TASK_ID}; + return {task_id_t::REPARTITION_INIT_TASK_ID, + task_id_t::REPARTITION_FWD_TASK_ID, + task_id_t::REPARTITION_BWD_TASK_ID}; } }; // namespace FlexFlow diff --git a/lib/local-execution/src/ops/replicate.cc b/lib/local-execution/src/ops/replicate.cc index b3d3a152d6..135475a711 100644 --- a/lib/local-execution/src/ops/replicate.cc +++ b/lib/local-execution/src/ops/replicate.cc @@ -35,12 +35,12 @@ OpTaskInvocation forward(ReplicateAttrs const &attrs) { binding.bind(OUTPUT, output_tensor(0)); binding.bind_arg(ATTRS, attrs); - return {REPLICATE_FWD_TASK_ID, binding}; + return {task_id_t::REPLICATE_FWD_TASK_ID, binding}; } OpTaskInvocation backward(ReplicateAttrs const &attrs) { OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - return {REPLICATE_BWD_TASK_ID, binding}; + return {task_id_t::REPLICATE_BWD_TASK_ID, binding}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -73,10 +73,10 @@ static std::optional } TaskImplFunction get_replicate_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_replicate_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_replicate_fwd_signature() { @@ -94,7 +94,7 @@ OpTaskSignature get_replicate_bwd_signature() { } std::vector get_task_ids(ReplicateAttrs const &) { - return {REPLICATE_FWD_TASK_ID, REPLICATE_BWD_TASK_ID}; + return {task_id_t::REPLICATE_FWD_TASK_ID, task_id_t::REPLICATE_BWD_TASK_ID}; } }; // namespace FlexFlow diff --git a/lib/local-execution/src/ops/reshape.cc b/lib/local-execution/src/ops/reshape.cc index 8d6c0d83c2..7584d405eb 100644 --- a/lib/local-execution/src/ops/reshape.cc +++ b/lib/local-execution/src/ops/reshape.cc @@ -28,7 +28,7 @@ OpTaskInvocation init(ReshapeAttrs const &attrs) { binding.bind_arg(ATTRS, attrs); - return {RESHAPE_INIT_TASK_ID, binding}; + return {task_id_t::RESHAPE_INIT_TASK_ID, binding}; } OpTaskInvocation forward(ReshapeAttrs const &attrs) { @@ -40,21 +40,22 @@ OpTaskInvocation forward(ReshapeAttrs const &attrs) { binding.bind(INPUT, input_tensor(0)); binding.bind(OUTPUT, output_tensor(0)); - return {RESHAPE_FWD_TASK_ID, binding}; + return {task_id_t::RESHAPE_FWD_TASK_ID, binding}; } OpTaskInvocation backward(ReshapeAttrs const &attrs) { OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - return {RESHAPE_BWD_TASK_ID, binding}; + return {task_id_t::RESHAPE_BWD_TASK_ID, binding}; } -static DeviceSpecific +static DeviceSpecificDeviceStates init_task_impl(TaskArgumentAccessor const &acc) { auto attrs = acc.get_argument(ATTRS); ReshapePerDeviceState per_device_state = init_kernel(attrs.shape.data_type); - return DeviceSpecific::create(per_device_state); + return DeviceSpecificDeviceStates{ + DeviceSpecific::create(per_device_state)}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -91,13 +92,13 @@ static std::optional } TaskImplFunction get_reshape_init_task_impl() { - return init_task_impl; + return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; } TaskImplFunction get_reshape_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_reshape_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_reshape_init_signature() { @@ -124,7 +125,9 @@ OpTaskSignature get_reshape_bwd_signature() { } std::vector get_task_ids(ReshapeAttrs const &) { - return {RESHAPE_INIT_TASK_ID, RESHAPE_FWD_TASK_ID, RESHAPE_BWD_TASK_ID}; + return {task_id_t::RESHAPE_INIT_TASK_ID, + task_id_t::RESHAPE_FWD_TASK_ID, + task_id_t::RESHAPE_BWD_TASK_ID}; } }; // namespace FlexFlow diff --git a/lib/local-execution/src/ops/reverse.cc b/lib/local-execution/src/ops/reverse.cc index b2a5107360..366a579bea 100644 --- a/lib/local-execution/src/ops/reverse.cc +++ b/lib/local-execution/src/ops/reverse.cc @@ -34,12 +34,12 @@ OpTaskInvocation forward(ReverseAttrs const &attrs) { binding.bind(INPUT, input_tensor(0)); binding.bind(OUTPUT, output_tensor(0)); - return {REVERSE_FWD_TASK_ID, binding}; + return {task_id_t::REVERSE_FWD_TASK_ID, binding}; } OpTaskInvocation backward(ReverseAttrs const &attrs) { OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - return {REVERSE_BWD_TASK_ID, binding}; + return {task_id_t::REVERSE_BWD_TASK_ID, binding}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -103,10 +103,10 @@ static std::optional } TaskImplFunction get_reverse_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_reverse_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_reverse_fwd_signature() { @@ -124,7 +124,7 @@ OpTaskSignature get_reverse_bwd_signature() { } std::vector get_task_ids(ReverseAttrs const &) { - return {REVERSE_FWD_TASK_ID, REVERSE_BWD_TASK_ID}; + return {task_id_t::REVERSE_FWD_TASK_ID, task_id_t::REVERSE_BWD_TASK_ID}; } }; // namespace FlexFlow diff --git a/lib/local-execution/src/ops/softmax.cc b/lib/local-execution/src/ops/softmax.cc index a0b3a047a7..4c7979ae9b 100644 --- a/lib/local-execution/src/ops/softmax.cc +++ b/lib/local-execution/src/ops/softmax.cc @@ -30,7 +30,7 @@ OpTaskInvocation init(SoftmaxAttrs const &attrs) { binding.bind_arg(HANDLE, ff_handle()); binding.bind_arg(ATTRS, attrs); - return {SOFTMAX_INIT_TASK_ID, binding}; + return {task_id_t::SOFTMAX_INIT_TASK_ID, binding}; } OpTaskInvocation forward(SoftmaxAttrs const &attrs) { @@ -43,16 +43,16 @@ OpTaskInvocation forward(SoftmaxAttrs const &attrs) { binding.bind(INPUT, input_tensor(0)); binding.bind(OUTPUT, output_tensor(0)); - return {SOFTMAX_FWD_TASK_ID, binding}; + return {task_id_t::SOFTMAX_FWD_TASK_ID, binding}; } OpTaskInvocation backward(SoftmaxAttrs const &attrs) { OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - return {SOFTMAX_BWD_TASK_ID, binding}; + return {task_id_t::SOFTMAX_BWD_TASK_ID, binding}; } -static DeviceSpecific +static DeviceSpecificDeviceStates init_task_impl(TaskArgumentAccessor const &acc) { PerDeviceFFHandle handle = acc.get_argument(HANDLE); @@ -67,7 +67,8 @@ static DeviceSpecific SoftmaxPerDeviceState per_device_state = init_kernel( handle, attrs.dim.value, output_n, output_c, output_h, output_w); - return DeviceSpecific::create(per_device_state); + return DeviceSpecificDeviceStates{ + DeviceSpecific::create(per_device_state)}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -107,13 +108,13 @@ static std::optional } TaskImplFunction get_softmax_init_task_impl() { - return init_task_impl; + return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; } TaskImplFunction get_softmax_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_softmax_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_softmax_init_signature() { @@ -140,7 +141,9 @@ OpTaskSignature get_softmax_bwd_signature() { } std::vector get_task_ids(SoftmaxAttrs const &) { - return {SOFTMAX_INIT_TASK_ID, SOFTMAX_FWD_TASK_ID, SOFTMAX_BWD_TASK_ID}; + return {task_id_t::SOFTMAX_INIT_TASK_ID, + task_id_t::SOFTMAX_FWD_TASK_ID, + task_id_t::SOFTMAX_BWD_TASK_ID}; } }; // namespace FlexFlow diff --git a/lib/local-execution/src/ops/split.cc b/lib/local-execution/src/ops/split.cc index 59d68dc8f5..9f039d84f8 100644 --- a/lib/local-execution/src/ops/split.cc +++ b/lib/local-execution/src/ops/split.cc @@ -35,13 +35,13 @@ OpTaskInvocation forward(SplitAttrs const &attrs) { binding.bind(INPUT, input_tensor(0)); binding.bind(OUTPUT, output_tensor(0)); - return {SPLIT_FWD_TASK_ID, binding}; + return {task_id_t::SPLIT_FWD_TASK_ID, binding}; } OpTaskInvocation backward(SplitAttrs const &attrs) { OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - return {SPLIT_BWD_TASK_ID, binding}; + return {task_id_t::SPLIT_BWD_TASK_ID, binding}; } void calc_block_size(coord_t &num_blocks, @@ -114,10 +114,10 @@ static std::optional } TaskImplFunction get_split_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_split_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_split_fwd_signature() { @@ -135,7 +135,7 @@ OpTaskSignature get_split_bwd_signature() { } std::vector get_task_ids(SplitAttrs const &) { - return {SPLIT_FWD_TASK_ID, SPLIT_BWD_TASK_ID}; + return {task_id_t::SPLIT_FWD_TASK_ID, task_id_t::SPLIT_BWD_TASK_ID}; } }; // namespace FlexFlow diff --git a/lib/local-execution/src/ops/topk.cc b/lib/local-execution/src/ops/topk.cc index 1669f58c3b..7f3519529a 100644 --- a/lib/local-execution/src/ops/topk.cc +++ b/lib/local-execution/src/ops/topk.cc @@ -33,7 +33,7 @@ OpTaskInvocation init(TopKAttrs const &attrs) { binding.bind_arg(ATTRS, attrs); - return {TOPK_INIT_TASK_ID, binding}; + return {task_id_t::TOPK_INIT_TASK_ID, binding}; } OpTaskInvocation forward(TopKAttrs const &attrs) { @@ -47,22 +47,23 @@ OpTaskInvocation forward(TopKAttrs const &attrs) { binding.bind(OUTPUT, output_tensor(0)); binding.bind(INDICES, output_tensor(1)); - return {TOPK_FWD_TASK_ID, binding}; + return {task_id_t::TOPK_FWD_TASK_ID, binding}; } OpTaskInvocation backward(TopKAttrs const &attrs) { OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - return {TOPK_BWD_TASK_ID, binding}; + return {task_id_t::TOPK_BWD_TASK_ID, binding}; } -static DeviceSpecific +static DeviceSpecificDeviceStates init_task_impl(TaskArgumentAccessor const &acc) { auto attrs = acc.get_argument(ATTRS); TopKPerDeviceState per_device_state = init_kernel(attrs.sorted); - return DeviceSpecific::create(per_device_state); + return DeviceSpecificDeviceStates{ + DeviceSpecific::create(per_device_state)}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -119,13 +120,13 @@ static std::optional } TaskImplFunction get_topk_init_task_impl() { - return init_task_impl; + return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; } TaskImplFunction get_topk_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_topk_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_topk_init_signature() { @@ -154,7 +155,9 @@ OpTaskSignature get_topk_bwd_signature() { } std::vector get_task_ids(TopKAttrs const &) { - return {TOPK_INIT_TASK_ID, TOPK_FWD_TASK_ID, TOPK_BWD_TASK_ID}; + return {task_id_t::TOPK_INIT_TASK_ID, + task_id_t::TOPK_FWD_TASK_ID, + task_id_t::TOPK_BWD_TASK_ID}; } }; // namespace FlexFlow diff --git a/lib/local-execution/src/ops/transpose.cc b/lib/local-execution/src/ops/transpose.cc index 5fa57772e2..5c3c1dd1ca 100644 --- a/lib/local-execution/src/ops/transpose.cc +++ b/lib/local-execution/src/ops/transpose.cc @@ -34,16 +34,17 @@ enum Slots { OpTaskInvocation init(TransposeAttrs const &attrs) { OpTaskBinding binding; binding.bind_arg(ATTRS, attrs); - return {TRANSPOSE_INIT_TASK_ID, binding}; + return {task_id_t::TRANSPOSE_INIT_TASK_ID, binding}; } -static DeviceSpecific +static DeviceSpecificDeviceStates init_task_impl(TaskArgumentAccessor const &acc) { auto const &attrs = acc.get_argument(ATTRS); std::vector perm = inner_to_outer_idxs(attrs.perm); TransposePerDeviceState per_device_state = init_kernel(perm.size(), perm); - return DeviceSpecific::create(per_device_state); + return DeviceSpecificDeviceStates{ + DeviceSpecific::create(per_device_state)}; } OpTaskInvocation forward(TransposeAttrs const &attrs) { @@ -56,7 +57,7 @@ OpTaskInvocation forward(TransposeAttrs const &attrs) { binding.bind(INPUT, input_tensor(0)); binding.bind(OUTPUT, output_tensor(0)); - return {TRANSPOSE_FWD_TASK_ID, binding}; + return {task_id_t::TRANSPOSE_FWD_TASK_ID, binding}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -95,17 +96,17 @@ static std::optional OpTaskInvocation backward(TransposeAttrs const &attrs) { OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - return {TRANSPOSE_BWD_TASK_ID, binding}; + return {task_id_t::TRANSPOSE_BWD_TASK_ID, binding}; } TaskImplFunction get_transpose_init_task_impl() { - return init_task_impl; + return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; } TaskImplFunction get_transpose_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_transpose_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_transpose_init_signature() { @@ -131,7 +132,9 @@ OpTaskSignature get_transpose_bwd_signature() { } std::vector get_task_ids(TransposeAttrs const &) { - return {TRANSPOSE_INIT_TASK_ID, TRANSPOSE_FWD_TASK_ID, TRANSPOSE_BWD_TASK_ID}; + return {task_id_t::TRANSPOSE_INIT_TASK_ID, + task_id_t::TRANSPOSE_FWD_TASK_ID, + task_id_t::TRANSPOSE_BWD_TASK_ID}; } } // namespace FlexFlow diff --git a/lib/local-execution/src/ops/weight.cc b/lib/local-execution/src/ops/weight.cc new file mode 100644 index 0000000000..5537163e85 --- /dev/null +++ b/lib/local-execution/src/ops/weight.cc @@ -0,0 +1,9 @@ +#include "weight.h" + +namespace FlexFlow { + +std::vector get_task_ids(WeightAttrs const &attrs) { + return {}; +} + +}; // namespace FlexFlow diff --git a/lib/local-execution/src/ops/weight.h b/lib/local-execution/src/ops/weight.h new file mode 100644 index 0000000000..e59a88f07d --- /dev/null +++ b/lib/local-execution/src/ops/weight.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_WEIGHT_H +#define _FLEXFLOW_WEIGHT_H + +#include "local-execution/op_task_invocation.h" +#include "op-attrs/ops/weight_attrs.dtg.h" + +namespace FlexFlow { + +std::vector get_task_ids(WeightAttrs const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/local-execution/src/per_device_state.cc b/lib/local-execution/src/per_device_state.cc new file mode 100644 index 0000000000..fa470b196d --- /dev/null +++ b/lib/local-execution/src/per_device_state.cc @@ -0,0 +1,12 @@ +#include "local-execution/per_device_op_state.h" +#include "utils/overload.h" + +namespace FlexFlow { + +PerDeviceOpState get_device_state_from_device_specific( + DeviceSpecificDeviceStates const &device_specific, size_t device_idx) { + return device_specific.visit( + [&](auto const &x) { return PerDeviceOpState{*(x.get(device_idx))}; }); +} + +} // namespace FlexFlow diff --git a/lib/local-execution/src/task_registry.cc b/lib/local-execution/src/task_registry.cc index 23a0c96d28..dad5c1fc69 100644 --- a/lib/local-execution/src/task_registry.cc +++ b/lib/local-execution/src/task_registry.cc @@ -1,31 +1,44 @@ #include "local-execution/task_registry.h" +#include "local-execution/task_signature_impl.h" namespace FlexFlow { -void TaskRegistry::register_task(task_id_t const &task_id, - layer_guid_t const &op_id, - ComputationGraphOpAttrs const &attrs) { - TaskSignatureAndImpl task_signature_impl = get_task_sig_impl(task_id); - switch (task_signature_impl.task_signature.type) { - case OpTaskType::INIT: - assert( - is_invocation_valid(task_signature_impl.task_signature, init(attrs))); - this->init_task_ids.insert({op_id, task_id}); - break; - case OpTaskType::FWD: - assert(is_invocation_valid(task_signature_impl.task_signature, - forward(attrs))); - this->forward_task_ids.insert({op_id, task_id}); - break; - case OpTaskType::BWD: - assert(is_invocation_valid(task_signature_impl.task_signature, - backward(attrs))); - this->backward_task_ids.insert({op_id, task_id}); - break; - default: - throw mk_runtime_error("Invalid OpTaskType"); +TaskRegistry empty_task_registry() { + return TaskRegistry{{}, {}, {}, {}}; +} + +void register_tasks_for_layer(TaskRegistry &task_registry, + layer_guid_t const &op_id, + ComputationGraphOpAttrs const &attrs) { + task_registry.init_task_ids.insert({op_id, std::nullopt}); + task_registry.forward_task_ids.insert({op_id, std::nullopt}); + task_registry.backward_task_ids.insert({op_id, std::nullopt}); + + // register tasks + std::vector task_ids = get_task_ids(attrs); + for (task_id_t task_id : task_ids) { + TaskSignatureAndImpl task_signature_impl = get_task_sig_impl(task_id); + switch (task_signature_impl.task_signature.type) { + case OpTaskType::INIT: + assert(is_invocation_valid(task_signature_impl.task_signature, + init(attrs))); + task_registry.init_task_ids[op_id] = task_id; + break; + case OpTaskType::FWD: + assert(is_invocation_valid(task_signature_impl.task_signature, + forward(attrs))); + task_registry.forward_task_ids[op_id] = task_id; + break; + case OpTaskType::BWD: + assert(is_invocation_valid(task_signature_impl.task_signature, + backward(attrs))); + task_registry.backward_task_ids[op_id] = task_id; + break; + default: + throw mk_runtime_error("Invalid OpTaskType"); + } + task_registry.task_mapping.insert({task_id, task_signature_impl}); } - this->task_mapping.insert({task_id, task_signature_impl}); } } // namespace FlexFlow diff --git a/lib/local-execution/src/task_signature_impl.cc b/lib/local-execution/src/task_signature_impl.cc index 62e77e3199..ca428aad25 100644 --- a/lib/local-execution/src/task_signature_impl.cc +++ b/lib/local-execution/src/task_signature_impl.cc @@ -1,5 +1,5 @@ #include "local-execution/task_signature_impl.h" -#include "ops/attention.h" +#include "local-execution/ops/attention.h" #include "ops/batch_matmul.h" #include "ops/batch_norm.h" #include "ops/cast.h" @@ -12,8 +12,10 @@ #include "ops/embedding.h" #include "ops/flat.h" #include "ops/gather.h" +#include "ops/input.h" #include "ops/layer_norm.h" #include "ops/linear.h" +#include "ops/noop.h" #include "ops/pool_2d.h" #include "ops/reduce.h" #include "ops/reduction.h" @@ -25,160 +27,344 @@ #include "ops/split.h" #include "ops/topk.h" #include "ops/transpose.h" +#include "ops/weight.h" +#include "utils/overload.h" namespace FlexFlow { TaskSignatureAndImpl get_task_sig_impl(task_id_t const &task_id) { switch (task_id) { - case ELEMENTBINARY_INIT_TASK_ID: - return {get_element_binary_init_task_impl(), - get_element_binary_init_signature()}; - case ELEMENTBINARY_FWD_TASK_ID: - return {get_element_binary_fwd_task_impl(), - get_element_binary_fwd_signature()}; - case ELEMENTBINARY_BWD_TASK_ID: - return {get_element_binary_bwd_task_impl(), - get_element_binary_bwd_signature()}; - case ELEMENTUNARY_INIT_TASK_ID: - return {get_element_unary_init_task_impl(), - get_element_unary_init_signature()}; - case ELEMENTUNARY_FWD_TASK_ID: - return {get_element_unary_fwd_task_impl(), - get_element_unary_fwd_signature()}; - case ELEMENTUNARY_BWD_TASK_ID: - return {get_element_binary_bwd_task_impl(), - get_element_binary_bwd_signature()}; - case CONV2D_INIT_TASK_ID: - return {get_conv_2d_init_task_impl(), get_conv_2d_init_signature()}; - case CONV2D_FWD_TASK_ID: - return {get_conv_2d_fwd_task_impl(), get_conv_2d_fwd_signature()}; - case CONV2D_BWD_TASK_ID: - return {get_conv_2d_bwd_task_impl(), get_conv_2d_bwd_signature()}; - case DROPOUT_INIT_TASK_ID: - return {get_dropout_init_task_impl(), get_dropout_init_signature()}; - case DROPOUT_FWD_TASK_ID: - return {get_dropout_fwd_task_impl(), get_dropout_fwd_signature()}; - case DROPOUT_BWD_TASK_ID: - return {get_dropout_bwd_task_impl(), get_dropout_bwd_signature()}; - case EMBED_FWD_TASK_ID: - return {get_embedding_fwd_task_impl(), get_embedding_fwd_signature()}; - case EMBED_BWD_TASK_ID: - return {get_embedding_bwd_task_impl(), get_embedding_bwd_signature()}; - case GATHER_INIT_TASK_ID: - return {get_gather_init_task_impl(), get_gather_init_signature()}; - case GATHER_FWD_TASK_ID: - return {get_gather_fwd_task_impl(), get_embedding_fwd_signature()}; - case GATHER_BWD_TASK_ID: - return {get_gather_bwd_task_impl(), get_embedding_bwd_signature()}; - case CAST_FWD_TASK_ID: - return {get_cast_fwd_task_impl(), get_cast_fwd_signature()}; - case CAST_BWD_TASK_ID: - return {get_cast_bwd_task_impl(), get_cast_bwd_signature()}; - case POOL2D_INIT_TASK_ID: - return {get_pool_2d_init_task_impl(), get_pool_2d_init_signature()}; - case POOL2D_FWD_TASK_ID: - return {get_pool_2d_fwd_task_impl(), get_pool_2d_fwd_signature()}; - case POOL2D_BWD_TASK_ID: - return {get_pool_2d_bwd_task_impl(), get_pool_2d_bwd_signature()}; - case BATCHNORM_INIT_TASK_ID: - return {get_batch_norm_init_task_impl(), get_batch_norm_init_signature()}; - case BATCHNORM_FWD_TASK_ID: - return {get_batch_norm_fwd_task_impl(), get_batch_norm_fwd_signature()}; - case BATCHNORM_BWD_TASK_ID: - return {get_batch_norm_bwd_task_impl(), get_batch_norm_bwd_signature()}; - case BATCHMATMUL_FWD_TASK_ID: - return {get_batch_matmul_fwd_task_impl(), - get_batch_matmul_fwd_signature()}; - case BATCHMATMUL_BWD_TASK_ID: - return {get_batch_matmul_bwd_task_impl(), - get_batch_matmul_bwd_signature()}; - case LAYERNORM_INIT_TASK_ID: - return {get_layer_norm_init_task_impl(), get_layer_norm_init_signature()}; - case LAYERNORM_FWD_TASK_ID: - return {get_layer_norm_fwd_task_impl(), get_layer_norm_init_signature()}; - case LAYERNORM_BWD_TASK_ID: - return {get_layer_norm_bwd_task_impl(), get_layer_norm_bwd_signature()}; - case LINEAR_INIT_TASK_ID: - return {get_linear_init_task_impl(), get_linear_init_signature()}; - case LINEAR_FWD_TASK_ID: - return {get_linear_fwd_task_impl(), get_linear_fwd_signature()}; - case LINEAR_BWD_TASK_ID: - return {get_linear_bwd_task_impl(), get_linear_bwd_signature()}; - case FLAT_FWD_TASK_ID: - return {get_flat_fwd_task_impl(), get_flat_fwd_signature()}; - case FLAT_BWD_TASK_ID: - return {get_flat_bwd_task_impl(), get_flat_bwd_signature()}; - case SOFTMAX_INIT_TASK_ID: - return {get_softmax_init_task_impl(), get_softmax_init_signature()}; - case SOFTMAX_FWD_TASK_ID: - return {get_softmax_fwd_task_impl(), get_softmax_fwd_signature()}; - case SOFTMAX_BWD_TASK_ID: - return {get_softmax_bwd_task_impl(), get_softmax_bwd_signature()}; - case CONCAT_FWD_TASK_ID: - return {get_concat_fwd_task_impl(), get_concat_fwd_signature()}; - case CONCAT_BWD_TASK_ID: - return {get_concat_bwd_task_impl(), get_concat_bwd_signature()}; - case SPLIT_FWD_TASK_ID: - return {get_split_fwd_task_impl(), get_split_fwd_signature()}; - case SPLIT_BWD_TASK_ID: - return {get_split_bwd_task_impl(), get_split_bwd_signature()}; - case REDUCE_INIT_TASK_ID: - return {get_reduce_init_task_impl(), get_reduce_init_signature()}; - case REDUCE_FWD_TASK_ID: - return {get_reduce_fwd_task_impl(), get_reduce_fwd_signature()}; - case REDUCE_BWD_TASK_ID: - return {get_reduce_bwd_task_impl(), get_reduce_bwd_signature()}; - case RESHAPE_INIT_TASK_ID: - return {get_reshape_init_task_impl(), get_reshape_init_signature()}; - case RESHAPE_FWD_TASK_ID: - return {get_reshape_fwd_task_impl(), get_reshape_fwd_signature()}; - case RESHAPE_BWD_TASK_ID: - return {get_reshape_bwd_task_impl(), get_reshape_bwd_signature()}; - case REVERSE_FWD_TASK_ID: - return {get_reverse_fwd_task_impl(), get_reverse_fwd_signature()}; - case REVERSE_BWD_TASK_ID: - return {get_reverse_bwd_task_impl(), get_reverse_bwd_signature()}; - case TOPK_INIT_TASK_ID: - return {get_topk_init_task_impl(), get_topk_init_signature()}; - case TOPK_FWD_TASK_ID: - return {get_topk_fwd_task_impl(), get_topk_fwd_signature()}; - case TOPK_BWD_TASK_ID: - return {get_topk_bwd_task_impl(), get_topk_bwd_signature()}; - case TRANSPOSE_INIT_TASK_ID: - return {get_transpose_init_task_impl(), get_transpose_init_signature()}; - case TRANSPOSE_FWD_TASK_ID: - return {get_transpose_fwd_task_impl(), get_transpose_fwd_signature()}; - case TRANSPOSE_BWD_TASK_ID: - return {get_transpose_bwd_task_impl(), get_transpose_bwd_signature()}; - case ATTENTION_INIT_TASK_ID: - return {get_attention_init_task_impl(), get_attention_init_signature()}; - case ATTENTION_FWD_TASK_ID: - return {get_attention_fwd_task_impl(), get_attention_fwd_signature()}; - case ATTENTION_BWD_TASK_ID: - return {get_attention_bwd_task_impl(), get_attention_bwd_signature()}; - case COMBINE_FWD_TASK_ID: - return {get_combine_fwd_task_impl(), get_combine_fwd_signature()}; - case COMBINE_BWD_TASK_ID: - return {get_combine_bwd_task_impl(), get_combine_bwd_signature()}; - case REDUCTION_FWD_TASK_ID: - return {get_reduction_fwd_task_impl(), get_reduction_fwd_signature()}; - case REDUCTION_BWD_TASK_ID: - return {get_reduction_bwd_task_impl(), get_reduction_bwd_signature()}; - case REPARTITION_INIT_TASK_ID: - return {get_repartition_init_task_impl(), - get_repartition_init_signature()}; - case REPARTITION_FWD_TASK_ID: - return {get_repartition_fwd_task_impl(), get_repartition_fwd_signature()}; - case REPARTITION_BWD_TASK_ID: - return {get_repartition_bwd_task_impl(), get_repartition_bwd_signature()}; - case REPLICATE_FWD_TASK_ID: - return {get_replicate_fwd_task_impl(), get_replicate_fwd_signature()}; - case REPLICATE_BWD_TASK_ID: - return {get_replicate_bwd_task_impl(), get_replicate_bwd_signature()}; + case task_id_t::ELEMENTBINARY_INIT_TASK_ID: + return TaskSignatureAndImpl{get_element_binary_init_task_impl(), + get_element_binary_init_signature()}; + case task_id_t::ELEMENTBINARY_FWD_TASK_ID: + return TaskSignatureAndImpl{get_element_binary_fwd_task_impl(), + get_element_binary_fwd_signature()}; + case task_id_t::ELEMENTBINARY_BWD_TASK_ID: + return TaskSignatureAndImpl{get_element_binary_bwd_task_impl(), + get_element_binary_bwd_signature()}; + case task_id_t::ELEMENTUNARY_INIT_TASK_ID: + return TaskSignatureAndImpl{get_element_unary_init_task_impl(), + get_element_unary_init_signature()}; + case task_id_t::ELEMENTUNARY_FWD_TASK_ID: + return TaskSignatureAndImpl{get_element_unary_fwd_task_impl(), + get_element_unary_fwd_signature()}; + case task_id_t::ELEMENTUNARY_BWD_TASK_ID: + return TaskSignatureAndImpl{get_element_binary_bwd_task_impl(), + get_element_binary_bwd_signature()}; + case task_id_t::CONV2D_INIT_TASK_ID: + return TaskSignatureAndImpl{get_conv_2d_init_task_impl(), + get_conv_2d_init_signature()}; + case task_id_t::CONV2D_FWD_TASK_ID: + return TaskSignatureAndImpl{get_conv_2d_fwd_task_impl(), + get_conv_2d_fwd_signature()}; + case task_id_t::CONV2D_BWD_TASK_ID: + return TaskSignatureAndImpl{get_conv_2d_bwd_task_impl(), + get_conv_2d_bwd_signature()}; + case task_id_t::DROPOUT_INIT_TASK_ID: + return TaskSignatureAndImpl{get_dropout_init_task_impl(), + get_dropout_init_signature()}; + case task_id_t::DROPOUT_FWD_TASK_ID: + return TaskSignatureAndImpl{get_dropout_fwd_task_impl(), + get_dropout_fwd_signature()}; + case task_id_t::DROPOUT_BWD_TASK_ID: + return TaskSignatureAndImpl{get_dropout_bwd_task_impl(), + get_dropout_bwd_signature()}; + // case task_id_t::EMBED_FWD_TASK_ID: + // return TaskSignatureAndImpl{get_embedding_fwd_task_impl(), + // get_embedding_fwd_signature()}; + // case task_id_t::EMBED_BWD_TASK_ID: + // return TaskSignatureAndImpl{get_embedding_bwd_task_impl(), + // get_embedding_bwd_signature()}; + case task_id_t::GATHER_INIT_TASK_ID: + return TaskSignatureAndImpl{get_gather_init_task_impl(), + get_gather_init_signature()}; + case task_id_t::GATHER_FWD_TASK_ID: + return TaskSignatureAndImpl{get_gather_fwd_task_impl(), + get_gather_fwd_signature()}; + case task_id_t::GATHER_BWD_TASK_ID: + return TaskSignatureAndImpl{get_gather_bwd_task_impl(), + get_gather_bwd_signature()}; + case task_id_t::CAST_FWD_TASK_ID: + return TaskSignatureAndImpl{get_cast_fwd_task_impl(), + get_cast_fwd_signature()}; + case task_id_t::CAST_BWD_TASK_ID: + return TaskSignatureAndImpl{get_cast_bwd_task_impl(), + get_cast_bwd_signature()}; + case task_id_t::POOL2D_INIT_TASK_ID: + return TaskSignatureAndImpl{get_pool_2d_init_task_impl(), + get_pool_2d_init_signature()}; + case task_id_t::POOL2D_FWD_TASK_ID: + return TaskSignatureAndImpl{get_pool_2d_fwd_task_impl(), + get_pool_2d_fwd_signature()}; + case task_id_t::POOL2D_BWD_TASK_ID: + return TaskSignatureAndImpl{get_pool_2d_bwd_task_impl(), + get_pool_2d_bwd_signature()}; + case task_id_t::BATCHNORM_INIT_TASK_ID: + return TaskSignatureAndImpl{get_batch_norm_init_task_impl(), + get_batch_norm_init_signature()}; + case task_id_t::BATCHNORM_FWD_TASK_ID: + return TaskSignatureAndImpl{get_batch_norm_fwd_task_impl(), + get_batch_norm_fwd_signature()}; + case task_id_t::BATCHNORM_BWD_TASK_ID: + return TaskSignatureAndImpl{get_batch_norm_bwd_task_impl(), + get_batch_norm_bwd_signature()}; + case task_id_t::BATCHMATMUL_FWD_TASK_ID: + return TaskSignatureAndImpl{get_batch_matmul_fwd_task_impl(), + get_batch_matmul_fwd_signature()}; + case task_id_t::BATCHMATMUL_BWD_TASK_ID: + return TaskSignatureAndImpl{get_batch_matmul_bwd_task_impl(), + get_batch_matmul_bwd_signature()}; + case task_id_t::LAYERNORM_INIT_TASK_ID: + return TaskSignatureAndImpl{get_layer_norm_init_task_impl(), + get_layer_norm_init_signature()}; + case task_id_t::LAYERNORM_FWD_TASK_ID: + return TaskSignatureAndImpl{get_layer_norm_fwd_task_impl(), + get_layer_norm_init_signature()}; + case task_id_t::LAYERNORM_BWD_TASK_ID: + return TaskSignatureAndImpl{get_layer_norm_bwd_task_impl(), + get_layer_norm_bwd_signature()}; + case task_id_t::LINEAR_INIT_TASK_ID: + return TaskSignatureAndImpl{get_linear_init_task_impl(), + get_linear_init_signature()}; + case task_id_t::LINEAR_FWD_TASK_ID: + return TaskSignatureAndImpl{get_linear_fwd_task_impl(), + get_linear_fwd_signature()}; + case task_id_t::LINEAR_BWD_TASK_ID: + return TaskSignatureAndImpl{get_linear_bwd_task_impl(), + get_linear_bwd_signature()}; + case task_id_t::FLAT_FWD_TASK_ID: + return TaskSignatureAndImpl{get_flat_fwd_task_impl(), + get_flat_fwd_signature()}; + case task_id_t::FLAT_BWD_TASK_ID: + return TaskSignatureAndImpl{get_flat_bwd_task_impl(), + get_flat_bwd_signature()}; + case task_id_t::SOFTMAX_INIT_TASK_ID: + return TaskSignatureAndImpl{get_softmax_init_task_impl(), + get_softmax_init_signature()}; + case task_id_t::SOFTMAX_FWD_TASK_ID: + return TaskSignatureAndImpl{get_softmax_fwd_task_impl(), + get_softmax_fwd_signature()}; + case task_id_t::SOFTMAX_BWD_TASK_ID: + return TaskSignatureAndImpl{get_softmax_bwd_task_impl(), + get_softmax_bwd_signature()}; + case task_id_t::CONCAT_FWD_TASK_ID: + return TaskSignatureAndImpl{get_concat_fwd_task_impl(), + get_concat_fwd_signature()}; + case task_id_t::CONCAT_BWD_TASK_ID: + return TaskSignatureAndImpl{get_concat_bwd_task_impl(), + get_concat_bwd_signature()}; + case task_id_t::SPLIT_FWD_TASK_ID: + return TaskSignatureAndImpl{get_split_fwd_task_impl(), + get_split_fwd_signature()}; + case task_id_t::SPLIT_BWD_TASK_ID: + return TaskSignatureAndImpl{get_split_bwd_task_impl(), + get_split_bwd_signature()}; + case task_id_t::REDUCE_INIT_TASK_ID: + return TaskSignatureAndImpl{get_reduce_init_task_impl(), + get_reduce_init_signature()}; + case task_id_t::REDUCE_FWD_TASK_ID: + return TaskSignatureAndImpl{get_reduce_fwd_task_impl(), + get_reduce_fwd_signature()}; + case task_id_t::REDUCE_BWD_TASK_ID: + return TaskSignatureAndImpl{get_reduce_bwd_task_impl(), + get_reduce_bwd_signature()}; + case task_id_t::RESHAPE_INIT_TASK_ID: + return TaskSignatureAndImpl{get_reshape_init_task_impl(), + get_reshape_init_signature()}; + case task_id_t::RESHAPE_FWD_TASK_ID: + return TaskSignatureAndImpl{get_reshape_fwd_task_impl(), + get_reshape_fwd_signature()}; + case task_id_t::RESHAPE_BWD_TASK_ID: + return TaskSignatureAndImpl{get_reshape_bwd_task_impl(), + get_reshape_bwd_signature()}; + case task_id_t::REVERSE_FWD_TASK_ID: + return TaskSignatureAndImpl{get_reverse_fwd_task_impl(), + get_reverse_fwd_signature()}; + case task_id_t::REVERSE_BWD_TASK_ID: + return TaskSignatureAndImpl{get_reverse_bwd_task_impl(), + get_reverse_bwd_signature()}; + case task_id_t::TOPK_INIT_TASK_ID: + return TaskSignatureAndImpl{get_topk_init_task_impl(), + get_topk_init_signature()}; + case task_id_t::TOPK_FWD_TASK_ID: + return TaskSignatureAndImpl{get_topk_fwd_task_impl(), + get_topk_fwd_signature()}; + case task_id_t::TOPK_BWD_TASK_ID: + return TaskSignatureAndImpl{get_topk_bwd_task_impl(), + get_topk_bwd_signature()}; + case task_id_t::TRANSPOSE_INIT_TASK_ID: + return TaskSignatureAndImpl{get_transpose_init_task_impl(), + get_transpose_init_signature()}; + case task_id_t::TRANSPOSE_FWD_TASK_ID: + return TaskSignatureAndImpl{get_transpose_fwd_task_impl(), + get_transpose_fwd_signature()}; + case task_id_t::TRANSPOSE_BWD_TASK_ID: + return TaskSignatureAndImpl{get_transpose_bwd_task_impl(), + get_transpose_bwd_signature()}; + case task_id_t::ATTENTION_INIT_TASK_ID: + return TaskSignatureAndImpl{get_attention_init_task_impl(), + get_attention_init_signature()}; + case task_id_t::ATTENTION_FWD_TASK_ID: + return TaskSignatureAndImpl{get_attention_fwd_task_impl(), + get_attention_fwd_signature()}; + case task_id_t::ATTENTION_BWD_TASK_ID: + return TaskSignatureAndImpl{get_attention_bwd_task_impl(), + get_attention_bwd_signature()}; + case task_id_t::COMBINE_FWD_TASK_ID: + return TaskSignatureAndImpl{get_combine_fwd_task_impl(), + get_combine_fwd_signature()}; + case task_id_t::COMBINE_BWD_TASK_ID: + return TaskSignatureAndImpl{get_combine_bwd_task_impl(), + get_combine_bwd_signature()}; + case task_id_t::REDUCTION_FWD_TASK_ID: + return TaskSignatureAndImpl{get_reduction_fwd_task_impl(), + get_reduction_fwd_signature()}; + case task_id_t::REDUCTION_BWD_TASK_ID: + return TaskSignatureAndImpl{get_reduction_bwd_task_impl(), + get_reduction_bwd_signature()}; + case task_id_t::REPARTITION_INIT_TASK_ID: + return TaskSignatureAndImpl{get_repartition_init_task_impl(), + get_repartition_init_signature()}; + case task_id_t::REPARTITION_FWD_TASK_ID: + return TaskSignatureAndImpl{get_repartition_fwd_task_impl(), + get_repartition_fwd_signature()}; + case task_id_t::REPARTITION_BWD_TASK_ID: + return TaskSignatureAndImpl{get_repartition_bwd_task_impl(), + get_repartition_bwd_signature()}; + case task_id_t::REPLICATE_FWD_TASK_ID: + return TaskSignatureAndImpl{get_replicate_fwd_task_impl(), + get_replicate_fwd_signature()}; + case task_id_t::REPLICATE_BWD_TASK_ID: + return TaskSignatureAndImpl{get_replicate_bwd_task_impl(), + get_replicate_bwd_signature()}; default: throw mk_runtime_error( fmt::format("Invalid task ID")); // inserting task_id yields // "type_is_unformattable" error } } + +std::vector get_task_ids(ComputationGraphOpAttrs const &op) { + return op.visit>(overload{ + [](BatchMatmulAttrs const &attrs) { return get_task_ids(attrs); }, + [](BatchNormAttrs const &attrs) { return get_task_ids(attrs); }, + [](CastAttrs const &attrs) { return get_task_ids(attrs); }, + [](ConcatAttrs const &attrs) { return get_task_ids(attrs); }, + [](Conv2DAttrs const &attrs) { return get_task_ids(attrs); }, + [](DropoutAttrs const &attrs) { return get_task_ids(attrs); }, + [](ElementBinaryAttrs const &attrs) { return get_task_ids(attrs); }, + [](ElementUnaryAttrs const &attrs) { return get_task_ids(attrs); }, + // [](EmbeddingAttrs const & attrs) { + // return get_task_ids(attrs); + // }, + [](FlatAttrs const &attrs) { return get_task_ids(attrs); }, + [](GatherAttrs const &attrs) { return get_task_ids(attrs); }, + [](InputAttrs const &attrs) { return get_task_ids(attrs); }, + [](LayerNormAttrs const &attrs) { return get_task_ids(attrs); }, + [](LinearAttrs const &attrs) { return get_task_ids(attrs); }, + [](MultiHeadAttentionAttrs const &attrs) { return get_task_ids(attrs); }, + [](NoopAttrs const &attrs) { return get_task_ids(attrs); }, + [](Pool2DAttrs const &attrs) { return get_task_ids(attrs); }, + [](ReduceAttrs const &attrs) { return get_task_ids(attrs); }, + [](ReverseAttrs const &attrs) { return get_task_ids(attrs); }, + [](ReshapeAttrs const &attrs) { return get_task_ids(attrs); }, + [](SplitAttrs const &attrs) { return get_task_ids(attrs); }, + [](SoftmaxAttrs const &attrs) { return get_task_ids(attrs); }, + [](TopKAttrs const &attrs) { return get_task_ids(attrs); }, + [](TransposeAttrs const &attrs) { return get_task_ids(attrs); }, + [](WeightAttrs const &attrs) { return get_task_ids(attrs); }, + [](auto const &attrs) -> std::vector { + throw mk_runtime_error(fmt::format("Unhandled attr type: {}", attrs)); + }, + }); +} + +OpTaskInvocation init(ComputationGraphOpAttrs const &op) { + return op.visit(overload{ + [](BatchNormAttrs const &attrs) { return init(attrs); }, + [](Conv2DAttrs const &attrs) { return init(attrs); }, + [](DropoutAttrs const &attrs) { return init(attrs); }, + [](ElementBinaryAttrs const &attrs) { return init(attrs); }, + [](ElementUnaryAttrs const &attrs) { return init(attrs); }, + [](GatherAttrs const &attrs) { return init(attrs); }, + [](LayerNormAttrs const &attrs) { return init(attrs); }, + [](LinearAttrs const &attrs) { return init(attrs); }, + [](MultiHeadAttentionAttrs const &attrs) { return init(attrs); }, + [](Pool2DAttrs const &attrs) { return init(attrs); }, + [](ReduceAttrs const &attrs) { return init(attrs); }, + [](ReshapeAttrs const &attrs) { return init(attrs); }, + [](SoftmaxAttrs const &attrs) { return init(attrs); }, + [](TopKAttrs const &attrs) { return init(attrs); }, + [](TransposeAttrs const &attrs) { return init(attrs); }, + [](auto const &attrs) -> OpTaskInvocation { + throw mk_runtime_error(fmt::format("Unhandled attr type {}", attrs)); + }, + }); +} + +OpTaskInvocation forward(ComputationGraphOpAttrs const &op) { + return op.visit(overload{ + [](BatchMatmulAttrs const &attrs) { return forward(attrs); }, + [](BatchNormAttrs const &attrs) { return forward(attrs); }, + [](CastAttrs const &attrs) { return forward(attrs); }, + [](ConcatAttrs const &attrs) { return forward(attrs); }, + [](Conv2DAttrs const &attrs) { return forward(attrs); }, + [](DropoutAttrs const &attrs) { return forward(attrs); }, + [](ElementBinaryAttrs const &attrs) { return forward(attrs); }, + [](ElementUnaryAttrs const &attrs) { return forward(attrs); }, + // [](EmbeddingAttrs const & attrs) { + // return forward(attrs); + // }, + [](FlatAttrs const &attrs) { return forward(attrs); }, + [](GatherAttrs const &attrs) { return forward(attrs); }, + [](LayerNormAttrs const &attrs) { return forward(attrs); }, + [](LinearAttrs const &attrs) { return forward(attrs); }, + [](MultiHeadAttentionAttrs const &attrs) { return forward(attrs); }, + [](Pool2DAttrs const &attrs) { return forward(attrs); }, + [](ReduceAttrs const &attrs) { return forward(attrs); }, + [](ReverseAttrs const &attrs) { return forward(attrs); }, + [](ReshapeAttrs const &attrs) { return forward(attrs); }, + [](SplitAttrs const &attrs) { return forward(attrs); }, + [](SoftmaxAttrs const &attrs) { return forward(attrs); }, + [](TopKAttrs const &attrs) { return forward(attrs); }, + [](TransposeAttrs const &attrs) { return forward(attrs); }, + [](auto const &attrs) -> OpTaskInvocation { + throw mk_runtime_error(fmt::format("Unhandled attr type {}", attrs)); + }, + }); +} + +OpTaskInvocation backward(ComputationGraphOpAttrs const &op) { + return op.visit(overload{ + [](BatchMatmulAttrs const &attrs) { return backward(attrs); }, + [](BatchNormAttrs const &attrs) { return backward(attrs); }, + [](CastAttrs const &attrs) { return backward(attrs); }, + [](ConcatAttrs const &attrs) { return backward(attrs); }, + [](Conv2DAttrs const &attrs) { return backward(attrs); }, + [](DropoutAttrs const &attrs) { return backward(attrs); }, + [](ElementBinaryAttrs const &attrs) { return backward(attrs); }, + [](ElementUnaryAttrs const &attrs) { return backward(attrs); }, + // [](EmbeddingAttrs const & attrs) { + // return backward(attrs); + // }, + [](FlatAttrs const &attrs) { return backward(attrs); }, + [](GatherAttrs const &attrs) { return backward(attrs); }, + [](LayerNormAttrs const &attrs) { return backward(attrs); }, + [](LinearAttrs const &attrs) { return backward(attrs); }, + [](MultiHeadAttentionAttrs const &attrs) { return backward(attrs); }, + [](Pool2DAttrs const &attrs) { return backward(attrs); }, + [](ReduceAttrs const &attrs) { return backward(attrs); }, + [](ReverseAttrs const &attrs) { return backward(attrs); }, + [](ReshapeAttrs const &attrs) { return backward(attrs); }, + [](SplitAttrs const &attrs) { return backward(attrs); }, + [](SoftmaxAttrs const &attrs) { return backward(attrs); }, + [](TopKAttrs const &attrs) { return backward(attrs); }, + [](TransposeAttrs const &attrs) { return backward(attrs); }, + [](auto const &attrs) -> OpTaskInvocation { + throw mk_runtime_error(fmt::format("Unhandled attr type {}", attrs)); + }, + }); +} + } // namespace FlexFlow diff --git a/lib/local-execution/src/tracked_allocator.cc b/lib/local-execution/src/tracked_allocator.cc index 68636906c3..e6c3a11711 100644 --- a/lib/local-execution/src/tracked_allocator.cc +++ b/lib/local-execution/src/tracked_allocator.cc @@ -7,13 +7,14 @@ TrackedAllocator::TrackedAllocator(Allocator a) : allocator(a) {} void *TrackedAllocator::allocate(size_t requested_memory_size) { void *ptr = this->allocator.allocate(requested_memory_size); + this->ptr_mem_usage.insert({ptr, requested_memory_size}); this->current_mem_usage += requested_memory_size; return ptr; } void TrackedAllocator::deallocate(void *ptr) { size_t psize; - checkCUDA(cudaGetSymbolSize(&psize, ptr)); + this->ptr_mem_usage.erase(ptr); this->allocator.deallocate(ptr); this->current_mem_usage -= psize; } diff --git a/lib/local-execution/test/CMakeLists.txt b/lib/local-execution/test/CMakeLists.txt new file mode 100644 index 0000000000..930ab5c4e2 --- /dev/null +++ b/lib/local-execution/test/CMakeLists.txt @@ -0,0 +1,14 @@ +ff_add_test_executable( + NAME + local-execution-tests + SRC_PATTERNS + src/*.cc + PRIVATE_INCLUDE + src/ + DEPS + doctest + utils-test-common + local-execution + kernels + op-attrs +) diff --git a/lib/local-execution/test/src/test_local_cost_estimator.cc b/lib/local-execution/test/src/test_local_cost_estimator.cc new file mode 100644 index 0000000000..2bd0acc222 --- /dev/null +++ b/lib/local-execution/test/src/test_local_cost_estimator.cc @@ -0,0 +1,77 @@ +#include "doctest/doctest.h" +#include "kernels/local_cuda_allocator.h" +#include "kernels/managed_per_device_ff_handle.h" +#include "local-execution/local_cost_estimator.h" +#include "pcg/computation_graph_builder.h" +#include "test_utils.h" + +namespace FlexFlow { + +TEST_SUITE(FF_CUDA_TEST_SUITE) { + TEST_CASE("Local Cost Estimator") { + // local backing initialization + ManagedPerDeviceFFHandle managed_handle{}; + + RuntimeArgConfig runtime_arg_config = RuntimeArgConfig{ + DeviceSpecific::create(managed_handle.raw_handle()), + EnableProfiling::YES, + ProfilingSettings{/*warmup_iters=*/0, + /*measure_iters=*/1}}; + + LocalCostEstimator cost_estimator = LocalCostEstimator{runtime_arg_config}; + + SUBCASE("Estimate cost -- Attention Op") { + int embed_dim = 32; + int num_heads = 10; + MultiHeadAttentionAttrs attrs = MultiHeadAttentionAttrs{ + /*embed_dim=*/embed_dim, + /*num_heads=*/num_heads, + /*kdim=*/embed_dim, + /*vdim=*/embed_dim, + /*dropout=*/0.0, + /*bias=*/true, + /*add_bias_kv=*/false, + /*add_zero_attn=*/false, + }; + + size_t batch_size = 40; + size_t seq_len = 48; + size_t feature_size = 36; + + DataType dtype = DataType::FLOAT; + ParallelTensorShape inputs_shape = lift_to_parallel(TensorShape{ + TensorDims{FFOrdered{batch_size, seq_len, feature_size}}, + DataType::FLOAT, + }); + + ParallelTensorShape weights_shape = throw_if_unexpected( + get_weights_shape(attrs, inputs_shape, inputs_shape, inputs_shape)); + ParallelTensorAttrs weight_attrs = + ParallelTensorAttrs{weights_shape, + /*sync_type=*/std::nullopt, + /*initializer=*/std::nullopt, + CreateGrad::YES}; + + ParallelTensorShape output_shape = throw_if_unexpected( + get_output_shape(attrs, inputs_shape, inputs_shape, inputs_shape)); + ParallelTensorAttrs output_attrs = + ParallelTensorAttrs{output_shape, + /*sync_type=*/std::nullopt, + /*initializer=*/std::nullopt, + CreateGrad::YES}; + + CostDetails result = cost_estimator.estimate_cost( + PCGOperatorAttrs{attrs}, + std::vector{ + inputs_shape, inputs_shape, inputs_shape}, + std::vector{weight_attrs}, + std::vector{output_attrs}, + make_1d_machine_view(gpu_id_t{0}, gpu_id_t{1})); + + CHECK(result.total_elapsed_time > 0); + CHECK(result.total_mem_usage > 0); + } + } +} + +} // namespace FlexFlow diff --git a/lib/local-execution/test/src/test_local_slots_backing.cc b/lib/local-execution/test/src/test_local_slots_backing.cc new file mode 100644 index 0000000000..542aa66087 --- /dev/null +++ b/lib/local-execution/test/src/test_local_slots_backing.cc @@ -0,0 +1,273 @@ +#include "doctest/doctest.h" +#include "kernels/attention_kernels.h" +#include "local-execution/local_cost_estimator.h" +#include "local-execution/local_cpu_allocator.h" +#include "local-execution/local_slots_backing.h" +#include "pcg/computation_graph_builder.h" +#include "test_utils.h" +#include "utils/fmt/unordered_map.h" +#include "utils/fmt/variant.h" +#include "utils/fmt/vector.h" + +namespace FlexFlow { + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("LocalSlotsBacking -- Attention Op") { + // allocate input memory + Allocator allocator = create_local_cpu_memory_allocator(); + int embed_dim = 32; + int num_heads = 10; + + size_t batch_size = 40; + size_t seq_len = 48; + size_t feature_size = 36; + + DataType dtype = DataType::FLOAT; + TensorShape input_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, seq_len, feature_size}}, + DataType::FLOAT, + }; + TensorShape query_shape = input_tensor_shape; + TensorShape key_shape = input_tensor_shape; + TensorShape value_shape = input_tensor_shape; + GenericTensorAccessorW query = allocator.allocate_tensor(query_shape); + GenericTensorAccessorW key = allocator.allocate_tensor(key_shape); + GenericTensorAccessorW value = allocator.allocate_tensor(value_shape); + + // build graph + ComputationGraphBuilder cg_builder; + tensor_guid_t query_guid = + cg_builder.create_tensor(query_shape, CreateGrad::YES); + tensor_guid_t key_guid = + cg_builder.create_tensor(key_shape, CreateGrad::YES); + tensor_guid_t value_guid = + cg_builder.create_tensor(value_shape, CreateGrad::YES); + + std::string layer_name = "attn1"; + tensor_guid_t output_guid = + cg_builder.multihead_attention(query_guid, + key_guid, + value_guid, + embed_dim, + num_heads, + /*kdim=*/embed_dim, + /*vdim=*/embed_dim, + /*dropout=*/0.0f, + /*bias=*/true, + /*add_bias_kv=*/false, + /*add_zero_attn=*/false, + /*initializer=*/std::nullopt, + /*maybe_name=*/layer_name); + + layer_guid_t layer_guid = + get_layer_by_name(cg_builder.computation_graph, layer_name); + + TensorBackingMap tensor_backing_map = { + {query_guid, query}, {key_guid, key}, {value_guid, value}}; + + // runtime arg config + ProfilingSettings settings = ProfilingSettings{/*warmup_iters=*/0, + /*measure_iters=*/0}; + PerDeviceFFHandle handle = get_mock_per_device_ff_handle(); + RuntimeArgConfig runtime_arg_config = + RuntimeArgConfig{DeviceSpecific::create(handle), + EnableProfiling::NO, + settings}; + + LocalSlotsBacking local_slots_backing = {tensor_backing_map, + runtime_arg_config}; + + SUBCASE("LocalSlotsBacking::allocate_outgoing_tensors") { + auto get_result_shape_and_dtype_for_tensor_guid_and_map = + [&](tensor_guid_t t, + TensorBackingMap m) -> std::pair { + GenericTensorAccessorW accessor = m.at(t); + return get_shape_and_datatype(accessor); + }; + + SUBCASE("Input (QKV) and gradient tensors allocation") { + + // allocate all tensors from input nodes + for (layer_guid_t const &node : + topological_ordering(cg_builder.computation_graph)) { + if (node == layer_guid) { + break; + } + local_slots_backing.allocate_outgoing_tensors( + node, cg_builder.computation_graph, allocator); + } + + SUBCASE("Query grad") { + std::pair result = + get_result_shape_and_dtype_for_tensor_guid_and_map( + query_guid, local_slots_backing.gradient_tensor_mapping); + std::pair correct = {ArrayShape{query_shape}, + dtype}; + CHECK(result == correct); + } + SUBCASE("Key grad") { + std::pair result = + get_result_shape_and_dtype_for_tensor_guid_and_map( + key_guid, local_slots_backing.gradient_tensor_mapping); + std::pair correct = {ArrayShape{key_shape}, + dtype}; + CHECK(result == correct); + } + SUBCASE("Value grad") { + std::pair result = + get_result_shape_and_dtype_for_tensor_guid_and_map( + value_guid, local_slots_backing.gradient_tensor_mapping); + std::pair correct = {ArrayShape{value_shape}, + dtype}; + CHECK(result == correct); + } + } + SUBCASE("Output and gradient tensors allocation") { + local_slots_backing.allocate_outgoing_tensors( + layer_guid, cg_builder.computation_graph, allocator); + SUBCASE("Output") { + std::pair result = + get_result_shape_and_dtype_for_tensor_guid_and_map( + output_guid, local_slots_backing.tensor_mapping); + std::pair correct = { + ArrayShape{ + get_tensor_attrs(cg_builder.computation_graph, output_guid) + .shape}, + dtype}; + CHECK(result == correct); + } + SUBCASE("Output grad") { + std::pair result = + get_result_shape_and_dtype_for_tensor_guid_and_map( + output_guid, local_slots_backing.gradient_tensor_mapping); + std::pair correct = { + ArrayShape{ + get_tensor_attrs(cg_builder.computation_graph, output_guid) + .shape}, + dtype}; + CHECK(result == correct); + } + } + + SUBCASE("Tensor slots") { + local_slots_backing.allocate_outgoing_tensors( + layer_guid, cg_builder.computation_graph, allocator); + SUBCASE("Input tensor slots") { + std::vector correct_incoming_tensors = + get_incoming_tensors(cg_builder.computation_graph, layer_guid); + CHECK(correct_incoming_tensors == + local_slots_backing.input_tensor_slots.at(layer_guid)); + } + SUBCASE("Output tensor slots") { + std::vector correct_outgoing_tensors = + get_outgoing_tensors(cg_builder.computation_graph, layer_guid); + CHECK(correct_outgoing_tensors == + local_slots_backing.output_tensor_slots.at(layer_guid)); + } + } + } + + SUBCASE("Construct Slots Backings") { + enum Slots { + QUERY, + KEY, + VALUE, + WEIGHTS, + OUTPUT, + QUERY_PARALLEL_TENSOR_SHAPE, + QPROJSIZE, + ATTRS, + PROFILING, + HANDLE, + }; + MultiHeadAttentionAttrs attrs = + get_layer_attrs(cg_builder.computation_graph, layer_guid) + .attrs.get(); + OpTaskBinding binding = [&] { + OpTaskBinding b; + b.bind(QUERY, input_tensor(0)); + b.bind(KEY, input_tensor(1)); + b.bind(VALUE, input_tensor(2)); + b.bind(WEIGHTS, weight_tensor(3)); + b.bind(OUTPUT, output_tensor(0)); + + b.bind_grad(QUERY, input_tensor(0)); + + b.bind_arg(QPROJSIZE, get_qProjSize(attrs)); + b.bind_arg(ATTRS, attrs); + b.bind_arg(QUERY_PARALLEL_TENSOR_SHAPE, input_parallel_tensor_shape(0)); + b.bind_arg(PROFILING, profiling_settings()); + b.bind_arg(HANDLE, ff_handle()); + return b; + }(); + + // allocate all incoming and outgoing tensors for graph + for (layer_guid_t const &node : + topological_ordering(cg_builder.computation_graph)) { + local_slots_backing.allocate_outgoing_tensors( + node, cg_builder.computation_graph, allocator); + } + + SUBCASE("LocalSlotsBacking::construct_tensor_slots_backing") { + TensorSlotsBackingWithoutAddresses result = + get_slots_backing_without_tensor_allocation_addresses( + local_slots_backing.construct_tensor_slots_backing(binding, + layer_guid)); + TensorSlotsBackingWithoutAddresses correct = [&] { + TensorShape weights_shape = throw_if_unexpected( + get_weights_shape(attrs, query_shape, key_shape, value_shape)); + GenericTensorAccessorW weights = + allocator.allocate_tensor(weights_shape); + + TensorAttrs output_attrs = + get_tensor_attrs(cg_builder.computation_graph, output_guid); + GenericTensorAccessorW output = + allocator.allocate_tensor(output_attrs.shape); + return get_slots_backing_without_tensor_allocation_addresses( + TensorSlotsBacking{ + {SlotGradId{slot_id_t{QUERY}, IsGrad::NO}, query}, + {SlotGradId{slot_id_t{KEY}, IsGrad::NO}, key}, + {SlotGradId{slot_id_t{VALUE}, IsGrad::NO}, value}, + {SlotGradId{slot_id_t{WEIGHTS}, IsGrad::NO}, weights}, + {SlotGradId{slot_id_t{OUTPUT}, IsGrad::NO}, output}, + {SlotGradId{slot_id_t{QUERY}, IsGrad::YES}, query}}); + }(); + + CHECK(result == correct); + } + SUBCASE("LocalSlotsBacking::construct_arg_slots_backing") { + ArgSlotsBacking result = + local_slots_backing.construct_arg_slots_backing(binding, + layer_guid); + + ArgSlotsBacking correct = [&] { + ParallelTensorShape query_parallel_tensor_shape = + lift_to_parallel(query_shape); + + return ArgSlotsBacking{ + {slot_id_t{QPROJSIZE}, + ConcreteArgSpec::create(get_qProjSize(attrs))}, + {slot_id_t{ATTRS}, ConcreteArgSpec::create(attrs)}, + {slot_id_t{QUERY_PARALLEL_TENSOR_SHAPE}, + ConcreteArgSpec::create(query_parallel_tensor_shape)}, + {slot_id_t{PROFILING}, + ConcreteArgSpec::create(runtime_arg_config.profiling_settings)}, + {slot_id_t{HANDLE}, ConcreteArgSpec::create(handle)}}; + }(); + + CHECK(result == correct); + } + + SUBCASE("LocalSlotsBacking::resolve_runtime_arg_ref_spec") { + RuntimeArgRefSpec ref_spec = RuntimeArgRefSpec::create(ff_handle()); + ConcreteArgSpec arg_spec = + local_slots_backing.resolve_runtime_arg_ref_spec(ref_spec); + + PerDeviceFFHandle result_handle = arg_spec.get(); + CHECK(result_handle == handle); + } + } + } +} + +} // namespace FlexFlow diff --git a/lib/local-execution/test/src/test_local_task_arg_accessor.cc b/lib/local-execution/test/src/test_local_task_arg_accessor.cc new file mode 100644 index 0000000000..819d773e00 --- /dev/null +++ b/lib/local-execution/test/src/test_local_task_arg_accessor.cc @@ -0,0 +1,143 @@ +#include "doctest/doctest.h" +#include "local-execution/local_cpu_allocator.h" +#include "local-execution/local_task_argument_accessor.h" +#include "local-execution/task_signature_impl.h" + +namespace FlexFlow { + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("LocalTaskArgumentAccessor") { + Allocator allocator = create_local_cpu_memory_allocator(); + int embed_dim = 32; + int num_heads = 10; + + size_t batch_size = 40; + size_t seq_len = 48; + size_t feature_size = 36; + + DataType dtype = DataType::FLOAT; + TensorShape input_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, seq_len, feature_size}}, + DataType::FLOAT, + }; + + GenericTensorAccessorW input = + allocator.allocate_tensor(input_tensor_shape); + GenericTensorAccessorW input_grad = + allocator.allocate_tensor(input_tensor_shape); + + std::vector variadic_tensors = {input, input}; + std::vector variadic_tensors_grad = {input_grad, + input_grad}; + + enum Slots { + INPUT, + VARIADIC_TENSORS, + }; + + TensorSlotsBacking tensor_slots_backing = { + {SlotGradId{slot_id_t{INPUT}, IsGrad::NO}, input}, + {SlotGradId{slot_id_t{INPUT}, IsGrad::YES}, input_grad}, + {SlotGradId{slot_id_t{VARIADIC_TENSORS}, IsGrad::NO}, variadic_tensors}, + {SlotGradId{slot_id_t{VARIADIC_TENSORS}, IsGrad::YES}, + variadic_tensors_grad}, + }; + + LocalTaskArgumentAccessor acc = {allocator, tensor_slots_backing, {}}; + + SUBCASE("get_tensor") { + SUBCASE("get_tensor(slot_id_t, Permissions::RO, IsGrad::NO)") { + GenericTensorAccessor correct = GenericTensorAccessor{ + read_only_accessor_from_write_accessor(input)}; + GenericTensorAccessor result = + acc.get_tensor(slot_id_t{INPUT}, Permissions::RO, IsGrad::NO); + CHECK(correct == result); + } + SUBCASE("get_tensor(slot_id_t, Permissions::RO, IsGrad::YES)") { + GenericTensorAccessor correct = GenericTensorAccessor{ + read_only_accessor_from_write_accessor(input_grad)}; + GenericTensorAccessor result = + acc.get_tensor(slot_id_t{INPUT}, Permissions::RO, IsGrad::YES); + CHECK(correct == result); + } + SUBCASE("get_tensor(slot_id_t, Permissions::WO, IsGrad::NO)") { + GenericTensorAccessor correct = GenericTensorAccessor{input}; + GenericTensorAccessor result = + acc.get_tensor(slot_id_t{INPUT}, Permissions::WO, IsGrad::NO); + CHECK(correct == result); + } + SUBCASE("get_tensor(slot_id_t, Permissions::WO, IsGrad::YES)") { + GenericTensorAccessor correct = GenericTensorAccessor{input_grad}; + GenericTensorAccessor result = + acc.get_tensor(slot_id_t{INPUT}, Permissions::WO, IsGrad::YES); + CHECK(correct == result); + } + SUBCASE("get_tensor(slot_id_t, Permissions::RW, IsGrad::NO)") { + GenericTensorAccessor correct = GenericTensorAccessor{input}; + GenericTensorAccessor result = + acc.get_tensor(slot_id_t{INPUT}, Permissions::RW, IsGrad::NO); + CHECK(correct == result); + } + SUBCASE("get_tensor(slot_id_t, Permissions::RW, IsGrad::YES)") { + GenericTensorAccessor correct = GenericTensorAccessor{input_grad}; + GenericTensorAccessor result = + acc.get_tensor(slot_id_t{INPUT}, Permissions::RW, IsGrad::YES); + CHECK(correct == result); + } + } + + SUBCASE("get_variadic_tensor") { + SUBCASE("get_variadic_tensor(slot_id_t, Permissions::RO, IsGrad::NO)") { + VariadicGenericTensorAccessor correct = + VariadicGenericTensorAccessor{std::vector{ + read_only_accessor_from_write_accessor(variadic_tensors.at(0)), + read_only_accessor_from_write_accessor( + variadic_tensors.at(1))}}; + VariadicGenericTensorAccessor result = acc.get_variadic_tensor( + slot_id_t{VARIADIC_TENSORS}, Permissions::RO, IsGrad::NO); + CHECK(result == correct); + } + SUBCASE("get_variadic_tensor(slot_id_t, Permissions::RO, IsGrad::YES)") { + VariadicGenericTensorAccessor correct = + VariadicGenericTensorAccessor{std::vector{ + read_only_accessor_from_write_accessor( + variadic_tensors_grad.at(0)), + read_only_accessor_from_write_accessor( + variadic_tensors_grad.at(1))}}; + VariadicGenericTensorAccessor result = acc.get_variadic_tensor( + slot_id_t{VARIADIC_TENSORS}, Permissions::RO, IsGrad::YES); + CHECK(result == correct); + } + SUBCASE("get_variadic_tensor(slot_id_t, Permissions::WO, IsGrad::NO)") { + VariadicGenericTensorAccessor correct = + VariadicGenericTensorAccessor{variadic_tensors}; + VariadicGenericTensorAccessor result = acc.get_variadic_tensor( + slot_id_t{VARIADIC_TENSORS}, Permissions::WO, IsGrad::NO); + CHECK(result == correct); + } + SUBCASE("get_variadic_tensor(slot_id_t, Permissions::WO, IsGrad::YES)") { + VariadicGenericTensorAccessor correct = + VariadicGenericTensorAccessor{variadic_tensors_grad}; + VariadicGenericTensorAccessor result = acc.get_variadic_tensor( + slot_id_t{VARIADIC_TENSORS}, Permissions::WO, IsGrad::YES); + CHECK(result == correct); + } + SUBCASE("get_variadic_tensor(slot_id_t, Permissions::WO, IsGrad::NO)") { + VariadicGenericTensorAccessor correct = + VariadicGenericTensorAccessor{variadic_tensors}; + VariadicGenericTensorAccessor result = acc.get_variadic_tensor( + slot_id_t{VARIADIC_TENSORS}, Permissions::RW, IsGrad::NO); + CHECK(result == correct); + } + SUBCASE("get_variadic_tensor(slot_id_t, Permissions::WO, IsGrad::YES)") { + VariadicGenericTensorAccessor correct = + VariadicGenericTensorAccessor{variadic_tensors_grad}; + VariadicGenericTensorAccessor result = acc.get_variadic_tensor( + slot_id_t{VARIADIC_TENSORS}, Permissions::RW, IsGrad::YES); + CHECK(result == correct); + } + } + } +} + +} // namespace FlexFlow diff --git a/lib/local-execution/test/src/test_task_registry.cc b/lib/local-execution/test/src/test_task_registry.cc new file mode 100644 index 0000000000..fa3b068425 --- /dev/null +++ b/lib/local-execution/test/src/test_task_registry.cc @@ -0,0 +1,131 @@ +#include "doctest/doctest.h" +#include "kernels/local_cuda_allocator.h" +#include "local-execution/local_cost_estimator.h" +#include "local-execution/ops/attention.h" +#include "local-execution/task_signature_impl.h" +#include "pcg/computation_graph_builder.h" +#include "utils/fmt/optional.h" +#include "utils/fmt/unordered_map.h" + +namespace FlexFlow { + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Task Registry") { + TaskRegistry task_registry = empty_task_registry(); + + layer_guid_t layer_guid = layer_guid_t{Node{0}}; + int embed_dim = 32; + int num_heads = 10; + ComputationGraphOpAttrs attrs = + ComputationGraphOpAttrs{MultiHeadAttentionAttrs{ + /*embed_dim=*/embed_dim, + /*num_heads=*/num_heads, + /*kdim=*/embed_dim, + /*vdim=*/embed_dim, + /*dropout=*/0.0, + /*bias=*/true, + /*add_bias_kv=*/false, + /*add_zero_attn=*/false, + }}; + + SUBCASE("register single layer") { + register_tasks_for_layer(task_registry, layer_guid, attrs); + + TaskRegistry correct_task_registry = [&] { + std::unordered_map> + init_task_ids = {{layer_guid, task_id_t::ATTENTION_INIT_TASK_ID}}; + std::unordered_map> + fwd_task_ids = {{layer_guid, task_id_t::ATTENTION_FWD_TASK_ID}}; + std::unordered_map> + bwd_task_ids = {{layer_guid, task_id_t::ATTENTION_BWD_TASK_ID}}; + std::unordered_map task_mapping = { + {task_id_t::ATTENTION_INIT_TASK_ID, + get_task_sig_impl(task_id_t::ATTENTION_INIT_TASK_ID)}, + {task_id_t::ATTENTION_FWD_TASK_ID, + get_task_sig_impl(task_id_t::ATTENTION_FWD_TASK_ID)}, + {task_id_t::ATTENTION_BWD_TASK_ID, + get_task_sig_impl(task_id_t::ATTENTION_BWD_TASK_ID)}}; + return TaskRegistry{ + init_task_ids, fwd_task_ids, bwd_task_ids, task_mapping}; + }(); + + CHECK(task_registry == correct_task_registry); + } + + SUBCASE("multiple layers same task") { + layer_guid_t other_layer_guid = layer_guid_t{Node{1}}; + register_tasks_for_layer(task_registry, layer_guid, attrs); + register_tasks_for_layer(task_registry, other_layer_guid, attrs); + + SUBCASE("layer to task ids") { + std::unordered_map> correct = { + {layer_guid, task_id_t::ATTENTION_INIT_TASK_ID}, + {other_layer_guid, task_id_t::ATTENTION_INIT_TASK_ID}, + }; + CHECK(correct == task_registry.init_task_ids); + } + + std::unordered_map correct_task_mapping = + {{task_id_t::ATTENTION_INIT_TASK_ID, + get_task_sig_impl(task_id_t::ATTENTION_INIT_TASK_ID)}, + {task_id_t::ATTENTION_FWD_TASK_ID, + get_task_sig_impl(task_id_t::ATTENTION_FWD_TASK_ID)}, + {task_id_t::ATTENTION_BWD_TASK_ID, + get_task_sig_impl(task_id_t::ATTENTION_BWD_TASK_ID)}}; + SUBCASE("task to signature+impl mapping") { + CHECK(correct_task_mapping == task_registry.task_mapping); + } + SUBCASE("different attrs, still same task fn mapping") { + int embed_dim = 100; + layer_guid_t layer_3 = layer_guid_t{Node{3}}; + ComputationGraphOpAttrs other_attrs = + ComputationGraphOpAttrs{MultiHeadAttentionAttrs{ + /*embed_dim=*/embed_dim, + /*num_heads=*/num_heads, + /*kdim=*/embed_dim, + /*vdim=*/embed_dim, + /*dropout=*/0.0, + /*bias=*/true, + /*add_bias_kv=*/false, + /*add_zero_attn=*/false, + }}; + register_tasks_for_layer(task_registry, layer_3, other_attrs); + + CHECK(correct_task_mapping == task_registry.task_mapping); + } + } + + SUBCASE("equality") { + TaskRegistry other_task_registry = empty_task_registry(); + SUBCASE("different attrs is still equal") { + int embed_dim = 100; + ComputationGraphOpAttrs other_attrs = + ComputationGraphOpAttrs{MultiHeadAttentionAttrs{ + /*embed_dim=*/embed_dim, + /*num_heads=*/num_heads, + /*kdim=*/embed_dim, + /*vdim=*/embed_dim, + /*dropout=*/0.0, + /*bias=*/true, + /*add_bias_kv=*/false, + /*add_zero_attn=*/false, + }}; + + register_tasks_for_layer(task_registry, layer_guid, attrs); + register_tasks_for_layer(other_task_registry, layer_guid, other_attrs); + + CHECK(task_registry == other_task_registry); + } + + SUBCASE("different layer_guid is not equal") { + register_tasks_for_layer(task_registry, layer_guid, attrs); + layer_guid_t other_layer_guid = layer_guid_t{Node{1}}; + register_tasks_for_layer(other_task_registry, other_layer_guid, attrs); + + CHECK(task_registry != other_task_registry); + } + } + } +} + +} // namespace FlexFlow diff --git a/lib/local-execution/test/src/test_utils.cc b/lib/local-execution/test/src/test_utils.cc new file mode 100644 index 0000000000..095e1272a2 --- /dev/null +++ b/lib/local-execution/test/src/test_utils.cc @@ -0,0 +1,9 @@ +#include "test_utils.h" + +namespace FlexFlow { + +PerDeviceFFHandle get_mock_per_device_ff_handle() { + return {nullptr, nullptr, nullptr, 0, false}; +} + +} // namespace FlexFlow diff --git a/lib/local-execution/test/src/test_utils.h b/lib/local-execution/test/src/test_utils.h new file mode 100644 index 0000000000..9a7b3f5991 --- /dev/null +++ b/lib/local-execution/test/src/test_utils.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LOCAL_EXECUTION_TEST_UTILS +#define _FLEXFLOW_LOCAL_EXECUTION_TEST_UTILS + +#include "kernels/ff_handle.h" + +namespace FlexFlow { + +PerDeviceFFHandle get_mock_per_device_ff_handle(); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/src/op-attrs/ops/attention.cc b/lib/op-attrs/src/op-attrs/ops/attention.cc index 3e4095eca8..036daa6e67 100644 --- a/lib/op-attrs/src/op-attrs/ops/attention.cc +++ b/lib/op-attrs/src/op-attrs/ops/attention.cc @@ -43,28 +43,52 @@ int get_vSize(TensorShape const &value_shape) { return dim_at_idx(value_shape, ff_dim_t(0)); } -int get_qSize(MultiHeadAttentionParallelInputs const &) { - NOT_IMPLEMENTED(); +int get_qSize(MultiHeadAttentionParallelInputs const &inputs) { + return inputs.query_dim.size; } -int get_qSize(MultiHeadAttentionInputs const &) { - NOT_IMPLEMENTED(); +int get_qSize(MultiHeadAttentionInputs const &inputs) { + return inputs.query_size; } -int get_kSize(MultiHeadAttentionParallelInputs const &) { - NOT_IMPLEMENTED(); +int get_kSize(MultiHeadAttentionParallelInputs const &inputs) { + return inputs.key_dim.size; } -int get_kSize(MultiHeadAttentionInputs const &) { - NOT_IMPLEMENTED(); +int get_kSize(MultiHeadAttentionInputs const &inputs) { + return inputs.key_size; } -int get_vSize(MultiHeadAttentionParallelInputs const &) { - NOT_IMPLEMENTED(); +int get_vSize(MultiHeadAttentionParallelInputs const &inputs) { + return inputs.value_dim.size; } -int get_vSize(MultiHeadAttentionInputs const &) { - NOT_IMPLEMENTED(); +int get_vSize(MultiHeadAttentionInputs const &inputs) { + return inputs.value_size; +} + +int get_kvSeqLength(MultiHeadAttentionParallelInputs const &inputs) { + return inputs.sequence_dim.size; +} + +int get_kvSeqLength(MultiHeadAttentionInputs const &inputs) { + return inputs.sequence_length; +} + +int get_qoSeqLength(MultiHeadAttentionParallelInputs const &inputs) { + return inputs.sequence_dim.size; // FIXME -- assumes only prefill +} + +int get_qoSeqLength(MultiHeadAttentionInputs const &inputs) { + return inputs.sequence_length; // FIXME -- assumes only prefil +} + +int get_num_samples(MultiHeadAttentionParallelInputs const &inputs) { + return inputs.batch_dim.size; +} + +int get_num_samples(MultiHeadAttentionInputs const &inputs) { + return inputs.batch_size; } tl::expected diff --git a/lib/pcg/include/pcg/computation_graph.h b/lib/pcg/include/pcg/computation_graph.h index cdb20c2303..46d5b22afb 100644 --- a/lib/pcg/include/pcg/computation_graph.h +++ b/lib/pcg/include/pcg/computation_graph.h @@ -29,6 +29,9 @@ std::vector get_incoming_tensors(ComputationGraph const &cg, LayerAttrs get_layer_attrs(ComputationGraph const &cg, layer_guid_t const &n); +layer_guid_t get_layer_by_name(ComputationGraph const &cg, + std::string const &name); + } // namespace FlexFlow #endif diff --git a/lib/pcg/include/pcg/computation_graph_builder.h b/lib/pcg/include/pcg/computation_graph_builder.h index c52ec2d5bb..0e453c50c3 100644 --- a/lib/pcg/include/pcg/computation_graph_builder.h +++ b/lib/pcg/include/pcg/computation_graph_builder.h @@ -168,10 +168,11 @@ struct ComputationGraphBuilder { DataType dtype, std::optional const &name = std::nullopt); // Add a concat layer - tensor_guid_t concat(int n, - std::vector const &tensors, - int axis, - std::optional const &name = std::nullopt); + tensor_guid_t + concat(int n, + std::vector const &tensors, + int axis, + std::optional const &maybe_name = std::nullopt); // Add a mean layer tensor_guid_t mean(tensor_guid_t const &input, std::vector const &dims, @@ -224,7 +225,7 @@ struct ComputationGraphBuilder { bool add_bias_kv = false, bool add_zero_attn = false, std::optional initializer = std::nullopt, - std::optional const &name = std::nullopt); + std::optional const &maybe_name = std::nullopt); tensor_guid_t create_tensor(TensorShape const &, CreateGrad); tensor_guid_t create_weight( TensorShape const &, diff --git a/lib/pcg/src/pcg/computation_graph.cc b/lib/pcg/src/pcg/computation_graph.cc index 43eb3ac42b..afa1774858 100644 --- a/lib/pcg/src/pcg/computation_graph.cc +++ b/lib/pcg/src/pcg/computation_graph.cc @@ -1,4 +1,5 @@ #include "pcg/computation_graph.h" +#include "utils/containers/get_only.h" #include "utils/containers/reversed.h" #include "utils/containers/transform.h" #include "utils/graph/dataflow_graph/algorithms.h" @@ -54,4 +55,13 @@ LayerAttrs get_layer_attrs(ComputationGraph const &cg, layer_guid_t const &n) { return cg.raw_graph.at(n.raw_node); } +layer_guid_t get_layer_by_name(ComputationGraph const &cg, + std::string const &name) { + std::unordered_set found = + filter(get_layers(cg), [&](layer_guid_t const &l) { + return get_layer_attrs(cg, l).name == name; + }); + return get_only(found); +} + } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/computation_graph_builder.cc b/lib/pcg/src/pcg/computation_graph_builder.cc index 1dbe191970..5028ed5709 100644 --- a/lib/pcg/src/pcg/computation_graph_builder.cc +++ b/lib/pcg/src/pcg/computation_graph_builder.cc @@ -502,6 +502,47 @@ tensor_guid_t ComputationGraphBuilder::batch_norm( return this->add_layer(layer, {input}, {}, output_shape); } +tensor_guid_t ComputationGraphBuilder::multihead_attention( + tensor_guid_t const &query, + tensor_guid_t const &key, + tensor_guid_t const &value, + int embed_dim, + int num_heads, + int kdim, + int vdim, + float dropout, + bool bias, + bool add_bias_kv, + bool add_zero_attn, + std::optional initializer, + std::optional const &maybe_name) { + + MultiHeadAttentionAttrs attrs = MultiHeadAttentionAttrs{embed_dim, + num_heads, + kdim, + vdim, + dropout, + bias, + add_bias_kv, + add_zero_attn}; + + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); + + LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; + TensorShape output_shape = throw_if_unexpected(get_output_shape( + attrs, get_shape(query), get_shape(key), get_shape(value))); + + TensorShape weights_shape = throw_if_unexpected(get_weights_shape( + attrs, get_shape(query), get_shape(key), get_shape(value))); + TensorAttrs weight_attrs = make_weight_attrs(weights_shape, initializer); + + return this->add_layer(layer, + std::vector{query, key, value}, + {weight_attrs}, + output_shape); +} + TensorShape ComputationGraphBuilder::get_broadcast_target_shape( std::vector const &) { NOT_IMPLEMENTED(); diff --git a/lib/local-execution/src/ops/embedding.cc b/lib/runtime/src/ops/embedding.cc similarity index 100% rename from lib/local-execution/src/ops/embedding.cc rename to lib/runtime/src/ops/embedding.cc diff --git a/lib/utils/include/utils/join_strings.h b/lib/utils/include/utils/join_strings.h index db82004317..9eb717b066 100644 --- a/lib/utils/include/utils/join_strings.h +++ b/lib/utils/include/utils/join_strings.h @@ -18,7 +18,7 @@ std::string join_strings(InputIt first, if (!first_iter) { oss << delimiter; } - oss << *first; + oss << f(*first); /* break; */ first_iter = false; /* i++; */