Skip to content

Commit

Permalink
Merge branch 'inference' into prep_model_weights
Browse files Browse the repository at this point in the history
  • Loading branch information
jiazhihao authored Jul 17, 2023
2 parents 7ee168a + 58b745d commit 88d6255
Show file tree
Hide file tree
Showing 85 changed files with 2,793 additions and 490 deletions.
7 changes: 5 additions & 2 deletions .github/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ for serving generative LLMs while provably preserving model quality.
</p>

## Build/Install SpecInfer
SpecInfer is built on top of FlexFlow. You can build/install SpecInfer by building the inference branch of FlexFlow. Please read the [instructions](../INSTALL.md) for building/installing FlexFlow from source code. If you would like to quickly try SpecInfer, we also provide pre-built Docker packages ([flexflow-cuda](https://github.com/flexflow/FlexFlow/pkgs/container/flexflow-cuda) with a CUDA backend, [flexflow-hip_rocm](https://github.com/flexflow/FlexFlow/pkgs/container/flexflow-hip_rocm) with a HIP-ROCM backend) with all dependencies pre-installed (N.B.: currently, the CUDA pre-built containers are only fully compatible with host machines that have CUDA 11.7 installed), together with [Dockerfiles](./docker) if you wish to build the containers manually.
SpecInfer is built on top of FlexFlow. You can build/install SpecInfer by building the inference branch of FlexFlow. Please read the [instructions](../INSTALL.md) for building/installing FlexFlow from source code. If you would like to quickly try SpecInfer, we also provide pre-built Docker packages ([specinfer-cuda](https://github.com/flexflow/FlexFlow/pkgs/container/specinfer-cuda) with a CUDA backend, [specinfer-hip_rocm](https://github.com/flexflow/FlexFlow/pkgs/container/specinfer-hip_rocm) with a HIP-ROCM backend) with all dependencies pre-installed (N.B.: currently, the CUDA pre-built containers are only fully compatible with host machines that have CUDA 11.7 installed), together with [Dockerfiles](./docker) if you wish to build the containers manually.

## Run SpecInfer
The source code of the SpecInfer pipeline is available at [this folder](../inference/spec_infer/). The SpecInfer executable will be available at `/build_dir/inference/spec_infer/spec_infer` at compilation. You can use the following command-line arguments to run SpecInfer:
Expand All @@ -44,7 +44,10 @@ The source code of the SpecInfer pipeline is available at [this folder](../infer
* `-ssm-weight`: path to the folder that stores the small speculative models' weights. The number of `-ssm-weight`s must match the number of `-ssm-model`s and `-ssm-config`s.
* `-ssm-config`: path to the json file that stores the SSM model configs. The number of `-ssm-config`s must match the number of `-ssm-model`s and `-ssm-weight`s.
* `-tokenizer`: path to the tokenizer file (see [Tokenizers](#tokenizers) for preparing a tokenizer for SpecInfer).
* `-data-parallelism-degree`, `-tensor-parallelism-degree` and `-pipeline-parallelism-degree`: parallelization degrees in the data, tensor, and pipeline dimensions. Their product must equal the number of GPUs available on the machine. When any of the three parallelism degree arguments is omitted, a default value of 1 will be used.
* `-prompt`: (optional) path to the prompt file. SpecInfer expects a json format file for prompts, all of which will be served by SpecInfer. In addition, users can also use the following API for registering requests:
* `-output-file`: (optional) filepath to use to save the output of the model, together with the generation latency


```c++
class RequestManager {
Expand All @@ -54,7 +57,7 @@ class RequestManager {
For example, you can use the following command line to serve a LLaMA-7B or LLaMA-13B model on 4 GPUs and use two collectively boost-tuned LLaMA-190M models for speculative inference.
```bash
./inference/spec_infer/spec_infer -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 -llm-model llama -llm-weight /path/to/llm/weights -llm-config /path/to/llm/config.json -ssm-model llama -ssm-weight /path/to/ssm1/weights -ssm-config /path/to/ssm/config.json -ssm-model llama -smm-weight /path/to/ssm2/weights -ssm-config /path/to/ssm2/config.json -tokenizer /path/to/tokenizer.model -prompt /path/to/prompt.json --use-full-precision
./inference/spec_infer/spec_infer -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 -llm-model llama -llm-weight /path/to/llm/weights -llm-config /path/to/llm/config.json -ssm-model llama -ssm-weight /path/to/ssm1/weights -ssm-config /path/to/ssm/config.json -ssm-model llama -smm-weight /path/to/ssm2/weights -ssm-config /path/to/ssm2/config.json -tokenizer /path/to/tokenizer.model -prompt /path/to/prompt.json --use-full-precision -tensor-parallelism-degree 2 -pipeline-parallelism-degree 2
```

### Tokenizers
Expand Down
10 changes: 10 additions & 0 deletions .github/workflows/gpu-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,17 @@ jobs:
./tests/gpt_tokenizer_test.sh
# Inference tests
export TENSOR_PARALLELISM_TESTS=ON
./tests/inference_tests.sh
cd inference
tar -zcvf output.tar.gz ./output
cd ..
- name: Save inference output as an artifact
uses: actions/upload-artifact@v3
with:
name: output
path: inference/output.tar.gz

gpu-ci-flexflow:
name: Single Machine, Multiple GPUs Tests
Expand Down
6 changes: 3 additions & 3 deletions INSTALL.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,11 @@ To run the Python examples, you have two options: you can use the `flexflow_pyth
* `export PYTHONPATH="${FF_HOME}/python:${FF_HOME}/build/python"`
* `export FF_USE_NATIVE_PYTHON=1`

**We recommend that you run the `mnist_mlp` test under `native` using the following cmd to check if FlexFlow has been installed correctly:**
**We recommend that you run the** `mnist_mlp` **test under** `native` **using the following cmd to check if FlexFlow has been installed correctly:**

```
cd python
./flexflow_python examples/python/native/mnist_mlp.py -ll:py 1 -ll:gpu 1 -ll:fsize <size of gpu buffer> -ll:zsize <size of zero buffer>
cd "$FF_HOME"
./python/flexflow_python examples/python/native/mnist_mlp.py -ll:py 1 -ll:gpu 1 -ll:fsize <size of gpu buffer> -ll:zsize <size of zero buffer>
```
A script to run all the Python examples is available at `tests/multi_gpu_tests.sh`

Expand Down
9 changes: 4 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
FlexFlow is a deep learning framework that accelerates distributed DNN training by automatically searching for efficient parallelization strategies. FlexFlow provides a drop-in replacement for PyTorch and TensorFlow Keras. Running existing PyTorch and Keras programs in FlexFlow only requires [a few lines of changes to the program](https://flexflow.ai/keras).

## Install FlexFlow
To install FlexFlow from source code, please read the [instructions](INSTALL.md). If you would like to quickly try FlexFlow, we also provide pre-built Docker packages ([flexflow-cuda](https://github.com/flexflow/FlexFlow/pkgs/container/flexflow-cuda) with a CUDA backend, [flexflow-hip_rocm](https://github.com/flexflow/FlexFlow/pkgs/container/flexflow-hip_rocm) with a HIP-ROCM backend) with all dependencies pre-installed (N.B.: currently, the CUDA pre-built containers are only fully compatible with host machines that have CUDA 11.7 installed), together with [Dockerfiles](./docker) if you wish to build the containers manually. You can also use `conda` to install the FlexFlow Python package (coming soon).
To install FlexFlow from source code, please read the [instructions](https://flexflow.readthedocs.io/en/latest/installation.html). If you would like to quickly try FlexFlow, we also provide pre-built Docker packages ([flexflow-cuda](https://github.com/flexflow/FlexFlow/pkgs/container/flexflow-cuda) with a CUDA backend, [flexflow-hip_rocm](https://github.com/flexflow/FlexFlow/pkgs/container/flexflow-hip_rocm) with a HIP-ROCM backend) with all dependencies pre-installed (N.B.: currently, the CUDA pre-built containers are only fully compatible with host machines that have CUDA 11.7 installed), together with [Dockerfiles](./docker) if you wish to build the containers manually. You can also use `conda` to install the FlexFlow Python package (coming soon).

## PyTorch Support
Users can also use FlexFlow to optimize the parallelization performance of existing PyTorch models in two steps. First, a PyTorch model can be exported to the FlexFlow model format using `flexflow.torch.fx.torch_to_flexflow`.
Expand All @@ -18,7 +18,7 @@ fx.torch_to_flexflow(model, "mymodel.ff")

Second, a FlexFlow program can directly import a previously saved PyTorch model and [autotune](https://www.usenix.org/conference/osdi22/presentation/unger) the parallelization performance for a given parallel machine.

```
```python
from flexflow.pytorch.model import PyTorchModel

def top_level_task():
Expand All @@ -39,7 +39,7 @@ FlexFlow prioritizes PyTorch compatibility, but also includes frontends for [Ten
## C++ Interface
For users that prefer to program in C/C++. FlexFlow supports a C++ program inference that is equivalent to its Python APIs.

**More FlexFlow C++ examples**: see the [C++ examples folder](https://github.com/flexflow/FlexFlow/tree/master/examples/c++).
**More FlexFlow C++ examples**: see the [C++ examples folder](https://github.com/flexflow/FlexFlow/tree/master/examples/cpp).


## Command-Line Flags
Expand Down Expand Up @@ -69,12 +69,11 @@ Performance auto-tuning flags:
For performance tuning related flags: see [performance autotuning](https://flexflow.ai/search).

## Contributing

Please let us know if you encounter any bugs or have any suggestions by [submitting an issue](https://github.com/flexflow/flexflow/issues).

We welcome all contributions to FlexFlow from bug fixes to new features and extensions.

Please subscribe to the FlexFlow users mailing list for

## Citations
* Colin Unger, Zhihao Jia, Wei Wu, Sina Lin, Mandeep Baines, Carlos Efrain Quintero Narvaez, Vinay Ramakrishnaiah, Nirmal Prajapati, Pat McCormick, Jamaludin Mohd-Yusof, Xi Luo, Dheevatsa Mudigere, Jongsoo Park, Misha Smelyanskiy, and Alex Aiken. [Unity: Accelerating DNN Training Through Joint Optimization of Algebraic Transformations and Parallelization](https://www.usenix.org/conference/osdi22/presentation/unger). In Proceedings of the Symposium on Operating Systems Design and Implementation (OSDI), July 2022.

Expand Down
2 changes: 1 addition & 1 deletion docker/flexflow-environment/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM nvidia/cuda:11.7.0-cudnn8-devel-ubuntu20.04
FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu20.04

LABEL org.opencontainers.image.source=https://github.com/flexflow/FlexFlow
LABEL org.opencontainers.image.description="FlexFlow environment container"
Expand Down
12 changes: 5 additions & 7 deletions examples/cpp/inference/mixture_of_experts/moe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,8 @@ void FlexFlow::top_level_task(Task const *task,
Tensor output = ff.arg_top_k(t, /*k=*/1, /*sorted=*/false);

//------------------- Initialize the inference manager ------------------
InferenceManager im(
ff.config, moeConfig.batch_size, moeConfig.num_inflight_batches);
std::unordered_map<Tensor, std::vector<MachineView>> mapping;
im.compile_model_and_allocate_buffer(&ff, mapping);
InferenceManager im(ff.config, moeConfig.batch_size);
im.compile_model_and_allocate_buffer(&ff);
im.init_operators_inference(&ff);

//------------ Initialize the data loader and data generator ------------
Expand All @@ -162,7 +160,7 @@ void FlexFlow::top_level_task(Task const *task,
ParallelTensor input_pt;
ff.get_parallel_tensor_from_tensor(input, input_pt);
assert(im.tensor_buffer.find(input_pt) != im.tensor_buffer.end());
assert(im.tensor_buffer[input_pt].size() == im.max_num_inflight_batches);
assert(im.tensor_buffer[input_pt].size() == ffConfig.data_parallelism_degree);
DataLoader data_loader(
ff, moeConfig, data_generator, im.tensor_buffer[input_pt]);

Expand All @@ -184,13 +182,13 @@ void FlexFlow::top_level_task(Task const *task,
std::map<int, BatchConfig *> batch_configs;
std::pair<size_t, size_t> new_prompts;
BatchConfig *bc = nullptr;
std::map<size_t, int> batch_predictions[im.max_num_inflight_batches];
std::map<size_t, int> batch_predictions[ffConfig.data_parallelism_degree];

assert(im.max_num_tokens_per_batch == moeConfig.batch_size);

// simulation loop. For deployment, we will use a while(true)
while (processed_requests < moeConfig.total_requests) {
for (int bid = 0; bid < im.max_num_inflight_batches; bid++) {
for (int bid = 0; bid < ffConfig.data_parallelism_degree; bid++) {
size_t max_reqs, max_tkns;
if (future_handlers.find(bid) == future_handlers.end()) {
max_reqs = moeConfig.incremental_mode ? bc->MAX_NUM_REQUESTS
Expand Down
13 changes: 5 additions & 8 deletions examples/cpp/inference/transformers/transformers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,8 @@ void FlexFlow::top_level_task(Task const *task,
Tensor output = ff.arg_top_k(t, /*k=*/1, false);

//------------------- Initialize the inference manager ------------------
InferenceManager im(ff.config,
transformerConfig.batch_size,
transformerConfig.num_inflight_batches);
std::unordered_map<Tensor, std::vector<MachineView>> mapping;
im.compile_model_and_allocate_buffer(&ff, mapping);
InferenceManager im(ff.config, transformerConfig.batch_size);
im.compile_model_and_allocate_buffer(&ff);
im.init_operators_inference(&ff);

//------------ Initialize the data loader and data generator ------------
Expand All @@ -138,7 +135,7 @@ void FlexFlow::top_level_task(Task const *task,
ParallelTensor input_pt;
ff.get_parallel_tensor_from_tensor(input, input_pt);
assert(im.tensor_buffer.find(input_pt) != im.tensor_buffer.end());
assert(im.tensor_buffer[input_pt].size() == im.max_num_inflight_batches);
assert(im.tensor_buffer[input_pt].size() == ffConfig.data_parallelism_degree);
DataLoader data_loader(
ff, transformerConfig, data_generator, im.tensor_buffer[input_pt]);

Expand All @@ -160,14 +157,14 @@ void FlexFlow::top_level_task(Task const *task,
std::map<int, BatchConfig *> batch_configs;
std::pair<size_t, size_t> new_prompts;
BatchConfig *bc = nullptr;
std::map<size_t, int> batch_predictions[im.max_num_inflight_batches];
std::map<size_t, int> batch_predictions[ffConfig.data_parallelism_degree];

assert(im.max_num_tokens_per_batch == transformerConfig.batch_size);
// assert(transformerConfig.batch_size <= BatchConfig::MAX_NUM_REQUESTS);

// simulation loop. For deployment, we will use a while(true)
while (processed_requests < transformerConfig.total_requests) {
for (int bid = 0; bid < im.max_num_inflight_batches; bid++) {
for (int bid = 0; bid < ffConfig.data_parallelism_degree; bid++) {
size_t max_reqs, max_tkns;
if (future_handlers.find(bid) == future_handlers.end()) {
max_reqs = transformerConfig.incremental_mode
Expand Down
1 change: 1 addition & 0 deletions include/flexflow/batch_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ class BeamSearchBatchConfig : public BatchConfig {
inline static int const MAX_BEAM_DEPTH = 8;

int model_id;
int max_init_length = 0;

struct BeamSearchPerRequestInfo {
int beam_size;
Expand Down
15 changes: 9 additions & 6 deletions include/flexflow/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,15 @@ namespace FlexFlow {
// ========================================================
// Define Runtime Constants
// ========================================================
#define MAX_NUM_INPUTS 256
#define MAX_NUM_WEIGHTS 64
#define MAX_NUM_OUTPUTS 256
#define MAX_NUM_FUSED_OPERATORS 64
#define MAX_NUM_FUSED_TENSORS 64
#define MAX_NUM_INPUTS 2048
#define MAX_NUM_WEIGHTS 2048
#define MAX_NUM_OUTPUTS 2048
#define MAX_NUM_FUSED_OPERATORS 2048
#define MAX_NUM_FUSED_TENSORS 2048
#define MAX_NUM_WORKERS 1024
#define MAX_FILENAME 200
#define MAX_OPNAME 128
#define MAX_NUM_TRANSFORMER_LAYERS 100
// DataLoader
#define MAX_SAMPLES_PER_LOAD 64
#define MAX_FILE_LENGTH 128
Expand Down Expand Up @@ -143,8 +144,10 @@ class FFConfig {
bool enable_parameter_parallel;
bool enable_attribute_parallel;
bool enable_inplace_optimizations;
// Control tensor model parallelism degree in inference
// Control parallelism degrees in inference
int data_parallelism_degree;
int tensor_parallelism_degree;
int pipeline_parallelism_degree;
// Control Tensor Op Math Conversion
bool allow_tensor_op_math_conversion;
std::string dataset_path;
Expand Down
2 changes: 2 additions & 0 deletions include/flexflow/ffconst.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ enum OperatorType {
OP_REPLICATE,
OP_REDUCTION,
OP_PIPELINE,
OP_ALLREDUCE,
OP_FUSED_PARALLEL,
OP_INVALID,
};
Expand Down Expand Up @@ -207,6 +208,7 @@ enum PMParameter {
PM_COMBINE_DEGREE, // Combine
PM_REDUCTION_DIM, // Reduction
PM_REDUCTION_DEGREE, // Reduction
PM_ALLREDUCE_DIM, // AllReduce
PM_SOFTMAX_DIM, // Softmax
PM_NUM_HEADS, // MultiHeadAttention
PM_INVALID,
Expand Down
7 changes: 4 additions & 3 deletions include/flexflow/fftype.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,16 @@ namespace FlexFlow {

class LayerID {
public:
static const LayerID NO_ID;
LayerID();
LayerID(size_t id);
LayerID(size_t id, size_t transformer_layer_id);
bool is_valid_id() const;
friend bool operator==(LayerID const &lhs, LayerID const &rhs);

public:
size_t id;
size_t id, transformer_layer_id;
};

}; // namespace FlexFlow

#endif // _FF_TYPE_H
#endif // _FF_TYPE_H
9 changes: 2 additions & 7 deletions include/flexflow/inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,8 @@ using tokenizers::Tokenizer;

class InferenceManager {
public:
InferenceManager(FFConfig const &config,
int max_num_tokens_per_batch,
int max_num_inflight_batches);
void compile_model_and_allocate_buffer(
FFModel *model,
std::unordered_map<Tensor, std::vector<MachineView>> const &mapping);
InferenceManager(FFConfig const &config, int max_num_tokens_per_batch);
void compile_model_and_allocate_buffer(FFModel *model);
void init_operators_inference(FFModel *model);
MachineView *get_machine_view(int mv_id);
Legion::FutureMap inference(FFModel *model, int index, BatchConfig const &bc);
Expand All @@ -45,7 +41,6 @@ class InferenceManager {
FFConfig ff_config;
std::unordered_map<ParallelTensor, std::vector<ParallelTensor>> tensor_buffer;
int max_num_tokens_per_batch;
int max_num_inflight_batches;
int num_devices;
std::vector<MachineView> machine_views;
};
Expand Down
13 changes: 13 additions & 0 deletions include/flexflow/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ enum TaskIDs {
LAYERNORM_BWD_TASK_ID,
LINEAR_INIT_TASK_ID,
LINEAR_INIT_PARA_TASK_ID,
LINEAR_INF_TASK_ID,
LINEAR_FWD_TASK_ID,
LINEAR_BWD_TASK_ID,
LINEAR_BWD2_TASK_ID,
Expand Down Expand Up @@ -159,6 +160,7 @@ enum TaskIDs {
FUSEDOP_INIT_TASK_ID,
FUSEDOP_FWD_TASK_ID,
FUSEDOP_BWD_TASK_ID,
FUSEDOP_INF_TASK_ID,
NOOP_INIT_TASK_ID,
// Metrics tasks
METRICS_COMP_TASK_ID,
Expand Down Expand Up @@ -212,6 +214,9 @@ enum TaskIDs {
PIPELINE_INIT_TASK_ID,
PIPELINE_FWD_TASK_ID,
PIPELINE_BWD_TASK_ID,
ALLREDUCE_INIT_TASK_ID,
ALLREDUCE_FWD_TASK_ID,
ALLREDUCE_BWD_TASK_ID,
FUSED_PARALLELOP_INIT_TASK_ID,
FUSED_PARALLELOP_FWD_TASK_ID,
FUSED_PARALLELOP_BWD_TASK_ID,
Expand Down Expand Up @@ -311,6 +316,7 @@ class Combine;
class Repartition;
class Reduction;
class Replicate;
class AllReduce;
class FusedParallelOp;
class ParallelOpInfo;

Expand Down Expand Up @@ -897,6 +903,9 @@ class FFModel {
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
// ========================================
// Internal APIs that should not be invoked from applications
// ========================================
void reset_metrics();
void init_operators();
void init_operators_inference(
Expand All @@ -919,6 +928,7 @@ class FFModel {
std::vector<MetricsType> const &metrics,
CompMode comp_mode = COMP_MODE_TRAINING);
void compile_inference();
void set_transformer_layer_id(int id);
void graph_optimize(size_t budget,
bool only_data_parallel,
std::unique_ptr<PCG::Graph> &best_graph,
Expand Down Expand Up @@ -975,6 +985,7 @@ class FFModel {
public:
size_t op_global_guid, layer_global_guid;
size_t tensor_global_guid, parallel_tensor_global_guid, node_global_guid;
size_t current_transformer_layer_id;
FFConfig config;
FFIterationConfig iter_config;
Optimizer *optimizer;
Expand Down Expand Up @@ -1078,6 +1089,8 @@ class FFModel {
Reduction *>,
std::unordered_map<std::pair<ParallelTensorShape, CombineParams>,
Combine *>,
std::unordered_map<std::pair<ParallelTensorShape, AllReduceParams>,
AllReduce *>,
std::unordered_map<std::pair<ParallelTensorShape, FusedParallelOpParams>,
FusedParallelOp *>>
cached_ops;
Expand Down
Loading

0 comments on commit 88d6255

Please sign in to comment.