From e9e0c3dab9adfc80d78200d5354ace778f6c92dc Mon Sep 17 00:00:00 2001 From: Xiang Zhang Date: Tue, 10 Sep 2024 04:34:19 -0700 Subject: [PATCH] dml specific change in model builder for phi3.5 --- src/python/py/models/builder.py | 74 ++++++++++++++++++++++++++------- 1 file changed, 58 insertions(+), 16 deletions(-) diff --git a/src/python/py/models/builder.py b/src/python/py/models/builder.py index 1887ff5cb..f2737e45e 100644 --- a/src/python/py/models/builder.py +++ b/src/python/py/models/builder.py @@ -571,6 +571,11 @@ def make_greater(self, name, inputs, shape): output = f"{name}/output_0" self.make_node("Greater", inputs=inputs, outputs=[output], name=name) self.make_value_info(output, TensorProto.BOOL, shape=shape) + + def make_greater_or_equal(self, name, inputs, shape): + output = f"{name}/output_0" + self.make_node("GreaterOrEqual", inputs=inputs, outputs=[output], name=name) + self.make_value_info(output, TensorProto.BOOL, shape=shape) def make_isinf(self, name, root_input, shape): output = f"{name}/output_0" @@ -597,6 +602,11 @@ def make_reduce_sum(self, name, inputs, dtype, shape): self.make_node("ReduceSum", inputs=inputs, outputs=[output], name=name) self.make_value_info(output, dtype, shape=shape) + def make_reduce_max(self, name, inputs, dtype, shape): + output = f"{name}/output_0" + self.make_node("ReduceMax", inputs=inputs, outputs=[output], name=name, keepdims=False) + self.make_value_info(output, dtype, shape=shape) + def make_cast(self, name, root_input, dtype, shape): output = f"{name}/output_0" self.make_node("Cast", inputs=[root_input], outputs=[output], name=name, to=dtype) @@ -924,7 +934,16 @@ def make_rotary_embedding_caches(self, rotemb, **kwargs): if self.rotemb_attrs["create_rotary_embedding_caches"]: if not hasattr(rotemb, "cos_cached"): # Create cos/sin caches if not already created - cos_cache, sin_cache = self.make_rotary_embedding_caches_from_scratch() + if self.ep == "dml": + cos_cache_large, sin_cache_large = self.make_rotary_embedding_caches_from_scratch() + self.rotemb_attrs["rescale_factors"] = self.rotemb_attrs["multi_cache"]["short_factor"] + self.rotemb_attrs["cache_length"] = self.original_context_length + self.rotemb_attrs["mscale"] = self.rotemb_attrs["multi_cache"]["short_mscale"] + cos_cache_small, sin_cache_small = self.make_rotary_embedding_caches_from_scratch() + cos_cache = torch.cat((cos_cache_small, cos_cache_large), dim=0) + sin_cache = torch.cat((sin_cache_small, sin_cache_large), dim=0) + else: + cos_cache, sin_cache = self.make_rotary_embedding_caches_from_scratch() else: cos_cache, sin_cache = rotemb.cos_cached, rotemb.sin_cached @@ -2215,22 +2234,45 @@ def make_position_ids_reformatting(self): # position_ids input for RotaryEmbedding basename = "/model/pos_ids_reformat" - shape_name = f"{basename}/Shape" - self.make_shape(shape_name, root_input="input_ids" if not self.exclude_embeds else "inputs_embeds", shape=[2] if not self.exclude_embeds else [3]) - gather_name = f"{basename}/Gather" - gather_inputs = [f"{shape_name}/output_0", "/model/constants/TensorProto.INT64/0D/1"] - self.make_gather(gather_name, gather_inputs, axis=0) - unsqueeze_name = f"{basename}/Unsqueeze" - unsqueeze_inputs = [f"{gather_name}/output_0", "/model/constants/TensorProto.INT64/1D/0"] - self.make_unsqueeze(unsqueeze_name, unsqueeze_inputs, dtype=TensorProto.INT64, shape=[1]) - concat_name = f"{basename}/Concat" - concat_inputs = ["/model/constants/TensorProto.INT64/1D/-1", f"{unsqueeze_name}/output_0"] - self.make_concat(concat_name, concat_inputs, dtype=TensorProto.INT64, shape=[2], axis=0) - reshape_name = f"{basename}/Reshape" - reshape_inputs = ["position_ids", f"{concat_name}/output_0"] - self.make_reshape(reshape_name, reshape_inputs, dtype=TensorProto.INT64, shape=None) - return reshape_name + if self.ep == "dml": + reduce_max_name = f"{basename}/ReduceMax" + reduce_max_inputs = ["position_ids"] + self.make_reduce_max(reduce_max_name, reduce_max_inputs, dtype=TensorProto.INT64, shape=[1]) + + greater_or_equal_name = f"{basename}/GreaterOrEqual" + greater_or_equal_inputs = [f"{reduce_max_name}/output_0", f"/model/constants/TensorProto.INT64/0D/{self.original_context_length}"] + self.make_greater_or_equal(greater_or_equal_name, greater_or_equal_inputs, shape=[]) + + cast_name = f"{basename}/Cast" + self.make_cast(cast_name, f"{greater_or_equal_name}/output_0", dtype=TensorProto.INT64, shape=None) + + mul_name = f"{basename}/Mul" + mul_inputs = [f"{cast_name}/output_0", f"/model/constants/TensorProto.INT64/0D/{self.original_context_length}"] + self.make_mul(mul_name, mul_inputs, dtype=TensorProto.INT64, shape=None) + + add_1_name = f"{basename}/Add_1" + add_1_inputs = [f"{mul_name}/output_0", "position_ids"] + self.make_add(add_1_name, add_1_inputs, dtype=TensorProto.INT64, shape=["batch_size", "sequence_length"]) + + return add_1_name + else: + shape_name = f"{basename}/Shape" + self.make_shape(shape_name, root_input="input_ids" if not self.exclude_embeds else "inputs_embeds", shape=[2] if not self.exclude_embeds else [3]) + gather_name = f"{basename}/Gather" + gather_inputs = [f"{shape_name}/output_0", "/model/constants/TensorProto.INT64/0D/1"] + self.make_gather(gather_name, gather_inputs, axis=0) + unsqueeze_name = f"{basename}/Unsqueeze" + unsqueeze_inputs = [f"{gather_name}/output_0", "/model/constants/TensorProto.INT64/1D/0"] + self.make_unsqueeze(unsqueeze_name, unsqueeze_inputs, dtype=TensorProto.INT64, shape=[1]) + concat_name = f"{basename}/Concat" + concat_inputs = ["/model/constants/TensorProto.INT64/1D/-1", f"{unsqueeze_name}/output_0"] + self.make_concat(concat_name, concat_inputs, dtype=TensorProto.INT64, shape=[2], axis=0) + reshape_name = f"{basename}/Reshape" + reshape_inputs = [f"position_ids", f"{concat_name}/output_0"] + self.make_reshape(reshape_name, reshape_inputs, dtype=TensorProto.INT64, shape=None) + + return reshape_name class LlamaModel(Model):