Skip to content

Commit

Permalink
complete merge
Browse files Browse the repository at this point in the history
  • Loading branch information
april-yyt committed Jan 4, 2024
2 parents 326d953 + 7b00e81 commit 27c9b71
Show file tree
Hide file tree
Showing 51 changed files with 2,908 additions and 1,281 deletions.
11 changes: 7 additions & 4 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ endif()
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${CMAKE_CURRENT_LIST_DIR}/cmake)
set(FLEXFLOW_ROOT ${CMAKE_CURRENT_LIST_DIR})
set(CMAKE_CXX_FLAGS "-std=c++17 ${CMAKE_CXX_FLAGS} -fPIC -UNDEBUG")
set(CMAKE_HIP_FLAGS "-std=c++17 ${CMAKE_HIP_FLAGS} -fPIC -UNDEBUG")

# set std 17
#set(CMAKE_CXX_STANDARD 17)
Expand Down Expand Up @@ -51,6 +52,7 @@ endif()

# do not disable assertions even if in release mode
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -UNDEBUG")
set(CMAKE_HIP_FLAGS_RELEASE "${CMAKE_HIP_FLAGS_RELEASE} -UNDEBUG")

if(${CMAKE_SYSTEM_NAME} MATCHES "Linux")
set(LIBEXT ".so")
Expand Down Expand Up @@ -157,6 +159,7 @@ endif()

# HIP
if (FF_GPU_BACKEND STREQUAL "hip_rocm" OR FF_GPU_BACKEND STREQUAL "hip_cuda")
enable_language(HIP)
include(hip)
endif()

Expand Down Expand Up @@ -299,7 +302,10 @@ if(NOT BUILD_LEGION_ONLY)
LIST_DIRECTORIES False
${FLEXFLOW_ROOT}/src/*.cpp)

if(BUILD_SHARED_LIBS)
set_source_files_properties(${FLEXFLOW_GPU_SRC} PROPERTIES LANGUAGE HIP)
set_source_files_properties(${FLEXFLOW_SRC} PROPERTIES LANGUAGE HIP)

if(BUILD_SHARED_LIBS)
add_library(flexflow SHARED ${FLEXFLOW_GPU_SRC} ${FLEXFLOW_SRC})
else()
add_library(flexflow STATIC ${FLEXFLOW_GPU_SRC} ${FLEXFLOW_SRC})
Expand Down Expand Up @@ -474,9 +480,6 @@ if(NOT BUILD_LEGION_ONLY)
endif()

if(FF_BUILD_ALL_INFERENCE_EXAMPLES OR FF_BUILD_TOKENIZER)
if (FF_GPU_BACKEND STREQUAL "hip_rocm")
SET(SPM_USE_BUILTIN_PROTOBUF OFF CACHE BOOL "Use builtin version of protobuf to compile SentencePiece")
endif()
# Ensure Rust is installed
execute_process(COMMAND rustc --version
RESULT_VARIABLE RUST_COMMAND_RESULT
Expand Down
4 changes: 2 additions & 2 deletions cmake/hip.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ if (NOT FF_HIP_ARCH STREQUAL "")
if (FF_HIP_ARCH STREQUAL "all")
set(FF_HIP_ARCH "gfx900,gfx902,gfx904,gfx906,gfx908,gfx909,gfx90a,gfx90c,gfx940,gfx1010,gfx1011,gfx1012,gfx1013,gfx1030,gfx1031,gfx1032,gfx1033,gfx1034,gfx1035,gfx1036,gfx1100,gfx1101,gfx1102,gfx1103")
endif()
string(REPLACE "," " " HIP_ARCH_LIST "${FF_HIP_ARCH}")
string(REPLACE "," "," HIP_ARCH_LIST "${FF_HIP_ARCH}")
endif()

message(STATUS "FF_HIP_ARCH: ${FF_HIP_ARCH}")
if(FF_GPU_BACKEND STREQUAL "hip_rocm")
set(HIP_CLANG_PATH ${ROCM_PATH}/llvm/bin CACHE STRING "Path to the clang compiler by ROCM" FORCE)
#set(HIP_CLANG_PATH ${ROCM_PATH}/llvm/bin CACHE STRING "Path to the clang compiler by ROCM" FORCE)
set(GPU_TARGETS "${FF_HIP_ARCH}" CACHE STRING "The GPU TARGETs")
endif()
7 changes: 5 additions & 2 deletions config/config.inc
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ if [ -n "$ROCM_PATH" ]; then
SET_ROCM_PATH="-DROCM_PATH=${ROCM_PATH}"
fi

ADD_ROCM_TO_PATH=""

# set GPU backend
if [ -n "$FF_GPU_BACKEND" ]; then
SET_FF_GPU_BACKEND="-DFF_GPU_BACKEND=${FF_GPU_BACKEND}"
Expand Down Expand Up @@ -222,7 +224,8 @@ if [ -n "$FF_GPU_BACKEND" ]; then
chmod +x "$(pwd)/nvidia_hipcc"
SET_CXX="-DCMAKE_CXX_COMPILER=$(pwd)/nvidia_hipcc -DCMAKE_CXX_LINKER=$(pwd)/nvidia_hipcc"
else
SET_CXX="-DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc -DCMAKE_CXX_LINKER=/opt/rocm/bin/hipcc"
ADD_ROCM_TO_PATH="PATH=${PATH}:${ROCM_PATH}/bin"
#SET_CXX="-DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc -DCMAKE_CXX_LINKER=/opt/rocm/bin/hipcc"
fi
fi
fi
Expand All @@ -232,7 +235,7 @@ CMAKE_FLAGS="-DCUDA_USE_STATIC_CUDA_RUNTIME=OFF -DLegion_HIJACK_CUDART=OFF ${SET

function run_cmake() {
SRC_LOCATION=${SRC_LOCATION:=`dirname $0`/../}
CMAKE_COMMAND="${SET_CC_FLAGS} ${SET_NVCC_FLAGS} ${SET_LD_FLAGS} ${SET_CUDA_LIB_PATH} cmake ${CMAKE_FLAGS} $* ${SRC_LOCATION}"
CMAKE_COMMAND="${SET_CC_FLAGS} ${SET_NVCC_FLAGS} ${SET_LD_FLAGS} ${SET_CUDA_LIB_PATH} ${ADD_ROCM_TO_PATH} cmake ${CMAKE_FLAGS} $* ${SRC_LOCATION}"
echo $CMAKE_COMMAND
eval $CMAKE_COMMAND
}
5 changes: 1 addition & 4 deletions docker/flexflow-environment/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,8 @@ RUN if [ "$FF_GPU_BACKEND" = "hip_cuda" ] || [ "$FF_GPU_BACKEND" = "hip_rocm" ]
rm ./${AMD_GPU_SCRIPT_NAME}; \
amdgpu-install -y --usecase=hip,rocm --no-dkms; \
apt-get install -y hip-dev hipblas miopen-hip rocm-hip-sdk rocm-device-libs; \
# Install protobuf v3.20.x manually
# Install protobuf dependencies
apt-get update -y && sudo apt-get install -y pkg-config zip g++ zlib1g-dev autoconf automake libtool make; \
git clone -b 3.20.x https://github.com/protocolbuffers/protobuf.git; cd protobuf/ ; git submodule update --init --recursive; \
./autogen.sh; ./configure; cores_available=$(nproc --all); n_build_cores=$(( cores_available -1 )); \
if (( n_build_cores < 1 )) ; then n_build_cores=1 ; fi; make -j $n_build_cores; make install; ldconfig; cd .. ; \
else \
echo "FF_GPU_BACKEND: ${FF_GPU_BACKEND}. Skipping installing HIP dependencies"; \
fi
Expand Down
58 changes: 48 additions & 10 deletions include/flexflow/batch_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class BatchConfig {
int num_active_tokens() const;
static int max_requests_per_batch();
static int max_tokens_per_batch();
static int max_verify_tokens_per_batch();
static int max_sequence_length();
friend std::ostream &operator<<(std::ostream &os, BatchConfig const &bc);
void print() const;
Expand All @@ -56,6 +57,7 @@ class BatchConfig {
// across workers
static int const MAX_NUM_REQUESTS = 64;
static int const MAX_NUM_TOKENS = 1024;
static int const MAX_SPEC_TREE_TOKEN_NUM = 64;

// Set by update
int num_tokens;
Expand All @@ -68,13 +70,35 @@ class BatchConfig {
int first_token_offset_in_batch;
int num_tokens_in_batch;
int max_sequence_length;

// request id in batch config:
int batch_config_request_id;
bool prompt_phase = false;
RequestGuid request_guid;
};
struct PerTokenInfo {
int abs_depth_in_request;
int request_index;
TokenId token_id;
};

struct BitMask {
unsigned long long mask[MAX_SPEC_TREE_TOKEN_NUM] = {0};

// how many tokens before the tree, every sub requests need this part of
// cache
int non_tree_cache_size = 0;

// current tree size
int tree_size = 0;

int this_layer_size = 0;

// input length-> prompt/root
int prompt_size = 0;
};

BitMask causalMask[MAX_NUM_REQUESTS];
PerRequestInfo requestsInfo[MAX_NUM_REQUESTS];
PerTokenInfo tokensInfo[MAX_NUM_TOKENS];

Expand Down Expand Up @@ -123,32 +147,43 @@ class BeamSearchBatchConfig : public BatchConfig {
bool done() const;
int max_beam_depth_all_requests() const;
int current_depth_all_requests() const;
int get_speculative_request_num() const;

size_t beam_width;
size_t target_iterations;
inline static int const MAX_BEAM_WIDTH = 1;

// how many requests is in speculative phase
int speculative_request_num = 0;
inline static int const MAX_BEAM_WIDTH = 3;
inline static int const MAX_BEAM_DEPTH = 8;

// maximum tree branches for a request
inline static int const MAX_SPECULATIVE_TREE_BRANCHES = 3;

int model_id;

struct BeamSearchPerRequestInfo {
int beam_size;
int current_depth = -1;
int max_depth = MAX_BEAM_DEPTH;

BatchConfig::TokenId tokens[BeamSearchBatchConfig::MAX_BEAM_WIDTH];
float probs[BeamSearchBatchConfig::MAX_BEAM_WIDTH];
int parent_id[BeamSearchBatchConfig::MAX_BEAM_WIDTH];
BatchConfig::TokenId
tokens[BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES];
float probs[BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES];
int parent_id[BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES];
int sub_request_num;
};

struct BeamSearchPerTokenInfo {
int sub_request_index;
};

BeamSearchPerRequestInfo beamRequestsInfo[MAX_NUM_REQUESTS];
BeamSearchPerTokenInfo beamTokenInfo[MAX_NUM_TOKENS * MAX_BEAM_WIDTH];
// why is this == MAX_NUM_REQUESTS * MAX_BEAM_WIDTH?
int sub_requests[MAX_NUM_REQUESTS * MAX_BEAM_WIDTH];
BeamSearchPerTokenInfo
beamTokenInfo[MAX_NUM_TOKENS +
MAX_SPEC_TREE_TOKEN_NUM * MAX_NUM_REQUESTS];

int sub_requests[MAX_NUM_REQUESTS];

private:
size_t current_iteration;
Expand All @@ -157,9 +192,12 @@ class BeamSearchBatchConfig : public BatchConfig {
struct BeamInferenceResult {
static int const MAX_NUM_TOKENS = BatchConfig::MAX_NUM_TOKENS;
BatchConfig::TokenId
token_ids[MAX_NUM_TOKENS * BeamSearchBatchConfig::MAX_BEAM_WIDTH];
float probs[MAX_NUM_TOKENS * BeamSearchBatchConfig::MAX_BEAM_WIDTH];
int parent_id[MAX_NUM_TOKENS * BeamSearchBatchConfig::MAX_BEAM_WIDTH];
token_ids[MAX_NUM_TOKENS *
BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES];
float probs[MAX_NUM_TOKENS *
BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES];
int parent_id[MAX_NUM_TOKENS *
BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES];
};

}; // namespace FlexFlow
12 changes: 12 additions & 0 deletions include/flexflow/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#ifndef _FLEXFLOW_CONFIG_H_
#define _FLEXFLOW_CONFIG_H_
#include "ffconst.h"
#include "flexflow/batch_config.h"
#include "legion.h"
#include <cstring>
#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA)
Expand Down Expand Up @@ -75,6 +76,16 @@ struct FFHandler {
#endif
void *workSpace;
size_t workSpaceSize;
void *batch_config_metadata;

// request info + token info + topolopgy mask info
size_t batch_config_metadata_size =
sizeof(BatchConfig::tokensInfo) + sizeof(BatchConfig::requestsInfo) +
sizeof(BeamSearchBatchConfig::beamTokenInfo) +
sizeof(BeamSearchBatchConfig::beamRequestsInfo) +
sizeof(BatchConfig::causalMask) +
sizeof(TreeVerifyBatchConfig::committed_tokens) +
sizeof(BatchConfig::request_completed);
void *offload_reserve_space;
size_t offload_reserve_space_size;
DataType quantization_type;
Expand Down Expand Up @@ -132,6 +143,7 @@ class FFConfig {
size_t workSpaceSize;
Legion::Context lg_ctx;
Legion::Runtime *lg_hlr;
Legion::IndexSpaceT<1> all_gpu_task_is;
// Legion::FieldSpace field_space;
bool syntheticInput, profiling, perform_fusion;
bool inference_debugging;
Expand Down
1 change: 1 addition & 0 deletions include/flexflow/flexflow_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,7 @@ flexflow_tensor_t flexflow_model_add_arg_top_k(flexflow_model_t handle_,
const flexflow_tensor_t input_,
int k,
bool sorted,
bool speculative_decoding,
char const *name);

flexflow_tensor_t flexflow_model_add_beam_top_k(flexflow_model_t handle_,
Expand Down
15 changes: 13 additions & 2 deletions include/flexflow/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ enum TaskIDs {
DROPOUT_BWD_TASK_ID,
EMBED_INIT_TASK_ID,
EMBED_FWD_TASK_ID,
EMBED_INF_TASK_ID,
EMBED_BWD_TASK_ID,
GATHER_INIT_TASK_ID,
GATHER_FWD_TASK_ID,
Expand Down Expand Up @@ -146,6 +147,7 @@ enum TaskIDs {
TOPK_BWD_TASK_ID,
ARG_TOPK_INIT_TASK_ID,
ARG_TOPK_INF_TASK_ID,
ARG_TOPK_INF_SPECULATIVE_TASK_ID,
SAMPLING_INIT_TASK_ID,
SAMPLING_INF_TASK_ID,
ARGMAX_INIT_TASK_ID,
Expand Down Expand Up @@ -240,6 +242,7 @@ enum TaskIDs {
// InferenceManager & RequestManager
RM_LOAD_TOKENS_TASK_ID,
RM_LOAD_POSITION_TASK_ID,
RM_LOAD_BATCH_CONFIG_TASK_ID,
RM_PREPARE_NEXT_BATCH_TASK_ID,
RM_PREPARE_NEXT_BATCH_INIT_TASK_ID,
RM_PREPARE_NEXT_BATCH_BEAM_TASK_ID,
Expand Down Expand Up @@ -674,6 +677,7 @@ class FFModel {
// Tensor *outputs,
int k,
bool sorted,
bool speculative_decoding,
char const *name = NULL);
Tensor argmax(const Tensor input, bool beam_search, char const *name = NULL);
Tensor sampling(const Tensor input, float top_p, char const *name = NULL);
Expand Down Expand Up @@ -1034,8 +1038,15 @@ class FFModel {
void get_metrics();
void backward(int seq_length = -1);
void update();
bool apply_fusion(std::vector<Op *> const &operators,
std::vector<Op *> &new_operators);
bool apply_fusion(
std::vector<Op *> const &operators,
std::vector<Op *> &new_operators,
std::unordered_map<ParallelTensor, std::vector<ParallelTensor>>
*parallel_tensor_mapping = nullptr);
bool check_operators_integrity(
std::vector<Op *> const &old_operators,
std::unordered_map<ParallelTensor, std::vector<ParallelTensor>>
*pt_mapping = nullptr);
Op *get_final_operator() const;
void compile(LossType loss_type,
std::vector<MetricsType> const &metrics,
Expand Down
16 changes: 14 additions & 2 deletions include/flexflow/ops/arg_topk.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ class ArgTopKMeta : public OpMeta {
public:
ArgTopKMeta(FFHandler handle, Op const *op);
bool sorted;
int k;
bool speculative_decoding;
};

class ArgTopK : public Op {
Expand All @@ -23,6 +25,7 @@ class ArgTopK : public Op {
const ParallelTensor input,
int k,
bool sorted,
bool speculative_decoding,
char const *name);
ArgTopK(FFModel &model,
LayerID const &layer_guid,
Expand Down Expand Up @@ -61,6 +64,11 @@ class ArgTopK : public Op {
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
static BeamInferenceResult inference_speculative_task(
Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
void serialize(Legion::Serializer &s) const override;
static PCG::Node deserialize(FFModel &ff,
Legion::Deserializer &d,
Expand All @@ -75,22 +83,26 @@ class ArgTopK : public Op {
template <typename DT>
static void forward_kernel(ArgTopKMeta const *m,
DT const *input_ptr,
// float *output_ptr,
float *output_ptr,
int *indices_ptr,
size_t batch_size,
int length,
int k,
bool sorted,
BeamSearchBatchConfig const *bc,
ffStream_t stream);
static void forward_kernel_wrapper(ArgTopKMeta const *m,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &prob,
GenericTensorAccessorW const &indices,
int batch_size);
int batch_size,
BeamSearchBatchConfig const *bc);
Params get_params() const;

public:
int k;
bool sorted;
bool speculative_decoding;
};

}; // namespace FlexFlow
Expand Down
1 change: 1 addition & 0 deletions include/flexflow/ops/arg_topk_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ struct ArgTopKParams {
LayerID layer_guid;
int k;
bool sorted;
bool speculative_decoding;
bool is_valid(ParallelTensorShape const &) const;
};
bool operator==(ArgTopKParams const &, ArgTopKParams const &);
Expand Down
4 changes: 4 additions & 0 deletions include/flexflow/ops/embedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ class Embedding : public Op {
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
static void inference_task(Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
static void backward_task(Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Expand Down
Loading

0 comments on commit 27c9b71

Please sign in to comment.