Skip to content

Commit

Permalink
[Cherrypick][ci] fix hf hub flakiness remove unused prepare (#2278) (#…
Browse files Browse the repository at this point in the history
…2294)

Co-authored-by: Tyler Osterberg <tylertosterberg@gmail.com>
  • Loading branch information
Qing Lan and tosterberg authored Aug 7, 2024
1 parent a4dd4d6 commit c343d60
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 115 deletions.
55 changes: 34 additions & 21 deletions serving/docker/partition/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
import utils
from properties_manager import PropertiesManager
from huggingface_hub import snapshot_download
from awq import AutoAWQForCausalLM
from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig
from datasets import load_dataset

from utils import (get_partition_cmd, extract_python_jar,
Expand Down Expand Up @@ -266,14 +264,21 @@ def autoawq_quantize(self):
"version": "GEMM"
}
logging.info(f"Model loading kwargs: {hf_configs.kwargs}")
awq_model = AutoAWQForCausalLM.from_pretrained(
hf_configs.model_id_or_path, **hf_configs.kwargs)
awq_model.quantize(tokenizer, quant_config=quant_config)

output_path = self.properties['option.save_mp_checkpoint_path']
logging.info(f"Saving model and tokenizer to: {output_path}")
awq_model.save_quantized(output_path)
tokenizer.save_pretrained(output_path)
try:
from awq import AutoAWQForCausalLM
awq_model = AutoAWQForCausalLM.from_pretrained(
hf_configs.model_id_or_path, **hf_configs.kwargs)
awq_model.quantize(tokenizer, quant_config=quant_config)

output_path = self.properties['option.save_mp_checkpoint_path']
logging.info(f"Saving model and tokenizer to: {output_path}")
awq_model.save_quantized(output_path)
tokenizer.save_pretrained(output_path)
except ImportError:
logging.error(
"AutoAWQ is not installed. Failing during quantization.")
raise ImportError(
"AutoAWQ is not installed. Failing during quantization.")

def autofp8_quantize(self, config: Optional[dict] = None):
"""
Expand Down Expand Up @@ -304,17 +309,25 @@ def autofp8_quantize(self, config: Optional[dict] = None):
truncation=True,
return_tensors="pt").to("cuda")

quantize_config = BaseQuantizeConfig(**config)
logging.info(
f"Using the following configurations for fp8 quantization: {vars(quantize_config)}"
)
model = AutoFP8ForCausalLM.from_pretrained(hf_configs.model_id_or_path,
quantize_config,
**hf_configs.kwargs)
model.quantize(examples)
output_path = self.properties['option.save_mp_checkpoint_path']
logging.info(f"Quantization complete. Saving model to: {output_path}")
model.save_quantized(output_path)
try:
from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig
quantize_config = BaseQuantizeConfig(**config)
logging.info(
f"Using the following configurations for fp8 quantization: {vars(quantize_config)}"
)
model = AutoFP8ForCausalLM.from_pretrained(
hf_configs.model_id_or_path, quantize_config,
**hf_configs.kwargs)
model.quantize(examples)
output_path = self.properties['option.save_mp_checkpoint_path']
logging.info(
f"Quantization complete. Saving model to: {output_path}")
model.save_quantized(output_path)
except ImportError:
logging.error(
"AutoFP8 is not installed. Failing during quantization.")
raise ImportError(
"AutoFP8 is not installed. Failing during quantization.")


def main():
Expand Down
18 changes: 3 additions & 15 deletions tests/integration/llm/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,19 +171,10 @@ def get_model_name():
"batch_size": [1],
"seq_length": [256],
"tokenizer": "TheBloke/Llama-2-13B-fp16"
}
}

transformers_neuronx_aot_model_spec = {
"gpt2": {
"worker": 1,
"seq_length": [512],
"batch_size": [4]
},
"gpt2-quantize": {
"worker": 1,
"seq_length": [512],
"batch_size": [4]
"tiny-llama-rb": {
"batch_size": [1, 4],
"seq_length": [256],
},
}

Expand Down Expand Up @@ -1778,9 +1769,6 @@ def run(raw_args):
transformers_neuronx_model_spec)
elif args.handler == "transformers_neuronx_rolling_batch":
test_handler_rolling_batch(args.model, transformers_neuronx_model_spec)
elif args.handler == "transformers_neuronx-aot":
test_transformers_neuronx_handler(args.model,
transformers_neuronx_aot_model_spec)
elif args.handler == "transformers_neuronx_neo":
test_transformers_neuronx_handler(args.model,
transformers_neuronx_neo_model_spec)
Expand Down
74 changes: 19 additions & 55 deletions tests/integration/llm/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,58 +76,17 @@
}
}

transformers_neuronx_aot_handler_list = {
"gpt2": {
"option.model_id":
"gpt2",
"option.batch_size":
4,
"option.tensor_parallel_degree":
2,
"option.n_positions":
512,
"option.dtype":
"fp16",
"option.model_loading_timeout":
600,
"option.enable_streaming":
False,
"option.save_mp_checkpoint_path":
"/opt/ml/input/data/training/partition-test"
},
"gpt2-quantize": {
"option.model_id":
"gpt2",
"option.batch_size":
4,
"option.tensor_parallel_degree":
2,
"option.n_positions":
512,
"option.dtype":
"fp16",
"option.model_loading_timeout":
600,
"option.quantize":
"static_int8",
"option.enable_streaming":
False,
"option.save_mp_checkpoint_path":
"/opt/ml/input/data/training/partition-test"
},
}

transformers_neuronx_handler_list = {
"gpt2": {
"option.model_id": "gpt2",
"option.model_id": "s3://djl-llm/gpt2/",
"max_dynamic_batch_size": 4,
"option.tensor_parallel_degree": 2,
"option.n_positions": 512,
"option.dtype": "fp16",
"option.model_loading_timeout": 600
},
"gpt2-quantize": {
"option.model_id": "gpt2",
"option.model_id": "s3://djl-llm/gpt2/",
"batch_size": 4,
"option.tensor_parallel_degree": 2,
"option.n_positions": 512,
Expand Down Expand Up @@ -276,6 +235,23 @@
"option.max_rolling_batch_size": 1,
"option.model_loading_timeout": 3600,
"option.output_formatter": "jsonlines"
},
"tiny-llama-rb-aot": {
"option.model_id": "s3://djl-llm/tinyllama-1.1b-chat/",
"option.tensor_parallel_degree": 2,
"option.n_positions": 1024,
"option.max_rolling_batch_size": 4,
"option.rolling_batch": 'auto',
"option.model_loading_timeout": 1200,
},
"tiny-llama-rb-aot-quant": {
"option.model_id": "s3://djl-llm/tinyllama-1.1b-chat/",
"option.quantize": "static_int8",
"option.tensor_parallel_degree": 2,
"option.n_positions": 1024,
"option.max_rolling_batch_size": 4,
"option.rolling_batch": 'auto',
"option.model_loading_timeout": 1200,
}
}

Expand Down Expand Up @@ -1217,17 +1193,6 @@ def build_transformers_neuronx_handler_model(model):
write_model_artifacts(options)


def build_transformers_neuronx_aot_handler_model(model):
if model not in transformers_neuronx_aot_handler_list.keys():
raise ValueError(
f"{model} is not one of the supporting handler {list(transformers_neuronx_aot_handler_list.keys())}"
)
options = transformers_neuronx_aot_handler_list[model]
options["engine"] = "Python"
options["option.entryPoint"] = "djl_python.transformers_neuronx"
write_model_artifacts(options)


def build_rolling_batch_model(model):
if model not in rolling_batch_model_list.keys():
raise ValueError(
Expand Down Expand Up @@ -1364,7 +1329,6 @@ def build_text_embedding_model(model):
supported_handler = {
'huggingface': build_hf_handler_model,
'transformers_neuronx': build_transformers_neuronx_handler_model,
'transformers_neuronx_aot': build_transformers_neuronx_aot_handler_model,
'performance': build_performance_model,
'handler_performance': build_handler_performance_model,
'rolling_batch_scheduler': build_rolling_batch_model,
Expand Down
40 changes: 16 additions & 24 deletions tests/integration/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
import subprocess
import logging
import pytest
import llm.prepare as prepare
import llm.client as client
Expand Down Expand Up @@ -62,7 +63,9 @@ def __exit__(self, *args):
f"cp client_logs/{esc_test_name}_client.log all_logs/{esc_test_name}/ || true"
)
os.system(f"cp -r logs all_logs/{esc_test_name}")
subprocess.run(["./remove_container.sh"], check=True)
subprocess.run(["./remove_container.sh"],
check=True,
capture_output=True)
os.system("cat logs/serving.log")

def launch(self, env_vars=None, container=None, cmd=None):
Expand Down Expand Up @@ -715,34 +718,23 @@ def test_bloom(self):
r.launch(container='pytorch-inf2-2')
client.run("transformers_neuronx bloom-7b1".split())

@pytest.mark.parametrize("model", ["gpt2", "gpt2-quantize"])
@pytest.mark.parametrize("model",
["tiny-llama-rb-aot", "tiny-llama-rb-aot-quant"])
def test_partition(self, model):
try:
with Runner('pytorch-inf2', f'partition-{model}') as r:
with Runner('pytorch-inf2', f'partition-{model}') as r:
try:
prepare.build_transformers_neuronx_handler_model(model)
with open("models/test/requirements.txt", "a") as f:
f.write("dummy_test")
partition_output = r.launch(
r.launch(
container="pytorch-inf2-1",
cmd=
'partition --model-dir /opt/ml/input/data/training/ --save-mp-checkpoint-path /opt/ml/input/data/training/partition --skip-copy'
"partition --model-dir /opt/ml/input/data/training --save-mp-checkpoint-path /opt/ml/input/data/training/aot --skip-copy"
)

# Check if neff files are generated
if len([
fn
for fn in os.listdir("models/test/partition/compiled")
if fn.endswith(".neff")
]) == 0:
raise Exception("Failed to generate any .neff files")

# Check whether requirements.txt download is sufficient
if 'pip install requirements succeed!' not in partition_output.stdout.decode(
"utf-8"):
raise Exception(
"Requirements.txt not installed successfully")
finally:
os.system('sudo rm -rf models')
r.launch(container="pytorch-inf2-1",
cmd="serve -m test=file:/opt/ml/model/test/aot")
client.run(
"transformers_neuronx_rolling_batch tiny-llama-rb".split())
finally:
os.system('sudo rm -rf models')


@pytest.mark.inf
Expand Down

0 comments on commit c343d60

Please sign in to comment.