From 67983a83856e08a19afe4fcd730df34d12692a7c Mon Sep 17 00:00:00 2001 From: ccsuzzh <1719571694@qq.com> Date: Thu, 26 Sep 2024 22:55:00 +0800 Subject: [PATCH 1/5] =?UTF-8?q?=E3=80=90Hackathon=207th=20No.35=E3=80=91?= =?UTF-8?q?=E4=B8=BA=20Paddle=20=E4=BB=A3=E7=A0=81=E8=BD=AC=E6=8D=A2?= =?UTF-8?q?=E5=B7=A5=E5=85=B7=E6=96=B0=E5=A2=9E=20API=20=E8=BD=AC=E6=8D=A2?= =?UTF-8?q?=E8=A7=84=E5=88=99=EF=BC=88=E7=AC=AC=202=20=E7=BB=84=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paconvert/api_mapping.json | 138 +++++++++++++++++++++++++--- paconvert/api_matcher.py | 156 ++++++++++++++++++++++++++++++++ tests/test_Tensor_addbmm_.py | 82 +++++++++++++++++ tests/test_Tensor_addcdiv_.py | 100 ++++++++++++++++++++ tests/test_Tensor_addmv_.py | 82 +++++++++++++++++ tests/test_Tensor_addr_.py | 81 +++++++++++++++++ tests/test_Tensor_baddbmm_.py | 82 +++++++++++++++++ tests/test_Tensor_copysign.py | 45 +++++---- tests/test_Tensor_copysign_.py | 68 ++++++++++++++ tests/test_Tensor_erfc_.py | 50 ++++++++++ tests/test_Tensor_fix_.py | 62 +++++++++++++ tests/test_Tensor_fmod_.py | 70 ++++++++++++++ tests/test_Tensor_sinc_.py | 40 ++++++++ tests/test_Tensor_t_.py | 30 ++++++ tests/test_Tensor_transpose_.py | 64 +++++++++++++ tests/test_Tensor_xlogy.py | 21 +++++ tests/test_Tensor_xlogy_.py | 62 +++++++++++++ 17 files changed, 1196 insertions(+), 37 deletions(-) create mode 100644 tests/test_Tensor_addbmm_.py create mode 100644 tests/test_Tensor_addcdiv_.py create mode 100644 tests/test_Tensor_addmv_.py create mode 100644 tests/test_Tensor_addr_.py create mode 100644 tests/test_Tensor_baddbmm_.py create mode 100644 tests/test_Tensor_copysign_.py create mode 100644 tests/test_Tensor_erfc_.py create mode 100644 tests/test_Tensor_fix_.py create mode 100644 tests/test_Tensor_fmod_.py create mode 100644 tests/test_Tensor_sinc_.py create mode 100644 tests/test_Tensor_t_.py create mode 100644 tests/test_Tensor_transpose_.py create mode 100644 tests/test_Tensor_xlogy_.py diff --git a/paconvert/api_mapping.json b/paconvert/api_mapping.json index 10232f952..a59687701 100644 --- a/paconvert/api_mapping.json +++ b/paconvert/api_mapping.json @@ -300,7 +300,17 @@ "alpha" ] }, - "torch.Tensor.addbmm_": {}, + "torch.Tensor.addbmm_": { + "Matcher": "AddBmm_Matcher", + "min_input_args": 2, + "args_list": [ + "batch1", + "batch2", + "*", + "beta", + "alpha" + ] + }, "torch.Tensor.addcdiv": { "Matcher": "AddCDivMatcher", "min_input_args": 2, @@ -311,7 +321,16 @@ "value" ] }, - "torch.Tensor.addcdiv_": {}, + "torch.Tensor.addcdiv_": { + "Matcher": "AddCDiv_Matcher", + "min_input_args": 2, + "args_list": [ + "tensor1", + "tensor2", + "*", + "value" + ] + }, "torch.Tensor.addcmul": { "Matcher": "AddCMulMatcher", "min_input_args": 2, @@ -376,7 +395,18 @@ "alpha" ] }, - "torch.Tensor.addmv_": {}, + "torch.Tensor.addmv_": { + "Matcher": "AddMR_Matcher", + "paddle_api": "paddle.mm", + "min_input_args": 2, + "args_list": [ + "mat", + "vec", + "*", + "beta", + "alpha" + ] + }, "torch.Tensor.addr": { "Matcher": "AddMRMatcher", "paddle_api": "paddle.outer", @@ -389,7 +419,18 @@ "alpha" ] }, - "torch.Tensor.addr_": {}, + "torch.Tensor.addr_": { + "Matcher": "AddMR_Matcher", + "paddle_api": "paddle.outer", + "min_input_args": 2, + "args_list": [ + "vec1", + "vec2", + "*", + "beta", + "alpha" + ] + }, "torch.Tensor.adjoint": { "Matcher": "AdjointMatcher", "min_input_args": 0 @@ -635,7 +676,18 @@ "alpha" ] }, - "torch.Tensor.baddbmm_": {}, + "torch.Tensor.baddbmm_": { + "Matcher": "AddMR_Matcher", + "paddle_api": "paddle.bmm", + "min_input_args": 2, + "args_list": [ + "batch1", + "batch2", + "*", + "beta", + "alpha" + ] + }, "torch.Tensor.bernoulli": { "Matcher": "TensorFunc2PaddleFunc", "paddle_api": "paddle.bernoulli", @@ -949,8 +1001,28 @@ "non_blocking": "" } }, - "torch.Tensor.copysign": {}, - "torch.Tensor.copysign_": {}, + "torch.Tensor.copysign": { + "Matcher": "Num2TensorBinaryMatcher", + "paddle_api": "paddle.Tensor.copysign", + "min_input_args": 1, + "args_list": [ + "other" + ], + "kwargs_change": { + "other": "y" + } + }, + "torch.Tensor.copysign_": { + "Matcher": "Num2TensorBinaryMatcher", + "paddle_api": "paddle.Tensor.copysign_", + "min_input_args": 1, + "args_list": [ + "other" + ], + "kwargs_change": { + "other": "y" + } + }, "torch.Tensor.corrcoef": { "Matcher": "GenericMatcher", "paddle_api": "paddle.Tensor.corrcoef", @@ -1349,7 +1421,10 @@ "Matcher": "ErfCMatcher", "min_input_args": 0 }, - "torch.Tensor.erfc_": {}, + "torch.Tensor.erfc_": { + "Matcher": "ErfC_Matcher", + "min_input_args": 0 + }, "torch.Tensor.erfinv": { "Matcher": "GenericMatcher", "paddle_api": "paddle.Tensor.erfinv", @@ -1439,7 +1514,11 @@ "paddle_api": "paddle.Tensor.trunc", "min_input_args": 0 }, - "torch.Tensor.fix_": {}, + "torch.Tensor.fix_": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.Tensor.trunc_", + "min_input_args": 0 + }, "torch.Tensor.flatten": { "Matcher": "GenericMatcher", "paddle_api": "paddle.Tensor.flatten", @@ -1567,7 +1646,17 @@ "other": "y" } }, - "torch.Tensor.fmod_": {}, + "torch.Tensor.fmod_": { + "Matcher": "Num2TensorBinaryMatcher", + "paddle_api": "paddle.Tensor.mod_", + "min_input_args": 1, + "args_list": [ + "other" + ], + "kwargs_change": { + "other": "y" + } + }, "torch.Tensor.frac": { "Matcher": "GenericMatcher", "paddle_api": "paddle.Tensor.frac", @@ -3362,7 +3451,10 @@ "Matcher": "SincMatcher", "min_input_args": 0 }, - "torch.Tensor.sinc_": {}, + "torch.Tensor.sinc_": { + "Matcher": "SincMatcher", + "min_input_args": 0 + }, "torch.Tensor.sinh": { "Matcher": "UnchangeMatcher", "min_input_args": 0 @@ -3607,7 +3699,11 @@ "paddle_api": "paddle.Tensor.t", "min_input_args": 0 }, - "torch.Tensor.t_": {}, + "torch.Tensor.t_": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.Tensor.t_", + "min_input_args": 0 + }, "torch.Tensor.take": { "Matcher": "TensorTakeMatcher", "paddle_api": "paddle.Tensor.take", @@ -3733,7 +3829,15 @@ "dim1" ] }, - "torch.Tensor.transpose_": {}, + "torch.Tensor.transpose_": { + "Matcher": "TensorTranspose_Matcher", + "paddle_api": "paddle.Tensor.transpose_", + "min_input_args": 2, + "args_list": [ + "dim0", + "dim1" + ] + }, "torch.Tensor.triangular_solve": { "Matcher": "TensorTriangularSolveMatcher", "paddle_api": "paddle.linalg.triangular_solve", @@ -3992,7 +4096,13 @@ "other" ] }, - "torch.Tensor.xlogy_": {}, + "torch.Tensor.xlogy_": { + "Matcher": "XLogY_Matcher", + "min_input_args": 1, + "args_list": [ + "other" + ] + }, "torch.Tensor.zero_": { "Matcher": "GenericMatcher", "paddle_api": "paddle.Tensor.zero_", diff --git a/paconvert/api_matcher.py b/paconvert/api_matcher.py index 4b4600203..7ac63629d 100644 --- a/paconvert/api_matcher.py +++ b/paconvert/api_matcher.py @@ -741,6 +741,35 @@ def generate_code(self, kwargs): return code +class TensorTranspose_Matcher(BaseMatcher): + def generate_aux_code(self): + API_TEMPLATE = textwrap.dedent( + """ + def transpose_aux_func(dims,dim0, dim1): + perm = list(range(dims)) + perm[dim0], perm[dim1] = perm[dim1], perm[dim0] + return perm + """ + ) + + return API_TEMPLATE + + def generate_code(self, kwargs): + self.write_aux_code() + API_TEMPLATE = textwrap.dedent( + """ + {}.transpose_(perm=paddle_aux.transpose_aux_func({}.ndim,{}, {})) + """ + ) + code = API_TEMPLATE.format( + self.paddleClass, + self.paddleClass, + kwargs["dim0"], + kwargs["dim1"], + ) + return code + + class BroadcastShapesMatcher(BaseMatcher): def get_paddle_nodes(self, args, kwargs): if len(args) == 1 and isinstance(args[0], ast.Starred): @@ -2208,6 +2237,26 @@ def generate_code(self, kwargs): return code +class AddCDiv_Matcher(BaseMatcher): + def generate_code(self, kwargs): + if "input" not in kwargs: + kwargs["input"] = self.paddleClass + + if "value" not in kwargs: + kwargs["value"] = 1 + + API_TEMPLATE = textwrap.dedent( + """ + {}.add_({} * {} / {}) + """ + ) + code = API_TEMPLATE.format( + kwargs["input"], kwargs["value"], kwargs["tensor1"], kwargs["tensor2"] + ) + + return code + + class IsNonzeroMatcher(BaseMatcher): def generate_code(self, kwargs): API_TEMPLATE = textwrap.dedent( @@ -2421,6 +2470,21 @@ def generate_code(self, kwargs): return code +class ErfC_Matcher(BaseMatcher): + def generate_code(self, kwargs): + if "input" not in kwargs: + kwargs["input"] = self.paddleClass + + API_TEMPLATE = textwrap.dedent( + """ + paddle.erf_({}).multiply_(paddle.to_tensor(-1.)).add_(paddle.to_tensor(1.)) + """ + ) + code = API_TEMPLATE.format(kwargs["input"]) + + return code + + class SpecialErfcxMatcher(BaseMatcher): def generate_code(self, kwargs): if "out" in kwargs and kwargs["out"] != "None": @@ -2446,6 +2510,9 @@ def generate_code(self, kwargs): if "input" not in kwargs: kwargs["input"] = self.paddleClass + if "other" in kwargs: + kwargs["other"] = "paddle.to_tensor({})".format(kwargs.pop("other")) + if "out" in kwargs and kwargs["out"] != "None": API_TEMPLATE = textwrap.dedent( """ @@ -2464,6 +2531,24 @@ def generate_code(self, kwargs): return code +class XLogY_Matcher(BaseMatcher): + def generate_code(self, kwargs): + if "input" not in kwargs: + kwargs["input"] = self.paddleClass + + if "other" in kwargs: + kwargs["other"] = f"paddle.to_tensor({kwargs.pop('other')})" + + API_TEMPLATE = textwrap.dedent( + """ + {}.multiply_(paddle.log({})) + """ + ) + code = API_TEMPLATE.format(kwargs["input"], kwargs["other"]) + + return code + + class Exp2Matcher(BaseMatcher): def generate_code(self, kwargs): if "out" in kwargs and kwargs["out"] != "None": @@ -2746,6 +2831,46 @@ def generate_code(self, kwargs): return code +class AddMR_Matcher(BaseMatcher): + def generate_code(self, kwargs): + if "input" not in kwargs: + kwargs["input"] = self.paddleClass + + params1 = ["mat1", "mat", "vec1", "batch1"] + params2 = ["mat2", "vec", "vec2", "batch2"] + param1, param2 = None, None + for i, param in enumerate(params1): + if param in kwargs: + param1 = kwargs[params1[i]] + param2 = kwargs[params2[i]] + + if "beta" in kwargs: + kwargs[ + "beta" + ] = f"paddle.to_tensor({kwargs.pop('beta')}, dtype={kwargs['input']}.dtype)" + else: + kwargs["beta"] = f"paddle.to_tensor(1, dtype={kwargs['input']}.dtype)" + + if "alpha" not in kwargs: + kwargs["alpha"] = 1 + + API_TEMPLATE = textwrap.dedent( + """ + {}.multiply_({}).add_({}*{}({}, {})) + """ + ) + code = API_TEMPLATE.format( + kwargs["input"], + kwargs["beta"], + kwargs["alpha"], + self.get_paddle_api(), + param1, + param2, + ) + + return code + + class AddBmmMatcher(BaseMatcher): def generate_code(self, kwargs): if "input" not in kwargs: @@ -2788,6 +2913,37 @@ def generate_code(self, kwargs): return code +class AddBmm_Matcher(BaseMatcher): + def generate_code(self, kwargs): + if "input" not in kwargs: + kwargs["input"] = self.paddleClass + + if "beta" in kwargs: + kwargs[ + "beta" + ] = f"paddle.to_tensor({kwargs.pop('beta')}, dtype={kwargs['input']}.dtype)" + else: + kwargs["beta"] = f"paddle.to_tensor(1, dtype={kwargs['input']}.dtype)" + + if "alpha" not in kwargs: + kwargs["alpha"] = 1 + + API_TEMPLATE = textwrap.dedent( + """ + {}.multiply_({}).add_({}*paddle.sum(paddle.bmm({}, {}), axis=0)) + """ + ) + code = API_TEMPLATE.format( + kwargs["input"], + kwargs["beta"], + kwargs["alpha"], + kwargs["batch1"], + kwargs["batch2"], + ) + + return code + + class CholeskyInverseMatcher(BaseMatcher): def generate_code(self, kwargs): if "input" not in kwargs: diff --git a/tests/test_Tensor_addbmm_.py b/tests/test_Tensor_addbmm_.py new file mode 100644 index 000000000..6e6bc2ce6 --- /dev/null +++ b/tests/test_Tensor_addbmm_.py @@ -0,0 +1,82 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.addbmm_") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([[[4., 5., 6.], [1., 2., 3.]]]) + b = torch.tensor([[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]]) + input = torch.tensor([[1., 2., 3.], [4., 5., 6.]]) + result = input.addbmm_(a, b) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([[[4., 5., 6.], [1., 2., 3.]]]) + b = torch.tensor([[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]]) + input = torch.tensor([[1., 2., 3.], [4., 5., 6.]]) + result = input.addbmm_(a, b, beta=3) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([[[4., 5., 6.], [1., 2., 3.]]]) + b = torch.tensor([[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]]) + input = torch.tensor([[1., 2., 3.], [4., 5., 6.]]) + result = input.addbmm_(batch1=a, batch2=b, beta=3, alpha=3) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[1., 2., 3.], [4., 5., 6.]]) + result = input.addbmm_(batch1=torch.tensor([[[4., 5., 6.], [1., 2., 3.]]]), batch2=torch.tensor([[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]]), beta=3, alpha=3) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([[[4., 5., 6.], [1., 2., 3.]]]) + b = torch.tensor([[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]]) + input = torch.tensor([[1., 2., 3.], [4., 5., 6.]]) + result = input.addbmm_(beta=3, alpha=3, batch2=b, batch1=a) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_addcdiv_.py b/tests/test_Tensor_addcdiv_.py new file mode 100644 index 000000000..e4208a666 --- /dev/null +++ b/tests/test_Tensor_addcdiv_.py @@ -0,0 +1,100 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.addcdiv_") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + tensor1 = torch.tensor([1., 2., 3.]) + tensor2 = torch.tensor([4., 5., 6.]) + input = torch.tensor([7., 8., 9.]) + result = input.addcdiv_(tensor1, tensor2) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + tensor1 = torch.tensor([1., 2., 3.]) + tensor2 = torch.tensor([4., 5., 6.]) + input = torch.tensor([7., 8., 9.]) + result = input.addcdiv_(tensor1, tensor2, value=2) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + tensor1 = torch.tensor([1., 2., 3.]) + tensor2 = torch.tensor([4., 5., 6.]) + input = torch.tensor([7., 8., 9.]) + value = 5.0 + result = input.addcdiv_(tensor1, tensor2=tensor2, value=value) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + tensor1 = torch.tensor([1., 2., 3.]) + tensor2 = torch.tensor([4., 5., 6.]) + input = torch.tensor([7., 8., 9.]) + value = 5 + result = input.addcdiv_(tensor1, tensor2, value=value) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + tensor1 = torch.tensor([1., 2., 3.]) + tensor2 = torch.tensor([4., 5., 6.]) + input = torch.tensor([7., 8., 9.]) + value = 5 + result = input.addcdiv_(tensor1=tensor1, tensor2=tensor2, value=value) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + tensor1 = torch.tensor([1., 2., 3.]) + tensor2 = torch.tensor([4., 5., 6.]) + input = torch.tensor([7., 8., 9.]) + value = 5 + result = input.addcdiv_(value=value, tensor2=tensor2, tensor1=tensor1) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_addmv_.py b/tests/test_Tensor_addmv_.py new file mode 100644 index 000000000..27c7ba658 --- /dev/null +++ b/tests/test_Tensor_addmv_.py @@ -0,0 +1,82 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.addmv_") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([[1., 2., 3.], [4., 5., 6.]]) + b = torch.tensor([1., 2., 3.]) + input = torch.tensor([1., 2.]) + result = input.addmv_(a, b) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([[1., 2., 3.], [4., 5., 6.]]) + b = torch.tensor([1., 2., 3.]) + input = torch.tensor([1., 2.]) + result = input.addmv_(a, b, beta=3) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([[1., 2., 3.], [4., 5., 6.]]) + b = torch.tensor([1., 2., 3.]) + input = torch.tensor([1., 2.]) + result = input.addmv_(mat=a, vec=b, beta=3, alpha=3) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([1., 2.]) + result = input.addmv_(mat=torch.tensor([[1., 2., 3.], [4., 5., 6.]]), vec=torch.tensor([1., 2., 3.]), beta=3, alpha=3) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([[1., 2., 3.], [4., 5., 6.]]) + b = torch.tensor([1., 2., 3.]) + input = torch.tensor([1., 2.]) + result = input.addmv_(alpha=3, mat=a, beta=3, vec=b) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_addr_.py b/tests/test_Tensor_addr_.py new file mode 100644 index 000000000..1890b0844 --- /dev/null +++ b/tests/test_Tensor_addr_.py @@ -0,0 +1,81 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.addr_") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([4., 5., 6.]) + b = torch.tensor([1., 2., 3.]) + input = torch.tensor([[1., 2., 3.], [7., 8., 9.], [10., 11., 12.]]) + result = input.addr_(a, b) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([4., 5., 6.]) + b = torch.tensor([1., 2., 3.]) + input = torch.tensor([[1., 2., 3.], [7., 8., 9.], [10., 11., 12.]]) + result = input.addr_(a, b, beta=3) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([4., 5., 6.]) + b = torch.tensor([1., 2., 3.]) + input = torch.tensor([[1., 2., 3.], [7., 8., 9.], [10., 11., 12.]]) + result = input.addr_(vec1=a, vec2=b, beta=3, alpha=3) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[1., 2.], [10., 11.]]) + result = input.addr_(vec1=torch.tensor([4., 5.]), vec2=torch.tensor([1., 2.]), beta=3, alpha=3) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([4., 5.]) + b = torch.tensor([1., 2.]) + input = torch.tensor([[1., 2.], [10., 11.]]) + result = input.addr_(vec1=a, alpha=3, vec2=b, beta=3) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_baddbmm_.py b/tests/test_Tensor_baddbmm_.py new file mode 100644 index 000000000..52bba2b94 --- /dev/null +++ b/tests/test_Tensor_baddbmm_.py @@ -0,0 +1,82 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.baddbmm_") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([[[4., 5., 6.], [1., 2., 3.]]]) + b = torch.tensor([[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]]) + input = torch.tensor([[[1., 2., 3.], [4., 5., 6.]]]) + result = input.baddbmm_(a, b) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([[[4., 5., 6.], [1., 2., 3.]]]) + b = torch.tensor([[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]]) + input = torch.tensor([[[1., 2., 3.], [4., 5., 6.]]]) + result = input.baddbmm_(a, b, beta=3) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([[[4., 5., 6.], [1., 2., 3.]]]) + b = torch.tensor([[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]]) + input = torch.tensor([[[1., 2., 3.], [4., 5., 6.]]]) + result = input.baddbmm_(batch1=a, batch2=b, beta=3, alpha=3) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[[1., 2., 3.], [4., 5., 6.]]]) + result = input.baddbmm_(batch1=torch.tensor([[[4., 5., 6.], [1., 2., 3.]]]), batch2=torch.tensor([[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]]), beta=3, alpha=3) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([[[4., 5., 6.], [1., 2., 3.]]]) + b = torch.tensor([[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]]) + input = torch.tensor([[[1., 2., 3.], [4., 5., 6.]]]) + result = input.baddbmm_(beta=3, batch1=a, batch2=b, alpha=3) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_copysign.py b/tests/test_Tensor_copysign.py index 022107168..28d09dd03 100644 --- a/tests/test_Tensor_copysign.py +++ b/tests/test_Tensor_copysign.py @@ -23,33 +23,26 @@ def test_case_1(): pytorch_code = textwrap.dedent( """ import torch - a = torch.randn(5) - result = torch.copysign(a, 1) + a = torch.tensor([-1.2557, -0.0026, -0.5387, 0.4740, -0.9244]) + result = a.copysign(1.) """ ) - obj.run( - pytorch_code, - ["result"], - unsupport=True, - reason="paddle has no corresponding api tentatively", - ) + obj.run(pytorch_code, ["result"]) def test_case_2(): pytorch_code = textwrap.dedent( """ import torch - a = torch.randn(4, 4) - b = torch.randn(4) - result = torch.copysign(a, b) + a = torch.tensor([[ 0.7079, 0.2778, -1.0249, 0.5719], + [-0.0059, -0.2600, -0.4475, -1.3948], + [ 0.3667, -0.9567, -2.5757, -0.1751], + [ 0.2046, -0.0742, 0.2998, -0.1054]]) + b = torch.tensor([ 0.2373, 0.3120, 0.3190, -1.1128]) + result = a.copysign(b) """ ) - obj.run( - pytorch_code, - ["result"], - unsupport=True, - reason="paddle has no corresponding api tentatively", - ) + obj.run(pytorch_code, ["result"]) def test_case_3(): @@ -58,12 +51,18 @@ def test_case_3(): import torch a = torch.tensor([1.]) b = torch.tensor([-0.]) - result = torch.copysign(a, b) + result = a.copysign(b) """ ) - obj.run( - pytorch_code, - ["result"], - unsupport=True, - reason="paddle has no corresponding api tentatively", + obj.run(pytorch_code, ["result"]) + + +# paddle.Tensor.copysign not support type promote and x/y must have same dtype +def _test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([3., 2, 1]).copysign(other=torch.tensor([2])) + """ ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_copysign_.py b/tests/test_Tensor_copysign_.py new file mode 100644 index 000000000..9eff4d416 --- /dev/null +++ b/tests/test_Tensor_copysign_.py @@ -0,0 +1,68 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.copysign_") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([-1.2557, -0.0026, -0.5387, 0.4740, -0.9244]) + result = a.copysign_(1.) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([[ 0.7079, 0.2778, -1.0249, 0.5719], + [-0.0059, -0.2600, -0.4475, -1.3948], + [ 0.3667, -0.9567, -2.5757, -0.1751], + [ 0.2046, -0.0742, 0.2998, -0.1054]]) + b = torch.tensor([ 0.2373, 0.3120, 0.3190, -1.1128]) + result = a.copysign_(b) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([1.]) + b = torch.tensor([-0.]) + result = a.copysign_(b) + """ + ) + obj.run(pytorch_code, ["result"]) + + +# paddle.Tensor.copysign_ not support type promote and x/y must have same dtype +def _test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([3., 2, 1]).copysign_(other=torch.tensor([2])) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_erfc_.py b/tests/test_Tensor_erfc_.py new file mode 100644 index 000000000..2a17e12e1 --- /dev/null +++ b/tests/test_Tensor_erfc_.py @@ -0,0 +1,50 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.erfc_") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([1., 2., -3., -4., 5.]) + a.erfc_() + """ + ) + obj.run(pytorch_code, ["a"], rtol=1.0e-5, atol=1.0e-8) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([[1., 2., -3., -4., 5.], [1., 2., -3., -4., 5.]]) + a.erfc_() + """ + ) + obj.run(pytorch_code, ["a"], rtol=1.0e-5, atol=1.0e-8) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([1., 2., -3., -4., 5.]).erfc_() + """ + ) + obj.run(pytorch_code, ["result"], rtol=1.0e-5, atol=1.0e-8) diff --git a/tests/test_Tensor_fix_.py b/tests/test_Tensor_fix_.py new file mode 100644 index 000000000..322acf693 --- /dev/null +++ b/tests/test_Tensor_fix_.py @@ -0,0 +1,62 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.fix_") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([3.4742, 0.5466, -0.8008, -0.9079]) + result = input.fix_() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([3.4742, 0.5466, -0.8008, -0.9079]).fix_() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([3.4742, 0.5466, -0.8008, -0.9079]) + result = x.fix_() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([3, 0, 5, -9]).fix_() + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_fmod_.py b/tests/test_Tensor_fmod_.py new file mode 100644 index 000000000..c38248047 --- /dev/null +++ b/tests/test_Tensor_fmod_.py @@ -0,0 +1,70 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.fmod_") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([1., 2., 3., 4., 5.]) + result = a.fmod_(1.5) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([3., 2, 1, 1, 2, 3]).fmod_(2.) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([3., 2, 1, 1, 2, 3]).fmod_(other=torch.tensor([2.])) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([[3., 2, 1], [1, 2, 3]]).fmod_(other=torch.tensor([2., 3, 1])) + """ + ) + obj.run(pytorch_code, ["result"]) + + +# paddle.Tensor.mod_ not support type promote and x/y must have same dtype +def _test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([[3., 2, 1], [1, 2, 3]]).fmod_(other=torch.tensor([2, 3, 1])) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_sinc_.py b/tests/test_Tensor_sinc_.py new file mode 100644 index 000000000..e3ec0d4c5 --- /dev/null +++ b/tests/test_Tensor_sinc_.py @@ -0,0 +1,40 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.sinc_") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([ 0.5950,-0.0872, 0, -0.2972]) + result = a.sinc_() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([ 0.5950,-0.0872, 0, -0.2972]).sinc_() + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_t_.py b/tests/test_Tensor_t_.py new file mode 100644 index 000000000..e0f6e72f1 --- /dev/null +++ b/tests/test_Tensor_t_.py @@ -0,0 +1,30 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.t_") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.Tensor([[1.,2.], [3.,4.]]) + result = a.t_() + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_transpose_.py b/tests/test_Tensor_transpose_.py new file mode 100644 index 000000000..dcb1f0767 --- /dev/null +++ b/tests/test_Tensor_transpose_.py @@ -0,0 +1,64 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.transpose_", is_aux_api=True) + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.Tensor([[1.,2.], [3.,4.]]) + result = a.transpose_(dim0=0, dim1=1) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.Tensor([[1.,2.], [3.,4.]]) + result = a.transpose_(0, 1) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.Tensor([[1.,2.], [3.,4.]]) + result = a.transpose_(dim1=0, dim0=1) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.Tensor([[1.,2.], [3.,4.]]) + list_a = [a,a] + result = [x.transpose_(dim1=0, dim0=1) for x in list_a ] + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_xlogy.py b/tests/test_Tensor_xlogy.py index ed05ca586..19bdcbf47 100644 --- a/tests/test_Tensor_xlogy.py +++ b/tests/test_Tensor_xlogy.py @@ -49,3 +49,24 @@ def test_case_3(): """ ) obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([1., 2. ,3.]).xlogy(other=3.0) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([1., 2., 3., 4., 5.]) + result = a.xlogy(3.0) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_xlogy_.py b/tests/test_Tensor_xlogy_.py new file mode 100644 index 000000000..2fcfb82f8 --- /dev/null +++ b/tests/test_Tensor_xlogy_.py @@ -0,0 +1,62 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.xlogy_") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([1., 2., 3., 4., 5.]) + b = torch.tensor([1., 2., 3., 4., 5.]) + result = a.xlogy_(b) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([1., 2., 3., 4., 5.]).xlogy_(other=torch.tensor([1., 2., 3., 4., 5.])) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([1., 2. ,3.]).xlogy_(other=3.0) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([1., 2., 3., 4., 5.]) + result = a.xlogy_(3.0) + """ + ) + obj.run(pytorch_code, ["result"]) From e5e038f3e81f101a205d6bc550c99214b922f12e Mon Sep 17 00:00:00 2001 From: ccsuzzh <1719571694@qq.com> Date: Sat, 28 Sep 2024 21:51:24 +0800 Subject: [PATCH 2/5] update TensorTransposeMatcher --- paconvert/api_mapping.json | 2 +- paconvert/api_matcher.py | 33 ++------------------------------- 2 files changed, 3 insertions(+), 32 deletions(-) diff --git a/paconvert/api_mapping.json b/paconvert/api_mapping.json index a59687701..24475813e 100644 --- a/paconvert/api_mapping.json +++ b/paconvert/api_mapping.json @@ -3830,7 +3830,7 @@ ] }, "torch.Tensor.transpose_": { - "Matcher": "TensorTranspose_Matcher", + "Matcher": "TensorTransposeMatcher", "paddle_api": "paddle.Tensor.transpose_", "min_input_args": 2, "args_list": [ diff --git a/paconvert/api_matcher.py b/paconvert/api_matcher.py index 7ac63629d..8bf497d1f 100644 --- a/paconvert/api_matcher.py +++ b/paconvert/api_matcher.py @@ -729,40 +729,11 @@ def generate_code(self, kwargs): self.write_aux_code() API_TEMPLATE = textwrap.dedent( """ - {}.transpose(perm=paddle_aux.transpose_aux_func({}.ndim,{}, {})) + {}(perm=paddle_aux.transpose_aux_func({}.ndim,{}, {})) """ ) code = API_TEMPLATE.format( - self.paddleClass, - self.paddleClass, - kwargs["dim0"], - kwargs["dim1"], - ) - return code - - -class TensorTranspose_Matcher(BaseMatcher): - def generate_aux_code(self): - API_TEMPLATE = textwrap.dedent( - """ - def transpose_aux_func(dims,dim0, dim1): - perm = list(range(dims)) - perm[dim0], perm[dim1] = perm[dim1], perm[dim0] - return perm - """ - ) - - return API_TEMPLATE - - def generate_code(self, kwargs): - self.write_aux_code() - API_TEMPLATE = textwrap.dedent( - """ - {}.transpose_(perm=paddle_aux.transpose_aux_func({}.ndim,{}, {})) - """ - ) - code = API_TEMPLATE.format( - self.paddleClass, + self.get_paddle_api(), self.paddleClass, kwargs["dim0"], kwargs["dim1"], From 8cbde9127bad4bfaa73a7aa73c23daf44a60a6a3 Mon Sep 17 00:00:00 2001 From: ccsuzzh <1719571694@qq.com> Date: Sun, 29 Sep 2024 17:21:12 +0800 Subject: [PATCH 3/5] update AddMRMatcherAddMR_Matcher --- paconvert/api_matcher.py | 28 ++++++---------------------- 1 file changed, 6 insertions(+), 22 deletions(-) diff --git a/paconvert/api_matcher.py b/paconvert/api_matcher.py index 8bf497d1f..dc2e49b8b 100644 --- a/paconvert/api_matcher.py +++ b/paconvert/api_matcher.py @@ -2755,14 +2755,6 @@ def generate_code(self, kwargs): if "input" not in kwargs: kwargs["input"] = self.paddleClass - params1 = ["mat1", "mat", "vec1", "batch1"] - params2 = ["mat2", "vec", "vec2", "batch2"] - param1, param2 = None, None - for i, param in enumerate(params1): - if param in kwargs: - param1 = kwargs[params1[i]] - param2 = kwargs[params2[i]] - if "beta" not in kwargs: kwargs["beta"] = 1 @@ -2780,8 +2772,8 @@ def generate_code(self, kwargs): kwargs["input"], kwargs["alpha"], self.get_paddle_api(), - param1, - param2, + kwargs["vec1"], + kwargs["vec2"], kwargs["out"], ) else: @@ -2795,8 +2787,8 @@ def generate_code(self, kwargs): kwargs["input"], kwargs["alpha"], self.get_paddle_api(), - param1, - param2, + kwargs["vec1"], + kwargs["vec2"], ) return code @@ -2807,14 +2799,6 @@ def generate_code(self, kwargs): if "input" not in kwargs: kwargs["input"] = self.paddleClass - params1 = ["mat1", "mat", "vec1", "batch1"] - params2 = ["mat2", "vec", "vec2", "batch2"] - param1, param2 = None, None - for i, param in enumerate(params1): - if param in kwargs: - param1 = kwargs[params1[i]] - param2 = kwargs[params2[i]] - if "beta" in kwargs: kwargs[ "beta" @@ -2835,8 +2819,8 @@ def generate_code(self, kwargs): kwargs["beta"], kwargs["alpha"], self.get_paddle_api(), - param1, - param2, + kwargs["vec1"], + kwargs["vec2"], ) return code From 20d0577cf4d4bf322a750ee41556789caf2a9d48 Mon Sep 17 00:00:00 2001 From: ccsuzzh <1719571694@qq.com> Date: Sun, 29 Sep 2024 21:07:02 +0800 Subject: [PATCH 4/5] revert 'update AddMRMatcherAddMR_Matcher' --- paconvert/api_matcher.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/paconvert/api_matcher.py b/paconvert/api_matcher.py index dc2e49b8b..e5e9835b6 100644 --- a/paconvert/api_matcher.py +++ b/paconvert/api_matcher.py @@ -2755,6 +2755,14 @@ def generate_code(self, kwargs): if "input" not in kwargs: kwargs["input"] = self.paddleClass + params1 = ["mat1", "mat", "vec1", "batch1"] + params2 = ["mat2", "vec", "vec2", "batch2"] + param1, param2 = None, None + for i, param in enumerate(params1): + if param in kwargs: + param1 = kwargs[params1[i]] + param2 = kwargs[params2[i]] + if "beta" not in kwargs: kwargs["beta"] = 1 @@ -2787,8 +2795,8 @@ def generate_code(self, kwargs): kwargs["input"], kwargs["alpha"], self.get_paddle_api(), - kwargs["vec1"], - kwargs["vec2"], + param1, + param2, ) return code @@ -2799,6 +2807,14 @@ def generate_code(self, kwargs): if "input" not in kwargs: kwargs["input"] = self.paddleClass + params1 = ["mat1", "mat", "vec1", "batch1"] + params2 = ["mat2", "vec", "vec2", "batch2"] + param1, param2 = None, None + for i, param in enumerate(params1): + if param in kwargs: + param1 = kwargs[params1[i]] + param2 = kwargs[params2[i]] + if "beta" in kwargs: kwargs[ "beta" @@ -2819,8 +2835,8 @@ def generate_code(self, kwargs): kwargs["beta"], kwargs["alpha"], self.get_paddle_api(), - kwargs["vec1"], - kwargs["vec2"], + param1, + param2, ) return code From f436ddfbbbeb4c5a0911277f2bb4931aa79f9f3d Mon Sep 17 00:00:00 2001 From: ccsuzzh <1719571694@qq.com> Date: Sun, 29 Sep 2024 21:28:57 +0800 Subject: [PATCH 5/5] fix ci --- paconvert/api_matcher.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paconvert/api_matcher.py b/paconvert/api_matcher.py index e5e9835b6..8bf497d1f 100644 --- a/paconvert/api_matcher.py +++ b/paconvert/api_matcher.py @@ -2780,8 +2780,8 @@ def generate_code(self, kwargs): kwargs["input"], kwargs["alpha"], self.get_paddle_api(), - kwargs["vec1"], - kwargs["vec2"], + param1, + param2, kwargs["out"], ) else: