Skip to content

Commit

Permalink
repo-sync-2024-01-22T22:48:42+0800
Browse files Browse the repository at this point in the history
  • Loading branch information
anakinxc committed Jan 22, 2024
1 parent 893301f commit e2da1d2
Show file tree
Hide file tree
Showing 5 changed files with 251 additions and 24 deletions.
46 changes: 46 additions & 0 deletions examples/python/ml/flax_whisper/3pc.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
{
"id": "outsourcing.3pc",
"nodes": {
"node:0": "127.0.0.1:61920",
"node:1": "127.0.0.1:61921",
"node:2": "127.0.0.1:61922",
"node:3": "127.0.0.1:61923",
"node:4": "127.0.0.1:61924"
},
"devices": {
"SPU": {
"kind": "SPU",
"config": {
"node_ids": [
"node:0",
"node:1",
"node:2"
],
"spu_internal_addrs": [
"127.0.0.1:61930",
"127.0.0.1:61931",
"127.0.0.1:61932"
],
"runtime_config": {
"protocol": "ABY3",
"field": "FM128",
"enable_pphlo_profile": true,
"enable_hal_profile": true,
"fxp_exp_mode": 1
}
}
},
"P1": {
"kind": "PYU",
"config": {
"node_id": "node:3"
}
},
"P2": {
"kind": "PYU",
"config": {
"node_id": "node:4"
}
}
}
}
28 changes: 28 additions & 0 deletions examples/python/ml/flax_whisper/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright 2024 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

load("@rules_python//python:defs.bzl", "py_binary")

package(default_visibility = ["//visibility:public"])

py_binary(
name = "flax_whisper",
srcs = ["flax_whisper.py"],
data = [
"//examples/python/ml/flax_whisper:3pc.json",
],
deps = [
"//spu/utils:distributed",
],
)
26 changes: 26 additions & 0 deletions examples/python/ml/flax_whisper/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Flax Whisper Example

This example demonstrates how to use SPU to run private inference on a pre-trained
[Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.FlaxWhisperForConditionalGeneration) model.

1. Install huggingface transformers library

```sh
pip install 'transformers[flax]' soundfile librosa
```

2. Enable While with secret value

Edit libspu/kernel/hlo/control_flow.cc, change `ENABLE_DEBUG_ONLY_REVEAL_SECRET_CONDITION` to `true`.

3. Launch SPU backend runtime

```sh
bazel run -c opt //examples/python/utils:nodectl -- --config `pwd`/examples/python/ml/flax_whisper/3pc.json up
```

4. Run `flax_whisper` example

```sh
bazel run -c opt //examples/python/ml/flax_whisper -- --config `pwd`/examples/python/ml/flax_whisper/3pc.json
```
84 changes: 84 additions & 0 deletions examples/python/ml/flax_whisper/flax_whisper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright 2024 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Start nodes.
# > bazel run -c opt //examples/python/utils:nodectl -- --config `pwd`/examples/python/ml/flax_whisper/3pc.json up
#
# Run this example script.
# > bazel run -c opt //examples/python/ml/flax_whisper:flax_whisper

import argparse
import json
import os

import jax.numpy as jnp
from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration
from datasets import load_dataset

import spu.utils.distributed as ppd

parser = argparse.ArgumentParser(description='distributed driver.')
parser.add_argument(
"-c", "--config", default="examples/python/ml/flax_whisper/3pc.json"
)
args = parser.parse_args()

with open(args.config, 'r') as file:
conf = json.load(file)

ppd.init(conf["nodes"], conf["devices"])

processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
pretrained_model = FlaxWhisperForConditionalGeneration.from_pretrained(
"openai/whisper-tiny.en", from_pt=True
)


def text_generation(input_features, params):
pretrained_model.params = params
generated_ids = pretrained_model.generate(input_features=input_features)
return generated_ids.sequences


def run_on_cpu():
ds = load_dataset(
"hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
)
inputs = processor(ds[0]["audio"]["array"], return_tensors="np")
generated_ids = pretrained_model.generate(input_features=inputs.input_features)
return generated_ids.sequences


def run_on_spu():
ds = load_dataset(
"hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
)
inputs_ids = processor(ds[0]["audio"]["array"], return_tensors="np")

input_ids = ppd.device("P1")(lambda x: x)(inputs_ids.input_features)
params = ppd.device("P2")(lambda x: x)(pretrained_model.params)
outputs_ids = ppd.device("SPU")(
text_generation,
)(input_ids, params)
outputs_ids = ppd.get(outputs_ids)
return outputs_ids


if __name__ == '__main__':
print('\n------\nRun on CPU')
generated_ids = run_on_cpu()
print(processor.batch_decode(generated_ids, skip_special_tokens=True)[0])
print('\n------\nRun on SPU')
generated_ids = run_on_spu()
print(processor.batch_decode(generated_ids, skip_special_tokens=True)[0])
91 changes: 67 additions & 24 deletions libspu/device/pphlo/pphlo_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -417,31 +417,74 @@ void execute(OpExecutor *, SPUContext *sctx, SymbolScope *sscope,
auto lhs = lookupValue(sscope, op.getLhs(), opts);
auto rhs = lookupValue(sscope, op.getRhs(), opts);

Strides window_strides(dnums.getInputSpatialDimensions().size(), 1);
if (op.getWindowStrides().has_value()) {
window_strides = *op.getWindowStrides();
}

kernel::hlo::ConvolutionConfig config;
config.featureGroupCount = op.getFeatureGroupCount();
config.batchGroupCount = op.getBatchGroupCount();
config.window_strides = window_strides;
config.inputBatchDimension = dnums.getInputBatchDimension();
config.inputFeatureDimension = dnums.getInputFeatureDimension();
config.inputSpatialDimensions = dnums.getInputSpatialDimensions();
config.kernelInputFeatureDimension = dnums.getKernelInputFeatureDimension();
config.kernelOutputFeatureDimension = dnums.getKernelOutputFeatureDimension();
config.kernelSpatialDimensions = dnums.getKernelSpatialDimensions();
config.outputBatchDimension = dnums.getOutputBatchDimension();
config.outputFeatureDimension = dnums.getOutputFeatureDimension();
config.outputSpatialDimensions = dnums.getOutputSpatialDimensions();

SPU_ENFORCE(
dnums.getInputSpatialDimensions().size() == 2,
"Convolution with more than 2 spatial dimensions is not supported");

spu::Value result =
kernel::hlo::Convolution2D(sctx, lhs, rhs, config, ret_shape);
spu::Value result;

switch (dnums.getInputSpatialDimensions().size()) {
case 1: {
// Nwc Wcf to N1wc 1Wcf
Strides window_strides(2, 1);
if (op.getWindowStrides().has_value()) {
window_strides[1] = (*op.getWindowStrides())[0];
}
config.featureGroupCount = op.getFeatureGroupCount();
config.batchGroupCount = op.getBatchGroupCount();
config.window_strides = window_strides;
config.inputBatchDimension = dnums.getInputBatchDimension();
config.inputFeatureDimension = dnums.getInputFeatureDimension() + 1;
config.inputSpatialDimensions = {1, 2};
config.kernelInputFeatureDimension =
dnums.getKernelInputFeatureDimension() + 1;
config.kernelOutputFeatureDimension =
dnums.getKernelOutputFeatureDimension() + 1;
config.kernelSpatialDimensions = {0, 1};
config.outputBatchDimension = dnums.getOutputBatchDimension();
config.outputFeatureDimension = dnums.getOutputFeatureDimension() + 1;
config.outputSpatialDimensions = {1, 2};

auto reshaped_lhs = kernel::hlo::Reshape(
sctx, lhs, {lhs.shape()[0], 1, lhs.shape()[1], lhs.shape()[2]});
auto reshaped_rhs = kernel::hlo::Reshape(
sctx, rhs, {1, rhs.shape()[0], rhs.shape()[1], rhs.shape()[2]});

Shape reshaped_result = {ret_shape[0], 1, ret_shape[1], ret_shape[2]};

result = kernel::hlo::Convolution2D(sctx, reshaped_lhs, reshaped_rhs,
config, reshaped_result);

result = kernel::hlo::Reshape(sctx, result, ret_shape);
break;
}
case 2: {
Strides window_strides(dnums.getInputSpatialDimensions().size(), 1);
if (op.getWindowStrides().has_value()) {
window_strides = *op.getWindowStrides();
}
config.featureGroupCount = op.getFeatureGroupCount();
config.batchGroupCount = op.getBatchGroupCount();
config.window_strides = window_strides;
config.inputBatchDimension = dnums.getInputBatchDimension();
config.inputFeatureDimension = dnums.getInputFeatureDimension();
config.inputSpatialDimensions = dnums.getInputSpatialDimensions();
config.kernelInputFeatureDimension =
dnums.getKernelInputFeatureDimension();
config.kernelOutputFeatureDimension =
dnums.getKernelOutputFeatureDimension();
config.kernelSpatialDimensions = dnums.getKernelSpatialDimensions();
config.outputBatchDimension = dnums.getOutputBatchDimension();
config.outputFeatureDimension = dnums.getOutputFeatureDimension();
config.outputSpatialDimensions = dnums.getOutputSpatialDimensions();

result = kernel::hlo::Convolution2D(sctx, lhs, rhs, config, ret_shape);
break;
}
default: {
SPU_THROW(
"Convolution with {} spatial dimensions is not "
"supported, {}",
dnums.getInputSpatialDimensions().size(), mlirObjectToString(op));
}
}

addValue(sscope, op.getResult(), std::move(result), opts);
}
Expand Down

0 comments on commit e2da1d2

Please sign in to comment.