diff --git a/CHANGELOG.md b/CHANGELOG.md index 7fecfd96..9d68d197 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ > > please add your unreleased change here. +- [Feature] Support more generic Torch model inference - [Improvement] Optimize one-time setup for yacl ot - [Improvement] Optimize sort performance diff --git a/docs/reference/pphlo_op_doc.md b/docs/reference/pphlo_op_doc.md index 69367ba5..74725045 100644 --- a/docs/reference/pphlo_op_doc.md +++ b/docs/reference/pphlo_op_doc.md @@ -18,9 +18,9 @@ Effects: MemoryEffects::Effect{} - - - + + +
AttributeMLIR TypeDescription
edge_padding_low::mlir::DenseIntElementsAttr64-bit signless integer elements attribute
edge_padding_high::mlir::DenseIntElementsAttr64-bit signless integer elements attribute
interior_padding::mlir::DenseIntElementsAttr64-bit signless integer elements attribute
edge_padding_low::mlir::DenseI64ArrayAttri64 dense array attribute
edge_padding_high::mlir::DenseI64ArrayAttri64 dense array attribute
interior_padding::mlir::DenseI64ArrayAttri64 dense array attribute
#### Operands: @@ -135,9 +135,9 @@ Effects: MemoryEffects::Effect{} - - - + + +
AttributeMLIR TypeDescription
window_dimensions::mlir::DenseIntElementsAttr64-bit signless integer elements attribute
window_strides::mlir::DenseIntElementsAttr64-bit signless integer elements attribute
window_dilations::mlir::DenseIntElementsAttr64-bit signless integer elements attribute
window_dimensions::mlir::DenseI64ArrayAttri64 dense array attribute
window_strides::mlir::DenseI64ArrayAttri64 dense array attribute
window_dilations::mlir::DenseI64ArrayAttri64 dense array attribute
onehot_index::mlir::BoolAttrbool attribute
@@ -213,7 +213,7 @@ Effects: MemoryEffects::Effect{} - +
AttributeMLIR TypeDescription
broadcast_dimensions::mlir::DenseIntElementsAttr64-bit signless integer elements attribute
broadcast_dimensions::mlir::DenseI64ArrayAttri64 dense array attribute
#### Operands: @@ -455,7 +455,7 @@ Effects: MemoryEffects::Effect{} - + @@ -658,7 +658,7 @@ Effects: MemoryEffects::Effect{}
AttributeMLIR TypeDescription
window_strides::mlir::DenseIntElementsAttr64-bit signless integer elements attribute
window_strides::mlir::DenseI64ArrayAttri64 dense array attribute
dimension_numbers::mlir::pphlo::ConvDimensionNumbersAttrStructure of dimension information for conv op
feature_group_count::mlir::IntegerAttr64-bit signless integer attribute
batch_group_count::mlir::IntegerAttr64-bit signless integer attribute
- +
AttributeMLIR TypeDescription
slice_sizes::mlir::DenseIntElementsAttr64-bit signless integer elements attribute
slice_sizes::mlir::DenseI64ArrayAttri64 dense array attribute
#### Operands: @@ -837,43 +837,6 @@ Effects: MemoryEffects::Effect{} | :-----: | ----------- | | `operand` | statically shaped tensor of PPHlo public type or PPHlo secret type values -### `pphlo.gather` (pphlo::GatherOp) - -_Gather operator_ - -Stitches together several slices of `operand` from offsets specified in -`start_indices` (each slice at a potentially different runtime offset). - -See https://www.tensorflow.org/xla/operation_semantics#gather. - -Traits: AlwaysSpeculatableImplTrait - -Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) - -Effects: MemoryEffects::Effect{} - -#### Attributes: - - - - - - -
AttributeMLIR TypeDescription
dimension_numbers::mlir::pphlo::GatherDimensionNumbersAttrAttribute that models the dimension information for gather
slice_sizes::mlir::DenseIntElementsAttr64-bit signless integer elements attribute
indices_are_sorted::mlir::BoolAttrbool attribute
- -#### Operands: - -| Operand | Description | -| :-----: | ----------- | -| `operand` | statically shaped tensor of PPHlo public type or PPHlo secret type values -| `start_indices` | statically shaped tensor of public integer type or secret integer type values - -#### Results: - -| Result | Description | -| :----: | ----------- | -«unnamed» | statically shaped tensor of PPHlo public type or PPHlo secret type values - ### `pphlo.greater_equal` (pphlo::GreaterEqualOp) _Greater_equal comparison operator_ @@ -1185,8 +1148,8 @@ Effects: MemoryEffects::Effect{} - - + +
AttributeMLIR TypeDescription
window_dimensions::mlir::DenseIntElementsAttr64-bit signless integer elements attribute
window_strides::mlir::DenseIntElementsAttr64-bit signless integer elements attribute
window_dimensions::mlir::DenseI64ArrayAttri64 dense array attribute
window_strides::mlir::DenseI64ArrayAttri64 dense array attribute
#### Operands: @@ -1491,7 +1454,7 @@ Traits: RecursiveMemoryEffects, SameVariadicOperandSize, SingleBlock, SingleBloc - +
AttributeMLIR TypeDescription
dimensions::mlir::DenseIntElementsAttr64-bit signless integer elements attribute
dimensions::mlir::DenseI64ArrayAttri64 dense array attribute
#### Operands: @@ -1522,9 +1485,9 @@ Traits: RecursiveMemoryEffects, SameVariadicOperandSize, SingleBlock, SingleBloc - - - + + +
AttributeMLIR TypeDescription
window_dimensions::mlir::DenseIntElementsAttr64-bit signless integer elements attribute
window_strides::mlir::DenseIntElementsAttr64-bit signless integer elements attribute
window_dilations::mlir::DenseIntElementsAttr64-bit signless integer elements attribute
window_dimensions::mlir::DenseI64ArrayAttri64 dense array attribute
window_strides::mlir::DenseI64ArrayAttri64 dense array attribute
window_dilations::mlir::DenseI64ArrayAttri64 dense array attribute
#### Operands: @@ -1630,7 +1593,7 @@ Effects: MemoryEffects::Effect{} - +
AttributeMLIR TypeDescription
dimensions::mlir::DenseIntElementsAttr64-bit signless integer elements attribute
dimensions::mlir::DenseI64ArrayAttri64 dense array attribute
#### Operands: @@ -1745,8 +1708,8 @@ Traits: RecursiveMemoryEffects - - + +
AttributeMLIR TypeDescription
window_dimensions::mlir::DenseIntElementsAttr64-bit signless integer elements attribute
window_strides::mlir::DenseIntElementsAttr64-bit signless integer elements attribute
window_dimensions::mlir::DenseI64ArrayAttri64 dense array attribute
window_strides::mlir::DenseI64ArrayAttri64 dense array attribute
#### Operands: @@ -1986,9 +1949,9 @@ Effects: MemoryEffects::Effect{} - - - + + +
AttributeMLIR TypeDescription
start_indices::mlir::DenseIntElementsAttr64-bit signless integer elements attribute
limit_indices::mlir::DenseIntElementsAttr64-bit signless integer elements attribute
strides::mlir::DenseIntElementsAttr64-bit signless integer elements attribute
start_indices::mlir::DenseI64ArrayAttri64 dense array attribute
limit_indices::mlir::DenseI64ArrayAttri64 dense array attribute
strides::mlir::DenseI64ArrayAttri64 dense array attribute
#### Operands: @@ -2136,7 +2099,7 @@ Effects: MemoryEffects::Effect{} - +
AttributeMLIR TypeDescription
permutation::mlir::DenseIntElementsAttr64-bit signless integer elements attribute
permutation::mlir::DenseI64ArrayAttri64 dense array attribute
#### Operands: diff --git a/examples/python/ml/BUILD.bazel b/examples/python/ml/BUILD.bazel index 065208db..264e7537 100644 --- a/examples/python/ml/BUILD.bazel +++ b/examples/python/ml/BUILD.bazel @@ -40,7 +40,8 @@ py_test( "//examples/python/ml/stax_mnist_classifier", "//examples/python/ml/stax_nn", "//examples/python/ml/tf_experiment", - "//examples/python/ml/torch_experiment", + "//examples/python/ml/torch_lr_experiment", + "//examples/python/ml/torch_resnet_experiment", "//spu/utils:distributed", ], ) diff --git a/examples/python/ml/README.md b/examples/python/ml/README.md index feb4f10c..20124747 100644 --- a/examples/python/ml/README.md +++ b/examples/python/ml/README.md @@ -27,4 +27,5 @@ library, and private inference of a pre-trained ResNet-50 model based on [Micros * [jraph_gnn](jraph_gnn/): Private training of a [graph convolutional network](https://arxiv.org/abs/1609.02907) model with [Jraph](https://github.com/deepmind/jraph). * [tf_experiment](tf_experiment/): Private training of a logistic regression model with TensorFlow (**experimental**). -* [torch_experiment](torch_experiment/): Private inference of a linear regression model with PyTorch (**experimental**). +* [torch_lr_experiment](torch_lr_experiment/): Private inference of a logistic regression model with PyTorch (**experimental**). +* [torch_resnet_experiment](torch_resnet_experiment/): Private inference of a [ResNet](https://arxiv.org/abs/1512.03385) model with PyTorch (**experimental**). diff --git a/examples/python/ml/haiku_lstm/README.md b/examples/python/ml/haiku_lstm/README.md index 4a150d4b..97f0990c 100644 --- a/examples/python/ml/haiku_lstm/README.md +++ b/examples/python/ml/haiku_lstm/README.md @@ -9,7 +9,7 @@ This example comes from Haiku official github repo: 1. Install dependencies ```sh - pip install -r requirements.txt + pip install -r ../requirements.txt ``` 2. Launch SPU backend runtime diff --git a/examples/python/ml/haiku_lstm/requirements.txt b/examples/python/ml/haiku_lstm/requirements.txt deleted file mode 100644 index e4ca8fe6..00000000 --- a/examples/python/ml/haiku_lstm/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -dm-haiku==0.0.10 -plotnine diff --git a/examples/python/ml/jraph_gnn/README.md b/examples/python/ml/jraph_gnn/README.md index 0739ddd4..32883888 100644 --- a/examples/python/ml/jraph_gnn/README.md +++ b/examples/python/ml/jraph_gnn/README.md @@ -9,7 +9,7 @@ This example comes from Jraph official github repo: 1. Install dependencies ```sh - pip install -r requirements.txt + pip install -r ../requirements.txt ``` 2. Set runtime configuration diff --git a/examples/python/ml/jraph_gnn/requirements.txt b/examples/python/ml/jraph_gnn/requirements.txt deleted file mode 100644 index 01b28ae8..00000000 --- a/examples/python/ml/jraph_gnn/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -dm-haiku -jraph \ No newline at end of file diff --git a/examples/python/ml/ml_test.py b/examples/python/ml/ml_test.py index 7526b2b9..cedd148e 100644 --- a/examples/python/ml/ml_test.py +++ b/examples/python/ml/ml_test.py @@ -213,20 +213,27 @@ def test_tf_experiment(self): score = tf_experiment.run_fit_manual_grad_spu() self.assertGreater(score, 0.9) - def test_torch_experiment(self): - from examples.python.ml.torch_experiment import torch_experiment + def test_torch_lr_experiment(self): + from examples.python.ml.torch_lr_experiment import torch_lr_experiment - model = torch_experiment.LinearRegression() - torch_experiment.train(model) - score = torch_experiment.run_inference_on_spu(model) + model = torch_lr_experiment.LinearRegression() + torch_lr_experiment.train(model) + score = torch_lr_experiment.run_inference_on_spu(model) self.assertGreater(score, 0.9) + def test_torch_resnet_experiment(self): + from examples.python.ml.torch_resnet_experiment import torch_resnet_experiment + + model = torch_resnet_experiment.resnet + image = torch_resnet_experiment.input_batch + label = torch_resnet_experiment.run_inference_on_spu(model, image) + self.assertEqual(label, 258) + def test_save_and_load_model(self): from examples.python.ml.jax_lr import jax_lr score = jax_lr.save_and_load_model() self.assertGreater(score, 0.9) - pass def suite(): @@ -246,7 +253,8 @@ def suite(): suite.addTest(UnitTests('test_save_and_load_model')) # should put JAX tests above suite.addTest(UnitTests('test_tf_experiment')) - suite.addTest(UnitTests('test_torch_experiment')) + suite.addTest(UnitTests('test_torch_lr_experiment')) + # suite.addTest(UnitTests('test_torch_resnet_experiment')) return suite diff --git a/examples/python/ml/requirements.txt b/examples/python/ml/requirements.txt new file mode 100644 index 00000000..133e8a38 --- /dev/null +++ b/examples/python/ml/requirements.txt @@ -0,0 +1,7 @@ +dm-haiku==0.0.10 +plotnine +jraph +optax==0.1.7 +torch==2.1.0 +torch_xla==2.1.0 +torchvision \ No newline at end of file diff --git a/examples/python/ml/torch_experiment/README.md b/examples/python/ml/torch_experiment/README.md deleted file mode 100644 index dbc728c2..00000000 --- a/examples/python/ml/torch_experiment/README.md +++ /dev/null @@ -1,24 +0,0 @@ -# Torch Example - -This example demonstrates how to use SPU to make inferences on a linear regression model privately with PyTorch. - -The model is trained with plaintext publicly. Currently, SPU's support of PyTorch is **experimental** and we only tested on Linux. - -1. Install a third-party dependency [Torch-MLIR](https://github.com/llvm/torch-mlir). - - ```sh - pip install https://github.com/llvm/torch-mlir/releases/download/snapshot-20220830.581/torch-1.13.0.dev20220830+cpu-cp38-cp38-linux_x86_64.whl - pip install https://github.com/llvm/torch-mlir/releases/download/snapshot-20220830.581/torch_mlir-20220830.581-cp38-cp38-linux_x86_64.whl - ``` - -2. Launch SPU backend runtime - - ```sh - bazel run -c opt //examples/python/utils:nodectl -- up - ``` - -3. Run `torch_experiment` example - - ```sh - bazel run -c opt //examples/python/ml/torch_experiment - ``` diff --git a/examples/python/ml/torch_experiment/BUILD.bazel b/examples/python/ml/torch_lr_experiment/BUILD.bazel similarity index 87% rename from examples/python/ml/torch_experiment/BUILD.bazel rename to examples/python/ml/torch_lr_experiment/BUILD.bazel index 715ed42d..36cdcce1 100644 --- a/examples/python/ml/torch_experiment/BUILD.bazel +++ b/examples/python/ml/torch_lr_experiment/BUILD.bazel @@ -1,4 +1,4 @@ -# Copyright 2023 Ant Group Co., Ltd. +# Copyright 2024 Ant Group Co., Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,8 +17,8 @@ load("@rules_python//python:defs.bzl", "py_binary") package(default_visibility = ["//visibility:public"]) py_binary( - name = "torch_experiment", - srcs = ["torch_experiment.py"], + name = "torch_lr_experiment", + srcs = ["torch_lr_experiment.py"], data = [ "//examples/python/conf", ], diff --git a/examples/python/ml/torch_lr_experiment/README.md b/examples/python/ml/torch_lr_experiment/README.md new file mode 100644 index 00000000..6bcc2a82 --- /dev/null +++ b/examples/python/ml/torch_lr_experiment/README.md @@ -0,0 +1,23 @@ +# Torch Example + +This example demonstrates how to use SPU to make private inferences on PyTorch models. + +**Note**: Currently, SPU's support of PyTorch is **experimental**. + +1. Install a third-party dependency [PyTorch/XLA](https://github.com/pytorch/xla). + + ```sh + pip install torch~=2.1.0 torch_xla[tpu]~=2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html + ``` + +2. Launch SPU backend runtime + + ```sh + bazel run -c opt //examples/python/utils:nodectl -- up + ``` + +3. Run `torch_lr_experiment` example + + ```sh + bazel run -c opt //examples/python/ml/torch_lr_experiment + ``` diff --git a/examples/python/ml/torch_experiment/torch_experiment.py b/examples/python/ml/torch_lr_experiment/torch_lr_experiment.py similarity index 66% rename from examples/python/ml/torch_experiment/torch_experiment.py rename to examples/python/ml/torch_lr_experiment/torch_lr_experiment.py index 8cb502a7..d06ccd56 100644 --- a/examples/python/ml/torch_experiment/torch_experiment.py +++ b/examples/python/ml/torch_lr_experiment/torch_lr_experiment.py @@ -22,18 +22,12 @@ import spu.utils.distributed as ppd -# This is an experimental example to show legacy pytorch program could be run -# by SPU. Currently we rely on torch-mlir to convert torch code into MLIR -# (specifically MHLO) which is then consumed by SPU. To run this example, -# torch-mlir python package should be installed. This example here trains a -# linear regression model in plaintext and makes private inferences with joint -# features. # Start nodes. # > bazel run -c opt //examples/python/utils:nodectl -- up # # Run this example script. -# > bazel run -c opt //examples/python/ml/torch_experiment:torch_experiment +# > bazel run -c opt //examples/python/ml/torch_lr_experiment:torch_lr_experiment class LinearRegression(torch.nn.Module): @@ -41,17 +35,15 @@ def __init__(self): super(LinearRegression, self).__init__() self.linear = torch.nn.Linear(30, 1) - def forward(self, x1, x2): - y_pred = self.linear(torch.cat((x1, x2), 1)) + def forward(self, x): + y_pred = self.linear(x) return y_pred def train(model, n_epochs=500, lr=0.01): print('Train model with plaintext features\n------\n') x, y = breast_cancer() - x1, x2 = x[:, :15], x[:, 15:] - x1 = torch.Tensor(x1) - x2 = torch.Tensor(x2) + x = torch.Tensor(x) y = torch.Tensor(y).view(-1, 1) criterion = torch.nn.BCEWithLogitsLoss() @@ -59,7 +51,7 @@ def train(model, n_epochs=500, lr=0.01): optimizer = torch.optim.SGD(model.parameters(), lr=lr) for _ in range(n_epochs): - pred_y = model(x1, x2) + pred_y = model(x) loss = criterion(pred_y, y) optimizer.zero_grad() loss.backward() @@ -70,7 +62,6 @@ def train(model, n_epochs=500, lr=0.01): # prepare test datasets def breast_cancer( - col_slicer=slice(None, None, None), train: bool = True, *, normalize: bool = True, @@ -96,7 +87,6 @@ def breast_cancer( else: x_ = x_test y_ = y_test - x_ = x_[:, col_slicer] return x_.astype(dtype=np.float32), y_.astype(dtype=np.float32) @@ -105,10 +95,10 @@ def breast_cancer( def run_inference_on_cpu(model): print('Run on CPU\n------\n') - x_test, y_test = breast_cancer(slice(None, None, None), False) - x1, x2 = torch.Tensor(x_test[:, :15]), torch.Tensor(x_test[:, 15:]) + x_test, y_test = breast_cancer(False) + x = torch.Tensor(x_test) start_ts = time.time() - y_pred = model.forward(x1, x2).cpu().detach().numpy() + y_pred = model(x).cpu().detach().numpy() end_ts = time.time() auc = metrics.roc_auc_score(y_test, y_pred) print(f"AUC(cpu)={auc}, time={end_ts-start_ts}\n------\n") @@ -123,35 +113,36 @@ def run_inference_on_cpu(model): ppd.init(conf["nodes"], conf["devices"], framework=ppd.Framework.EXP_TORCH) +from collections import OrderedDict +from jax.tree_util import tree_map + def run_inference_on_spu(model): print('Run on SPU\n------\n') - x1, _ = ppd.device("P1")(breast_cancer)(slice(None, 15), False) - x2, _ = ppd.device("P2")(breast_cancer)(slice(15, None), False) + + # load parameters and buffers on P1 + params_buffers = OrderedDict() + for k, v in model.named_parameters(): + params_buffers[k] = v + for k, v in model.named_buffers(): + params_buffers[k] = v + params = ppd.device("P1")( + lambda input: tree_map(lambda x: x.detach().numpy(), input) + )(params_buffers) + + # load inputs on P2 + x, _ = ppd.device("P2")(breast_cancer)(False) + start_ts = time.time() - y_pred_ciphertext = ppd.device('SPU')(model)(x1, x2) + y_pred_ciphertext = ppd.device('SPU')(model)(params, x) end_ts = time.time() y_pred_plaintext = ppd.get(y_pred_ciphertext) - _, y_test = breast_cancer(slice(None, None, None), False) + _, y_test = breast_cancer(False) auc = metrics.roc_auc_score(y_test, y_pred_plaintext) print(f"AUC(cpu)={auc}, time={end_ts-start_ts}\n------\n") return auc -def compile_torch_to_mhlo(model): - print('Compile torch program to mhlo test\n------\n') - x_test, _ = breast_cancer(slice(None, None, None), False) - x1, x2 = torch.Tensor(x_test[:, :15]), torch.Tensor(x_test[:, 15:]) - import torch_mlir - - module = torch_mlir.compile( - model, - [x1, x2], - output_type=torch_mlir.OutputType.MHLO, - ) - print(f"MHLO={module}\n------\n") - - if __name__ == '__main__': # For reproducibility torch.manual_seed(0) @@ -159,8 +150,7 @@ def compile_torch_to_mhlo(model): model = LinearRegression() # Train model with plaintext features train(model) - # Torch-mlho conversion test - compile_torch_to_mhlo(model) + model.eval() # Native torch inference run_inference_on_cpu(model) # SPU inference diff --git a/examples/python/ml/torch_resnet_experiment/BUILD.bazel b/examples/python/ml/torch_resnet_experiment/BUILD.bazel new file mode 100644 index 00000000..91d89e45 --- /dev/null +++ b/examples/python/ml/torch_resnet_experiment/BUILD.bazel @@ -0,0 +1,28 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_python//python:defs.bzl", "py_binary") + +package(default_visibility = ["//visibility:public"]) + +py_binary( + name = "torch_resnet_experiment", + srcs = ["torch_resnet_experiment.py"], + data = [ + "//examples/python/conf", + ], + deps = [ + "//spu/utils:distributed", + ], +) diff --git a/examples/python/ml/torch_resnet_experiment/README.md b/examples/python/ml/torch_resnet_experiment/README.md new file mode 100644 index 00000000..ab1da24d --- /dev/null +++ b/examples/python/ml/torch_resnet_experiment/README.md @@ -0,0 +1,24 @@ +# Torch Example + +This example demonstrates how to use SPU to make private inferences on PyTorch models. + +**Note**: Currently, SPU's support of PyTorch is **experimental**. + +1. Install a third-party dependency [PyTorch/XLA](https://github.com/pytorch/xla). + + ```sh + pip install torch~=2.1.0 torch_xla[tpu]~=2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html + pip install torchvision + ``` + +2. Launch SPU backend runtime + + ```sh + bazel run -c opt //examples/python/utils:nodectl -- up + ``` + +3. Run `torch_resnet_experiment` example + + ```sh + bazel run -c opt //examples/python/ml/torch_resnet_experiment + ``` diff --git a/examples/python/ml/torch_resnet_experiment/torch_resnet_experiment.py b/examples/python/ml/torch_resnet_experiment/torch_resnet_experiment.py new file mode 100644 index 00000000..16427ebc --- /dev/null +++ b/examples/python/ml/torch_resnet_experiment/torch_resnet_experiment.py @@ -0,0 +1,102 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import json +import urllib +from collections import OrderedDict + +import torch +from jax.tree_util import tree_map +from PIL import Image +from torchvision import transforms +from torchvision.models import ResNet50_Weights, resnet50 + +import spu.utils.distributed as ppd + +# This is an experimental example to show legacy pytorch program could be run +# by SPU. Currently we rely on torch-xla to convert torch code into MLIR +# (specifically StableHLO) which is then consumed by SPU. To run this example, +# torch-xla python package should be installed. + +# Start nodes. +# > bazel run -c opt //examples/python/utils:nodectl -- up +# +# Run this example script. +# > bazel run -c opt //examples/python/ml/torch_resnet_experiment:torch_resnet_experiment + + +parser = argparse.ArgumentParser(description='distributed driver.') +parser.add_argument("-c", "--config", default="examples/python/conf/3pc.json") +args = parser.parse_args() + +with open(args.config, 'r') as file: + conf = json.load(file) + +ppd.init(conf["nodes"], conf["devices"], framework=ppd.Framework.EXP_TORCH) + +url, filename = ( + "https://github.com/pytorch/hub/raw/master/images/dog.jpg", + "dog.jpg", +) + +urllib.request.urlretrieve(url, filename) + + +input_image = Image.open(filename) +preprocess = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] +) +input_tensor = preprocess(input_image) +input_batch = input_tensor.unsqueeze(0) +resnet = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) +resnet.eval() + + +def run_inference_on_cpu(model, image): + print('Run on CPU\n------\n') + output = model(image) + # model predicts one of the 1000 ImageNet classes + predicted_label = output.argmax(-1).item() + print(f"predicted_label={predicted_label}\n------\n") + return predicted_label + + +def run_inference_on_spu(model, image): + print('Run on SPU\n------\n') + params_buffers = OrderedDict() + for k, v in model.named_parameters(): + params_buffers[k] = v + for k, v in model.named_buffers(): + params_buffers[k] = v + params = ppd.device("P1")( + lambda input: tree_map(lambda x: x.detach().numpy(), input) + )(params_buffers) + image_hat = ppd.device("P2")(lambda x: x.detach().numpy())(image) + res = ppd.device("SPU")(model)(params, image_hat) + predicted_label = ppd.get(res).argmax(-1).item() + print(f"predicted_label={predicted_label}\n------\n") + return predicted_label + + +if __name__ == '__main__': + torch.manual_seed(0) + run_inference_on_cpu(resnet, input_batch) + run_inference_on_spu(resnet, input_batch) diff --git a/libspu/compiler/core/core.cc b/libspu/compiler/core/core.cc index e5be8f8c..08398f8b 100644 --- a/libspu/compiler/core/core.cc +++ b/libspu/compiler/core/core.cc @@ -62,8 +62,6 @@ void Core::buildPipeline(mlir::PassManager *pm) { optPM.addPass(mlir::pphlo::createRewriteDivSqrtPatterns()); } - optPM.addPass(mlir::pphlo::createExpandSecretGatherPass()); - if (options.enable_optimize_denominator_with_broadcast()) { optPM.addPass(mlir::pphlo::createOptimizeDenominatorWithBroadcast()); } diff --git a/libspu/compiler/front_end/BUILD.bazel b/libspu/compiler/front_end/BUILD.bazel index 8a6e1f97..399e30b8 100644 --- a/libspu/compiler/front_end/BUILD.bazel +++ b/libspu/compiler/front_end/BUILD.bazel @@ -76,5 +76,6 @@ spu_cc_library( "@llvm-project//mlir:FuncExtensions", "@llvm-project//mlir:Parser", "@xla//xla/mlir_hlo:mhlo_passes", + "@xla//xla/translate/mhlo_to_hlo:translate", ], ) diff --git a/libspu/compiler/front_end/fe.cc b/libspu/compiler/front_end/fe.cc index 5d89e7c3..f1185296 100644 --- a/libspu/compiler/front_end/fe.cc +++ b/libspu/compiler/front_end/fe.cc @@ -24,6 +24,7 @@ #include "stablehlo/dialect/StablehloOps.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" +#include "xla/translate/mhlo_to_hlo/translate.h" #include "libspu/compiler/common/compilation_context.h" #include "libspu/compiler/front_end/hlo_importer.h" @@ -44,21 +45,33 @@ FE::FE(CompilationContext *ctx) : ctx_(ctx) { } mlir::OwningOpRef FE::doit(const CompilationSource &source) { + HloImporter importer(ctx_); mlir::OwningOpRef module; - switch (source.ir_type()) { - case spu::SourceIRType::XLA: { - HloImporter importer(ctx_); - module = importer.parseXlaModuleFromString(source.ir_txt()); - break; - } - case spu::SourceIRType::MLIR_HLO: { + + if (source.ir_type() == spu::SourceIRType::STABLEHLO) { module = mlir::parseSourceString(source.ir_txt(), ctx_->getMLIRContext()); - break; - } - default: { - SPU_THROW("Unsupported input IR type = {}", source.ir_type()); - } + + // Convert stablehlo to mhlo first + mlir::PassManager pm(ctx_->getMLIRContext()); + pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); + if (pm.run(module.get()).failed()) { + SPU_THROW("Failed to legalized stablehlo to mhlo"); + } + + // Convert back to XLA, SPU still relies on XLA to eliminate ops like + // batch-normal-inference + std::string xla_text; + llvm::raw_string_ostream out(xla_text); + if (!mlir::failed(xla::MlirHloToHloTranslateFunction(module.get(), out, + true, true))) { + out.flush(); + module = importer.parseXlaModuleFromString(xla_text); + } + } else if (source.ir_type() == spu::SourceIRType::XLA) { + module = importer.parseXlaModuleFromString(source.ir_txt()); + } else { + SPU_THROW("Unhandled IR type = {}", source.ir_type()); } std::string input_vis_str; diff --git a/libspu/compiler/front_end/hlo_importer.cc b/libspu/compiler/front_end/hlo_importer.cc index 149f5a4b..80ba9aba 100644 --- a/libspu/compiler/front_end/hlo_importer.cc +++ b/libspu/compiler/front_end/hlo_importer.cc @@ -127,7 +127,7 @@ void runHloPasses(xla::HloModule *module) { /*allow_mixed_precision=*/false); pipeline.AddPass(); - pipeline.AddPass(GatherExpander::kEliminateSimpleGathers); + pipeline.AddPass(GatherExpander::kEliminateAllGathers); pipeline.AddPass(ScatterExpander::kEliminateAllScatters); pipeline.AddPass(options); pipeline.AddPass(); @@ -163,7 +163,10 @@ HloImporter::parseXlaModuleFromString(const std::string &content) { // If parse as HloModuleProto fails, try HloProto. xla::HloProto hlo_proto; if (!hlo_proto.ParseFromString(content)) { - SPU_THROW("Failed to parse hlo module from string"); + // Try human-readable format + if (!google::protobuf::TextFormat::ParseFromString(content, &hlo_proto)) { + SPU_THROW("Failed to parse hlo module from string {}", content); + } } hlo_module = hlo_proto.hlo_module(); } diff --git a/libspu/compiler/passes/BUILD.bazel b/libspu/compiler/passes/BUILD.bazel index c470e86e..ef5ca7d3 100644 --- a/libspu/compiler/passes/BUILD.bazel +++ b/libspu/compiler/passes/BUILD.bazel @@ -197,18 +197,6 @@ spu_cc_library( ], ) -spu_cc_library( - name = "expand_secret_gather", - srcs = ["expand_secret_gather.cc"], - hdrs = ["passes.h"], - deps = [ - ":pass_details", - "//libspu/dialect:pphlo_dialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:TransformUtils", - ], -) - spu_cc_library( name = "rewrite_div_sqrt_patterns", srcs = ["rewrite_div_sqrt_patterns.cc"], @@ -276,7 +264,6 @@ spu_cc_library( ":convert_push_down", ":decompose_comparison", ":decompose_minmax", - ":expand_secret_gather", ":hlo_legalize_to_pphlo", ":insert_deallocation", ":lower_conversion_cast", diff --git a/libspu/compiler/passes/expand_secret_gather.cc b/libspu/compiler/passes/expand_secret_gather.cc deleted file mode 100644 index 10117f31..00000000 --- a/libspu/compiler/passes/expand_secret_gather.cc +++ /dev/null @@ -1,641 +0,0 @@ -// Copyright 2023 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -#include "libspu/compiler/passes/pass_details.h" -#include "libspu/dialect/pphlo_ops.h" - -namespace mlir::pphlo { - -namespace { - -bool GatherIsBroadcast(GatherOp &op) { - auto gather_slice_size = op.getSliceSizes(); - auto op_shape = op.getOperand().getType().getShape(); - return (gather_slice_size.size() == op_shape.size()) && - (std::equal(gather_slice_size.begin(), gather_slice_size.end(), - op_shape.begin())); -} - -std::vector DeleteDimensions(llvm::ArrayRef dims_to_delete, - llvm::ArrayRef shape) { - std::unordered_set ordered_dims_to_delete(dims_to_delete.begin(), - dims_to_delete.end()); - - std::vector result; - result.reserve(shape.size() - ordered_dims_to_delete.size()); - - for (size_t idx = 0; idx < shape.size(); ++idx) { - if (ordered_dims_to_delete.count(idx) != 0) { - continue; - } - result.emplace_back(idx); - } - return result; -} - -// Computes how many trips a loop implementing this gather op would take. -int64_t GatherLoopTripCount(GatherOp op) { - auto start_indices = op.getStartIndices(); - const auto start_indices_shape = start_indices.getType().getShape(); - const auto &dim_numbers = op.getDimensionNumbers(); - - int64_t trip_count = 1; - for (int64_t i = 0, e = start_indices_shape.size(); i < e; i++) { - if (i != dim_numbers.getIndexVectorDim()) { - trip_count *= start_indices_shape[i]; - } - } - return trip_count; -} - -llvm::SmallVector -ComputePermutedShape(llvm::ArrayRef shape, - llvm::ArrayRef permutation) { - llvm::SmallVector result_shape; - for (auto dim : permutation) { - result_shape.emplace_back(shape[dim]); - } - return result_shape; -} - -TypedValue -TransposeIndexVectorDimToLast(TypedValue &start_indices, - int64_t index_vector_dim) { - const auto start_indices_shape = start_indices.getType().getShape(); - - if (static_cast(start_indices_shape.size()) == index_vector_dim) { - return start_indices; - } - - if (index_vector_dim == - static_cast(start_indices_shape.size() - 1)) { - return start_indices; - } - - std::vector permutation; - permutation.reserve(start_indices_shape.size()); - for (int64_t i = 0, e = start_indices_shape.size(); i < e; i++) { - if (i != index_vector_dim) { - permutation.emplace_back(i); - } - } - permutation.emplace_back(index_vector_dim); - - auto result_shape = ComputePermutedShape(start_indices_shape, permutation); - - OpBuilder builder(start_indices.getContext()); - if (auto *ip = start_indices.getDefiningOp()) { - builder.setInsertionPointAfter(ip); - } else { - builder.setInsertionPointToStart(start_indices.getParentBlock()); - } - - auto transpose = builder.create( - start_indices.getLoc(), - RankedTensorType::get(result_shape, - start_indices.getType().getElementType()), - start_indices, permutation); - - return transpose.getResult(); -} - -TypedValue -PrependDegenerateDims(TypedValue operand, int64_t n) { - SPU_ENFORCE(n > 0); - std::vector new_shape_dims; - const auto operand_shape = operand.getType().getShape(); - new_shape_dims.reserve(n + operand_shape.size()); - new_shape_dims.insert(new_shape_dims.begin(), n, 1); - std::copy(operand_shape.begin(), operand_shape.end(), - std::back_inserter(new_shape_dims)); - - OpBuilder builder(operand.getContext()); - if (auto *ip = operand.getDefiningOp()) { - builder.setInsertionPointAfter(ip); - } else { - builder.setInsertionPointToStart(operand.getParentBlock()); - } - - auto reshape = builder.create( - operand.getLoc(), - RankedTensorType::get(new_shape_dims, operand.getType().getElementType()), - operand); - - return reshape.getResult(); -} - -TypedValue -CollapseFirstNDims(TypedValue operand, int64_t n) { - SPU_ENFORCE(n > 0); - - const auto operand_shape = operand.getType().getShape(); - SPU_ENFORCE((int64_t)operand_shape.size() >= n); - - int64_t new_shape_leading_bound = 1; - for (int64_t i = 0; i < n; i++) { - new_shape_leading_bound *= operand_shape[i]; - } - - std::vector new_shape_dims; - new_shape_dims.reserve(operand_shape.size() - n + 1); - new_shape_dims.push_back(new_shape_leading_bound); - - std::copy(operand_shape.begin() + n, operand_shape.end(), - std::back_inserter(new_shape_dims)); - - auto output_type = - RankedTensorType::get(new_shape_dims, operand.getType().getElementType()); - - OpBuilder builder(operand.getContext()); - if (auto *ip = operand.getDefiningOp()) { - builder.setInsertionPointAfter(ip); - } else { - builder.setInsertionPointToStart(operand.getParentBlock()); - } - - auto reshape = - builder.create(operand.getLoc(), output_type, operand); - - return reshape.getResult(); -} - -// Canonicalizes the start_indices tensors so that we only have deal with some -// specific cases in the while loop that does the heavy lifting. -// -// See the "High Level Algorithm" section for a broader picture. -TypedValue -CanonicalizeGatherIndices(TypedValue &start_indices, - int64_t index_vector_dim) { - // Transpose the non-index-vector dimensions to the front. - auto transposed_start_indices = - TransposeIndexVectorDimToLast(start_indices, index_vector_dim); - bool indices_are_scalar = - index_vector_dim == - static_cast(start_indices.getType().getShape().size()); - - // The number of dimensions in start_indices that are index dimensions. - const int64_t index_dims_in_start_indices = indices_are_scalar ? 0 : 1; - - // If there is only one index (i.e. start_indices has rank 1 and this gather - // is really just a dynamic slice) add a leading degenerate dimension for - // uniformity. Otherwise create a "collapsed" leading dimension that subsumes - // all of the non-index-vector dimensions. - const auto shape = transposed_start_indices.getType().getShape(); - if (static_cast(shape.size()) == index_dims_in_start_indices) { - return PrependDegenerateDims(transposed_start_indices, 1); - } else { - // Collapse all but the dimensions (0 or 1) in start_indices containing the - // index vectors. - return CollapseFirstNDims(transposed_start_indices, - shape.size() - index_dims_in_start_indices); - } -} - -TypedValue CreateGatherLoopAccumulatorInitValue( - GatherOp op, Type element_type, llvm::ArrayRef slice_sizes, - int64_t gather_loop_trip_count, - const GatherDimensionNumbersAttr &dim_numbers) { - std::vector accumulator_state_shape_dims; - accumulator_state_shape_dims.reserve(1 + slice_sizes.size()); - accumulator_state_shape_dims.push_back(gather_loop_trip_count); - for (int64_t i = 0; i < static_cast(slice_sizes.size()); i++) { - if (!std::binary_search(dim_numbers.getCollapsedSliceDims().begin(), - dim_numbers.getCollapsedSliceDims().end(), i)) { - accumulator_state_shape_dims.emplace_back(slice_sizes[i]); - } - } - - OpBuilder builder(op); - TypeTools type_tools; - - auto express_type = type_tools.getExpressedType(element_type); - auto shaped_type = - RankedTensorType::get(accumulator_state_shape_dims, express_type); - auto zero_attr = builder.getZeroAttr(shaped_type); - - if (zero_attr == nullptr && express_type.isa()) { - std::complex zero = {APFloat(0.0F), APFloat(0.0F)}; - zero_attr = DenseElementsAttr::get(shaped_type, - std::vector>( - shaped_type.getNumElements(), zero)); - } - - auto c = builder.create(op->getLoc(), zero_attr); - - if (type_tools.getTypeVisibility(element_type) != Visibility::VIS_PUBLIC) { - auto convert = builder.create( - op.getLoc(), - RankedTensorType::get(accumulator_state_shape_dims, element_type), - c.getResult()); - return convert.getResult(); - } else { - return c.getResult(); - } -} - -TypedValue -ExpandFirstDimIntoNDims(TypedValue operand, - llvm::ArrayRef expanded_dims) { - SPU_ENFORCE_GT(operand.getType().getShape().size(), size_t(0)); - SPU_ENFORCE_EQ(operand.getType().getShape()[0], - std::accumulate(expanded_dims.begin(), expanded_dims.end(), 1, - std::multiplies())); - - std::vector expanded_shape_dim_bounds; - expanded_shape_dim_bounds.reserve(expanded_dims.size() + - operand.getType().getShape().size() - 1); - std::copy(expanded_dims.begin(), expanded_dims.end(), - std::back_inserter(expanded_shape_dim_bounds)); - std::copy(operand.getType().getShape().begin() + 1, - operand.getType().getShape().end(), - std::back_inserter(expanded_shape_dim_bounds)); - - auto result_type = RankedTensorType::get(expanded_shape_dim_bounds, - operand.getType().getElementType()); - - OpBuilder builder(operand.getContext()); - if (auto *ip = operand.getDefiningOp()) { - builder.setInsertionPointAfter(ip); - } else { - builder.setInsertionPointToStart(operand.getParentBlock()); - } - auto reshaped = - builder.create(operand.getLoc(), result_type, operand); - return reshaped.getResult(); -} - -TypedValue -ElideDegenerateDims(OpBuilder *builder, TypedValue operand, - absl::Span dims_to_elide) { - std::unordered_set dims_to_elide_set(dims_to_elide.begin(), - dims_to_elide.end()); - std::vector new_shape; - for (size_t idx = 0; idx < operand.getType().getShape().size(); ++idx) { - if (dims_to_elide_set.count(idx) > 0) { - continue; - } - new_shape.emplace_back(operand.getType().getShape()[idx]); - } - - auto reshape = builder->create( - operand.getLoc(), - RankedTensorType::get(new_shape, operand.getType().getElementType()), - operand); - return reshape.getResult(); -} - -// Expands out or contracts away the gather dimensions in the accumulator -// produced by the while loop. -TypedValue AdjustBatchDimsInAccumulator( - OpBuilder *builder, llvm::ArrayRef start_indices_shape, - TypedValue accumulator, int64_t index_vector_dim) { - std::vector batch_dim_bounds; - batch_dim_bounds.reserve(start_indices_shape.size()); - for (int64_t i = 0, e = start_indices_shape.size(); i < e; i++) { - if (i != index_vector_dim) { - batch_dim_bounds.push_back(start_indices_shape[i]); - } - } - - if (batch_dim_bounds.empty()) { - // If batch_dim_bounds is empty we must be lowering a (effectively) - // dynamic-slice. In that case, there is a leading degenerate gather - // dimension that we added to make this special case play well with the - // general while loop which we need to remove now. - return ElideDegenerateDims(builder, accumulator, {0}); - } - - return ExpandFirstDimIntoNDims(accumulator, batch_dim_bounds); -} - -void BuildWhileCondition(Region &cond, Value /*counter*/, - Value /*canonical_start_indices*/, - Value /*accumulator_init*/, Value loop_upper_bound) { - OpBuilder builder(cond); - TypeTools type_tool; - - auto lt = builder.create( - cond.getLoc(), - RankedTensorType::get( - {}, type_tool.getTypeWithVisibility(builder.getI1Type(), - Visibility::VIS_PUBLIC)), - cond.getArgument(0), loop_upper_bound); - - builder.create(cond.getLoc(), ValueRange{lt.getResult()}); -} - -int64_t FindIndex(llvm::ArrayRef c, int64_t value) { - const auto *it = std::find(c.begin(), c.end(), value); - return std::distance(c.begin(), it); -} - -// Expand an index vector from the start_indices tensor into a vector that can -// be used to dynamic-slice out of the gather operand. -llvm::SmallVector ExpandIndexVectorIntoOperandSpace( - OpBuilder *builder, TypedValue index_vector, - const GatherDimensionNumbersAttr &dim_numbers, int64_t operand_rank) { - - TypeTools typetool; - auto index_type = - typetool.getExpressedType(index_vector.getType().getElementType()); - - if (operand_rank == 0) { - // This is Gather from a scalar. So, the index vector in operand space must - // be a zero-sized vector. - // return computation->AddInstruction(HloInstruction::CreateConstant( - // LiteralUtil::CreateFromDimensions(index_shape.element_type(), {0}))); - auto zero_const = builder->create( - index_vector.getLoc(), - builder->getZeroAttr(RankedTensorType::get({}, index_type))); - return {zero_const.getResult()}; - } - - auto p_zero_const = builder->create( - index_vector.getLoc(), - builder->getZeroAttr(RankedTensorType::get({}, index_type))); - - auto zero_const = builder->create( - index_vector.getLoc(), - RankedTensorType::get({}, typetool.toMPCType(index_type)), - p_zero_const); - - // We extract out individual components from the smaller index and concatenate - // them (interspersing zeros as needed) into the larger index. - llvm::SmallVector expanded_index_components; - - for (int64_t i = 0; i < operand_rank; i++) { - int64_t index_vector_dim_index = - FindIndex(dim_numbers.getStartIndexMap(), i); - if (index_vector_dim_index != - static_cast(dim_numbers.getStartIndexMap().size())) { - - auto component_to_concat = builder->create( - index_vector.getLoc(), - RankedTensorType::get({1}, index_vector.getType().getElementType()), - index_vector, - DenseI64ArrayAttr::get(builder->getContext(), - {index_vector_dim_index}), - DenseI64ArrayAttr::get(builder->getContext(), - {index_vector_dim_index + 1}), - DenseI64ArrayAttr::get(builder->getContext(), {1})); - auto reshaped = builder->create( - index_vector.getLoc(), - RankedTensorType::get({}, index_vector.getType().getElementType()), - component_to_concat); - expanded_index_components.push_back(reshaped); - } else { - expanded_index_components.push_back(zero_const); - } - } - - return expanded_index_components; -} - -// This generates the body of the while that implements the main data movement -// behavior of gather using dynamic-slice and dynamic-update-slice. -void GatherLoopBody(GatherOp gather, Region &body, - TypedValue operand, - TypedValue start_indices) { - OpBuilder builder(body); - - auto induction_var = body.getArgument(0); - auto output_accumulator = body.getArgument(1); - - TypeTools typetools; - auto index_type = typetools.getExpressedType( - induction_var.getType().dyn_cast().getElementType()); - - // Increment counter first - auto const_one = builder.create( - gather->getLoc(), - DenseElementsAttr::get(RankedTensorType::get({}, index_type), - builder.getIntegerAttr(index_type, 1))); - - // counter + 1 - auto incremented_counter = - builder.create(induction_var.getLoc(), induction_var.getType(), - induction_var, const_one); - - const auto &dim_numbers = gather.getDimensionNumbers(); - - bool has_scalar_indices = start_indices.getType().getShape().size() == 1; - SPU_ENFORCE_EQ( - has_scalar_indices, - dim_numbers.getIndexVectorDim() == - (int64_t)gather.getStartIndices().getType().getShape().size()); - - auto index_zero = builder.create( - gather->getLoc(), - builder.getZeroAttr(RankedTensorType::get({}, index_type))); - - TypedValue index_vector; - - if (has_scalar_indices) { - // In this case start_indices has rank 1 and induction_var_as_vector (of - // shape {1}) is an index into this rank 1 tensor. - auto ds = builder.create( - gather->getLoc(), start_indices, ValueRange{induction_var}, - DenseI64ArrayAttr::get(builder.getContext(), {1})); - index_vector = ds.getResult(); - } else { - // In this case start_indices has rank 2 and induction_var_as_vector (of - // shape {1}) is an index into just the first dimension of this rank 2 - // tensor. - - int64_t index_vector_size = start_indices.getType().getShape()[1]; - - auto index_vector_2d = builder.create( - gather->getLoc(), start_indices, ValueRange{induction_var, index_zero}, - DenseI64ArrayAttr::get(builder.getContext(), {1, index_vector_size})); - - index_vector = ElideDegenerateDims(&builder, index_vector_2d, {0}); - } - - auto gathered_slice_start = ExpandIndexVectorIntoOperandSpace( - &builder, index_vector, dim_numbers, operand.getType().getShape().size()); - - auto gathered_slice = builder.create( - gather->getLoc(), operand, gathered_slice_start, gather.getSliceSizes()); - - auto gathered_slice_with_dims_collapsed = ElideDegenerateDims( - &builder, gathered_slice, dim_numbers.getCollapsedSliceDims()); - - auto gathered_slice_for_update = - PrependDegenerateDims(gathered_slice_with_dims_collapsed, 1); - - SmallVector index_vector_into_accumulator; - index_vector_into_accumulator.push_back(induction_var); - for (size_t idx = 0; - idx < gathered_slice_with_dims_collapsed.getType().getShape().size(); - ++idx) { - index_vector_into_accumulator.push_back(index_zero); - } - - auto updated_accumulator = builder.create( - gather->getLoc(), output_accumulator, gathered_slice_for_update, - index_vector_into_accumulator); - - builder.create( - gather->getLoc(), ValueRange{incremented_counter, updated_accumulator}); -} - -struct GatherConverter : public OpRewritePattern { - explicit GatherConverter(MLIRContext *context) : OpRewritePattern(context) {} - - LogicalResult matchAndRewrite(GatherOp op, - PatternRewriter &rewriter) const override { - - TypeTools type_tool; - if (type_tool.getTypeVisibility(op.getStartIndices().getType()) != - Visibility::VIS_SECRET) { - // Do not expand public gather - return failure(); - } - - OpBuilder builder(op); - - // Secret gather - if (GatherIsBroadcast(op)) { - // Replace gather with broadcast - auto broadcast_operand_shape = - DeleteDimensions(op.getDimensionNumbers().getCollapsedSliceDims(), - op.getOperand().getType().getShape()); - auto reshaped_type = RankedTensorType::get( - broadcast_operand_shape, op.getOperand().getType().getElementType()); - auto broadcast_operand = builder.create( - op->getLoc(), reshaped_type, op.getOperand()); - rewriter.replaceOpWithNewOp( - op, op->getResults().getType(), broadcast_operand, - DenseI64ArrayAttr::get(builder.getContext(), - op.getDimensionNumbers().getOffsetDims())); - return success(); - } - - auto index_type = type_tool.getExpressedType( - op.getStartIndices().getType().getElementType()); - auto operand = op.getOperand(); - auto start_indices = op.getStartIndices(); - auto output_type = op->getResultTypes()[0].dyn_cast(); - auto output_shape = output_type.getShape(); - int64_t output_rank = output_shape.size(); - - const auto &dim_numbers = op.getDimensionNumbers(); - - int64_t gather_loop_trip_count = GatherLoopTripCount(op); - - auto canonical_start_indices = CanonicalizeGatherIndices( - start_indices, dim_numbers.getIndexVectorDim()); - - SPU_ENFORCE(gather_loop_trip_count == - canonical_start_indices.getType().getShape()[0]); - - auto accumulator_init = CreateGatherLoopAccumulatorInitValue( - op, output_type.getElementType(), op.getSliceSizes(), - gather_loop_trip_count, op.getDimensionNumbers()); - - auto loopUpperBound = builder.create( - op->getLoc(), - DenseElementsAttr::get( - RankedTensorType::get({}, index_type), - builder.getIntegerAttr(index_type, gather_loop_trip_count))); - - auto counter = builder.create( - op->getLoc(), - builder.getZeroAttr(RankedTensorType::get({}, index_type))); - - auto loop = builder.create( - op->getLoc(), - TypeRange{counter.getResult().getType(), accumulator_init.getType()}, - ValueRange{counter, accumulator_init}); - { - loop.getCond().push_back(new Block()); - loop.getCond().front().addArguments( - TypeRange{counter.getType(), accumulator_init.getType()}, - {counter.getLoc(), accumulator_init.getLoc()}); - } - { - loop.getBody().push_back(new Block()); - - loop.getBody().front().addArguments( - TypeRange{counter.getType(), accumulator_init.getType()}, - {counter.getLoc(), accumulator_init.getLoc()}); - } - // Generate loop condition - BuildWhileCondition(loop.getCond(), counter.getResult(), - canonical_start_indices, accumulator_init, - loopUpperBound.getResult()); - - GatherLoopBody(op, loop.getBody(), operand, canonical_start_indices); - - OpResult accumulator_result = loop->getResults().back(); - - auto accumulator_with_batch_dims_decanonicalized = - AdjustBatchDimsInAccumulator( - &builder, start_indices.getType().getShape(), - cast>(accumulator_result), - dim_numbers.getIndexVectorDim()); - - std::vector permutation; - permutation.reserve(output_rank); - - int64_t batch_idx_counter = 0; - int64_t offset_idx_counter = - output_rank - dim_numbers.getOffsetDims().size(); - for (int64_t i = 0; i < output_rank; i++) { - bool is_offset_dim = - std::binary_search(dim_numbers.getOffsetDims().begin(), - dim_numbers.getOffsetDims().end(), i); - if (is_offset_dim) { - permutation.push_back(offset_idx_counter++); - } else { - permutation.push_back(batch_idx_counter++); - } - } - - rewriter.replaceOpWithNewOp( - op, op.getResult().getType(), - accumulator_with_batch_dims_decanonicalized, - DenseI64ArrayAttr::get(builder.getContext(), permutation)); - - return success(); - } -}; - -struct ExpandSecretGather : public ExpandSecretGatherBase { - void runOnOperation() override { - RewritePatternSet patterns(&getContext()); - populateOwningPatterns(&patterns, &getContext()); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); - } - -private: - static void populateOwningPatterns(RewritePatternSet *patterns, - MLIRContext *ctx) { - patterns->insert(ctx); - } -}; -} // namespace - -std::unique_ptr> createExpandSecretGatherPass() { - return std::make_unique(); -} - -} // namespace mlir::pphlo diff --git a/libspu/compiler/passes/hlo_legalize_to_pphlo.cc b/libspu/compiler/passes/hlo_legalize_to_pphlo.cc index c5c78b7f..c55e552c 100644 --- a/libspu/compiler/passes/hlo_legalize_to_pphlo.cc +++ b/libspu/compiler/passes/hlo_legalize_to_pphlo.cc @@ -1226,41 +1226,6 @@ class HloToPPHloOpConverter } }; -template <> -class HloToPPHloOpConverter - : public OpConversionPattern { -private: - const ValueVisibilityMap &vis_; - -public: - HloToPPHloOpConverter(TypeConverter &type_converter, MLIRContext *context, - const ValueVisibilityMap &vis) - : OpConversionPattern(type_converter, context), - vis_(vis) {} - - LogicalResult - matchAndRewrite(stablehlo::GatherOp op, stablehlo::GatherOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto old_attr = op.getDimensionNumbers(); - pphlo::GatherDimensionNumbersAttr attr = GatherDimensionNumbersAttr::get( - op.getContext(), old_attr.getOffsetDims(), - old_attr.getCollapsedSliceDims(), old_attr.getStartIndexMap(), - old_attr.getIndexVectorDim()); - - auto result_vis = vis_.getValueVisibility(op.getResult()); - - Type resultType = HloToPPHloTypeConverter::getTypeWithVisibility( - this->getTypeConverter()->convertType(op.getType()), result_vis); - - rewriter.replaceOpWithNewOp( - op, resultType, adaptor.getOperands()[0], adaptor.getOperands()[1], - attr, ConvertDenseIntElementAttr(op.getSliceSizes()), - op.getIndicesAreSorted()); - - return success(); - } -}; - template <> class HloToPPHloOpConverter : public OpConversionPattern { @@ -1628,7 +1593,6 @@ struct HloLegalizeToPPHlo HloToPPHloOpConverter, HloToPPHloOpConverter, HloToPPHloOpConverter, - HloToPPHloOpConverter, HloToPPHloOpConverter, HloToPPHloOpConverter, HloToPPHloOpConverter, diff --git a/libspu/compiler/passes/map_stablehlo_to_pphlo_op.h b/libspu/compiler/passes/map_stablehlo_to_pphlo_op.h index 4aa3f8a5..b409051a 100644 --- a/libspu/compiler/passes/map_stablehlo_to_pphlo_op.h +++ b/libspu/compiler/passes/map_stablehlo_to_pphlo_op.h @@ -58,7 +58,6 @@ MAP_HLO_TO_PPHLO(DivOp) MAP_HLO_TO_PPHLO(DotOp) MAP_HLO_TO_PPHLO(ExpOp) MAP_HLO_TO_PPHLO(Expm1Op) -MAP_HLO_TO_PPHLO(GatherOp) MAP_HLO_TO_PPHLO(IotaOp) MAP_HLO_TO_PPHLO(FloorOp) MAP_HLO_TO_PPHLO(LogOp) diff --git a/libspu/compiler/passes/passes.h b/libspu/compiler/passes/passes.h index fbda4cfb..16c70cab 100644 --- a/libspu/compiler/passes/passes.h +++ b/libspu/compiler/passes/passes.h @@ -62,8 +62,6 @@ std::unique_ptr> createOptimizeSelectPass(); // Optimize sqrt(x) + very_small_const) -> sqrt(x + eps) std::unique_ptr> createOptimizeSqrtPlusEps(); -std::unique_ptr> createExpandSecretGatherPass(); - // Rewrite x/sqrt(x+eps) -> x*rsqrt(x+eps) std::unique_ptr> createRewriteDivSqrtPatterns(); diff --git a/libspu/compiler/passes/passes.td b/libspu/compiler/passes/passes.td index 281649d4..942d531f 100644 --- a/libspu/compiler/passes/passes.td +++ b/libspu/compiler/passes/passes.td @@ -81,12 +81,6 @@ def RewriteDivSqrtPatterns: Pass<"rewrite-div-sqrt-pattern", "func::FuncOp"> { let dependentDialects = ["pphlo::PPHloDialect"]; } -def ExpandSecretGather: Pass<"expand-secret-gather", "func::FuncOp"> { - let summary = "Rewrite Gather with secret indexing to loop with DynamicUpdateSlice"; - let constructor = "createExpandSecretGatherPass()"; - let dependentDialects = ["pphlo::PPHloDialect"]; -} - def OptimizeDenominatorWithBcast: Pass<"optimize-denominator-with-broadcast", "func::FuncOp"> { let summary = "Optimize x/broadcast(y) into x*broadcast(1/y)"; let constructor = "createOptimizeDenominatorWithBroadcast()"; diff --git a/libspu/compiler/passes/rewrite_div_sqrt_patterns.cc b/libspu/compiler/passes/rewrite_div_sqrt_patterns.cc index a975799e..e4a4ae21 100644 --- a/libspu/compiler/passes/rewrite_div_sqrt_patterns.cc +++ b/libspu/compiler/passes/rewrite_div_sqrt_patterns.cc @@ -27,23 +27,48 @@ namespace mlir::pphlo { namespace { struct DivRewriter : public OpRewritePattern { +private: + Operation *rewriteSqrtIfPossible(PatternRewriter &rewriter, + Operation *op) const { + if (op == nullptr || op->getNumOperands() != 1) { + return nullptr; + } + + if (mlir::isa(op)) { + return rewriter.create(op->getLoc(), op->getResultTypes(), + op->getOperand(0)); + } + + if (auto bcastOp = mlir::dyn_cast(op)) { + if (auto *inner = rewriteSqrtIfPossible( + rewriter, bcastOp.getOperand().getDefiningOp())) { + return rewriter.create( + op->getLoc(), bcastOp->getResultTypes(), inner->getResult(0), + bcastOp.getBroadcastDimensions()); + } + return nullptr; + } + + return nullptr; + } + +public: explicit DivRewriter(MLIRContext *context) : OpRewritePattern(context) {} LogicalResult matchAndRewrite(DivOp op, PatternRewriter &rewriter) const override { // Pattern 1: - // y/sqrt(x + eps) + // y/sqrt(x) -> y*rsqrt(x) auto denominator = op.getRhs(); - if (auto sqrt = denominator.getDefiningOp()) { - auto newRsqrt = rewriter.create( - denominator.getLoc(), denominator.getType(), sqrt.getOperand()); + if (auto *newop = + rewriteSqrtIfPossible(rewriter, denominator.getDefiningOp())) { rewriter.replaceOpWithNewOp(op, op.getType(), op.getLhs(), - newRsqrt); + newop->getResult(0)); return success(); } else { // Pattern 2: - // y/(k*sqrt(x + eps)) -> y/k*rsqrt(x+eps) + // y/(k*sqrt(x)) -> y/k*rsqrt(x) if (auto mulOp = denominator.getDefiningOp()) { auto sqrtOp = mulOp.getRhs().getDefiningOp(); auto k = mulOp.getLhs(); @@ -55,10 +80,10 @@ struct DivRewriter : public OpRewritePattern { // y/k auto newDiv = rewriter.create( op.getLoc(), op->getResultTypes(), op.getLhs(), k); - // rsqrt(x+eps) + // rsqrt(x) auto newRsqrt = rewriter.create( op->getLoc(), sqrtOp->getResultTypes(), sqrtOp->getOperand(0)); - // y/k*rsqrt(x+eps) + // y/k*rsqrt(x) rewriter.replaceOpWithNewOp(op, op.getType(), newDiv, newRsqrt); return success(); diff --git a/libspu/compiler/passes/visibility_inference.cc b/libspu/compiler/passes/visibility_inference.cc index 3b6eae6c..b238d34f 100644 --- a/libspu/compiler/passes/visibility_inference.cc +++ b/libspu/compiler/passes/visibility_inference.cc @@ -306,14 +306,6 @@ void VisibilityInference::inferOperation(Operation &op) { value_vis_.setValueVisibility(op.getResult(0), Visibility::VIS_PUBLIC); } else if (llvm::isa(op)) { inferSort(op); - } else if (llvm::isa(op)) { - // For gather op, if either operand or indices is a secret, result is a - // secret - auto operand_vis = value_vis_.getValueVisibility(op.getOperand(0)); - auto indices_vis = value_vis_.getValueVisibility(op.getOperand(1)); - value_vis_.setValueVisibility( - op.getResult(0), - TypeTools::inferResultVisibility({operand_vis, indices_vis})); } else if (llvm::isa(op)) { inferSelectAndScatter(op); } else if (llvm::isa(op)) { diff --git a/libspu/compiler/tests/expand_secret_gather.mlir b/libspu/compiler/tests/expand_secret_gather.mlir deleted file mode 100644 index 36ac1553..00000000 --- a/libspu/compiler/tests/expand_secret_gather.mlir +++ /dev/null @@ -1,14 +0,0 @@ -// RUN: mlir-pphlo-opt --expand-secret-gather --split-input-file %s | FileCheck %s - -func.func @main(%arg0: tensor<2x!pphlo.pub>, %arg1: tensor<1x!pphlo.sec>) -> (tensor>) { - //CHECK-NOT: pphlo.gather - //CHECK : pphlo.while - %0 = "pphlo.gather"(%arg0, %arg1) {dimension_numbers = #pphlo.gather, indices_are_sorted = true, slice_sizes = array} : (tensor<2x!pphlo.pub>, tensor<1x!pphlo.sec>) -> tensor> - return %0: tensor> -} - -// ----- -func.func @main(%arg0: tensor<3x3x!pphlo.pub>, %arg1: tensor<2x!pphlo.sec>) -> (tensor<2x3x!pphlo.sec>) { - %0 = "pphlo.gather"(%arg0, %arg1) {dimension_numbers = #pphlo.gather, indices_are_sorted = false, slice_sizes = array} : (tensor<3x3x!pphlo.pub>, tensor<2x!pphlo.sec>) -> tensor<2x3x!pphlo.sec> - return %0 : tensor<2x3x!pphlo.sec> -} diff --git a/libspu/compiler/tests/hlo_to_pphlo_dynamic_slice.mlir b/libspu/compiler/tests/hlo_to_pphlo_dynamic_slice.mlir new file mode 100644 index 00000000..3196a62c --- /dev/null +++ b/libspu/compiler/tests/hlo_to_pphlo_dynamic_slice.mlir @@ -0,0 +1,7 @@ +// RUN: mlir-pphlo-opt --hlo-legalize-to-pphlo=input_vis_list=VIS_PUBLIC,VIS_SECRET --split-input-file %s | FileCheck %s + +func.func @main(%arg0: tensor<15xi32>,%arg1: tensor) -> (tensor<1xi32>) { + // CHECK: %0 = "pphlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = array} : (tensor<15x!pphlo.pub>, tensor>) -> tensor<1x!pphlo.sec> + %0 = "stablehlo.dynamic_slice"(%arg0, %arg1) {slice_sizes = array} : (tensor<15xi32>, tensor) -> tensor<1xi32> + return %0 : tensor<1xi32> +} diff --git a/libspu/compiler/tests/no_expand_secret_gather.mlir b/libspu/compiler/tests/no_expand_secret_gather.mlir deleted file mode 100644 index 155a4dca..00000000 --- a/libspu/compiler/tests/no_expand_secret_gather.mlir +++ /dev/null @@ -1,8 +0,0 @@ -// RUN: mlir-pphlo-opt --expand-secret-gather --split-input-file %s | FileCheck %s - -func.func @main(%arg0: tensor<2x!pphlo.pub>, %arg1: tensor<1x!pphlo.pub>) -> (tensor>) { - //CHECK-NOT: pphlo.while - //CHECK : pphlo.gather - %0 = "pphlo.gather"(%arg0, %arg1) {dimension_numbers = #pphlo.gather, indices_are_sorted = true, slice_sizes = array} : (tensor<2x!pphlo.pub>, tensor<1x!pphlo.pub>) -> tensor> - return %0: tensor> -} diff --git a/libspu/compiler/tests/optimize_sqrt_to_rsqrt.mlir b/libspu/compiler/tests/optimize_sqrt_to_rsqrt.mlir index 88b7738d..6289c118 100644 --- a/libspu/compiler/tests/optimize_sqrt_to_rsqrt.mlir +++ b/libspu/compiler/tests/optimize_sqrt_to_rsqrt.mlir @@ -34,3 +34,29 @@ func.func @main(%arg0: tensor>, %arg1: tensor>) return %4: tensor> } +// ----- + +func.func @main(%arg0: tensor<3x4x!pphlo.sec>) -> tensor<3x4x!pphlo.sec> { + // CHECK: %[[RSQRT:.+]] = "pphlo.rsqrt"(%arg0) : (tensor<3x4x!pphlo.sec>) -> tensor<3x4x!pphlo.sec> + // CHECK: "pphlo.multiply"(%arg0, %[[RSQRT]]) : (tensor<3x4x!pphlo.sec>, tensor<3x4x!pphlo.sec>) -> tensor<3x4x!pphlo.sec> + %0 = "pphlo.sqrt"(%arg0) : (tensor<3x4x!pphlo.sec>) -> tensor<3x4x!pphlo.sec> + %1 = "pphlo.divide"(%arg0, %0) : (tensor<3x4x!pphlo.sec>, tensor<3x4x!pphlo.sec>) -> tensor<3x4x!pphlo.sec> + return %1 : tensor<3x4x!pphlo.sec> +} + +// ----- + +func.func @main(%arg0: tensor<3x4x!pphlo.sec>) -> tensor<3x4x!pphlo.sec> { + %0 = "pphlo.convert"(%arg0) : (tensor<3x4x!pphlo.sec>) -> tensor<3x4x!pphlo.sec> + %1 = "pphlo.reshape"(%arg0) : (tensor<3x4x!pphlo.sec>) -> tensor<3x4x1x!pphlo.sec> + %2 = "pphlo.transpose"(%1) {permutation = array} : (tensor<3x4x1x!pphlo.sec>) -> tensor<3x1x4x!pphlo.sec> + %3 = "pphlo.dot_general"(%2, %1) {dot_dimension_numbers = #pphlo.dot} : (tensor<3x1x4x!pphlo.sec>, tensor<3x4x1x!pphlo.sec>) -> tensor<3x!pphlo.sec> + %4 = "pphlo.convert"(%3) : (tensor<3x!pphlo.sec>) -> tensor<3x!pphlo.sec> + // CHECK: %[[RSQRT:.+]] = "pphlo.rsqrt" + // CHECK: %[[BCAST:.+]] = "pphlo.broadcast"(%[[RSQRT]]) {broadcast_dimensions = array} : (tensor<3x!pphlo.sec>) -> tensor<3x4x!pphlo.sec> + // CHECK: "pphlo.multiply"(%0, %[[BCAST]]) : (tensor<3x4x!pphlo.sec>, tensor<3x4x!pphlo.sec>) -> tensor<3x4x!pphlo.sec> + %5 = "pphlo.sqrt"(%4) : (tensor<3x!pphlo.sec>) -> tensor<3x!pphlo.sec> + %6 = "pphlo.broadcast"(%5) {broadcast_dimensions = array} : (tensor<3x!pphlo.sec>) -> tensor<3x4x!pphlo.sec> + %7 = "pphlo.divide"(%0, %6) : (tensor<3x4x!pphlo.sec>, tensor<3x4x!pphlo.sec>) -> tensor<3x4x!pphlo.sec> + return %7 : tensor<3x4x!pphlo.sec> +} diff --git a/libspu/device/pphlo/pphlo_executor.cc b/libspu/device/pphlo/pphlo_executor.cc index 6ed81a4d..248fe9e4 100644 --- a/libspu/device/pphlo/pphlo_executor.cc +++ b/libspu/device/pphlo/pphlo_executor.cc @@ -481,36 +481,6 @@ void execute(OpExecutor *, SPUContext *sctx, SymbolScope *sscope, opts); } -void execute(OpExecutor *, SPUContext *sctx, SymbolScope *sscope, - mlir::pphlo::GatherOp &op, const ExecutionOptions &opts) { - // If input is empty, short circuit - auto operand = lookupValue(sscope, op.getOperand(), opts); - auto start_indices = lookupValue(sscope, op.getStartIndices(), opts); - if (operand.numel() == 0) { - addValue(sscope, op.getResult(), operand, opts); - return; - } - - const auto &output_shape = - op.getResult().getType().dyn_cast().getShape(); - - const auto &dim_numbers = op.getDimensionNumbers(); - - kernel::hlo::GatherConfig config; - // Sizes ss; - // convertDenseIntElementAttr(op.getSliceSizes(), ss); - config.sliceSizes = op.getSliceSizes(); - config.indexVectorDim = dim_numbers.getIndexVectorDim(); - config.offsetDims = dim_numbers.getOffsetDims(); - config.collapsedSliceDims = dim_numbers.getCollapsedSliceDims(); - config.startIndexMap = dim_numbers.getStartIndexMap(); - - addValue( - sscope, op.getResult(), - kernel::hlo::Gather(sctx, operand, start_indices, config, output_shape), - opts); -} - void execute(OpExecutor *executor, SPUContext *sctx, SymbolScope *sscope, mlir::pphlo::SortOp &op, const ExecutionOptions &opts) { auto sort_dim = op.getDimension(); diff --git a/libspu/device/pphlo/pphlo_executor_test.cc b/libspu/device/pphlo/pphlo_executor_test.cc index 46362969..8d5a09ae 100644 --- a/libspu/device/pphlo/pphlo_executor_test.cc +++ b/libspu/device/pphlo/pphlo_executor_test.cc @@ -449,11 +449,11 @@ TEST_P(ExecutorTest, ReduceWindowStableHloTest) { r.run(r.compileMHlo(R"( func.func @main(%arg0: tensor<3x2xi32>) -> (tensor<2x2xi32>) { - %0 = "mhlo.constant"() {value = dense<0> : tensor} : () -> tensor - %1 = "mhlo.reduce_window"(%arg0, %0) ( { + %0 = "stablehlo.constant"() {value = dense<0> : tensor} : () -> tensor + %1 = "stablehlo.reduce_window"(%arg0, %0) ( { ^bb0(%arg1: tensor, %arg2: tensor): // no predecessors - %2 = "mhlo.maximum"(%arg1, %arg2) : (tensor, tensor) -> tensor - "mhlo.return"(%2) : (tensor) -> () + %2 = "stablehlo.maximum"(%arg1, %arg2) : (tensor, tensor) -> tensor + "stablehlo.return"(%2) : (tensor) -> () }) { base_dilations = dense<[2, 1]> : tensor<2xi64>, padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>, @@ -478,11 +478,11 @@ TEST_P(ExecutorTest, ReduceWindowStableHloTest2) { r.run(r.compileMHlo(R"( func.func @main(%arg0: tensor<3x2xi32>) -> (tensor<1x2xi32>) { - %0 = "mhlo.constant"() {value = dense<0> : tensor} : () -> tensor - %1 = "mhlo.reduce_window"(%arg0, %0) ( { + %0 = "stablehlo.constant"() {value = dense<0> : tensor} : () -> tensor + %1 = "stablehlo.reduce_window"(%arg0, %0) ( { ^bb0(%arg1: tensor, %arg2: tensor): // no predecessors - %2 = "mhlo.maximum"(%arg1, %arg2) : (tensor, tensor) -> tensor - "mhlo.return"(%2) : (tensor) -> () + %2 = "stablehlo.maximum"(%arg1, %arg2) : (tensor, tensor) -> tensor + "stablehlo.return"(%2) : (tensor) -> () }) { base_dilations = dense<[2, 1]> : tensor<2xi64>, padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>, @@ -583,11 +583,11 @@ TEST_P(ExecutorTest, ReduceWindowMaxIotaBaseDilation) { r.run(r.compileMHlo(R"( func.func @main(%arg0: tensor<4x4xi32>) -> (tensor<6x6xi32>) { - %0 = "mhlo.constant"() {value = dense<0> : tensor} : () -> tensor - %1 = "mhlo.reduce_window"(%arg0, %0) ( { + %0 = "stablehlo.constant"() {value = dense<0> : tensor} : () -> tensor + %1 = "stablehlo.reduce_window"(%arg0, %0) ( { ^bb0(%arg1: tensor, %arg2: tensor): // no predecessors - %2 = "mhlo.maximum"(%arg1, %arg2) : (tensor, tensor) -> tensor - "mhlo.return"(%2) : (tensor) -> () + %2 = "stablehlo.maximum"(%arg1, %arg2) : (tensor, tensor) -> tensor + "stablehlo.return"(%2) : (tensor) -> () }) { base_dilations = dense<2> : tensor<2xi64>, padding = dense<0> : tensor<2x2xi64>, @@ -615,11 +615,11 @@ TEST_P(ExecutorTest, ReduceWindowMaxIotaStrideBaseDilation) { auto compiled = r.compileMHlo(R"( func.func @main(%arg0: tensor<4x4xi32>) -> (tensor<3x3xi32>) { - %0 = "mhlo.constant"() {value = dense<0> : tensor} : () -> tensor - %1 = "mhlo.reduce_window"(%arg0, %0) ( { + %0 = "stablehlo.constant"() {value = dense<0> : tensor} : () -> tensor + %1 = "stablehlo.reduce_window"(%arg0, %0) ( { ^bb0(%arg1: tensor, %arg2: tensor): // no predecessors - %2 = "mhlo.maximum"(%arg1, %arg2) : (tensor, tensor) -> tensor - "mhlo.return"(%2) : (tensor) -> () + %2 = "stablehlo.maximum"(%arg1, %arg2) : (tensor, tensor) -> tensor + "stablehlo.return"(%2) : (tensor) -> () }) {base_dilations = dense<2> : tensor<2xi64>, padding = dense<0> : tensor<2x2xi64>, window_dilations = dense<1> : tensor<2xi64>, window_dimensions = dense<2> : tensor<2xi64>, window_strides = dense<2> : tensor<2xi64>} : (tensor<4x4xi32>, tensor) -> tensor<3x3xi32> return %1 : tensor<3x3xi32> @@ -642,11 +642,11 @@ TEST_P(ExecutorTest, ReduceWindowMaxIotaStrideBothDilation) { auto compiled = r.compileMHlo(R"( func.func @main(%arg0: tensor<4x4xi32>) -> (tensor<3x3xi32>) { - %0 = "mhlo.constant"() {value = dense<0> : tensor} : () -> tensor - %1 = "mhlo.reduce_window"(%arg0, %0) ( { + %0 = "stablehlo.constant"() {value = dense<0> : tensor} : () -> tensor + %1 = "stablehlo.reduce_window"(%arg0, %0) ( { ^bb0(%arg1: tensor, %arg2: tensor): // no predecessors - %2 = "mhlo.maximum"(%arg1, %arg2) : (tensor, tensor) -> tensor - "mhlo.return"(%2) : (tensor) -> () + %2 = "stablehlo.maximum"(%arg1, %arg2) : (tensor, tensor) -> tensor + "stablehlo.return"(%2) : (tensor) -> () }) {base_dilations = dense<2> : tensor<2xi64>, padding = dense<0> : tensor<2x2xi64>, window_dilations = dense<2> : tensor<2xi64>, window_dimensions = dense<2> : tensor<2xi64>, window_strides = dense<2> : tensor<2xi64>} : (tensor<4x4xi32>, tensor) -> tensor<3x3xi32> return %1 : tensor<3x3xi32> @@ -669,11 +669,11 @@ TEST_P(ExecutorTest, ReduceWindowMaxIotaPaddingStrideBaseDilation) { auto compiled = r.compileMHlo(R"( func.func @main(%arg0: tensor<4x4xi32>) -> (tensor<3x3xi32>) { - %0 = "mhlo.constant"() {value = dense<0> : tensor} : () -> tensor - %1 = "mhlo.reduce_window"(%arg0, %0) ( { + %0 = "stablehlo.constant"() {value = dense<0> : tensor} : () -> tensor + %1 = "stablehlo.reduce_window"(%arg0, %0) ( { ^bb0(%arg1: tensor, %arg2: tensor): // no predecessors - %2 = "mhlo.maximum"(%arg1, %arg2) : (tensor, tensor) -> tensor - "mhlo.return"(%2) : (tensor) -> () + %2 = "stablehlo.maximum"(%arg1, %arg2) : (tensor, tensor) -> tensor + "stablehlo.return"(%2) : (tensor) -> () }) {base_dilations = dense<2> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>, window_dilations = dense<1> : tensor<2xi64>, window_dimensions = dense<3> : tensor<2xi64>, window_strides = dense<3> : tensor<2xi64>} : (tensor<4x4xi32>, tensor) -> tensor<3x3xi32> return %1 : tensor<3x3xi32> @@ -846,124 +846,6 @@ func.func @main(%arg0: tensor>) -> (tensor>) { r.verifyOutput(reinterpret_cast(&in)); } -void testGatherImpl(size_t world_size, FieldType field, ProtocolKind protocol, - const xt::xarray &operand, - const xt::xarray &indices, - const xt::xarray &expected, const std::string &mhlo) { - // Public index - { - Runner r(world_size, field, protocol); - - r.addInput(operand); - // Start indices - r.addInput(indices); - - auto compiled = r.compileMHlo(mhlo, {VIS_PUBLIC, VIS_PUBLIC}); - - EXPECT_THAT(compiled, testing::HasSubstr("pphlo.gather")); - - r.run(compiled); - - r.verifyOutput(expected.data()); - } - - // Secret index - { - Runner r(world_size, field, protocol); - - r.addInput(operand); - // Start indices - r.addInput(indices, VIS_SECRET); - - auto compiled = r.compileMHlo(mhlo, {VIS_PUBLIC, VIS_SECRET}); - - EXPECT_THAT(compiled, testing::Not(testing::HasSubstr("pphlo.gather"))); - - r.run(compiled); - - r.verifyOutput(expected.data()); - } -} - -TEST_P(ExecutorTest, Gather1) { - std::string mhlo = R"( -func.func @main(%arg0: tensor<3x3xi32>, %arg1: tensor<2xi32>) -> (tensor<2x3xi32>) { - %0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = #mhlo.gather, indices_are_sorted = false, slice_sizes = dense<[1, 3]> : tensor<2xi64>} : (tensor<3x3xi32>, tensor<2xi32>) -> tensor<2x3xi32> - return %0 : tensor<2x3xi32> -})"; - - auto operand = xt::xarray{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}; - auto indices = xt::xarray{0, 2}; - xt::xarray expected = {{1, 2, 3}, {7, 8, 9}}; - - testGatherImpl(std::get<0>(GetParam()), std::get<1>(GetParam()), - std::get<2>(GetParam()), operand, indices, expected, mhlo); -} - -TEST_P(ExecutorTest, Gather2) { - std::string mhlo = R"( -func.func @main(%arg0: tensor<3x3xi32>, %arg1: tensor<2xi32>) -> (tensor<3x2xi32>) { - %0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = #mhlo.gather, indices_are_sorted = false, slice_sizes = dense<[3,1]> : tensor<2xi64>} : (tensor<3x3xi32>, tensor<2xi32>) -> tensor<3x2xi32> - return %0 : tensor<3x2xi32> -})"; - - auto operand = xt::xarray{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}; - auto indices = xt::xarray{0, 2}; - xt::xarray expected = {{1, 3}, {4, 6}, {7, 9}}; - - testGatherImpl(std::get<0>(GetParam()), std::get<1>(GetParam()), - std::get<2>(GetParam()), operand, indices, expected, mhlo); -} - -TEST_P(ExecutorTest, GatherBatch) { - std::string mhlo = R"( -func.func @main(%arg0: tensor<3x3xi32>, %arg1: tensor<2x2xi32>) -> (tensor<2x3x2xi32>) { - %0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = #mhlo.gather, indices_are_sorted = false, slice_sizes = dense<[3,1]> : tensor<2xi64>} : (tensor<3x3xi32>, tensor<2x2xi32>) -> tensor<2x3x2xi32> - return %0 : tensor<2x3x2xi32> -})"; - - auto operand = xt::xarray{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}; - auto indices = xt::xarray{{0, 2}, {2, 1}}; - - xt::xarray expected = {{{1, 3}, {4, 6}, {7, 9}}, - {{3, 2}, {6, 5}, {9, 8}}}; - - testGatherImpl(std::get<0>(GetParam()), std::get<1>(GetParam()), - std::get<2>(GetParam()), operand, indices, expected, mhlo); -} - -TEST_P(ExecutorTest, GatherNd) { - std::string mhlo = R"( -func.func @main(%arg0: tensor<3x3x2xi32>, %arg1: tensor<2x2xi32>) -> (tensor<2x2xi32>) { - %0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = #mhlo.gather, indices_are_sorted = false, slice_sizes = dense<[1,1,2]> : tensor<3xi64>} : (tensor<3x3x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> - return %0 : tensor<2x2xi32> -})"; - const xt::xarray operand = {{{-1, 1}, {-2, 2}, {-3, 3}}, - {{-4, 4}, {-5, 5}, {-6, 6}}, - {{-7, 7}, {-8, 8}, {-9, 9}}}; - auto indices = xt::xarray{{0, 0}, {1, 0}}; - xt::xarray expected = {{-1, 1}, {-4, 4}}; - - testGatherImpl(std::get<0>(GetParam()), std::get<1>(GetParam()), - std::get<2>(GetParam()), operand, indices, expected, mhlo); -} - -TEST_P(ExecutorTest, GatherNdNonDefaultIndexVectorDim) { - std::string mhlo = R"( -func.func @main(%arg0: tensor<3x3x2xi32>, %arg1: tensor<2x2xi32>) -> (tensor<2x2xi32>) { - %0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = #mhlo.gather, indices_are_sorted = false, slice_sizes = dense<[1,1,2]> : tensor<3xi64>} : (tensor<3x3x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> - return %0 : tensor<2x2xi32> -})"; - xt::xarray operand = {{{-1, 1}, {-2, 2}, {-3, 3}}, - {{-4, 4}, {-5, 5}, {-6, 6}}, - {{-7, 7}, {-8, 8}, {-9, 9}}}; - auto indices = xt::xarray{{0, 0}, {1, 0}}; - xt::xarray expected = {{-2, 2}, {-1, 1}}; - - testGatherImpl(std::get<0>(GetParam()), std::get<1>(GetParam()), - std::get<2>(GetParam()), operand, indices, expected, mhlo); -} - TEST_P(ExecutorTest, Simple4x4Conv2DWith2x2Kernel) { Runner r(std::get<0>(GetParam()), std::get<1>(GetParam()), std::get<2>(GetParam())); @@ -984,7 +866,7 @@ TEST_P(ExecutorTest, Simple4x4Conv2DWith2x2Kernel) { auto ir = r.compileMHlo(R"( func.func @main(%arg0: tensor<1x1x4x4xf32>, %arg1: tensor<1x1x2x2xf32>) -> (tensor<1x1x4x4xf32>) { - %0 = mhlo.convolution(%arg0, %arg1) + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[0, 1], [0, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} { @@ -1023,7 +905,7 @@ TEST_P(ExecutorTest, Conv2DGeneralDimensions) { auto ir = r.compileMHlo(R"( func.func @main(%arg0: tensor<2x3x1x4xf32>, %arg1:tensor<1x3x2x3xf32>) -> (tensor<1x1x1x2xf32>) { - %0 = mhlo.convolution(%arg0, %arg1) + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [f, 0, b, 1]x[o, 1, i,0]->[f, 0, b, 1], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} { @@ -1061,7 +943,7 @@ TEST_P(ExecutorTest, DilatedBaseConv2DWithHighPadding) { auto ir = r.compileMHlo(R"( func.func @main(%arg0: tensor<1x1x4x4xf32>, %arg1: tensor<1x1x2x2xf32>) -> (tensor<1x1x7x7xf32>) { - %0 = mhlo.convolution(%arg0, %arg1) + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[0, 1], [0, 1]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} { @@ -1105,7 +987,7 @@ TEST_P(ExecutorTest, DilatedBaseConv2DWithLowAndHighPadding) { auto ir = r.compileMHlo(R"( func.func @main(%arg0: tensor<1x1x4x4xf32>, %arg1: tensor<1x1x2x2xf32>) -> (tensor<1x1x8x8xf32>) { - %0 = mhlo.convolution(%arg0, %arg1) + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} { @@ -1151,7 +1033,7 @@ TEST_P(ExecutorTest, FlatRhsDilation) { auto ir = r.compileMHlo(R"( func.func @main(%arg0: tensor<1x1x4x6xf32>, %arg1: tensor<1x1x2x3xf32>) -> (tensor<1x1x2x2xf32>) { - %0 = mhlo.convolution(%arg0, %arg1) + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], lhs_dilate = [1, 1], rhs_dilate = [2, 2]} { @@ -2509,21 +2391,21 @@ TEST_P(ExecutorTest, OptimizedMaxPool1) { auto ir = r.compileMHlo(R"( func.func @main(%arg0: tensor<4x6xi32>, %arg1: tensor<2x2xi32>) -> (tensor<2x2xi32>, tensor<4x6xi32>) { - %0 = mhlo.constant dense<0> : tensor - %1 = "mhlo.reduce_window"(%arg0, %0) ({ + %0 = stablehlo.constant dense<0> : tensor + %1 = "stablehlo.reduce_window"(%arg0, %0) ({ ^bb0(%arg2: tensor, %arg3: tensor): - %3 = mhlo.maximum %arg2, %arg3 : tensor - "mhlo.return"(%3) : (tensor) -> () + %3 = stablehlo.maximum %arg2, %arg3 : tensor + "stablehlo.return"(%3) : (tensor) -> () }) {base_dilations = dense<1> : tensor<2xi64>, padding = dense<0> : tensor<2x2xi64>, window_dilations = dense<1> : tensor<2xi64>, window_dimensions = dense<[2,3]> : tensor<2xi64>, window_strides = dense<[2, 3]> : tensor<2xi64>} : (tensor<4x6xi32>, tensor) -> tensor<2x2xi32> - %2 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ({ + %2 = "stablehlo.select_and_scatter"(%arg0, %arg1, %0) ({ ^bb0(%arg3: tensor, %arg4: tensor): - %3 = "mhlo.compare"(%arg3, %arg4) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - "mhlo.return"(%3) : (tensor) -> () + %3 = "stablehlo.compare"(%arg3, %arg4) {comparison_direction = #stablehlo} : (tensor, tensor) -> tensor + "stablehlo.return"(%3) : (tensor) -> () }, { ^bb0(%arg3: tensor, %arg4: tensor): - %3 = mhlo.add %arg3, %arg4 : tensor - "mhlo.return"(%3) : (tensor) -> () + %3 = stablehlo.add %arg3, %arg4 : tensor + "stablehlo.return"(%3) : (tensor) -> () }) {padding = dense<0> : tensor<2x2xi64>, window_dimensions = dense<[2,3]> : tensor<2xi64>, window_strides = dense<[2,3]> : tensor<2xi64>} : (tensor<4x6xi32>, tensor<2x2xi32>, tensor) -> tensor<4x6xi32> return %1, %2 : tensor<2x2xi32>, tensor<4x6xi32> })", @@ -2758,11 +2640,11 @@ TEST_P(ExecutorTest, MixedPayload) { r.run(r.compileMHlo( R"( func.func @main(%arg0: tensor<20xi32>) -> (tensor<20xi32>, tensor<20xi32>) { - %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<20xi32> - %1:2 = "mhlo.sort"(%arg0, %0) ({ + %0 = "stablehlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<20xi32> + %1:2 = "stablehlo.sort"(%arg0, %0) ({ ^bb0(%arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor): - %2 = mhlo.compare LT, %arg1, %arg2 : (tensor, tensor) -> tensor - mhlo.return %2 : tensor + %2 = stablehlo.compare LT, %arg1, %arg2 : (tensor, tensor) -> tensor + stablehlo.return %2 : tensor }) {dimension = 0 : i64, is_stable = true} : (tensor<20xi32>, tensor<20xi32>) -> (tensor<20xi32>, tensor<20xi32>) return %1#0, %1#1: tensor<20xi32>, tensor<20xi32> })", @@ -2788,11 +2670,11 @@ TEST_P(ExecutorTest, MixedPayloadDescending) { r.run(r.compileMHlo( R"( func.func @main(%arg0: tensor<20xi32>) -> (tensor<20xi32>, tensor<20xi32>) { - %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<20xi32> - %1:2 = "mhlo.sort"(%arg0, %0) ({ + %0 = "stablehlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<20xi32> + %1:2 = "stablehlo.sort"(%arg0, %0) ({ ^bb0(%arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor): - %2 = mhlo.compare GT, %arg1, %arg2 : (tensor, tensor) -> tensor - mhlo.return %2 : tensor + %2 = stablehlo.compare GT, %arg1, %arg2 : (tensor, tensor) -> tensor + stablehlo.return %2 : tensor }) {dimension = 0 : i64, is_stable = true} : (tensor<20xi32>, tensor<20xi32>) -> (tensor<20xi32>, tensor<20xi32>) return %1#0, %1#1: tensor<20xi32>, tensor<20xi32> })", diff --git a/libspu/device/pphlo/pphlo_executor_test_runner.cc b/libspu/device/pphlo/pphlo_executor_test_runner.cc index be485319..ca6bed0d 100644 --- a/libspu/device/pphlo/pphlo_executor_test_runner.cc +++ b/libspu/device/pphlo/pphlo_executor_test_runner.cc @@ -35,7 +35,7 @@ Runner::Runner(size_t world_size, FieldType field, ProtocolKind protocol) std::string Runner::compileMHlo(const std::string &mhlo, const std::vector &vis) { CompilationSource source; - source.set_ir_type(SourceIRType::MLIR_HLO); + source.set_ir_type(SourceIRType::STABLEHLO); source.set_ir_txt(mhlo); for (const auto v : vis) { source.add_input_visibility(v); diff --git a/libspu/device/pphlo/pphlo_verifier.cc b/libspu/device/pphlo/pphlo_verifier.cc index ce684730..e338f283 100644 --- a/libspu/device/pphlo/pphlo_verifier.cc +++ b/libspu/device/pphlo/pphlo_verifier.cc @@ -490,12 +490,6 @@ void PPHloVerifier::verify(mlir::pphlo::SortOp, SPDLOG_WARN("Missing stablehlo interpreter support"); } -void PPHloVerifier::verify(mlir::pphlo::GatherOp, - absl::Span /*operands*/, - absl::Span /*expected*/) { - SPDLOG_WARN("Missing stablehlo interpreter support"); -} - void PPHloVerifier::verify(mlir::pphlo::BitcastConvertOp, absl::Span /*operands*/, absl::Span /*expected*/) { diff --git a/libspu/device/pphlo/pphlo_verifier.h b/libspu/device/pphlo/pphlo_verifier.h index 50697e83..001ffb50 100644 --- a/libspu/device/pphlo/pphlo_verifier.h +++ b/libspu/device/pphlo/pphlo_verifier.h @@ -98,9 +98,6 @@ class PPHloVerifier { VERIFY_DECL(DynamicSliceOp) VERIFY_DECL(DynamicUpdateSliceOp) - // Gather - VERIFY_DECL(GatherOp) - // Geometrical VERIFY_DECL(PadOp) VERIFY_DECL(BroadcastOp) diff --git a/libspu/dialect/pphlo_attrs.td b/libspu/dialect/pphlo_attrs.td index 7e919913..653c3672 100644 --- a/libspu/dialect/pphlo_attrs.td +++ b/libspu/dialect/pphlo_attrs.td @@ -22,17 +22,6 @@ include "pphlo_dialect.td" def PPHloDim : ArrayRefParameter<"int64_t", "Dimension">; -def GatherDimensionNumbers : AttrDef { - let mnemonic = "gather"; - let summary = "Attribute that models the dimension information for gather"; - let parameters = (ins - PPHloDim: $offsetDims, - PPHloDim: $collapsedSliceDims, - PPHloDim: $startIndexMap, - "int64_t": $indexVectorDim); - let hasCustomAssemblyFormat = 1; -} - def ConvDimensionNumbers : AttrDef { let mnemonic = "conv"; let summary = "Structure of dimension information for conv op"; diff --git a/libspu/dialect/pphlo_ops.cc b/libspu/dialect/pphlo_ops.cc index 0dd64ca0..12aa584e 100644 --- a/libspu/dialect/pphlo_ops.cc +++ b/libspu/dialect/pphlo_ops.cc @@ -1058,9 +1058,19 @@ LogicalResult inferDynamicSliceOp(std::optional location, } } + TypeTools tools; // dynamic_slice_c5 - inferredReturnTypes.emplace_back( - RankedTensorType::get(sliceSizes, rankedOperandType.getElementType())); + llvm::SmallVector vis(startIndicesTypes.size() + 1); + vis[0] = tools.getTypeVisibility(operandType); + for (const auto& index_type : llvm::enumerate(startIndicesTypes)) { + vis[index_type.index() + 1] = tools.getTypeVisibility(index_type.value()); + } + + inferredReturnTypes.emplace_back(RankedTensorType::get( + sliceSizes, + tools.getTypeWithVisibility(rankedOperandType.getElementType(), + tools.inferResultVisibility(vis)))); + return success(); } @@ -1227,41 +1237,6 @@ static ParseResult parseDims(AsmParser& parser, SmallVector& dims) { return success(); } -void GatherDimensionNumbersAttr::print(AsmPrinter& printer) const { - printStruct(printer, "gather", std::make_pair("offset_dims", getOffsetDims()), - std::make_pair("collapsed_slice_dims", getCollapsedSliceDims()), - std::make_pair("start_index_map", getStartIndexMap()), - std::make_pair("index_vector_dim", getIndexVectorDim())); -} - -Attribute GatherDimensionNumbersAttr::parse(AsmParser& parser, Type) { - if (failed(parser.parseLess())) { - return {}; - } - - SmallVector offset_dims; - SmallVector collapsed_slice_dims; - SmallVector start_index_map; - int64_t index_vector_dim = 0; - - if (failed(parseStruct( - parser, - {"offset_dims", "collapsed_slice_dims", "start_index_map", - "index_vector_dim"}, - {[&]() { return parseDims(parser, offset_dims); }, - [&]() { return parseDims(parser, collapsed_slice_dims); }, - [&]() { return parseDims(parser, start_index_map); }, - [&]() { return parser.parseInteger(index_vector_dim); }}))) { - parser.emitError(parser.getCurrentLocation()) - << "failed parsing gather dimension numbers attribute"; - return {}; - } - - return GatherDimensionNumbersAttr::get(parser.getContext(), offset_dims, - collapsed_slice_dims, start_index_map, - index_vector_dim); -} - // Custom printer and parser for DotDimensionNumbersAttr. void DotDimensionNumbersAttr::print(AsmPrinter& printer) const { printStruct( diff --git a/libspu/dialect/pphlo_ops.td b/libspu/dialect/pphlo_ops.td index 125012b1..08bdf3db 100644 --- a/libspu/dialect/pphlo_ops.td +++ b/libspu/dialect/pphlo_ops.td @@ -993,26 +993,6 @@ def HLO_PadOp let results = (outs PPHLO_Tensor); } -def PPHLO_GatherOp : PPHLO_Op<"gather", [Pure]> { - let summary = "Gather operator"; - let description = [{ - Stitches together several slices of `operand` from offsets specified in - `start_indices` (each slice at a potentially different runtime offset). - - See https://www.tensorflow.org/xla/operation_semantics#gather. - }]; - - let arguments = (ins - PPHLO_Tensor:$operand, - PPHLO_IntTensor:$start_indices, - GatherDimensionNumbers:$dimension_numbers, - DenseI64ArrayAttr:$slice_sizes, - DefaultValuedAttr:$indices_are_sorted - ); - - let results = (outs PPHLO_Tensor); -} - def ConvolutionAttributes { dag attributes = (ins // Default value: one for each of the spatial dimension. diff --git a/libspu/kernel/hlo/indexing.cc b/libspu/kernel/hlo/indexing.cc index b815be7d..da64ac92 100644 --- a/libspu/kernel/hlo/indexing.cc +++ b/libspu/kernel/hlo/indexing.cc @@ -35,279 +35,6 @@ void hintNumberOfBits(const Value &a, size_t nbits); } namespace { -struct IndexIterationSpace { - spu::Index index_base; - spu::Index index_count; - spu::Index index_incr; -}; - -// Returns an IndexIterationSpace that iterates over the output batch -// dimensions while keeping the rest of the output dimensions clamped to 0. -IndexIterationSpace iterationSpaceForOutputBatchIndices( - const spu::Shape &output_shape, - const spu::kernel::hlo::GatherConfig &config) { - int64_t output_rank = output_shape.size(); - spu::Index index_base(output_rank, 0); - spu::Index index_count; - index_count.reserve(output_rank); - - for (int64_t i = 0; i < output_rank; i++) { - bool is_output_batch_dim = !std::binary_search(config.offsetDims.begin(), - config.offsetDims.end(), i); - index_count.push_back(is_output_batch_dim ? output_shape[i] : 1); - } - - return {std::move(index_base), std::move(index_count), - spu::Index(output_rank, 1)}; -} - -// Return an IndexIterationSpace that iterates over the output slice -// dimensions while keeping the rest of the output dimensions clamped to 0. -IndexIterationSpace iterationSpaceForOutputOffsetIndices( - int64_t output_rank, const spu::kernel::hlo::GatherConfig &config) { - spu::Index index_base(output_rank, 0); - spu::Index index_count(output_rank, 1); - int64_t slice_sizes_idx = 0; - - for (int64_t i = 0; i < output_rank; i++) { - bool is_output_window_dim = std::binary_search(config.offsetDims.begin(), - config.offsetDims.end(), i); - if (is_output_window_dim) { - while (std::binary_search(config.collapsedSliceDims.begin(), - config.collapsedSliceDims.end(), - slice_sizes_idx)) { - slice_sizes_idx++; - } - index_count[i] = config.sliceSizes[slice_sizes_idx++]; - } - } - - return {std::move(index_base), std::move(index_count), - spu::Index(output_rank, 1)}; -} - -// This functor computes the contribution of start_indices to an input index -// corresponding to an output index. That is, given an output index I, it -// picks out the batch indices in I and uses them to look up a starting index, -// G, from the start indices tensor, and expands G into the input space -// according to start_index_map. -class OutputBatchIndexToInputIndex { - public: - // The constructor does some setup work that is amortized across all - // iterations. - explicit OutputBatchIndexToInputIndex( - const spu::kernel::hlo::GatherConfig &config, - const spu::Shape &input_shape, const spu::Shape &output_shape, - const xt::xarray &start_indices) - : config_(config), start_indices_(start_indices) { - for (int64_t i = 0; i < static_cast(output_shape.size()); ++i) { - output_dim_is_batch_dims_.push_back(!std::binary_search( - config_.offsetDims.begin(), config_.offsetDims.end(), i)); - } - - for (int64_t i = 0; i < static_cast(input_shape.size()); ++i) { - int64_t index_of_input_dim_in_index_vector = - std::distance(config_.startIndexMap.begin(), - std::find(config_.startIndexMap.begin(), - config_.startIndexMap.end(), i)); - - if (static_cast(index_of_input_dim_in_index_vector) == - config_.startIndexMap.size()) { - input_dim_value_to_index_vector_.push_back(-1); - } else { - input_dim_value_to_index_vector_.push_back( - index_of_input_dim_in_index_vector); - } - } - - index_vector_index_.resize(start_indices_.shape().size()); - input_index_.resize(input_shape.size()); - int64_t index_vector_size = start_indices_.shape()[config.indexVectorDim]; - index_vector_.resize(index_vector_size); - - start_indices_shape_.reserve(start_indices_.shape().size()); - for (const auto &d : start_indices_.shape()) { - start_indices_shape_.emplace_back(static_cast(d)); - } - } - - // Returns the contribution of start_indices to the input index - // corresponding to output_index. See gather_inner_loop_body. - // - // This is conceptually a stateless transformation from output_index to the - // gather input index, but: - // - // - Instead of allocating memory to represent the gather input index on - // every invocation we reuse the same storage for the result - // (input_index_), mutating it in place. - // - Instead of allocating buffers for temporary values like - // index_vector_index_ and index_vector on every invocation, we reuse the - // same storage for all invocations. - // - // This returns a Span into memory owned by the class. - spu::Index &operator()(const spu::Index &output_index) { - propagateOutputIndexGatherDimsToIndexVectorIndex(output_index); - fetchIndexVector(); - propagateIndexVectorToInputIndex(); - return input_index_; - } - - private: - // Propagates the batch dimensions from the output index into - // index_vector_index_ by mutating index_vector_index_ in place. Does not - // update the dim_numbers.index_vector_dim() dimension -- that's the - // dimension we iterate over in FetchIndexVector. - void propagateOutputIndexGatherDimsToIndexVectorIndex( - absl::Span output_index) { - int64_t index_vector_index_i = 0; - for (int64_t i = 0, e = output_index.size(); i < e; i++) { - if (!output_dim_is_batch_dims_[i]) { - continue; - } - - if (index_vector_index_i == config_.indexVectorDim) { - index_vector_index_i++; - } - - index_vector_index_[index_vector_index_i++] = output_index[i]; - } - } - - // Populates index_vector_ by iterating over start_indices_ according to - // index_vector_index_. - void fetchIndexVector() { - int64_t index_vector_dim = config_.indexVectorDim; - for (int64_t i = 0, e = index_vector_.size(); i < e; i++) { - index_vector_index_[index_vector_dim] = i; - index_vector_[i] = start_indices_.data()[spu::flattenIndex( - index_vector_index_, start_indices_shape_)]; - } - } - - // Populates input_index_. - void propagateIndexVectorToInputIndex() { - for (int64_t i = 0, e = input_index_.size(); i < e; i++) { - if (input_dim_value_to_index_vector_[i] != -1) { - input_index_[i] = index_vector_[input_dim_value_to_index_vector_[i]]; - } - } - } - - // input_dim_value_to_index_vector_[i] tells us how to compute dimension i - // of the input index from the index vector. See - // PropagateIndexVectorToInputIndex. - spu::Index input_dim_value_to_index_vector_; - - // output_dim_is_batch_dims_[i] is true iff the output index i is a gather - // dimension. - std::vector output_dim_is_batch_dims_; - - // The buffer into which we construct an index into start_indices_ to fetch - // the index vector. - spu::Index index_vector_index_; - - // The index vector fetched from start_indices_. - spu::Index index_vector_; - - // The result computed by this functor. operator() returns a Span into - // this vector. - spu::Index input_index_; - - const spu::kernel::hlo::GatherConfig &config_; - const xt::xarray &start_indices_; - spu::Shape start_indices_shape_; -}; - -// This functor computes the contribution of the offset indices in an output -// index to an input index. That is, given an output index I it picks out the -// output offset indices in I and expands it into an index into the input -// shape. -class OutputOffsetIndexToInputIndex { - public: - // The constructor does some setup work that is amortized across all - // iterations. - explicit OutputOffsetIndexToInputIndex( - const spu::kernel::hlo::GatherConfig &config, - const spu::Shape &input_shape, const spu::Shape &output_shape) { - spu::Index window_index_to_output_index; - int64_t output_index_count = 0; - for (int64_t i = 0; i < static_cast(output_shape.size()); i++) { - if (std::binary_search(config.offsetDims.begin(), config.offsetDims.end(), - i)) { - window_index_to_output_index.push_back(output_index_count++); - } else { - output_index_count++; - } - } - - int64_t window_dim_count = 0; - for (int64_t i = 0; i < static_cast(input_shape.size()); i++) { - if (std::binary_search(config.collapsedSliceDims.begin(), - config.collapsedSliceDims.end(), i)) { - input_dim_value_to_output_index_.push_back(-1); - } else { - input_dim_value_to_output_index_.push_back( - window_index_to_output_index[window_dim_count++]); - } - } - - input_index_.resize(input_shape.size()); - } - - // Returns the contribution of the window indices to the input index - // corresponding to output_index. See gather_inner_loop_body. - // - // This is conceptually a stateless transformation from output_index to the - // window input index, but instead of allocating memory to represent the - // gather input index on every invocation we reuse the same storage for the - // result (input_index_), mutating it in place. - // - // This returns a Span into memory owned by the class. - spu::Index &operator()(const spu::Index &output_index) { - propagateOutputIndexWindowDimsToInputIndex(output_index); - return input_index_; - } - - // Returns for a given 'input_dim' the corresponding output dimension index, - // or -1 if 'input_dim' is an elided window dimension. - int64_t input_dim_value_to_output_index(int64_t input_dim) { - return input_dim_value_to_output_index_[input_dim]; - } - - private: - // Propagates window dimensions from the output index to input_index_ by - // mutating input_index_ in place. - void propagateOutputIndexWindowDimsToInputIndex( - absl::Span output_index) { - for (int64_t i = 0, e = input_index_.size(); i < e; i++) { - if (input_dim_value_to_output_index_[i] != -1) { - input_index_[i] = output_index[input_dim_value_to_output_index_[i]]; - } - } - } - - // input_dim_value_to_index_vector_[i] tells us how to compute dimension i - // of the input index from the output index. See - // PropagateOutputIndexWindowDimsToInputIndex. - spu::Index input_dim_value_to_output_index_; - - // The result computed by this functor. operator() returns a Span into - // this vector. - spu::Index input_index_; -}; - -spu::Value reshapedGatherIndices(spu::SPUContext *ctx, int64_t index_vector_dim, - const spu::Value &start_indices) { - if (start_indices.shape().size() != static_cast(index_vector_dim)) { - return start_indices; - } - - auto new_shape = start_indices.shape(); - new_shape.push_back(1); - - return spu::kernel::hal::reshape(ctx, start_indices, new_shape); -} - spu::Value SecretLinearUpdateIndexing(spu::SPUContext *ctx, const spu::Value &operand, const spu::Value &update, @@ -421,111 +148,6 @@ std::vector ClampAndFlattenIndex( namespace spu::kernel::hlo { -spu::Value Gather(SPUContext *ctx, const spu::Value &operand, - const spu::Value &start_indices, const GatherConfig &config, - const Shape &result_shape) { - // If input is empty, short circuit - if (operand.numel() == 0) { - return operand; - } - - auto start_indices_value = - reshapedGatherIndices(ctx, config.indexVectorDim, start_indices); - - SPU_ENFORCE(start_indices.isPublic()); - - auto start_index = getIndices(ctx, start_indices_value); - - // We iterate over the gather dimensions in the output shape in an outer - // loop nest, and iterate over the window dimensions in the output shape in - // an inner loop nest. - IndexIterationSpace start_indices_iteration_space = - iterationSpaceForOutputBatchIndices(result_shape, config); - IndexIterationSpace offset_indices_iteration_space = - iterationSpaceForOutputOffsetIndices(result_shape.size(), config); - - // Scratch buffers that hold an index in the output shape and the - // corresponding index in the input shape. - // If input is empty, short circuit it - auto operand_shape = operand.shape(); - Index input_index(operand_shape.size()); - Index output_index(result_shape.size()); - Index input_index_clamped(operand_shape.size()); - - OutputBatchIndexToInputIndex output_batch_index_to_input_index( - config, /*input_shape=*/operand_shape, - /*output_shape=*/result_shape, start_index); - OutputOffsetIndexToInputIndex output_offset_index_to_input_index( - config, /*input_shape=*/operand_shape, - /*output_shape=*/result_shape); - - spu::Value result(NdArrayRef(operand.data().eltype(), result_shape), - operand.dtype()); - - if (operand.isComplex()) { - result = - Value(result.data(), NdArrayRef(operand.imag()->eltype(), result_shape), - operand.dtype()); - } - - auto gather_inner_loop_body = [&](const spu::Index &output_window_index, - const spu::Index &input_gather_index, - const spu::Index &output_gather_index) { - auto input_window_index = - output_offset_index_to_input_index(output_window_index); - for (int i = 0, e = output_index.size(); i < e; i++) { - output_index[i] = output_gather_index[i] + output_window_index[i]; - } - for (int i = 0, e = input_gather_index.size(); i < e; i++) { - int64_t output_dim = - output_offset_index_to_input_index.input_dim_value_to_output_index(i); - // If 'output_dim' is -1, it means 'i' is an elided window dim. This - // means we set the iteration index to 0, so for the purpose of the - // following calculations we can consider the output dimension size - // to be 1. - int64_t output_dim_size = output_dim == -1 ? 1 : result_shape[output_dim]; - // Clamp the gather index so that the gather region fits in the - // operand. input_index_clamped[i] = clamp(input_gather_index[i], 0, - // operand_shape.dimensions(i) - // - output_dim_size); - input_index_clamped[i] = - std::min(operand_shape[i] - output_dim_size, - std::max(int64_t{0}, input_gather_index[i])); - } - for (int i = 0, e = input_index.size(); i < e; i++) { - input_index[i] = input_index_clamped[i] + input_window_index[i]; - } - - result.data().update_slice(operand.data().slice_scalar_at(input_index), - output_index); - - if (result.isComplex()) { - result.imag()->update_slice(operand.imag()->slice_scalar_at(input_index), - output_index); - } - }; - - auto gather_outer_loop_body = [&](const spu::Index &output_gather_index) { - auto input_gather_index = - output_batch_index_to_input_index(output_gather_index); - forEachIndex(result_shape, offset_indices_iteration_space.index_base, - offset_indices_iteration_space.index_count, - offset_indices_iteration_space.index_incr, - [&](const spu::Index &output_window_index) { - return gather_inner_loop_body(output_window_index, - input_gather_index, - output_gather_index); - }); - }; - - forEachIndex(result_shape, start_indices_iteration_space.index_base, - start_indices_iteration_space.index_count, - start_indices_iteration_space.index_incr, - gather_outer_loop_body); - - return result; -} - spu::Value DynamicUpdateSlice(SPUContext *ctx, const spu::Value &operand, const spu::Value &update, absl::Span start_indices) { diff --git a/libspu/kernel/hlo/indexing.h b/libspu/kernel/hlo/indexing.h index 1086e05a..61854fb4 100644 --- a/libspu/kernel/hlo/indexing.h +++ b/libspu/kernel/hlo/indexing.h @@ -22,20 +22,6 @@ class SPUContext; namespace spu::kernel::hlo { -struct GatherConfig { - spu::Sizes sliceSizes; - int64_t indexVectorDim; - spu::Axes offsetDims; - spu::Axes collapsedSliceDims; - spu::Axes startIndexMap; -}; - -// This is ported from -// https://github.com/tensorflow/tensorflow/blob/bf4c6ad46dac1f7f69911e2bfc48e141a39b40af/tensorflow/compiler/xla/service/hlo_evaluator.cc#L1774 -spu::Value Gather(SPUContext *ctx, const spu::Value &operand, - const spu::Value &start_indices, const GatherConfig &config, - const Shape &result_shape); - spu::Value DynamicUpdateSlice(SPUContext *ctx, const spu::Value &operand, const spu::Value &update, absl::Span start_indices); diff --git a/libspu/spu.proto b/libspu/spu.proto index d83c93b0..0a05809e 100644 --- a/libspu/spu.proto +++ b/libspu/spu.proto @@ -341,7 +341,7 @@ message TTPBeaverConfig { ////////////////////////////////////////////////////////////////////////// enum SourceIRType { XLA = 0; - MLIR_HLO = 1; + STABLEHLO = 1; } message CompilationSource { diff --git a/spu/libspu.cc b/spu/libspu.cc index e86e3179..4b3a0d40 100644 --- a/spu/libspu.cc +++ b/spu/libspu.cc @@ -152,6 +152,7 @@ void BindLink(py::module& m) { .def_readwrite("server_ssl_opts", &ContextDesc::server_ssl_opts) .def_readwrite("link_type", &ContextDesc::link_type) .def_readwrite("retry_opts", &ContextDesc::retry_opts) + .def_readwrite("disable_msg_seq_id", &ContextDesc::disable_msg_seq_id) .def( "add_party", [](ContextDesc& desc, std::string id, std::string host) { diff --git a/spu/utils/distributed.py b/spu/utils/distributed.py index b59e7389..035103eb 100644 --- a/spu/utils/distributed.py +++ b/spu/utils/distributed.py @@ -537,7 +537,6 @@ def builtin_spu_init( return desc = libspu.link.Desc() desc.recv_timeout_ms = 100 * 1000 # 100 seconds - desc.throttle_window_size = 0 # disable throttle desc.http_max_payload_size = 32 * 1024 * 1024 # Default set link payload to 32M for rank, addr in enumerate(addrs): desc.add_party(f"r{rank}", addr) @@ -839,8 +838,20 @@ class TorchFunction(Device.Function): def __init__(self, device: Device, pyfunc: Callable): super().__init__(device, pyfunc) + self.state_dict = None - def __call__(self, *args, **kwargs): + def _place_state_dict(self, state_dict): + # place arguments onto this device. + def place(obj): + if not isinstance(obj, Device.Object): + return obj + return Device.move(obj, self.device) + + return tree_map(place, state_dict) + + def __call__(self, state_dict, *args, **kwargs): + # place state_dict + self.state_dict = self._place_state_dict(state_dict) args, kwargs = self.device._place_arguments(*args, **kwargs) # now, all object are either PyObject or SPU.DeviceObject @@ -871,12 +882,17 @@ def get_share_ref(idx, obj): builtin_fetch_meta, results[0] ) - ret_flat = [ + ret = [ SPU.Object(self.device, share_refs, *meta) for share_refs, meta in zip(zip(*results), metas) ] - return tree_unflatten(out_tree, ret_flat) + from torch.utils import _pytree as pytree + + if out_tree is not None: + out_spec = pytree.treespec_loads(out_tree) + ret = pytree.tree_unflatten(ret, out_spec) + return ret def dump_pphlo(self, *args, **kwargs): args, kwargs = self.device._place_arguments(*args, **kwargs) @@ -884,38 +900,32 @@ def dump_pphlo(self, *args, **kwargs): return executable.code.decode('utf-8') def _compile_torch_func(self, fn, *args, **kwargs): + import torch + def mock_parameters(obj: Union[SPU.Object, np.ndarray]): if isinstance(obj, SPU.Object): - return np.zeros(shape=obj.shape, dtype=obj.dtype) + zeros = np.zeros(shape=obj.shape, dtype=obj.dtype) + return torch.from_numpy(zeros) else: assert not isinstance(obj, Device.Object) return obj - mock_args, mock_kwargs = tree_map(mock_parameters, (args, kwargs)) + assert isinstance( + fn, torch.nn.Module + ), "currently only torch.nn.Module is supported" - args_flat, _ = jax.tree_util.tree_flatten((args, kwargs)) + mock_args, mock_kwargs = tree_map(mock_parameters, (args, kwargs)) - fn_name = repr(fn) + exported_fn = torch._export.export(fn, args=mock_args, kwargs=mock_kwargs) - in_vis = [ - arg.vtype - if isinstance(arg, SPU.Object) - else spu_pb2.Visibility.VIS_PUBLIC - for arg in args_flat - ] - in_names = [f'{id(fn_name)}-in{idx}' for idx in range(len(args_flat))] - - def outputNameGen(out_flat: List): - return [f'{id(fn_name)}-out{idx}' for idx in range(len(out_flat))] + args_flat, _ = jax.tree_util.tree_flatten((args, kwargs)) + m_args_flat, _ = jax.tree_util.tree_flatten((mock_args, mock_kwargs)) - executable, output_tree = spu_fe.compile( - spu_fe.Kind.Torch, - fn, - mock_args, - mock_kwargs, - in_names, - in_vis, - outputNameGen, + executable, output_tree, args_flat = spu_fe.torch_compile( + exported_fn, + args_flat, + m_args_flat, + state_dict=self.state_dict, ) return executable, args_flat, output_tree diff --git a/spu/utils/frontend.py b/spu/utils/frontend.py index de7e04c8..5fecf875 100644 --- a/spu/utils/frontend.py +++ b/spu/utils/frontend.py @@ -11,13 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and +import collections import functools from enum import Enum from threading import Lock -from typing import Callable, Dict, Iterable, List -from numpy import ndarray +from typing import Callable, Dict, Iterable, List, Union from cachetools import LRUCache, cached +from numpy import ndarray from .. import api as spu_api from .. import spu_pb2 @@ -94,8 +95,8 @@ def _jax_compilation( fn: Callable, static_argnums, static_argnames, args: List, kwargs: Dict ): import jax - from jax._src.xla_bridge import _backend_lock, _backends, register_backend_factory from jax._src.lib import xla_client, xla_extension_version + from jax._src.xla_bridge import _backend_lock, _backends, register_backend_factory # Register interpreter backend since we don't want any cpu/gpu/tpu specific optimization if xla_extension_version < 164: @@ -242,50 +243,14 @@ def compile( output = cf.structured_outputs output_names = outputNameGen(cf.outputs) - elif kind == Kind.Torch: - import jax - import torch - import torch_mlir - from torch_mlir._mlir_libs._mlir.ir import Attribute, Context - - assert isinstance( - fn, torch.nn.Module - ), "currently only torch.nn.Module is supported" - - # convert numpy.ndarray to torch tensor as torch_mlir required - arg_tensors = [torch.Tensor(arg) for arg in m_args] - # get mlir module - module = torch_mlir.compile( - fn, arg_tensors, output_type=torch_mlir.OutputType.MHLO - ) - # get mlir func op of torch.nn.Module.forward function - func_op = module.body.operations[0] - # rename func name from 'forward' to 'main' - with Context(): - func_op.attributes["sym_name"] = Attribute.parse('"main"') - - # parse output_num from func op signature string - func_sig = func_op.attributes["function_type"] - output_num = len(str(func_sig).split("->")[1].split(",")) - # get mhlo - ir_text = bytes(str(module), 'utf-8') - # mock output - output = [0] * output_num - output_names = outputNameGen(output) - output = tuple(output) if output_num > 1 else output[0] - _, output = jax.tree_util.tree_flatten(output) else: raise NameError(f"Unknown frontend type {kind}") source = spu_pb2.CompilationSource() source.ir_txt = ir_text + source.ir_type = spu_pb2.SourceIRType.XLA source.input_visibility.extend(input_vis) - if kind in [Kind.JAX, Kind.Tensorflow]: - source.ir_type = spu_pb2.SourceIRType.XLA - name = fn.func.__name__ if isinstance(fn, functools.partial) else fn.__name__ - elif kind == Kind.Torch: - source.ir_type = spu_pb2.SourceIRType.MLIR_HLO - name = repr(fn) + name = fn.func.__name__ if isinstance(fn, functools.partial) else fn.__name__ mlir = spu_api.compile(source, copts) executable = spu_pb2.ExecutableProto( name=name, @@ -294,3 +259,71 @@ def compile( code=mlir, ) return executable, output + + +def torch_compile( + fn: Callable, + args_flat: List, + m_args_flat: List, + state_dict: collections.OrderedDict(), + copts=spu_pb2.CompilerOptions(), +): + import os + import torch + from torch_xla import stablehlo + from torch_xla.stablehlo import VariableType + + from . import distributed + + assert isinstance( + fn, torch.export.ExportedProgram + ), "input should be an exported torch model" + + os.environ['PJRT_DEVICE'] = 'CPU' + options = stablehlo.StableHLOExportOptions() + options.override_tracing_arguments = m_args_flat + shlo = stablehlo.exported_program_to_stablehlo(fn, options) + method = shlo._name_to_stablehlo["forward"] + ir_str = method.text + ir_text = bytes(ir_str, 'utf-8') + + name = fn.module()._get_name() + output_names = [ + f'{id(name)}-out{idx}' for idx in range(len(fn.graph_signature.user_outputs)) + ] + output_tree = method.meta.output_pytree_spec + + source = spu_pb2.CompilationSource() + source.ir_txt = ir_text + source.ir_type = spu_pb2.SourceIRType.STABLEHLO + + state_dict_idx = {k: i for i, k in enumerate(shlo._bundle.state_dict.keys())} + state_dict_list = list(state_dict.values()) + args_params_flat = [] + for loc in method.meta.input_locations: + if loc.type_ == VariableType.PARAMETER: + args_params_flat.append(state_dict_list[state_dict_idx[loc.name]]) + elif loc.type_ == VariableType.INPUT_ARG: + args_params_flat.append(args_flat[loc.position]) + else: + raise RuntimeError( + 'Currently only torch models with named parameters and buffers are supported' + ) + input_names = [f'{id(name)}-in{idx}' for idx in range(len(args_params_flat))] + + source.input_visibility.extend( + [ + arg.vtype + if isinstance(arg, distributed.SPU.Object) + else spu_pb2.Visibility.VIS_PUBLIC + for arg in args_params_flat + ] + ) + mlir = spu_api.compile(source, copts) + executable = spu_pb2.ExecutableProto( + name=name, + input_names=input_names, + output_names=output_names, + code=mlir, + ) + return executable, output_tree, args_params_flat diff --git a/spu/utils/simulation.py b/spu/utils/simulation.py index 92e58128..828295bc 100644 --- a/spu/utils/simulation.py +++ b/spu/utils/simulation.py @@ -83,7 +83,6 @@ def __call__(self, executable, *flat_args): ] lctx_desc = libspu.link.Desc() - lctx_desc.throttle_window_size = 0 # disable throttle for rank in range(self.wsize): lctx_desc.add_party(f"id_{rank}", f"thread_{rank}")