From 8265edf36e522d7b7b9812cc05f87c8553274bbe Mon Sep 17 00:00:00 2001 From: Ilango Rajagopal Date: Wed, 10 Jul 2024 14:41:07 +0530 Subject: [PATCH] Refactor custom ops into proper file locations (#55) * Refactor custom ops into proper file locations Signed-off-by: Ilango Rajagopal * Remove the old native RMSNorm code Signed-off-by: Ilango Rajagopal --------- Signed-off-by: Ilango Rajagopal --- QEfficient/customop/CMakeLists.txt | 25 ---- .../CustomRMSNorm/src/customrmsnorm_aic.cpp | 49 -------- .../src/customrmsnorm_functions.cpp | 85 ------------- QEfficient/customop/README.md | 20 --- QEfficient/customop/__init__.py | 119 +----------------- QEfficient/customop/ctx_scatter_gather.py | 81 ++++++++++++ QEfficient/customop/custom_rms_op_config.yaml | 35 ------ QEfficient/customop/rms_norm.py | 88 ++++++------- 8 files changed, 123 insertions(+), 379 deletions(-) delete mode 100644 QEfficient/customop/CMakeLists.txt delete mode 100644 QEfficient/customop/CustomRMSNorm/src/customrmsnorm_aic.cpp delete mode 100644 QEfficient/customop/CustomRMSNorm/src/customrmsnorm_functions.cpp delete mode 100644 QEfficient/customop/README.md create mode 100644 QEfficient/customop/ctx_scatter_gather.py delete mode 100644 QEfficient/customop/custom_rms_op_config.yaml diff --git a/QEfficient/customop/CMakeLists.txt b/QEfficient/customop/CMakeLists.txt deleted file mode 100644 index 3f16c930..00000000 --- a/QEfficient/customop/CMakeLists.txt +++ /dev/null @@ -1,25 +0,0 @@ -# ----------------------------------------------------------------------------- -# -#Copyright (c) 2023-2024 Qualcomm Innovation Center, Inc. All rights reserved. -#SPDX-License-Identifier: BSD-3-Clause -# -# ----------------------------------------------------------------------------- - -# Set the minimum version of CMake required to build the project -cmake_minimum_required(VERSION 3.0) - -# Set the name of the project -project(custom_lib) - -# Set the C++ standard to use -set(CMAKE_CXX_STANDARD 11) - -# Set the source files for the library -set(SOURCE_FILES ${SRC_FILE}) - -# Set the include directories for the library -include_directories(/opt/qti-aic/dev/inc) - -# Create the shared library -add_library(${CUSTOM_LIB} SHARED ${SOURCE_FILES}) - diff --git a/QEfficient/customop/CustomRMSNorm/src/customrmsnorm_aic.cpp b/QEfficient/customop/CustomRMSNorm/src/customrmsnorm_aic.cpp deleted file mode 100644 index f9c9c02b..00000000 --- a/QEfficient/customop/CustomRMSNorm/src/customrmsnorm_aic.cpp +++ /dev/null @@ -1,49 +0,0 @@ -/* ----------------------------------------------------------------------------- - * - * Copyright (c) 2023-2024 Qualcomm Innovation Center, Inc. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * ----------------------------------------------------------------------------- - */ - -/* Interface version: 5.0.0 */ -#include "CustomOpAICInterface.h" -#include "stddef.h" -#include "CustomOpLog.h" -#include -#include -#include -// #include "defs.h" -#define NUM_THREADS (4) -extern "C" { -/* The AIC compilation target supports an API similar to the Interpreter -API. Additionally, threadId, which is the AIC thread ID, is passed. -Kernel is invoked by four AIC threads with threadId equal to 0, 1, 2, and 3. */ -void CustomRMSNormAIC(const CustomOpContext *context, const int32_t threadId) { - int32_t *input_dims = context->inputs[0].sizes; - int32_t input_rank = context->inputs[0].rank; - int32_t batch_size = input_dims[0]; - int32_t sequence_length = input_dims[1]; - int32_t hidden_size = input_dims[2]; - float eps = *(float *)context->params[0].data; - float16_ty *hidden_states = (float16_ty *)context->inputs[0].data; - float16_ty *weight = (float16_ty *)context->inputs[1].data; - float16_ty *output = (float16_ty *)context->outputs[0].data; - // Calculate reciprocal hidden size to avoid division multiple times - float r_hidden_size = 1.0f / hidden_size; - for (int i = threadId; i < sequence_length; i += NUM_THREADS) { - float variance = 0.0f; - int layer_offset = i * hidden_size; - for (int j = 0; j < hidden_size; j++) { - float val = hidden_states[j + layer_offset]; - variance += val * val; - } - variance *= r_hidden_size; - float rms = sqrt(variance + eps); - float r_rms = 1.0f / rms; - for (int j = 0; j < hidden_size; j++){ - output[j + layer_offset] = hidden_states[j + layer_offset] * r_rms * weight[j]; - } - } - } -} diff --git a/QEfficient/customop/CustomRMSNorm/src/customrmsnorm_functions.cpp b/QEfficient/customop/CustomRMSNorm/src/customrmsnorm_functions.cpp deleted file mode 100644 index 7dc78998..00000000 --- a/QEfficient/customop/CustomRMSNorm/src/customrmsnorm_functions.cpp +++ /dev/null @@ -1,85 +0,0 @@ -/* ----------------------------------------------------------------------------- - * - * Copyright (c) 2023-2024 Qualcomm Innovation Center, Inc. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * ----------------------------------------------------------------------------- - */ - -/* Interface version: 5.0.0 */ -#include "CustomOpFunctions.h" -#include "CustomOpInterpreterInterface.h" -#include "CustomOpTypes.h" -#include -extern "C" { -bool customOpVerify(const CustomOpProperties *const opProp) { - /* Refer to function declaration at CustomOpFunctions.h for usage. */ - // Must have two params. - if (opProp->params.size() < 1) - return false; - auto ¶m0 = opProp->params[0]; - // Check params names are valid. - if (strcmp(param0.name, "eps")) - return false; - // Op must have only 2 input and 1 output. - if (opProp->inputs.size() != 2 || opProp->outputs.size() != 1) - return false; - // Input and Output must have the same data type. - if (opProp->inputs[0].dtype != opProp->outputs[0].dtype) - return false; - // Input and Output must have the same dimensions. - if (opProp->inputs[0].rank != opProp->outputs[0].rank) - return false; - for (int i = 0; i < opProp->inputs[0].rank; i++) { - if (opProp->inputs[0].sizes[i] != opProp->outputs[0].sizes[i]) - return false; - } - return true; -} -const char *customOpSelectImpl(const CustomOpProperties *const opProp, - const CustomOpKernelInfo *const kernelInfos, - const int32_t numKernels, const char *backend) { - /* Refer to function declaration at CustomOpFunctions.h for usage. */ - /* For AIC pick 'AIC', for Interpreter pick 'Interpreter' */ - if (strcmp(backend, "AIC") == 0) { - return "CustomRMSNormAIC"; - } else if (strcmp(backend, "Interpreter") == 0) { - return "CustomReluInterpreter"; - } -} -bool customOpInferShape(CustomOpProperties *const opProp) { - /* Refer to function declaration at CustomOpFunctions.h for usage. */ - if (opProp->inputs.size() != 2 || opProp->outputs.size() != 1) - return false; - // There is only 1 output. - // Output has the same type as input. - CustomOpIOTensor &out = opProp->outputs[0]; - out.rank = opProp->inputs[0].rank; - for (int i = 0; i < opProp->inputs[0].rank; i++) { - out.sizes[i] = opProp->inputs[0].sizes[i]; - } - out.dtype = opProp->inputs[0].dtype; - return true; -} -bool customOpSetProperties(CustomOpProperties *const opProp) { - /* Refer to function declaration at CustomOpFunctions.h for usage. */ - if (opProp->inputs[0].sizes[0] > 1) - { - setTileConfig(opProp, "output", {0, 1}); - return true; - } - return false; -} -bool customOpMapTiles(CustomOpProperties *const opProp) { - /* Refer to function declaration at CustomOpFunctions.h for usage. */ - // Get output start and end indices - if (opProp->inputs[0].sizes[0] > 1) - { - const std::vector startIndices = tileStartIndices(opProp->outputs[0]); - const std::vector endIndices = tileEndIndices(opProp->outputs[0]); - createInputTile(opProp, 0, startIndices, endIndices); - return true; - } - return false; -} -} diff --git a/QEfficient/customop/README.md b/QEfficient/customop/README.md deleted file mode 100644 index 2c172d27..00000000 --- a/QEfficient/customop/README.md +++ /dev/null @@ -1,20 +0,0 @@ -# CustomOP Compilation - -This is an example for compiling the custom rms norm op, this should be used when we see the deviation in outputs or out of range fp16 values due to rms norm kernel. This will handle the out of range issues without affecting accuracy and perf. - -Note: Minimum CMAKE version is 3. (g++ compiler and c++11) - -## Steps to compile the custom op -**Note**: Provide specific SRC_FILE path accordingly, below example is for custom rms norm -``` -mkdir build && cd build -cmake .. -D CMAKE_CXX_FLAGS="-Wall" -D SRC_FILE=../CustomRMSNorm/src/customrmsnorm_functions.cpp -D CUSTOM_LIB=customrmsnorm_lib -make all -cd .. -``` - -## Move the custom op shared library to the customop src directory -**Note**: Provide specific build so path accordingly, below example is for custom rms norm -``` -mv build/customrmsnorm_lib.so CustomRMSNorm/src/ -``` \ No newline at end of file diff --git a/QEfficient/customop/__init__.py b/QEfficient/customop/__init__.py index d13c8d90..9bfd0899 100644 --- a/QEfficient/customop/__init__.py +++ b/QEfficient/customop/__init__.py @@ -5,120 +5,7 @@ # # ----------------------------------------------------------------------------- -""" -RMS Norm CustomOp Node in com.qti.aisw.onnx Domain for Cloud AI 100 -This is to handle the FP16 Overflow seen in RMS Norm for LLMs +from QEfficient.customop.ctx_scatter_gather import CtxGatherFunc, CtxScatterFunc +from QEfficient.customop.rms_norm import CustomRMSNormAIC -""" - -import onnxscript -import torch -from onnxscript.onnx_opset import opset13 as ops -from torch import nn - -opset_version = 13 -custom_opset = onnxscript.values.Opset(domain="com.qti.aisw.onnx", version=1) - - -# Version 1 -@onnxscript.script(custom_opset) -def CustomRMSNorm(hidden_states: onnxscript.FLOAT, weight: onnxscript.FLOAT, epsilon: float): - weight = ops.Cast(weight, to=1) - variance = ops.ReduceMean(ops.Pow(hidden_states, 2), axes=[-1], keepdims=1) - epsilon = ops.Expand(epsilon, ops.Shape(variance)) - hidden_states = hidden_states * ops.Reciprocal(ops.Sqrt(variance + epsilon)) - return weight * hidden_states - - -class CustomRMSNormOp(torch.autograd.Function): - @staticmethod - def forward(hidden_states: torch.Tensor, weight: torch.Tensor, epsilon: float): - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + epsilon) - return weight * hidden_states - - @staticmethod - def setup_context(ctx, inputs, outputs): - pass - - @staticmethod - def symbolic( - g: torch.onnx._internal.jit_utils.GraphContext, - hidden_states: torch.Value, - weight: torch.Value, - epsilon: torch.Value, - ) -> torch.Value: - return g.onnxscript_op(CustomRMSNorm, hidden_states, weight, epsilon_f=epsilon).setTypeAs(hidden_states) - - -class CustomRMSNormAIC(nn.Module): - def __init__(self, hidden_size, eps=1e-05): - super(CustomRMSNormAIC, self).__init__() - self.variance_epsilon = eps - self.weight = torch.nn.Parameter(torch.ones(hidden_size)) - - def forward(self, hidden_states): - output = CustomRMSNormOp.apply(hidden_states, self.weight, self.variance_epsilon) - return output - - -@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) -def CtxScatter(data: onnxscript.FLOAT, position_ids: onnxscript.INT32, updates: onnxscript.FLOAT) -> onnxscript.FLOAT: - # Find dims - batch_size = ops.Gather(ops.Shape(data), [0]) - num_heads = ops.Gather(ops.Shape(data), [1]) - seq_len = ops.Gather(ops.Shape(position_ids), [1]) - - # Expanded shape to create indices - zero = ops.Constant(value_ints=[0]) - one = ops.Constant(value_ints=[1]) - exp_shape = ops.Concat(batch_size, num_heads, seq_len, one, axis=0) - - # Create indices - batch_idx = ops.Expand(ops.Unsqueeze(ops.Range(zero, batch_size, one), [1, 2, 3]), exp_shape) - head_idx = ops.Expand(ops.Unsqueeze(ops.Range(zero, num_heads, one), [0, 2, 3]), exp_shape) - ctx_idx = ops.Expand(ops.Unsqueeze(position_ids, [1, 3]), exp_shape) - indices = ops.Concat(batch_idx, head_idx, ctx_idx, axis=3) - - return ops.ScatterND(data, indices, updates) - - -class CtxScatterFunc(torch.autograd.Function): - @staticmethod - def forward(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor): - batch_idx = torch.arange(data.shape[0]).view(-1, 1, 1) - head_idx = torch.arange(data.shape[1]).view(1, -1, 1) - ctx_idx = position_ids.unsqueeze(1) - data[batch_idx, head_idx, ctx_idx] = updates - return data - - @staticmethod - def setup_context(ctx, inputs, outputs): - pass - - @staticmethod - def symbolic(g: torch.Graph, data: torch.Value, position_ids: torch.Value, updates: torch.Value) -> torch.Value: - return g.onnxscript_op(CtxScatter, data, position_ids, updates).setTypeAs(data) - - -@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) -def CtxGather(data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32) -> onnxscript.FLOAT: - ctx_indices = ops.Expand(ctx_indices, ops.Slice(ops.Shape(data), starts=[0], ends=[3], axes=[0])) - ctx_indices = ops.Unsqueeze(ctx_indices, [-1]) - return ops.GatherND(data, ctx_indices, batch_dims=2) - - -class CtxGatherFunc(torch.autograd.Function): - @staticmethod - def forward(data: torch.Tensor, ctx_indices: torch.Tensor): - batch_indices = torch.arange(data.shape[0]).view(-1, 1, 1) - head_indices = torch.arange(data.shape[1]).view(1, -1, 1) - return data[batch_indices, head_indices, ctx_indices] - - @staticmethod - def setup_context(ctx, inputs, outputs): - pass - - @staticmethod - def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> torch.Value: - return g.onnxscript_op(CtxGather, data, ctx_indices).setTypeAs(data) +__all__ = ["CtxGatherFunc", "CtxScatterFunc", "CustomRMSNormAIC"] diff --git a/QEfficient/customop/ctx_scatter_gather.py b/QEfficient/customop/ctx_scatter_gather.py new file mode 100644 index 00000000..fa615b46 --- /dev/null +++ b/QEfficient/customop/ctx_scatter_gather.py @@ -0,0 +1,81 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2023-2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import onnxscript +import torch + +ops = onnxscript.opset13 + + +@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) +def CtxScatter(data: onnxscript.FLOAT, position_ids: onnxscript.INT32, updates: onnxscript.FLOAT) -> onnxscript.FLOAT: + # Find dims + batch_size = ops.Gather(ops.Shape(data), [0]) + num_heads = ops.Gather(ops.Shape(data), [1]) + seq_len = ops.Gather(ops.Shape(position_ids), [1]) + + # Expanded shape to create indices + zero = ops.Constant(value_ints=[0]) + one = ops.Constant(value_ints=[1]) + exp_shape = ops.Concat(batch_size, num_heads, seq_len, one, axis=0) + + # Create indices + batch_idx = ops.Expand(ops.Unsqueeze(ops.Range(zero, batch_size, one), [1, 2, 3]), exp_shape) + head_idx = ops.Expand(ops.Unsqueeze(ops.Range(zero, num_heads, one), [0, 2, 3]), exp_shape) + ctx_idx = ops.Expand(ops.Unsqueeze(position_ids, [1, 3]), exp_shape) + indices = ops.Concat(batch_idx, head_idx, ctx_idx, axis=3) + + return ops.ScatterND(data, indices, updates) + + +class CtxScatterFunc(torch.autograd.Function): + """ + Function to scatter the current key values into KV-cache. + """ + + @staticmethod + def forward(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor): + batch_idx = torch.arange(data.shape[0]).view(-1, 1, 1) + head_idx = torch.arange(data.shape[1]).view(1, -1, 1) + ctx_idx = position_ids.unsqueeze(1) + data[batch_idx, head_idx, ctx_idx] = updates + return data + + @staticmethod + def setup_context(ctx, inputs, outputs): + pass + + @staticmethod + def symbolic(g: torch.Graph, data: torch.Value, position_ids: torch.Value, updates: torch.Value) -> torch.Value: + return g.onnxscript_op(CtxScatter, data, position_ids, updates).setTypeAs(data) + + +@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) +def CtxGather(data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32) -> onnxscript.FLOAT: + ctx_indices = ops.Expand(ctx_indices, ops.Slice(ops.Shape(data), starts=[0], ends=[3], axes=[0])) + ctx_indices = ops.Unsqueeze(ctx_indices, [-1]) + return ops.GatherND(data, ctx_indices, batch_dims=2) + + +class CtxGatherFunc(torch.autograd.Function): + """ + Function to gather only the valid key values from KV-cache. + """ + + @staticmethod + def forward(data: torch.Tensor, ctx_indices: torch.Tensor): + batch_indices = torch.arange(data.shape[0]).view(-1, 1, 1) + head_indices = torch.arange(data.shape[1]).view(1, -1, 1) + return data[batch_indices, head_indices, ctx_indices] + + @staticmethod + def setup_context(ctx, inputs, outputs): + pass + + @staticmethod + def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> torch.Value: + return g.onnxscript_op(CtxGather, data, ctx_indices).setTypeAs(data) diff --git a/QEfficient/customop/custom_rms_op_config.yaml b/QEfficient/customop/custom_rms_op_config.yaml deleted file mode 100644 index 446921b2..00000000 --- a/QEfficient/customop/custom_rms_op_config.yaml +++ /dev/null @@ -1,35 +0,0 @@ -# ----------------------------------------------------------------------------- -# -#Copyright (c) 2023-2024 Qualcomm Innovation Center, Inc. All rights reserved. -#SPDX-License-Identifier: BSD-3-Clause -# -# ----------------------------------------------------------------------------- - ---- -version: 5.0.0 -CustomOps: - - type: CustomRMSNorm - package: QAic - inputs: - - name: hidden_states - maxDims: 3 - - name: weight - maxDims: 1 - parameters: - - name: eps - dataType: float - scalar: true - outputs: - - name: output - maxDims: 3 - functionsLibrary: CustomRMSNorm/src/customrmsnorm_lib.so - implementations: - - backend: AIC - type: CustomRMSNormAIC - impl: CustomRMSNorm/src/customrmsnorm_aic.cpp - memoryConfig: - DDR: - CacheableDDR: - VTCM: [hidden_states, weight, output] - requiredFor: -... diff --git a/QEfficient/customop/rms_norm.py b/QEfficient/customop/rms_norm.py index d7f4fd2d..210cca68 100644 --- a/QEfficient/customop/rms_norm.py +++ b/QEfficient/customop/rms_norm.py @@ -1,61 +1,51 @@ # ----------------------------------------------------------------------------- # -# Copyright (c) 2023-2024 Qualcomm Innovation Center, Inc. All rights reserved. +# Copyright (c) 2023-2024 Qualcomm Innovation Center, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- -""" -RMS Norm CustomOp Node in QAic Domain for Cloud AI 100 -This is to handle the FP16 Overflow seen in RMS Norm for LLMs -""" - +import onnxscript import torch -from torch.onnx.symbolic_helper import parse_args - -op_source = """ -#include - -torch::Tensor custom_rms_norm(torch::Tensor hidden_states, torch::Tensor weight, double eps) { - torch::Tensor output; - torch::Tensor variance; - bool keepdim; - // double eps = 1e-5; - variance = hidden_states.pow(2).mean(-1, keepdim=true); - output = hidden_states * torch::rsqrt(variance + eps); - output = output * weight; - return output; -} - -TORCH_LIBRARY(QAic, m) { - m.def("QEffCustomRMSNorm", &custom_rms_norm); -} -""" - -# Compile and load the custom op -torch.utils.cpp_extension.load_inline( - name="custom_rms_norm", - cpp_sources=op_source, - is_python_module=False, - verbose=True, -) - - -# Wrapper module for custom relu C++ op -class QEffCustomRMSNorm(torch.nn.Module): - def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs): - super().__init__() - self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype)) - self.eps = eps +from torch import nn + +ops = onnxscript.opset13 + + +@onnxscript.script(onnxscript.values.Opset(domain="com.qti.aisw.onnx", version=1)) +def CustomRMSNorm(hidden_states: onnxscript.FLOAT, weight: onnxscript.FLOAT, epsilon: float): + weight = ops.Cast(weight, to=1) + variance = ops.ReduceMean(ops.Pow(hidden_states, 2), axes=[-1], keepdims=1) + epsilon = ops.Expand(epsilon, ops.Shape(variance)) + hidden_states = hidden_states * ops.Reciprocal(ops.Sqrt(variance + epsilon)) + return weight * hidden_states - def forward(self, hidden_states): - return torch.ops.QAic.QEffCustomRMSNorm(hidden_states, self.weight, self.eps) +class CustomRMSNormFunc(torch.autograd.Function): + @staticmethod + def forward(hidden_states: torch.Tensor, weight: torch.Tensor, epsilon: float): + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + epsilon) + return weight * hidden_states -# ONNX export symbolic helper -@parse_args("v", "v", "f") -def custom_rms_norm(g, hidden_states, weight, eps): - return g.op("QAic::QEffCustomRMSNorm", hidden_states, weight, eps_f=eps).setTypeAs(hidden_states) + @staticmethod + def setup_context(ctx, inputs, outputs): + pass + @staticmethod + def symbolic(g: torch.Graph, hidden_states: torch.Value, weight: torch.Value, epsilon: torch.Value) -> torch.Value: + return g.onnxscript_op(CustomRMSNorm, hidden_states, weight, epsilon_f=epsilon).setTypeAs(hidden_states) -torch.onnx.register_custom_op_symbolic("QAic::QEffCustomRMSNorm", custom_rms_norm, 1) + +class CustomRMSNormAIC(nn.Module): + """ + RMSNorm module that works by replacing the current module with compiler known custom-op. + """ + + def __init__(self, hidden_size, eps=1e-05): + super(CustomRMSNormAIC, self).__init__() + self.variance_epsilon = eps + self.weight = torch.nn.Parameter(torch.ones(hidden_size)) + + def forward(self, hidden_states): + return CustomRMSNormFunc.apply(hidden_states, self.weight, self.variance_epsilon)