Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
goliaro committed Oct 17, 2024
1 parent aafc3d7 commit 3b9d8bb
Show file tree
Hide file tree
Showing 8 changed files with 41 additions and 24 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/gpu-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ jobs:
CONDA: "3"
needs: gpu-ci-concierge
container:
image: ghcr.io/flexflow/flexflow-environment-cuda-11.8:latest
image: ghcr.io/flexflow/flexflow-environment-cuda-12.1:latest
options: --gpus all --shm-size=8192m
steps:
- name: Keep alive
Expand All @@ -75,7 +75,7 @@ jobs:
CONDA: "3"
needs: gpu-ci-concierge
container:
image: ghcr.io/flexflow/flexflow-environment-cuda-11.8:latest
image: ghcr.io/flexflow/flexflow-environment-cuda-12.1:latest
options: --gpus all --shm-size=8192m
steps:
- name: Install updated git version
Expand Down Expand Up @@ -151,7 +151,7 @@ jobs:
HUGGINGFACE_TOKEN: ${{ secrets.HUGGINGFACE_TOKEN }}
needs: gpu-ci-concierge
container:
image: ghcr.io/flexflow/flexflow-environment-cuda-11.8:latest
image: ghcr.io/flexflow/flexflow-environment-cuda-12.1:latest
options: --gpus all --shm-size=8192m
steps:
- name: Install updated git version
Expand Down Expand Up @@ -239,7 +239,7 @@ jobs:
CONDA: "3"
needs: inference-tests
container:
image: ghcr.io/flexflow/flexflow-environment-cuda-11.8:latest
image: ghcr.io/flexflow/flexflow-environment-cuda-12.1:latest
options: --gpus all --shm-size=8192m
steps:
- name: Install updated git version
Expand Down
17 changes: 10 additions & 7 deletions inference/python/incr_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ def get_configs():
"tensor_parallelism_degree": 1,
"pipeline_parallelism_degree": 2,
"offload": False,
"offload_reserve_space_size": 8 * 1024, # 8GB
"offload_reserve_space_size": 8 * 1024, # 8GB
"use_4bit_quantization": False,
"use_8bit_quantization": False,
"enable_peft": False,
"peft_activation_reserve_space_size": 1024, # 1GB
"peft_weight_reserve_space_size": 1024, # 1GB
"peft_activation_reserve_space_size": 1024, # 1GB
"peft_weight_reserve_space_size": 1024, # 1GB
"profiling": False,
"benchmarking": False,
"inference_debugging": False,
Expand All @@ -71,6 +71,7 @@ def get_configs():
"full_precision": False,
"prompt": "",
"output_file": "",
"max_length": 128,
}
# Merge dictionaries
ff_init_configs.update(llm_configs)
Expand Down Expand Up @@ -106,9 +107,9 @@ def main():
max_seq_length=256,
max_tokens_per_batch=64,
)

llm.start_server()

if len(configs.prompt) > 0:
prompts = [s for s in json.load(open(configs.prompt))]
if "max_length" not in configs_dict:
Expand All @@ -119,8 +120,10 @@ def main():
if "max_length" not in configs_dict:
result = llm.generate("Three tips for staying healthy are: ")
else:
result = llm.generate("Three tips for staying healthy are: ", max_length=configs.max_length)

result = llm.generate(
"Three tips for staying healthy are: ", max_length=configs.max_length
)

llm.stop_server()


Expand Down
24 changes: 17 additions & 7 deletions inference/python/spec_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ def get_configs():
"tensor_parallelism_degree": 1,
"pipeline_parallelism_degree": 2,
"offload": False,
"offload_reserve_space_size": 8 * 1024, # 8GB
"offload_reserve_space_size": 8 * 1024, # 8GB
"use_4bit_quantization": False,
"use_8bit_quantization": False,
"enable_peft": False,
"peft_activation_reserve_space_size": 1024, # 1GB
"peft_weight_reserve_space_size": 1024, # 1GB
"peft_activation_reserve_space_size": 1024, # 1GB
"peft_weight_reserve_space_size": 1024, # 1GB
"profiling": False,
"benchmarking": False,
"inference_debugging": False,
Expand All @@ -81,6 +81,7 @@ def get_configs():
],
"prompt": "",
"output_file": "",
"max_length": 128,
}
# Merge dictionaries
ff_init_configs.update(llm_configs)
Expand Down Expand Up @@ -144,17 +145,26 @@ def main():
max_tokens_per_batch=64,
ssms=ssms,
)

llm.start_server()

if len(configs.prompt) > 0:
prompts = [s for s in json.load(open(configs.prompt))]
results = llm.generate(prompts)
if "max_length" not in configs_dict:
results = llm.generate(prompts)
else:
results = llm.generate(prompts, max_length=configs.max_length)
else:
result = llm.generate("Three tips for staying healthy are: ")

if "max_length" not in configs_dict:
result = llm.generate("Three tips for staying healthy are: ")
else:
result = llm.generate(
"Three tips for staying healthy are: ", max_length=configs.max_length
)

llm.stop_server()


if __name__ == "__main__":
print("flexflow inference example (speculative inference)")
main()
2 changes: 1 addition & 1 deletion src/ops/inc_multihead_self_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta *m,
int num_new_tokens = bc->requestsInfo[i].num_tokens_in_batch;
int total_tokens = bc->requestsInfo[i].first_token_depth_in_request +
bc->requestsInfo[i].num_tokens_in_batch;
int max_peft_tokens = bc->requestsInfo[i].max_sequence_length;
int max_peft_tokens = bc->requestsInfo[i].max_length;
// Copy query to m->query_activation_buffer if we need to compute
// PEFT backward
if (bc->requestsInfo[i].peft_bwd) {
Expand Down
6 changes: 3 additions & 3 deletions src/ops/spec_inc_multihead_self_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ Op *SpecIncMultiHeadSelfAttention::create_operator_from_layer(
Layer const *layer,
std::vector<ParallelTensor> const &inputs) {

std::cout << "spec create operator: " << layer->name << "\n";
// std::cout << "spec create operator: " << layer->name << "\n";
long long value;
layer->get_int_property("embed_dim", value);
int embed_dim = value;
Expand All @@ -182,10 +182,10 @@ Op *SpecIncMultiHeadSelfAttention::create_operator_from_layer(
int kdim = value;
layer->get_int_property("vdim", value);
int vdim = value;
float dropout;
layer->get_float_property("dropout", dropout);
layer->get_int_property("add_zero_attn", value);
bool add_zero_attn = (bool)value;
float dropout;
layer->get_float_property("dropout", dropout);
RotaryEmbeddingMeta rotary_embedding_meta;
layer->get_int_property("apply_rotary_embedding", value);
rotary_embedding_meta.apply_rotary_embedding = (bool)value;
Expand Down
5 changes: 3 additions & 2 deletions src/ops/tree_inc_multihead_self_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,10 @@ Op *TreeIncMultiHeadSelfAttention::create_operator_from_layer(
int kdim = value;
layer->get_int_property("vdim", value);
int vdim = value;
float dropout;
layer->get_float_property("dropout", dropout);
layer->get_int_property("add_zero_attn", value);
bool add_zero_attn = (bool)value;
float dropout;
layer->get_float_property("dropout", dropout);
RotaryEmbeddingMeta rotary_embedding_meta;
layer->get_int_property("apply_rotary_embedding", value);
rotary_embedding_meta.apply_rotary_embedding = (bool)value;
Expand All @@ -204,6 +204,7 @@ Op *TreeIncMultiHeadSelfAttention::create_operator_from_layer(
rotary_embedding_meta.high_freq_factor);
layer->get_int_property("original_max_position_embeddings", value);
rotary_embedding_meta.original_max_position_embeddings = (int)value;
layer->get_int_property("scaling_query", value);
bool scaling_query = (bool)value;
float scaling_factor;
layer->get_float_property("scaling_factor", scaling_factor);
Expand Down
2 changes: 2 additions & 0 deletions src/runtime/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1677,6 +1677,7 @@ void FFModel::finish_nccl_comms() {
false /*must*/,
0 /*mapper_id*/,
comm.first);
index_launcher.concurrent = true;
FutureMap fm = runtime->execute_index_space(ctx, index_launcher);
fm.wait_all_results();
}
Expand Down Expand Up @@ -7685,6 +7686,7 @@ void register_flexflow_internal_tasks(Runtime *runtime,
"NCCL Finish Communicators");
registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC));
registrar.set_leaf();
registrar.set_concurrent();
if (pre_register) {
Runtime::preregister_task_variant<Op::finish_nccl_comms_task>(
registrar, "NCCL Finish Communicators Task", 111 /*variant ID*/);
Expand Down
1 change: 1 addition & 0 deletions tests/inference/python_test_configs/generate_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"full_precision": True,
"prompt": "",
"output_file": "",
"max_length": 128,
}
ssm_configs = {
"ssms": [
Expand Down

0 comments on commit 3b9d8bb

Please sign in to comment.