Skip to content

Commit

Permalink
repo-sync-2024-01-19T11:33:44+0800
Browse files Browse the repository at this point in the history
  • Loading branch information
anakinxc committed Jan 19, 2024
1 parent 832398b commit 6ae3684
Show file tree
Hide file tree
Showing 48 changed files with 527 additions and 1,625 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
>
> please add your unreleased change here.
- [Feature] Support more generic Torch model inference
- [Improvement] Optimize one-time setup for yacl ot
- [Improvement] Optimize sort performance

Expand Down
81 changes: 22 additions & 59 deletions docs/reference/pphlo_op_doc.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ Effects: MemoryEffects::Effect{}

<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>edge_padding_low</code></td><td>::mlir::DenseIntElementsAttr</td><td>64-bit signless integer elements attribute</td></tr>
<tr><td><code>edge_padding_high</code></td><td>::mlir::DenseIntElementsAttr</td><td>64-bit signless integer elements attribute</td></tr>
<tr><td><code>interior_padding</code></td><td>::mlir::DenseIntElementsAttr</td><td>64-bit signless integer elements attribute</td></tr>
<tr><td><code>edge_padding_low</code></td><td>::mlir::DenseI64ArrayAttr</td><td>i64 dense array attribute</td></tr>
<tr><td><code>edge_padding_high</code></td><td>::mlir::DenseI64ArrayAttr</td><td>i64 dense array attribute</td></tr>
<tr><td><code>interior_padding</code></td><td>::mlir::DenseI64ArrayAttr</td><td>i64 dense array attribute</td></tr>
</table>

#### Operands:
Expand Down Expand Up @@ -135,9 +135,9 @@ Effects: MemoryEffects::Effect{}

<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>window_dimensions</code></td><td>::mlir::DenseIntElementsAttr</td><td>64-bit signless integer elements attribute</td></tr>
<tr><td><code>window_strides</code></td><td>::mlir::DenseIntElementsAttr</td><td>64-bit signless integer elements attribute</td></tr>
<tr><td><code>window_dilations</code></td><td>::mlir::DenseIntElementsAttr</td><td>64-bit signless integer elements attribute</td></tr>
<tr><td><code>window_dimensions</code></td><td>::mlir::DenseI64ArrayAttr</td><td>i64 dense array attribute</td></tr>
<tr><td><code>window_strides</code></td><td>::mlir::DenseI64ArrayAttr</td><td>i64 dense array attribute</td></tr>
<tr><td><code>window_dilations</code></td><td>::mlir::DenseI64ArrayAttr</td><td>i64 dense array attribute</td></tr>
<tr><td><code>onehot_index</code></td><td>::mlir::BoolAttr</td><td>bool attribute</td></tr>
</table>

Expand Down Expand Up @@ -213,7 +213,7 @@ Effects: MemoryEffects::Effect{}

<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>broadcast_dimensions</code></td><td>::mlir::DenseIntElementsAttr</td><td>64-bit signless integer elements attribute</td></tr>
<tr><td><code>broadcast_dimensions</code></td><td>::mlir::DenseI64ArrayAttr</td><td>i64 dense array attribute</td></tr>
</table>

#### Operands:
Expand Down Expand Up @@ -455,7 +455,7 @@ Effects: MemoryEffects::Effect{}

<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>window_strides</code></td><td>::mlir::DenseIntElementsAttr</td><td>64-bit signless integer elements attribute</td></tr>
<tr><td><code>window_strides</code></td><td>::mlir::DenseI64ArrayAttr</td><td>i64 dense array attribute</td></tr>
<tr><td><code>dimension_numbers</code></td><td>::mlir::pphlo::ConvDimensionNumbersAttr</td><td>Structure of dimension information for conv op</td></tr>
<tr><td><code>feature_group_count</code></td><td>::mlir::IntegerAttr</td><td>64-bit signless integer attribute</td></tr>
<tr><td><code>batch_group_count</code></td><td>::mlir::IntegerAttr</td><td>64-bit signless integer attribute</td></tr>
Expand Down Expand Up @@ -658,7 +658,7 @@ Effects: MemoryEffects::Effect{}

<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>slice_sizes</code></td><td>::mlir::DenseIntElementsAttr</td><td>64-bit signless integer elements attribute</td></tr>
<tr><td><code>slice_sizes</code></td><td>::mlir::DenseI64ArrayAttr</td><td>i64 dense array attribute</td></tr>
</table>

#### Operands:
Expand Down Expand Up @@ -837,43 +837,6 @@ Effects: MemoryEffects::Effect{}
| :-----: | ----------- |
| `operand` | statically shaped tensor of PPHlo public type or PPHlo secret type values

### `pphlo.gather` (pphlo::GatherOp)

_Gather operator_

Stitches together several slices of `operand` from offsets specified in
`start_indices` (each slice at a potentially different runtime offset).

See https://www.tensorflow.org/xla/operation_semantics#gather.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

#### Attributes:

<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>dimension_numbers</code></td><td>::mlir::pphlo::GatherDimensionNumbersAttr</td><td>Attribute that models the dimension information for gather</td></tr>
<tr><td><code>slice_sizes</code></td><td>::mlir::DenseIntElementsAttr</td><td>64-bit signless integer elements attribute</td></tr>
<tr><td><code>indices_are_sorted</code></td><td>::mlir::BoolAttr</td><td>bool attribute</td></tr>
</table>

#### Operands:

| Operand | Description |
| :-----: | ----------- |
| `operand` | statically shaped tensor of PPHlo public type or PPHlo secret type values
| `start_indices` | statically shaped tensor of public integer type or secret integer type values

#### Results:

| Result | Description |
| :----: | ----------- |
&laquo;unnamed&raquo; | statically shaped tensor of PPHlo public type or PPHlo secret type values

### `pphlo.greater_equal` (pphlo::GreaterEqualOp)

_Greater_equal comparison operator_
Expand Down Expand Up @@ -1185,8 +1148,8 @@ Effects: MemoryEffects::Effect{}

<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>window_dimensions</code></td><td>::mlir::DenseIntElementsAttr</td><td>64-bit signless integer elements attribute</td></tr>
<tr><td><code>window_strides</code></td><td>::mlir::DenseIntElementsAttr</td><td>64-bit signless integer elements attribute</td></tr>
<tr><td><code>window_dimensions</code></td><td>::mlir::DenseI64ArrayAttr</td><td>i64 dense array attribute</td></tr>
<tr><td><code>window_strides</code></td><td>::mlir::DenseI64ArrayAttr</td><td>i64 dense array attribute</td></tr>
</table>

#### Operands:
Expand Down Expand Up @@ -1491,7 +1454,7 @@ Traits: RecursiveMemoryEffects, SameVariadicOperandSize, SingleBlock, SingleBloc

<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>dimensions</code></td><td>::mlir::DenseIntElementsAttr</td><td>64-bit signless integer elements attribute</td></tr>
<tr><td><code>dimensions</code></td><td>::mlir::DenseI64ArrayAttr</td><td>i64 dense array attribute</td></tr>
</table>

#### Operands:
Expand Down Expand Up @@ -1522,9 +1485,9 @@ Traits: RecursiveMemoryEffects, SameVariadicOperandSize, SingleBlock, SingleBloc

<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>window_dimensions</code></td><td>::mlir::DenseIntElementsAttr</td><td>64-bit signless integer elements attribute</td></tr>
<tr><td><code>window_strides</code></td><td>::mlir::DenseIntElementsAttr</td><td>64-bit signless integer elements attribute</td></tr>
<tr><td><code>window_dilations</code></td><td>::mlir::DenseIntElementsAttr</td><td>64-bit signless integer elements attribute</td></tr>
<tr><td><code>window_dimensions</code></td><td>::mlir::DenseI64ArrayAttr</td><td>i64 dense array attribute</td></tr>
<tr><td><code>window_strides</code></td><td>::mlir::DenseI64ArrayAttr</td><td>i64 dense array attribute</td></tr>
<tr><td><code>window_dilations</code></td><td>::mlir::DenseI64ArrayAttr</td><td>i64 dense array attribute</td></tr>
</table>

#### Operands:
Expand Down Expand Up @@ -1630,7 +1593,7 @@ Effects: MemoryEffects::Effect{}

<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>dimensions</code></td><td>::mlir::DenseIntElementsAttr</td><td>64-bit signless integer elements attribute</td></tr>
<tr><td><code>dimensions</code></td><td>::mlir::DenseI64ArrayAttr</td><td>i64 dense array attribute</td></tr>
</table>

#### Operands:
Expand Down Expand Up @@ -1745,8 +1708,8 @@ Traits: RecursiveMemoryEffects

<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>window_dimensions</code></td><td>::mlir::DenseIntElementsAttr</td><td>64-bit signless integer elements attribute</td></tr>
<tr><td><code>window_strides</code></td><td>::mlir::DenseIntElementsAttr</td><td>64-bit signless integer elements attribute</td></tr>
<tr><td><code>window_dimensions</code></td><td>::mlir::DenseI64ArrayAttr</td><td>i64 dense array attribute</td></tr>
<tr><td><code>window_strides</code></td><td>::mlir::DenseI64ArrayAttr</td><td>i64 dense array attribute</td></tr>
</table>

#### Operands:
Expand Down Expand Up @@ -1986,9 +1949,9 @@ Effects: MemoryEffects::Effect{}

<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>start_indices</code></td><td>::mlir::DenseIntElementsAttr</td><td>64-bit signless integer elements attribute</td></tr>
<tr><td><code>limit_indices</code></td><td>::mlir::DenseIntElementsAttr</td><td>64-bit signless integer elements attribute</td></tr>
<tr><td><code>strides</code></td><td>::mlir::DenseIntElementsAttr</td><td>64-bit signless integer elements attribute</td></tr>
<tr><td><code>start_indices</code></td><td>::mlir::DenseI64ArrayAttr</td><td>i64 dense array attribute</td></tr>
<tr><td><code>limit_indices</code></td><td>::mlir::DenseI64ArrayAttr</td><td>i64 dense array attribute</td></tr>
<tr><td><code>strides</code></td><td>::mlir::DenseI64ArrayAttr</td><td>i64 dense array attribute</td></tr>
</table>

#### Operands:
Expand Down Expand Up @@ -2136,7 +2099,7 @@ Effects: MemoryEffects::Effect{}

<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>permutation</code></td><td>::mlir::DenseIntElementsAttr</td><td>64-bit signless integer elements attribute</td></tr>
<tr><td><code>permutation</code></td><td>::mlir::DenseI64ArrayAttr</td><td>i64 dense array attribute</td></tr>
</table>

#### Operands:
Expand Down
3 changes: 2 additions & 1 deletion examples/python/ml/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ py_test(
"//examples/python/ml/stax_mnist_classifier",
"//examples/python/ml/stax_nn",
"//examples/python/ml/tf_experiment",
"//examples/python/ml/torch_experiment",
"//examples/python/ml/torch_lr_experiment",
"//examples/python/ml/torch_resnet_experiment",
"//spu/utils:distributed",
],
)
3 changes: 2 additions & 1 deletion examples/python/ml/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ library, and private inference of a pre-trained ResNet-50 model based on [Micros
* [jraph_gnn](jraph_gnn/): Private training of a [graph convolutional network](https://arxiv.org/abs/1609.02907) model with
[Jraph](https://github.com/deepmind/jraph).
* [tf_experiment](tf_experiment/): Private training of a logistic regression model with TensorFlow (**experimental**).
* [torch_experiment](torch_experiment/): Private inference of a linear regression model with PyTorch (**experimental**).
* [torch_lr_experiment](torch_lr_experiment/): Private inference of a logistic regression model with PyTorch (**experimental**).
* [torch_resnet_experiment](torch_resnet_experiment/): Private inference of a [ResNet](https://arxiv.org/abs/1512.03385) model with PyTorch (**experimental**).
2 changes: 1 addition & 1 deletion examples/python/ml/haiku_lstm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ This example comes from Haiku official github repo:
1. Install dependencies

```sh
pip install -r requirements.txt
pip install -r ../requirements.txt
```

2. Launch SPU backend runtime
Expand Down
2 changes: 0 additions & 2 deletions examples/python/ml/haiku_lstm/requirements.txt

This file was deleted.

2 changes: 1 addition & 1 deletion examples/python/ml/jraph_gnn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ This example comes from Jraph official github repo:
1. Install dependencies

```sh
pip install -r requirements.txt
pip install -r ../requirements.txt
```

2. Set runtime configuration
Expand Down
2 changes: 0 additions & 2 deletions examples/python/ml/jraph_gnn/requirements.txt

This file was deleted.

22 changes: 15 additions & 7 deletions examples/python/ml/ml_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,20 +213,27 @@ def test_tf_experiment(self):
score = tf_experiment.run_fit_manual_grad_spu()
self.assertGreater(score, 0.9)

def test_torch_experiment(self):
from examples.python.ml.torch_experiment import torch_experiment
def test_torch_lr_experiment(self):
from examples.python.ml.torch_lr_experiment import torch_lr_experiment

model = torch_experiment.LinearRegression()
torch_experiment.train(model)
score = torch_experiment.run_inference_on_spu(model)
model = torch_lr_experiment.LinearRegression()
torch_lr_experiment.train(model)
score = torch_lr_experiment.run_inference_on_spu(model)
self.assertGreater(score, 0.9)

def test_torch_resnet_experiment(self):
from examples.python.ml.torch_resnet_experiment import torch_resnet_experiment

model = torch_resnet_experiment.resnet
image = torch_resnet_experiment.input_batch
label = torch_resnet_experiment.run_inference_on_spu(model, image)
self.assertEqual(label, 258)

def test_save_and_load_model(self):
from examples.python.ml.jax_lr import jax_lr

score = jax_lr.save_and_load_model()
self.assertGreater(score, 0.9)
pass


def suite():
Expand All @@ -246,7 +253,8 @@ def suite():
suite.addTest(UnitTests('test_save_and_load_model'))
# should put JAX tests above
suite.addTest(UnitTests('test_tf_experiment'))
suite.addTest(UnitTests('test_torch_experiment'))
suite.addTest(UnitTests('test_torch_lr_experiment'))
# suite.addTest(UnitTests('test_torch_resnet_experiment'))
return suite


Expand Down
7 changes: 7 additions & 0 deletions examples/python/ml/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
dm-haiku==0.0.10
plotnine
jraph
optax==0.1.7
torch==2.1.0
torch_xla==2.1.0
torchvision
24 changes: 0 additions & 24 deletions examples/python/ml/torch_experiment/README.md

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Ant Group Co., Ltd.
# Copyright 2024 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -17,8 +17,8 @@ load("@rules_python//python:defs.bzl", "py_binary")
package(default_visibility = ["//visibility:public"])

py_binary(
name = "torch_experiment",
srcs = ["torch_experiment.py"],
name = "torch_lr_experiment",
srcs = ["torch_lr_experiment.py"],
data = [
"//examples/python/conf",
],
Expand Down
23 changes: 23 additions & 0 deletions examples/python/ml/torch_lr_experiment/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Torch Example

This example demonstrates how to use SPU to make private inferences on PyTorch models.

**Note**: Currently, SPU's support of PyTorch is **experimental**.

1. Install a third-party dependency [PyTorch/XLA](https://github.com/pytorch/xla).

```sh
pip install torch~=2.1.0 torch_xla[tpu]~=2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html
```

2. Launch SPU backend runtime

```sh
bazel run -c opt //examples/python/utils:nodectl -- up
```

3. Run `torch_lr_experiment` example

```sh
bazel run -c opt //examples/python/ml/torch_lr_experiment
```
Loading

0 comments on commit 6ae3684

Please sign in to comment.