Skip to content

Commit

Permalink
dml specific change in model builder for phi3.5
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangxiang1993 committed Sep 12, 2024
1 parent 39a0da1 commit 33fda89
Showing 1 changed file with 59 additions and 16 deletions.
75 changes: 59 additions & 16 deletions src/python/py/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,11 @@ def make_greater(self, name, inputs, shape):
output = f"{name}/output_0"
self.make_node("Greater", inputs=inputs, outputs=[output], name=name)
self.make_value_info(output, TensorProto.BOOL, shape=shape)

def make_greater_or_equal(self, name, inputs, shape):
output = f"{name}/output_0"
self.make_node("GreaterOrEqual", inputs=inputs, outputs=[output], name=name)
self.make_value_info(output, TensorProto.BOOL, shape=shape)

def make_isinf(self, name, root_input, shape):
output = f"{name}/output_0"
Expand All @@ -597,6 +602,11 @@ def make_reduce_sum(self, name, inputs, dtype, shape):
self.make_node("ReduceSum", inputs=inputs, outputs=[output], name=name)
self.make_value_info(output, dtype, shape=shape)

def make_reduce_max(self, name, inputs, dtype, shape):
output = f"{name}/output_0"
self.make_node("ReduceMax", inputs=inputs, outputs=[output], name=name, keepdims=False)
self.make_value_info(output, dtype, shape=shape)

def make_cast(self, name, root_input, dtype, shape):
output = f"{name}/output_0"
self.make_node("Cast", inputs=[root_input], outputs=[output], name=name, to=dtype)
Expand Down Expand Up @@ -924,7 +934,16 @@ def make_rotary_embedding_caches(self, rotemb, **kwargs):
if self.rotemb_attrs["create_rotary_embedding_caches"]:
if not hasattr(rotemb, "cos_cached"):
# Create cos/sin caches if not already created
cos_cache, sin_cache = self.make_rotary_embedding_caches_from_scratch()
if self.ep == "dml":
cos_cache_large, sin_cache_large = self.make_rotary_embedding_caches_from_scratch()
self.rotemb_attrs["rescale_factors"] = self.rotemb_attrs["multi_cache"]["short_factor"]
self.rotemb_attrs["cache_length"] = self.original_context_length
self.rotemb_attrs["mscale"] = self.rotemb_attrs["multi_cache"]["short_mscale"]
cos_cache_small, sin_cache_small = self.make_rotary_embedding_caches_from_scratch()
cos_cache = torch.cat((cos_cache_small, cos_cache_large), dim=0)
sin_cache = torch.cat((sin_cache_small, sin_cache_large), dim=0)
else:
cos_cache, sin_cache = self.make_rotary_embedding_caches_from_scratch()
else:
cos_cache, sin_cache = rotemb.cos_cached, rotemb.sin_cached

Expand Down Expand Up @@ -2215,22 +2234,46 @@ def make_position_ids_reformatting(self):
# position_ids input for RotaryEmbedding

basename = "/model/pos_ids_reformat"
shape_name = f"{basename}/Shape"
self.make_shape(shape_name, root_input="input_ids" if not self.exclude_embeds else "inputs_embeds", shape=[2] if not self.exclude_embeds else [3])
gather_name = f"{basename}/Gather"
gather_inputs = [f"{shape_name}/output_0", "/model/constants/TensorProto.INT64/0D/1"]
self.make_gather(gather_name, gather_inputs, axis=0)
unsqueeze_name = f"{basename}/Unsqueeze"
unsqueeze_inputs = [f"{gather_name}/output_0", "/model/constants/TensorProto.INT64/1D/0"]
self.make_unsqueeze(unsqueeze_name, unsqueeze_inputs, dtype=TensorProto.INT64, shape=[1])
concat_name = f"{basename}/Concat"
concat_inputs = ["/model/constants/TensorProto.INT64/1D/-1", f"{unsqueeze_name}/output_0"]
self.make_concat(concat_name, concat_inputs, dtype=TensorProto.INT64, shape=[2], axis=0)
reshape_name = f"{basename}/Reshape"
reshape_inputs = ["position_ids", f"{concat_name}/output_0"]
self.make_reshape(reshape_name, reshape_inputs, dtype=TensorProto.INT64, shape=None)

return reshape_name
if self.ep == "dml":
reduce_max_name = f"{basename}/ReduceMax"
reduce_max_inputs = ["position_ids"]
self.make_reduce_max(reduce_max_name, reduce_max_inputs, dtype=TensorProto.INT64, shape=[1])

greater_or_equal_name = f"{basename}/GreaterOrEqual"
greater_or_equal_inputs = [f"{reduce_max_name}/output_0", f"/model/constants/TensorProto.INT64/0D/{self.original_context_length}"]
self.make_greater_or_equal(greater_or_equal_name, greater_or_equal_inputs, shape=[])

cast_name = f"{basename}/Cast"
self.make_cast(cast_name, f"{greater_or_equal_name}/output_0", dtype=TensorProto.INT64, shape=None)

mul_name = f"{basename}/Mul"
mul_inputs = [f"{cast_name}/output_0", f"/model/constants/TensorProto.INT64/0D/{self.original_context_length}"]
self.make_mul(mul_name, mul_inputs, dtype=TensorProto.INT64, shape=None)

add_1_name = f"{basename}/Add_1"
add_1_inputs = [f"{mul_name}/output_0", "position_ids"]
self.make_add(add_1_name, add_1_inputs, dtype=TensorProto.INT64, shape=["batch_size", "sequence_length"])

return add_1_name
else:
shape_name = f"{basename}/Shape"
self.make_shape(shape_name, root_input="input_ids" if not self.exclude_embeds else "inputs_embeds", shape=[2] if not self.exclude_embeds else [3])
gather_name = f"{basename}/Gather"
gather_inputs = [f"{shape_name}/output_0", "/model/constants/TensorProto.INT64/0D/1"]
self.make_gather(gather_name, gather_inputs, axis=0)
unsqueeze_name = f"{basename}/Unsqueeze"
unsqueeze_inputs = [f"{gather_name}/output_0", "/model/constants/TensorProto.INT64/1D/0"]
self.make_unsqueeze(unsqueeze_name, unsqueeze_inputs, dtype=TensorProto.INT64, shape=[1])
concat_name = f"{basename}/Concat"
concat_inputs = ["/model/constants/TensorProto.INT64/1D/-1", f"{unsqueeze_name}/output_0"]
self.make_concat(concat_name, concat_inputs, dtype=TensorProto.INT64, shape=[2], axis=0)

reshape_name = f"{basename}/Reshape"
reshape_inputs = [f"{add_1_name}/output_0", f"{concat_name}/output_0"]

Check failure

Code scanning / CodeQL

Potentially uninitialized local variable Error

Local variable 'add_1_name' may be used before it is initialized.
self.make_reshape(reshape_name, reshape_inputs, dtype=TensorProto.INT64, shape=None)

return reshape_name


class LlamaModel(Model):
Expand Down

0 comments on commit 33fda89

Please sign in to comment.