Skip to content

Commit

Permalink
LayerNormalization broadcast (limited support for axis=2) (#23297)
Browse files Browse the repository at this point in the history
### Description

Spec of LayerNormalization supports broadcasting (tensors Scale and B
should be unidirectional broadcastable to tensor X).
https://onnx.ai/onnx/operators/onnx__LayerNormalization.html
However, current implementation only allow scale and bias size to be
X.shape()[axis:].

Example of input tensors that normalized with axis=2:

| X shape |  Scale shape | B shape | Before | After |
| - | - | - | - | - |
| (B, S, D) | (D) | (D) | Supported | Supported |
| (B, S, D) | (1, 1, D) | (1, 1, D) | Supported | Supported |
| (B, S, D) | (B, 1, D) | (B, 1, D) | Not Supported | Supported |
| (B, S, D) | (1, S, D) | (1, S, D) | Not Supported | Supported |
| (B, S, D) | (B, S, D) | (B, S, D) | Not Supported | Supported |


Here we add limited support: axis=2; scale/bias has same shape;
scale/bias/X have same number of dimensions. It could support common use
case in LLM and vision models.

### Motivation and Context

Support Stable Diffusion 3.x and Flux model.
  • Loading branch information
tianleiwu authored Jan 11, 2025
1 parent a74817a commit 73f5b0c
Show file tree
Hide file tree
Showing 9 changed files with 333 additions and 66 deletions.
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ Status SkipLayerNorm<T, Simplified>::ComputeInternal(OpKernelContext* ctx) const
(double)epsilon_, // epsilon
reinterpret_cast<const CudaT*>(gamma->Data<T>()), // gamma
(beta != nullptr) ? reinterpret_cast<const CudaT*>(beta->Data<T>()) : nullptr, // beta
0, // no broadcast for gamma/beta
reinterpret_cast<const CudaT*>(skip->Data<T>()), // skip or residual to add
(bias != nullptr) ? reinterpret_cast<const CudaT*>(bias->Data<T>()) : nullptr, // bias to add
sum_output != nullptr ? reinterpret_cast<CudaT*>(sum_output->MutableData<T>()) : nullptr);
Expand Down
116 changes: 116 additions & 0 deletions onnxruntime/core/providers/cpu/nn/layer_norm_helper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/framework/tensor_shape.h"
#include "core/common/status.h"
#include "core/common/narrow.h"

namespace onnxruntime {

constexpr const char* kLayerNormInputShapeMismatchError =
"Size of scale and bias (if provided) must match X.shape[axis:], "
"or scale and bias (with same shape) can be broadcasted to X when axis is 2.";

constexpr const char* kLayerNormInvalidSize = "Size of X.shape[axis:] must be larger than 1, got ";

constexpr int64_t kLayerNormInvalidInput = -1;

struct LayerNormParams {
int64_t num_rows;
int64_t norm_size; // size per row
int64_t scale_size;
int64_t bias_size;
int64_t broadcast_param;
};

// We support broadcasting for axis=2, where the first two dimensions are rows, and the rest are columns.
// When X shape is (B, S, ...), and x_row (index of one row in X) is in the range of [0, B * S).
// We support scale and bias shape like below:
// When scale and bias shape is (1, 1, ...) or (...), value of broadcast_param is 0.
// When scale and bias shape is (B, 1, ...), value of broadcast_param is S.
// When scale and bias shape is (B, S, ...), value of broadcast_param is 1.
// When scale and bias shape is (1, S, ...), value of broadcast_param is -S.

// Below is a macro to compute the offset for scale and bias data for a row of X.
#ifndef LAYER_NORM_SCALE_BIAS_OFFSET
#define LAYER_NORM_SCALE_BIAS_OFFSET(broadcast_param, x_row, norm_size) \
((broadcast_param == 0) ? 0 \
: norm_size * (broadcast_param > 0 ? x_row / broadcast_param : x_row % (-broadcast_param)))
#endif

class LayerNormHelper {
public:
static Status CheckInputs(const TensorShape& x_shape,
const TensorShape& scale_shape,
const TensorShape& bias_shape,
bool has_bias,
int64_t axis,
LayerNormParams& params) {
params.num_rows = x_shape.SizeToDimension(onnxruntime::narrow<size_t>(axis));
params.norm_size = x_shape.SizeFromDimension(onnxruntime::narrow<size_t>(axis));
params.scale_size = scale_shape.Size();
params.bias_size = bias_shape.Size();
params.broadcast_param = 0;

if (params.norm_size <= 1) {
params.broadcast_param = kLayerNormInvalidInput;
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, kLayerNormInvalidSize, params.norm_size);
} else if (params.scale_size != params.norm_size || (has_bias && params.bias_size != params.scale_size)) {
params.broadcast_param = GetBroadcastParam(x_shape, scale_shape, has_bias ? &bias_shape : nullptr, axis);
if (params.broadcast_param == kLayerNormInvalidInput) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
kLayerNormInputShapeMismatchError,
" X.shape=", x_shape,
" scale.shape=", scale_shape,
" bias.shape=", bias_shape,
" and axis=", axis);
}
}
return Status::OK();
}

private:
static int64_t GetBroadcastParam(const TensorShape& x_shape,
const TensorShape& scale_shape,
const TensorShape* bias_shape,
int64_t axis) {
// Note that when size of scale and bias is norm_size, it won't enter this function (see CheckInputs).

// X shape is (B, S, ...)
if (axis == 2 &&
x_shape.NumDimensions() >= 3 &&
x_shape.NumDimensions() == scale_shape.NumDimensions() &&
(bias_shape == nullptr || *bias_shape == scale_shape)) {
for (size_t i = 2; i < x_shape.NumDimensions(); ++i) {
if (x_shape.GetDims()[i] != scale_shape.GetDims()[i]) {
// scale cannot be broadcasted to X. It is invalid input.
return kLayerNormInvalidInput;
}
}

if (x_shape.GetDims()[0] == scale_shape.GetDims()[0]) {
// scale and bias shape is (B, S, ...).
if (x_shape.GetDims()[1] == scale_shape.GetDims()[1]) {
return 1;
}

// scale and bias shape is (B, 1, ...), returns S
if (scale_shape.GetDims()[1] == 1) {
return x_shape.GetDims()[1];
}
} else if (scale_shape.GetDims()[0] == 1) {
// scale and bias shape is (1, S, ...), returns -S
if (x_shape.GetDims()[1] == scale_shape.GetDims()[1]) {
return -(x_shape.GetDims()[1]);
}
}
}

// Other cases that are not supported.
return kLayerNormInvalidInput;
}
};

} // namespace onnxruntime
62 changes: 31 additions & 31 deletions onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

#include "layer_norm_impl.h"
#include "layer_norm_helper.h"

#include "core/common/safeint.h"
#include "core/framework/tensor.h"
Expand All @@ -24,6 +25,7 @@ void ComputeJob(
const T* bias_data,
const ptrdiff_t task_idx,
const int64_t norm_size,
const int64_t broadcast_param,
const float* scale_float_ptr,
const float* bias_float_ptr,
float epsilon,
Expand Down Expand Up @@ -55,13 +57,16 @@ void ComputeJob(
mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon);
}

for (int64_t h = 0; h < norm_size; h++) {
// Compute the offset of gamma and beta to support broadcasting.
int64_t i = LAYER_NORM_SCALE_BIAS_OFFSET(broadcast_param, task_idx, norm_size);

for (int64_t h = 0; h < norm_size; h++, i++) {
if (simplified) {
p_output[h] = p_output[h] / mean_square * scale_data[h];
p_output[h] = p_output[h] / mean_square * scale_data[i];
} else if (nullptr == bias_data) {
p_output[h] = (p_output[h] - mean) / mean_square * scale_data[h];
p_output[h] = (p_output[h] - mean) / mean_square * scale_data[i];
} else {
p_output[h] = (p_output[h] - mean) / mean_square * scale_data[h] + bias_data[h];
p_output[h] = (p_output[h] - mean) / mean_square * scale_data[i] + bias_data[i];
}
}

Expand All @@ -82,6 +87,7 @@ void ComputeJob(
const MLFloat16* bias_data,
const ptrdiff_t task_idx,
const int64_t norm_size,
const int64_t broadcast_param,
const float* scale_float_ptr,
const float* bias_float_ptr,
float epsilon,
Expand Down Expand Up @@ -120,13 +126,16 @@ void ComputeJob(
mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon);
}

for (size_t h = 0; h < num_elems; h++) {
// Compute the offset of gamma and beta to support broadcasting.
int64_t i = LAYER_NORM_SCALE_BIAS_OFFSET(broadcast_param, task_idx, norm_size);

for (size_t h = 0; h < num_elems; h++, i++) {
if (simplified) {
output_float_ptr[h] = output_float_ptr[h] / mean_square * scale_float_ptr[h];
output_float_ptr[h] = output_float_ptr[h] / mean_square * scale_float_ptr[i];
} else if (nullptr == bias_float_ptr) {
output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * scale_float_ptr[h];
output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * scale_float_ptr[i];
} else {
output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * scale_float_ptr[h] + bias_float_ptr[h];
output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * scale_float_ptr[i] + bias_float_ptr[i];
}
}

Expand Down Expand Up @@ -161,9 +170,7 @@ LayerNormImpl::LayerNormImpl(const OpKernelInfo& op_kernel_info, bool simplified
simplified_{simplified},
contrib_op_{contrib_op},
prepacked_scale_fp32_data_(nullptr),
prepacked_scale_fp32_size_(0),
prepacked_bias_fp32_data_(nullptr),
prepacked_bias_fp32_size_(0) {
prepacked_bias_fp32_data_(nullptr) {
ORT_ENFORCE(op_kernel_info.GetAttr("axis", &axis_).IsOK());
ORT_ENFORCE(op_kernel_info.GetAttr<float>("epsilon", &epsilon_).IsOK());
}
Expand All @@ -179,8 +186,8 @@ Status LayerNormImpl::ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, flo
const T* bias_data = (simplified || nullptr == bias) ? nullptr : bias->Data<T>();

const TensorShape& x_shape = X->Shape();
size_t scale_size = scale ? static_cast<size_t>(scale->Shape().Size()) : prepacked_scale_fp32_size_;
size_t bias_size = bias ? static_cast<size_t>(bias->Shape().Size()) : prepacked_bias_fp32_size_;
const TensorShape& scale_shape = scale ? scale->Shape() : prepacked_scale_fp32_shape_;
const TensorShape& bias_shape = bias ? bias->Shape() : prepacked_bias_fp32_shape_;
Tensor* Y = p_ctx->Output(0, x_shape);
T* Y_data = Y->MutableData<T>();

Expand Down Expand Up @@ -215,7 +222,7 @@ Status LayerNormImpl::ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, flo

AllocatorPtr alloc;
ORT_RETURN_IF_ERROR(p_ctx->GetTempSpaceAllocator(&alloc));
return ComputeWithoutContext<T, U>(X_data, x_shape, scale_data, scale_size, bias_data, bias_size, Y_data, mean_data,
return ComputeWithoutContext<T, U>(X_data, x_shape, scale_data, scale_shape, bias_data, bias_shape, Y_data, mean_data,
inv_std_dev_data, thread_pool, axis, epsilon, simplified, alloc);
}

Expand All @@ -234,10 +241,10 @@ Status LayerNormImpl::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr

is_packed = false;
if (input_idx == 1) { // scale
prepacked_scale_fp32_size_ = static_cast<size_t>(tensor.Shape().Size());
prepacked_scale_fp32_shape_ = tensor.Shape();
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_scale_fp32_data_, is_packed);
} else if (input_idx == 2) { // bias
prepacked_bias_fp32_size_ = static_cast<size_t>(tensor.Shape().Size());
prepacked_bias_fp32_shape_ = tensor.Shape();
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_bias_fp32_data_, is_packed);
}

Expand All @@ -249,9 +256,9 @@ Status LayerNormImpl::ComputeWithoutContext(
const T* X_data,
const TensorShape& x_shape,
const T* scale_data,
size_t scale_size,
const TensorShape& scale_shape,
const T* bias_data,
size_t bias_size,
const TensorShape& bias_shape,
T* Y_data,
U* mean_data,
U* inv_std_dev_data,
Expand All @@ -260,35 +267,28 @@ Status LayerNormImpl::ComputeWithoutContext(
float epsilon,
bool simplified,
AllocatorPtr alloc) const {
int64_t norm_count = x_shape.SizeToDimension(onnxruntime::narrow<size_t>(axis));
int64_t norm_size = x_shape.SizeFromDimension(onnxruntime::narrow<size_t>(axis));

if (static_cast<int64_t>(scale_size) != norm_size || (bias_data && static_cast<int64_t>(bias_size) != norm_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Size of X.shape()[axis:] == ", norm_size,
". Size of scale and bias (if provided) must match this. Got scale size of ",
scale_size, " and bias size of ", bias_size);
}
LayerNormParams params;
ORT_RETURN_IF_ERROR(LayerNormHelper::CheckInputs(x_shape, scale_shape, bias_shape, bias_data != nullptr, axis, params));

IAllocatorUniquePtr<float> scale_fp32;
IAllocatorUniquePtr<float> bias_fp32;
if constexpr (std::is_same_v<T, MLFloat16>) {
if (prepacked_scale_fp32_data_ == nullptr) {
const size_t num_elems = static_cast<size_t>(norm_size);
const size_t num_elems = static_cast<size_t>(params.scale_size);
scale_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(scale_data, scale_fp32.get(), num_elems);
}
if (prepacked_bias_fp32_data_ == nullptr && bias_data) {
const size_t num_elems = static_cast<size_t>(norm_size);
const size_t num_elems = static_cast<size_t>(params.bias_size);
bias_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(bias_data, bias_fp32.get(), num_elems);
}
}

concurrency::ThreadPool::TryBatchParallelFor(
thread_pool, static_cast<int32_t>(norm_count),
thread_pool, static_cast<int32_t>(params.num_rows),
[&](ptrdiff_t task_idx) {
ComputeJob(X_data, scale_data, bias_data, task_idx, norm_size,
ComputeJob(X_data, scale_data, bias_data, task_idx, params.norm_size, params.broadcast_param,
prepacked_scale_fp32_data_ ? prepacked_scale_fp32_data_.get() : scale_fp32.get(),
prepacked_bias_fp32_data_ ? prepacked_bias_fp32_data_.get() : bias_fp32.get(),
epsilon, simplified, Y_data, mean_data, inv_std_dev_data, alloc);
Expand Down
8 changes: 4 additions & 4 deletions onnxruntime/core/providers/cpu/nn/layer_norm_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ class LayerNormImpl : public OpKernel {
const T* X_data,
const TensorShape& x_shape,
const T* scale_data,
size_t scale_size,
const TensorShape& scale_shape,
const T* bias_data,
size_t bias_size,
const TensorShape& bias_shape,
T* Y_data,
U* mean_data,
U* inv_std_dev,
Expand Down Expand Up @@ -64,9 +64,9 @@ class LayerNormImpl : public OpKernel {
const bool simplified_;
const bool contrib_op_;
IAllocatorUniquePtr<float> prepacked_scale_fp32_data_;
size_t prepacked_scale_fp32_size_;
TensorShape prepacked_scale_fp32_shape_;
IAllocatorUniquePtr<float> prepacked_bias_fp32_data_;
size_t prepacked_bias_fp32_size_;
TensorShape prepacked_bias_fp32_shape_;
};

} // namespace onnxruntime
32 changes: 15 additions & 17 deletions onnxruntime/core/providers/cuda/nn/layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "core/providers/shared_library/provider_api.h"
#include "core/providers/cuda/nn/layer_norm.h"
#include "core/providers/cuda/nn/layer_norm_impl.h"
#include "core/providers/cpu/nn/layer_norm_helper.h"
#include "core/providers/cuda/cuda_common.h"

namespace onnxruntime {
Expand Down Expand Up @@ -44,28 +45,22 @@ Status LayerNorm<T, U, V, simplified>::ComputeInternal(OpKernelContext* ctx) con
auto bias_data = (simplified || (nullptr == bias)) ? nullptr : reinterpret_cast<const CudaV*>(bias->Data<V>());

const TensorShape& x_shape = X->Shape();
const int64_t axis = HandleNegativeAxis(axis_, x_shape.NumDimensions());

int n1 = gsl::narrow<int>(x_shape.SizeToDimension(axis));
int n2 = gsl::narrow<int>(x_shape.SizeFromDimension(axis));

const auto scale_size = scale->Shape().Size();
const auto bias_size = (bias_data) ? bias->Shape().Size() : 0;
if (n2 == 1 || scale_size != n2 || (bias_data && bias_size != n2)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Size of X.shape()[axis:] == ", n2,
". Size of scale and bias (if provided) must match this "
"and the size must not be 1. Got scale size of ",
scale_size, " and bias size of ", bias_size);
}
auto x_num_dims = x_shape.NumDimensions();
const int64_t axis = HandleNegativeAxis(axis_, x_num_dims);

const TensorShape& scale_shape = scale->Shape();
const TensorShape& bias_shape = bias_data ? bias->Shape() : TensorShape();

LayerNormParams params;
ORT_RETURN_IF_ERROR(LayerNormHelper::CheckInputs(x_shape, scale_shape, bias_shape, bias_data != nullptr, axis, params));

// Outputs
Tensor* Y = ctx->Output(0, x_shape);
auto Y_data = reinterpret_cast<CudaV*>(Y->MutableData<V>());

// Mean and variance
std::vector<int64_t> mean_inv_std_var_dim;
for (int i = 0; i < static_cast<int>(x_shape.NumDimensions()); ++i) {
for (int i = 0; i < static_cast<int>(x_num_dims); ++i) {
if (i < axis) {
mean_inv_std_var_dim.emplace_back(x_shape.GetDims()[i]);
} else {
Expand Down Expand Up @@ -93,8 +88,11 @@ Status LayerNorm<T, U, V, simplified>::ComputeInternal(OpKernelContext* ctx) con
return Status::OK();
}

HostApplyLayerNorm<CudaT, CudaU, CudaV, simplified>(GetDeviceProp(), Stream(ctx), Y_data, mean_data, inv_var_data,
X_data, n1, n2, epsilon_, scale_data, bias_data);
HostApplyLayerNorm<CudaT, CudaU, CudaV, simplified>(
GetDeviceProp(), Stream(ctx), Y_data, mean_data, inv_var_data, X_data,
onnxruntime::narrow<int>(params.num_rows), onnxruntime::narrow<int>(params.norm_size), epsilon_,
scale_data, bias_data,
onnxruntime::narrow<int>(params.broadcast_param));
CUDA_RETURN_IF_ERROR(cudaGetLastError());
return Status::OK();
}
Expand Down
Loading

0 comments on commit 73f5b0c

Please sign in to comment.