Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python interface for inference (part 2) #893

Merged
merged 35 commits into from
Aug 2, 2023
Merged

Conversation

goliaro
Copy link
Collaborator

@goliaro goliaro commented Jul 28, 2023

Description of changes:

This PR introduces the Python interface for inference. It will allow the user to run FlexFlow serve as shown below. For more complete examples, check out the inference/python/incr_decoding.py and inference/python/spec_infer.py scripts.

Incremental decoding

import flexflow.serve as ff
import json

# Initialize the FlexFlow runtime. ff.init() takes a dictionary or the path to a JSON file with the configs
ff.init(
    {
        "num_gpus": 4,
        "memory_per_gpu": 14000,
        "zero_copy_memory_per_gpu": 30000,
        "pipeline_parallelism_degree": 4,
    }
)

# Create the FlexFlow LLM
llm = ff.LLM(
    "decapoda-research/llama-7b-hf",
    data_type=ff.DataType.DT_FLOAT,         # or ff.DataType.DT_HALF
    tokenizer_path="",                      # leave empty to use HF's tokenizer
    weights_path="",                        # leave empty to use HF's weights directly
    clean_cache=False,                      # set to True if you'd like to discard the FlexFlow weights/tokenizer cache for the given model
    output_file="output.txt",
)
sampling_config = ff.SamplingConfig(
    do_sample=False, temperature=0.9, topp=0.8, topk=1
)
# Compile the LLM for inference and load the weights into memory
llm.compile(
    ff.InferenceMode.INC_DECODING_MODE,
    sampling_config,
    max_batch_size=1,
    max_seq_length=256,
    max_tokens_per_batch=64,
)
# Generation begins!
prompts = [s for s in json.load(open("chatgpt.json"))]
results = llm.generate(prompts)

Speculative Inference

import flexflow.serve as ff
import os, json
from types import SimpleNamespace

# Initialize the FlexFlow runtime. ff.init() takes a dictionary or the path to a JSON file with the configs
ff.init(
    {
        "num_gpus": 4,
        "memory_per_gpu": 14000,
        "zero_copy_memory_per_gpu": 30000,
        "pipeline_parallelism_degree": 4,
    }
)

# Configure the LLM and SSM
configs = {
    "llm_model": "decapoda-research/llama-7b-hf",
    "llm_weight": "",
    "llm_tokenizer": "",
    "clean_model_cache": False,
    "full_precision": False,
    "ssms": [
        {
            "ssm_model": "JackFram/llama-160m",
            "ssm_weight": "",
            "ssm_tokenizer": "",
            "clean_model_cache": False,
            "full_precision": False,
        },
        {
            "ssm_model": "facebook/opt-125m",
            "ssm_weight": "",
            "ssm_tokenizer": "",
            "clean_model_cache": False,
            "full_precision": False,
        },
    ],
    "prompt": "../prompt/test.json",
    "output_file": "",
}
configs = SimpleNamespace(**configs)


# Create the FlexFlow LLM
ff_data_type = (
    ff.DataType.DT_FLOAT if configs.full_precision else ff.DataType.DT_HALF
)
llm = ff.LLM(
    configs.llm_model,
    data_type=ff_data_type,
    tokenizer_path=configs.llm_tokenizer,
    weights_path=configs.llm_weight,
    clean_cache=configs.clean_model_cache,
    output_file=configs.output_file,
)

# Create the SSMs
ssms = []
for ssm_config in configs.ssms:
    ssm_config = SimpleNamespace(**ssm_config)
    ff_data_type = (
        ff.DataType.DT_FLOAT if ssm_config.full_precision else ff.DataType.DT_HALF
    )
    ssm = ff.SSM(
        ssm_config.ssm_model,
        data_type=ff_data_type,
        tokenizer_path=ssm_config.ssm_tokenizer,
        weights_path=ssm_config.ssm_weight,
        clean_cache=ssm_config.clean_model_cache,
        output_file=configs.output_file,
    )
    ssms.append(ssm)

# Create the sampling configs
sampling_config = ff.SamplingConfig(
    do_sample=False, temperature=0.9, topp=0.8, topk=1
)

# Compile the SSMs for inference and load the weights into memory
for ssm in ssms:
    ssm.compile(
        ff.InferenceMode.BEAM_SEARCH_MODE,
        sampling_config,
        max_batch_size=1,
        max_seq_length=256,
        max_tokens_per_batch=64,
    )

# Compile the LLM for inference and load the weights into memory
llm.compile(
    ff.InferenceMode.TREE_VERIFY_MODE,
    sampling_config,
    max_batch_size=1,
    max_seq_length=256,
    max_tokens_per_batch=64,
    ssms=ssms,
)
# Generation begins!
prompts = [s for s in json.load(open("chatgpt.json"))]
results = llm.generate(prompts)

TODOs:

  • Implement speculative inference
  • Implement OPT model
  • Implement Falcon model
  • Unify argument parsing
  • Download tokenizers directly from HF
  • Add code to automatically generate set_ff_envs.sh
  • Update readme
  • Debug speculative inference with different types of models
  • Separate flexflow_inference.py example file into two, one for incremental decoding, and one for specinfer
  • Make example args more uniform with C++
  • Replace C++ tests in inference_tests.sh with Python ones
  • Update PR description
  • Update READMEs and docs

Related Issues:

Linked Issues:

  • Issue #

Issues closed by this PR:

  • Closes #

Before merging:

  • Did you update the flexflow-third-party repo, if modifying any of the Cmake files, the build configs, or the submodules?

@goliaro goliaro added the inference Features and fixes related to the inference project. label Jul 28, 2023
@goliaro goliaro marked this pull request as ready for review August 2, 2023 03:38
@goliaro
Copy link
Collaborator Author

goliaro commented Aug 2, 2023

The only missing part of this PR is the updating of the docs and the replacement of the C++ tests with the Python tests in CI. Everything else is working though, so if anyone is blocked by this PR, feel free to merge it. In that case, I'll open a new PR for the final polishing. Otherwise, I'll keep pushing here.

@jiazhihao jiazhihao enabled auto-merge (squash) August 2, 2023 03:45
@jiazhihao
Copy link
Collaborator

Great! Let's merge this PR after it passes CI.

@jiazhihao jiazhihao merged commit ba91733 into inference Aug 2, 2023
42 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
inference Features and fixes related to the inference project.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants