forked from quic/efficient-transformers
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor custom ops into proper file locations (quic#55)
* Refactor custom ops into proper file locations Signed-off-by: Ilango Rajagopal <quic_irajagop@quicinc.com> * Remove the old native RMSNorm code Signed-off-by: Ilango Rajagopal <quic_irajagop@quicinc.com> --------- Signed-off-by: Ilango Rajagopal <quic_irajagop@quicinc.com>
- Loading branch information
Showing
8 changed files
with
123 additions
and
379 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
49 changes: 0 additions & 49 deletions
49
QEfficient/customop/CustomRMSNorm/src/customrmsnorm_aic.cpp
This file was deleted.
Oops, something went wrong.
85 changes: 0 additions & 85 deletions
85
QEfficient/customop/CustomRMSNorm/src/customrmsnorm_functions.cpp
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# ----------------------------------------------------------------------------- | ||
# | ||
# Copyright (c) 2023-2024 Qualcomm Innovation Center, Inc. All rights reserved. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# | ||
# ----------------------------------------------------------------------------- | ||
|
||
import onnxscript | ||
import torch | ||
|
||
ops = onnxscript.opset13 | ||
|
||
|
||
@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) | ||
def CtxScatter(data: onnxscript.FLOAT, position_ids: onnxscript.INT32, updates: onnxscript.FLOAT) -> onnxscript.FLOAT: | ||
# Find dims | ||
batch_size = ops.Gather(ops.Shape(data), [0]) | ||
num_heads = ops.Gather(ops.Shape(data), [1]) | ||
seq_len = ops.Gather(ops.Shape(position_ids), [1]) | ||
|
||
# Expanded shape to create indices | ||
zero = ops.Constant(value_ints=[0]) | ||
one = ops.Constant(value_ints=[1]) | ||
exp_shape = ops.Concat(batch_size, num_heads, seq_len, one, axis=0) | ||
|
||
# Create indices | ||
batch_idx = ops.Expand(ops.Unsqueeze(ops.Range(zero, batch_size, one), [1, 2, 3]), exp_shape) | ||
head_idx = ops.Expand(ops.Unsqueeze(ops.Range(zero, num_heads, one), [0, 2, 3]), exp_shape) | ||
ctx_idx = ops.Expand(ops.Unsqueeze(position_ids, [1, 3]), exp_shape) | ||
indices = ops.Concat(batch_idx, head_idx, ctx_idx, axis=3) | ||
|
||
return ops.ScatterND(data, indices, updates) | ||
|
||
|
||
class CtxScatterFunc(torch.autograd.Function): | ||
""" | ||
Function to scatter the current key values into KV-cache. | ||
""" | ||
|
||
@staticmethod | ||
def forward(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor): | ||
batch_idx = torch.arange(data.shape[0]).view(-1, 1, 1) | ||
head_idx = torch.arange(data.shape[1]).view(1, -1, 1) | ||
ctx_idx = position_ids.unsqueeze(1) | ||
data[batch_idx, head_idx, ctx_idx] = updates | ||
return data | ||
|
||
@staticmethod | ||
def setup_context(ctx, inputs, outputs): | ||
pass | ||
|
||
@staticmethod | ||
def symbolic(g: torch.Graph, data: torch.Value, position_ids: torch.Value, updates: torch.Value) -> torch.Value: | ||
return g.onnxscript_op(CtxScatter, data, position_ids, updates).setTypeAs(data) | ||
|
||
|
||
@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) | ||
def CtxGather(data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32) -> onnxscript.FLOAT: | ||
ctx_indices = ops.Expand(ctx_indices, ops.Slice(ops.Shape(data), starts=[0], ends=[3], axes=[0])) | ||
ctx_indices = ops.Unsqueeze(ctx_indices, [-1]) | ||
return ops.GatherND(data, ctx_indices, batch_dims=2) | ||
|
||
|
||
class CtxGatherFunc(torch.autograd.Function): | ||
""" | ||
Function to gather only the valid key values from KV-cache. | ||
""" | ||
|
||
@staticmethod | ||
def forward(data: torch.Tensor, ctx_indices: torch.Tensor): | ||
batch_indices = torch.arange(data.shape[0]).view(-1, 1, 1) | ||
head_indices = torch.arange(data.shape[1]).view(1, -1, 1) | ||
return data[batch_indices, head_indices, ctx_indices] | ||
|
||
@staticmethod | ||
def setup_context(ctx, inputs, outputs): | ||
pass | ||
|
||
@staticmethod | ||
def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> torch.Value: | ||
return g.onnxscript_op(CtxGather, data, ctx_indices).setTypeAs(data) |
Oops, something went wrong.