Skip to content

Commit

Permalink
dml specific change in model builder for phi3.5 (#888)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangxiang1993 authored Sep 16, 2024
1 parent e689880 commit e8a89ad
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 2 deletions.
12 changes: 10 additions & 2 deletions examples/python/awq-quantized-model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ def parse_args():
help="Folder to save AWQ-quantized ONNX model and associated files in",
)

parser.add_argument(
"-e",
"--execution_provider",
default="cuda",
help="Target execution provider to apply quantization (e.g. dml, cuda)",
)

args = parser.parse_args()
return args

Expand Down Expand Up @@ -108,13 +115,14 @@ def main():
input_folder = args.quant_path
output_folder = args.output_path
precision = "int4"
execution_provider = "cuda"
execution_provider = args.execution_provider
cache_dir = os.path.join(".", "cache_dir")

create_model(model_name, input_folder, output_folder, precision, execution_provider, cache_dir)

# Run ONNX model
run_model(args)
if args.execution_provider != "dml":
run_model(args)

if __name__ == "__main__":
main()
74 changes: 74 additions & 0 deletions src/python/py/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,11 @@ def make_greater(self, name, inputs, shape):
output = f"{name}/output_0"
self.make_node("Greater", inputs=inputs, outputs=[output], name=name)
self.make_value_info(output, TensorProto.BOOL, shape=shape)

def make_greater_or_equal(self, name, inputs, shape):
output = f"{name}/output_0"
self.make_node("GreaterOrEqual", inputs=inputs, outputs=[output], name=name)
self.make_value_info(output, TensorProto.BOOL, shape=shape)

def make_isinf(self, name, root_input, shape):
output = f"{name}/output_0"
Expand All @@ -597,6 +602,11 @@ def make_reduce_sum(self, name, inputs, dtype, shape):
self.make_node("ReduceSum", inputs=inputs, outputs=[output], name=name)
self.make_value_info(output, dtype, shape=shape)

def make_reduce_max(self, name, inputs, dtype, shape):
output = f"{name}/output_0"
self.make_node("ReduceMax", inputs=inputs, outputs=[output], name=name, keepdims=False)
self.make_value_info(output, dtype, shape=shape)

def make_cast(self, name, root_input, dtype, shape):
output = f"{name}/output_0"
self.make_node("Cast", inputs=[root_input], outputs=[output], name=name, to=dtype)
Expand Down Expand Up @@ -2215,6 +2225,7 @@ def make_position_ids_reformatting(self):
# position_ids input for RotaryEmbedding

basename = "/model/pos_ids_reformat"

shape_name = f"{basename}/Shape"
self.make_shape(shape_name, root_input="input_ids" if not self.exclude_embeds else "inputs_embeds", shape=[2] if not self.exclude_embeds else [3])
gather_name = f"{basename}/Gather"
Expand Down Expand Up @@ -2372,7 +2383,70 @@ class Phi3Mini128KModel(Phi3Mini4KModel):
def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options)
self.make_rotary_embedding_multi_cache()

def make_position_ids_reformatting(self):
if self.ep != "dml":
position_ids_input_to_rotemb = super().make_position_ids_reformatting()
return position_ids_input_to_rotemb

basename = "/model/pos_ids_reformat"
reduce_max_name = f"{basename}/ReduceMax"
reduce_max_inputs = ["position_ids"]
self.make_reduce_max(reduce_max_name, reduce_max_inputs, dtype=TensorProto.INT64, shape=[1])
greater_or_equal_name = f"{basename}/GreaterOrEqual"
greater_or_equal_inputs = [f"{reduce_max_name}/output_0", f"/model/constants/TensorProto.INT64/0D/{self.original_context_length}"]
self.make_greater_or_equal(greater_or_equal_name, greater_or_equal_inputs, shape=[])
cast_name = f"{basename}/Cast"
self.make_cast(cast_name, f"{greater_or_equal_name}/output_0", dtype=TensorProto.INT64, shape=None)
mul_name = f"{basename}/Mul"
mul_inputs = [f"{cast_name}/output_0", f"/model/constants/TensorProto.INT64/0D/{self.original_context_length}"]
self.make_mul(mul_name, mul_inputs, dtype=TensorProto.INT64, shape=None)
add_1_name = f"{basename}/Add_1"
add_1_inputs = [f"{mul_name}/output_0", "position_ids"]
self.make_add(add_1_name, add_1_inputs, dtype=TensorProto.INT64, shape=["batch_size", "sequence_length"])

return add_1_name

def make_rotary_embedding_caches(self, rotemb, **kwargs):
if self.ep != "dml":
cos_cache_name, sin_cache_name = super().make_rotary_embedding_caches(rotemb, **kwargs)
return cos_cache_name, sin_cache_name

cos_cache_name = kwargs.get("cos_cache_name", "cos_cache")
sin_cache_name = kwargs.get("sin_cache_name", "sin_cache")

if self.rotemb_attrs["create_rotary_embedding_caches"]:
if not hasattr(rotemb, "cos_cached"):
# Create cos/sin caches if not already created
# concate 4k and 128k cos/sin caches for phi3/phi3.5 and dml EP only
cos_cache_large, sin_cache_large = self.make_rotary_embedding_caches_from_scratch()
self.rotemb_attrs["rescale_factors"] = self.rotemb_attrs["multi_cache"]["short_factor"]
self.rotemb_attrs["cache_length"] = self.original_context_length
self.rotemb_attrs["mscale"] = self.rotemb_attrs["multi_cache"]["short_mscale"]
cos_cache_small, sin_cache_small = self.make_rotary_embedding_caches_from_scratch()
cos_cache = torch.cat((cos_cache_small, cos_cache_large), dim=0)
sin_cache = torch.cat((sin_cache_small, sin_cache_large), dim=0)
else:
cos_cache, sin_cache = rotemb.cos_cached, rotemb.sin_cached

# Reshape cos/sin cache from (M, H) to (M, H/2)
hidden_dim = cos_cache.shape[-1]
cos_cache = cos_cache.squeeze()[:, : (hidden_dim // 2)].detach().numpy()
cos_cache = cos_cache.astype(self.to_numpy_dtype[self.io_dtype])
sin_cache = sin_cache.squeeze()[:, : (hidden_dim // 2)].detach().numpy()
sin_cache = sin_cache.astype(self.to_numpy_dtype[self.io_dtype])

if "cos_cache_name" not in kwargs and "sin_cache_name" not in kwargs:
# Save cos/sin caches to disk
self.make_external_tensor(cos_cache, cos_cache_name)
self.make_external_tensor(sin_cache, sin_cache_name)
else:
# Return cos/sin caches since they will be custom-saved
return cos_cache, sin_cache

self.rotemb_attrs["create_rotary_embedding_caches"] = False

return cos_cache_name, sin_cache_name

class Phi3Small8KModel(Model):
def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
Expand Down

0 comments on commit e8a89ad

Please sign in to comment.