diff --git a/examples/python/ml/flax_whisper/3pc.json b/examples/python/ml/flax_whisper/3pc.json new file mode 100644 index 00000000..7cbcf43f --- /dev/null +++ b/examples/python/ml/flax_whisper/3pc.json @@ -0,0 +1,46 @@ +{ + "id": "outsourcing.3pc", + "nodes": { + "node:0": "127.0.0.1:61920", + "node:1": "127.0.0.1:61921", + "node:2": "127.0.0.1:61922", + "node:3": "127.0.0.1:61923", + "node:4": "127.0.0.1:61924" + }, + "devices": { + "SPU": { + "kind": "SPU", + "config": { + "node_ids": [ + "node:0", + "node:1", + "node:2" + ], + "spu_internal_addrs": [ + "127.0.0.1:61930", + "127.0.0.1:61931", + "127.0.0.1:61932" + ], + "runtime_config": { + "protocol": "ABY3", + "field": "FM128", + "enable_pphlo_profile": true, + "enable_hal_profile": true, + "fxp_exp_mode": 1 + } + } + }, + "P1": { + "kind": "PYU", + "config": { + "node_id": "node:3" + } + }, + "P2": { + "kind": "PYU", + "config": { + "node_id": "node:4" + } + } + } +} \ No newline at end of file diff --git a/examples/python/ml/flax_whisper/BUILD.bazel b/examples/python/ml/flax_whisper/BUILD.bazel new file mode 100644 index 00000000..30bc8d3c --- /dev/null +++ b/examples/python/ml/flax_whisper/BUILD.bazel @@ -0,0 +1,28 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_python//python:defs.bzl", "py_binary") + +package(default_visibility = ["//visibility:public"]) + +py_binary( + name = "flax_whisper", + srcs = ["flax_whisper.py"], + data = [ + "//examples/python/ml/flax_whisper:3pc.json", + ], + deps = [ + "//spu/utils:distributed", + ], +) diff --git a/examples/python/ml/flax_whisper/README.md b/examples/python/ml/flax_whisper/README.md new file mode 100644 index 00000000..1a331f4e --- /dev/null +++ b/examples/python/ml/flax_whisper/README.md @@ -0,0 +1,26 @@ +# Flax Whisper Example + +This example demonstrates how to use SPU to run private inference on a pre-trained +[Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.FlaxWhisperForConditionalGeneration) model. + +1. Install huggingface transformers library + + ```sh + pip install 'transformers[flax]' soundfile librosa + ``` + +2. Enable While with secret value + + Edit libspu/kernel/hlo/control_flow.cc, change `ENABLE_DEBUG_ONLY_REVEAL_SECRET_CONDITION` to `true`. + +3. Launch SPU backend runtime + + ```sh + bazel run -c opt //examples/python/utils:nodectl -- --config `pwd`/examples/python/ml/flax_whisper/3pc.json up + ``` + +4. Run `flax_whisper` example + + ```sh + bazel run -c opt //examples/python/ml/flax_whisper -- --config `pwd`/examples/python/ml/flax_whisper/3pc.json + ``` diff --git a/examples/python/ml/flax_whisper/flax_whisper.py b/examples/python/ml/flax_whisper/flax_whisper.py new file mode 100644 index 00000000..e2c84066 --- /dev/null +++ b/examples/python/ml/flax_whisper/flax_whisper.py @@ -0,0 +1,84 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Start nodes. +# > bazel run -c opt //examples/python/utils:nodectl -- --config `pwd`/examples/python/ml/flax_whisper/3pc.json up +# +# Run this example script. +# > bazel run -c opt //examples/python/ml/flax_whisper:flax_whisper + +import argparse +import json +import os + +import jax.numpy as jnp +from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration +from datasets import load_dataset + +import spu.utils.distributed as ppd + +parser = argparse.ArgumentParser(description='distributed driver.') +parser.add_argument( + "-c", "--config", default="examples/python/ml/flax_whisper/3pc.json" +) +args = parser.parse_args() + +with open(args.config, 'r') as file: + conf = json.load(file) + +ppd.init(conf["nodes"], conf["devices"]) + +processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") +pretrained_model = FlaxWhisperForConditionalGeneration.from_pretrained( + "openai/whisper-tiny.en", from_pt=True +) + + +def text_generation(input_features, params): + pretrained_model.params = params + generated_ids = pretrained_model.generate(input_features=input_features) + return generated_ids.sequences + + +def run_on_cpu(): + ds = load_dataset( + "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation" + ) + inputs = processor(ds[0]["audio"]["array"], return_tensors="np") + generated_ids = pretrained_model.generate(input_features=inputs.input_features) + return generated_ids.sequences + + +def run_on_spu(): + ds = load_dataset( + "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation" + ) + inputs_ids = processor(ds[0]["audio"]["array"], return_tensors="np") + + input_ids = ppd.device("P1")(lambda x: x)(inputs_ids.input_features) + params = ppd.device("P2")(lambda x: x)(pretrained_model.params) + outputs_ids = ppd.device("SPU")( + text_generation, + )(input_ids, params) + outputs_ids = ppd.get(outputs_ids) + return outputs_ids + + +if __name__ == '__main__': + print('\n------\nRun on CPU') + generated_ids = run_on_cpu() + print(processor.batch_decode(generated_ids, skip_special_tokens=True)[0]) + print('\n------\nRun on SPU') + generated_ids = run_on_spu() + print(processor.batch_decode(generated_ids, skip_special_tokens=True)[0]) diff --git a/libspu/device/pphlo/pphlo_executor.cc b/libspu/device/pphlo/pphlo_executor.cc index 248fe9e4..32878a03 100644 --- a/libspu/device/pphlo/pphlo_executor.cc +++ b/libspu/device/pphlo/pphlo_executor.cc @@ -417,31 +417,74 @@ void execute(OpExecutor *, SPUContext *sctx, SymbolScope *sscope, auto lhs = lookupValue(sscope, op.getLhs(), opts); auto rhs = lookupValue(sscope, op.getRhs(), opts); - Strides window_strides(dnums.getInputSpatialDimensions().size(), 1); - if (op.getWindowStrides().has_value()) { - window_strides = *op.getWindowStrides(); - } - kernel::hlo::ConvolutionConfig config; - config.featureGroupCount = op.getFeatureGroupCount(); - config.batchGroupCount = op.getBatchGroupCount(); - config.window_strides = window_strides; - config.inputBatchDimension = dnums.getInputBatchDimension(); - config.inputFeatureDimension = dnums.getInputFeatureDimension(); - config.inputSpatialDimensions = dnums.getInputSpatialDimensions(); - config.kernelInputFeatureDimension = dnums.getKernelInputFeatureDimension(); - config.kernelOutputFeatureDimension = dnums.getKernelOutputFeatureDimension(); - config.kernelSpatialDimensions = dnums.getKernelSpatialDimensions(); - config.outputBatchDimension = dnums.getOutputBatchDimension(); - config.outputFeatureDimension = dnums.getOutputFeatureDimension(); - config.outputSpatialDimensions = dnums.getOutputSpatialDimensions(); - - SPU_ENFORCE( - dnums.getInputSpatialDimensions().size() == 2, - "Convolution with more than 2 spatial dimensions is not supported"); - - spu::Value result = - kernel::hlo::Convolution2D(sctx, lhs, rhs, config, ret_shape); + spu::Value result; + + switch (dnums.getInputSpatialDimensions().size()) { + case 1: { + // Nwc Wcf to N1wc 1Wcf + Strides window_strides(2, 1); + if (op.getWindowStrides().has_value()) { + window_strides[1] = (*op.getWindowStrides())[0]; + } + config.featureGroupCount = op.getFeatureGroupCount(); + config.batchGroupCount = op.getBatchGroupCount(); + config.window_strides = window_strides; + config.inputBatchDimension = dnums.getInputBatchDimension(); + config.inputFeatureDimension = dnums.getInputFeatureDimension() + 1; + config.inputSpatialDimensions = {1, 2}; + config.kernelInputFeatureDimension = + dnums.getKernelInputFeatureDimension() + 1; + config.kernelOutputFeatureDimension = + dnums.getKernelOutputFeatureDimension() + 1; + config.kernelSpatialDimensions = {0, 1}; + config.outputBatchDimension = dnums.getOutputBatchDimension(); + config.outputFeatureDimension = dnums.getOutputFeatureDimension() + 1; + config.outputSpatialDimensions = {1, 2}; + + auto reshaped_lhs = kernel::hlo::Reshape( + sctx, lhs, {lhs.shape()[0], 1, lhs.shape()[1], lhs.shape()[2]}); + auto reshaped_rhs = kernel::hlo::Reshape( + sctx, rhs, {1, rhs.shape()[0], rhs.shape()[1], rhs.shape()[2]}); + + Shape reshaped_result = {ret_shape[0], 1, ret_shape[1], ret_shape[2]}; + + result = kernel::hlo::Convolution2D(sctx, reshaped_lhs, reshaped_rhs, + config, reshaped_result); + + result = kernel::hlo::Reshape(sctx, result, ret_shape); + break; + } + case 2: { + Strides window_strides(dnums.getInputSpatialDimensions().size(), 1); + if (op.getWindowStrides().has_value()) { + window_strides = *op.getWindowStrides(); + } + config.featureGroupCount = op.getFeatureGroupCount(); + config.batchGroupCount = op.getBatchGroupCount(); + config.window_strides = window_strides; + config.inputBatchDimension = dnums.getInputBatchDimension(); + config.inputFeatureDimension = dnums.getInputFeatureDimension(); + config.inputSpatialDimensions = dnums.getInputSpatialDimensions(); + config.kernelInputFeatureDimension = + dnums.getKernelInputFeatureDimension(); + config.kernelOutputFeatureDimension = + dnums.getKernelOutputFeatureDimension(); + config.kernelSpatialDimensions = dnums.getKernelSpatialDimensions(); + config.outputBatchDimension = dnums.getOutputBatchDimension(); + config.outputFeatureDimension = dnums.getOutputFeatureDimension(); + config.outputSpatialDimensions = dnums.getOutputSpatialDimensions(); + + result = kernel::hlo::Convolution2D(sctx, lhs, rhs, config, ret_shape); + break; + } + default: { + SPU_THROW( + "Convolution with {} spatial dimensions is not " + "supported, {}", + dnums.getInputSpatialDimensions().size(), mlirObjectToString(op)); + } + } addValue(sscope, op.getResult(), std::move(result), opts); }