diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 10a3423..2598cfe 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1,12 +1,11 @@ name: Build on: - pull_request: - branches: - - main - push: - branches: - - main + workflow_run: + workflows: [ "Test" ] + types: + - completed + branches: [ main ] workflow_dispatch: branches: - main @@ -44,11 +43,11 @@ jobs: --extra-index-url https://download.pytorch.org/whl/nightly/cpu mkdir -p ${{ github.sha }} - mv wheelhouse/SharkPy*.whl ${{ github.sha }}/ + mv wheelhouse/PI*.whl ${{ github.sha }}/ - name: Upload an artifact uses: actions/upload-artifact@v3 - if: github.event_name == 'push' + if: github.event_name == 'push' || github.event_name == 'workflow_dispatch' with: if-no-files-found: error name: build_artifact @@ -62,7 +61,7 @@ jobs: needs: [ build ] - if: ${{ github.event_name == 'push' }} + if: github.event_name == 'push' || github.event_name == 'workflow_dispatch' steps: - name: Checkout uses: actions/checkout@v2 @@ -76,10 +75,10 @@ jobs: - name: Set up a release page id: setup_release run: | - SHARKPY_VERSION=$(python setup.py --version) - tag_name="$SHARKPY_VERSION" - release_title="SharkPy $SHARKPY_VERSION" - echo "SharkPy $SHARKPY_VERSION created at $(date)" > body.md + PI_VERSION=$(python setup.py --version) + tag_name="$PI_VERSION" + release_title="PI $PI_VERSION" + echo "PI $PI_VERSION created at $(date)" > body.md echo "tag_name=${tag_name}" >> $GITHUB_OUTPUT echo "release_title=${release_title}" >> $GITHUB_OUTPUT @@ -88,7 +87,7 @@ jobs: with: artifacts: "${{ github.sha }}/*.whl" bodyFile: body.md - token: "${{ secrets.SHARK_PY_CI }}" + token: "${{ secrets.PI_CI }}" tag: "${{ steps.setup_release.outputs.tag_name }}" name: "${{ steps.setup_release.outputs.release_title }}" removeArtifacts: true diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..6593b8a --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,53 @@ +name: Test + +on: + pull_request: + branches: + - main + - nn_module + push: + branches: + - main + - nn_module + workflow_dispatch: + branches: + - main + - nn_module + +jobs: + + test-against-torch-mlir: + + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + arch: [ x86_64 ] + python_version: [ "3.10" ] + + steps: + - name: Checkout + uses: actions/checkout@v2 + +# - name: Install linux system packages +# run: | +# sudo apt-get update +# sudo apt-get -y install ninja-build cmake clang + + - name: Setup Python + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python_version }} + + - name: Install + run: | + pip install . \ + --pre torch-mlir torchvision \ + -f https://llvm.github.io/torch-mlir/package-index/ \ + --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ + -v + + - name: Test vs. torch-mlir + run: | + PYTHONPATH=tests/torch_mlir python tests/torch_mlir/main.py diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..17fe0db --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,62 @@ +cmake_minimum_required(VERSION 3.13.4) + +if (POLICY CMP0068) + cmake_policy(SET CMP0068 NEW) + set(CMAKE_BUILD_WITH_INSTALL_NAME_DIR ON) +endif () + +if (POLICY CMP0075) + cmake_policy(SET CMP0075 NEW) +endif () + +if (POLICY CMP0077) + cmake_policy(SET CMP0077 NEW) +endif () + +if (POLICY CMP0116) + cmake_policy(SET CMP0116 NEW) +endif () + +project(PI LANGUAGES CXX C) + +set(CMAKE_BUILD_WITH_INSTALL_NAME_DIR ON) + +set(CMAKE_CXX_STANDARD 17 CACHE STRING "C++ standard to conform to") + +find_package(MLIR REQUIRED CONFIG) + +message(STATUS "Using MLIRConfig.cmake in: ${MLIR_DIR}") +message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") + +set(LLVM_RUNTIME_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/bin) +set(LLVM_LIBRARY_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/lib) +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") +list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") +include(TableGen) +include(AddLLVM) +include(AddMLIR) +include(HandleLLVMOptions) + +include_directories(${LLVM_INCLUDE_DIRS}) +include_directories(${MLIR_INCLUDE_DIRS}) +link_directories(${LLVM_BUILD_LIBRARY_DIR}) +add_definitions(${LLVM_DEFINITIONS}) + +##################################### Bindings path hacks + +include(MLIRDetectPythonEnv) +include(AddMLIRPython) +mlir_configure_python_dev_packages() +mlir_detect_pybind11_install() + +set(PYTHON_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/cpp_ext) # --src-root +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) +# set(MLIR_TABLEGEN_EXE "" CACHE STRING "Path to mlir-tablegen") +# message(STATUS "MLIR_TABLEGEN_EXE: ${MLIR_TABLEGEN_EXE}") +set(MLIR_INCLUDE_TESTS 0) + +pybind11_add_module(_mlir cpp_ext/MainModule.cpp cpp_ext/TensorValue.cpp cpp_ext/TorchTypes.cpp) +#target_link_libraries(_mlir PRIVATE MLIRIR MLIRSupport MLIRCAPIInterfaces MLIRCAPIIR) + diff --git a/README.md b/README.md index cd65eb5..931bcf7 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -- [SharkPy](#sharkpy) +- [PI](#PI) - [Installing](#installing) - [Minimal example](#minimal-example) - [Moderately interesting example](#moderately-interesting-example) @@ -8,7 +8,7 @@ image

-# SharkPy +# PI Early days of a Python frontend for MLIR. @@ -25,10 +25,10 @@ pip install . \ and you're good to go. -Alternatively, you can install the [latest released wheel](https://github.com/nod-ai/SharkPy/releases/latest): +Alternatively, you can install the [latest released wheel](https://github.com/nod-ai/PI/releases/latest): ```shell -pip install https://github.com/nod-ai/SharkPy/releases/latest/download/SharkPy-$CURRENT_VERSION-py3-none-any.whl \ +pip install https://github.com/nod-ai/PI/releases/latest/download/PI-$CURRENT_VERSION-py3-none-any.whl \ --pre torch-mlir torchvision \ -f https://llvm.github.io/torch-mlir/package-index/ \ --extra-index-url https://download.pytorch.org/whl/nightly/cpu @@ -39,7 +39,7 @@ pip install https://github.com/nod-ai/SharkPy/releases/latest/download/SharkPy-$ [simple_kernels.py](./tests/simple_kernels.py) (in [tests](./tests)) looks like this ```python -from shark.dialects import memref, linalg +from pi.dialects import memref, linalg def saxpy(a: float, b: float): A = memref.AllocaOp((10, 30)) @@ -233,29 +233,47 @@ func.func private @saxpy(%arg0: f64, %arg1: f64) -> memref<10x20xf64> { Preliminary support for the `torch-mlir` dialect is available: ```python -def torch_ops(): - f64 = F64Type.get() - z = torch.ConstantFloatOp(value=FloatAttr.get(f64, 256.0)) - attr = DenseFPElementsAttr(Attribute.parse("dense<0.0> : tensor<3x5xf32>")) - a = torch.ValueTensorLiteralOp(attr) - b = torch.ValueTensorLiteralOp(attr) - c = torch.AtenAddTensorOp(a.result.type, a.result, b.result, z) - return c +class MyConv2d(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(3, 1, 3) + + def forward(self, x): + y = self.conv(x) + z = y + y + w = z * z + return w ``` lowers to ```mlir -func.func private @torch_ops() -> !torch.vtensor<[3,5],f32> { - %float2.560000e02 = torch.constant.float 2.560000e+02 - %0 = torch.vtensor.literal(dense<0.000000e+00> : tensor<3x5xf32>) : !torch.vtensor<[3,5],f32> - %1 = torch.vtensor.literal(dense<0.000000e+00> : tensor<3x5xf32>) : !torch.vtensor<[3,5],f32> - %2 = torch.aten.add.Tensor %0, %1, %float2.560000e02 : - !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>, !torch.float -> !torch.vtensor<[3,5],f32> - return %2 : !torch.vtensor<[3,5],f32> +module { + func.func private @simple_conv2d() -> !torch.vtensor { + %0 = torch.vtensor.literal(dense<1.000000e+00> : tensor<1x3x32x32xf32>) : !torch.vtensor<[1,3,32,32],f32> + %1 = torch.vtensor.literal(dense<1.000000e+00> : tensor<1xf32>) : !torch.vtensor<[1],f32> + %2 = torch.vtensor.literal(dense<1.000000e+00> : tensor<1x3x3x3xf32>) : !torch.vtensor<[1,3,3,3],f32> + %int1 = torch.constant.int 1 + %int1_0 = torch.constant.int 1 + %3 = torch.prim.ListConstruct %int1, %int1_0 : (!torch.int, !torch.int) -> !torch.list + %int0 = torch.constant.int 0 + %int0_1 = torch.constant.int 0 + %4 = torch.prim.ListConstruct %int0, %int0_1 : (!torch.int, !torch.int) -> !torch.list + %int1_2 = torch.constant.int 1 + %int1_3 = torch.constant.int 1 + %5 = torch.prim.ListConstruct %int1_2, %int1_3 : (!torch.int, !torch.int) -> !torch.list + %int1_4 = torch.constant.int 1 + %6 = torch.aten.conv2d %0, %2, %1, %3, %4, %5, %int1_4 : !torch.vtensor<[1,3,32,32],f32>, !torch.vtensor<[1,3,3,3],f32>, !torch.vtensor<[1],f32>, !torch.list, !torch.list, !torch.list, !torch.int -> !torch.vtensor + %7 = "torch.constant.number"() {value = 1 : i64} : () -> !torch.number + %8 = torch.aten.add.Tensor %6, %6, %7 : !torch.vtensor, !torch.vtensor, !torch.number -> !torch.vtensor + %9 = torch.aten.mul.Tensor %8, %8 : !torch.vtensor, !torch.vtensor -> !torch.vtensor + return %9 : !torch.vtensor + } } ``` +This is very rough right now; to get a rough idea of the current status check the [latest tests](https://github.com/nod-ai/PI/actions?query=branch%3Ann_module+) on the `nn_module` branch. + # Build Wheel ```shell diff --git a/cpp_ext/IRModule.h b/cpp_ext/IRModule.h new file mode 100644 index 0000000..bf7679d --- /dev/null +++ b/cpp_ext/IRModule.h @@ -0,0 +1,158 @@ +//===- IRModules.h - IR Submodules of pybind module -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H +#define MLIR_BINDINGS_PYTHON_IRMODULES_H + +#include +#include +#include + +#include +#include +#include + +#include "mlir-c/AffineExpr.h" +#include "mlir-c/AffineMap.h" +#include "mlir-c/Diagnostics.h" +#include "mlir-c/IR.h" +#include "mlir-c/IntegerSet.h" + +namespace py = pybind11; + +namespace mlir::python { + +class PyOperation; + +/// Template for a reference to a concrete type which captures a python +/// reference to its underlying python object. +template +class PyObjectRef { +public: + PyObjectRef(T *referrent, pybind11::object object) + : referrent(referrent), object(std::move(object)) { + assert(this->referrent && "cannot construct PyObjectRef with null referrent"); + assert(this->object && "cannot construct PyObjectRef with null object"); + } + PyObjectRef(PyObjectRef &&other) + : referrent(other.referrent), object(std::move(other.object)) { + other.referrent = nullptr; + assert(!other.object); + } + PyObjectRef(const PyObjectRef &other) + : referrent(other.referrent), object(other.object /* copies */) {} + ~PyObjectRef() = default; + + T *operator->() { + assert(referrent && object); + return referrent; + } + pybind11::object getObject() { + assert(referrent && object); + return object; + } + explicit operator bool() const { return referrent && object; } + +private: + T *referrent; + pybind11::object object; +}; + +/// Wrapper around MlirContext. +class PyMlirContext { +public: + PyMlirContext() = delete; + PyMlirContext(const PyMlirContext &) = delete; + PyMlirContext(PyMlirContext &&) = delete; + + explicit PyMlirContext(MlirContext context) : context(context){}; + + MlirContext context; + friend class PyModule; + friend class PyOperation; +}; + +using PyMlirContextRef = PyObjectRef; + +class BaseContextObject { +public: + explicit BaseContextObject(PyMlirContextRef ref) : contextRef(std::move(ref)) { + assert(this->contextRef && "context object constructed with null context ref"); + } + PyMlirContextRef contextRef; +}; + +class PyOperation : public BaseContextObject { +public: + PyOperation &getOperation() { return *this; } + PyOperation(PyMlirContextRef contextRef, MlirOperation operation) : BaseContextObject(std::move(contextRef)), operation(operation){}; + + pybind11::handle handle; + MlirOperation operation; + pybind11::object parentKeepAlive; + bool attached = true; + bool valid = true; + + friend class PyOperationBase; + friend class PySymbolTable; +}; + +using PyOperationRef = PyObjectRef; + +class PyValue { +public: + PyValue(PyOperationRef parentOperation, MlirValue value) + : parentOperation(std::move(parentOperation)), value(value) {} + explicit operator MlirValue() const { return value; } + +private: + PyOperationRef parentOperation; + MlirValue value; +}; + +struct PyType : public BaseContextObject { + PyType(PyMlirContextRef contextRef, MlirType type) + : BaseContextObject(std::move(contextRef)), type(type) {} + explicit operator MlirType() const { return type; } + [[nodiscard]] MlirType get() const { return type; } + + + MlirType type; +}; + + + +template +struct PyConcreteType : public BaseTy { +// using ClassTy = pybind11::class_; + + PyConcreteType() = default; + PyConcreteType(PyMlirContextRef contextRef, MlirType t) + : BaseTy(std::move(contextRef), t) {} + + static DerivedTy createFromCapsule_(py::capsule& capsule) { + MlirType rawType = {capsule.get_pointer()}; + if (mlirTypeIsNull(rawType)) + throw py::error_already_set(); + + MlirContext ctx = mlirTypeGetContext(rawType); + auto *unownedContextWrapper = new PyMlirContext(ctx); + auto pyCtxRef = py::reinterpret_steal(mlirPythonContextToCapsule(ctx)); + assert(pyCtxRef && "cast to py::object failed"); + auto ctxRef = PyMlirContextRef(unownedContextWrapper, std::move(pyCtxRef)); + + return {std::move(ctxRef), rawType}; + } + +}; + +void populateTorchTypes(py::module &m); + +}// namespace mlir::python + +#endif// MLIR_BINDINGS_PYTHON_IRMODULES_H diff --git a/cpp_ext/MainModule.cpp b/cpp_ext/MainModule.cpp new file mode 100644 index 0000000..520958a --- /dev/null +++ b/cpp_ext/MainModule.cpp @@ -0,0 +1,33 @@ +#include "dylib.hpp" +#include "IRModule.h" +#include "TensorValue.h" +#include "TorchTypes.h" +#include "TorchTypesCAPI.h" +#include "mlir-c/Bindings/Python/Interop.h" + +#include + +namespace py = pybind11; +using namespace mlir::python; + +// no clue why but without this i get a missing symbol error +namespace llvm { +int DisableABIBreakingChecks = 1; +int EnableABIBreakingChecks = 0; +}// namespace llvm + +PYBIND11_MODULE(_mlir, m) { +// dylib lib1("_torchMlir.cpython-310-darwin.so", dylib::no_filename_decorations); + dylib lib2("TorchMLIRAggregateCAPI"); + +// if (!lib1.has_symbol("mlirValueIsAOpResult")) +// std::cerr << "symbol 'mlirValueIsAOpResult' not found in '_torchMlir' lib" << std::endl; + if (!lib2.has_symbol("mlirValueIsAOpResult")) + std::cerr << "symbol 'mlirValueIsAOpResult' not found in 'TorchMLIRAggregateCAPI' lib" << std::endl; + else + std::cerr << "found symbol 'mlirValueIsAOpResult' in 'TorchMLIRAggregateCAPI' lib" << std::endl; + + bindValues(m); + bindTypes(m); + bindTypeHelpers(m); +} diff --git a/cpp_ext/TensorValue.cpp b/cpp_ext/TensorValue.cpp new file mode 100644 index 0000000..e065439 --- /dev/null +++ b/cpp_ext/TensorValue.cpp @@ -0,0 +1,55 @@ +//===- TorchTypes.cpp - C Interface for torch types -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "TensorValue.h" +#include "TorchTypesCAPI.h" + +#include "mlir/CAPI/Support.h" + +using namespace mlir; +using namespace mlir::python; + +Torch_Tensor Torch_Tensor::createFromCapsule_(const py::capsule &capsule) { + MlirValue value = {capsule.get_pointer()}; + if (mlirValueIsNull(value)) + throw py::error_already_set(); + MlirOperation owner; + if (mlirValueIsAOpResult(value)) + owner = mlirOpResultGetOwner(value); + if (mlirValueIsABlockArgument(value)) + owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value)); + if (mlirOperationIsNull(owner)) + throw py::error_already_set(); + + MlirContext ctx = mlirOperationGetContext(owner); + auto *unownedContextWrapper = new PyMlirContext(ctx); + auto pyCtxRef = py::reinterpret_steal(mlirPythonContextToCapsule(ctx)); + assert(pyCtxRef && "cast to py::object failed"); + auto ctxRef = PyMlirContextRef(unownedContextWrapper, std::move(pyCtxRef)); + + auto pyOpRef = py::reinterpret_steal(mlirPythonOperationToCapsule(owner)); + auto *unownedOperation = + new PyOperation(std::move(ctxRef), owner); + unownedOperation->handle = pyOpRef; + auto ownerRef = PyOperationRef(unownedOperation, std::move(pyOpRef)); + + return {ownerRef, value}; +} + +void bindValues(py::module &m) { + py::object value_ = + (py::object) py::module_::import("torch_mlir.ir").attr("Value"); + py::object op_result_ = + (py::object) py::module_::import("torch_mlir.ir").attr("OpResult"); + + py::class_(m, "_Torch_Tensor", value_) + .def(py::init<>([](const py::capsule &capsule) { + return Torch_Tensor::createFromCapsule_(capsule); + })); +} \ No newline at end of file diff --git a/cpp_ext/TensorValue.h b/cpp_ext/TensorValue.h new file mode 100644 index 0000000..ff0e2e9 --- /dev/null +++ b/cpp_ext/TensorValue.h @@ -0,0 +1,27 @@ +//===- TorchTypes.cpp - C Interface for torch types -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "IRModule.h" + +#include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Support.h" +#include "mlir/IR/BuiltinTypes.h" + +using namespace mlir; +using namespace mlir::python; + + +struct Torch_Tensor : PyValue { + Torch_Tensor(PyOperationRef operationRef, MlirValue value) + : PyValue(std::move(operationRef), value) {} + + static Torch_Tensor createFromCapsule_(const py::capsule& capsule); +}; + +void bindValues(py::module &m); \ No newline at end of file diff --git a/cpp_ext/TorchTypes.cpp b/cpp_ext/TorchTypes.cpp new file mode 100644 index 0000000..51288b0 --- /dev/null +++ b/cpp_ext/TorchTypes.cpp @@ -0,0 +1,84 @@ +//===- TorchTypes.cpp - C Interface for torch types -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "TorchTypes.h" +#include "IRModule.h" +#include "TorchTypesCAPI.h" +#include "mlir-c/BuiltinAttributes.h" + +using namespace mlir; +using namespace mlir::python; + +void bindTypes(py::module &m) { + py::object type_ = + (py::object) py::module_::import("torch_mlir.ir").attr("Type"); + + py::class_(m, "_Torch_IntType", type_) + .def(py::init<>([](py::capsule capsule) { + return Torch_IntType::createFromCapsule_(capsule); + })); + + py::class_(m, "_Torch_BoolType", type_) + .def(py::init<>([](py::capsule capsule) { + return Torch_BoolType::createFromCapsule_(capsule); + })); + + py::class_(m, "_Torch_StringType", type_) + .def(py::init<>([](py::capsule capsule) { + return Torch_StringType::createFromCapsule_(capsule); + })); + + py::class_(m, "_Torch_FloatType", type_) + .def(py::init<>([](py::capsule capsule) { + return Torch_FloatType::createFromCapsule_(capsule); + })); + + py::class_(m, "_Torch_ValueTensorType", type_) + .def(py::init<>([](py::capsule capsule) { + return Torch_ValueTensorType::createFromCapsule_(capsule); + })); + + py::class_(m, "_Torch_NonValueTensorType", type_) + .def(py::init<>([](py::capsule capsule) { + return Torch_NonValueTensorType::createFromCapsule_(capsule); + })); +} + +void bindTypeHelpers(py::module &m) { + m.def( + "is_a_torch_int_type", [](const py::capsule &type) { + MlirType rawType = {type.get_pointer()}; + return torchMlirTypeIsATorchInt(rawType); + }); + m.def( + "is_a_torch_bool_type", [](const py::capsule &type) { + MlirType rawType = {type.get_pointer()}; + return torchMlirTypeIsATorchBool(rawType); + }); + m.def( + "is_a_torch_string_type", [](const py::capsule &type) { + MlirType rawType = {type.get_pointer()}; + return torchMlirTypeIsATorchString(rawType); + }); + m.def( + "is_a_torch_float_type", [](const py::capsule &type) { + MlirType rawType = {type.get_pointer()}; + return torchMlirTypeIsATorchFloat(rawType); + }); + m.def( + "is_a_torch_value_tensor_type", [](const py::capsule &type) { + MlirType rawType = {type.get_pointer()}; + return torchMlirTypeIsATorchValueTensor(rawType); + }); + m.def( + "is_a_torch_nonvalue_tensor_type", [](const py::capsule &type) { + MlirType rawType = {type.get_pointer()}; + return torchMlirTypeIsATorchNonValueTensor(rawType); + }); +} diff --git a/cpp_ext/TorchTypes.h b/cpp_ext/TorchTypes.h new file mode 100644 index 0000000..bc3f99b --- /dev/null +++ b/cpp_ext/TorchTypes.h @@ -0,0 +1,51 @@ +//===- TorchTypes.cpp - C Interface for torch types -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "IRModule.h" +#include "TorchTypesCAPI.h" + +#include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Support.h" +#include "mlir/IR/BuiltinTypes.h" + +using namespace mlir; +using namespace mlir::python; + +struct Torch_IntType : public PyConcreteType { + Torch_IntType(PyMlirContextRef contextRef, MlirType t) + : PyConcreteType(std::move(contextRef), t) {} +}; + +struct Torch_BoolType : public PyConcreteType { + Torch_BoolType(PyMlirContextRef contextRef, MlirType t) + : PyConcreteType(std::move(contextRef), t) {} +}; + +struct Torch_StringType : public PyConcreteType { + Torch_StringType(PyMlirContextRef contextRef, MlirType t) + : PyConcreteType(std::move(contextRef), t) {} +}; + +struct Torch_FloatType : public PyConcreteType { + Torch_FloatType(PyMlirContextRef contextRef, MlirType t) + : PyConcreteType(std::move(contextRef), t) {} +}; + +struct Torch_ValueTensorType : public PyConcreteType { + Torch_ValueTensorType(PyMlirContextRef contextRef, MlirType t) + : PyConcreteType(std::move(contextRef), t) {} +}; + +struct Torch_NonValueTensorType : public PyConcreteType { + Torch_NonValueTensorType(PyMlirContextRef contextRef, MlirType t) + : PyConcreteType(std::move(contextRef), t) {} +}; + +void bindTypes(py::module &m); +void bindTypeHelpers(py::module &m); \ No newline at end of file diff --git a/cpp_ext/TorchTypesCAPI.h b/cpp_ext/TorchTypesCAPI.h new file mode 100644 index 0000000..761ebea --- /dev/null +++ b/cpp_ext/TorchTypesCAPI.h @@ -0,0 +1,266 @@ +//===-- torch-mlir-c/TorchTypes.h - C API for torch types ---------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIR_C_TORCHTYPES_H +#define TORCHMLIR_C_TORCHTYPES_H + +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" + +#ifdef __cplusplus +extern "C" { +#endif + +//===----------------------------------------------------------------------===// +// torch.nn.Module type. +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is a torch.nn.Module type +MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNnModule(MlirType t); + +/// Gets the !torch.nn.Module type of the specified class. +MLIR_CAPI_EXPORTED MlirType +torchMlirTorchNnModuleTypeGet(MlirContext context, MlirStringRef className); + +//===----------------------------------------------------------------------===// +// torch.optional type. +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is a !torch.optional type +MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchOptional(MlirType t); + +/// Gets the !torch.optional type with subtype T. +MLIR_CAPI_EXPORTED MlirType +torchMlirTorchOptionalTypeGet(MlirType containedType); + +//===----------------------------------------------------------------------===// +// torch.tuple type. +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is a !torch.tuple type +MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchTuple(MlirType t); + +/// Gets the !torch.tuple type with contained types `containedTypes`. +MLIR_CAPI_EXPORTED MlirType +torchMlirTorchTupleTypeGet(MlirContext context, intptr_t numContainedTypes, + MlirType const *containedTypes); + +//===----------------------------------------------------------------------===// +// torch.union type. +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is a !torch.union type +MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchUnion(MlirType t); + +/// Gets the !torch.union type with contained types `containedTypes`. +MLIR_CAPI_EXPORTED MlirType +torchMlirTorchUnionTypeGet(MlirContext context, intptr_t numContainedTypes, + MlirType const *containedTypes); + +//===----------------------------------------------------------------------===// +// torch.list type. +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is a !torch.list type +MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchList(MlirType t); + +/// Gets the !torch.list type with contained T. +MLIR_CAPI_EXPORTED MlirType torchMlirTorchListTypeGet(MlirType containedType); + +//===----------------------------------------------------------------------===// +// torch.Device type. +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is a !torch.Device type +MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchDevice(MlirType t); + +/// Gets the !torch.Device type. +MLIR_CAPI_EXPORTED MlirType torchMlirTorchDeviceTypeGet(MlirContext context); + +//===----------------------------------------------------------------------===// +// torch.Generator type. +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is a !torch.Generator type +MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchGenerator(MlirType t); + +/// Gets the !torch.Generator type. +MLIR_CAPI_EXPORTED MlirType torchMlirTorchGeneratorTypeGet(MlirContext context); + +//===----------------------------------------------------------------------===// +// torch.bool type. +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is a !torch.bool type +MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchBool(MlirType t); + +/// Gets the !torch.bool type. +MLIR_CAPI_EXPORTED MlirType torchMlirTorchBoolTypeGet(MlirContext context); + +//===----------------------------------------------------------------------===// +// torch.int type. +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is a !torch.int type +MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchInt(MlirType t); + +/// Gets the !torch.int type. +MLIR_CAPI_EXPORTED MlirType torchMlirTorchIntTypeGet(MlirContext context); + +//===----------------------------------------------------------------------===// +// torch.float type. +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is a !torch.float type +MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchFloat(MlirType t); + +/// Gets the !torch.float type. +MLIR_CAPI_EXPORTED MlirType torchMlirTorchFloatTypeGet(MlirContext context); + +//===----------------------------------------------------------------------===// +// torch.LinearParams type. +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is a !torch.LinearParams type +MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchLinearParams(MlirType t); + +/// Gets the !torch.LinearParams type. +MLIR_CAPI_EXPORTED MlirType +torchMlirTorchLinearParamsTypeGet(MlirContext context); + +//===----------------------------------------------------------------------===// +// torch.qint8 type. +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is a !torch.qint8 type +MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchQInt8(MlirType t); + +/// Gets the !torch.qint8 type. +MLIR_CAPI_EXPORTED MlirType torchMlirTorchQInt8TypeGet(MlirContext context); + +//===----------------------------------------------------------------------===// +// torch.quint8 type. +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is a !torch.quint8 type +MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchQUInt8(MlirType t); + +/// Gets the !torch.quint8 type. +MLIR_CAPI_EXPORTED MlirType torchMlirTorchQUInt8TypeGet(MlirContext context); + +//===----------------------------------------------------------------------===// +// torch.tensor type. +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is a !torch.tensor type +MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNonValueTensor(MlirType t); + +/// Gets a !torch.tensor type. +/// +/// - `numSizes` having a value of -1 denotes an unranked tensor. +/// - `optionalSizes` is allowed to be null, meaning that no size +/// information is present (and `numSizes` is ignored in that case). - +/// `optionalDtype` is allowed to be null, meaning that no dtype +/// information is present. +MLIR_CAPI_EXPORTED MlirType torchMlirTorchNonValueTensorTypeGet( + MlirContext context, intptr_t numSizes, const int64_t *optionalSizes, + MlirType optionalDtype); + +/// Gets the !torch.tensor type with the least static information. +MLIR_CAPI_EXPORTED MlirType +torchMlirTorchNonValueTensorTypeGetWithLeastStaticInformation( + MlirContext context); + +/// Gets the !torch.tensor type with the tensor attribute. +MLIR_CAPI_EXPORTED MlirType +torchMlirTorchNonValueTensorTypeGetFromAttribute(MlirAttribute attr); + +//===----------------------------------------------------------------------===// +// torch.vtensor type. +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is a !torch.vtensor type +MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchValueTensor(MlirType t); + +/// Gets a !torch.vtensor type. +/// +/// - `numSizes` having a value of -1 denotes an unranked tensor. +/// - `optionalSizes` is allowed to be null, meaning that no size +/// information is present (and `numSizes` is ignored in that case). +/// - `optionalDtype` is allowed to be null, meaning that no dtype +/// information is present. +MLIR_CAPI_EXPORTED MlirType torchMlirTorchValueTensorTypeGet( + MlirContext context, intptr_t numSizes, const int64_t *optionalSizes, + MlirType optionalDtype); + +/// Gets the !torch.tensor type with the least static information. +MLIR_CAPI_EXPORTED MlirType +torchMlirTorchValueTensorTypeGetWithLeastStaticInformation(MlirContext context); + +/// Gets the !torch.vtensor type with the tensor attribute. +MLIR_CAPI_EXPORTED MlirType +torchMlirTorchValueTensorTypeGetFromAttribute(MlirAttribute attr); + +//===----------------------------------------------------------------------===// +// !torch.none type. +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is a !torch.none type +MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNone(MlirType t); + +/// Gets the !torch.none type. +MLIR_CAPI_EXPORTED MlirType torchMlirTorchNoneTypeGet(MlirContext context); + +//===----------------------------------------------------------------------===// +// !torch.str type. +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is a !torch.str type +MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchString(MlirType t); + +/// Gets the !torch.str type. +MLIR_CAPI_EXPORTED MlirType torchMlirTorchStringTypeGet(MlirContext context); + +//===----------------------------------------------------------------------===// +// !torch.any type. +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is a !torch.any type. +MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchAny(MlirType t); + +/// Gets the !torch.str type. +MLIR_CAPI_EXPORTED MlirType torchMlirTorchAnyTypeGet(MlirContext context); + +//===----------------------------------------------------------------------===// +// !torch.number type. +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is a !torch.number type. +MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNumber(MlirType t); + +/// Gets the !torch.number type. +MLIR_CAPI_EXPORTED MlirType torchMlirTorchNumberTypeGet(MlirContext context); + +//===----------------------------------------------------------------------===// +// !torch.dict type. +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is a !torch.dict type. +MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchDict(MlirType t); + +/// Gets the !torch.dict type. +MLIR_CAPI_EXPORTED MlirType torchMlirTorchDictTypeGet(MlirType keyType, + MlirType valueType); + +#ifdef __cplusplus +} +#endif + +#endif// TORCHMLIR_C_TORCHTYPES_H diff --git a/cpp_ext/dylib.hpp b/cpp_ext/dylib.hpp new file mode 100644 index 0000000..62a9b39 --- /dev/null +++ b/cpp_ext/dylib.hpp @@ -0,0 +1,298 @@ +/** + * @file dylib.hpp + * @version 2.1.0 + * @brief C++ cross-platform wrapper around dynamic loading of shared libraries + * @link https://github.com/martin-olivier/dylib + * + * @author Martin Olivier + * @copyright (c) 2022 Martin Olivier + * + * This library is released under MIT license + */ + +#pragma once + +#include +#include +#include + +#if ((defined(_MSVC_LANG) && _MSVC_LANG >= 201703L) || __cplusplus >= 201703L) +#define DYLIB_CPP17 +#include +#endif + +#if defined(_WIN32) || defined(_WIN64) +#define WIN32_LEAN_AND_MEAN +#include +#undef WIN32_LEAN_AND_MEAN +#else +#include +#endif + +#if defined(_WIN32) || defined(_WIN64) +#define DYLIB_WIN_MAC_OTHER(win_def, mac_def, other_def) win_def +#define DYLIB_WIN_OTHER(win_def, other_def) win_def +#elif defined(__APPLE__) +#define DYLIB_WIN_MAC_OTHER(win_def, mac_def, other_def) mac_def +#define DYLIB_WIN_OTHER(win_def, other_def) other_def +#else +#define DYLIB_WIN_MAC_OTHER(win_def, mac_def, other_def) other_def +#define DYLIB_WIN_OTHER(win_def, other_def) other_def +#endif + +/** + * The dylib class can hold a dynamic library instance and interact with it + * by getting its symbols like functions or global variables + */ +class dylib { +public: + struct filename_components { + static constexpr const char *prefix = DYLIB_WIN_OTHER("", "lib"); + static constexpr const char *suffix = DYLIB_WIN_MAC_OTHER(".dll", ".dylib", ".so"); + }; + using native_handle_type = DYLIB_WIN_OTHER(HINSTANCE, void *); + using native_symbol_type = DYLIB_WIN_OTHER(FARPROC, void *); + + static_assert(std::is_pointer::value, "Expecting HINSTANCE to be a pointer"); + static_assert(std::is_pointer::value, "Expecting FARPROC to be a pointer"); + + static constexpr bool add_filename_decorations = true; + static constexpr bool no_filename_decorations = false; + + /** + * This exception is raised when the library failed to load a dynamic library or a symbol + * + * @param message the error message + */ + class exception : public std::runtime_error { + public: + explicit exception(const std::string &message) : std::runtime_error(message) {} + }; + + /** + * This exception is raised when the library failed to load or encountered symbol resolution issues + * + * @param message the error message + */ + class load_error : public exception { + public: + explicit load_error(const std::string &message) : exception(message) {} + }; + + /** + * This exception is raised when the library failed to load a symbol + * + * @param message the error message + */ + class symbol_error : public exception { + public: + explicit symbol_error(const std::string &message) : exception(message) {} + }; + + dylib(const dylib&) = delete; + dylib& operator=(const dylib&) = delete; + + dylib(dylib &&other) noexcept : m_handle(other.m_handle) { + other.m_handle = nullptr; + } + + dylib& operator=(dylib &&other) noexcept { + if (this != &other) + std::swap(m_handle, other.m_handle); + return *this; + } + + /** + * @brief Loads a dynamic library + * + * @throws dylib::load_error if the library could not be opened (including + * the case of the library file not being found) + * + * @param dir_path the directory path where is located the dynamic library you want to load + * @param name the name of the dynamic library to load + * @param decorations add os decorations to the library name + */ + ///@{ + dylib(const char *dir_path, const char *lib_name, bool decorations = add_filename_decorations) { + if (!dir_path || !lib_name) + throw std::invalid_argument("Null parameter"); + + std::string final_name = lib_name; + std::string final_path = dir_path; + + if (decorations) + final_name = filename_components::prefix + final_name + filename_components::suffix; + + if (final_path != "" && final_path.find_last_of('/') != final_path.size() - 1) + final_path += '/'; + + m_handle = open((final_path + final_name).c_str()); + + if (!m_handle) + throw load_error("Could not load library \"" + final_path + final_name + "\"\n" + get_error_description()); + } + + dylib(const std::string &dir_path, const std::string &lib_name, bool decorations = add_filename_decorations) + : dylib(dir_path.c_str(), lib_name.c_str(), decorations) {} + + dylib(const std::string &dir_path, const char *lib_name, bool decorations = add_filename_decorations) + : dylib(dir_path.c_str(), lib_name, decorations) {} + + dylib(const char *dir_path, const std::string &lib_name, bool decorations = add_filename_decorations) + : dylib(dir_path, lib_name.c_str(), decorations) {} + + explicit dylib(const std::string &lib_name, bool decorations = add_filename_decorations) + : dylib("", lib_name.c_str(), decorations) {} + + explicit dylib(const char *lib_name, bool decorations = add_filename_decorations) + : dylib("", lib_name, decorations) {} + +#ifdef DYLIB_CPP17 + explicit dylib(const std::filesystem::path &lib_path) + : dylib("", lib_path.string().c_str(), no_filename_decorations) {} + + dylib(const std::filesystem::path &dir_path, const std::string &lib_name, bool decorations = add_filename_decorations) + : dylib(dir_path.string().c_str(), lib_name.c_str(), decorations) {} + + dylib(const std::filesystem::path &dir_path, const char *lib_name, bool decorations = add_filename_decorations) + : dylib(dir_path.string().c_str(), lib_name, decorations) {} +#endif + ///@} + + ~dylib() { + if (m_handle) + close(m_handle); + } + + /** + * Get a symbol from the dynamic library currently loaded in the object + * + * @throws dylib::symbol_error if the symbol could not be found + * + * @param symbol_name the symbol name to get from the dynamic library + * + * @return a pointer to the requested symbol + */ + native_symbol_type get_symbol(const char *symbol_name) const { + if (!symbol_name) + throw std::invalid_argument("Null parameter"); + if (!m_handle) + throw std::logic_error("The dynamic library handle is null"); + + auto symbol = locate_symbol(m_handle, symbol_name); + + if (symbol == nullptr) + throw symbol_error("Could not get symbol \"" + std::string(symbol_name) + "\"\n" + get_error_description()); + return symbol; + } + + native_symbol_type get_symbol(const std::string &symbol_name) const { + return get_symbol(symbol_name.c_str()); + } + + /** + * Get a function from the dynamic library currently loaded in the object + * + * @throws dylib::symbol_error if the symbol could not be found + * + * @param T the template argument must be the function prototype to get + * @param symbol_name the symbol name of a function to get from the dynamic library + * + * @return a pointer to the requested function + */ + template + T *get_function(const char *symbol_name) const { + return reinterpret_cast(get_symbol(symbol_name)); + } + + template + T *get_function(const std::string &symbol_name) const { + return get_function(symbol_name.c_str()); + } + + /** + * Get a variable from the dynamic library currently loaded in the object + * + * @throws dylib::symbol_error if the symbol could not be found + * + * @param T the template argument must be the type of the variable to get + * @param symbol_name the symbol name of a variable to get from the dynamic library + * + * @return a reference to the requested variable + */ + template + T &get_variable(const char *symbol_name) const { + return *reinterpret_cast(get_symbol(symbol_name)); + } + + template + T &get_variable(const std::string &symbol_name) const { + return get_variable(symbol_name.c_str()); + } + + /** + * Check if a symbol exists in the currently loaded dynamic library. + * This method will return false if no dynamic library is currently loaded + * or if the symbol name is nullptr + * + * @param symbol_name the symbol name to look for + * + * @return true if the symbol exists in the dynamic library, false otherwise + */ + bool has_symbol(const char *symbol_name) const noexcept { + if (!m_handle || !symbol_name) + return false; + return locate_symbol(m_handle, symbol_name) != nullptr; + } + + bool has_symbol(const std::string &symbol) const noexcept { + return has_symbol(symbol.c_str()); + } + + /** + * @return the dynamic library handle + */ + native_handle_type native_handle() noexcept { + return m_handle; + } + +protected: + native_handle_type m_handle{nullptr}; + + static native_handle_type open(const char *path) noexcept { +#if defined(_WIN32) || defined(_WIN64) + return LoadLibraryA(path); +#else + return dlopen(path, RTLD_NOW | RTLD_LOCAL); +#endif + } + + static native_symbol_type locate_symbol(native_handle_type lib, const char *name) noexcept { + return DYLIB_WIN_OTHER(GetProcAddress, dlsym)(lib, name); + } + + static void close(native_handle_type lib) noexcept { + DYLIB_WIN_OTHER(FreeLibrary, dlclose)(lib); + } + + static std::string get_error_description() noexcept { +#if defined(_WIN32) || defined(_WIN64) + constexpr const size_t buf_size = 512; + auto error_code = GetLastError(); + if (!error_code) + return "Unknown error (GetLastError failed)"; + char description[512]; + auto lang = MAKELANGID(LANG_ENGLISH, SUBLANG_ENGLISH_US); + const DWORD length = + FormatMessageA(FORMAT_MESSAGE_FROM_SYSTEM, nullptr, error_code, lang, description, buf_size, nullptr); + return (length == 0) ? "Unknown error (FormatMessage failed)" : description; +#else + auto description = dlerror(); + return (description == nullptr) ? "Unknown error (dlerror failed)" : description; +#endif + } +}; + +#undef DYLIB_WIN_MAC_OTHER +#undef DYLIB_WIN_OTHER +#undef DYLIB_CPP17 diff --git a/pi/__init__.py b/pi/__init__.py new file mode 100644 index 0000000..ef8035e --- /dev/null +++ b/pi/__init__.py @@ -0,0 +1,43 @@ +import logging + +logger = logging.getLogger(__name__) +# noinspection PyUnresolvedReferences +from .dialects import patch_meta_path_non_context + +if __name__ == "pi": + # prevent double patching of path during testing + # where we've already patched torch -> pi + patch_meta_path_non_context() +else: + logger.debug(f"reimporting pi as {__name__}") + +# this has to go before the above (otherwise torch extensions won't be picked up) +# noinspection PyUnresolvedReferences +from torch_mlir import ir +import torch_mlir + +assert ( + len(torch_mlir.dialects.torch.AtenConv2dOp.__bases__) > 1 +), "failed to import torch dialect extensions" + +from ._tensor import * +from .types_ import * +from .dialects._torch_wrappers import * +from ._ops import _OpNamespace + +ops = _OpNamespace("ops") +_nn = _OpNamespace("_nn") +_C = _OpNamespace("_C") +_VF = _OpNamespace("_VF") +special = _OpNamespace("special") +linalg = _OpNamespace("linalg") + + +from . import nn as nn + + +def manual_seed(*_, **__): + return + + +DEBUG = True diff --git a/pi/_ops.py b/pi/_ops.py new file mode 100644 index 0000000..830758a --- /dev/null +++ b/pi/_ops.py @@ -0,0 +1,83 @@ +import types + +from .dialects import _torch_wrappers + +all_ops = {o: _torch_wrappers.__dict__[o] for o in _torch_wrappers.__all__} + + +class _OpNamespace(types.ModuleType): + def __init__(self, name): + super(_OpNamespace, self).__init__("pi." + name) + self.name = name + self._dir = [] + + def __iter__(self): + return iter(self._dir) + + def __getattr__(self, op_name): + # It is not a valid op_name when __file__ is passed in + # if op_name == "__file__": + # return "pi.ops" + # elif op_name == "__origin__": + # raise AttributeError() + + # namespace_name = self.name + # qualified_op_name = "{}::{}".format(namespace_name, op_name) + op_name = op_name.split(".")[-1] + + if op_name in all_ops: + return all_ops[op_name] + else: + return _OpNamespace(op_name) + + def __call__(self, *args, **kwargs): + if self.name in all_ops: + return all_ops[self.name](*args, **kwargs) + else: + raise NotImplementedError(self.name) + + # TODO(max): resolve overloads correctly here + # Get the op `my_namespace::my_op` if available. This will also check + # for overloads and raise an exception if there are more than one. + # try: + # op, overload_names = pi._C._jit_get_operation(qualified_op_name) + # except RuntimeError as e: + # # Turn this into AttributeError so getattr(obj, key, default) + # # works (this is called by TorchScript with __origin__) + # raise AttributeError( + # f"'_OpNamespace' '{self.name}' object has no attribute '{op_name}'" + # ) from e + # + # # let the script frontend know that op is identical to the builtin op + # # with qualified_op_name + # pi.jit._builtins._register_builtin(op, qualified_op_name) + # op.__module__ = self.__module__ + "." + namespace_name + # opoverloadpacket = OpOverloadPacket( + # qualified_op_name, op_name, op, overload_names + # ) + # opoverloadpacket.__module__ = self.__module__ + "." + namespace_name + # # cache the opoverloadpacket to ensure that each op corresponds to + # # a unique OpOverloadPacket object + # setattr(self, op_name, opoverloadpacket) + # self._dir.append(op_name) + # return opoverloadpacket + + +# class _PyOpNamespace(_OpNamespace): +# def __init__(self): +# super(_PyOpNamespace, self).__init__("pi.ops") +# self.pyop_namespace = all_ops + + +# class _Ops(types.ModuleType): +# def __init__(self, name): +# super(_Ops, self).__init__(name) +# +# def __getattr__(self, name): +# # Check if the name is a pyop +# if name in self.pyops.pyop_namespace: +# return self.pyops.pyop_namespace[name] +# +# namespace = _OpNamespace(name) +# setattr(self, name, namespace) +# return namespace diff --git a/pi/_tensor.py b/pi/_tensor.py new file mode 100644 index 0000000..42bc4cf --- /dev/null +++ b/pi/_tensor.py @@ -0,0 +1,2216 @@ +from __future__ import annotations +import warnings +from typing import Tuple, Optional, Any + +# noinspection PyUnresolvedReferences +import numpy as np +from torch_mlir.dialects import torch as torch_dialect +from torch_mlir.dialects._ods_common import get_op_result_or_value +from torch_mlir.ir import ( + DenseElementsAttr, +) +from torch_mlir.ir import ( + Value as MLIRValue, +) + +import pi +from .types_ import dtype as pi_dtype + + +class TorchTensorWrapper(type): + # def __new__(mcs, name, bases, class_dict): + # for k, f in class_dict.items(): + # if k in {"__init__", "__hash__", "_version", "value", "__class__", "type"}: + # continue + # if inspect.isfunction(f) and not isinstance(f, property): + # def run_on_actual_value(*args, **kwargs): + # self = args[0] + # return f((self.value, *args[1:]), **kwargs) + # + # class_dict[k] = run_on_actual_value + # return type.__new__(mcs, name, bases, class_dict) + + def __subclasscheck__(cls, subclass): + print(cls, subclass) + return False + + @classmethod + def __instancecheck__(cls, instance): + try: + return instance.is_pi_tensor + except: + return False + + +class Tensor(metaclass=TorchTensorWrapper): + @property + def is_pi_tensor(self): + return True + + @property + def __class__(self): + return MLIRValue + + @property + def type(self): + return self._value.type + + @property + def value(self): + return self._value + + def __init__(self, tensor: MLIRValue): + self._value = get_op_result_or_value(tensor) + + def abs(self): + raise NotImplementedError + + def absolute(self): + raise NotImplementedError + + def absolute_(self): + raise NotImplementedError + + def abs_(self): + raise NotImplementedError + + def acos(self): + raise NotImplementedError + + def acosh(self): + raise NotImplementedError + + def acosh_(self): + raise NotImplementedError + + def acos_(self): + raise NotImplementedError + + def add(self, other, *args, **kwargs): + raise NotImplementedError + + def addbmm(self, batch1, batch2, *args, **kwargs): + raise NotImplementedError + + def addbmm_(self, batch1, batch2, *args, **kwargs): + raise NotImplementedError + + def addcdiv(self, tensor1, tensor2, *args, **kwargs): + raise NotImplementedError + + def addcdiv_(self, tensor1, tensor2, *args, **kwargs): + raise NotImplementedError + + def addcmul(self, tensor1, tensor2, *args, **kwargs): + raise NotImplementedError + + def addcmul_(self, tensor1, tensor2, *args, **kwargs): + raise NotImplementedError + + def addmm(self, mat1, mat2, *args, **kwargs): + raise NotImplementedError + + def addmm_(self, mat1, mat2, *args, **kwargs): + raise NotImplementedError + + def addmv(self, mat, vec, *args, **kwargs): + raise NotImplementedError + + def addmv_(self, mat, vec, *args, **kwargs): + raise NotImplementedError + + def addr(self, vec1, vec2, *args, **kwargs): + raise NotImplementedError + + def addr_(self, vec1, vec2, *args, **kwargs): + raise NotImplementedError + + def add_(self, other, *args, **kwargs): + raise NotImplementedError + + def adjoint(self): + raise NotImplementedError + + def align_as(self, other): + raise NotImplementedError + + def align_to(self, *args, **kwargs): + raise NotImplementedError + + def all(self, dim=None, keepdim=False): + raise NotImplementedError + + def allclose(self, other, rtol=1, *args, **kwargs): + raise NotImplementedError + + def amax(self, dim=None, keepdim=False): + raise NotImplementedError + + def amin(self, dim=None, keepdim=False): + raise NotImplementedError + + def aminmax(self, *args, **kwargs): + raise NotImplementedError + + def angle(self): + raise NotImplementedError + + def any(self, dim=None, keepdim=False): + raise NotImplementedError + + def apply_(self, callable): + raise NotImplementedError + + def arccos(self): + raise NotImplementedError + + def arccosh(self, *args, **kwargs): + raise NotImplementedError + + def arccosh_(self, *args, **kwargs): + raise NotImplementedError + + def arccos_(self): + raise NotImplementedError + + def arcsin(self): + raise NotImplementedError + + def arcsinh(self): + raise NotImplementedError + + def arcsinh_(self): + raise NotImplementedError + + def arcsin_(self): + raise NotImplementedError + + def arctan(self): + raise NotImplementedError + + def arctan2(self, other): + raise NotImplementedError + + def arctan2_(self, *args, **kwargs): + raise NotImplementedError + + def arctanh(self): + raise NotImplementedError + + def arctanh_(self, other): + raise NotImplementedError + + def arctan_(self): + raise NotImplementedError + + def argmax(self, dim=None, keepdim=False): + raise NotImplementedError + + def argmin(self, dim=None, keepdim=False): + raise NotImplementedError + + def argsort(self, dim=-1, descending=False): + raise NotImplementedError + + def argwhere(self): + raise NotImplementedError + + def asin(self): + raise NotImplementedError + + def asinh(self): + raise NotImplementedError + + def asinh_(self): + raise NotImplementedError + + def asin_(self): + raise NotImplementedError + + def as_strided(self, size, stride, storage_offset=None): + raise NotImplementedError + + def as_strided_(self, *args, **kwargs): + raise NotImplementedError + + def as_strided_scatter(self, src, size, stride, storage_offset=None): + raise NotImplementedError + + def as_subclass(self, cls): + raise NotImplementedError + + def atan(self): + raise NotImplementedError + + def atan2(self, other): + raise NotImplementedError + + def atan2_(self, other): + raise NotImplementedError + + def atanh(self): + raise NotImplementedError + + def atanh_(self, other): + raise NotImplementedError + + def atan_(self): + raise NotImplementedError + + def baddbmm(self, batch1, batch2, *args, **kwargs): + raise NotImplementedError + + def baddbmm_(self, batch1, batch2, *args, **kwargs): + raise NotImplementedError + + def bernoulli(self, *args, **kwargs): + raise NotImplementedError + + def bernoulli_(self, p=0.5, *args, **kwargs): + raise NotImplementedError + + def bfloat16(self, memory_format=None): + raise NotImplementedError + + def bincount(self, weights=None, minlength=0): + raise NotImplementedError + + def bitwise_and(self): + raise NotImplementedError + + def bitwise_and_(self): + raise NotImplementedError + + def bitwise_left_shift(self, other): + raise NotImplementedError + + def bitwise_left_shift_(self, other): + raise NotImplementedError + + def bitwise_not(self): + raise NotImplementedError + + def bitwise_not_(self): + raise NotImplementedError + + def bitwise_or(self): + raise NotImplementedError + + def bitwise_or_(self): + raise NotImplementedError + + def bitwise_right_shift(self, other): + raise NotImplementedError + + def bitwise_right_shift_(self, other): + raise NotImplementedError + + def bitwise_xor(self): + raise NotImplementedError + + def bitwise_xor_(self): + raise NotImplementedError + + def bmm(self, batch2): + raise NotImplementedError + + def bool(self, memory_format=None): + raise NotImplementedError + + def broadcast_to(self, shape): + raise NotImplementedError + + def byte(self, memory_format=None): + raise NotImplementedError + + def cauchy_(self, median=0, sigma=1, *args, **kwargs): + raise NotImplementedError + + def ccol_indices(self, *args, **kwargs): + raise NotImplementedError + + def cdouble(self, memory_format=None): + raise NotImplementedError + + def ceil(self): + raise NotImplementedError + + def ceil_(self): + raise NotImplementedError + + def cfloat(self, memory_format=None): + raise NotImplementedError + + def chalf(self, memory_format=None): + raise NotImplementedError + + def char(self, memory_format=None): + raise NotImplementedError + + def cholesky(self, upper=False): + raise NotImplementedError + + def cholesky_inverse(self, upper=False): + raise NotImplementedError + + def cholesky_solve(self, input2, upper=False): + raise NotImplementedError + + def chunk(self, chunks, dim=0): + raise NotImplementedError + + def clamp(self, min=None, max=None): + raise NotImplementedError + + def clamp_(self, min=None, max=None): + raise NotImplementedError + + def clamp_max(self, *args, **kwargs): + raise NotImplementedError + + def clamp_max_(self, *args, **kwargs): + raise NotImplementedError + + def clamp_min(self, *args, **kwargs): + raise NotImplementedError + + def clamp_min_(self, *args, **kwargs): + raise NotImplementedError + + def clip(self, min=None, max=None): + raise NotImplementedError + + def clip_(self, min=None, max=None): + raise NotImplementedError + + def clone(self, *args, **kwargs): + raise NotImplementedError + + def coalesce(self): + raise NotImplementedError + + def col_indices(self): + raise NotImplementedError + + def conj(self): + raise NotImplementedError + + def conj_physical(self): + raise NotImplementedError + + def conj_physical_(self): + raise NotImplementedError + + def contiguous(self, memory_format=None): + raise NotImplementedError + + def copysign(self, other): + raise NotImplementedError + + def copysign_(self, other): + raise NotImplementedError + + def copy_(self, src, non_blocking=False): + raise NotImplementedError + + def corrcoef(self): + raise NotImplementedError + + def cos(self): + raise NotImplementedError + + def cosh(self): + raise NotImplementedError + + def cosh_(self): + raise NotImplementedError + + def cos_(self): + raise NotImplementedError + + def count_nonzero(self, dim=None): + raise NotImplementedError + + def cov(self, *args, **kwargs): + raise NotImplementedError + + def cpu(self, memory_format=None): + raise NotImplementedError + + def cross(self, other, dim=None): + raise NotImplementedError + + def crow_indices(self): + raise NotImplementedError + + def cuda(self, device=None, non_blocking=False, memory_format=None): + raise NotImplementedError + + def cummax(self, dim): + raise NotImplementedError + + def cummin(self, dim): + raise NotImplementedError + + def cumprod(self, dim, dtype=None): + raise NotImplementedError + + def cumprod_(self, dim, dtype=None): + raise NotImplementedError + + def cumsum(self, dim, dtype=None): + raise NotImplementedError + + def cumsum_(self, dim, dtype=None): + raise NotImplementedError + + def data_ptr(self): + + return 0 + + def deg2rad(self): + raise NotImplementedError + + def deg2rad_(self): + raise NotImplementedError + + def dense_dim(self): + + return 0 + + def dequantize(self): + raise NotImplementedError + + def det(self): + raise NotImplementedError + + def detach(self, *args, **kwargs): + raise NotImplementedError + + def detach_(self, *args, **kwargs): + raise NotImplementedError + + def diag(self, diagonal=0): + raise NotImplementedError + + def diagflat(self, offset=0): + raise NotImplementedError + + def diagonal(self, offset=0, dim1=0, dim2=1): + raise NotImplementedError + + def diagonal_scatter(self, src, offset=0, dim1=0, dim2=1): + raise NotImplementedError + + def diag_embed(self, offset=0, dim1=-2, dim2=-1): + raise NotImplementedError + + def diff(self, n=1, dim=-1, prepend=None, append=None): + raise NotImplementedError + + def digamma(self): + raise NotImplementedError + + def digamma_(self): + raise NotImplementedError + + def dim(self): + + return 0 + + def dist(self, other, p=2): + raise NotImplementedError + + def div(self, value, *args, **kwargs): + raise NotImplementedError + + def divide(self, value, *args, **kwargs): + raise NotImplementedError + + def divide_(self, value, *args, **kwargs): + raise NotImplementedError + + def div_(self, value, *args, **kwargs): + raise NotImplementedError + + def dot(self, other): + raise NotImplementedError + + def double(self, memory_format=None): + raise NotImplementedError + + def dsplit(self, split_size_or_sections): + raise NotImplementedError + + def element_size(self): + + return 0 + + def eq(self, other): + raise NotImplementedError + + def equal(self, other): + + return False + + def eq_(self, other): + raise NotImplementedError + + def erf(self): + raise NotImplementedError + + def erfc(self): + raise NotImplementedError + + def erfc_(self): + raise NotImplementedError + + def erfinv(self): + raise NotImplementedError + + def erfinv_(self): + raise NotImplementedError + + def erf_(self): + raise NotImplementedError + + def exp(self): + raise NotImplementedError + + def exp2(self): + raise NotImplementedError + + def exp2_(self): + raise NotImplementedError + + def expand(self, *sizes): + raise NotImplementedError + + def expand_as(self, other): + raise NotImplementedError + + def expm1(self): + raise NotImplementedError + + def expm1_(self): + raise NotImplementedError + + def exponential_(self, lambd=1, *args, **kwargs): + raise NotImplementedError + + def exp_(self): + raise NotImplementedError + + def fill_(self, value): + raise NotImplementedError + + def fill_diagonal_(self, fill_value, wrap=False): + raise NotImplementedError + + def fix(self): + raise NotImplementedError + + def fix_(self): + raise NotImplementedError + + def flatten(self, start_dim=0, end_dim=-1): + raise NotImplementedError + + def flip(self, dims): + raise NotImplementedError + + def fliplr(self): + raise NotImplementedError + + def flipud(self): + raise NotImplementedError + + def float(self, memory_format=None): + raise NotImplementedError + + def float_power(self, exponent): + raise NotImplementedError + + def float_power_(self, exponent): + raise NotImplementedError + + def floor(self): + raise NotImplementedError + + def floor_(self): + raise NotImplementedError + + def floor_divide(self, value): + raise NotImplementedError + + def floor_divide_(self, value): + raise NotImplementedError + + def fmax(self, other): + raise NotImplementedError + + def fmin(self, other): + raise NotImplementedError + + def fmod(self, divisor): + raise NotImplementedError + + def fmod_(self, divisor): + raise NotImplementedError + + def frac(self): + raise NotImplementedError + + def frac_(self): + raise NotImplementedError + + def frexp(self, input): + raise NotImplementedError + + def gather(self, dim, index): + raise NotImplementedError + + def gcd(self, other): + raise NotImplementedError + + def gcd_(self, other): + raise NotImplementedError + + def ge(self, other): + raise NotImplementedError + + def geometric_(self, p, *args, **kwargs): + raise NotImplementedError + + def geqrf(self): + raise NotImplementedError + + def ger(self, vec2): + raise NotImplementedError + + def get_device(self): + raise NotImplementedError + + def ge_(self, other): + raise NotImplementedError + + def greater(self, other): + raise NotImplementedError + + def greater_(self, other): + raise NotImplementedError + + def greater_equal(self, other): + raise NotImplementedError + + def greater_equal_(self, other): + raise NotImplementedError + + def gt(self, other): + raise NotImplementedError + + def gt_(self, other): + raise NotImplementedError + + def half(self, memory_format=None): + raise NotImplementedError + + def hardshrink(self, lambd=0.5): + raise NotImplementedError + + def has_names(self, *args, **kwargs): + raise NotImplementedError + + def heaviside(self, values): + raise NotImplementedError + + def heaviside_(self, values): + raise NotImplementedError + + def histc(self, bins=100, min=0, max=0): + raise NotImplementedError + + def histogram(self, input, bins, *args, **kwargs): + raise NotImplementedError + + def hsplit(self, split_size_or_sections): + raise NotImplementedError + + def hypot(self, other): + raise NotImplementedError + + def hypot_(self, other): + raise NotImplementedError + + def i0(self): + raise NotImplementedError + + def i0_(self): + raise NotImplementedError + + def igamma(self, other): + raise NotImplementedError + + def igammac(self, other): + raise NotImplementedError + + def igammac_(self, other): + raise NotImplementedError + + def igamma_(self, other): + raise NotImplementedError + + def index_add(self, dim, index, source, *args, **kwargs): + raise NotImplementedError + + def index_add_(self, dim, index, source, *args, **kwargs): + raise NotImplementedError + + def index_copy(self, dim, index, tensor2): + raise NotImplementedError + + def index_copy_(self, dim, index, tensor): + raise NotImplementedError + + def index_fill(self, dim, index, value): + raise NotImplementedError + + def index_fill_(self, dim, index, value): + raise NotImplementedError + + def index_put(self, indices, values, accumulate=False): + raise NotImplementedError + + def index_put_(self, indices, values, accumulate=False): + raise NotImplementedError + + def index_reduce(self, *args, **kwargs): + raise NotImplementedError + + def index_reduce_(self, dim, index, source, reduce, *args, **kwargs): + raise NotImplementedError + + def index_select(self, dim, index): + raise NotImplementedError + + def indices(self): + raise NotImplementedError + + def inner(self, other): + raise NotImplementedError + + def int(self, memory_format=None): + raise NotImplementedError + + def int_repr(self): + raise NotImplementedError + + def inverse(self): + raise NotImplementedError + + def ipu(self, device=None, non_blocking=False, memory_format=None): + raise NotImplementedError + + def isclose(self, other, rtol=1, *args, **kwargs): + raise NotImplementedError + + def isfinite(self): + raise NotImplementedError + + def isinf(self): + raise NotImplementedError + + def isnan(self): + raise NotImplementedError + + def isneginf(self): + raise NotImplementedError + + def isposinf(self): + raise NotImplementedError + + def isreal(self): + raise NotImplementedError + + def istft( + self, + n_fft, + hop_length=None, + win_length=None, + window=None, + center=True, + normalized=False, + onesided=True, + length=None, + ): + raise NotImplementedError + + def is_coalesced(self): + + return False + + def is_complex(self): + + return False + + def is_conj(self): + + return False + + def is_contiguous(self, memory_format=None): + + return False + + def is_distributed(self, *args, **kwargs): + raise NotImplementedError + + def is_floating_point(self): + + return False + + def is_inference(self): + + return False + + def is_neg(self): + + return False + + def is_nonzero(self, *args, **kwargs): + raise NotImplementedError + + def is_pinned(self, *args, **kwargs): + raise NotImplementedError + + def is_same_size(self, *args, **kwargs): + raise NotImplementedError + + def is_set_to(self, tensor): + + return False + + def is_signed(self): + + return False + + def item(self): + + return 0 + + def kron(self, other): + raise NotImplementedError + + def kthvalue(self, k, dim=None, keepdim=False): + raise NotImplementedError + + def lcm(self, other): + raise NotImplementedError + + def lcm_(self, other): + raise NotImplementedError + + def ldexp(self, other): + raise NotImplementedError + + def ldexp_(self, other): + raise NotImplementedError + + def le(self, other): + raise NotImplementedError + + def lerp(self, end, weight): + raise NotImplementedError + + def lerp_(self, end, weight): + raise NotImplementedError + + def less(self, *args, **kwargs): + raise NotImplementedError + + def less_(self, other): + raise NotImplementedError + + def less_equal(self, other): + raise NotImplementedError + + def less_equal_(self, other): + raise NotImplementedError + + def le_(self, other): + raise NotImplementedError + + def lgamma(self): + raise NotImplementedError + + def lgamma_(self): + raise NotImplementedError + + def log(self): + raise NotImplementedError + + def log10(self): + raise NotImplementedError + + def log10_(self): + raise NotImplementedError + + def log1p(self): + raise NotImplementedError + + def log1p_(self): + raise NotImplementedError + + def log2(self): + raise NotImplementedError + + def log2_(self): + raise NotImplementedError + + def logaddexp(self, other): + raise NotImplementedError + + def logaddexp2(self, other): + raise NotImplementedError + + def logcumsumexp(self, dim): + raise NotImplementedError + + def logdet(self): + raise NotImplementedError + + def logical_and(self): + raise NotImplementedError + + def logical_and_(self): + raise NotImplementedError + + def logical_not(self): + raise NotImplementedError + + def logical_not_(self): + raise NotImplementedError + + def logical_or(self): + raise NotImplementedError + + def logical_or_(self): + raise NotImplementedError + + def logical_xor(self): + raise NotImplementedError + + def logical_xor_(self): + raise NotImplementedError + + def logit(self): + raise NotImplementedError + + def logit_(self): + raise NotImplementedError + + def logsumexp(self, dim, keepdim=False): + raise NotImplementedError + + def log_(self): + raise NotImplementedError + + def log_normal_(self, mean=1, std=2, *args, **kwargs): + raise NotImplementedError + + def log_softmax(self, *args, **kwargs): + raise NotImplementedError + + def long(self, memory_format=None): + raise NotImplementedError + + def lt(self, other): + raise NotImplementedError + + def lt_(self, other): + raise NotImplementedError + + def lu_solve(self, LU_data, LU_pivots): + raise NotImplementedError + + def map2_(self, *args, **kwargs): + raise NotImplementedError + + def map_(self, tensor, callable): + raise NotImplementedError + + def masked_fill(self, mask, value): + raise NotImplementedError + + def masked_fill_(self, mask, value): + raise NotImplementedError + + def masked_scatter(self, mask, tensor): + raise NotImplementedError + + def masked_scatter_(self, mask, source): + raise NotImplementedError + + def masked_select(self, mask): + raise NotImplementedError + + def matmul(self, tensor2): + raise NotImplementedError + + def matrix_exp(self): + raise NotImplementedError + + def matrix_power(self, n): + raise NotImplementedError + + def max(self, dim=None, keepdim=False): + raise NotImplementedError + + def maximum(self, other): + raise NotImplementedError + + def mean(self, dim=None, keepdim=False, *args, **kwargs): + raise NotImplementedError + + def median(self, dim=None, keepdim=False): + raise NotImplementedError + + def min(self, dim=None, keepdim=False): + raise NotImplementedError + + def minimum(self, other): + raise NotImplementedError + + def mm(self, mat2): + raise NotImplementedError + + def mode(self, dim=None, keepdim=False): + raise NotImplementedError + + def moveaxis(self, source, destination): + raise NotImplementedError + + def movedim(self, source, destination): + raise NotImplementedError + + def msort(self): + raise NotImplementedError + + def mul(self, value): + raise NotImplementedError + + def multinomial(self, num_samples, replacement=False, *args, **kwargs): + raise NotImplementedError + + def multiply(self, value): + raise NotImplementedError + + def multiply_(self, value): + raise NotImplementedError + + def mul_(self, value): + raise NotImplementedError + + def mv(self, vec): + raise NotImplementedError + + def mvlgamma(self, p): + raise NotImplementedError + + def mvlgamma_(self, p): + raise NotImplementedError + + def nanmean(self, dim=None, keepdim=False, *args, **kwargs): + raise NotImplementedError + + def nanmedian(self, dim=None, keepdim=False): + raise NotImplementedError + + def nanquantile(self, q, dim=None, keepdim=False, *args, **kwargs): + raise NotImplementedError + + def nansum(self, dim=None, keepdim=False, dtype=None): + raise NotImplementedError + + def nan_to_num(self, nan=0.0, posinf=None, neginf=None): + raise NotImplementedError + + def nan_to_num_(self, nan=0.0, posinf=None, neginf=None): + raise NotImplementedError + + def narrow(self, dimension, start, length): + raise NotImplementedError + + def narrow_copy(self, dimension, start, length): + raise NotImplementedError + + def ndimension(self): + + return 0 + + def ne(self, other): + raise NotImplementedError + + def neg(self): + raise NotImplementedError + + def negative(self): + raise NotImplementedError + + def negative_(self): + raise NotImplementedError + + def neg_(self): + raise NotImplementedError + + def nelement(self): + + return 0 + + def new(self, *args, **kwargs): + raise NotImplementedError + + def new_empty(self, size, *args, **kwargs): + raise NotImplementedError + + def new_empty_strided( + self, + size, + stride, + dtype=None, + device=None, + requires_grad=False, + layout=None, + pin_memory=False, + ): + raise NotImplementedError + + def new_full(self, size, fill_value, *args, **kwargs): + raise NotImplementedError + + def new_ones(self, size, *args, **kwargs): + raise NotImplementedError + + def new_tensor(self, data, *args, **kwargs): + raise NotImplementedError + + def new_zeros(self, size, *args, **kwargs): + raise NotImplementedError + + def nextafter(self, other): + raise NotImplementedError + + def nextafter_(self, other): + raise NotImplementedError + + def ne_(self, other): + raise NotImplementedError + + def nonzero(self): + raise NotImplementedError + + def norm(self, p=2, dim=None, keepdim=False): + raise NotImplementedError + + def normal_(self, mean=0, std=1, *args, **kwargs): + raise NotImplementedError + + def not_equal(self, other): + raise NotImplementedError + + def not_equal_(self, other): + raise NotImplementedError + + def numel(self): + + return 0 + + def numpy(self, *args, **kwargs): + raise NotImplementedError + + def orgqr(self, input2): + raise NotImplementedError + + def ormqr(self, input2, input3, left=True, transpose=False): + raise NotImplementedError + + def outer(self, vec2): + raise NotImplementedError + + def permute(self, *dims): + raise NotImplementedError + + def pinverse(self): + raise NotImplementedError + + def pin_memory(self): + raise NotImplementedError + + def polygamma(self, n): + raise NotImplementedError + + def polygamma_(self, n): + raise NotImplementedError + + def positive(self): + raise NotImplementedError + + def pow(self, exponent): + raise NotImplementedError + + def pow_(self, exponent): + raise NotImplementedError + + def prelu(self, *args, **kwargs): + raise NotImplementedError + + def prod(self, dim=None, keepdim=False, dtype=None): + raise NotImplementedError + + def put(self, input, index, source, accumulate=False): + raise NotImplementedError + + def put_(self, index, source, accumulate=False): + raise NotImplementedError + + def qr(self, some=True): + raise NotImplementedError + + def qscheme(self): + raise NotImplementedError + + def quantile(self, q, dim=None, keepdim=False, *args, **kwargs): + raise NotImplementedError + + def q_per_channel_axis(self): + + return 0 + + def q_per_channel_scales(self): + raise NotImplementedError + + def q_per_channel_zero_points( + self, + ): + raise NotImplementedError + + def q_scale(self): + + return 0.0 + + def q_zero_point(self): + + return 0 + + def rad2deg(self): + raise NotImplementedError + + def rad2deg_(self): + raise NotImplementedError + + def random_(self, from_=0, to=None, *args, **kwargs): + raise NotImplementedError + + def ravel(self): + raise NotImplementedError + + def reciprocal(self): + raise NotImplementedError + + def reciprocal_(self): + raise NotImplementedError + + def record_stream(self, stream): + raise NotImplementedError + + def refine_names(self, *args, **kwargs): + raise NotImplementedError + + def relu(self, *args, **kwargs): + raise NotImplementedError + + def relu_(self, *args, **kwargs): + raise NotImplementedError + + def remainder(self, divisor): + raise NotImplementedError + + def remainder_(self, divisor): + raise NotImplementedError + + def rename(self, *args, **kwargs): + raise NotImplementedError + + def rename_(self, *args, **kwargs): + raise NotImplementedError + + def renorm(self, p, dim, maxnorm): + raise NotImplementedError + + def renorm_(self, p, dim, maxnorm): + raise NotImplementedError + + def repeat(self, *sizes): + raise NotImplementedError + + def repeat_interleave(self, repeats, dim=None, *args, **kwargs): + raise NotImplementedError + + def requires_grad_(self, requires_grad=True): + raise NotImplementedError + + def reshape(self, *shape): + raise NotImplementedError + + def reshape_as(self, other): + raise NotImplementedError + + def resize_(self, *sizes, memory_format=None): + raise NotImplementedError + + def resize_as_(self, tensor, memory_format=None): + raise NotImplementedError + + def resize_as_sparse_(self, *args, **kwargs): + raise NotImplementedError + + def resolve_conj(self): + raise NotImplementedError + + def resolve_neg(self): + raise NotImplementedError + + def retain_grad(self): + raise NotImplementedError + + def roll(self, shifts, dims): + raise NotImplementedError + + def rot90(self, k, dims): + raise NotImplementedError + + def round(self, decimals=0): + raise NotImplementedError + + def round_(self, decimals=0): + raise NotImplementedError + + def row_indices(self, *args, **kwargs): + raise NotImplementedError + + def rsqrt(self): + raise NotImplementedError + + def rsqrt_(self): + raise NotImplementedError + + def scatter(self, dim, index, src): + raise NotImplementedError + + def scatter_(self, dim, index, src, reduce=None): + raise NotImplementedError + + def scatter_add(self, dim, index, src): + raise NotImplementedError + + def scatter_add_(self, dim, index, src): + raise NotImplementedError + + def scatter_reduce(self, dim, index, src, reduce, *args, **kwargs): + raise NotImplementedError + + def scatter_reduce_(self, dim, index, src, reduce, *args, **kwargs): + raise NotImplementedError + + def select(self, dim, index): + raise NotImplementedError + + def select_scatter(self, src, dim, index): + raise NotImplementedError + + def set_(self, source=None, storage_offset=0, size=None, stride=None): + raise NotImplementedError + + def sgn(self): + raise NotImplementedError + + def sgn_(self): + raise NotImplementedError + + def short(self, memory_format=None): + raise NotImplementedError + + def sigmoid(self): + raise NotImplementedError + + def sigmoid_(self): + raise NotImplementedError + + def sign(self): + raise NotImplementedError + + def signbit(self): + raise NotImplementedError + + def sign_(self): + raise NotImplementedError + + def sin(self): + raise NotImplementedError + + def sinc(self): + raise NotImplementedError + + def sinc_(self): + raise NotImplementedError + + def sinh(self): + raise NotImplementedError + + def sinh_(self): + raise NotImplementedError + + def sin_(self): + raise NotImplementedError + + def size(self, dim=None): + raise NotImplementedError + + def slice_scatter(self, src, dim=0, start=None, end=None, step=1): + raise NotImplementedError + + def slogdet(self): + raise NotImplementedError + + def smm(self, mat): + raise NotImplementedError + + def softmax(self, *args, **kwargs): + raise NotImplementedError + + def sort(self, dim=-1, descending=False): + raise NotImplementedError + + def sparse_dim(self): + + return 0 + + def sparse_mask(self, mask): + raise NotImplementedError + + def sparse_resize_(self, size, sparse_dim, dense_dim): + raise NotImplementedError + + def sparse_resize_and_clear_(self, size, sparse_dim, dense_dim): + raise NotImplementedError + + def split(self, *args, **kwargs): + raise NotImplementedError + + def split_with_sizes(self, *args, **kwargs): + raise NotImplementedError + + def sqrt(self): + raise NotImplementedError + + def sqrt_(self): + raise NotImplementedError + + def square(self): + raise NotImplementedError + + def square_(self): + raise NotImplementedError + + def squeeze(self, dim=None): + raise NotImplementedError + + def squeeze_(self, dim=None): + raise NotImplementedError + + def sspaddmm(self, mat1, mat2, *args, **kwargs): + raise NotImplementedError + + def std(self, dim, unbiased=True, keepdim=False): + raise NotImplementedError + + def stft( + self, + frame_length, + hop, + fft_size=None, + return_onesided=True, + window=None, + pad_end=0, + ): + raise NotImplementedError + + def storage_offset(self): + + return 0 + + def stride(self, dim): + + return () + + def sub(self, other, *args, **kwargs): + raise NotImplementedError + + def subtract(self, other, *args, **kwargs): + raise NotImplementedError + + def subtract_(self, other, *args, **kwargs): + raise NotImplementedError + + def sub_(self, other, *args, **kwargs): + raise NotImplementedError + + def sum(self, dim=None, keepdim=False, dtype=None): + raise NotImplementedError + + def sum_to_size(self, *size): + raise NotImplementedError + + def svd(self, some=True, compute_uv=True): + raise NotImplementedError + + def swapaxes(self, axis0, axis1): + raise NotImplementedError + + def swapaxes_(self, axis0, axis1): + raise NotImplementedError + + def swapdims(self, dim0, dim1): + raise NotImplementedError + + def swapdims_(self, dim0, dim1): + raise NotImplementedError + + def symeig(self, eigenvectors=False, upper=True): + raise NotImplementedError + + def t(self): + raise NotImplementedError + + def take(self, indices): + raise NotImplementedError + + def take_along_dim(self, indices, dim): + raise NotImplementedError + + def tan(self): + raise NotImplementedError + + def tanh(self): + raise NotImplementedError + + def tanh_(self): + raise NotImplementedError + + def tan_(self): + raise NotImplementedError + + def tensor_split(self, indices_or_sections, dim=0): + raise NotImplementedError + + def tile(self, *reps): + raise NotImplementedError + + def to(self, *args, **kwargs): + raise NotImplementedError + + def tolist(self): + raise NotImplementedError + + def topk(self, k, dim=None, largest=True, sorted=True): + raise NotImplementedError + + def to_dense(self): + raise NotImplementedError + + def to_mkldnn(self): + raise NotImplementedError + + def to_padded_tensor(self, padding, output_size=None): + raise NotImplementedError + + def to_sparse(self, sparseDims): + raise NotImplementedError + + def to_sparse_bsc(self, blocksize): + raise NotImplementedError + + def to_sparse_bsr(self, blocksize): + raise NotImplementedError + + def to_sparse_csc(self): + raise NotImplementedError + + def to_sparse_csr(self): + raise NotImplementedError + + def trace(self): + raise NotImplementedError + + def transpose(self, dim0, dim1): + raise NotImplementedError + + def transpose_(self, dim0, dim1): + raise NotImplementedError + + def triangular_solve(self, A, upper=True, transpose=False, unitriangular=False): + raise NotImplementedError + + def tril(self, diagonal=0): + raise NotImplementedError + + def tril_(self, diagonal=0): + raise NotImplementedError + + def triu(self, diagonal=0): + raise NotImplementedError + + def triu_(self, diagonal=0): + raise NotImplementedError + + def true_divide(self, value): + raise NotImplementedError + + def true_divide_(self, value): + raise NotImplementedError + + def trunc(self): + raise NotImplementedError + + def trunc_(self): + raise NotImplementedError + + # def type( + # self, dtype=None, non_blocking=False, **kwargs + # ): + # return "" + + def type_as(self, tensor): + raise NotImplementedError + + def t_(self): + raise NotImplementedError + + def unbind(self, dim=0): + raise NotImplementedError + + def unflatten(self, *args, **kwargs): + raise NotImplementedError + + def unfold(self, dimension, size, step): + raise NotImplementedError + + def uniform_(self, from_=0, to=1): + raise NotImplementedError + + def unsafe_chunk(self, chunks, dim=0): + raise NotImplementedError + + def unsafe_split(self, split_size, dim=0): + raise NotImplementedError + + def unsafe_split_with_sizes(self, *args, **kwargs): + raise NotImplementedError + + def unsqueeze(self, dim): + raise NotImplementedError + + def unsqueeze_(self, dim): + raise NotImplementedError + + def values(self): + raise NotImplementedError + + def var(self, dim, unbiased=True, keepdim=False): + raise NotImplementedError + + def vdot(self, other): + raise NotImplementedError + + def view(self, *shape): + raise NotImplementedError + + def view_as(self, other): + raise NotImplementedError + + def vsplit(self, split_size_or_sections): + raise NotImplementedError + + def where(self, condition, y): + raise NotImplementedError + + def xlogy(self, other): + raise NotImplementedError + + def xlogy_(self, other): + raise NotImplementedError + + def xpu(self, device=None, non_blocking=False, memory_format=None): + raise NotImplementedError + + def zero_(self): + raise NotImplementedError + + def _addmm_activation(self, *args, **kwargs): + raise NotImplementedError + + def _autocast_to_full_precision(self, *args, **kwargs): + raise NotImplementedError + + def _autocast_to_reduced_precision(self, *args, **kwargs): + raise NotImplementedError + + def _coalesced_(self, *args, **kwargs): + raise NotImplementedError + + def _conj(self, *args, **kwargs): + raise NotImplementedError + + def _conj_physical(self, *args, **kwargs): + raise NotImplementedError + + def _dimI(self, *args, **kwargs): + raise NotImplementedError + + def _dimV(self, *args, **kwargs): + raise NotImplementedError + + def _fix_weakref(self, *args, **kwargs): + raise NotImplementedError + + def _indices(self, *args, **kwargs): + raise NotImplementedError + + def _is_view(self, *args, **kwargs): + raise NotImplementedError + + def _is_zerotensor(self, *args, **kwargs): + raise NotImplementedError + + def _make_subclass(self, *args, **kwargs): + raise NotImplementedError + + def _make_wrapper_subclass(self, *args, **kwargs): + raise NotImplementedError + + def _neg_view(self, *args, **kwargs): + raise NotImplementedError + + def _nested_tensor_size(self, *args, **kwargs): + raise NotImplementedError + + def _nnz(self, *args, **kwargs): + raise NotImplementedError + + def _storage(self, *args, **kwargs): + raise NotImplementedError + + def _to_dense(self, *args, **kwargs): + raise NotImplementedError + + def _values(self, *args, **kwargs): + raise NotImplementedError + + def __add__(self, *args, **kwargs): + return pi.add_Tensor(self, args[0]) + + def __and__(self, *args, **kwargs): + raise NotImplementedError + + def __bool__(self, *args, **kwargs): + raise NotImplementedError + + def __complex__(self, *args, **kwargs): + raise NotImplementedError + + def __delitem__(self, *args, **kwargs): + raise NotImplementedError + + def __div__(self, *args, **kwargs): + raise NotImplementedError + + def __eq__(self, *args, **kwargs): + raise NotImplementedError + + def __float__(self, *args, **kwargs): + raise NotImplementedError + + def __floordiv__(self, *args, **kwargs): + raise NotImplementedError + + def __getitem__(self, *args, **kwargs): + raise NotImplementedError + + def __ge__(self, *args, **kwargs): + raise NotImplementedError + + def __gt__(self, *args, **kwargs): + raise NotImplementedError + + # + # def __iadd__(self, *args, **kwargs): + # raise NotImplementedError + # + # def __iand__(self, *args, **kwargs): + # raise NotImplementedError + # + # def __idiv__(self, *args, **kwargs): + # raise NotImplementedError + # + # def __ifloordiv__(self, *args, **kwargs): + # raise NotImplementedError + # + # def __ilshift__(self, *args, **kwargs): + # raise NotImplementedError + # + # def __imod__(self, *args, **kwargs): + # raise NotImplementedError + # + # def __imul__(self, *args, **kwargs): + # raise NotImplementedError + # + # def __index__(self, *args, **kwargs): + # raise NotImplementedError + # + def __int__(self, *args, **kwargs): + raise NotImplementedError + + def __invert__(self, *args, **kwargs): + raise NotImplementedError + + # + # def __ior__(self, *args, **kwargs): + # raise NotImplementedError + # + # def __irshift__(self, *args, **kwargs): + # raise NotImplementedError + # + # def __isub__(self, *args, **kwargs): + # raise NotImplementedError + # + # def __ixor__(self, *args, **kwargs): + # raise NotImplementedError + # + def __len__(self, *args, **kwargs): + raise NotImplementedError + + def __le__(self, *args, **kwargs): + raise NotImplementedError + + def __long__(self, *args, **kwargs): + raise NotImplementedError + + # + # def __lshift__(self, *args, **kwargs): + # raise NotImplementedError + # + def __lt__(self, *args, **kwargs): + raise NotImplementedError + + # + # def __matmul__(self, *args, **kwargs): + # raise NotImplementedError + # + def __mod__(self, *args, **kwargs): + raise NotImplementedError + + # + def __mul__(self, *args, **kwargs): + return pi.mul_Tensor(self, args[0]) + + def __ne__(self, *args, **kwargs): + raise NotImplementedError + + def __nonzero__(self, *args, **kwargs): + raise NotImplementedError + + def __or__(self, *args, **kwargs): + raise NotImplementedError + + # def __radd__(self, *args, **kwargs): + # raise NotImplementedError + # + # def __rand__(self, *args, **kwargs): + # raise NotImplementedError + # + # def __rmul__(self, *args, **kwargs): + # raise NotImplementedError + # + # def __ror__(self, *args, **kwargs): + # raise NotImplementedError + # + # def __rshift__(self, *args, **kwargs): + # raise NotImplementedError + # + # def __rxor__(self, *args, **kwargs): + # raise NotImplementedError + # + def __setitem__(self, *args, **kwargs): + raise NotImplementedError + + def __sub__(self, *args, **kwargs): + raise NotImplementedError + + def __truediv__(self, *args, **kwargs): + raise NotImplementedError + + # def __xor__(self, *args, **kwargs): + # raise NotImplementedError + + # data = property( + # lambda self: object(), lambda self, v: None, lambda self: None + # ) # default + # + # device = property( + # lambda self: object(), lambda self, v: None, lambda self: None + # ) # default + # + # dtype = property( + # lambda self: object(), lambda self, v: None, lambda self: None + # ) # default + # + # grad = property( + # lambda self: object(), lambda self, v: None, lambda self: None + # ) # default + # + # grad_fn = property( + # lambda self: object(), lambda self, v: None, lambda self: None + # ) # default + # + # H = property( + # lambda self: object(), lambda self, v: None, lambda self: None + # ) # default + # + # imag = property( + # lambda self: object(), lambda self, v: None, lambda self: None + # ) # default + # + # is_cpu = property( + # lambda self: object(), lambda self, v: None, lambda self: None + # ) # default + # + # is_cuda = property( + # lambda self: object(), lambda self, v: None, lambda self: None + # ) # default + # + # is_ipu = property( + # lambda self: object(), lambda self, v: None, lambda self: None + # ) # default + # + # is_leaf = property( + # lambda self: object(), lambda self, v: None, lambda self: None + # ) # default + # + # is_meta = property( + # lambda self: object(), lambda self, v: None, lambda self: None + # ) # default + # + # is_mkldnn = property( + # lambda self: object(), lambda self, v: None, lambda self: None + # ) # default + # + # is_mps = property( + # lambda self: object(), lambda self, v: None, lambda self: None + # ) # default + # + # is_nested = property( + # lambda self: object(), lambda self, v: None, lambda self: None + # ) # default + # + # is_ort = property( + # lambda self: object(), lambda self, v: None, lambda self: None + # ) # default + # + # is_quantized = property( + # lambda self: object(), lambda self, v: None, lambda self: None + # ) # default + # + # is_sparse = property( + # lambda self: object(), lambda self, v: None, lambda self: None + # ) # default + # + # is_sparse_csr = property( + # lambda self: object(), lambda self, v: None, lambda self: None + # ) # default + # + # is_vulkan = property( + # lambda self: object(), lambda self, v: None, lambda self: None + # ) # default + # + # is_xpu = property( + # lambda self: object(), lambda self, v: None, lambda self: None + # ) # default + # + # layout = property( + # lambda self: object(), lambda self, v: None, lambda self: None + # ) # default + # + # mH = property( + # lambda self: object(), lambda self, v: None, lambda self: None + # ) # default + # + # mT = property( + # lambda self: object(), lambda self, v: None, lambda self: None + # ) # default + # + # name = property( + # lambda self: object(), lambda self, v: None, lambda self: None + # ) # default + # + # names = property( + # lambda self: object(), lambda self, v: None, lambda self: None + # ) # default + # + # ndim = property( + # lambda self: object(), lambda self, v: None, lambda self: None + # ) # default + # + # output_nr = property( + # lambda self: object(), lambda self, v: None, lambda self: None + # ) # default + # + # real = property( + # lambda self: object(), lambda self, v: None, lambda self: None + # ) # default + # + # requires_grad = property( + # lambda self: object(), lambda self, v: None, lambda self: None + # ) # default + # + # retains_grad = property( + # lambda self: object(), lambda self, v: None, lambda self: None + # ) # default + # + # shape = property( + # lambda self: object(), lambda self, v: None, lambda self: None + # ) # default + # + # T = property( + # lambda self: object(), lambda self, v: None, lambda self: None + # ) # default + # + # volatile = property( + # lambda self: object(), lambda self, v: None, lambda self: None + # ) # default + # + # _backward_hooks = property( + # lambda self: object(), lambda self, v: None, lambda self: None + # ) # default + # + # _base = property( + # lambda self: object(), lambda self, v: None, lambda self: None + # ) # default + # + # _cdata = property( + # lambda self: object(), lambda self, v: None, lambda self: None + # ) # default + # + # _grad = property( + # lambda self: object(), lambda self, v: None, lambda self: None + # ) # default + # + # _grad_fn = property( + # lambda self: object(), lambda self, v: None, lambda self: None + # ) # default + # + # _has_symbolic_sizes_strides = property( + # lambda self: object(), lambda self, v: None, lambda self: None + # ) # default + # + # _python_dispatch = property( + # lambda self: object(), lambda self, v: None, lambda self: None + # ) # default + # + # _version = property( + # lambda self: object(), lambda self, v: None, lambda self: None + # ) # default + + +def from_numpy(arr: np.ndarray): + from pi import DEBUG + + if DEBUG: + arr = np.ones_like(arr, dtype=np.float32) + attr = DenseElementsAttr.get(arr) + vt = Tensor(torch_dialect.ValueTensorLiteralOp(attr)) + return vt + + +def empty(shape: Tuple[int, ...], dtype: "pi.dtype" = None, **kwargs) -> Tensor: + if np.prod(shape) == 0: + return Tensor(None) + else: + if dtype is not None: + dtype = dtype.to_np_type() + + return from_numpy(np.empty(shape, dtype)) + + +def randint(low: int, high: int, size: Tuple[int, ...]) -> Tensor: + return from_numpy(np.random.randint(low, high, size)) + + +def randn(*size: Tuple[int, ...]) -> Tensor: + return from_numpy(np.random.randn(*size)) + + +def uniform(low: float, high: float, size: Tuple[int, ...]) -> Tensor: + return from_numpy(np.random.uniform(low, high, size)) + + +def rand(*size: Tuple[int, ...], **kwargs) -> Tensor: + dtype = kwargs.get("dtype", None) + if dtype is not None: + dtype = dtype.to_np_type() + return from_numpy(np.random.rand(*size)) + + +def ones(*size: Tuple[int, ...], **kwargs) -> Tensor: + # dtype: "pi.dtype" = None, _device: Any = None + dtype = kwargs.get("dtype", None) + if dtype is not None: + dtype = dtype.to_np_type() + return from_numpy(np.ones(size, dtype=dtype)) + + +def zeros(*size: Tuple[int, ...], **kwargs) -> Tensor: + dtype = kwargs.get("dtype", None) + if dtype is not None: + dtype = dtype.to_np_type() + return from_numpy(np.zeros(size, dtype)) + + +def tensor(data: Any, dtype: Optional["pi.dtype"] = None) -> Tensor: + if dtype is not None: + dtype = dtype.to_np_type() + + return from_numpy(np.array(data, dtype=dtype)) + + +def LongTensor(data: Any) -> Tensor: + return from_numpy(np.array(data, dtype=pi_dtype.int64.to_np_type())) + + +def clone(x: Tensor, **kwargs): + # TODO(max): is this necessary? + warnings.warn(f"not actually cloning") + return x + + +__all__ = [ + "from_numpy", + "empty", + "randint", + "randn", + "rand", + "uniform", + "ones", + "tensor", + "Tensor", + "LongTensor", + "zeros", +] diff --git a/shark/compiler/__init__.py b/pi/compiler/__init__.py similarity index 100% rename from shark/compiler/__init__.py rename to pi/compiler/__init__.py diff --git a/pi/compiler/annotations.py b/pi/compiler/annotations.py new file mode 100644 index 0000000..d82f152 --- /dev/null +++ b/pi/compiler/annotations.py @@ -0,0 +1,102 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. +import functools +import inspect +from collections import OrderedDict +from typing import List, Optional, Tuple, Union + +import pi +from torch_mlir.dialects.torch.importer.jit_ir.build_tools.torch_ods_gen import ( + get_ods_type, +) +from torch_mlir import ir + +PI_EXPORT_ATTR_NAME = "_PI_EXPORT" +PI_ARG_ANNOTATIONS_ATTR_NAME = "_PI_ARG_ANNOTATIONS" + + +def export(fn): + # setattr(fn, PI_EXPORT_ATTR_NAME, True) + return fn + + +ArgAnnotation = Union[type, Tuple[List[int], pi.dtype]] + + +class TensorPlaceholder: + def __init__(self, shape: List[int], dtype: pi.dtype): + self.shape = shape + self.dtype = dtype + + def to_value_tensor_type(self): + dtype = self.dtype.to_mlir_type() + type = ir.Type.parse( + f"!torch.vtensor<[{','.join(map(str, self.shape))}],{dtype}>" + ) + return type + + def to(self, dtype: pi.dtype): + self.dtype = dtype + return self + + def type(self, dtype): + return self.to(dtype) + + def bool(self): + return self.to(pi.dtype.bool) + + def double(self): + self.dtype = pi.dtype.float64 + return self + + +def annotations_to_placeholders( + args: List[str], annotations: List[Optional[ArgAnnotation]] +) -> OrderedDict: + placeholders = OrderedDict() + for annotation, arg in zip(annotations, args): + # Skip the "self" annotation. + if annotation is None: + assert arg == "self" + continue + shape, dtype, value_tensor = annotation + assert value_tensor, f"non-value tensors not supported {arg}" + placeholders[arg] = TensorPlaceholder(annotation[0], annotation[1]) + return placeholders + + +# TODO: Replace with py3 extended argument annotations when available. +# See https://www.python.org/dev/peps/pep-0593/ +def annotate_args(annotations: List[Optional[ArgAnnotation]]): + def decorator(fn): + arg_spec = inspect.getfullargspec(fn) + placeholders = annotations_to_placeholders(arg_spec.args, annotations) + setattr(fn, "__placeholders__", placeholders) + return fn + + return decorator + + +def convert_annotations_to_placeholders(forward_method): + """Converts the annotations on a forward method into tensor placeholders. + + These placeholders are suitable for being passed to `torch_mlir.compile`. + """ + annotations = getattr(forward_method, PI_ARG_ANNOTATIONS_ATTR_NAME) + placeholders = [] + # Skip the "self" annotation. + for annotation in annotations[1:]: + placeholders.append(TensorPlaceholder(annotation[0], annotation[1])) + return placeholders + + +def pipile(annotations: List[Optional[ArgAnnotation]]): + def actual_decorator(func): + func = export(func) + if len(annotations): + func = annotate_args(annotations)(func) + return func + + return actual_decorator diff --git a/shark/compiler/compiler.py b/pi/compiler/compiler.py similarity index 67% rename from shark/compiler/compiler.py rename to pi/compiler/compiler.py index db77661..178a1b0 100644 --- a/shark/compiler/compiler.py +++ b/pi/compiler/compiler.py @@ -1,21 +1,18 @@ +import os + from torch_mlir import ir # noinspection PyUnresolvedReferences -from torch_mlir.dialects import ( - arith, - linalg, - math, - memref, - torch as torch_dialect -) +from torch_mlir.dialects import arith, linalg, math, memref, torch as torch_dialect # noinspection PyUnresolvedReferences -from shark.dialects import affine_ +from pi.dialects import affine_ -from shark.compiler.tracing.trace import trace +from pi.compiler.tracing.trace import trace def mlir_trace(script_path): + assert os.path.isabs(script_path), f"script path must be absolute {script_path}" top_mlir_context = ir.Context() mlir_location = ir.Location.unknown(context=top_mlir_context) with top_mlir_context, mlir_location: diff --git a/shark/compiler/config.py b/pi/compiler/config.py similarity index 100% rename from shark/compiler/config.py rename to pi/compiler/config.py diff --git a/shark/compiler/tracing/__init__.py b/pi/compiler/tracing/__init__.py similarity index 100% rename from shark/compiler/tracing/__init__.py rename to pi/compiler/tracing/__init__.py diff --git a/pi/compiler/tracing/handlers.py b/pi/compiler/tracing/handlers.py new file mode 100644 index 0000000..30ea345 --- /dev/null +++ b/pi/compiler/tracing/handlers.py @@ -0,0 +1,116 @@ +import ast +from enum import Enum +from typing import List, cast + +from pyccolo import fast, TraceEvent +from pyccolo.extra_builtins import ( + TRACING_ENABLED, + make_guard_name, + EMIT_EVENT, +) +from pyccolo.fast import make_composite_condition, make_test +from pyccolo.stmt_inserter import StatementInserter + + +def _handle_class_body( + self, + node: ast.ClassDef, + orig_body: List[ast.AST], +) -> List[ast.AST]: + classdef_copy = cast( + ast.ClassDef, + self.orig_to_copy_mapping[id(node)], + ) + if self.global_guards_enabled: + classdef_copy = self._global_nonlocal_stripper.visit(classdef_copy) + class_guard = make_guard_name(classdef_copy) + self.register_guard(class_guard) + else: + class_guard = None + docstring = [] + if ( + len(orig_body) > 0 + and isinstance(orig_body[0], ast.Expr) + and isinstance(orig_body[0].value, ast.Str) + ): + docstring = [orig_body.pop(0)] + if len(orig_body) == 0: + return docstring + with fast.location_of(classdef_copy): + if self.global_guards_enabled: + ret = [ + fast.If( + test=make_composite_condition( + [ + make_test(TRACING_ENABLED), + make_test(class_guard), + self.emit( + TraceEvent.before_class_body, + node, + ret=fast.NameConstant(True), + ) + if self.handler_predicate_by_event[ + TraceEvent.before_class_body + ](classdef_copy) + else None, + ] + ), + body=orig_body, + orelse=classdef_copy.body + if len(docstring) == 0 + else classdef_copy.body[len(docstring) :], # noqa: E203 + ), + ] + return docstring + ret + + +def generic_visit(self, node): + if self.is_tracing_disabled_context(node): + return node + for name, field in ast.iter_fields(node): + if isinstance(field, ast.AST): + setattr(node, name, self.visit(field)) + elif isinstance(field, list): + new_field = [] + future_imports = [] + if isinstance(node, ast.Module) and name == "body": + node_copy = self.get_copy_node(node) + if self.handler_predicate_by_event[TraceEvent.init_module](node_copy): + with fast.location_of(node): + new_field.extend( + fast.parse( + f'{EMIT_EVENT}("{TraceEvent.init_module.name}", ' + + f"{id(node_copy)})" + ).body + ) + for inner_node in field: + if isinstance(inner_node, ast.stmt): + if ( + isinstance(inner_node, ast.ImportFrom) + and inner_node.module == "__future__" + ): + future_imports.append(inner_node) + else: + new_field.extend(self._handle_stmt(node, name, inner_node)) + elif isinstance(inner_node, ast.AST): + new_field.append(self.visit(inner_node)) + else: + new_field.append(inner_node) + new_field = future_imports + new_field + if name == "body": + if isinstance(node, ast.Module): + new_field = self._handle_module_body(node, new_field) + elif isinstance(node, (ast.For, ast.While)): + new_field = self._handle_loop_body(node, new_field) + elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + new_field = self._handle_function_body(node, new_field) + elif isinstance(node, ast.ClassDef): + new_field = self._handle_class_body(node, new_field) + setattr(node, name, new_field) + else: + continue + return node + + +StatementInserter.generic_visit = generic_visit +StatementInserter._handle_class_body = _handle_class_body diff --git a/shark/compiler/tracing/trace.py b/pi/compiler/tracing/trace.py similarity index 91% rename from shark/compiler/tracing/trace.py rename to pi/compiler/tracing/trace.py index c5ebc6d..e5a8e5d 100644 --- a/shark/compiler/tracing/trace.py +++ b/pi/compiler/tracing/trace.py @@ -1,3 +1,4 @@ +import operator import sys from collections import namedtuple @@ -5,17 +6,24 @@ import ctypes import inspect import os + +from pi.compiler.annotations import PI_EXPORT_ATTR_NAME + +# noinspection PyUnresolvedReferences +import pi.compiler.tracing.handlers import pyccolo as pyc import traceback import types from contextlib import contextmanager from pathlib import Path -from pyccolo import TraceEvent, AstRewriter, fast +from pyccolo import TraceEvent, AstRewriter, fast, register_raw_handler from pyccolo.emit_event import _TRACER_STACK from runpy import run_module from typing import Optional, Union, Tuple, List +import pi from torch_mlir import ir + # this needs to be all of the dialects that will be used in the user scripts # (in order to register ops) # noinspection PyUnresolvedReferences @@ -29,7 +37,7 @@ from torch_mlir.dialects._ods_common import get_op_result_or_value from torch_mlir.ir import Type as MLIRType, IntegerType, F64Type -from shark.dialects import affine_, value_ +from pi.dialects import affine_, value_ # from uncompyle6 import main @@ -102,10 +110,19 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef: for n in node.body: self.visit(n) if not isinstance(node.body[-1], ast.Return): - return_ = ast.Return(value=ast.Constant(value=None)) + return_ = ast.Return( + value=None if node.name == "__init__" else ast.Constant(value=None), + lineno=node.body[-1].lineno + 1, + col_offset=node.col_offset + len("return "), + ) node.body.append(return_) return node + def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: + for n in node.body: + self.visit(n) + return node + class BreakIf(ast.NodeTransformer): def visit_If(self, node: ast.If): @@ -158,11 +175,13 @@ def update_frame_locals(frame, updates: dict): class MLIRTracer(pyc.BaseTracer): def __init__( self, + script_path: str, mlir_context: ir.Context, mlir_location: ir.Location, mlir_module: ir.Module, ): super().__init__() + self.script_path = script_path self.mlir_context = mlir_context self.mlir_location = mlir_location self.mlir_module = mlir_module @@ -186,9 +205,10 @@ def __init__( mlir_context=self.mlir_context or self.mlir_context, scope_name="module", ) - self.if_bodies_executed = set() # dirty dirty hack - self.compares_executed = {} + self.if_bodies_executed = set() + self.binops_executed = {} + self.fn_to_node = {} def enter_mlir_block_scope( self, @@ -290,10 +310,22 @@ def should_propagate_handler_exception( return True def should_instrument_file(self, filename: str) -> bool: - return True + return filename == self.script_path # handlers + @TraceEvent.before_class_body + def handle_before_class_body( + self, + old_ret, + node, + frame: types.FrameType, + event, + guard_for_spec, + **_, + ): + return False + @pyc.before_function_body def handle_before_function_body( self, @@ -366,6 +398,8 @@ def handle_after_return( inputs=func_type.inputs, results=[mlir_return_val.type] ) func_op.attributes["function_type"] = ir.TypeAttr.get(canonical_func_type) + if isinstance(mlir_return_val, pi.Tensor): + mlir_return_val = mlir_return_val.value func_dialect.ReturnOp((mlir_return_val,)) else: func_dialect.ReturnOp(()) @@ -383,6 +417,11 @@ def handle_exit_module( guard_for_spec, **kwargs, ): + for loc_name, loc in frame.f_locals.items(): + if inspect.isfunction(loc) and getattr( + loc, PI_EXPORT_ATTR_NAME, False + ): + print(loc) self.exit_mlir_block_scope(scope_name="module") ast_rewriter_cls = MLIRRewriter @@ -459,7 +498,7 @@ def handle_before_binop( ): def eval_op(x, y): hash = id(x), id(y) - if hash not in self.compares_executed: + if hash not in self.binops_executed: x, y = map( lambda v: self.get_or_make_mlir_constant(v) if isinstance(v, (float, int, bool)) @@ -471,8 +510,12 @@ def eval_op(x, y): else: # ...god damn it op = node.op.__class__.__name__.lower().replace("mult", "mul") - self.compares_executed[hash] = getattr(value_, op)(x, y) - return self.compares_executed[hash] + if isinstance(x, pi.Tensor): + assert isinstance(y, pi.Tensor) + self.binops_executed[hash] = getattr(operator, op)(x, y) + else: + self.binops_executed[hash] = getattr(value_, op)(x, y) + return self.binops_executed[hash] return eval_op @@ -564,7 +607,12 @@ def get_script_as_module(script: str) -> str: def trace(script_path, mlir_context, mlir_location, mlir_module) -> ir.Module: module_to_run = get_script_as_module(script_path) - with MLIRTracer(mlir_context, mlir_location, mlir_module): + with MLIRTracer( + script_path, + mlir_context, + mlir_location, + mlir_module, + ): run_module(module_to_run) return mlir_module diff --git a/pi/compiler/tracing/trace_events.py b/pi/compiler/tracing/trace_events.py new file mode 100644 index 0000000..e11de88 --- /dev/null +++ b/pi/compiler/tracing/trace_events.py @@ -0,0 +1,166 @@ +# -*- coding: utf-8 -*- +import ast +from enum import Enum + +from pyccolo import fast + + +class TraceEvent(Enum): + before_import = "before_import" + init_module = "init_module" + exit_module = "exit_module" + after_import = "after_import" + + before_stmt = "before_stmt" + after_stmt = "after_stmt" + after_module_stmt = "after_module_stmt" + after_expr_stmt = "after_expr_stmt" + + load_name = "load_name" + + before_for_loop_body = "before_for_loop_body" + after_for_loop_iter = "after_for_loop_iter" + before_while_loop_body = "before_while_loop_body" + after_while_loop_iter = "after_while_loop_iter" + + before_attribute_load = "before_attribute_load" + before_attribute_store = "before_attribute_store" + before_attribute_del = "before_attribute_del" + after_attribute_load = "after_attribute_load" + before_subscript_load = "before_subscript_load" + before_subscript_store = "before_subscript_store" + before_subscript_del = "before_subscript_del" + after_subscript_load = "after_subscript_load" + + before_subscript_slice = "before_subscript_slice" + after_subscript_slice = "after_subscript_slice" + _load_saved_slice = "_load_saved_slice" + + before_load_complex_symbol = "before_load_complex_symbol" + after_load_complex_symbol = "after_load_complex_symbol" + + after_if_test = "after_if_test" + after_while_test = "after_while_test" + + before_lambda = "before_lambda" + after_lambda = "after_lambda" + + before_call = "before_call" + after_call = "after_call" + before_argument = "before_argument" + after_argument = "after_argument" + before_return = "before_return" + after_return = "after_return" + + before_dict_literal = "before_dict_literal" + after_dict_literal = "after_dict_literal" + before_list_literal = "before_list_literal" + after_list_literal = "after_list_literal" + before_set_literal = "before_set_literal" + after_set_literal = "after_set_literal" + before_tuple_literal = "before_tuple_literal" + after_tuple_literal = "after_tuple_literal" + + dict_key = "dict_key" + dict_value = "dict_value" + list_elt = "list_elt" + set_elt = "set_elt" + tuple_elt = "tuple_elt" + + before_assign_rhs = "before_assign_rhs" + after_assign_rhs = "after_assign_rhs" + before_augassign_rhs = "before_augassign_rhs" + after_augassign_rhs = "after_augassign_rhs" + + before_function_body = "before_function_body" + after_function_execution = "after_function_execution" + + before_class_body = "before_class_body" + + before_lambda_body = "before_lambda_body" + after_lambda_body = "after_lambda_body" + + left_binop_arg = "left_binop_arg" + right_binop_arg = "right_binop_arg" + before_binop = "before_binop" + after_binop = "after_binop" + + left_compare_arg = "left_compare_arg" + compare_arg = "compare_arg" + before_compare = "before_compare" + after_compare = "after_compare" + + after_comprehension_if = "after_comprehension_if" + after_comprehension_elt = "after_comprehension_elt" + after_dict_comprehension_key = "after_dict_comprehension_key" + after_dict_comprehension_value = "after_dict_comprehension_value" + + ellipses = "ellipses" + + line = "line" + call = "call" + return_ = "return" + exception = "exception" + opcode = "opcode" + + # these are included for completeness but will probably not be used + c_call = "c_call" + c_return = "c_return" + c_exception = "c_exception" + + def __str__(self): + return self.value + + def __repr__(self): + return "<" + str(self) + ">" + + def __call__(self, handler=None, **kwargs): + # this will be filled by tracer.py + ... + + def to_ast(self): + return fast.Constant(self.name) + + +SYS_TRACE_EVENTS = { + TraceEvent.line, + TraceEvent.call, + TraceEvent.return_, + TraceEvent.exception, + TraceEvent.opcode, +} + + +BEFORE_EXPR_EVENTS = { + TraceEvent.before_argument, + TraceEvent.before_assign_rhs, + TraceEvent.before_augassign_rhs, + TraceEvent.before_binop, + TraceEvent.before_compare, + TraceEvent.before_dict_literal, + TraceEvent.before_lambda, + TraceEvent.before_list_literal, + TraceEvent.before_load_complex_symbol, + TraceEvent.before_return, + TraceEvent.before_set_literal, + TraceEvent.before_subscript_slice, + TraceEvent.before_tuple_literal, +} + + +AST_TO_EVENT_MAPPING = { + ast.stmt: TraceEvent.after_stmt, + ast.Assign: TraceEvent.after_assign_rhs, + ast.Module: TraceEvent.init_module, + ast.Name: TraceEvent.load_name, + ast.Attribute: TraceEvent.before_attribute_load, + ast.Subscript: TraceEvent.before_subscript_load, + ast.Call: TraceEvent.before_call, + ast.Dict: TraceEvent.after_dict_literal, + ast.List: TraceEvent.after_list_literal, + ast.Tuple: TraceEvent.after_tuple_literal, + ast.Set: TraceEvent.after_set_literal, + ast.Return: TraceEvent.after_return, + ast.BinOp: TraceEvent.after_binop, + ast.Compare: TraceEvent.after_compare, +} diff --git a/pi/dialects/__init__.py b/pi/dialects/__init__.py new file mode 100644 index 0000000..2e00f34 --- /dev/null +++ b/pi/dialects/__init__.py @@ -0,0 +1,189 @@ +import sys + +import logging +import threading +from contextlib import contextmanager +from dataclasses import dataclass +from importlib.abc import MetaPathFinder +from importlib.machinery import SourceFileLoader, ModuleSpec +from importlib.util import find_spec, spec_from_loader +from pathlib import Path +from typing import Generator, Callable, Dict, List + + +logger = logging.getLogger(__name__) + + +@dataclass(order=True, frozen=True) +class ImportOverload: + name: str + origin: Path + is_package: bool + submodule_search_locations: List[Path] = None + + def __post_init__(self): + if self.is_package and self.submodule_search_locations is None: + assert ( + self.origin.name == "__init__.py" + ), f"default search path for {self.name} isn't a package: {self.origin}" + object.__setattr__(self, "submodule_search_locations", [self.origin.parent]) + + +_base_overloads = [ + ImportOverload( + "torch_mlir.dialects._arith_ops_ext", + Path(__file__).parent / "_arith_ops_ext.py", + False, + ), + ImportOverload( + "torch_mlir.dialects._memref_ops_ext", + Path(__file__).parent / "_memref_ops_ext.py", + False, + ), + ImportOverload( + "torch_mlir.dialects._torch_ops_ext_custom", + Path(__file__).parent / "_torch_ops_ext_custom.py", + False, + ), + ImportOverload( + "torch_mlir.dialects._torch_ops_ext", + Path(__file__).parent / "_torch_ops_ext.py", + False, + ), + ImportOverload( + "pyccolo.trace_events", + Path(__file__).parent.parent / "compiler" / "tracing" / "trace_events.py", + False, + ), +] + +BASE_OVERLOADS: Dict[str, ImportOverload] = {i.name: i for i in _base_overloads} + + +# this is based on the birdseye finder (which uses import hooks based on MacroPy's): +# https://github.com/alexmojaki/birdseye/blob/9974af715b1801f9dd99fef93ff133d0ab5223af/birdseye/import_hook.py +class Overloader(MetaPathFinder): + def __init__(self, overloads) -> None: + self.tracers = None + self._thread = threading.current_thread() + self.overloads: Dict[str, ImportOverload] = overloads + + @contextmanager + def _clear_preceding_finders(self) -> Generator[None, None, None]: + """ + Clear all preceding finders from sys.meta_path, and restore them afterwards. + """ + orig_finders = sys.meta_path + try: + sys.meta_path = sys.meta_path[sys.meta_path.index(self) + 1 :] # noqa: E203 + yield + finally: + sys.meta_path = orig_finders + + def _find_plain_spec(self, fullname, path, target): + """Try to find the original module using all the + remaining meta_path finders.""" + spec = None + self_seen = False + for finder in sys.meta_path: + if finder is self: + self_seen = True + continue + elif not self_seen or "pytest" in finder.__module__: + # when testing with pytest, it installs a finder that for + # some yet unknown reasons makes birdseye + # fail. For now it will just avoid using it and pass to + # the next one + continue + if hasattr(finder, "find_spec"): + spec = finder.find_spec(fullname, path, target=target) + elif hasattr(finder, "load_module"): + spec = spec_from_loader(fullname, finder) + + if spec is not None and spec.origin != "builtin": + return spec + + def find_spec(self, fullname, path=None, target=None): + logger.debug(f"finding spec for {fullname=} {path=} {target=}") + + if threading.current_thread() is not self._thread: + return None + if target is None: + with self._clear_preceding_finders(): + spec = find_spec(fullname, path) + else: + spec = self._find_plain_spec(fullname, path, target) + + if fullname not in self.overloads: + if spec is None or not ( + hasattr(spec.loader, "get_source") and callable(spec.loader.get_source) + ): # noqa: E128 + if fullname != "org": + # stdlib pickle.py at line 94 contains a ``from + # org.python.core for Jython which is always failing, + # of course + logger.debug("Failed finding spec for %s", fullname) + return None + + if not isinstance(spec.loader, SourceFileLoader): + return None + return spec + + logger.debug("patching spec for %s", fullname) + + overload = self.overloads[fullname] + new_path = str(overload.origin) + source_file_loader = SourceFileLoader(fullname, new_path) + spec = ModuleSpec( + name=fullname, + loader=source_file_loader, + origin=new_path, + is_package=overload.is_package, + ) + if overload.is_package: + spec.submodule_search_locations = [ + str(p) for p in overload.submodule_search_locations + ] + spec.has_location = True + return spec + + +def patch_meta_path_non_context(overloads=None) -> Callable: + if overloads is None: + overloads = BASE_OVERLOADS + orig_meta_path_entry = None + + def cleanup_callback(): + if orig_meta_path_entry is None: + del sys.meta_path[0] + else: + sys.meta_path[0] = orig_meta_path_entry + + if len(sys.meta_path) > 0 and isinstance(sys.meta_path[0], Overloader): + orig_meta_path_entry = sys.meta_path[0] + sys.meta_path[0] = Overloader(overloads) + else: + sys.meta_path.insert(0, Overloader(overloads)) + + return cleanup_callback + + +@contextmanager +def patch_meta_path(overloads=None) -> Generator[None, None, None]: + cleanup_callback = None + try: + cleanup_callback = patch_meta_path_non_context(overloads) + yield + finally: + if cleanup_callback is not None: + cleanup_callback() + + +def remove_modules(pred: Callable): + to_delete = [] + for mod in sys.modules: + if pred(mod): + logger.debug(f"removing from sys.modules {mod}") + to_delete.append(mod) + for mod in to_delete: + del sys.modules[mod] diff --git a/shark/dialects/_affine_ops_ext.py b/pi/dialects/_affine_ops_ext.py similarity index 87% rename from shark/dialects/_affine_ops_ext.py rename to pi/dialects/_affine_ops_ext.py index d28741c..da50434 100644 --- a/shark/dialects/_affine_ops_ext.py +++ b/pi/dialects/_affine_ops_ext.py @@ -1,19 +1,13 @@ # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -try: - from torch_mlir.ir import * - from ._ods_common import get_default_loc_context as _get_default_loc_context -except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e - from typing import Optional, Sequence, Union -from ._ods_common import ( - get_op_result_or_value as _get_op_result_or_value, - get_op_results_or_values as _get_op_results_or_values, +from torch_mlir.dialects._ods_common import ( + get_op_result_or_value, + get_op_results_or_values, ) +from torch_mlir.ir import * class AffineForOp: @@ -86,8 +80,8 @@ def __init__( loc: user-visible location of the operation. ip: insertion point. """ - memref_resolved = _get_op_result_or_value(memref) - indices_resolved = [] if indices is None else _get_op_results_or_values(indices) + memref_resolved = get_op_result_or_value(memref) + indices_resolved = [] if indices is None else get_op_results_or_values(indices) return_type = MemRefType(memref_resolved.type).element_type super().__init__(return_type, memref, indices_resolved, loc=loc, ip=ip) diff --git a/shark/dialects/_arith_ops_ext.py b/pi/dialects/_arith_ops_ext.py similarity index 94% rename from shark/dialects/_arith_ops_ext.py rename to pi/dialects/_arith_ops_ext.py index 6465ce4..7b11c2b 100644 --- a/shark/dialects/_arith_ops_ext.py +++ b/pi/dialects/_arith_ops_ext.py @@ -1,13 +1,16 @@ # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from torch_mlir.dialects._ods_common import ( + get_default_loc_context, + get_op_result_or_value, +) +from torch_mlir.ir import * +from typing import Any, List, Union try: - from torch_mlir.ir import * - from ._ods_common import get_op_result_or_value, get_default_loc_context - from shark.dialects.value_ import _Value + from pi.dialects.value_ import _Value - from typing import Any, List, Union except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e diff --git a/shark/dialects/_memref_ops_ext.py b/pi/dialects/_memref_ops_ext.py similarity index 98% rename from shark/dialects/_memref_ops_ext.py rename to pi/dialects/_memref_ops_ext.py index bbb5e7a..ad8bdd3 100644 --- a/shark/dialects/_memref_ops_ext.py +++ b/pi/dialects/_memref_ops_ext.py @@ -5,7 +5,7 @@ try: from torch_mlir.ir import * from ._ods_common import get_op_result_or_value, get_op_results_or_values - from shark.dialects.value_ import _Value + from pi.dialects.value_ import _Value except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e diff --git a/pi/dialects/_torch_ops_ext.py b/pi/dialects/_torch_ops_ext.py new file mode 100644 index 0000000..c12fe1f --- /dev/null +++ b/pi/dialects/_torch_ops_ext.py @@ -0,0 +1,8744 @@ +try: + # from pi import Tensor, Number + from torch_mlir.ir import * + from torch_mlir.dialects._ods_common import ( + get_default_loc_context, + get_op_result_or_value, + get_op_results_or_values, + ) + from ._torch_ops_ext_custom import * +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import List, Optional, Any + + +class AtenTanhOp: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenTanhOp, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenTanh_Op: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenTanh_Op, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenHardtanhOp: + def __init__(self, self_: Value, min_val: "Number", max_val: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(min_val): + min_val = torch_dialect.ConstantNumberOp(min_val) + else: + min_val = get_op_result_or_value(min_val) + assert str(min_val.type) in {'!torch.float', '!torch.int'}, f'`min_val` should be a !torch.number but is {type(min_val).__module__}.{type(min_val).__name__}' + + if not is_mlir_value(max_val): + max_val = torch_dialect.ConstantNumberOp(max_val) + else: + max_val = get_op_result_or_value(max_val) + assert str(max_val.type) in {'!torch.float', '!torch.int'}, f'`max_val` should be a !torch.number but is {type(max_val).__module__}.{type(max_val).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenHardtanhOp, self).__init__(result_type, self_, min_val, max_val, loc=loc, ip=ip) + + +class AtenHardtanh_Op: + def __init__(self, self_: Value, min_val: "Number", max_val: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(min_val): + min_val = torch_dialect.ConstantNumberOp(min_val) + else: + min_val = get_op_result_or_value(min_val) + assert str(min_val.type) in {'!torch.float', '!torch.int'}, f'`min_val` should be a !torch.number but is {type(min_val).__module__}.{type(min_val).__name__}' + + if not is_mlir_value(max_val): + max_val = torch_dialect.ConstantNumberOp(max_val) + else: + max_val = get_op_result_or_value(max_val) + assert str(max_val.type) in {'!torch.float', '!torch.int'}, f'`max_val` should be a !torch.number but is {type(max_val).__module__}.{type(max_val).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenHardtanh_Op, self).__init__(result_type, self_, min_val, max_val, loc=loc, ip=ip) + + +class AtenReluOp: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenReluOp, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenRelu_Op: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenRelu_Op, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenRelu6Op: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenRelu6Op, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenRelu6_Op: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenRelu6_Op, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenLeakyReluOp: + def __init__(self, self_: Value, negative_slope: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(negative_slope): + negative_slope = torch_dialect.ConstantNumberOp(negative_slope) + else: + negative_slope = get_op_result_or_value(negative_slope) + assert str(negative_slope.type) in {'!torch.float', '!torch.int'}, f'`negative_slope` should be a !torch.number but is {type(negative_slope).__module__}.{type(negative_slope).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenLeakyReluOp, self).__init__(result_type, self_, negative_slope, loc=loc, ip=ip) + + +class AtenLeakyRelu_Op: + def __init__(self, self_: Value, negative_slope: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(negative_slope): + negative_slope = torch_dialect.ConstantNumberOp(negative_slope) + else: + negative_slope = get_op_result_or_value(negative_slope) + assert str(negative_slope.type) in {'!torch.float', '!torch.int'}, f'`negative_slope` should be a !torch.number but is {type(negative_slope).__module__}.{type(negative_slope).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenLeakyRelu_Op, self).__init__(result_type, self_, negative_slope, loc=loc, ip=ip) + + +class AtenLogOp: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenLogOp, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenLog_Op: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenLog_Op, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenSigmoidOp: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenSigmoidOp, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenSigmoid_Op: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenSigmoid_Op, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenHardsigmoidOp: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenHardsigmoidOp, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenHardsigmoid_Op: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenHardsigmoid_Op, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenHardswishOp: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenHardswishOp, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenHardswish_Op: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenHardswish_Op, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenErfOp: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenErfOp, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenErf_Op: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenErf_Op, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenSiluOp: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenSiluOp, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenSilu_Op: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenSilu_Op, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenSinOp: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenSinOp, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenSin_Op: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenSin_Op, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenExpOp: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenExpOp, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenExp_Op: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenExp_Op, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenExpm1Op: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenExpm1Op, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenExpm1_Op: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenExpm1_Op, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenCosOp: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenCosOp, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenCos_Op: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenCos_Op, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenAtan2Op: + def __init__(self, self_: Value, other: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}' + else: + other = get_op_result_or_value(other) + assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenAtan2Op, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenAtan2_Op: + def __init__(self, self_: Value, other: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}' + else: + other = get_op_result_or_value(other) + assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenAtan2_Op, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenNegOp: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenNegOp, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenNeg_Op: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenNeg_Op, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenFloorOp: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenFloorOp, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenFloor_Op: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenFloor_Op, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenCeilOp: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenCeilOp, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenCeil_Op: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenCeil_Op, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenBitwiseNotOp: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenBitwiseNotOp, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenBitwiseNot_Op: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenBitwiseNot_Op, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenDivTensorOp: + def __init__(self, self_: Value, other: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}' + else: + other = get_op_result_or_value(other) + assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenDivTensorOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenDiv_TensorOp: + def __init__(self, self_: Value, other: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}' + else: + other = get_op_result_or_value(other) + assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenDiv_TensorOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenLogicalOrOp: + def __init__(self, self_: Value, other: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}' + else: + other = get_op_result_or_value(other) + assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenLogicalOrOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenLogicalOr_Op: + def __init__(self, self_: Value, other: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}' + else: + other = get_op_result_or_value(other) + assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenLogicalOr_Op, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenLogicalAndOp: + def __init__(self, self_: Value, other: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}' + else: + other = get_op_result_or_value(other) + assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenLogicalAndOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenLogicalAnd_Op: + def __init__(self, self_: Value, other: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}' + else: + other = get_op_result_or_value(other) + assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenLogicalAnd_Op, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenLogicalXorOp: + def __init__(self, self_: Value, other: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}' + else: + other = get_op_result_or_value(other) + assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenLogicalXorOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenLogicalXor_Op: + def __init__(self, self_: Value, other: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}' + else: + other = get_op_result_or_value(other) + assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenLogicalXor_Op, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenLogicalNotOp: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenLogicalNotOp, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenLogicalNot_Op: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenLogicalNot_Op, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenLerpTensorOp: + def __init__(self, self_: Value, end: Value, weight: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(end): + assert is_mlir_value(end), f'`end` should be a Value but is {type(end).__module__}.{type(end).__name__}' + else: + end = get_op_result_or_value(end) + assert str(end.type).startswith("!torch.vtensor"), f'`end` should be a torch.vtensor but is {type(end).__module__}.{type(end).__name__}' + + if not is_mlir_value(weight): + assert is_mlir_value(weight), f'`weight` should be a Value but is {type(weight).__module__}.{type(weight).__name__}' + else: + weight = get_op_result_or_value(weight) + assert str(weight.type).startswith("!torch.vtensor"), f'`weight` should be a torch.vtensor but is {type(weight).__module__}.{type(weight).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenLerpTensorOp, self).__init__(result_type, self_, end, weight, loc=loc, ip=ip) + + +class AtenLerp_TensorOp: + def __init__(self, self_: Value, end: Value, weight: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(end): + assert is_mlir_value(end), f'`end` should be a Value but is {type(end).__module__}.{type(end).__name__}' + else: + end = get_op_result_or_value(end) + assert str(end.type).startswith("!torch.vtensor"), f'`end` should be a torch.vtensor but is {type(end).__module__}.{type(end).__name__}' + + if not is_mlir_value(weight): + assert is_mlir_value(weight), f'`weight` should be a Value but is {type(weight).__module__}.{type(weight).__name__}' + else: + weight = get_op_result_or_value(weight) + assert str(weight.type).startswith("!torch.vtensor"), f'`weight` should be a torch.vtensor but is {type(weight).__module__}.{type(weight).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenLerp_TensorOp, self).__init__(result_type, self_, end, weight, loc=loc, ip=ip) + + +class AtenEqTensorOp: + def __init__(self, self_: Value, other: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}' + else: + other = get_op_result_or_value(other) + assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenEqTensorOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenEq_TensorOp: + def __init__(self, self_: Value, other: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}' + else: + other = get_op_result_or_value(other) + assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenEq_TensorOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenGtTensorOp: + def __init__(self, self_: Value, other: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}' + else: + other = get_op_result_or_value(other) + assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenGtTensorOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenGt_TensorOp: + def __init__(self, self_: Value, other: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}' + else: + other = get_op_result_or_value(other) + assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenGt_TensorOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenGeTensorOp: + def __init__(self, self_: Value, other: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}' + else: + other = get_op_result_or_value(other) + assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenGeTensorOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenGe_TensorOp: + def __init__(self, self_: Value, other: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}' + else: + other = get_op_result_or_value(other) + assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenGe_TensorOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenLtTensorOp: + def __init__(self, self_: Value, other: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}' + else: + other = get_op_result_or_value(other) + assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenLtTensorOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenLt_TensorOp: + def __init__(self, self_: Value, other: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}' + else: + other = get_op_result_or_value(other) + assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenLt_TensorOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenLeTensorOp: + def __init__(self, self_: Value, other: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}' + else: + other = get_op_result_or_value(other) + assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenLeTensorOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenLe_TensorOp: + def __init__(self, self_: Value, other: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}' + else: + other = get_op_result_or_value(other) + assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenLe_TensorOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenNeTensorOp: + def __init__(self, self_: Value, other: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}' + else: + other = get_op_result_or_value(other) + assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenNeTensorOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenNe_TensorOp: + def __init__(self, self_: Value, other: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}' + else: + other = get_op_result_or_value(other) + assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenNe_TensorOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenDivScalarOp: + def __init__(self, self_: Value, other: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + other = torch_dialect.ConstantNumberOp(other) + else: + other = get_op_result_or_value(other) + assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenDivScalarOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenDiv_ScalarOp: + def __init__(self, self_: Value, other: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + other = torch_dialect.ConstantNumberOp(other) + else: + other = get_op_result_or_value(other) + assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenDiv_ScalarOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenNeScalarOp: + def __init__(self, self_: Value, other: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + other = torch_dialect.ConstantNumberOp(other) + else: + other = get_op_result_or_value(other) + assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenNeScalarOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenNe_ScalarOp: + def __init__(self, self_: Value, other: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + other = torch_dialect.ConstantNumberOp(other) + else: + other = get_op_result_or_value(other) + assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenNe_ScalarOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenEqScalarOp: + def __init__(self, self_: Value, other: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + other = torch_dialect.ConstantNumberOp(other) + else: + other = get_op_result_or_value(other) + assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenEqScalarOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenEq_ScalarOp: + def __init__(self, self_: Value, other: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + other = torch_dialect.ConstantNumberOp(other) + else: + other = get_op_result_or_value(other) + assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenEq_ScalarOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenGtScalarOp: + def __init__(self, self_: Value, other: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + other = torch_dialect.ConstantNumberOp(other) + else: + other = get_op_result_or_value(other) + assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenGtScalarOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenGt_ScalarOp: + def __init__(self, self_: Value, other: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + other = torch_dialect.ConstantNumberOp(other) + else: + other = get_op_result_or_value(other) + assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenGt_ScalarOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenGeScalarOp: + def __init__(self, self_: Value, other: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + other = torch_dialect.ConstantNumberOp(other) + else: + other = get_op_result_or_value(other) + assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenGeScalarOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenGe_ScalarOp: + def __init__(self, self_: Value, other: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + other = torch_dialect.ConstantNumberOp(other) + else: + other = get_op_result_or_value(other) + assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenGe_ScalarOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenLtScalarOp: + def __init__(self, self_: Value, other: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + other = torch_dialect.ConstantNumberOp(other) + else: + other = get_op_result_or_value(other) + assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenLtScalarOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenLt_ScalarOp: + def __init__(self, self_: Value, other: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + other = torch_dialect.ConstantNumberOp(other) + else: + other = get_op_result_or_value(other) + assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenLt_ScalarOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenLeScalarOp: + def __init__(self, self_: Value, other: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + other = torch_dialect.ConstantNumberOp(other) + else: + other = get_op_result_or_value(other) + assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenLeScalarOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenLe_ScalarOp: + def __init__(self, self_: Value, other: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + other = torch_dialect.ConstantNumberOp(other) + else: + other = get_op_result_or_value(other) + assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenLe_ScalarOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenFmodScalarOp: + def __init__(self, self_: Value, other: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + other = torch_dialect.ConstantNumberOp(other) + else: + other = get_op_result_or_value(other) + assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenFmodScalarOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenFmod_ScalarOp: + def __init__(self, self_: Value, other: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + other = torch_dialect.ConstantNumberOp(other) + else: + other = get_op_result_or_value(other) + assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenFmod_ScalarOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenMaskedFillScalarOp: + def __init__(self, self_: Value, mask: Value, value: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(mask): + assert is_mlir_value(mask), f'`mask` should be a Value but is {type(mask).__module__}.{type(mask).__name__}' + else: + mask = get_op_result_or_value(mask) + assert str(mask.type).startswith("!torch.vtensor"), f'`mask` should be a torch.vtensor but is {type(mask).__module__}.{type(mask).__name__}' + + if not is_mlir_value(value): + value = torch_dialect.ConstantNumberOp(value) + else: + value = get_op_result_or_value(value) + assert str(value.type) in {'!torch.float', '!torch.int'}, f'`value` should be a !torch.number but is {type(value).__module__}.{type(value).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenMaskedFillScalarOp, self).__init__(result_type, self_, mask, value, loc=loc, ip=ip) + + +class AtenMaskedFill_ScalarOp: + def __init__(self, self_: Value, mask: Value, value: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(mask): + assert is_mlir_value(mask), f'`mask` should be a Value but is {type(mask).__module__}.{type(mask).__name__}' + else: + mask = get_op_result_or_value(mask) + assert str(mask.type).startswith("!torch.vtensor"), f'`mask` should be a torch.vtensor but is {type(mask).__module__}.{type(mask).__name__}' + + if not is_mlir_value(value): + value = torch_dialect.ConstantNumberOp(value) + else: + value = get_op_result_or_value(value) + assert str(value.type) in {'!torch.float', '!torch.int'}, f'`value` should be a !torch.number but is {type(value).__module__}.{type(value).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenMaskedFill_ScalarOp, self).__init__(result_type, self_, mask, value, loc=loc, ip=ip) + + +class AtenMaskedFillTensorOp: + def __init__(self, self_: Value, mask: Value, value: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(mask): + assert is_mlir_value(mask), f'`mask` should be a Value but is {type(mask).__module__}.{type(mask).__name__}' + else: + mask = get_op_result_or_value(mask) + assert str(mask.type).startswith("!torch.vtensor"), f'`mask` should be a torch.vtensor but is {type(mask).__module__}.{type(mask).__name__}' + + if not is_mlir_value(value): + assert is_mlir_value(value), f'`value` should be a Value but is {type(value).__module__}.{type(value).__name__}' + else: + value = get_op_result_or_value(value) + assert str(value.type).startswith("!torch.vtensor"), f'`value` should be a torch.vtensor but is {type(value).__module__}.{type(value).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenMaskedFillTensorOp, self).__init__(result_type, self_, mask, value, loc=loc, ip=ip) + + +class AtenMaskedFill_TensorOp: + def __init__(self, self_: Value, mask: Value, value: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(mask): + assert is_mlir_value(mask), f'`mask` should be a Value but is {type(mask).__module__}.{type(mask).__name__}' + else: + mask = get_op_result_or_value(mask) + assert str(mask.type).startswith("!torch.vtensor"), f'`mask` should be a torch.vtensor but is {type(mask).__module__}.{type(mask).__name__}' + + if not is_mlir_value(value): + assert is_mlir_value(value), f'`value` should be a Value but is {type(value).__module__}.{type(value).__name__}' + else: + value = get_op_result_or_value(value) + assert str(value.type).startswith("!torch.vtensor"), f'`value` should be a torch.vtensor but is {type(value).__module__}.{type(value).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenMaskedFill_TensorOp, self).__init__(result_type, self_, mask, value, loc=loc, ip=ip) + + +class AtenClampOp: + def __init__(self, self_: Value, min: Optional["Number"], max: Optional["Number"], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(min): + if min is not None: + min = torch_dialect.ConstantNumberOp(min) + else: + min = torch_dialect.ConstantNoneOp() + else: + min = get_op_result_or_value(min) + assert str(min.type) in {'!torch.float', '!torch.int'}, f'`min` should be a !torch.number but is {type(min).__module__}.{type(min).__name__}' + + if not is_mlir_value(max): + if max is not None: + max = torch_dialect.ConstantNumberOp(max) + else: + max = torch_dialect.ConstantNoneOp() + else: + max = get_op_result_or_value(max) + assert str(max.type) in {'!torch.float', '!torch.int'}, f'`max` should be a !torch.number but is {type(max).__module__}.{type(max).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenClampOp, self).__init__(result_type, self_, min, max, loc=loc, ip=ip) + + +class AtenClamp_Op: + def __init__(self, self_: Value, min: Optional["Number"], max: Optional["Number"], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(min): + if min is not None: + min = torch_dialect.ConstantNumberOp(min) + else: + min = torch_dialect.ConstantNoneOp() + else: + min = get_op_result_or_value(min) + assert str(min.type) in {'!torch.float', '!torch.int'}, f'`min` should be a !torch.number but is {type(min).__module__}.{type(min).__name__}' + + if not is_mlir_value(max): + if max is not None: + max = torch_dialect.ConstantNumberOp(max) + else: + max = torch_dialect.ConstantNoneOp() + else: + max = get_op_result_or_value(max) + assert str(max.type) in {'!torch.float', '!torch.int'}, f'`max` should be a !torch.number but is {type(max).__module__}.{type(max).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenClamp_Op, self).__init__(result_type, self_, min, max, loc=loc, ip=ip) + + +class AtenClampMinOp: + def __init__(self, self_: Value, min: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(min): + min = torch_dialect.ConstantNumberOp(min) + else: + min = get_op_result_or_value(min) + assert str(min.type) in {'!torch.float', '!torch.int'}, f'`min` should be a !torch.number but is {type(min).__module__}.{type(min).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenClampMinOp, self).__init__(result_type, self_, min, loc=loc, ip=ip) + + +class AtenClampMin_Op: + def __init__(self, self_: Value, min: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(min): + min = torch_dialect.ConstantNumberOp(min) + else: + min = get_op_result_or_value(min) + assert str(min.type) in {'!torch.float', '!torch.int'}, f'`min` should be a !torch.number but is {type(min).__module__}.{type(min).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenClampMin_Op, self).__init__(result_type, self_, min, loc=loc, ip=ip) + + +class AtenClampMaxOp: + def __init__(self, self_: Value, max: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(max): + max = torch_dialect.ConstantNumberOp(max) + else: + max = get_op_result_or_value(max) + assert str(max.type) in {'!torch.float', '!torch.int'}, f'`max` should be a !torch.number but is {type(max).__module__}.{type(max).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenClampMaxOp, self).__init__(result_type, self_, max, loc=loc, ip=ip) + + +class AtenClampMax_Op: + def __init__(self, self_: Value, max: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(max): + max = torch_dialect.ConstantNumberOp(max) + else: + max = get_op_result_or_value(max) + assert str(max.type) in {'!torch.float', '!torch.int'}, f'`max` should be a !torch.number but is {type(max).__module__}.{type(max).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenClampMax_Op, self).__init__(result_type, self_, max, loc=loc, ip=ip) + + +class AtenLog2Op: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenLog2Op, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenLog2_Op: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenLog2_Op, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenSqrtOp: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenSqrtOp, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenSqrt_Op: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenSqrt_Op, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenLog1pOp: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenLog1pOp, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenLog1p_Op: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenLog1p_Op, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenRsqrtOp: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenRsqrtOp, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenRsqrt_Op: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenRsqrt_Op, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenAbsOp: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenAbsOp, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenAbs_Op: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenAbs_Op, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenReciprocalOp: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenReciprocalOp, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenReciprocal_Op: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenReciprocal_Op, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenBitwiseAndTensorOp: + def __init__(self, self_: Value, other: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}' + else: + other = get_op_result_or_value(other) + assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenBitwiseAndTensorOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenBitwiseAnd_TensorOp: + def __init__(self, self_: Value, other: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}' + else: + other = get_op_result_or_value(other) + assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenBitwiseAnd_TensorOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenBitwiseOrTensorOp: + def __init__(self, self_: Value, other: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}' + else: + other = get_op_result_or_value(other) + assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenBitwiseOrTensorOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenBitwiseOr_TensorOp: + def __init__(self, self_: Value, other: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}' + else: + other = get_op_result_or_value(other) + assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenBitwiseOr_TensorOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenThresholdOp: + def __init__(self, self_: Value, threshold: "Number", value: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(threshold): + threshold = torch_dialect.ConstantNumberOp(threshold) + else: + threshold = get_op_result_or_value(threshold) + assert str(threshold.type) in {'!torch.float', '!torch.int'}, f'`threshold` should be a !torch.number but is {type(threshold).__module__}.{type(threshold).__name__}' + + if not is_mlir_value(value): + value = torch_dialect.ConstantNumberOp(value) + else: + value = get_op_result_or_value(value) + assert str(value.type) in {'!torch.float', '!torch.int'}, f'`value` should be a !torch.number but is {type(value).__module__}.{type(value).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenThresholdOp, self).__init__(result_type, self_, threshold, value, loc=loc, ip=ip) + + +class AtenThreshold_Op: + def __init__(self, self_: Value, threshold: "Number", value: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(threshold): + threshold = torch_dialect.ConstantNumberOp(threshold) + else: + threshold = get_op_result_or_value(threshold) + assert str(threshold.type) in {'!torch.float', '!torch.int'}, f'`threshold` should be a !torch.number but is {type(threshold).__module__}.{type(threshold).__name__}' + + if not is_mlir_value(value): + value = torch_dialect.ConstantNumberOp(value) + else: + value = get_op_result_or_value(value) + assert str(value.type) in {'!torch.float', '!torch.int'}, f'`value` should be a !torch.number but is {type(value).__module__}.{type(value).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenThreshold_Op, self).__init__(result_type, self_, threshold, value, loc=loc, ip=ip) + + +class AtenSquareOp: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenSquareOp, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenSquare_Op: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenSquare_Op, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenUnsqueezeOp: + def __init__(self, self_: Value, dim: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(dim): + dim = torch_dialect.ConstantIntOp(dim) + else: + dim = get_op_result_or_value(dim) + assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenUnsqueezeOp, self).__init__(result_type, self_, dim, loc=loc, ip=ip) + + +class AtenUnsqueeze_Op: + def __init__(self, self_: Value, dim: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(dim): + dim = torch_dialect.ConstantIntOp(dim) + else: + dim = get_op_result_or_value(dim) + assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenUnsqueeze_Op, self).__init__(result_type, self_, dim, loc=loc, ip=ip) + + +class AtenZeroOp: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenZeroOp, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenZero_Op: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenZero_Op, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenFillScalarOp: + def __init__(self, self_: Value, value: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(value): + value = torch_dialect.ConstantNumberOp(value) + else: + value = get_op_result_or_value(value) + assert str(value.type) in {'!torch.float', '!torch.int'}, f'`value` should be a !torch.number but is {type(value).__module__}.{type(value).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenFillScalarOp, self).__init__(result_type, self_, value, loc=loc, ip=ip) + + +class AtenFill_ScalarOp: + def __init__(self, self_: Value, value: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(value): + value = torch_dialect.ConstantNumberOp(value) + else: + value = get_op_result_or_value(value) + assert str(value.type) in {'!torch.float', '!torch.int'}, f'`value` should be a !torch.number but is {type(value).__module__}.{type(value).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenFill_ScalarOp, self).__init__(result_type, self_, value, loc=loc, ip=ip) + + +class AtenFillTensorOp: + def __init__(self, self_: Value, value: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(value): + assert is_mlir_value(value), f'`value` should be a Value but is {type(value).__module__}.{type(value).__name__}' + else: + value = get_op_result_or_value(value) + assert str(value.type).startswith("!torch.vtensor"), f'`value` should be a torch.vtensor but is {type(value).__module__}.{type(value).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenFillTensorOp, self).__init__(result_type, self_, value, loc=loc, ip=ip) + + +class AtenFill_TensorOp: + def __init__(self, self_: Value, value: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(value): + assert is_mlir_value(value), f'`value` should be a Value but is {type(value).__module__}.{type(value).__name__}' + else: + value = get_op_result_or_value(value) + assert str(value.type).startswith("!torch.vtensor"), f'`value` should be a torch.vtensor but is {type(value).__module__}.{type(value).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenFill_TensorOp, self).__init__(result_type, self_, value, loc=loc, ip=ip) + + +class AtenDivTensorModeOp: + def __init__(self, self_: Value, other: Value, rounding_mode: Optional[str], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}' + else: + other = get_op_result_or_value(other) + assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}' + + if not is_mlir_value(rounding_mode): + if rounding_mode is not None: + rounding_mode = torch_dialect.ConstantStrOp(rounding_mode) + else: + rounding_mode = torch_dialect.ConstantNoneOp() + else: + rounding_mode = get_op_result_or_value(rounding_mode) + assert str(rounding_mode.type) == '!torch.str', f'`rounding_mode` should be a !torch.str but is {type(rounding_mode).__module__}.{type(rounding_mode).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenDivTensorModeOp, self).__init__(result_type, self_, other, rounding_mode, loc=loc, ip=ip) + + +class AtenDiv_TensorModeOp: + def __init__(self, self_: Value, other: Value, rounding_mode: Optional[str], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}' + else: + other = get_op_result_or_value(other) + assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}' + + if not is_mlir_value(rounding_mode): + if rounding_mode is not None: + rounding_mode = torch_dialect.ConstantStrOp(rounding_mode) + else: + rounding_mode = torch_dialect.ConstantNoneOp() + else: + rounding_mode = get_op_result_or_value(rounding_mode) + assert str(rounding_mode.type) == '!torch.str', f'`rounding_mode` should be a !torch.str but is {type(rounding_mode).__module__}.{type(rounding_mode).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenDiv_TensorModeOp, self).__init__(result_type, self_, other, rounding_mode, loc=loc, ip=ip) + + +class AtenMulTensorOp: + def __init__(self, self_: Value, other: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}' + else: + other = get_op_result_or_value(other) + assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenMulTensorOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenMul_TensorOp: + def __init__(self, self_: Value, other: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}' + else: + other = get_op_result_or_value(other) + assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenMul_TensorOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenAddTensorOp: + def __init__(self, self_: Value, other: Value, alpha: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}' + else: + other = get_op_result_or_value(other) + assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}' + + if not is_mlir_value(alpha): + alpha = torch_dialect.ConstantNumberOp(alpha) + else: + alpha = get_op_result_or_value(alpha) + assert str(alpha.type) in {'!torch.float', '!torch.int'}, f'`alpha` should be a !torch.number but is {type(alpha).__module__}.{type(alpha).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenAddTensorOp, self).__init__(result_type, self_, other, alpha, loc=loc, ip=ip) + + +class AtenAdd_TensorOp: + def __init__(self, self_: Value, other: Value, alpha: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}' + else: + other = get_op_result_or_value(other) + assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}' + + if not is_mlir_value(alpha): + alpha = torch_dialect.ConstantNumberOp(alpha) + else: + alpha = get_op_result_or_value(alpha) + assert str(alpha.type) in {'!torch.float', '!torch.int'}, f'`alpha` should be a !torch.number but is {type(alpha).__module__}.{type(alpha).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenAdd_TensorOp, self).__init__(result_type, self_, other, alpha, loc=loc, ip=ip) + + +class AtenSubTensorOp: + def __init__(self, self_: Value, other: Value, alpha: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}' + else: + other = get_op_result_or_value(other) + assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}' + + if not is_mlir_value(alpha): + alpha = torch_dialect.ConstantNumberOp(alpha) + else: + alpha = get_op_result_or_value(alpha) + assert str(alpha.type) in {'!torch.float', '!torch.int'}, f'`alpha` should be a !torch.number but is {type(alpha).__module__}.{type(alpha).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenSubTensorOp, self).__init__(result_type, self_, other, alpha, loc=loc, ip=ip) + + +class AtenSub_TensorOp: + def __init__(self, self_: Value, other: Value, alpha: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}' + else: + other = get_op_result_or_value(other) + assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}' + + if not is_mlir_value(alpha): + alpha = torch_dialect.ConstantNumberOp(alpha) + else: + alpha = get_op_result_or_value(alpha) + assert str(alpha.type) in {'!torch.float', '!torch.int'}, f'`alpha` should be a !torch.number but is {type(alpha).__module__}.{type(alpha).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenSub_TensorOp, self).__init__(result_type, self_, other, alpha, loc=loc, ip=ip) + + +class AtenAddScalarOp: + def __init__(self, self_: Value, other: "Number", alpha: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + other = torch_dialect.ConstantNumberOp(other) + else: + other = get_op_result_or_value(other) + assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}' + + if not is_mlir_value(alpha): + alpha = torch_dialect.ConstantNumberOp(alpha) + else: + alpha = get_op_result_or_value(alpha) + assert str(alpha.type) in {'!torch.float', '!torch.int'}, f'`alpha` should be a !torch.number but is {type(alpha).__module__}.{type(alpha).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenAddScalarOp, self).__init__(result_type, self_, other, alpha, loc=loc, ip=ip) + + +class AtenAdd_ScalarOp: + def __init__(self, self_: Value, other: "Number", alpha: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + other = torch_dialect.ConstantNumberOp(other) + else: + other = get_op_result_or_value(other) + assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}' + + if not is_mlir_value(alpha): + alpha = torch_dialect.ConstantNumberOp(alpha) + else: + alpha = get_op_result_or_value(alpha) + assert str(alpha.type) in {'!torch.float', '!torch.int'}, f'`alpha` should be a !torch.number but is {type(alpha).__module__}.{type(alpha).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenAdd_ScalarOp, self).__init__(result_type, self_, other, alpha, loc=loc, ip=ip) + + +class AtenSubScalarOp: + def __init__(self, self_: Value, other: "Number", alpha: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + other = torch_dialect.ConstantNumberOp(other) + else: + other = get_op_result_or_value(other) + assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}' + + if not is_mlir_value(alpha): + alpha = torch_dialect.ConstantNumberOp(alpha) + else: + alpha = get_op_result_or_value(alpha) + assert str(alpha.type) in {'!torch.float', '!torch.int'}, f'`alpha` should be a !torch.number but is {type(alpha).__module__}.{type(alpha).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenSubScalarOp, self).__init__(result_type, self_, other, alpha, loc=loc, ip=ip) + + +class AtenSub_ScalarOp: + def __init__(self, self_: Value, other: "Number", alpha: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + other = torch_dialect.ConstantNumberOp(other) + else: + other = get_op_result_or_value(other) + assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}' + + if not is_mlir_value(alpha): + alpha = torch_dialect.ConstantNumberOp(alpha) + else: + alpha = get_op_result_or_value(alpha) + assert str(alpha.type) in {'!torch.float', '!torch.int'}, f'`alpha` should be a !torch.number but is {type(alpha).__module__}.{type(alpha).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenSub_ScalarOp, self).__init__(result_type, self_, other, alpha, loc=loc, ip=ip) + + +class AtenMulScalarOp: + def __init__(self, self_: Value, other: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + other = torch_dialect.ConstantNumberOp(other) + else: + other = get_op_result_or_value(other) + assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenMulScalarOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenMul_ScalarOp: + def __init__(self, self_: Value, other: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + other = torch_dialect.ConstantNumberOp(other) + else: + other = get_op_result_or_value(other) + assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenMul_ScalarOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenAddcmulOp: + def __init__(self, self_: Value, tensor1: Value, tensor2: Value, value: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(tensor1): + assert is_mlir_value(tensor1), f'`tensor1` should be a Value but is {type(tensor1).__module__}.{type(tensor1).__name__}' + else: + tensor1 = get_op_result_or_value(tensor1) + assert str(tensor1.type).startswith("!torch.vtensor"), f'`tensor1` should be a torch.vtensor but is {type(tensor1).__module__}.{type(tensor1).__name__}' + + if not is_mlir_value(tensor2): + assert is_mlir_value(tensor2), f'`tensor2` should be a Value but is {type(tensor2).__module__}.{type(tensor2).__name__}' + else: + tensor2 = get_op_result_or_value(tensor2) + assert str(tensor2.type).startswith("!torch.vtensor"), f'`tensor2` should be a torch.vtensor but is {type(tensor2).__module__}.{type(tensor2).__name__}' + + if not is_mlir_value(value): + value = torch_dialect.ConstantNumberOp(value) + else: + value = get_op_result_or_value(value) + assert str(value.type) in {'!torch.float', '!torch.int'}, f'`value` should be a !torch.number but is {type(value).__module__}.{type(value).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenAddcmulOp, self).__init__(result_type, self_, tensor1, tensor2, value, loc=loc, ip=ip) + + +class AtenAddcmul_Op: + def __init__(self, self_: Value, tensor1: Value, tensor2: Value, value: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(tensor1): + assert is_mlir_value(tensor1), f'`tensor1` should be a Value but is {type(tensor1).__module__}.{type(tensor1).__name__}' + else: + tensor1 = get_op_result_or_value(tensor1) + assert str(tensor1.type).startswith("!torch.vtensor"), f'`tensor1` should be a torch.vtensor but is {type(tensor1).__module__}.{type(tensor1).__name__}' + + if not is_mlir_value(tensor2): + assert is_mlir_value(tensor2), f'`tensor2` should be a Value but is {type(tensor2).__module__}.{type(tensor2).__name__}' + else: + tensor2 = get_op_result_or_value(tensor2) + assert str(tensor2.type).startswith("!torch.vtensor"), f'`tensor2` should be a torch.vtensor but is {type(tensor2).__module__}.{type(tensor2).__name__}' + + if not is_mlir_value(value): + value = torch_dialect.ConstantNumberOp(value) + else: + value = get_op_result_or_value(value) + assert str(value.type) in {'!torch.float', '!torch.int'}, f'`value` should be a !torch.number but is {type(value).__module__}.{type(value).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenAddcmul_Op, self).__init__(result_type, self_, tensor1, tensor2, value, loc=loc, ip=ip) + + +class AtenAddcdivOp: + def __init__(self, self_: Value, tensor1: Value, tensor2: Value, value: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(tensor1): + assert is_mlir_value(tensor1), f'`tensor1` should be a Value but is {type(tensor1).__module__}.{type(tensor1).__name__}' + else: + tensor1 = get_op_result_or_value(tensor1) + assert str(tensor1.type).startswith("!torch.vtensor"), f'`tensor1` should be a torch.vtensor but is {type(tensor1).__module__}.{type(tensor1).__name__}' + + if not is_mlir_value(tensor2): + assert is_mlir_value(tensor2), f'`tensor2` should be a Value but is {type(tensor2).__module__}.{type(tensor2).__name__}' + else: + tensor2 = get_op_result_or_value(tensor2) + assert str(tensor2.type).startswith("!torch.vtensor"), f'`tensor2` should be a torch.vtensor but is {type(tensor2).__module__}.{type(tensor2).__name__}' + + if not is_mlir_value(value): + value = torch_dialect.ConstantNumberOp(value) + else: + value = get_op_result_or_value(value) + assert str(value.type) in {'!torch.float', '!torch.int'}, f'`value` should be a !torch.number but is {type(value).__module__}.{type(value).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenAddcdivOp, self).__init__(result_type, self_, tensor1, tensor2, value, loc=loc, ip=ip) + + +class AtenAddcdiv_Op: + def __init__(self, self_: Value, tensor1: Value, tensor2: Value, value: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(tensor1): + assert is_mlir_value(tensor1), f'`tensor1` should be a Value but is {type(tensor1).__module__}.{type(tensor1).__name__}' + else: + tensor1 = get_op_result_or_value(tensor1) + assert str(tensor1.type).startswith("!torch.vtensor"), f'`tensor1` should be a torch.vtensor but is {type(tensor1).__module__}.{type(tensor1).__name__}' + + if not is_mlir_value(tensor2): + assert is_mlir_value(tensor2), f'`tensor2` should be a Value but is {type(tensor2).__module__}.{type(tensor2).__name__}' + else: + tensor2 = get_op_result_or_value(tensor2) + assert str(tensor2.type).startswith("!torch.vtensor"), f'`tensor2` should be a torch.vtensor but is {type(tensor2).__module__}.{type(tensor2).__name__}' + + if not is_mlir_value(value): + value = torch_dialect.ConstantNumberOp(value) + else: + value = get_op_result_or_value(value) + assert str(value.type) in {'!torch.float', '!torch.int'}, f'`value` should be a !torch.number but is {type(value).__module__}.{type(value).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenAddcdiv_Op, self).__init__(result_type, self_, tensor1, tensor2, value, loc=loc, ip=ip) + + +class AtenMaximumOp: + def __init__(self, self_: Value, other: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}' + else: + other = get_op_result_or_value(other) + assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenMaximumOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenMinimumOp: + def __init__(self, self_: Value, other: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}' + else: + other = get_op_result_or_value(other) + assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenMinimumOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenMishOp: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenMishOp, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenRsubScalarOp: + def __init__(self, self_: Value, other: "Number", alpha: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + other = torch_dialect.ConstantNumberOp(other) + else: + other = get_op_result_or_value(other) + assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}' + + if not is_mlir_value(alpha): + alpha = torch_dialect.ConstantNumberOp(alpha) + else: + alpha = get_op_result_or_value(alpha) + assert str(alpha.type) in {'!torch.float', '!torch.int'}, f'`alpha` should be a !torch.number but is {type(alpha).__module__}.{type(alpha).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenRsubScalarOp, self).__init__(result_type, self_, other, alpha, loc=loc, ip=ip) + + +class AtenGeluOp: + def __init__(self, self_: Value, approximate: str, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(approximate): + approximate = torch_dialect.ConstantStrOp(approximate) + else: + approximate = get_op_result_or_value(approximate) + assert str(approximate.type) == '!torch.str', f'`approximate` should be a !torch.str but is {type(approximate).__module__}.{type(approximate).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenGeluOp, self).__init__(result_type, self_, approximate, loc=loc, ip=ip) + + +class AtenPowTensorScalarOp: + def __init__(self, self_: Value, exponent: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(exponent): + exponent = torch_dialect.ConstantNumberOp(exponent) + else: + exponent = get_op_result_or_value(exponent) + assert str(exponent.type) in {'!torch.float', '!torch.int'}, f'`exponent` should be a !torch.number but is {type(exponent).__module__}.{type(exponent).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenPowTensorScalarOp, self).__init__(result_type, self_, exponent, loc=loc, ip=ip) + + +class AtenPowTensorTensorOp: + def __init__(self, self_: Value, exponent: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(exponent): + assert is_mlir_value(exponent), f'`exponent` should be a Value but is {type(exponent).__module__}.{type(exponent).__name__}' + else: + exponent = get_op_result_or_value(exponent) + assert str(exponent.type).startswith("!torch.vtensor"), f'`exponent` should be a torch.vtensor but is {type(exponent).__module__}.{type(exponent).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenPowTensorTensorOp, self).__init__(result_type, self_, exponent, loc=loc, ip=ip) + + +class AtenThresholdBackwardOp: + def __init__(self, grad_output: Value, self_: Value, threshold: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(grad_output): + assert is_mlir_value(grad_output), f'`grad_output` should be a Value but is {type(grad_output).__module__}.{type(grad_output).__name__}' + else: + grad_output = get_op_result_or_value(grad_output) + assert str(grad_output.type).startswith("!torch.vtensor"), f'`grad_output` should be a torch.vtensor but is {type(grad_output).__module__}.{type(grad_output).__name__}' + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(threshold): + threshold = torch_dialect.ConstantNumberOp(threshold) + else: + threshold = get_op_result_or_value(threshold) + assert str(threshold.type) in {'!torch.float', '!torch.int'}, f'`threshold` should be a !torch.number but is {type(threshold).__module__}.{type(threshold).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenThresholdBackwardOp, self).__init__(result_type, grad_output, self_, threshold, loc=loc, ip=ip) + + +class AtenFloorDivideOp: + def __init__(self, self_: Value, other: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}' + else: + other = get_op_result_or_value(other) + assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenFloorDivideOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenSoftplusOp: + def __init__(self, self_: Value, beta: "Number", threshold: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(beta): + beta = torch_dialect.ConstantNumberOp(beta) + else: + beta = get_op_result_or_value(beta) + assert str(beta.type) in {'!torch.float', '!torch.int'}, f'`beta` should be a !torch.number but is {type(beta).__module__}.{type(beta).__name__}' + + if not is_mlir_value(threshold): + threshold = torch_dialect.ConstantNumberOp(threshold) + else: + threshold = get_op_result_or_value(threshold) + assert str(threshold.type) in {'!torch.float', '!torch.int'}, f'`threshold` should be a !torch.number but is {type(threshold).__module__}.{type(threshold).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenSoftplusOp, self).__init__(result_type, self_, beta, threshold, loc=loc, ip=ip) + + +class AtenPreluOp: + def __init__(self, self_: Value, weight: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(weight): + assert is_mlir_value(weight), f'`weight` should be a Value but is {type(weight).__module__}.{type(weight).__name__}' + else: + weight = get_op_result_or_value(weight) + assert str(weight.type).startswith("!torch.vtensor"), f'`weight` should be a torch.vtensor but is {type(weight).__module__}.{type(weight).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenPreluOp, self).__init__(result_type, self_, weight, loc=loc, ip=ip) + + +class AtenTriuOp: + def __init__(self, self_: Value, diagonal: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(diagonal): + diagonal = torch_dialect.ConstantIntOp(diagonal) + else: + diagonal = get_op_result_or_value(diagonal) + assert str(diagonal.type) == '!torch.int', f'`diagonal` should be a !torch.int but is {type(diagonal).__module__}.{type(diagonal).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenTriuOp, self).__init__(result_type, self_, diagonal, loc=loc, ip=ip) + + +class AtenTriu_Op: + def __init__(self, self_: Value, diagonal: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(diagonal): + diagonal = torch_dialect.ConstantIntOp(diagonal) + else: + diagonal = get_op_result_or_value(diagonal) + assert str(diagonal.type) == '!torch.int', f'`diagonal` should be a !torch.int but is {type(diagonal).__module__}.{type(diagonal).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenTriu_Op, self).__init__(result_type, self_, diagonal, loc=loc, ip=ip) + + +class AtenRoundOp: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenRoundOp, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenRound_Op: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenRound_Op, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenIndexPutHackedTwinOp: + def __init__(self, self_: Value, indices: List[Value], values: Value, accumulate: bool, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(indices): + indices = torch_dialect.PrimListConstructOp(indices) + else: + indices = get_op_result_or_value(indices) + assert str(indices.type) == '!torch.list', f'`indices` should be a !torch.list but is {type(indices).__module__}.{type(indices).__name__}' + + if not is_mlir_value(values): + assert is_mlir_value(values), f'`values` should be a Value but is {type(values).__module__}.{type(values).__name__}' + else: + values = get_op_result_or_value(values) + assert str(values.type).startswith("!torch.vtensor"), f'`values` should be a torch.vtensor but is {type(values).__module__}.{type(values).__name__}' + + if not is_mlir_value(accumulate): + accumulate = torch_dialect.ConstantBoolOp(accumulate) + else: + accumulate = get_op_result_or_value(accumulate) + assert str(accumulate.type) == '!torch.bool', f'`accumulate` should be a !torch.bool but is {type(accumulate).__module__}.{type(accumulate).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenIndexPutHackedTwinOp, self).__init__(result_type, self_, indices, values, accumulate, loc=loc, ip=ip) + + +class AtenIndexPut_HackedTwinOp: + def __init__(self, self_: Value, indices: List[Value], values: Value, accumulate: bool, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(indices): + indices = torch_dialect.PrimListConstructOp(indices) + else: + indices = get_op_result_or_value(indices) + assert str(indices.type) == '!torch.list', f'`indices` should be a !torch.list but is {type(indices).__module__}.{type(indices).__name__}' + + if not is_mlir_value(values): + assert is_mlir_value(values), f'`values` should be a Value but is {type(values).__module__}.{type(values).__name__}' + else: + values = get_op_result_or_value(values) + assert str(values.type).startswith("!torch.vtensor"), f'`values` should be a torch.vtensor but is {type(values).__module__}.{type(values).__name__}' + + if not is_mlir_value(accumulate): + accumulate = torch_dialect.ConstantBoolOp(accumulate) + else: + accumulate = get_op_result_or_value(accumulate) + assert str(accumulate.type) == '!torch.bool', f'`accumulate` should be a !torch.bool but is {type(accumulate).__module__}.{type(accumulate).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenIndexPut_HackedTwinOp, self).__init__(result_type, self_, indices, values, accumulate, loc=loc, ip=ip) + + +class AtenLinearOp: + def __init__(self, input: Value, weight: Value, bias: Optional[Value], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(input): + assert is_mlir_value(input), f'`input` should be a Value but is {type(input).__module__}.{type(input).__name__}' + else: + input = get_op_result_or_value(input) + assert str(input.type).startswith("!torch.vtensor"), f'`input` should be a torch.vtensor but is {type(input).__module__}.{type(input).__name__}' + + if not is_mlir_value(weight): + assert is_mlir_value(weight), f'`weight` should be a Value but is {type(weight).__module__}.{type(weight).__name__}' + else: + weight = get_op_result_or_value(weight) + assert str(weight.type).startswith("!torch.vtensor"), f'`weight` should be a torch.vtensor but is {type(weight).__module__}.{type(weight).__name__}' + + if not is_mlir_value(bias): + if bias is not None: + assert is_mlir_value(bias), f'`bias` should be a Value but is {type(bias).__module__}.{type(bias).__name__}' + else: + bias = torch_dialect.ConstantNoneOp() + else: + bias = get_op_result_or_value(bias) + assert str(bias.type).startswith("!torch.vtensor"), f'`bias` should be a torch.vtensor but is {type(bias).__module__}.{type(bias).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenLinearOp, self).__init__(result_type, input, weight, bias, loc=loc, ip=ip) + + +class AtenMmOp: + def __init__(self, self_: Value, mat2: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(mat2): + assert is_mlir_value(mat2), f'`mat2` should be a Value but is {type(mat2).__module__}.{type(mat2).__name__}' + else: + mat2 = get_op_result_or_value(mat2) + assert str(mat2.type).startswith("!torch.vtensor"), f'`mat2` should be a torch.vtensor but is {type(mat2).__module__}.{type(mat2).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenMmOp, self).__init__(result_type, self_, mat2, loc=loc, ip=ip) + + +class AtenAddmmOp: + def __init__(self, self_: Value, mat1: Value, mat2: Value, beta: "Number", alpha: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(mat1): + assert is_mlir_value(mat1), f'`mat1` should be a Value but is {type(mat1).__module__}.{type(mat1).__name__}' + else: + mat1 = get_op_result_or_value(mat1) + assert str(mat1.type).startswith("!torch.vtensor"), f'`mat1` should be a torch.vtensor but is {type(mat1).__module__}.{type(mat1).__name__}' + + if not is_mlir_value(mat2): + assert is_mlir_value(mat2), f'`mat2` should be a Value but is {type(mat2).__module__}.{type(mat2).__name__}' + else: + mat2 = get_op_result_or_value(mat2) + assert str(mat2.type).startswith("!torch.vtensor"), f'`mat2` should be a torch.vtensor but is {type(mat2).__module__}.{type(mat2).__name__}' + + if not is_mlir_value(beta): + beta = torch_dialect.ConstantNumberOp(beta) + else: + beta = get_op_result_or_value(beta) + assert str(beta.type) in {'!torch.float', '!torch.int'}, f'`beta` should be a !torch.number but is {type(beta).__module__}.{type(beta).__name__}' + + if not is_mlir_value(alpha): + alpha = torch_dialect.ConstantNumberOp(alpha) + else: + alpha = get_op_result_or_value(alpha) + assert str(alpha.type) in {'!torch.float', '!torch.int'}, f'`alpha` should be a !torch.number but is {type(alpha).__module__}.{type(alpha).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenAddmmOp, self).__init__(result_type, self_, mat1, mat2, beta, alpha, loc=loc, ip=ip) + + +class AtenMatmulOp: + def __init__(self, self_: Value, other: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}' + else: + other = get_op_result_or_value(other) + assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenMatmulOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenMvOp: + def __init__(self, self_: Value, vec: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(vec): + assert is_mlir_value(vec), f'`vec` should be a Value but is {type(vec).__module__}.{type(vec).__name__}' + else: + vec = get_op_result_or_value(vec) + assert str(vec.type).startswith("!torch.vtensor"), f'`vec` should be a torch.vtensor but is {type(vec).__module__}.{type(vec).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenMvOp, self).__init__(result_type, self_, vec, loc=loc, ip=ip) + + +class AtenConv2dOp: + def __init__(self, input: Value, weight: Value, bias: Optional[Value], stride: List[int], padding: List[int], dilation: List[int], groups: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(input): + assert is_mlir_value(input), f'`input` should be a Value but is {type(input).__module__}.{type(input).__name__}' + else: + input = get_op_result_or_value(input) + assert str(input.type).startswith("!torch.vtensor"), f'`input` should be a torch.vtensor but is {type(input).__module__}.{type(input).__name__}' + + if not is_mlir_value(weight): + assert is_mlir_value(weight), f'`weight` should be a Value but is {type(weight).__module__}.{type(weight).__name__}' + else: + weight = get_op_result_or_value(weight) + assert str(weight.type).startswith("!torch.vtensor"), f'`weight` should be a torch.vtensor but is {type(weight).__module__}.{type(weight).__name__}' + + if not is_mlir_value(bias): + if bias is not None: + assert is_mlir_value(bias), f'`bias` should be a Value but is {type(bias).__module__}.{type(bias).__name__}' + else: + bias = torch_dialect.ConstantNoneOp() + else: + bias = get_op_result_or_value(bias) + assert str(bias.type).startswith("!torch.vtensor"), f'`bias` should be a torch.vtensor but is {type(bias).__module__}.{type(bias).__name__}' + + if not is_mlir_value(stride): + stride = list(map(torch_dialect.ConstantIntOp, stride)) + stride = torch_dialect.PrimListConstructOp(stride) + else: + stride = get_op_result_or_value(stride) + assert str(stride.type) == '!torch.list', f'`stride` should be a !torch.list but is {type(stride).__module__}.{type(stride).__name__}' + + if not is_mlir_value(padding): + padding = list(map(torch_dialect.ConstantIntOp, padding)) + padding = torch_dialect.PrimListConstructOp(padding) + else: + padding = get_op_result_or_value(padding) + assert str(padding.type) == '!torch.list', f'`padding` should be a !torch.list but is {type(padding).__module__}.{type(padding).__name__}' + + if not is_mlir_value(dilation): + dilation = list(map(torch_dialect.ConstantIntOp, dilation)) + dilation = torch_dialect.PrimListConstructOp(dilation) + else: + dilation = get_op_result_or_value(dilation) + assert str(dilation.type) == '!torch.list', f'`dilation` should be a !torch.list but is {type(dilation).__module__}.{type(dilation).__name__}' + + if not is_mlir_value(groups): + groups = torch_dialect.ConstantIntOp(groups) + else: + groups = get_op_result_or_value(groups) + assert str(groups.type) == '!torch.int', f'`groups` should be a !torch.int but is {type(groups).__module__}.{type(groups).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenConv2dOp, self).__init__(result_type, input, weight, bias, stride, padding, dilation, groups, loc=loc, ip=ip) + + +class AtenConvTranspose1dOp: + def __init__(self, input: Value, weight: Value, bias: Optional[Value], stride: List[int], padding: List[int], output_padding: List[int], groups: int, dilation: List[int], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(input): + assert is_mlir_value(input), f'`input` should be a Value but is {type(input).__module__}.{type(input).__name__}' + else: + input = get_op_result_or_value(input) + assert str(input.type).startswith("!torch.vtensor"), f'`input` should be a torch.vtensor but is {type(input).__module__}.{type(input).__name__}' + + if not is_mlir_value(weight): + assert is_mlir_value(weight), f'`weight` should be a Value but is {type(weight).__module__}.{type(weight).__name__}' + else: + weight = get_op_result_or_value(weight) + assert str(weight.type).startswith("!torch.vtensor"), f'`weight` should be a torch.vtensor but is {type(weight).__module__}.{type(weight).__name__}' + + if not is_mlir_value(bias): + if bias is not None: + assert is_mlir_value(bias), f'`bias` should be a Value but is {type(bias).__module__}.{type(bias).__name__}' + else: + bias = torch_dialect.ConstantNoneOp() + else: + bias = get_op_result_or_value(bias) + assert str(bias.type).startswith("!torch.vtensor"), f'`bias` should be a torch.vtensor but is {type(bias).__module__}.{type(bias).__name__}' + + if not is_mlir_value(stride): + stride = list(map(torch_dialect.ConstantIntOp, stride)) + stride = torch_dialect.PrimListConstructOp(stride) + else: + stride = get_op_result_or_value(stride) + assert str(stride.type) == '!torch.list', f'`stride` should be a !torch.list but is {type(stride).__module__}.{type(stride).__name__}' + + if not is_mlir_value(padding): + padding = list(map(torch_dialect.ConstantIntOp, padding)) + padding = torch_dialect.PrimListConstructOp(padding) + else: + padding = get_op_result_or_value(padding) + assert str(padding.type) == '!torch.list', f'`padding` should be a !torch.list but is {type(padding).__module__}.{type(padding).__name__}' + + if not is_mlir_value(output_padding): + output_padding = list(map(torch_dialect.ConstantIntOp, output_padding)) + output_padding = torch_dialect.PrimListConstructOp(output_padding) + else: + output_padding = get_op_result_or_value(output_padding) + assert str(output_padding.type) == '!torch.list', f'`output_padding` should be a !torch.list but is {type(output_padding).__module__}.{type(output_padding).__name__}' + + if not is_mlir_value(groups): + groups = torch_dialect.ConstantIntOp(groups) + else: + groups = get_op_result_or_value(groups) + assert str(groups.type) == '!torch.int', f'`groups` should be a !torch.int but is {type(groups).__module__}.{type(groups).__name__}' + + if not is_mlir_value(dilation): + dilation = list(map(torch_dialect.ConstantIntOp, dilation)) + dilation = torch_dialect.PrimListConstructOp(dilation) + else: + dilation = get_op_result_or_value(dilation) + assert str(dilation.type) == '!torch.list', f'`dilation` should be a !torch.list but is {type(dilation).__module__}.{type(dilation).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenConvTranspose1dOp, self).__init__(result_type, input, weight, bias, stride, padding, output_padding, groups, dilation, loc=loc, ip=ip) + + +class AtenConvTranspose2dInputOp: + def __init__(self, input: Value, weight: Value, bias: Optional[Value], stride: List[int], padding: List[int], output_padding: List[int], groups: int, dilation: List[int], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(input): + assert is_mlir_value(input), f'`input` should be a Value but is {type(input).__module__}.{type(input).__name__}' + else: + input = get_op_result_or_value(input) + assert str(input.type).startswith("!torch.vtensor"), f'`input` should be a torch.vtensor but is {type(input).__module__}.{type(input).__name__}' + + if not is_mlir_value(weight): + assert is_mlir_value(weight), f'`weight` should be a Value but is {type(weight).__module__}.{type(weight).__name__}' + else: + weight = get_op_result_or_value(weight) + assert str(weight.type).startswith("!torch.vtensor"), f'`weight` should be a torch.vtensor but is {type(weight).__module__}.{type(weight).__name__}' + + if not is_mlir_value(bias): + if bias is not None: + assert is_mlir_value(bias), f'`bias` should be a Value but is {type(bias).__module__}.{type(bias).__name__}' + else: + bias = torch_dialect.ConstantNoneOp() + else: + bias = get_op_result_or_value(bias) + assert str(bias.type).startswith("!torch.vtensor"), f'`bias` should be a torch.vtensor but is {type(bias).__module__}.{type(bias).__name__}' + + if not is_mlir_value(stride): + stride = list(map(torch_dialect.ConstantIntOp, stride)) + stride = torch_dialect.PrimListConstructOp(stride) + else: + stride = get_op_result_or_value(stride) + assert str(stride.type) == '!torch.list', f'`stride` should be a !torch.list but is {type(stride).__module__}.{type(stride).__name__}' + + if not is_mlir_value(padding): + padding = list(map(torch_dialect.ConstantIntOp, padding)) + padding = torch_dialect.PrimListConstructOp(padding) + else: + padding = get_op_result_or_value(padding) + assert str(padding.type) == '!torch.list', f'`padding` should be a !torch.list but is {type(padding).__module__}.{type(padding).__name__}' + + if not is_mlir_value(output_padding): + output_padding = list(map(torch_dialect.ConstantIntOp, output_padding)) + output_padding = torch_dialect.PrimListConstructOp(output_padding) + else: + output_padding = get_op_result_or_value(output_padding) + assert str(output_padding.type) == '!torch.list', f'`output_padding` should be a !torch.list but is {type(output_padding).__module__}.{type(output_padding).__name__}' + + if not is_mlir_value(groups): + groups = torch_dialect.ConstantIntOp(groups) + else: + groups = get_op_result_or_value(groups) + assert str(groups.type) == '!torch.int', f'`groups` should be a !torch.int but is {type(groups).__module__}.{type(groups).__name__}' + + if not is_mlir_value(dilation): + dilation = list(map(torch_dialect.ConstantIntOp, dilation)) + dilation = torch_dialect.PrimListConstructOp(dilation) + else: + dilation = get_op_result_or_value(dilation) + assert str(dilation.type) == '!torch.list', f'`dilation` should be a !torch.list but is {type(dilation).__module__}.{type(dilation).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenConvTranspose2dInputOp, self).__init__(result_type, input, weight, bias, stride, padding, output_padding, groups, dilation, loc=loc, ip=ip) + + +class AtenConvTranspose3dInputOp: + def __init__(self, input: Value, weight: Value, bias: Optional[Value], stride: List[int], padding: List[int], output_padding: List[int], groups: int, dilation: List[int], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(input): + assert is_mlir_value(input), f'`input` should be a Value but is {type(input).__module__}.{type(input).__name__}' + else: + input = get_op_result_or_value(input) + assert str(input.type).startswith("!torch.vtensor"), f'`input` should be a torch.vtensor but is {type(input).__module__}.{type(input).__name__}' + + if not is_mlir_value(weight): + assert is_mlir_value(weight), f'`weight` should be a Value but is {type(weight).__module__}.{type(weight).__name__}' + else: + weight = get_op_result_or_value(weight) + assert str(weight.type).startswith("!torch.vtensor"), f'`weight` should be a torch.vtensor but is {type(weight).__module__}.{type(weight).__name__}' + + if not is_mlir_value(bias): + if bias is not None: + assert is_mlir_value(bias), f'`bias` should be a Value but is {type(bias).__module__}.{type(bias).__name__}' + else: + bias = torch_dialect.ConstantNoneOp() + else: + bias = get_op_result_or_value(bias) + assert str(bias.type).startswith("!torch.vtensor"), f'`bias` should be a torch.vtensor but is {type(bias).__module__}.{type(bias).__name__}' + + if not is_mlir_value(stride): + stride = list(map(torch_dialect.ConstantIntOp, stride)) + stride = torch_dialect.PrimListConstructOp(stride) + else: + stride = get_op_result_or_value(stride) + assert str(stride.type) == '!torch.list', f'`stride` should be a !torch.list but is {type(stride).__module__}.{type(stride).__name__}' + + if not is_mlir_value(padding): + padding = list(map(torch_dialect.ConstantIntOp, padding)) + padding = torch_dialect.PrimListConstructOp(padding) + else: + padding = get_op_result_or_value(padding) + assert str(padding.type) == '!torch.list', f'`padding` should be a !torch.list but is {type(padding).__module__}.{type(padding).__name__}' + + if not is_mlir_value(output_padding): + output_padding = list(map(torch_dialect.ConstantIntOp, output_padding)) + output_padding = torch_dialect.PrimListConstructOp(output_padding) + else: + output_padding = get_op_result_or_value(output_padding) + assert str(output_padding.type) == '!torch.list', f'`output_padding` should be a !torch.list but is {type(output_padding).__module__}.{type(output_padding).__name__}' + + if not is_mlir_value(groups): + groups = torch_dialect.ConstantIntOp(groups) + else: + groups = get_op_result_or_value(groups) + assert str(groups.type) == '!torch.int', f'`groups` should be a !torch.int but is {type(groups).__module__}.{type(groups).__name__}' + + if not is_mlir_value(dilation): + dilation = list(map(torch_dialect.ConstantIntOp, dilation)) + dilation = torch_dialect.PrimListConstructOp(dilation) + else: + dilation = get_op_result_or_value(dilation) + assert str(dilation.type) == '!torch.list', f'`dilation` should be a !torch.list but is {type(dilation).__module__}.{type(dilation).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenConvTranspose3dInputOp, self).__init__(result_type, input, weight, bias, stride, padding, output_padding, groups, dilation, loc=loc, ip=ip) + + +class AtenConvolutionOp: + def __init__(self, input: Value, weight: Value, bias: Optional[Value], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(input): + assert is_mlir_value(input), f'`input` should be a Value but is {type(input).__module__}.{type(input).__name__}' + else: + input = get_op_result_or_value(input) + assert str(input.type).startswith("!torch.vtensor"), f'`input` should be a torch.vtensor but is {type(input).__module__}.{type(input).__name__}' + + if not is_mlir_value(weight): + assert is_mlir_value(weight), f'`weight` should be a Value but is {type(weight).__module__}.{type(weight).__name__}' + else: + weight = get_op_result_or_value(weight) + assert str(weight.type).startswith("!torch.vtensor"), f'`weight` should be a torch.vtensor but is {type(weight).__module__}.{type(weight).__name__}' + + if not is_mlir_value(bias): + if bias is not None: + assert is_mlir_value(bias), f'`bias` should be a Value but is {type(bias).__module__}.{type(bias).__name__}' + else: + bias = torch_dialect.ConstantNoneOp() + else: + bias = get_op_result_or_value(bias) + assert str(bias.type).startswith("!torch.vtensor"), f'`bias` should be a torch.vtensor but is {type(bias).__module__}.{type(bias).__name__}' + + if not is_mlir_value(stride): + stride = list(map(torch_dialect.ConstantIntOp, stride)) + stride = torch_dialect.PrimListConstructOp(stride) + else: + stride = get_op_result_or_value(stride) + assert str(stride.type) == '!torch.list', f'`stride` should be a !torch.list but is {type(stride).__module__}.{type(stride).__name__}' + + if not is_mlir_value(padding): + padding = list(map(torch_dialect.ConstantIntOp, padding)) + padding = torch_dialect.PrimListConstructOp(padding) + else: + padding = get_op_result_or_value(padding) + assert str(padding.type) == '!torch.list', f'`padding` should be a !torch.list but is {type(padding).__module__}.{type(padding).__name__}' + + if not is_mlir_value(dilation): + dilation = list(map(torch_dialect.ConstantIntOp, dilation)) + dilation = torch_dialect.PrimListConstructOp(dilation) + else: + dilation = get_op_result_or_value(dilation) + assert str(dilation.type) == '!torch.list', f'`dilation` should be a !torch.list but is {type(dilation).__module__}.{type(dilation).__name__}' + + if not is_mlir_value(transposed): + transposed = torch_dialect.ConstantBoolOp(transposed) + else: + transposed = get_op_result_or_value(transposed) + assert str(transposed.type) == '!torch.bool', f'`transposed` should be a !torch.bool but is {type(transposed).__module__}.{type(transposed).__name__}' + + if not is_mlir_value(output_padding): + output_padding = list(map(torch_dialect.ConstantIntOp, output_padding)) + output_padding = torch_dialect.PrimListConstructOp(output_padding) + else: + output_padding = get_op_result_or_value(output_padding) + assert str(output_padding.type) == '!torch.list', f'`output_padding` should be a !torch.list but is {type(output_padding).__module__}.{type(output_padding).__name__}' + + if not is_mlir_value(groups): + groups = torch_dialect.ConstantIntOp(groups) + else: + groups = get_op_result_or_value(groups) + assert str(groups.type) == '!torch.int', f'`groups` should be a !torch.int but is {type(groups).__module__}.{type(groups).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenConvolutionOp, self).__init__(result_type, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, loc=loc, ip=ip) + + +class AtenConvolutionOverrideableOp: + def __init__(self, input: Value, weight: Value, bias: Optional[Value], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(input): + assert is_mlir_value(input), f'`input` should be a Value but is {type(input).__module__}.{type(input).__name__}' + else: + input = get_op_result_or_value(input) + assert str(input.type).startswith("!torch.vtensor"), f'`input` should be a torch.vtensor but is {type(input).__module__}.{type(input).__name__}' + + if not is_mlir_value(weight): + assert is_mlir_value(weight), f'`weight` should be a Value but is {type(weight).__module__}.{type(weight).__name__}' + else: + weight = get_op_result_or_value(weight) + assert str(weight.type).startswith("!torch.vtensor"), f'`weight` should be a torch.vtensor but is {type(weight).__module__}.{type(weight).__name__}' + + if not is_mlir_value(bias): + if bias is not None: + assert is_mlir_value(bias), f'`bias` should be a Value but is {type(bias).__module__}.{type(bias).__name__}' + else: + bias = torch_dialect.ConstantNoneOp() + else: + bias = get_op_result_or_value(bias) + assert str(bias.type).startswith("!torch.vtensor"), f'`bias` should be a torch.vtensor but is {type(bias).__module__}.{type(bias).__name__}' + + if not is_mlir_value(stride): + stride = list(map(torch_dialect.ConstantIntOp, stride)) + stride = torch_dialect.PrimListConstructOp(stride) + else: + stride = get_op_result_or_value(stride) + assert str(stride.type) == '!torch.list', f'`stride` should be a !torch.list but is {type(stride).__module__}.{type(stride).__name__}' + + if not is_mlir_value(padding): + padding = list(map(torch_dialect.ConstantIntOp, padding)) + padding = torch_dialect.PrimListConstructOp(padding) + else: + padding = get_op_result_or_value(padding) + assert str(padding.type) == '!torch.list', f'`padding` should be a !torch.list but is {type(padding).__module__}.{type(padding).__name__}' + + if not is_mlir_value(dilation): + dilation = list(map(torch_dialect.ConstantIntOp, dilation)) + dilation = torch_dialect.PrimListConstructOp(dilation) + else: + dilation = get_op_result_or_value(dilation) + assert str(dilation.type) == '!torch.list', f'`dilation` should be a !torch.list but is {type(dilation).__module__}.{type(dilation).__name__}' + + if not is_mlir_value(transposed): + transposed = torch_dialect.ConstantBoolOp(transposed) + else: + transposed = get_op_result_or_value(transposed) + assert str(transposed.type) == '!torch.bool', f'`transposed` should be a !torch.bool but is {type(transposed).__module__}.{type(transposed).__name__}' + + if not is_mlir_value(output_padding): + output_padding = list(map(torch_dialect.ConstantIntOp, output_padding)) + output_padding = torch_dialect.PrimListConstructOp(output_padding) + else: + output_padding = get_op_result_or_value(output_padding) + assert str(output_padding.type) == '!torch.list', f'`output_padding` should be a !torch.list but is {type(output_padding).__module__}.{type(output_padding).__name__}' + + if not is_mlir_value(groups): + groups = torch_dialect.ConstantIntOp(groups) + else: + groups = get_op_result_or_value(groups) + assert str(groups.type) == '!torch.int', f'`groups` should be a !torch.int but is {type(groups).__module__}.{type(groups).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenConvolutionOverrideableOp, self).__init__(result_type, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, loc=loc, ip=ip) + + +class Aten_ConvolutionOp: + def __init__(self, input: Value, weight: Value, bias: Optional[Value], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool, allow_tf32: bool, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(input): + assert is_mlir_value(input), f'`input` should be a Value but is {type(input).__module__}.{type(input).__name__}' + else: + input = get_op_result_or_value(input) + assert str(input.type).startswith("!torch.vtensor"), f'`input` should be a torch.vtensor but is {type(input).__module__}.{type(input).__name__}' + + if not is_mlir_value(weight): + assert is_mlir_value(weight), f'`weight` should be a Value but is {type(weight).__module__}.{type(weight).__name__}' + else: + weight = get_op_result_or_value(weight) + assert str(weight.type).startswith("!torch.vtensor"), f'`weight` should be a torch.vtensor but is {type(weight).__module__}.{type(weight).__name__}' + + if not is_mlir_value(bias): + if bias is not None: + assert is_mlir_value(bias), f'`bias` should be a Value but is {type(bias).__module__}.{type(bias).__name__}' + else: + bias = torch_dialect.ConstantNoneOp() + else: + bias = get_op_result_or_value(bias) + assert str(bias.type).startswith("!torch.vtensor"), f'`bias` should be a torch.vtensor but is {type(bias).__module__}.{type(bias).__name__}' + + if not is_mlir_value(stride): + stride = list(map(torch_dialect.ConstantIntOp, stride)) + stride = torch_dialect.PrimListConstructOp(stride) + else: + stride = get_op_result_or_value(stride) + assert str(stride.type) == '!torch.list', f'`stride` should be a !torch.list but is {type(stride).__module__}.{type(stride).__name__}' + + if not is_mlir_value(padding): + padding = list(map(torch_dialect.ConstantIntOp, padding)) + padding = torch_dialect.PrimListConstructOp(padding) + else: + padding = get_op_result_or_value(padding) + assert str(padding.type) == '!torch.list', f'`padding` should be a !torch.list but is {type(padding).__module__}.{type(padding).__name__}' + + if not is_mlir_value(dilation): + dilation = list(map(torch_dialect.ConstantIntOp, dilation)) + dilation = torch_dialect.PrimListConstructOp(dilation) + else: + dilation = get_op_result_or_value(dilation) + assert str(dilation.type) == '!torch.list', f'`dilation` should be a !torch.list but is {type(dilation).__module__}.{type(dilation).__name__}' + + if not is_mlir_value(transposed): + transposed = torch_dialect.ConstantBoolOp(transposed) + else: + transposed = get_op_result_or_value(transposed) + assert str(transposed.type) == '!torch.bool', f'`transposed` should be a !torch.bool but is {type(transposed).__module__}.{type(transposed).__name__}' + + if not is_mlir_value(output_padding): + output_padding = list(map(torch_dialect.ConstantIntOp, output_padding)) + output_padding = torch_dialect.PrimListConstructOp(output_padding) + else: + output_padding = get_op_result_or_value(output_padding) + assert str(output_padding.type) == '!torch.list', f'`output_padding` should be a !torch.list but is {type(output_padding).__module__}.{type(output_padding).__name__}' + + if not is_mlir_value(groups): + groups = torch_dialect.ConstantIntOp(groups) + else: + groups = get_op_result_or_value(groups) + assert str(groups.type) == '!torch.int', f'`groups` should be a !torch.int but is {type(groups).__module__}.{type(groups).__name__}' + + if not is_mlir_value(benchmark): + benchmark = torch_dialect.ConstantBoolOp(benchmark) + else: + benchmark = get_op_result_or_value(benchmark) + assert str(benchmark.type) == '!torch.bool', f'`benchmark` should be a !torch.bool but is {type(benchmark).__module__}.{type(benchmark).__name__}' + + if not is_mlir_value(deterministic): + deterministic = torch_dialect.ConstantBoolOp(deterministic) + else: + deterministic = get_op_result_or_value(deterministic) + assert str(deterministic.type) == '!torch.bool', f'`deterministic` should be a !torch.bool but is {type(deterministic).__module__}.{type(deterministic).__name__}' + + if not is_mlir_value(cudnn_enabled): + cudnn_enabled = torch_dialect.ConstantBoolOp(cudnn_enabled) + else: + cudnn_enabled = get_op_result_or_value(cudnn_enabled) + assert str(cudnn_enabled.type) == '!torch.bool', f'`cudnn_enabled` should be a !torch.bool but is {type(cudnn_enabled).__module__}.{type(cudnn_enabled).__name__}' + + if not is_mlir_value(allow_tf32): + allow_tf32 = torch_dialect.ConstantBoolOp(allow_tf32) + else: + allow_tf32 = get_op_result_or_value(allow_tf32) + assert str(allow_tf32.type) == '!torch.bool', f'`allow_tf32` should be a !torch.bool but is {type(allow_tf32).__module__}.{type(allow_tf32).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(Aten_ConvolutionOp, self).__init__(result_type, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled, allow_tf32, loc=loc, ip=ip) + + +class Aten_ConvolutionDeprecatedOp: + def __init__(self, input: Value, weight: Value, bias: Optional[Value], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(input): + assert is_mlir_value(input), f'`input` should be a Value but is {type(input).__module__}.{type(input).__name__}' + else: + input = get_op_result_or_value(input) + assert str(input.type).startswith("!torch.vtensor"), f'`input` should be a torch.vtensor but is {type(input).__module__}.{type(input).__name__}' + + if not is_mlir_value(weight): + assert is_mlir_value(weight), f'`weight` should be a Value but is {type(weight).__module__}.{type(weight).__name__}' + else: + weight = get_op_result_or_value(weight) + assert str(weight.type).startswith("!torch.vtensor"), f'`weight` should be a torch.vtensor but is {type(weight).__module__}.{type(weight).__name__}' + + if not is_mlir_value(bias): + if bias is not None: + assert is_mlir_value(bias), f'`bias` should be a Value but is {type(bias).__module__}.{type(bias).__name__}' + else: + bias = torch_dialect.ConstantNoneOp() + else: + bias = get_op_result_or_value(bias) + assert str(bias.type).startswith("!torch.vtensor"), f'`bias` should be a torch.vtensor but is {type(bias).__module__}.{type(bias).__name__}' + + if not is_mlir_value(stride): + stride = list(map(torch_dialect.ConstantIntOp, stride)) + stride = torch_dialect.PrimListConstructOp(stride) + else: + stride = get_op_result_or_value(stride) + assert str(stride.type) == '!torch.list', f'`stride` should be a !torch.list but is {type(stride).__module__}.{type(stride).__name__}' + + if not is_mlir_value(padding): + padding = list(map(torch_dialect.ConstantIntOp, padding)) + padding = torch_dialect.PrimListConstructOp(padding) + else: + padding = get_op_result_or_value(padding) + assert str(padding.type) == '!torch.list', f'`padding` should be a !torch.list but is {type(padding).__module__}.{type(padding).__name__}' + + if not is_mlir_value(dilation): + dilation = list(map(torch_dialect.ConstantIntOp, dilation)) + dilation = torch_dialect.PrimListConstructOp(dilation) + else: + dilation = get_op_result_or_value(dilation) + assert str(dilation.type) == '!torch.list', f'`dilation` should be a !torch.list but is {type(dilation).__module__}.{type(dilation).__name__}' + + if not is_mlir_value(transposed): + transposed = torch_dialect.ConstantBoolOp(transposed) + else: + transposed = get_op_result_or_value(transposed) + assert str(transposed.type) == '!torch.bool', f'`transposed` should be a !torch.bool but is {type(transposed).__module__}.{type(transposed).__name__}' + + if not is_mlir_value(output_padding): + output_padding = list(map(torch_dialect.ConstantIntOp, output_padding)) + output_padding = torch_dialect.PrimListConstructOp(output_padding) + else: + output_padding = get_op_result_or_value(output_padding) + assert str(output_padding.type) == '!torch.list', f'`output_padding` should be a !torch.list but is {type(output_padding).__module__}.{type(output_padding).__name__}' + + if not is_mlir_value(groups): + groups = torch_dialect.ConstantIntOp(groups) + else: + groups = get_op_result_or_value(groups) + assert str(groups.type) == '!torch.int', f'`groups` should be a !torch.int but is {type(groups).__module__}.{type(groups).__name__}' + + if not is_mlir_value(benchmark): + benchmark = torch_dialect.ConstantBoolOp(benchmark) + else: + benchmark = get_op_result_or_value(benchmark) + assert str(benchmark.type) == '!torch.bool', f'`benchmark` should be a !torch.bool but is {type(benchmark).__module__}.{type(benchmark).__name__}' + + if not is_mlir_value(deterministic): + deterministic = torch_dialect.ConstantBoolOp(deterministic) + else: + deterministic = get_op_result_or_value(deterministic) + assert str(deterministic.type) == '!torch.bool', f'`deterministic` should be a !torch.bool but is {type(deterministic).__module__}.{type(deterministic).__name__}' + + if not is_mlir_value(cudnn_enabled): + cudnn_enabled = torch_dialect.ConstantBoolOp(cudnn_enabled) + else: + cudnn_enabled = get_op_result_or_value(cudnn_enabled) + assert str(cudnn_enabled.type) == '!torch.bool', f'`cudnn_enabled` should be a !torch.bool but is {type(cudnn_enabled).__module__}.{type(cudnn_enabled).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(Aten_ConvolutionDeprecatedOp, self).__init__(result_type, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled, loc=loc, ip=ip) + + +class AtenRollOp: + def __init__(self, self_: Value, shifts: List[int], dims: List[int], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(shifts): + shifts = list(map(torch_dialect.ConstantIntOp, shifts)) + shifts = torch_dialect.PrimListConstructOp(shifts) + else: + shifts = get_op_result_or_value(shifts) + assert str(shifts.type) == '!torch.list', f'`shifts` should be a !torch.list but is {type(shifts).__module__}.{type(shifts).__name__}' + + if not is_mlir_value(dims): + dims = list(map(torch_dialect.ConstantIntOp, dims)) + dims = torch_dialect.PrimListConstructOp(dims) + else: + dims = get_op_result_or_value(dims) + assert str(dims.type) == '!torch.list', f'`dims` should be a !torch.list but is {type(dims).__module__}.{type(dims).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenRollOp, self).__init__(result_type, self_, shifts, dims, loc=loc, ip=ip) + + +class AtenConvolutionBackwardOverrideableOp: + def __init__(self, grad_output: Value, input: Value, weight: Value, stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, output_mask: List[bool], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(grad_output): + assert is_mlir_value(grad_output), f'`grad_output` should be a Value but is {type(grad_output).__module__}.{type(grad_output).__name__}' + else: + grad_output = get_op_result_or_value(grad_output) + assert str(grad_output.type).startswith("!torch.vtensor"), f'`grad_output` should be a torch.vtensor but is {type(grad_output).__module__}.{type(grad_output).__name__}' + + if not is_mlir_value(input): + assert is_mlir_value(input), f'`input` should be a Value but is {type(input).__module__}.{type(input).__name__}' + else: + input = get_op_result_or_value(input) + assert str(input.type).startswith("!torch.vtensor"), f'`input` should be a torch.vtensor but is {type(input).__module__}.{type(input).__name__}' + + if not is_mlir_value(weight): + assert is_mlir_value(weight), f'`weight` should be a Value but is {type(weight).__module__}.{type(weight).__name__}' + else: + weight = get_op_result_or_value(weight) + assert str(weight.type).startswith("!torch.vtensor"), f'`weight` should be a torch.vtensor but is {type(weight).__module__}.{type(weight).__name__}' + + if not is_mlir_value(stride): + stride = list(map(torch_dialect.ConstantIntOp, stride)) + stride = torch_dialect.PrimListConstructOp(stride) + else: + stride = get_op_result_or_value(stride) + assert str(stride.type) == '!torch.list', f'`stride` should be a !torch.list but is {type(stride).__module__}.{type(stride).__name__}' + + if not is_mlir_value(padding): + padding = list(map(torch_dialect.ConstantIntOp, padding)) + padding = torch_dialect.PrimListConstructOp(padding) + else: + padding = get_op_result_or_value(padding) + assert str(padding.type) == '!torch.list', f'`padding` should be a !torch.list but is {type(padding).__module__}.{type(padding).__name__}' + + if not is_mlir_value(dilation): + dilation = list(map(torch_dialect.ConstantIntOp, dilation)) + dilation = torch_dialect.PrimListConstructOp(dilation) + else: + dilation = get_op_result_or_value(dilation) + assert str(dilation.type) == '!torch.list', f'`dilation` should be a !torch.list but is {type(dilation).__module__}.{type(dilation).__name__}' + + if not is_mlir_value(transposed): + transposed = torch_dialect.ConstantBoolOp(transposed) + else: + transposed = get_op_result_or_value(transposed) + assert str(transposed.type) == '!torch.bool', f'`transposed` should be a !torch.bool but is {type(transposed).__module__}.{type(transposed).__name__}' + + if not is_mlir_value(output_padding): + output_padding = list(map(torch_dialect.ConstantIntOp, output_padding)) + output_padding = torch_dialect.PrimListConstructOp(output_padding) + else: + output_padding = get_op_result_or_value(output_padding) + assert str(output_padding.type) == '!torch.list', f'`output_padding` should be a !torch.list but is {type(output_padding).__module__}.{type(output_padding).__name__}' + + if not is_mlir_value(groups): + groups = torch_dialect.ConstantIntOp(groups) + else: + groups = get_op_result_or_value(groups) + assert str(groups.type) == '!torch.int', f'`groups` should be a !torch.int but is {type(groups).__module__}.{type(groups).__name__}' + + if not is_mlir_value(output_mask): + output_mask = list(map(torch_dialect.ConstantBoolOp, output_mask)) + output_mask = torch_dialect.PrimListConstructOp(output_mask) + else: + output_mask = get_op_result_or_value(output_mask) + # should be bool[] + pass + + grad_input_type = Type.parse("!torch.vtensor") + grad_weight_type = Type.parse("!torch.vtensor") + grad_bias_type = Type.parse("!torch.vtensor") + super(AtenConvolutionBackwardOverrideableOp, self).__init__(grad_input_type, grad_weight_type, grad_bias_type, grad_output, input, weight, stride, padding, dilation, transposed, output_padding, groups, output_mask, loc=loc, ip=ip) + + +class AtenFlipOp: + def __init__(self, self_: Value, dims: List[int], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(dims): + dims = list(map(torch_dialect.ConstantIntOp, dims)) + dims = torch_dialect.PrimListConstructOp(dims) + else: + dims = get_op_result_or_value(dims) + assert str(dims.type) == '!torch.list', f'`dims` should be a !torch.list but is {type(dims).__module__}.{type(dims).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenFlipOp, self).__init__(result_type, self_, dims, loc=loc, ip=ip) + + +class AtenNativeBatchNormOp: + def __init__(self, input: Value, weight: Optional[Value], bias: Optional[Value], running_mean: Optional[Value], running_var: Optional[Value], training: bool, momentum: float, eps: float, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(input): + assert is_mlir_value(input), f'`input` should be a Value but is {type(input).__module__}.{type(input).__name__}' + else: + input = get_op_result_or_value(input) + assert str(input.type).startswith("!torch.vtensor"), f'`input` should be a torch.vtensor but is {type(input).__module__}.{type(input).__name__}' + + if not is_mlir_value(weight): + if weight is not None: + assert is_mlir_value(weight), f'`weight` should be a Value but is {type(weight).__module__}.{type(weight).__name__}' + else: + weight = torch_dialect.ConstantNoneOp() + else: + weight = get_op_result_or_value(weight) + assert str(weight.type).startswith("!torch.vtensor"), f'`weight` should be a torch.vtensor but is {type(weight).__module__}.{type(weight).__name__}' + + if not is_mlir_value(bias): + if bias is not None: + assert is_mlir_value(bias), f'`bias` should be a Value but is {type(bias).__module__}.{type(bias).__name__}' + else: + bias = torch_dialect.ConstantNoneOp() + else: + bias = get_op_result_or_value(bias) + assert str(bias.type).startswith("!torch.vtensor"), f'`bias` should be a torch.vtensor but is {type(bias).__module__}.{type(bias).__name__}' + + if not is_mlir_value(running_mean): + if running_mean is not None: + assert is_mlir_value(running_mean), f'`running_mean` should be a Value but is {type(running_mean).__module__}.{type(running_mean).__name__}' + else: + running_mean = torch_dialect.ConstantNoneOp() + else: + running_mean = get_op_result_or_value(running_mean) + assert str(running_mean.type).startswith("!torch.vtensor"), f'`running_mean` should be a torch.vtensor but is {type(running_mean).__module__}.{type(running_mean).__name__}' + + if not is_mlir_value(running_var): + if running_var is not None: + assert is_mlir_value(running_var), f'`running_var` should be a Value but is {type(running_var).__module__}.{type(running_var).__name__}' + else: + running_var = torch_dialect.ConstantNoneOp() + else: + running_var = get_op_result_or_value(running_var) + assert str(running_var.type).startswith("!torch.vtensor"), f'`running_var` should be a torch.vtensor but is {type(running_var).__module__}.{type(running_var).__name__}' + + if not is_mlir_value(training): + training = torch_dialect.ConstantBoolOp(training) + else: + training = get_op_result_or_value(training) + assert str(training.type) == '!torch.bool', f'`training` should be a !torch.bool but is {type(training).__module__}.{type(training).__name__}' + + if not is_mlir_value(momentum): + momentum = torch_dialect.ConstantFloatOp(momentum) + else: + momentum = get_op_result_or_value(momentum) + assert str(momentum.type) == '!torch.float', f'`momentum` should be a !torch.float but is {type(momentum).__module__}.{type(momentum).__name__}' + + if not is_mlir_value(eps): + eps = torch_dialect.ConstantFloatOp(eps) + else: + eps = get_op_result_or_value(eps) + assert str(eps.type) == '!torch.float', f'`eps` should be a !torch.float but is {type(eps).__module__}.{type(eps).__name__}' + + result0_type = Type.parse("!torch.vtensor") + result1_type = Type.parse("!torch.vtensor") + result2_type = Type.parse("!torch.vtensor") + super(AtenNativeBatchNormOp, self).__init__(result0_type, result1_type, result2_type, input, weight, bias, running_mean, running_var, training, momentum, eps, loc=loc, ip=ip) + + +class AtenBatchNormOp: + def __init__(self, input: Value, weight: Optional[Value], bias: Optional[Value], running_mean: Optional[Value], running_var: Optional[Value], training: bool, momentum: float, eps: float, cudnn_enabled: bool, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(input): + assert is_mlir_value(input), f'`input` should be a Value but is {type(input).__module__}.{type(input).__name__}' + else: + input = get_op_result_or_value(input) + assert str(input.type).startswith("!torch.vtensor"), f'`input` should be a torch.vtensor but is {type(input).__module__}.{type(input).__name__}' + + if not is_mlir_value(weight): + if weight is not None: + assert is_mlir_value(weight), f'`weight` should be a Value but is {type(weight).__module__}.{type(weight).__name__}' + else: + weight = torch_dialect.ConstantNoneOp() + else: + weight = get_op_result_or_value(weight) + assert str(weight.type).startswith("!torch.vtensor"), f'`weight` should be a torch.vtensor but is {type(weight).__module__}.{type(weight).__name__}' + + if not is_mlir_value(bias): + if bias is not None: + assert is_mlir_value(bias), f'`bias` should be a Value but is {type(bias).__module__}.{type(bias).__name__}' + else: + bias = torch_dialect.ConstantNoneOp() + else: + bias = get_op_result_or_value(bias) + assert str(bias.type).startswith("!torch.vtensor"), f'`bias` should be a torch.vtensor but is {type(bias).__module__}.{type(bias).__name__}' + + if not is_mlir_value(running_mean): + if running_mean is not None: + assert is_mlir_value(running_mean), f'`running_mean` should be a Value but is {type(running_mean).__module__}.{type(running_mean).__name__}' + else: + running_mean = torch_dialect.ConstantNoneOp() + else: + running_mean = get_op_result_or_value(running_mean) + assert str(running_mean.type).startswith("!torch.vtensor"), f'`running_mean` should be a torch.vtensor but is {type(running_mean).__module__}.{type(running_mean).__name__}' + + if not is_mlir_value(running_var): + if running_var is not None: + assert is_mlir_value(running_var), f'`running_var` should be a Value but is {type(running_var).__module__}.{type(running_var).__name__}' + else: + running_var = torch_dialect.ConstantNoneOp() + else: + running_var = get_op_result_or_value(running_var) + assert str(running_var.type).startswith("!torch.vtensor"), f'`running_var` should be a torch.vtensor but is {type(running_var).__module__}.{type(running_var).__name__}' + + if not is_mlir_value(training): + training = torch_dialect.ConstantBoolOp(training) + else: + training = get_op_result_or_value(training) + assert str(training.type) == '!torch.bool', f'`training` should be a !torch.bool but is {type(training).__module__}.{type(training).__name__}' + + if not is_mlir_value(momentum): + momentum = torch_dialect.ConstantFloatOp(momentum) + else: + momentum = get_op_result_or_value(momentum) + assert str(momentum.type) == '!torch.float', f'`momentum` should be a !torch.float but is {type(momentum).__module__}.{type(momentum).__name__}' + + if not is_mlir_value(eps): + eps = torch_dialect.ConstantFloatOp(eps) + else: + eps = get_op_result_or_value(eps) + assert str(eps.type) == '!torch.float', f'`eps` should be a !torch.float but is {type(eps).__module__}.{type(eps).__name__}' + + if not is_mlir_value(cudnn_enabled): + cudnn_enabled = torch_dialect.ConstantBoolOp(cudnn_enabled) + else: + cudnn_enabled = get_op_result_or_value(cudnn_enabled) + assert str(cudnn_enabled.type) == '!torch.bool', f'`cudnn_enabled` should be a !torch.bool but is {type(cudnn_enabled).__module__}.{type(cudnn_enabled).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenBatchNormOp, self).__init__(result_type, input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled, loc=loc, ip=ip) + + +class AtenLayerNormOp: + def __init__(self, input: Value, normalized_shape: List[int], weight: Optional[Value], bias: Optional[Value], eps: float, cudnn_enable: bool, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(input): + assert is_mlir_value(input), f'`input` should be a Value but is {type(input).__module__}.{type(input).__name__}' + else: + input = get_op_result_or_value(input) + assert str(input.type).startswith("!torch.vtensor"), f'`input` should be a torch.vtensor but is {type(input).__module__}.{type(input).__name__}' + + if not is_mlir_value(normalized_shape): + normalized_shape = list(map(torch_dialect.ConstantIntOp, normalized_shape)) + normalized_shape = torch_dialect.PrimListConstructOp(normalized_shape) + else: + normalized_shape = get_op_result_or_value(normalized_shape) + assert str(normalized_shape.type) == '!torch.list', f'`normalized_shape` should be a !torch.list but is {type(normalized_shape).__module__}.{type(normalized_shape).__name__}' + + if not is_mlir_value(weight): + if weight is not None: + assert is_mlir_value(weight), f'`weight` should be a Value but is {type(weight).__module__}.{type(weight).__name__}' + else: + weight = torch_dialect.ConstantNoneOp() + else: + weight = get_op_result_or_value(weight) + assert str(weight.type).startswith("!torch.vtensor"), f'`weight` should be a torch.vtensor but is {type(weight).__module__}.{type(weight).__name__}' + + if not is_mlir_value(bias): + if bias is not None: + assert is_mlir_value(bias), f'`bias` should be a Value but is {type(bias).__module__}.{type(bias).__name__}' + else: + bias = torch_dialect.ConstantNoneOp() + else: + bias = get_op_result_or_value(bias) + assert str(bias.type).startswith("!torch.vtensor"), f'`bias` should be a torch.vtensor but is {type(bias).__module__}.{type(bias).__name__}' + + if not is_mlir_value(eps): + eps = torch_dialect.ConstantFloatOp(eps) + else: + eps = get_op_result_or_value(eps) + assert str(eps.type) == '!torch.float', f'`eps` should be a !torch.float but is {type(eps).__module__}.{type(eps).__name__}' + + if not is_mlir_value(cudnn_enable): + cudnn_enable = torch_dialect.ConstantBoolOp(cudnn_enable) + else: + cudnn_enable = get_op_result_or_value(cudnn_enable) + assert str(cudnn_enable.type) == '!torch.bool', f'`cudnn_enable` should be a !torch.bool but is {type(cudnn_enable).__module__}.{type(cudnn_enable).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenLayerNormOp, self).__init__(result_type, input, normalized_shape, weight, bias, eps, cudnn_enable, loc=loc, ip=ip) + + +class AtenNativeLayerNormOp: + def __init__(self, input: Value, normalized_shape: List[int], weight: Optional[Value], bias: Optional[Value], eps: float, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(input): + assert is_mlir_value(input), f'`input` should be a Value but is {type(input).__module__}.{type(input).__name__}' + else: + input = get_op_result_or_value(input) + assert str(input.type).startswith("!torch.vtensor"), f'`input` should be a torch.vtensor but is {type(input).__module__}.{type(input).__name__}' + + if not is_mlir_value(normalized_shape): + normalized_shape = list(map(torch_dialect.ConstantIntOp, normalized_shape)) + normalized_shape = torch_dialect.PrimListConstructOp(normalized_shape) + else: + normalized_shape = get_op_result_or_value(normalized_shape) + assert str(normalized_shape.type) == '!torch.list', f'`normalized_shape` should be a !torch.list but is {type(normalized_shape).__module__}.{type(normalized_shape).__name__}' + + if not is_mlir_value(weight): + if weight is not None: + assert is_mlir_value(weight), f'`weight` should be a Value but is {type(weight).__module__}.{type(weight).__name__}' + else: + weight = torch_dialect.ConstantNoneOp() + else: + weight = get_op_result_or_value(weight) + assert str(weight.type).startswith("!torch.vtensor"), f'`weight` should be a torch.vtensor but is {type(weight).__module__}.{type(weight).__name__}' + + if not is_mlir_value(bias): + if bias is not None: + assert is_mlir_value(bias), f'`bias` should be a Value but is {type(bias).__module__}.{type(bias).__name__}' + else: + bias = torch_dialect.ConstantNoneOp() + else: + bias = get_op_result_or_value(bias) + assert str(bias.type).startswith("!torch.vtensor"), f'`bias` should be a torch.vtensor but is {type(bias).__module__}.{type(bias).__name__}' + + if not is_mlir_value(eps): + eps = torch_dialect.ConstantFloatOp(eps) + else: + eps = get_op_result_or_value(eps) + assert str(eps.type) == '!torch.float', f'`eps` should be a !torch.float but is {type(eps).__module__}.{type(eps).__name__}' + + result0_type = Type.parse("!torch.vtensor") + result1_type = Type.parse("!torch.vtensor") + result2_type = Type.parse("!torch.vtensor") + super(AtenNativeLayerNormOp, self).__init__(result0_type, result1_type, result2_type, input, normalized_shape, weight, bias, eps, loc=loc, ip=ip) + + +class AtenMaxPool2dOp: + def __init__(self, self_: Value, kernel_size: List[int], stride: List[int], padding: List[int], dilation: List[int], ceil_mode: bool, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(kernel_size): + kernel_size = list(map(torch_dialect.ConstantIntOp, kernel_size)) + kernel_size = torch_dialect.PrimListConstructOp(kernel_size) + else: + kernel_size = get_op_result_or_value(kernel_size) + assert str(kernel_size.type) == '!torch.list', f'`kernel_size` should be a !torch.list but is {type(kernel_size).__module__}.{type(kernel_size).__name__}' + + if not is_mlir_value(stride): + stride = list(map(torch_dialect.ConstantIntOp, stride)) + stride = torch_dialect.PrimListConstructOp(stride) + else: + stride = get_op_result_or_value(stride) + assert str(stride.type) == '!torch.list', f'`stride` should be a !torch.list but is {type(stride).__module__}.{type(stride).__name__}' + + if not is_mlir_value(padding): + padding = list(map(torch_dialect.ConstantIntOp, padding)) + padding = torch_dialect.PrimListConstructOp(padding) + else: + padding = get_op_result_or_value(padding) + assert str(padding.type) == '!torch.list', f'`padding` should be a !torch.list but is {type(padding).__module__}.{type(padding).__name__}' + + if not is_mlir_value(dilation): + dilation = list(map(torch_dialect.ConstantIntOp, dilation)) + dilation = torch_dialect.PrimListConstructOp(dilation) + else: + dilation = get_op_result_or_value(dilation) + assert str(dilation.type) == '!torch.list', f'`dilation` should be a !torch.list but is {type(dilation).__module__}.{type(dilation).__name__}' + + if not is_mlir_value(ceil_mode): + ceil_mode = torch_dialect.ConstantBoolOp(ceil_mode) + else: + ceil_mode = get_op_result_or_value(ceil_mode) + assert str(ceil_mode.type) == '!torch.bool', f'`ceil_mode` should be a !torch.bool but is {type(ceil_mode).__module__}.{type(ceil_mode).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenMaxPool2dOp, self).__init__(result_type, self_, kernel_size, stride, padding, dilation, ceil_mode, loc=loc, ip=ip) + + +class AtenMaxPool2dWithIndicesOp: + def __init__(self, self_: Value, kernel_size: List[int], stride: List[int], padding: List[int], dilation: List[int], ceil_mode: bool, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(kernel_size): + kernel_size = list(map(torch_dialect.ConstantIntOp, kernel_size)) + kernel_size = torch_dialect.PrimListConstructOp(kernel_size) + else: + kernel_size = get_op_result_or_value(kernel_size) + assert str(kernel_size.type) == '!torch.list', f'`kernel_size` should be a !torch.list but is {type(kernel_size).__module__}.{type(kernel_size).__name__}' + + if not is_mlir_value(stride): + stride = list(map(torch_dialect.ConstantIntOp, stride)) + stride = torch_dialect.PrimListConstructOp(stride) + else: + stride = get_op_result_or_value(stride) + assert str(stride.type) == '!torch.list', f'`stride` should be a !torch.list but is {type(stride).__module__}.{type(stride).__name__}' + + if not is_mlir_value(padding): + padding = list(map(torch_dialect.ConstantIntOp, padding)) + padding = torch_dialect.PrimListConstructOp(padding) + else: + padding = get_op_result_or_value(padding) + assert str(padding.type) == '!torch.list', f'`padding` should be a !torch.list but is {type(padding).__module__}.{type(padding).__name__}' + + if not is_mlir_value(dilation): + dilation = list(map(torch_dialect.ConstantIntOp, dilation)) + dilation = torch_dialect.PrimListConstructOp(dilation) + else: + dilation = get_op_result_or_value(dilation) + assert str(dilation.type) == '!torch.list', f'`dilation` should be a !torch.list but is {type(dilation).__module__}.{type(dilation).__name__}' + + if not is_mlir_value(ceil_mode): + ceil_mode = torch_dialect.ConstantBoolOp(ceil_mode) + else: + ceil_mode = get_op_result_or_value(ceil_mode) + assert str(ceil_mode.type) == '!torch.bool', f'`ceil_mode` should be a !torch.bool but is {type(ceil_mode).__module__}.{type(ceil_mode).__name__}' + + result0_type = Type.parse("!torch.vtensor") + result1_type = Type.parse("!torch.vtensor") + super(AtenMaxPool2dWithIndicesOp, self).__init__(result0_type, result1_type, self_, kernel_size, stride, padding, dilation, ceil_mode, loc=loc, ip=ip) + + +class AtenMaxPool2dWithIndicesBackwardOp: + def __init__(self, grad_output: Value, self_: Value, kernel_size: List[int], stride: List[int], padding: List[int], dilation: List[int], ceil_mode: bool, indices: Value, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(grad_output): + assert is_mlir_value(grad_output), f'`grad_output` should be a Value but is {type(grad_output).__module__}.{type(grad_output).__name__}' + else: + grad_output = get_op_result_or_value(grad_output) + assert str(grad_output.type).startswith("!torch.vtensor"), f'`grad_output` should be a torch.vtensor but is {type(grad_output).__module__}.{type(grad_output).__name__}' + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(kernel_size): + kernel_size = list(map(torch_dialect.ConstantIntOp, kernel_size)) + kernel_size = torch_dialect.PrimListConstructOp(kernel_size) + else: + kernel_size = get_op_result_or_value(kernel_size) + assert str(kernel_size.type) == '!torch.list', f'`kernel_size` should be a !torch.list but is {type(kernel_size).__module__}.{type(kernel_size).__name__}' + + if not is_mlir_value(stride): + stride = list(map(torch_dialect.ConstantIntOp, stride)) + stride = torch_dialect.PrimListConstructOp(stride) + else: + stride = get_op_result_or_value(stride) + assert str(stride.type) == '!torch.list', f'`stride` should be a !torch.list but is {type(stride).__module__}.{type(stride).__name__}' + + if not is_mlir_value(padding): + padding = list(map(torch_dialect.ConstantIntOp, padding)) + padding = torch_dialect.PrimListConstructOp(padding) + else: + padding = get_op_result_or_value(padding) + assert str(padding.type) == '!torch.list', f'`padding` should be a !torch.list but is {type(padding).__module__}.{type(padding).__name__}' + + if not is_mlir_value(dilation): + dilation = list(map(torch_dialect.ConstantIntOp, dilation)) + dilation = torch_dialect.PrimListConstructOp(dilation) + else: + dilation = get_op_result_or_value(dilation) + assert str(dilation.type) == '!torch.list', f'`dilation` should be a !torch.list but is {type(dilation).__module__}.{type(dilation).__name__}' + + if not is_mlir_value(ceil_mode): + ceil_mode = torch_dialect.ConstantBoolOp(ceil_mode) + else: + ceil_mode = get_op_result_or_value(ceil_mode) + assert str(ceil_mode.type) == '!torch.bool', f'`ceil_mode` should be a !torch.bool but is {type(ceil_mode).__module__}.{type(ceil_mode).__name__}' + + if not is_mlir_value(indices): + assert is_mlir_value(indices), f'`indices` should be a Value but is {type(indices).__module__}.{type(indices).__name__}' + else: + indices = get_op_result_or_value(indices) + assert str(indices.type).startswith("!torch.vtensor"), f'`indices` should be a torch.vtensor but is {type(indices).__module__}.{type(indices).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenMaxPool2dWithIndicesBackwardOp, self).__init__(result_type, grad_output, self_, kernel_size, stride, padding, dilation, ceil_mode, indices, loc=loc, ip=ip) + + +class AtenAvgPool2dOp: + def __init__(self, self_: Value, kernel_size: List[int], stride: List[int], padding: List[int], ceil_mode: bool, count_include_pad: bool, divisor_override: Optional[int], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(kernel_size): + kernel_size = list(map(torch_dialect.ConstantIntOp, kernel_size)) + kernel_size = torch_dialect.PrimListConstructOp(kernel_size) + else: + kernel_size = get_op_result_or_value(kernel_size) + assert str(kernel_size.type) == '!torch.list', f'`kernel_size` should be a !torch.list but is {type(kernel_size).__module__}.{type(kernel_size).__name__}' + + if not is_mlir_value(stride): + stride = list(map(torch_dialect.ConstantIntOp, stride)) + stride = torch_dialect.PrimListConstructOp(stride) + else: + stride = get_op_result_or_value(stride) + assert str(stride.type) == '!torch.list', f'`stride` should be a !torch.list but is {type(stride).__module__}.{type(stride).__name__}' + + if not is_mlir_value(padding): + padding = list(map(torch_dialect.ConstantIntOp, padding)) + padding = torch_dialect.PrimListConstructOp(padding) + else: + padding = get_op_result_or_value(padding) + assert str(padding.type) == '!torch.list', f'`padding` should be a !torch.list but is {type(padding).__module__}.{type(padding).__name__}' + + if not is_mlir_value(ceil_mode): + ceil_mode = torch_dialect.ConstantBoolOp(ceil_mode) + else: + ceil_mode = get_op_result_or_value(ceil_mode) + assert str(ceil_mode.type) == '!torch.bool', f'`ceil_mode` should be a !torch.bool but is {type(ceil_mode).__module__}.{type(ceil_mode).__name__}' + + if not is_mlir_value(count_include_pad): + count_include_pad = torch_dialect.ConstantBoolOp(count_include_pad) + else: + count_include_pad = get_op_result_or_value(count_include_pad) + assert str(count_include_pad.type) == '!torch.bool', f'`count_include_pad` should be a !torch.bool but is {type(count_include_pad).__module__}.{type(count_include_pad).__name__}' + + if not is_mlir_value(divisor_override): + if divisor_override is not None: + divisor_override = torch_dialect.ConstantIntOp(divisor_override) + else: + divisor_override = torch_dialect.ConstantNoneOp() + else: + divisor_override = get_op_result_or_value(divisor_override) + assert str(divisor_override.type) == '!torch.int', f'`divisor_override` should be a !torch.int but is {type(divisor_override).__module__}.{type(divisor_override).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenAvgPool2dOp, self).__init__(result_type, self_, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override, loc=loc, ip=ip) + + +class AtenSoftmaxIntOp: + def __init__(self, self_: Value, dim: int, dtype: Optional[int], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(dim): + dim = torch_dialect.ConstantIntOp(dim) + else: + dim = get_op_result_or_value(dim) + assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}' + + if not is_mlir_value(dtype): + if dtype is not None: + dtype = torch_dialect.ConstantIntOp(dtype) + else: + dtype = torch_dialect.ConstantNoneOp() + else: + dtype = get_op_result_or_value(dtype) + assert str(dtype.type) == '!torch.int', f'`dtype` should be a !torch.int but is {type(dtype).__module__}.{type(dtype).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenSoftmaxIntOp, self).__init__(result_type, self_, dim, dtype, loc=loc, ip=ip) + + +class AtenLogSoftmaxIntOp: + def __init__(self, self_: Value, dim: int, dtype: Optional[int], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(dim): + dim = torch_dialect.ConstantIntOp(dim) + else: + dim = get_op_result_or_value(dim) + assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}' + + if not is_mlir_value(dtype): + if dtype is not None: + dtype = torch_dialect.ConstantIntOp(dtype) + else: + dtype = torch_dialect.ConstantNoneOp() + else: + dtype = get_op_result_or_value(dtype) + assert str(dtype.type) == '!torch.int', f'`dtype` should be a !torch.int but is {type(dtype).__module__}.{type(dtype).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenLogSoftmaxIntOp, self).__init__(result_type, self_, dim, dtype, loc=loc, ip=ip) + + +class Aten_LogSoftmaxOp: + def __init__(self, self_: Value, dim: int, half_to_float: bool, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(dim): + dim = torch_dialect.ConstantIntOp(dim) + else: + dim = get_op_result_or_value(dim) + assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}' + + if not is_mlir_value(half_to_float): + half_to_float = torch_dialect.ConstantBoolOp(half_to_float) + else: + half_to_float = get_op_result_or_value(half_to_float) + assert str(half_to_float.type) == '!torch.bool', f'`half_to_float` should be a !torch.bool but is {type(half_to_float).__module__}.{type(half_to_float).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(Aten_LogSoftmaxOp, self).__init__(result_type, self_, dim, half_to_float, loc=loc, ip=ip) + + +class AtenAdaptiveAvgPool2dOp: + def __init__(self, self_: Value, output_size: List[int], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(output_size): + output_size = list(map(torch_dialect.ConstantIntOp, output_size)) + output_size = torch_dialect.PrimListConstructOp(output_size) + else: + output_size = get_op_result_or_value(output_size) + assert str(output_size.type) == '!torch.list', f'`output_size` should be a !torch.list but is {type(output_size).__module__}.{type(output_size).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenAdaptiveAvgPool2dOp, self).__init__(result_type, self_, output_size, loc=loc, ip=ip) + + +class AtenTopkOp: + def __init__(self, self_: Value, k: int, dim: int, largest: bool, sorted: bool, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(k): + k = torch_dialect.ConstantIntOp(k) + else: + k = get_op_result_or_value(k) + assert str(k.type) == '!torch.int', f'`k` should be a !torch.int but is {type(k).__module__}.{type(k).__name__}' + + if not is_mlir_value(dim): + dim = torch_dialect.ConstantIntOp(dim) + else: + dim = get_op_result_or_value(dim) + assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}' + + if not is_mlir_value(largest): + largest = torch_dialect.ConstantBoolOp(largest) + else: + largest = get_op_result_or_value(largest) + assert str(largest.type) == '!torch.bool', f'`largest` should be a !torch.bool but is {type(largest).__module__}.{type(largest).__name__}' + + if not is_mlir_value(sorted): + sorted = torch_dialect.ConstantBoolOp(sorted) + else: + sorted = get_op_result_or_value(sorted) + assert str(sorted.type) == '!torch.bool', f'`sorted` should be a !torch.bool but is {type(sorted).__module__}.{type(sorted).__name__}' + + values_type = Type.parse("!torch.vtensor") + indices_type = Type.parse("!torch.vtensor") + super(AtenTopkOp, self).__init__(values_type, indices_type, self_, k, dim, largest, sorted, loc=loc, ip=ip) + + +class AtenTransposeIntOp: + def __init__(self, self_: Value, dim0: int, dim1: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(dim0): + dim0 = torch_dialect.ConstantIntOp(dim0) + else: + dim0 = get_op_result_or_value(dim0) + assert str(dim0.type) == '!torch.int', f'`dim0` should be a !torch.int but is {type(dim0).__module__}.{type(dim0).__name__}' + + if not is_mlir_value(dim1): + dim1 = torch_dialect.ConstantIntOp(dim1) + else: + dim1 = get_op_result_or_value(dim1) + assert str(dim1.type) == '!torch.int', f'`dim1` should be a !torch.int but is {type(dim1).__module__}.{type(dim1).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenTransposeIntOp, self).__init__(result_type, self_, dim0, dim1, loc=loc, ip=ip) + + +class AtenPermuteOp: + def __init__(self, self_: Value, dims: List[int], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(dims): + dims = list(map(torch_dialect.ConstantIntOp, dims)) + dims = torch_dialect.PrimListConstructOp(dims) + else: + dims = get_op_result_or_value(dims) + assert str(dims.type) == '!torch.list', f'`dims` should be a !torch.list but is {type(dims).__module__}.{type(dims).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenPermuteOp, self).__init__(result_type, self_, dims, loc=loc, ip=ip) + + +class AtenBmmOp: + def __init__(self, self_: Value, mat2: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(mat2): + assert is_mlir_value(mat2), f'`mat2` should be a Value but is {type(mat2).__module__}.{type(mat2).__name__}' + else: + mat2 = get_op_result_or_value(mat2) + assert str(mat2.type).startswith("!torch.vtensor"), f'`mat2` should be a torch.vtensor but is {type(mat2).__module__}.{type(mat2).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenBmmOp, self).__init__(result_type, self_, mat2, loc=loc, ip=ip) + + +class AtenCumsumOp: + def __init__(self, self_: Value, dim: int, dtype: Optional[int], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(dim): + dim = torch_dialect.ConstantIntOp(dim) + else: + dim = get_op_result_or_value(dim) + assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}' + + if not is_mlir_value(dtype): + if dtype is not None: + dtype = torch_dialect.ConstantIntOp(dtype) + else: + dtype = torch_dialect.ConstantNoneOp() + else: + dtype = get_op_result_or_value(dtype) + assert str(dtype.type) == '!torch.int', f'`dtype` should be a !torch.int but is {type(dtype).__module__}.{type(dtype).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenCumsumOp, self).__init__(result_type, self_, dim, dtype, loc=loc, ip=ip) + + +class AtenFloorDivideScalarOp: + def __init__(self, self_: Value, other: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + other = torch_dialect.ConstantNumberOp(other) + else: + other = get_op_result_or_value(other) + assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenFloorDivideScalarOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenLogsumexpOp: + def __init__(self, self_: Value, dim: List[int], keepdim: bool, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(dim): + dim = list(map(torch_dialect.ConstantIntOp, dim)) + dim = torch_dialect.PrimListConstructOp(dim) + else: + dim = get_op_result_or_value(dim) + assert str(dim.type) == '!torch.list', f'`dim` should be a !torch.list but is {type(dim).__module__}.{type(dim).__name__}' + + if not is_mlir_value(keepdim): + keepdim = torch_dialect.ConstantBoolOp(keepdim) + else: + keepdim = get_op_result_or_value(keepdim) + assert str(keepdim.type) == '!torch.bool', f'`keepdim` should be a !torch.bool but is {type(keepdim).__module__}.{type(keepdim).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenLogsumexpOp, self).__init__(result_type, self_, dim, keepdim, loc=loc, ip=ip) + + +class Aten__And__TensorOp: + def __init__(self, self_: Value, other: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}' + else: + other = get_op_result_or_value(other) + assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(Aten__And__TensorOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class Aten_SoftmaxOp: + def __init__(self, self_: Value, dim: int, half_to_float: bool, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(dim): + dim = torch_dialect.ConstantIntOp(dim) + else: + dim = get_op_result_or_value(dim) + assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}' + + if not is_mlir_value(half_to_float): + half_to_float = torch_dialect.ConstantBoolOp(half_to_float) + else: + half_to_float = get_op_result_or_value(half_to_float) + assert str(half_to_float.type) == '!torch.bool', f'`half_to_float` should be a !torch.bool but is {type(half_to_float).__module__}.{type(half_to_float).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(Aten_SoftmaxOp, self).__init__(result_type, self_, dim, half_to_float, loc=loc, ip=ip) + + +class AtenMeanOp: + def __init__(self, self_: Value, dtype: Optional[int], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(dtype): + if dtype is not None: + dtype = torch_dialect.ConstantIntOp(dtype) + else: + dtype = torch_dialect.ConstantNoneOp() + else: + dtype = get_op_result_or_value(dtype) + assert str(dtype.type) == '!torch.int', f'`dtype` should be a !torch.int but is {type(dtype).__module__}.{type(dtype).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenMeanOp, self).__init__(result_type, self_, dtype, loc=loc, ip=ip) + + +class AtenStdOp: + def __init__(self, self_: Value, unbiased: bool, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(unbiased): + unbiased = torch_dialect.ConstantBoolOp(unbiased) + else: + unbiased = get_op_result_or_value(unbiased) + assert str(unbiased.type) == '!torch.bool', f'`unbiased` should be a !torch.bool but is {type(unbiased).__module__}.{type(unbiased).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenStdOp, self).__init__(result_type, self_, unbiased, loc=loc, ip=ip) + + +class AtenVarOp: + def __init__(self, self_: Value, unbiased: bool, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(unbiased): + unbiased = torch_dialect.ConstantBoolOp(unbiased) + else: + unbiased = get_op_result_or_value(unbiased) + assert str(unbiased.type) == '!torch.bool', f'`unbiased` should be a !torch.bool but is {type(unbiased).__module__}.{type(unbiased).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenVarOp, self).__init__(result_type, self_, unbiased, loc=loc, ip=ip) + + +class AtenVarMeanOp: + def __init__(self, self_: Value, unbiased: bool, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(unbiased): + unbiased = torch_dialect.ConstantBoolOp(unbiased) + else: + unbiased = get_op_result_or_value(unbiased) + assert str(unbiased.type) == '!torch.bool', f'`unbiased` should be a !torch.bool but is {type(unbiased).__module__}.{type(unbiased).__name__}' + + result0_type = Type.parse("!torch.vtensor") + result1_type = Type.parse("!torch.vtensor") + super(AtenVarMeanOp, self).__init__(result0_type, result1_type, self_, unbiased, loc=loc, ip=ip) + + +class AtenNllLossForwardOp: + def __init__(self, self_: Value, target: Value, weight: Optional[Value], reduction: int, ignore_index: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(target): + assert is_mlir_value(target), f'`target` should be a Value but is {type(target).__module__}.{type(target).__name__}' + else: + target = get_op_result_or_value(target) + assert str(target.type).startswith("!torch.vtensor"), f'`target` should be a torch.vtensor but is {type(target).__module__}.{type(target).__name__}' + + if not is_mlir_value(weight): + if weight is not None: + assert is_mlir_value(weight), f'`weight` should be a Value but is {type(weight).__module__}.{type(weight).__name__}' + else: + weight = torch_dialect.ConstantNoneOp() + else: + weight = get_op_result_or_value(weight) + assert str(weight.type).startswith("!torch.vtensor"), f'`weight` should be a torch.vtensor but is {type(weight).__module__}.{type(weight).__name__}' + + if not is_mlir_value(reduction): + reduction = torch_dialect.ConstantIntOp(reduction) + else: + reduction = get_op_result_or_value(reduction) + assert str(reduction.type) == '!torch.int', f'`reduction` should be a !torch.int but is {type(reduction).__module__}.{type(reduction).__name__}' + + if not is_mlir_value(ignore_index): + ignore_index = torch_dialect.ConstantIntOp(ignore_index) + else: + ignore_index = get_op_result_or_value(ignore_index) + assert str(ignore_index.type) == '!torch.int', f'`ignore_index` should be a !torch.int but is {type(ignore_index).__module__}.{type(ignore_index).__name__}' + + output_type = Type.parse("!torch.vtensor") + total_weight_type = Type.parse("!torch.vtensor") + super(AtenNllLossForwardOp, self).__init__(output_type, total_weight_type, self_, target, weight, reduction, ignore_index, loc=loc, ip=ip) + + +class AtenNllLossBackwardOp: + def __init__(self, grad_output: Value, self_: Value, target: Value, weight: Optional[Value], reduction: int, ignore_index: int, total_weight: Value, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(grad_output): + assert is_mlir_value(grad_output), f'`grad_output` should be a Value but is {type(grad_output).__module__}.{type(grad_output).__name__}' + else: + grad_output = get_op_result_or_value(grad_output) + assert str(grad_output.type).startswith("!torch.vtensor"), f'`grad_output` should be a torch.vtensor but is {type(grad_output).__module__}.{type(grad_output).__name__}' + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(target): + assert is_mlir_value(target), f'`target` should be a Value but is {type(target).__module__}.{type(target).__name__}' + else: + target = get_op_result_or_value(target) + assert str(target.type).startswith("!torch.vtensor"), f'`target` should be a torch.vtensor but is {type(target).__module__}.{type(target).__name__}' + + if not is_mlir_value(weight): + if weight is not None: + assert is_mlir_value(weight), f'`weight` should be a Value but is {type(weight).__module__}.{type(weight).__name__}' + else: + weight = torch_dialect.ConstantNoneOp() + else: + weight = get_op_result_or_value(weight) + assert str(weight.type).startswith("!torch.vtensor"), f'`weight` should be a torch.vtensor but is {type(weight).__module__}.{type(weight).__name__}' + + if not is_mlir_value(reduction): + reduction = torch_dialect.ConstantIntOp(reduction) + else: + reduction = get_op_result_or_value(reduction) + assert str(reduction.type) == '!torch.int', f'`reduction` should be a !torch.int but is {type(reduction).__module__}.{type(reduction).__name__}' + + if not is_mlir_value(ignore_index): + ignore_index = torch_dialect.ConstantIntOp(ignore_index) + else: + ignore_index = get_op_result_or_value(ignore_index) + assert str(ignore_index.type) == '!torch.int', f'`ignore_index` should be a !torch.int but is {type(ignore_index).__module__}.{type(ignore_index).__name__}' + + if not is_mlir_value(total_weight): + assert is_mlir_value(total_weight), f'`total_weight` should be a Value but is {type(total_weight).__module__}.{type(total_weight).__name__}' + else: + total_weight = get_op_result_or_value(total_weight) + assert str(total_weight.type).startswith("!torch.vtensor"), f'`total_weight` should be a torch.vtensor but is {type(total_weight).__module__}.{type(total_weight).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenNllLossBackwardOp, self).__init__(result_type, grad_output, self_, target, weight, reduction, ignore_index, total_weight, loc=loc, ip=ip) + + +class AtenBincountOp: + def __init__(self, self_: Value, weights: Optional[Value], minlength: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(weights): + if weights is not None: + assert is_mlir_value(weights), f'`weights` should be a Value but is {type(weights).__module__}.{type(weights).__name__}' + else: + weights = torch_dialect.ConstantNoneOp() + else: + weights = get_op_result_or_value(weights) + assert str(weights.type).startswith("!torch.vtensor"), f'`weights` should be a torch.vtensor but is {type(weights).__module__}.{type(weights).__name__}' + + if not is_mlir_value(minlength): + minlength = torch_dialect.ConstantIntOp(minlength) + else: + minlength = get_op_result_or_value(minlength) + assert str(minlength.type) == '!torch.int', f'`minlength` should be a !torch.int but is {type(minlength).__module__}.{type(minlength).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenBincountOp, self).__init__(result_type, self_, weights, minlength, loc=loc, ip=ip) + + +class AtenFrobeniusNormDimOp: + def __init__(self, self_: Value, dim: List[int], keepdim: bool, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(dim): + dim = list(map(torch_dialect.ConstantIntOp, dim)) + dim = torch_dialect.PrimListConstructOp(dim) + else: + dim = get_op_result_or_value(dim) + assert str(dim.type) == '!torch.list', f'`dim` should be a !torch.list but is {type(dim).__module__}.{type(dim).__name__}' + + if not is_mlir_value(keepdim): + keepdim = torch_dialect.ConstantBoolOp(keepdim) + else: + keepdim = get_op_result_or_value(keepdim) + assert str(keepdim.type) == '!torch.bool', f'`keepdim` should be a !torch.bool but is {type(keepdim).__module__}.{type(keepdim).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenFrobeniusNormDimOp, self).__init__(result_type, self_, dim, keepdim, loc=loc, ip=ip) + + +class AtenMseLossOp: + def __init__(self, self_: Value, target: Value, reduction: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(target): + assert is_mlir_value(target), f'`target` should be a Value but is {type(target).__module__}.{type(target).__name__}' + else: + target = get_op_result_or_value(target) + assert str(target.type).startswith("!torch.vtensor"), f'`target` should be a torch.vtensor but is {type(target).__module__}.{type(target).__name__}' + + if not is_mlir_value(reduction): + reduction = torch_dialect.ConstantIntOp(reduction) + else: + reduction = get_op_result_or_value(reduction) + assert str(reduction.type) == '!torch.int', f'`reduction` should be a !torch.int but is {type(reduction).__module__}.{type(reduction).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenMseLossOp, self).__init__(result_type, self_, target, reduction, loc=loc, ip=ip) + + +class AtenUpsampleNearest2dBackwardOp: + def __init__(self, grad_output: Value, output_size: List[int], input_size: List[int], scales_h: Optional[float], scales_w: Optional[float], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(grad_output): + assert is_mlir_value(grad_output), f'`grad_output` should be a Value but is {type(grad_output).__module__}.{type(grad_output).__name__}' + else: + grad_output = get_op_result_or_value(grad_output) + assert str(grad_output.type).startswith("!torch.vtensor"), f'`grad_output` should be a torch.vtensor but is {type(grad_output).__module__}.{type(grad_output).__name__}' + + if not is_mlir_value(output_size): + output_size = list(map(torch_dialect.ConstantIntOp, output_size)) + output_size = torch_dialect.PrimListConstructOp(output_size) + else: + output_size = get_op_result_or_value(output_size) + assert str(output_size.type) == '!torch.list', f'`output_size` should be a !torch.list but is {type(output_size).__module__}.{type(output_size).__name__}' + + if not is_mlir_value(input_size): + input_size = list(map(torch_dialect.ConstantIntOp, input_size)) + input_size = torch_dialect.PrimListConstructOp(input_size) + else: + input_size = get_op_result_or_value(input_size) + assert str(input_size.type) == '!torch.list', f'`input_size` should be a !torch.list but is {type(input_size).__module__}.{type(input_size).__name__}' + + if not is_mlir_value(scales_h): + if scales_h is not None: + scales_h = torch_dialect.ConstantFloatOp(scales_h) + else: + scales_h = torch_dialect.ConstantNoneOp() + else: + scales_h = get_op_result_or_value(scales_h) + assert str(scales_h.type) == '!torch.float', f'`scales_h` should be a !torch.float but is {type(scales_h).__module__}.{type(scales_h).__name__}' + + if not is_mlir_value(scales_w): + if scales_w is not None: + scales_w = torch_dialect.ConstantFloatOp(scales_w) + else: + scales_w = torch_dialect.ConstantNoneOp() + else: + scales_w = get_op_result_or_value(scales_w) + assert str(scales_w.type) == '!torch.float', f'`scales_w` should be a !torch.float but is {type(scales_w).__module__}.{type(scales_w).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenUpsampleNearest2dBackwardOp, self).__init__(result_type, grad_output, output_size, input_size, scales_h, scales_w, loc=loc, ip=ip) + + +class AtenConstantPadNdOp: + def __init__(self, self_: Value, pad: List[int], value: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(pad): + pad = list(map(torch_dialect.ConstantIntOp, pad)) + pad = torch_dialect.PrimListConstructOp(pad) + else: + pad = get_op_result_or_value(pad) + assert str(pad.type) == '!torch.list', f'`pad` should be a !torch.list but is {type(pad).__module__}.{type(pad).__name__}' + + if not is_mlir_value(value): + value = torch_dialect.ConstantNumberOp(value) + else: + value = get_op_result_or_value(value) + assert str(value.type) in {'!torch.float', '!torch.int'}, f'`value` should be a !torch.number but is {type(value).__module__}.{type(value).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenConstantPadNdOp, self).__init__(result_type, self_, pad, value, loc=loc, ip=ip) + + +class AtenPadOp: + def __init__(self, self_: Value, pad: List[int], mode: str, value: Optional[float], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(pad): + pad = list(map(torch_dialect.ConstantIntOp, pad)) + pad = torch_dialect.PrimListConstructOp(pad) + else: + pad = get_op_result_or_value(pad) + assert str(pad.type) == '!torch.list', f'`pad` should be a !torch.list but is {type(pad).__module__}.{type(pad).__name__}' + + if not is_mlir_value(mode): + mode = torch_dialect.ConstantStrOp(mode) + else: + mode = get_op_result_or_value(mode) + assert str(mode.type) == '!torch.str', f'`mode` should be a !torch.str but is {type(mode).__module__}.{type(mode).__name__}' + + if not is_mlir_value(value): + if value is not None: + value = torch_dialect.ConstantFloatOp(value) + else: + value = torch_dialect.ConstantNoneOp() + else: + value = get_op_result_or_value(value) + assert str(value.type) == '!torch.float', f'`value` should be a !torch.float but is {type(value).__module__}.{type(value).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenPadOp, self).__init__(result_type, self_, pad, mode, value, loc=loc, ip=ip) + + +class AtenSqueezeDimOp: + def __init__(self, self_: Value, dim: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(dim): + dim = torch_dialect.ConstantIntOp(dim) + else: + dim = get_op_result_or_value(dim) + assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenSqueezeDimOp, self).__init__(result_type, self_, dim, loc=loc, ip=ip) + + +class AtenSqueezeOp: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenSqueezeOp, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenFlattenUsingIntsOp: + def __init__(self, self_: Value, start_dim: int, end_dim: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(start_dim): + start_dim = torch_dialect.ConstantIntOp(start_dim) + else: + start_dim = get_op_result_or_value(start_dim) + assert str(start_dim.type) == '!torch.int', f'`start_dim` should be a !torch.int but is {type(start_dim).__module__}.{type(start_dim).__name__}' + + if not is_mlir_value(end_dim): + end_dim = torch_dialect.ConstantIntOp(end_dim) + else: + end_dim = get_op_result_or_value(end_dim) + assert str(end_dim.type) == '!torch.int', f'`end_dim` should be a !torch.int but is {type(end_dim).__module__}.{type(end_dim).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenFlattenUsingIntsOp, self).__init__(result_type, self_, start_dim, end_dim, loc=loc, ip=ip) + + +class AtenDimOp: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + super(AtenDimOp, self).__init__(self_, loc=loc, ip=ip) + + +class AtenSizeOp: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.list") + super(AtenSizeOp, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenBoolTensorOp: + def __init__(self, a: Value, *, loc=None, ip=None): + if not is_mlir_value(a): + assert is_mlir_value(a), f'`a` should be a Value but is {type(a).__module__}.{type(a).__name__}' + else: + a = get_op_result_or_value(a) + assert str(a.type).startswith("!torch.vtensor"), f'`a` should be a torch.vtensor but is {type(a).__module__}.{type(a).__name__}' + + super(AtenBoolTensorOp, self).__init__(a, loc=loc, ip=ip) + + +class AtenIsFloatingPointOp: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + super(AtenIsFloatingPointOp, self).__init__(self_, loc=loc, ip=ip) + + +class Aten_ShapeAsTensorOp: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(Aten_ShapeAsTensorOp, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenAllOp: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenAllOp, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenAllBoolOp: + def __init__(self, self_: List[bool], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + self_ = list(map(torch_dialect.ConstantBoolOp, self_)) + self_ = torch_dialect.PrimListConstructOp(self_) + else: + self_ = get_op_result_or_value(self_) + # should be bool[] + pass + + super(AtenAllBoolOp, self).__init__(self_, loc=loc, ip=ip) + + +class AtenAnyOp: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenAnyOp, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenAnyDimOp: + def __init__(self, self_: Value, dim: int, keepdim: bool, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(dim): + dim = torch_dialect.ConstantIntOp(dim) + else: + dim = get_op_result_or_value(dim) + assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}' + + if not is_mlir_value(keepdim): + keepdim = torch_dialect.ConstantBoolOp(keepdim) + else: + keepdim = get_op_result_or_value(keepdim) + assert str(keepdim.type) == '!torch.bool', f'`keepdim` should be a !torch.bool but is {type(keepdim).__module__}.{type(keepdim).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenAnyDimOp, self).__init__(result_type, self_, dim, keepdim, loc=loc, ip=ip) + + +class AtenArangeStartOutOp: + def __init__(self, start: "Number", end: "Number", step: "Number", out: Value, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(start): + start = torch_dialect.ConstantNumberOp(start) + else: + start = get_op_result_or_value(start) + assert str(start.type) in {'!torch.float', '!torch.int'}, f'`start` should be a !torch.number but is {type(start).__module__}.{type(start).__name__}' + + if not is_mlir_value(end): + end = torch_dialect.ConstantNumberOp(end) + else: + end = get_op_result_or_value(end) + assert str(end.type) in {'!torch.float', '!torch.int'}, f'`end` should be a !torch.number but is {type(end).__module__}.{type(end).__name__}' + + if not is_mlir_value(step): + step = torch_dialect.ConstantNumberOp(step) + else: + step = get_op_result_or_value(step) + assert str(step.type) in {'!torch.float', '!torch.int'}, f'`step` should be a !torch.number but is {type(step).__module__}.{type(step).__name__}' + + if not is_mlir_value(out): + assert is_mlir_value(out), f'`out` should be a Value but is {type(out).__module__}.{type(out).__name__}' + else: + out = get_op_result_or_value(out) + assert str(out.type).startswith("!torch.vtensor"), f'`out` should be a torch.vtensor but is {type(out).__module__}.{type(out).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenArangeStartOutOp, self).__init__(result_type, start, end, step, out, loc=loc, ip=ip) + + +class AtenArgmaxOp: + def __init__(self, self_: Value, dim: Optional[int], keepdim: bool, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(dim): + if dim is not None: + dim = torch_dialect.ConstantIntOp(dim) + else: + dim = torch_dialect.ConstantNoneOp() + else: + dim = get_op_result_or_value(dim) + assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}' + + if not is_mlir_value(keepdim): + keepdim = torch_dialect.ConstantBoolOp(keepdim) + else: + keepdim = get_op_result_or_value(keepdim) + assert str(keepdim.type) == '!torch.bool', f'`keepdim` should be a !torch.bool but is {type(keepdim).__module__}.{type(keepdim).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenArgmaxOp, self).__init__(result_type, self_, dim, keepdim, loc=loc, ip=ip) + + +class AtenBucketizeTensorOp: + def __init__(self, self_: Value, boundaries: Value, out_int32: bool, right: bool, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(boundaries): + assert is_mlir_value(boundaries), f'`boundaries` should be a Value but is {type(boundaries).__module__}.{type(boundaries).__name__}' + else: + boundaries = get_op_result_or_value(boundaries) + assert str(boundaries.type).startswith("!torch.vtensor"), f'`boundaries` should be a torch.vtensor but is {type(boundaries).__module__}.{type(boundaries).__name__}' + + if not is_mlir_value(out_int32): + out_int32 = torch_dialect.ConstantBoolOp(out_int32) + else: + out_int32 = get_op_result_or_value(out_int32) + assert str(out_int32.type) == '!torch.bool', f'`out_int32` should be a !torch.bool but is {type(out_int32).__module__}.{type(out_int32).__name__}' + + if not is_mlir_value(right): + right = torch_dialect.ConstantBoolOp(right) + else: + right = get_op_result_or_value(right) + assert str(right.type) == '!torch.bool', f'`right` should be a !torch.bool but is {type(right).__module__}.{type(right).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenBucketizeTensorOp, self).__init__(result_type, self_, boundaries, out_int32, right, loc=loc, ip=ip) + + +class AtenCloneOp: + def __init__(self, self_: Value, memory_format: Optional[int], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(memory_format): + if memory_format is not None: + memory_format = torch_dialect.ConstantIntOp(memory_format) + else: + memory_format = torch_dialect.ConstantNoneOp() + else: + memory_format = get_op_result_or_value(memory_format) + assert str(memory_format.type) == '!torch.int', f'`memory_format` should be a !torch.int but is {type(memory_format).__module__}.{type(memory_format).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenCloneOp, self).__init__(result_type, self_, memory_format, loc=loc, ip=ip) + + +class AtenLiftFreshCopyOp: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenLiftFreshCopyOp, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenContiguousOp: + def __init__(self, self_: Value, memory_format: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(memory_format): + memory_format = torch_dialect.ConstantIntOp(memory_format) + else: + memory_format = get_op_result_or_value(memory_format) + assert str(memory_format.type) == '!torch.int', f'`memory_format` should be a !torch.int but is {type(memory_format).__module__}.{type(memory_format).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenContiguousOp, self).__init__(result_type, self_, memory_format, loc=loc, ip=ip) + + +class AtenCopyOp: + def __init__(self, self_: Value, src: Value, non_blocking: bool, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(src): + assert is_mlir_value(src), f'`src` should be a Value but is {type(src).__module__}.{type(src).__name__}' + else: + src = get_op_result_or_value(src) + assert str(src.type).startswith("!torch.vtensor"), f'`src` should be a torch.vtensor but is {type(src).__module__}.{type(src).__name__}' + + if not is_mlir_value(non_blocking): + non_blocking = torch_dialect.ConstantBoolOp(non_blocking) + else: + non_blocking = get_op_result_or_value(non_blocking) + assert str(non_blocking.type) == '!torch.bool', f'`non_blocking` should be a !torch.bool but is {type(non_blocking).__module__}.{type(non_blocking).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenCopyOp, self).__init__(result_type, self_, src, non_blocking, loc=loc, ip=ip) + + +class AtenCopy_Op: + def __init__(self, self_: Value, src: Value, non_blocking: bool, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(src): + assert is_mlir_value(src), f'`src` should be a Value but is {type(src).__module__}.{type(src).__name__}' + else: + src = get_op_result_or_value(src) + assert str(src.type).startswith("!torch.vtensor"), f'`src` should be a torch.vtensor but is {type(src).__module__}.{type(src).__name__}' + + if not is_mlir_value(non_blocking): + non_blocking = torch_dialect.ConstantBoolOp(non_blocking) + else: + non_blocking = get_op_result_or_value(non_blocking) + assert str(non_blocking.type) == '!torch.bool', f'`non_blocking` should be a !torch.bool but is {type(non_blocking).__module__}.{type(non_blocking).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenCopy_Op, self).__init__(result_type, self_, src, non_blocking, loc=loc, ip=ip) + + +class AtenDetachOp: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenDetachOp, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenEmbeddingOp: + def __init__(self, weight: Value, indices: Value, padding_idx: int, scale_grad_by_freq: bool, sparse: bool, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(weight): + assert is_mlir_value(weight), f'`weight` should be a Value but is {type(weight).__module__}.{type(weight).__name__}' + else: + weight = get_op_result_or_value(weight) + assert str(weight.type).startswith("!torch.vtensor"), f'`weight` should be a torch.vtensor but is {type(weight).__module__}.{type(weight).__name__}' + + if not is_mlir_value(indices): + assert is_mlir_value(indices), f'`indices` should be a Value but is {type(indices).__module__}.{type(indices).__name__}' + else: + indices = get_op_result_or_value(indices) + assert str(indices.type).startswith("!torch.vtensor"), f'`indices` should be a torch.vtensor but is {type(indices).__module__}.{type(indices).__name__}' + + if not is_mlir_value(padding_idx): + padding_idx = torch_dialect.ConstantIntOp(padding_idx) + else: + padding_idx = get_op_result_or_value(padding_idx) + assert str(padding_idx.type) == '!torch.int', f'`padding_idx` should be a !torch.int but is {type(padding_idx).__module__}.{type(padding_idx).__name__}' + + if not is_mlir_value(scale_grad_by_freq): + scale_grad_by_freq = torch_dialect.ConstantBoolOp(scale_grad_by_freq) + else: + scale_grad_by_freq = get_op_result_or_value(scale_grad_by_freq) + assert str(scale_grad_by_freq.type) == '!torch.bool', f'`scale_grad_by_freq` should be a !torch.bool but is {type(scale_grad_by_freq).__module__}.{type(scale_grad_by_freq).__name__}' + + if not is_mlir_value(sparse): + sparse = torch_dialect.ConstantBoolOp(sparse) + else: + sparse = get_op_result_or_value(sparse) + assert str(sparse.type) == '!torch.bool', f'`sparse` should be a !torch.bool but is {type(sparse).__module__}.{type(sparse).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenEmbeddingOp, self).__init__(result_type, weight, indices, padding_idx, scale_grad_by_freq, sparse, loc=loc, ip=ip) + + +class AtenEmbeddingBagPaddingIdxOp: + def __init__(self, weight: Value, indices: Value, offsets: Value, scale_grad_by_freq: bool, mode: int, sparse: bool, per_sample_weights: Optional[Value], include_last_offset: bool, padding_idx: Optional[int], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(weight): + assert is_mlir_value(weight), f'`weight` should be a Value but is {type(weight).__module__}.{type(weight).__name__}' + else: + weight = get_op_result_or_value(weight) + assert str(weight.type).startswith("!torch.vtensor"), f'`weight` should be a torch.vtensor but is {type(weight).__module__}.{type(weight).__name__}' + + if not is_mlir_value(indices): + assert is_mlir_value(indices), f'`indices` should be a Value but is {type(indices).__module__}.{type(indices).__name__}' + else: + indices = get_op_result_or_value(indices) + assert str(indices.type).startswith("!torch.vtensor"), f'`indices` should be a torch.vtensor but is {type(indices).__module__}.{type(indices).__name__}' + + if not is_mlir_value(offsets): + assert is_mlir_value(offsets), f'`offsets` should be a Value but is {type(offsets).__module__}.{type(offsets).__name__}' + else: + offsets = get_op_result_or_value(offsets) + assert str(offsets.type).startswith("!torch.vtensor"), f'`offsets` should be a torch.vtensor but is {type(offsets).__module__}.{type(offsets).__name__}' + + if not is_mlir_value(scale_grad_by_freq): + scale_grad_by_freq = torch_dialect.ConstantBoolOp(scale_grad_by_freq) + else: + scale_grad_by_freq = get_op_result_or_value(scale_grad_by_freq) + assert str(scale_grad_by_freq.type) == '!torch.bool', f'`scale_grad_by_freq` should be a !torch.bool but is {type(scale_grad_by_freq).__module__}.{type(scale_grad_by_freq).__name__}' + + if not is_mlir_value(mode): + mode = torch_dialect.ConstantIntOp(mode) + else: + mode = get_op_result_or_value(mode) + assert str(mode.type) == '!torch.int', f'`mode` should be a !torch.int but is {type(mode).__module__}.{type(mode).__name__}' + + if not is_mlir_value(sparse): + sparse = torch_dialect.ConstantBoolOp(sparse) + else: + sparse = get_op_result_or_value(sparse) + assert str(sparse.type) == '!torch.bool', f'`sparse` should be a !torch.bool but is {type(sparse).__module__}.{type(sparse).__name__}' + + if not is_mlir_value(per_sample_weights): + if per_sample_weights is not None: + assert is_mlir_value(per_sample_weights), f'`per_sample_weights` should be a Value but is {type(per_sample_weights).__module__}.{type(per_sample_weights).__name__}' + else: + per_sample_weights = torch_dialect.ConstantNoneOp() + else: + per_sample_weights = get_op_result_or_value(per_sample_weights) + assert str(per_sample_weights.type).startswith("!torch.vtensor"), f'`per_sample_weights` should be a torch.vtensor but is {type(per_sample_weights).__module__}.{type(per_sample_weights).__name__}' + + if not is_mlir_value(include_last_offset): + include_last_offset = torch_dialect.ConstantBoolOp(include_last_offset) + else: + include_last_offset = get_op_result_or_value(include_last_offset) + assert str(include_last_offset.type) == '!torch.bool', f'`include_last_offset` should be a !torch.bool but is {type(include_last_offset).__module__}.{type(include_last_offset).__name__}' + + if not is_mlir_value(padding_idx): + if padding_idx is not None: + padding_idx = torch_dialect.ConstantIntOp(padding_idx) + else: + padding_idx = torch_dialect.ConstantNoneOp() + else: + padding_idx = get_op_result_or_value(padding_idx) + assert str(padding_idx.type) == '!torch.int', f'`padding_idx` should be a !torch.int but is {type(padding_idx).__module__}.{type(padding_idx).__name__}' + + result0_type = Type.parse("!torch.vtensor") + result1_type = Type.parse("!torch.vtensor") + result2_type = Type.parse("!torch.vtensor") + result3_type = Type.parse("!torch.vtensor") + super(AtenEmbeddingBagPaddingIdxOp, self).__init__(result0_type, result1_type, result2_type, result3_type, weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx, loc=loc, ip=ip) + + +class Aten_EmbeddingBagOp: + def __init__(self, weight: Value, indices: Value, offsets: Value, scale_grad_by_freq: bool, mode: int, sparse: bool, per_sample_weights: Optional[Value], include_last_offset: bool, padding_idx: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(weight): + assert is_mlir_value(weight), f'`weight` should be a Value but is {type(weight).__module__}.{type(weight).__name__}' + else: + weight = get_op_result_or_value(weight) + assert str(weight.type).startswith("!torch.vtensor"), f'`weight` should be a torch.vtensor but is {type(weight).__module__}.{type(weight).__name__}' + + if not is_mlir_value(indices): + assert is_mlir_value(indices), f'`indices` should be a Value but is {type(indices).__module__}.{type(indices).__name__}' + else: + indices = get_op_result_or_value(indices) + assert str(indices.type).startswith("!torch.vtensor"), f'`indices` should be a torch.vtensor but is {type(indices).__module__}.{type(indices).__name__}' + + if not is_mlir_value(offsets): + assert is_mlir_value(offsets), f'`offsets` should be a Value but is {type(offsets).__module__}.{type(offsets).__name__}' + else: + offsets = get_op_result_or_value(offsets) + assert str(offsets.type).startswith("!torch.vtensor"), f'`offsets` should be a torch.vtensor but is {type(offsets).__module__}.{type(offsets).__name__}' + + if not is_mlir_value(scale_grad_by_freq): + scale_grad_by_freq = torch_dialect.ConstantBoolOp(scale_grad_by_freq) + else: + scale_grad_by_freq = get_op_result_or_value(scale_grad_by_freq) + assert str(scale_grad_by_freq.type) == '!torch.bool', f'`scale_grad_by_freq` should be a !torch.bool but is {type(scale_grad_by_freq).__module__}.{type(scale_grad_by_freq).__name__}' + + if not is_mlir_value(mode): + mode = torch_dialect.ConstantIntOp(mode) + else: + mode = get_op_result_or_value(mode) + assert str(mode.type) == '!torch.int', f'`mode` should be a !torch.int but is {type(mode).__module__}.{type(mode).__name__}' + + if not is_mlir_value(sparse): + sparse = torch_dialect.ConstantBoolOp(sparse) + else: + sparse = get_op_result_or_value(sparse) + assert str(sparse.type) == '!torch.bool', f'`sparse` should be a !torch.bool but is {type(sparse).__module__}.{type(sparse).__name__}' + + if not is_mlir_value(per_sample_weights): + if per_sample_weights is not None: + assert is_mlir_value(per_sample_weights), f'`per_sample_weights` should be a Value but is {type(per_sample_weights).__module__}.{type(per_sample_weights).__name__}' + else: + per_sample_weights = torch_dialect.ConstantNoneOp() + else: + per_sample_weights = get_op_result_or_value(per_sample_weights) + assert str(per_sample_weights.type).startswith("!torch.vtensor"), f'`per_sample_weights` should be a torch.vtensor but is {type(per_sample_weights).__module__}.{type(per_sample_weights).__name__}' + + if not is_mlir_value(include_last_offset): + include_last_offset = torch_dialect.ConstantBoolOp(include_last_offset) + else: + include_last_offset = get_op_result_or_value(include_last_offset) + assert str(include_last_offset.type) == '!torch.bool', f'`include_last_offset` should be a !torch.bool but is {type(include_last_offset).__module__}.{type(include_last_offset).__name__}' + + if not is_mlir_value(padding_idx): + padding_idx = torch_dialect.ConstantIntOp(padding_idx) + else: + padding_idx = get_op_result_or_value(padding_idx) + assert str(padding_idx.type) == '!torch.int', f'`padding_idx` should be a !torch.int but is {type(padding_idx).__module__}.{type(padding_idx).__name__}' + + result0_type = Type.parse("!torch.vtensor") + result1_type = Type.parse("!torch.vtensor") + result2_type = Type.parse("!torch.vtensor") + result3_type = Type.parse("!torch.vtensor") + super(Aten_EmbeddingBagOp, self).__init__(result0_type, result1_type, result2_type, result3_type, weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx, loc=loc, ip=ip) + + +class AtenExpandOp: + def __init__(self, self_: Value, size: List[int], implicit: bool, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(size): + size = list(map(torch_dialect.ConstantIntOp, size)) + size = torch_dialect.PrimListConstructOp(size) + else: + size = get_op_result_or_value(size) + assert str(size.type) == '!torch.list', f'`size` should be a !torch.list but is {type(size).__module__}.{type(size).__name__}' + + if not is_mlir_value(implicit): + implicit = torch_dialect.ConstantBoolOp(implicit) + else: + implicit = get_op_result_or_value(implicit) + assert str(implicit.type) == '!torch.bool', f'`implicit` should be a !torch.bool but is {type(implicit).__module__}.{type(implicit).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenExpandOp, self).__init__(result_type, self_, size, implicit, loc=loc, ip=ip) + + +class AtenExpandAsOp: + def __init__(self, self_: Value, other: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}' + else: + other = get_op_result_or_value(other) + assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenExpandAsOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenBroadcastToOp: + def __init__(self, self_: Value, size: List[int], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(size): + size = list(map(torch_dialect.ConstantIntOp, size)) + size = torch_dialect.PrimListConstructOp(size) + else: + size = get_op_result_or_value(size) + assert str(size.type) == '!torch.list', f'`size` should be a !torch.list but is {type(size).__module__}.{type(size).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenBroadcastToOp, self).__init__(result_type, self_, size, loc=loc, ip=ip) + + +class AtenIndexTensorHackedTwinOp: + def __init__(self, self_: Value, indices: List[Value], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(indices): + indices = torch_dialect.PrimListConstructOp(indices) + else: + indices = get_op_result_or_value(indices) + assert str(indices.type) == '!torch.list', f'`indices` should be a !torch.list but is {type(indices).__module__}.{type(indices).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenIndexTensorHackedTwinOp, self).__init__(result_type, self_, indices, loc=loc, ip=ip) + + +class AtenIndexSelectOp: + def __init__(self, self_: Value, dim: int, index: Value, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(dim): + dim = torch_dialect.ConstantIntOp(dim) + else: + dim = get_op_result_or_value(dim) + assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}' + + if not is_mlir_value(index): + assert is_mlir_value(index), f'`index` should be a Value but is {type(index).__module__}.{type(index).__name__}' + else: + index = get_op_result_or_value(index) + assert str(index.type).startswith("!torch.vtensor"), f'`index` should be a torch.vtensor but is {type(index).__module__}.{type(index).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenIndexSelectOp, self).__init__(result_type, self_, dim, index, loc=loc, ip=ip) + + +class AtenItemOp: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + super(AtenItemOp, self).__init__(self_, loc=loc, ip=ip) + + +class AtenMaskedSelectOp: + def __init__(self, self_: Value, mask: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(mask): + assert is_mlir_value(mask), f'`mask` should be a Value but is {type(mask).__module__}.{type(mask).__name__}' + else: + mask = get_op_result_or_value(mask) + assert str(mask.type).startswith("!torch.vtensor"), f'`mask` should be a torch.vtensor but is {type(mask).__module__}.{type(mask).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenMaskedSelectOp, self).__init__(result_type, self_, mask, loc=loc, ip=ip) + + +class AtenNumelOp: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + super(AtenNumelOp, self).__init__(self_, loc=loc, ip=ip) + + +class AtenRepeatOp: + def __init__(self, self_: Value, repeats: List[int], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(repeats): + repeats = list(map(torch_dialect.ConstantIntOp, repeats)) + repeats = torch_dialect.PrimListConstructOp(repeats) + else: + repeats = get_op_result_or_value(repeats) + assert str(repeats.type) == '!torch.list', f'`repeats` should be a !torch.list but is {type(repeats).__module__}.{type(repeats).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenRepeatOp, self).__init__(result_type, self_, repeats, loc=loc, ip=ip) + + +class AtenReshapeOp: + def __init__(self, self_: Value, shape: List[int], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(shape): + shape = list(map(torch_dialect.ConstantIntOp, shape)) + shape = torch_dialect.PrimListConstructOp(shape) + else: + shape = get_op_result_or_value(shape) + assert str(shape.type) == '!torch.list', f'`shape` should be a !torch.list but is {type(shape).__module__}.{type(shape).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenReshapeOp, self).__init__(result_type, self_, shape, loc=loc, ip=ip) + + +class Aten_ReshapeAliasOp: + def __init__(self, self_: Value, size: List[int], stride: List[int], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(size): + size = list(map(torch_dialect.ConstantIntOp, size)) + size = torch_dialect.PrimListConstructOp(size) + else: + size = get_op_result_or_value(size) + assert str(size.type) == '!torch.list', f'`size` should be a !torch.list but is {type(size).__module__}.{type(size).__name__}' + + if not is_mlir_value(stride): + stride = list(map(torch_dialect.ConstantIntOp, stride)) + stride = torch_dialect.PrimListConstructOp(stride) + else: + stride = get_op_result_or_value(stride) + assert str(stride.type) == '!torch.list', f'`stride` should be a !torch.list but is {type(stride).__module__}.{type(stride).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(Aten_ReshapeAliasOp, self).__init__(result_type, self_, size, stride, loc=loc, ip=ip) + + +class AtenResize_Op: + def __init__(self, self_: Value, size: List[int], memory_format: Optional[int], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(size): + size = list(map(torch_dialect.ConstantIntOp, size)) + size = torch_dialect.PrimListConstructOp(size) + else: + size = get_op_result_or_value(size) + assert str(size.type) == '!torch.list', f'`size` should be a !torch.list but is {type(size).__module__}.{type(size).__name__}' + + if not is_mlir_value(memory_format): + if memory_format is not None: + memory_format = torch_dialect.ConstantIntOp(memory_format) + else: + memory_format = torch_dialect.ConstantNoneOp() + else: + memory_format = get_op_result_or_value(memory_format) + assert str(memory_format.type) == '!torch.int', f'`memory_format` should be a !torch.int but is {type(memory_format).__module__}.{type(memory_format).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenResize_Op, self).__init__(result_type, self_, size, memory_format, loc=loc, ip=ip) + + +class AtenSelectIntOp: + def __init__(self, self_: Value, dim: int, index: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(dim): + dim = torch_dialect.ConstantIntOp(dim) + else: + dim = get_op_result_or_value(dim) + assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}' + + if not is_mlir_value(index): + index = torch_dialect.ConstantIntOp(index) + else: + index = get_op_result_or_value(index) + assert str(index.type) == '!torch.int', f'`index` should be a !torch.int but is {type(index).__module__}.{type(index).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenSelectIntOp, self).__init__(result_type, self_, dim, index, loc=loc, ip=ip) + + +class AtenSizeIntOp: + def __init__(self, self_: Value, dim: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(dim): + dim = torch_dialect.ConstantIntOp(dim) + else: + dim = get_op_result_or_value(dim) + assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}' + + super(AtenSizeIntOp, self).__init__(self_, dim, loc=loc, ip=ip) + + +class AtenStackOp: + def __init__(self, tensors: List[Value], dim: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(tensors): + tensors = torch_dialect.PrimListConstructOp(tensors) + else: + tensors = get_op_result_or_value(tensors) + assert str(tensors.type) == '!torch.list', f'`tensors` should be a !torch.list but is {type(tensors).__module__}.{type(tensors).__name__}' + + if not is_mlir_value(dim): + dim = torch_dialect.ConstantIntOp(dim) + else: + dim = get_op_result_or_value(dim) + assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenStackOp, self).__init__(result_type, tensors, dim, loc=loc, ip=ip) + + +class AtenSumOp: + def __init__(self, self_: Value, dtype: Optional[int], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(dtype): + if dtype is not None: + dtype = torch_dialect.ConstantIntOp(dtype) + else: + dtype = torch_dialect.ConstantNoneOp() + else: + dtype = get_op_result_or_value(dtype) + assert str(dtype.type) == '!torch.int', f'`dtype` should be a !torch.int but is {type(dtype).__module__}.{type(dtype).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenSumOp, self).__init__(result_type, self_, dtype, loc=loc, ip=ip) + + +class AtenMaxOp: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenMaxOp, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenMaxDimOp: + def __init__(self, self_: Value, dim: int, keepdim: bool, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(dim): + dim = torch_dialect.ConstantIntOp(dim) + else: + dim = get_op_result_or_value(dim) + assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}' + + if not is_mlir_value(keepdim): + keepdim = torch_dialect.ConstantBoolOp(keepdim) + else: + keepdim = get_op_result_or_value(keepdim) + assert str(keepdim.type) == '!torch.bool', f'`keepdim` should be a !torch.bool but is {type(keepdim).__module__}.{type(keepdim).__name__}' + + values_type = Type.parse("!torch.vtensor") + indices_type = Type.parse("!torch.vtensor") + super(AtenMaxDimOp, self).__init__(values_type, indices_type, self_, dim, keepdim, loc=loc, ip=ip) + + +class AtenAmaxOp: + def __init__(self, self_: Value, dim: List[int], keepdim: bool, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(dim): + dim = list(map(torch_dialect.ConstantIntOp, dim)) + dim = torch_dialect.PrimListConstructOp(dim) + else: + dim = get_op_result_or_value(dim) + assert str(dim.type) == '!torch.list', f'`dim` should be a !torch.list but is {type(dim).__module__}.{type(dim).__name__}' + + if not is_mlir_value(keepdim): + keepdim = torch_dialect.ConstantBoolOp(keepdim) + else: + keepdim = get_op_result_or_value(keepdim) + assert str(keepdim.type) == '!torch.bool', f'`keepdim` should be a !torch.bool but is {type(keepdim).__module__}.{type(keepdim).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenAmaxOp, self).__init__(result_type, self_, dim, keepdim, loc=loc, ip=ip) + + +class AtenToDtypeOp: + def __init__(self, self_: Value, dtype: int, non_blocking: bool, copy: bool, memory_format: Optional[int], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(dtype): + dtype = torch_dialect.ConstantIntOp(dtype) + else: + dtype = get_op_result_or_value(dtype) + assert str(dtype.type) == '!torch.int', f'`dtype` should be a !torch.int but is {type(dtype).__module__}.{type(dtype).__name__}' + + if not is_mlir_value(non_blocking): + non_blocking = torch_dialect.ConstantBoolOp(non_blocking) + else: + non_blocking = get_op_result_or_value(non_blocking) + assert str(non_blocking.type) == '!torch.bool', f'`non_blocking` should be a !torch.bool but is {type(non_blocking).__module__}.{type(non_blocking).__name__}' + + if not is_mlir_value(copy): + copy = torch_dialect.ConstantBoolOp(copy) + else: + copy = get_op_result_or_value(copy) + assert str(copy.type) == '!torch.bool', f'`copy` should be a !torch.bool but is {type(copy).__module__}.{type(copy).__name__}' + + if not is_mlir_value(memory_format): + if memory_format is not None: + memory_format = torch_dialect.ConstantIntOp(memory_format) + else: + memory_format = torch_dialect.ConstantNoneOp() + else: + memory_format = get_op_result_or_value(memory_format) + assert str(memory_format.type) == '!torch.int', f'`memory_format` should be a !torch.int but is {type(memory_format).__module__}.{type(memory_format).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenToDtypeOp, self).__init__(result_type, self_, dtype, non_blocking, copy, memory_format, loc=loc, ip=ip) + + +class AtenToOtherOp: + def __init__(self, self_: Value, other: Value, non_blocking: bool, copy: bool, memory_format: Optional[int], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}' + else: + other = get_op_result_or_value(other) + assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}' + + if not is_mlir_value(non_blocking): + non_blocking = torch_dialect.ConstantBoolOp(non_blocking) + else: + non_blocking = get_op_result_or_value(non_blocking) + assert str(non_blocking.type) == '!torch.bool', f'`non_blocking` should be a !torch.bool but is {type(non_blocking).__module__}.{type(non_blocking).__name__}' + + if not is_mlir_value(copy): + copy = torch_dialect.ConstantBoolOp(copy) + else: + copy = get_op_result_or_value(copy) + assert str(copy.type) == '!torch.bool', f'`copy` should be a !torch.bool but is {type(copy).__module__}.{type(copy).__name__}' + + if not is_mlir_value(memory_format): + if memory_format is not None: + memory_format = torch_dialect.ConstantIntOp(memory_format) + else: + memory_format = torch_dialect.ConstantNoneOp() + else: + memory_format = get_op_result_or_value(memory_format) + assert str(memory_format.type) == '!torch.int', f'`memory_format` should be a !torch.int but is {type(memory_format).__module__}.{type(memory_format).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenToOtherOp, self).__init__(result_type, self_, other, non_blocking, copy, memory_format, loc=loc, ip=ip) + + +class AtenTypeAsOp: + def __init__(self, self_: Value, other: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}' + else: + other = get_op_result_or_value(other) + assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenTypeAsOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenViewOp: + def __init__(self, self_: Value, size: List[int], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(size): + size = list(map(torch_dialect.ConstantIntOp, size)) + size = torch_dialect.PrimListConstructOp(size) + else: + size = get_op_result_or_value(size) + assert str(size.type) == '!torch.list', f'`size` should be a !torch.list but is {type(size).__module__}.{type(size).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenViewOp, self).__init__(result_type, self_, size, loc=loc, ip=ip) + + +class Aten_UnsafeViewOp: + def __init__(self, self_: Value, size: List[int], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(size): + size = list(map(torch_dialect.ConstantIntOp, size)) + size = torch_dialect.PrimListConstructOp(size) + else: + size = get_op_result_or_value(size) + assert str(size.type) == '!torch.list', f'`size` should be a !torch.list but is {type(size).__module__}.{type(size).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(Aten_UnsafeViewOp, self).__init__(result_type, self_, size, loc=loc, ip=ip) + + +class AtenWhereSelfOp: + def __init__(self, condition: Value, self_: Value, other: Value, *, loc=None, ip=None): + if not is_mlir_value(condition): + assert is_mlir_value(condition), f'`condition` should be a Value but is {type(condition).__module__}.{type(condition).__name__}' + else: + condition = get_op_result_or_value(condition) + assert str(condition.type).startswith("!torch.vtensor"), f'`condition` should be a torch.vtensor but is {type(condition).__module__}.{type(condition).__name__}' + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}' + else: + other = get_op_result_or_value(other) + assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenWhereSelfOp, self).__init__(result_type, condition, self_, other, loc=loc, ip=ip) + + +class AtenWhereScalarOp: + def __init__(self, condition: Value, self_: "Number", other: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(condition): + assert is_mlir_value(condition), f'`condition` should be a Value but is {type(condition).__module__}.{type(condition).__name__}' + else: + condition = get_op_result_or_value(condition) + assert str(condition.type).startswith("!torch.vtensor"), f'`condition` should be a torch.vtensor but is {type(condition).__module__}.{type(condition).__name__}' + + if not is_mlir_value(self_): + self_ = torch_dialect.ConstantNumberOp(self_) + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type) in {'!torch.float', '!torch.int'}, f'`self_` should be a !torch.number but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + other = torch_dialect.ConstantNumberOp(other) + else: + other = get_op_result_or_value(other) + assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenWhereScalarOp, self).__init__(result_type, condition, self_, other, loc=loc, ip=ip) + + +class AtenWhereScalarOtherOp: + def __init__(self, condition: Value, self_: Value, other: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(condition): + assert is_mlir_value(condition), f'`condition` should be a Value but is {type(condition).__module__}.{type(condition).__name__}' + else: + condition = get_op_result_or_value(condition) + assert str(condition.type).startswith("!torch.vtensor"), f'`condition` should be a torch.vtensor but is {type(condition).__module__}.{type(condition).__name__}' + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + other = torch_dialect.ConstantNumberOp(other) + else: + other = get_op_result_or_value(other) + assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenWhereScalarOtherOp, self).__init__(result_type, condition, self_, other, loc=loc, ip=ip) + + +class AtenWhereScalarSelfOp: + def __init__(self, condition: Value, self_: "Number", other: Value, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(condition): + assert is_mlir_value(condition), f'`condition` should be a Value but is {type(condition).__module__}.{type(condition).__name__}' + else: + condition = get_op_result_or_value(condition) + assert str(condition.type).startswith("!torch.vtensor"), f'`condition` should be a torch.vtensor but is {type(condition).__module__}.{type(condition).__name__}' + + if not is_mlir_value(self_): + self_ = torch_dialect.ConstantNumberOp(self_) + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type) in {'!torch.float', '!torch.int'}, f'`self_` should be a !torch.number but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}' + else: + other = get_op_result_or_value(other) + assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenWhereScalarSelfOp, self).__init__(result_type, condition, self_, other, loc=loc, ip=ip) + + +class AtenSliceTensorOp: + def __init__(self, self_: Value, dim: int, start: Optional[int], end: Optional[int], step: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(dim): + dim = torch_dialect.ConstantIntOp(dim) + else: + dim = get_op_result_or_value(dim) + assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}' + + if not is_mlir_value(start): + if start is not None: + start = torch_dialect.ConstantIntOp(start) + else: + start = torch_dialect.ConstantNoneOp() + else: + start = get_op_result_or_value(start) + assert str(start.type) == '!torch.int', f'`start` should be a !torch.int but is {type(start).__module__}.{type(start).__name__}' + + if not is_mlir_value(end): + if end is not None: + end = torch_dialect.ConstantIntOp(end) + else: + end = torch_dialect.ConstantNoneOp() + else: + end = get_op_result_or_value(end) + assert str(end.type) == '!torch.int', f'`end` should be a !torch.int but is {type(end).__module__}.{type(end).__name__}' + + if not is_mlir_value(step): + step = torch_dialect.ConstantIntOp(step) + else: + step = get_op_result_or_value(step) + assert str(step.type) == '!torch.int', f'`step` should be a !torch.int but is {type(step).__module__}.{type(step).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenSliceTensorOp, self).__init__(result_type, self_, dim, start, end, step, loc=loc, ip=ip) + + +class AtenLenTensorOp: + def __init__(self, t: Value, *, loc=None, ip=None): + if not is_mlir_value(t): + assert is_mlir_value(t), f'`t` should be a Value but is {type(t).__module__}.{type(t).__name__}' + else: + t = get_op_result_or_value(t) + assert str(t.type).startswith("!torch.vtensor"), f'`t` should be a torch.vtensor but is {type(t).__module__}.{type(t).__name__}' + + super(AtenLenTensorOp, self).__init__(t, loc=loc, ip=ip) + + +class AtenCpuOp: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenCpuOp, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenGatherOp: + def __init__(self, self_: Value, dim: int, index: Value, sparse_grad: bool, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(dim): + dim = torch_dialect.ConstantIntOp(dim) + else: + dim = get_op_result_or_value(dim) + assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}' + + if not is_mlir_value(index): + assert is_mlir_value(index), f'`index` should be a Value but is {type(index).__module__}.{type(index).__name__}' + else: + index = get_op_result_or_value(index) + assert str(index.type).startswith("!torch.vtensor"), f'`index` should be a torch.vtensor but is {type(index).__module__}.{type(index).__name__}' + + if not is_mlir_value(sparse_grad): + sparse_grad = torch_dialect.ConstantBoolOp(sparse_grad) + else: + sparse_grad = get_op_result_or_value(sparse_grad) + assert str(sparse_grad.type) == '!torch.bool', f'`sparse_grad` should be a !torch.bool but is {type(sparse_grad).__module__}.{type(sparse_grad).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenGatherOp, self).__init__(result_type, self_, dim, index, sparse_grad, loc=loc, ip=ip) + + +class AtenScatterAddOp: + def __init__(self, self_: Value, dim: int, index: Value, src: Value, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(dim): + dim = torch_dialect.ConstantIntOp(dim) + else: + dim = get_op_result_or_value(dim) + assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}' + + if not is_mlir_value(index): + assert is_mlir_value(index), f'`index` should be a Value but is {type(index).__module__}.{type(index).__name__}' + else: + index = get_op_result_or_value(index) + assert str(index.type).startswith("!torch.vtensor"), f'`index` should be a torch.vtensor but is {type(index).__module__}.{type(index).__name__}' + + if not is_mlir_value(src): + assert is_mlir_value(src), f'`src` should be a Value but is {type(src).__module__}.{type(src).__name__}' + else: + src = get_op_result_or_value(src) + assert str(src.type).startswith("!torch.vtensor"), f'`src` should be a torch.vtensor but is {type(src).__module__}.{type(src).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenScatterAddOp, self).__init__(result_type, self_, dim, index, src, loc=loc, ip=ip) + + +class AtenIntImplicitOp: + def __init__(self, a: Value, *, loc=None, ip=None): + if not is_mlir_value(a): + assert is_mlir_value(a), f'`a` should be a Value but is {type(a).__module__}.{type(a).__name__}' + else: + a = get_op_result_or_value(a) + assert str(a.type).startswith("!torch.vtensor"), f'`a` should be a torch.vtensor but is {type(a).__module__}.{type(a).__name__}' + + super(AtenIntImplicitOp, self).__init__(a, loc=loc, ip=ip) + + +class AtenFloatImplicitOp: + def __init__(self, a: Value, *, loc=None, ip=None): + if not is_mlir_value(a): + assert is_mlir_value(a), f'`a` should be a Value but is {type(a).__module__}.{type(a).__name__}' + else: + a = get_op_result_or_value(a) + assert str(a.type).startswith("!torch.vtensor"), f'`a` should be a torch.vtensor but is {type(a).__module__}.{type(a).__name__}' + + super(AtenFloatImplicitOp, self).__init__(a, loc=loc, ip=ip) + + +class AtenIntTensorOp: + def __init__(self, a: Value, *, loc=None, ip=None): + if not is_mlir_value(a): + assert is_mlir_value(a), f'`a` should be a Value but is {type(a).__module__}.{type(a).__name__}' + else: + a = get_op_result_or_value(a) + assert str(a.type).startswith("!torch.vtensor"), f'`a` should be a torch.vtensor but is {type(a).__module__}.{type(a).__name__}' + + super(AtenIntTensorOp, self).__init__(a, loc=loc, ip=ip) + + +class AtenFloatTensorOp: + def __init__(self, a: Value, *, loc=None, ip=None): + if not is_mlir_value(a): + assert is_mlir_value(a), f'`a` should be a Value but is {type(a).__module__}.{type(a).__name__}' + else: + a = get_op_result_or_value(a) + assert str(a.type).startswith("!torch.vtensor"), f'`a` should be a torch.vtensor but is {type(a).__module__}.{type(a).__name__}' + + super(AtenFloatTensorOp, self).__init__(a, loc=loc, ip=ip) + + +class AtenDropoutOp: + def __init__(self, input: Value, p: float, train: bool, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(input): + assert is_mlir_value(input), f'`input` should be a Value but is {type(input).__module__}.{type(input).__name__}' + else: + input = get_op_result_or_value(input) + assert str(input.type).startswith("!torch.vtensor"), f'`input` should be a torch.vtensor but is {type(input).__module__}.{type(input).__name__}' + + if not is_mlir_value(p): + p = torch_dialect.ConstantFloatOp(p) + else: + p = get_op_result_or_value(p) + assert str(p.type) == '!torch.float', f'`p` should be a !torch.float but is {type(p).__module__}.{type(p).__name__}' + + if not is_mlir_value(train): + train = torch_dialect.ConstantBoolOp(train) + else: + train = get_op_result_or_value(train) + assert str(train.type) == '!torch.bool', f'`train` should be a !torch.bool but is {type(train).__module__}.{type(train).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenDropoutOp, self).__init__(result_type, input, p, train, loc=loc, ip=ip) + + +class AtenDropout_Op: + def __init__(self, self_: Value, p: float, train: bool, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(p): + p = torch_dialect.ConstantFloatOp(p) + else: + p = get_op_result_or_value(p) + assert str(p.type) == '!torch.float', f'`p` should be a !torch.float but is {type(p).__module__}.{type(p).__name__}' + + if not is_mlir_value(train): + train = torch_dialect.ConstantBoolOp(train) + else: + train = get_op_result_or_value(train) + assert str(train.type) == '!torch.bool', f'`train` should be a !torch.bool but is {type(train).__module__}.{type(train).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenDropout_Op, self).__init__(result_type, self_, p, train, loc=loc, ip=ip) + + +class AtenNativeDropoutOp: + def __init__(self, input: Value, p: float, train: Optional[bool], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(input): + assert is_mlir_value(input), f'`input` should be a Value but is {type(input).__module__}.{type(input).__name__}' + else: + input = get_op_result_or_value(input) + assert str(input.type).startswith("!torch.vtensor"), f'`input` should be a torch.vtensor but is {type(input).__module__}.{type(input).__name__}' + + if not is_mlir_value(p): + p = torch_dialect.ConstantFloatOp(p) + else: + p = get_op_result_or_value(p) + assert str(p.type) == '!torch.float', f'`p` should be a !torch.float but is {type(p).__module__}.{type(p).__name__}' + + if not is_mlir_value(train): + if train is not None: + train = torch_dialect.ConstantBoolOp(train) + else: + train = torch_dialect.ConstantNoneOp() + else: + train = get_op_result_or_value(train) + assert str(train.type) == '!torch.bool', f'`train` should be a !torch.bool but is {type(train).__module__}.{type(train).__name__}' + + result0_type = Type.parse("!torch.vtensor") + result1_type = Type.parse("!torch.vtensor") + super(AtenNativeDropoutOp, self).__init__(result0_type, result1_type, input, p, train, loc=loc, ip=ip) + + +class AtenTOp: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenTOp, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenNumpyTOp: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenNumpyTOp, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenBaddbmmOp: + def __init__(self, self_: Value, batch1: Value, batch2: Value, beta: "Number", alpha: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(batch1): + assert is_mlir_value(batch1), f'`batch1` should be a Value but is {type(batch1).__module__}.{type(batch1).__name__}' + else: + batch1 = get_op_result_or_value(batch1) + assert str(batch1.type).startswith("!torch.vtensor"), f'`batch1` should be a torch.vtensor but is {type(batch1).__module__}.{type(batch1).__name__}' + + if not is_mlir_value(batch2): + assert is_mlir_value(batch2), f'`batch2` should be a Value but is {type(batch2).__module__}.{type(batch2).__name__}' + else: + batch2 = get_op_result_or_value(batch2) + assert str(batch2.type).startswith("!torch.vtensor"), f'`batch2` should be a torch.vtensor but is {type(batch2).__module__}.{type(batch2).__name__}' + + if not is_mlir_value(beta): + beta = torch_dialect.ConstantNumberOp(beta) + else: + beta = get_op_result_or_value(beta) + assert str(beta.type) in {'!torch.float', '!torch.int'}, f'`beta` should be a !torch.number but is {type(beta).__module__}.{type(beta).__name__}' + + if not is_mlir_value(alpha): + alpha = torch_dialect.ConstantNumberOp(alpha) + else: + alpha = get_op_result_or_value(alpha) + assert str(alpha.type) in {'!torch.float', '!torch.int'}, f'`alpha` should be a !torch.number but is {type(alpha).__module__}.{type(alpha).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenBaddbmmOp, self).__init__(result_type, self_, batch1, batch2, beta, alpha, loc=loc, ip=ip) + + +class AtenBaddbmm_Op: + def __init__(self, self_: Value, batch1: Value, batch2: Value, beta: "Number", alpha: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(batch1): + assert is_mlir_value(batch1), f'`batch1` should be a Value but is {type(batch1).__module__}.{type(batch1).__name__}' + else: + batch1 = get_op_result_or_value(batch1) + assert str(batch1.type).startswith("!torch.vtensor"), f'`batch1` should be a torch.vtensor but is {type(batch1).__module__}.{type(batch1).__name__}' + + if not is_mlir_value(batch2): + assert is_mlir_value(batch2), f'`batch2` should be a Value but is {type(batch2).__module__}.{type(batch2).__name__}' + else: + batch2 = get_op_result_or_value(batch2) + assert str(batch2.type).startswith("!torch.vtensor"), f'`batch2` should be a torch.vtensor but is {type(batch2).__module__}.{type(batch2).__name__}' + + if not is_mlir_value(beta): + beta = torch_dialect.ConstantNumberOp(beta) + else: + beta = get_op_result_or_value(beta) + assert str(beta.type) in {'!torch.float', '!torch.int'}, f'`beta` should be a !torch.number but is {type(beta).__module__}.{type(beta).__name__}' + + if not is_mlir_value(alpha): + alpha = torch_dialect.ConstantNumberOp(alpha) + else: + alpha = get_op_result_or_value(alpha) + assert str(alpha.type) in {'!torch.float', '!torch.int'}, f'`alpha` should be a !torch.number but is {type(alpha).__module__}.{type(alpha).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenBaddbmm_Op, self).__init__(result_type, self_, batch1, batch2, beta, alpha, loc=loc, ip=ip) + + +class AtenFftFftOp: + def __init__(self, self_: Value, n: Optional[int], dim: int, norm: Optional[str], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(n): + if n is not None: + n = torch_dialect.ConstantIntOp(n) + else: + n = torch_dialect.ConstantNoneOp() + else: + n = get_op_result_or_value(n) + assert str(n.type) == '!torch.int', f'`n` should be a !torch.int but is {type(n).__module__}.{type(n).__name__}' + + if not is_mlir_value(dim): + dim = torch_dialect.ConstantIntOp(dim) + else: + dim = get_op_result_or_value(dim) + assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}' + + if not is_mlir_value(norm): + if norm is not None: + norm = torch_dialect.ConstantStrOp(norm) + else: + norm = torch_dialect.ConstantNoneOp() + else: + norm = get_op_result_or_value(norm) + assert str(norm.type) == '!torch.str', f'`norm` should be a !torch.str but is {type(norm).__module__}.{type(norm).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenFftFftOp, self).__init__(result_type, self_, n, dim, norm, loc=loc, ip=ip) + + +class AtenAliasCopyOp: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenAliasCopyOp, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenAsStridedCopyOp: + def __init__(self, self_: Value, size: List[int], stride: List[int], storage_offset: Optional[int], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(size): + size = list(map(torch_dialect.ConstantIntOp, size)) + size = torch_dialect.PrimListConstructOp(size) + else: + size = get_op_result_or_value(size) + assert str(size.type) == '!torch.list', f'`size` should be a !torch.list but is {type(size).__module__}.{type(size).__name__}' + + if not is_mlir_value(stride): + stride = list(map(torch_dialect.ConstantIntOp, stride)) + stride = torch_dialect.PrimListConstructOp(stride) + else: + stride = get_op_result_or_value(stride) + assert str(stride.type) == '!torch.list', f'`stride` should be a !torch.list but is {type(stride).__module__}.{type(stride).__name__}' + + if not is_mlir_value(storage_offset): + if storage_offset is not None: + storage_offset = torch_dialect.ConstantIntOp(storage_offset) + else: + storage_offset = torch_dialect.ConstantNoneOp() + else: + storage_offset = get_op_result_or_value(storage_offset) + assert str(storage_offset.type) == '!torch.int', f'`storage_offset` should be a !torch.int but is {type(storage_offset).__module__}.{type(storage_offset).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenAsStridedCopyOp, self).__init__(result_type, self_, size, stride, storage_offset, loc=loc, ip=ip) + + +class AtenDiagonalCopyOp: + def __init__(self, self_: Value, offset: int, dim1: int, dim2: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(offset): + offset = torch_dialect.ConstantIntOp(offset) + else: + offset = get_op_result_or_value(offset) + assert str(offset.type) == '!torch.int', f'`offset` should be a !torch.int but is {type(offset).__module__}.{type(offset).__name__}' + + if not is_mlir_value(dim1): + dim1 = torch_dialect.ConstantIntOp(dim1) + else: + dim1 = get_op_result_or_value(dim1) + assert str(dim1.type) == '!torch.int', f'`dim1` should be a !torch.int but is {type(dim1).__module__}.{type(dim1).__name__}' + + if not is_mlir_value(dim2): + dim2 = torch_dialect.ConstantIntOp(dim2) + else: + dim2 = get_op_result_or_value(dim2) + assert str(dim2.type) == '!torch.int', f'`dim2` should be a !torch.int but is {type(dim2).__module__}.{type(dim2).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenDiagonalCopyOp, self).__init__(result_type, self_, offset, dim1, dim2, loc=loc, ip=ip) + + +class AtenExpandCopyOp: + def __init__(self, self_: Value, size: List[int], implicit: bool, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(size): + size = list(map(torch_dialect.ConstantIntOp, size)) + size = torch_dialect.PrimListConstructOp(size) + else: + size = get_op_result_or_value(size) + assert str(size.type) == '!torch.list', f'`size` should be a !torch.list but is {type(size).__module__}.{type(size).__name__}' + + if not is_mlir_value(implicit): + implicit = torch_dialect.ConstantBoolOp(implicit) + else: + implicit = get_op_result_or_value(implicit) + assert str(implicit.type) == '!torch.bool', f'`implicit` should be a !torch.bool but is {type(implicit).__module__}.{type(implicit).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenExpandCopyOp, self).__init__(result_type, self_, size, implicit, loc=loc, ip=ip) + + +class AtenPermuteCopyOp: + def __init__(self, self_: Value, dims: List[int], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(dims): + dims = list(map(torch_dialect.ConstantIntOp, dims)) + dims = torch_dialect.PrimListConstructOp(dims) + else: + dims = get_op_result_or_value(dims) + assert str(dims.type) == '!torch.list', f'`dims` should be a !torch.list but is {type(dims).__module__}.{type(dims).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenPermuteCopyOp, self).__init__(result_type, self_, dims, loc=loc, ip=ip) + + +class Aten_ReshapeAliasCopyOp: + def __init__(self, self_: Value, size: List[int], stride: List[int], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(size): + size = list(map(torch_dialect.ConstantIntOp, size)) + size = torch_dialect.PrimListConstructOp(size) + else: + size = get_op_result_or_value(size) + assert str(size.type) == '!torch.list', f'`size` should be a !torch.list but is {type(size).__module__}.{type(size).__name__}' + + if not is_mlir_value(stride): + stride = list(map(torch_dialect.ConstantIntOp, stride)) + stride = torch_dialect.PrimListConstructOp(stride) + else: + stride = get_op_result_or_value(stride) + assert str(stride.type) == '!torch.list', f'`stride` should be a !torch.list but is {type(stride).__module__}.{type(stride).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(Aten_ReshapeAliasCopyOp, self).__init__(result_type, self_, size, stride, loc=loc, ip=ip) + + +class AtenSelectCopyIntOp: + def __init__(self, self_: Value, dim: int, index: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(dim): + dim = torch_dialect.ConstantIntOp(dim) + else: + dim = get_op_result_or_value(dim) + assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}' + + if not is_mlir_value(index): + index = torch_dialect.ConstantIntOp(index) + else: + index = get_op_result_or_value(index) + assert str(index.type) == '!torch.int', f'`index` should be a !torch.int but is {type(index).__module__}.{type(index).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenSelectCopyIntOp, self).__init__(result_type, self_, dim, index, loc=loc, ip=ip) + + +class AtenDetachCopyOp: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenDetachCopyOp, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenSliceCopyTensorOp: + def __init__(self, self_: Value, dim: int, start: Optional[int], end: Optional[int], step: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(dim): + dim = torch_dialect.ConstantIntOp(dim) + else: + dim = get_op_result_or_value(dim) + assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}' + + if not is_mlir_value(start): + if start is not None: + start = torch_dialect.ConstantIntOp(start) + else: + start = torch_dialect.ConstantNoneOp() + else: + start = get_op_result_or_value(start) + assert str(start.type) == '!torch.int', f'`start` should be a !torch.int but is {type(start).__module__}.{type(start).__name__}' + + if not is_mlir_value(end): + if end is not None: + end = torch_dialect.ConstantIntOp(end) + else: + end = torch_dialect.ConstantNoneOp() + else: + end = get_op_result_or_value(end) + assert str(end.type) == '!torch.int', f'`end` should be a !torch.int but is {type(end).__module__}.{type(end).__name__}' + + if not is_mlir_value(step): + step = torch_dialect.ConstantIntOp(step) + else: + step = get_op_result_or_value(step) + assert str(step.type) == '!torch.int', f'`step` should be a !torch.int but is {type(step).__module__}.{type(step).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenSliceCopyTensorOp, self).__init__(result_type, self_, dim, start, end, step, loc=loc, ip=ip) + + +class AtenSqueezeCopyOp: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenSqueezeCopyOp, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenSqueezeCopyDimOp: + def __init__(self, self_: Value, dim: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(dim): + dim = torch_dialect.ConstantIntOp(dim) + else: + dim = get_op_result_or_value(dim) + assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenSqueezeCopyDimOp, self).__init__(result_type, self_, dim, loc=loc, ip=ip) + + +class AtenTCopyOp: + def __init__(self, self_: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenTCopyOp, self).__init__(result_type, self_, loc=loc, ip=ip) + + +class AtenTransposeCopyIntOp: + def __init__(self, self_: Value, dim0: int, dim1: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(dim0): + dim0 = torch_dialect.ConstantIntOp(dim0) + else: + dim0 = get_op_result_or_value(dim0) + assert str(dim0.type) == '!torch.int', f'`dim0` should be a !torch.int but is {type(dim0).__module__}.{type(dim0).__name__}' + + if not is_mlir_value(dim1): + dim1 = torch_dialect.ConstantIntOp(dim1) + else: + dim1 = get_op_result_or_value(dim1) + assert str(dim1.type) == '!torch.int', f'`dim1` should be a !torch.int but is {type(dim1).__module__}.{type(dim1).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenTransposeCopyIntOp, self).__init__(result_type, self_, dim0, dim1, loc=loc, ip=ip) + + +class AtenUnsqueezeCopyOp: + def __init__(self, self_: Value, dim: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(dim): + dim = torch_dialect.ConstantIntOp(dim) + else: + dim = get_op_result_or_value(dim) + assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenUnsqueezeCopyOp, self).__init__(result_type, self_, dim, loc=loc, ip=ip) + + +class AtenViewCopyOp: + def __init__(self, self_: Value, size: List[int], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(size): + size = list(map(torch_dialect.ConstantIntOp, size)) + size = torch_dialect.PrimListConstructOp(size) + else: + size = get_op_result_or_value(size) + assert str(size.type) == '!torch.list', f'`size` should be a !torch.list but is {type(size).__module__}.{type(size).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenViewCopyOp, self).__init__(result_type, self_, size, loc=loc, ip=ip) + + +class AtenViewCopyDtypeOp: + def __init__(self, self_: Value, dtype: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(dtype): + dtype = torch_dialect.ConstantIntOp(dtype) + else: + dtype = get_op_result_or_value(dtype) + assert str(dtype.type) == '!torch.int', f'`dtype` should be a !torch.int but is {type(dtype).__module__}.{type(dtype).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenViewCopyDtypeOp, self).__init__(result_type, self_, dtype, loc=loc, ip=ip) + + +class AtenUnfoldCopyOp: + def __init__(self, self_: Value, dimension: int, size: int, step: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(dimension): + dimension = torch_dialect.ConstantIntOp(dimension) + else: + dimension = get_op_result_or_value(dimension) + assert str(dimension.type) == '!torch.int', f'`dimension` should be a !torch.int but is {type(dimension).__module__}.{type(dimension).__name__}' + + if not is_mlir_value(size): + size = torch_dialect.ConstantIntOp(size) + else: + size = get_op_result_or_value(size) + assert str(size.type) == '!torch.int', f'`size` should be a !torch.int but is {type(size).__module__}.{type(size).__name__}' + + if not is_mlir_value(step): + step = torch_dialect.ConstantIntOp(step) + else: + step = get_op_result_or_value(step) + assert str(step.type) == '!torch.int', f'`step` should be a !torch.int but is {type(step).__module__}.{type(step).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenUnfoldCopyOp, self).__init__(result_type, self_, dimension, size, step, loc=loc, ip=ip) + + +class AtenSelectScatterOp: + def __init__(self, self_: Value, src: Value, dim: int, index: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(src): + assert is_mlir_value(src), f'`src` should be a Value but is {type(src).__module__}.{type(src).__name__}' + else: + src = get_op_result_or_value(src) + assert str(src.type).startswith("!torch.vtensor"), f'`src` should be a torch.vtensor but is {type(src).__module__}.{type(src).__name__}' + + if not is_mlir_value(dim): + dim = torch_dialect.ConstantIntOp(dim) + else: + dim = get_op_result_or_value(dim) + assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}' + + if not is_mlir_value(index): + index = torch_dialect.ConstantIntOp(index) + else: + index = get_op_result_or_value(index) + assert str(index.type) == '!torch.int', f'`index` should be a !torch.int but is {type(index).__module__}.{type(index).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenSelectScatterOp, self).__init__(result_type, self_, src, dim, index, loc=loc, ip=ip) + + +class AtenSliceScatterOp: + def __init__(self, self_: Value, src: Value, dim: int, start: Optional[int], end: Optional[int], step: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(src): + assert is_mlir_value(src), f'`src` should be a Value but is {type(src).__module__}.{type(src).__name__}' + else: + src = get_op_result_or_value(src) + assert str(src.type).startswith("!torch.vtensor"), f'`src` should be a torch.vtensor but is {type(src).__module__}.{type(src).__name__}' + + if not is_mlir_value(dim): + dim = torch_dialect.ConstantIntOp(dim) + else: + dim = get_op_result_or_value(dim) + assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}' + + if not is_mlir_value(start): + if start is not None: + start = torch_dialect.ConstantIntOp(start) + else: + start = torch_dialect.ConstantNoneOp() + else: + start = get_op_result_or_value(start) + assert str(start.type) == '!torch.int', f'`start` should be a !torch.int but is {type(start).__module__}.{type(start).__name__}' + + if not is_mlir_value(end): + if end is not None: + end = torch_dialect.ConstantIntOp(end) + else: + end = torch_dialect.ConstantNoneOp() + else: + end = get_op_result_or_value(end) + assert str(end.type) == '!torch.int', f'`end` should be a !torch.int but is {type(end).__module__}.{type(end).__name__}' + + if not is_mlir_value(step): + step = torch_dialect.ConstantIntOp(step) + else: + step = get_op_result_or_value(step) + assert str(step.type) == '!torch.int', f'`step` should be a !torch.int but is {type(step).__module__}.{type(step).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenSliceScatterOp, self).__init__(result_type, self_, src, dim, start, end, step, loc=loc, ip=ip) + + +class AtenDiagonalScatterOp: + def __init__(self, self_: Value, src: Value, offset: int, dim1: int, dim2: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(src): + assert is_mlir_value(src), f'`src` should be a Value but is {type(src).__module__}.{type(src).__name__}' + else: + src = get_op_result_or_value(src) + assert str(src.type).startswith("!torch.vtensor"), f'`src` should be a torch.vtensor but is {type(src).__module__}.{type(src).__name__}' + + if not is_mlir_value(offset): + offset = torch_dialect.ConstantIntOp(offset) + else: + offset = get_op_result_or_value(offset) + assert str(offset.type) == '!torch.int', f'`offset` should be a !torch.int but is {type(offset).__module__}.{type(offset).__name__}' + + if not is_mlir_value(dim1): + dim1 = torch_dialect.ConstantIntOp(dim1) + else: + dim1 = get_op_result_or_value(dim1) + assert str(dim1.type) == '!torch.int', f'`dim1` should be a !torch.int but is {type(dim1).__module__}.{type(dim1).__name__}' + + if not is_mlir_value(dim2): + dim2 = torch_dialect.ConstantIntOp(dim2) + else: + dim2 = get_op_result_or_value(dim2) + assert str(dim2.type) == '!torch.int', f'`dim2` should be a !torch.int but is {type(dim2).__module__}.{type(dim2).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenDiagonalScatterOp, self).__init__(result_type, self_, src, offset, dim1, dim2, loc=loc, ip=ip) + + +class AtenAsStridedScatterOp: + def __init__(self, self_: Value, src: Value, size: List[int], stride: List[int], storage_offset: Optional[int], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(src): + assert is_mlir_value(src), f'`src` should be a Value but is {type(src).__module__}.{type(src).__name__}' + else: + src = get_op_result_or_value(src) + assert str(src.type).startswith("!torch.vtensor"), f'`src` should be a torch.vtensor but is {type(src).__module__}.{type(src).__name__}' + + if not is_mlir_value(size): + size = list(map(torch_dialect.ConstantIntOp, size)) + size = torch_dialect.PrimListConstructOp(size) + else: + size = get_op_result_or_value(size) + assert str(size.type) == '!torch.list', f'`size` should be a !torch.list but is {type(size).__module__}.{type(size).__name__}' + + if not is_mlir_value(stride): + stride = list(map(torch_dialect.ConstantIntOp, stride)) + stride = torch_dialect.PrimListConstructOp(stride) + else: + stride = get_op_result_or_value(stride) + assert str(stride.type) == '!torch.list', f'`stride` should be a !torch.list but is {type(stride).__module__}.{type(stride).__name__}' + + if not is_mlir_value(storage_offset): + if storage_offset is not None: + storage_offset = torch_dialect.ConstantIntOp(storage_offset) + else: + storage_offset = torch_dialect.ConstantNoneOp() + else: + storage_offset = get_op_result_or_value(storage_offset) + assert str(storage_offset.type) == '!torch.int', f'`storage_offset` should be a !torch.int but is {type(storage_offset).__module__}.{type(storage_offset).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenAsStridedScatterOp, self).__init__(result_type, self_, src, size, stride, storage_offset, loc=loc, ip=ip) + + +class AtenUpsampleNearest2dOp: + def __init__(self, self_: Value, output_size: List[int], scales_h: Optional[float], scales_w: Optional[float], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(output_size): + output_size = list(map(torch_dialect.ConstantIntOp, output_size)) + output_size = torch_dialect.PrimListConstructOp(output_size) + else: + output_size = get_op_result_or_value(output_size) + assert str(output_size.type) == '!torch.list', f'`output_size` should be a !torch.list but is {type(output_size).__module__}.{type(output_size).__name__}' + + if not is_mlir_value(scales_h): + if scales_h is not None: + scales_h = torch_dialect.ConstantFloatOp(scales_h) + else: + scales_h = torch_dialect.ConstantNoneOp() + else: + scales_h = get_op_result_or_value(scales_h) + assert str(scales_h.type) == '!torch.float', f'`scales_h` should be a !torch.float but is {type(scales_h).__module__}.{type(scales_h).__name__}' + + if not is_mlir_value(scales_w): + if scales_w is not None: + scales_w = torch_dialect.ConstantFloatOp(scales_w) + else: + scales_w = torch_dialect.ConstantNoneOp() + else: + scales_w = get_op_result_or_value(scales_w) + assert str(scales_w.type) == '!torch.float', f'`scales_w` should be a !torch.float but is {type(scales_w).__module__}.{type(scales_w).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenUpsampleNearest2dOp, self).__init__(result_type, self_, output_size, scales_h, scales_w, loc=loc, ip=ip) + + +class Aten__Contains__IntListOp: + def __init__(self, l: List[int], item: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(l): + l = list(map(torch_dialect.ConstantIntOp, l)) + l = torch_dialect.PrimListConstructOp(l) + else: + l = get_op_result_or_value(l) + assert str(l.type) == '!torch.list', f'`l` should be a !torch.list but is {type(l).__module__}.{type(l).__name__}' + + if not is_mlir_value(item): + item = torch_dialect.ConstantIntOp(item) + else: + item = get_op_result_or_value(item) + assert str(item.type) == '!torch.int', f'`item` should be a !torch.int but is {type(item).__module__}.{type(item).__name__}' + + super(Aten__Contains__IntListOp, self).__init__(l, item, loc=loc, ip=ip) + + +class AtenCatOp: + def __init__(self, tensors: List[Value], dim: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(tensors): + tensors = torch_dialect.PrimListConstructOp(tensors) + else: + tensors = get_op_result_or_value(tensors) + assert str(tensors.type) == '!torch.list', f'`tensors` should be a !torch.list but is {type(tensors).__module__}.{type(tensors).__name__}' + + if not is_mlir_value(dim): + dim = torch_dialect.ConstantIntOp(dim) + else: + dim = get_op_result_or_value(dim) + assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenCatOp, self).__init__(result_type, tensors, dim, loc=loc, ip=ip) + + +class AtenAppendTOp: + def __init__(self, self_: List[Value], el: Value, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + self_ = torch_dialect.PrimListConstructOp(self_) + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type) == '!torch.list', f'`self_` should be a !torch.list but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(el): + assert is_mlir_value(el), f'`el` should be a Value but is {type(el).__module__}.{type(el).__name__}' + else: + el = get_op_result_or_value(el) + assert str(el.type).startswith("!torch.vtensor"), f'`el` should be a torch.vtensor but is {type(el).__module__}.{type(el).__name__}' + + result_type = Type.parse("!torch.list") + super(AtenAppendTOp, self).__init__(result_type, self_, el, loc=loc, ip=ip) + + +class AtenAddTOp: + def __init__(self, a: List[Value], b: List[Value], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.PrimListConstructOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) == '!torch.list', f'`a` should be a !torch.list but is {type(a).__module__}.{type(a).__name__}' + + if not is_mlir_value(b): + b = torch_dialect.PrimListConstructOp(b) + else: + b = get_op_result_or_value(b) + assert str(b.type) == '!torch.list', f'`b` should be a !torch.list but is {type(b).__module__}.{type(b).__name__}' + + result_type = Type.parse("!torch.list") + super(AtenAddTOp, self).__init__(result_type, a, b, loc=loc, ip=ip) + + +class AtenEqIntListOp: + def __init__(self, a: List[int], b: List[int], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = list(map(torch_dialect.ConstantIntOp, a)) + a = torch_dialect.PrimListConstructOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) == '!torch.list', f'`a` should be a !torch.list but is {type(a).__module__}.{type(a).__name__}' + + if not is_mlir_value(b): + b = list(map(torch_dialect.ConstantIntOp, b)) + b = torch_dialect.PrimListConstructOp(b) + else: + b = get_op_result_or_value(b) + assert str(b.type) == '!torch.list', f'`b` should be a !torch.list but is {type(b).__module__}.{type(b).__name__}' + + super(AtenEqIntListOp, self).__init__(a, b, loc=loc, ip=ip) + + +class AtenListTOp: + def __init__(self, l: List[Value], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(l): + l = torch_dialect.PrimListConstructOp(l) + else: + l = get_op_result_or_value(l) + assert str(l.type) == '!torch.list', f'`l` should be a !torch.list but is {type(l).__module__}.{type(l).__name__}' + + result_type = Type.parse("!torch.list") + super(AtenListTOp, self).__init__(result_type, l, loc=loc, ip=ip) + + +class AtenSliceTOp: + def __init__(self, l: List[Value], start: Optional[int], end: Optional[int], step: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(l): + l = torch_dialect.PrimListConstructOp(l) + else: + l = get_op_result_or_value(l) + assert str(l.type) == '!torch.list', f'`l` should be a !torch.list but is {type(l).__module__}.{type(l).__name__}' + + if not is_mlir_value(start): + if start is not None: + start = torch_dialect.ConstantIntOp(start) + else: + start = torch_dialect.ConstantNoneOp() + else: + start = get_op_result_or_value(start) + assert str(start.type) == '!torch.int', f'`start` should be a !torch.int but is {type(start).__module__}.{type(start).__name__}' + + if not is_mlir_value(end): + if end is not None: + end = torch_dialect.ConstantIntOp(end) + else: + end = torch_dialect.ConstantNoneOp() + else: + end = get_op_result_or_value(end) + assert str(end.type) == '!torch.int', f'`end` should be a !torch.int but is {type(end).__module__}.{type(end).__name__}' + + if not is_mlir_value(step): + step = torch_dialect.ConstantIntOp(step) + else: + step = get_op_result_or_value(step) + assert str(step.type) == '!torch.int', f'`step` should be a !torch.int but is {type(step).__module__}.{type(step).__name__}' + + result_type = Type.parse("!torch.list") + super(AtenSliceTOp, self).__init__(result_type, l, start, end, step, loc=loc, ip=ip) + + +class AtenInsertTOp: + def __init__(self, self_: List[Value], idx: int, el: Value, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + self_ = torch_dialect.PrimListConstructOp(self_) + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type) == '!torch.list', f'`self_` should be a !torch.list but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(idx): + idx = torch_dialect.ConstantIntOp(idx) + else: + idx = get_op_result_or_value(idx) + assert str(idx.type) == '!torch.int', f'`idx` should be a !torch.int but is {type(idx).__module__}.{type(idx).__name__}' + + if not is_mlir_value(el): + assert is_mlir_value(el), f'`el` should be a Value but is {type(el).__module__}.{type(el).__name__}' + else: + el = get_op_result_or_value(el) + assert str(el.type).startswith("!torch.vtensor"), f'`el` should be a torch.vtensor but is {type(el).__module__}.{type(el).__name__}' + + super(AtenInsertTOp, self).__init__(self_, idx, el, loc=loc, ip=ip) + + +class AtenNeIntListOp: + def __init__(self, a: List[int], b: List[int], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = list(map(torch_dialect.ConstantIntOp, a)) + a = torch_dialect.PrimListConstructOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) == '!torch.list', f'`a` should be a !torch.list but is {type(a).__module__}.{type(a).__name__}' + + if not is_mlir_value(b): + b = list(map(torch_dialect.ConstantIntOp, b)) + b = torch_dialect.PrimListConstructOp(b) + else: + b = get_op_result_or_value(b) + assert str(b.type) == '!torch.list', f'`b` should be a !torch.list but is {type(b).__module__}.{type(b).__name__}' + + super(AtenNeIntListOp, self).__init__(a, b, loc=loc, ip=ip) + + +class AtenAnyBoolOp: + def __init__(self, self_: List[bool], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + self_ = list(map(torch_dialect.ConstantBoolOp, self_)) + self_ = torch_dialect.PrimListConstructOp(self_) + else: + self_ = get_op_result_or_value(self_) + # should be bool[] + pass + + super(AtenAnyBoolOp, self).__init__(self_, loc=loc, ip=ip) + + +class AtenSortIntOp: + def __init__(self, self_: List[int], reverse: bool, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + self_ = list(map(torch_dialect.ConstantIntOp, self_)) + self_ = torch_dialect.PrimListConstructOp(self_) + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type) == '!torch.list', f'`self_` should be a !torch.list but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(reverse): + reverse = torch_dialect.ConstantBoolOp(reverse) + else: + reverse = get_op_result_or_value(reverse) + assert str(reverse.type) == '!torch.bool', f'`reverse` should be a !torch.bool but is {type(reverse).__module__}.{type(reverse).__name__}' + + super(AtenSortIntOp, self).__init__(self_, reverse, loc=loc, ip=ip) + + +class AtenAddStrOp: + def __init__(self, a: str, b: str, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantStrOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) == '!torch.str', f'`a` should be a !torch.str but is {type(a).__module__}.{type(a).__name__}' + + if not is_mlir_value(b): + b = torch_dialect.ConstantStrOp(b) + else: + b = get_op_result_or_value(b) + assert str(b.type) == '!torch.str', f'`b` should be a !torch.str but is {type(b).__module__}.{type(b).__name__}' + + super(AtenAddStrOp, self).__init__(a, b, loc=loc, ip=ip) + + +class AtenEqStrOp: + def __init__(self, a: str, b: str, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantStrOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) == '!torch.str', f'`a` should be a !torch.str but is {type(a).__module__}.{type(a).__name__}' + + if not is_mlir_value(b): + b = torch_dialect.ConstantStrOp(b) + else: + b = get_op_result_or_value(b) + assert str(b.type) == '!torch.str', f'`b` should be a !torch.str but is {type(b).__module__}.{type(b).__name__}' + + super(AtenEqStrOp, self).__init__(a, b, loc=loc, ip=ip) + + +class AtenLenStrOp: + def __init__(self, s: str, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(s): + s = torch_dialect.ConstantStrOp(s) + else: + s = get_op_result_or_value(s) + assert str(s.type) == '!torch.str', f'`s` should be a !torch.str but is {type(s).__module__}.{type(s).__name__}' + + super(AtenLenStrOp, self).__init__(s, loc=loc, ip=ip) + + +class AtenStrOp: + def __init__(self, elem: Value, *, loc=None, ip=None): + if not is_mlir_value(elem): + assert is_mlir_value(elem), f'`elem` should be a Value but is {type(elem).__module__}.{type(elem).__name__}' + else: + elem = get_op_result_or_value(elem) + assert str(elem.type).startswith("!torch.vtensor"), f'`elem` should be a torch.vtensor but is {type(elem).__module__}.{type(elem).__name__}' + + super(AtenStrOp, self).__init__(elem, loc=loc, ip=ip) + + +class AtenJoinOp: + def __init__(self, self_: str, values: List[str], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + self_ = torch_dialect.ConstantStrOp(self_) + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type) == '!torch.str', f'`self_` should be a !torch.str but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(values): + values = list(map(torch_dialect.ConstantStrOp, values)) + values = torch_dialect.PrimListConstructOp(values) + else: + values = get_op_result_or_value(values) + # should be str[] + pass + + super(AtenJoinOp, self).__init__(self_, values, loc=loc, ip=ip) + + +class AtenFloatScalarOp: + def __init__(self, a: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantNumberOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) in {'!torch.float', '!torch.int'}, f'`a` should be a !torch.number but is {type(a).__module__}.{type(a).__name__}' + + super(AtenFloatScalarOp, self).__init__(a, loc=loc, ip=ip) + + +class AtenFloatStrOp: + def __init__(self, a: str, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantStrOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) == '!torch.str', f'`a` should be a !torch.str but is {type(a).__module__}.{type(a).__name__}' + + super(AtenFloatStrOp, self).__init__(a, loc=loc, ip=ip) + + +class AtenIntFloatOp: + def __init__(self, a: float, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantFloatOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) == '!torch.float', f'`a` should be a !torch.float but is {type(a).__module__}.{type(a).__name__}' + + super(AtenIntFloatOp, self).__init__(a, loc=loc, ip=ip) + + +class AtenIntScalarOp: + def __init__(self, a: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantNumberOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) in {'!torch.float', '!torch.int'}, f'`a` should be a !torch.number but is {type(a).__module__}.{type(a).__name__}' + + super(AtenIntScalarOp, self).__init__(a, loc=loc, ip=ip) + + +class Aten__RangeLengthOp: + def __init__(self, lo: int, hi: int, step: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(lo): + lo = torch_dialect.ConstantIntOp(lo) + else: + lo = get_op_result_or_value(lo) + assert str(lo.type) == '!torch.int', f'`lo` should be a !torch.int but is {type(lo).__module__}.{type(lo).__name__}' + + if not is_mlir_value(hi): + hi = torch_dialect.ConstantIntOp(hi) + else: + hi = get_op_result_or_value(hi) + assert str(hi.type) == '!torch.int', f'`hi` should be a !torch.int but is {type(hi).__module__}.{type(hi).__name__}' + + if not is_mlir_value(step): + step = torch_dialect.ConstantIntOp(step) + else: + step = get_op_result_or_value(step) + assert str(step.type) == '!torch.int', f'`step` should be a !torch.int but is {type(step).__module__}.{type(step).__name__}' + + super(Aten__RangeLengthOp, self).__init__(lo, hi, step, loc=loc, ip=ip) + + +class Aten__DeriveIndexOp: + def __init__(self, index: int, start: int, step: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(index): + index = torch_dialect.ConstantIntOp(index) + else: + index = get_op_result_or_value(index) + assert str(index.type) == '!torch.int', f'`index` should be a !torch.int but is {type(index).__module__}.{type(index).__name__}' + + if not is_mlir_value(start): + start = torch_dialect.ConstantIntOp(start) + else: + start = get_op_result_or_value(start) + assert str(start.type) == '!torch.int', f'`start` should be a !torch.int but is {type(start).__module__}.{type(start).__name__}' + + if not is_mlir_value(step): + step = torch_dialect.ConstantIntOp(step) + else: + step = get_op_result_or_value(step) + assert str(step.type) == '!torch.int', f'`step` should be a !torch.int but is {type(step).__module__}.{type(step).__name__}' + + super(Aten__DeriveIndexOp, self).__init__(index, start, step, loc=loc, ip=ip) + + +class AtenGtIntOp: + def __init__(self, a: int, b: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantIntOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) == '!torch.int', f'`a` should be a !torch.int but is {type(a).__module__}.{type(a).__name__}' + + if not is_mlir_value(b): + b = torch_dialect.ConstantIntOp(b) + else: + b = get_op_result_or_value(b) + assert str(b.type) == '!torch.int', f'`b` should be a !torch.int but is {type(b).__module__}.{type(b).__name__}' + + super(AtenGtIntOp, self).__init__(a, b, loc=loc, ip=ip) + + +class AtenGeIntOp: + def __init__(self, a: int, b: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantIntOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) == '!torch.int', f'`a` should be a !torch.int but is {type(a).__module__}.{type(a).__name__}' + + if not is_mlir_value(b): + b = torch_dialect.ConstantIntOp(b) + else: + b = get_op_result_or_value(b) + assert str(b.type) == '!torch.int', f'`b` should be a !torch.int but is {type(b).__module__}.{type(b).__name__}' + + super(AtenGeIntOp, self).__init__(a, b, loc=loc, ip=ip) + + +class AtenLtIntOp: + def __init__(self, a: int, b: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantIntOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) == '!torch.int', f'`a` should be a !torch.int but is {type(a).__module__}.{type(a).__name__}' + + if not is_mlir_value(b): + b = torch_dialect.ConstantIntOp(b) + else: + b = get_op_result_or_value(b) + assert str(b.type) == '!torch.int', f'`b` should be a !torch.int but is {type(b).__module__}.{type(b).__name__}' + + super(AtenLtIntOp, self).__init__(a, b, loc=loc, ip=ip) + + +class AtenLeIntOp: + def __init__(self, a: int, b: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantIntOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) == '!torch.int', f'`a` should be a !torch.int but is {type(a).__module__}.{type(a).__name__}' + + if not is_mlir_value(b): + b = torch_dialect.ConstantIntOp(b) + else: + b = get_op_result_or_value(b) + assert str(b.type) == '!torch.int', f'`b` should be a !torch.int but is {type(b).__module__}.{type(b).__name__}' + + super(AtenLeIntOp, self).__init__(a, b, loc=loc, ip=ip) + + +class AtenNeIntOp: + def __init__(self, a: int, b: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantIntOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) == '!torch.int', f'`a` should be a !torch.int but is {type(a).__module__}.{type(a).__name__}' + + if not is_mlir_value(b): + b = torch_dialect.ConstantIntOp(b) + else: + b = get_op_result_or_value(b) + assert str(b.type) == '!torch.int', f'`b` should be a !torch.int but is {type(b).__module__}.{type(b).__name__}' + + super(AtenNeIntOp, self).__init__(a, b, loc=loc, ip=ip) + + +class AtenEqIntOp: + def __init__(self, a: int, b: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantIntOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) == '!torch.int', f'`a` should be a !torch.int but is {type(a).__module__}.{type(a).__name__}' + + if not is_mlir_value(b): + b = torch_dialect.ConstantIntOp(b) + else: + b = get_op_result_or_value(b) + assert str(b.type) == '!torch.int', f'`b` should be a !torch.int but is {type(b).__module__}.{type(b).__name__}' + + super(AtenEqIntOp, self).__init__(a, b, loc=loc, ip=ip) + + +class AtenFloordivIntOp: + def __init__(self, a: int, b: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantIntOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) == '!torch.int', f'`a` should be a !torch.int but is {type(a).__module__}.{type(a).__name__}' + + if not is_mlir_value(b): + b = torch_dialect.ConstantIntOp(b) + else: + b = get_op_result_or_value(b) + assert str(b.type) == '!torch.int', f'`b` should be a !torch.int but is {type(b).__module__}.{type(b).__name__}' + + super(AtenFloordivIntOp, self).__init__(a, b, loc=loc, ip=ip) + + +class AtenRemainderIntOp: + def __init__(self, a: int, b: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantIntOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) == '!torch.int', f'`a` should be a !torch.int but is {type(a).__module__}.{type(a).__name__}' + + if not is_mlir_value(b): + b = torch_dialect.ConstantIntOp(b) + else: + b = get_op_result_or_value(b) + assert str(b.type) == '!torch.int', f'`b` should be a !torch.int but is {type(b).__module__}.{type(b).__name__}' + + super(AtenRemainderIntOp, self).__init__(a, b, loc=loc, ip=ip) + + +class AtenRemainderScalarOp: + def __init__(self, self_: Value, other: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(other): + other = torch_dialect.ConstantNumberOp(other) + else: + other = get_op_result_or_value(other) + assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenRemainderScalarOp, self).__init__(result_type, self_, other, loc=loc, ip=ip) + + +class AtenAddIntOp: + def __init__(self, a: int, b: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantIntOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) == '!torch.int', f'`a` should be a !torch.int but is {type(a).__module__}.{type(a).__name__}' + + if not is_mlir_value(b): + b = torch_dialect.ConstantIntOp(b) + else: + b = get_op_result_or_value(b) + assert str(b.type) == '!torch.int', f'`b` should be a !torch.int but is {type(b).__module__}.{type(b).__name__}' + + super(AtenAddIntOp, self).__init__(a, b, loc=loc, ip=ip) + + +class AtenSubIntOp: + def __init__(self, a: int, b: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantIntOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) == '!torch.int', f'`a` should be a !torch.int but is {type(a).__module__}.{type(a).__name__}' + + if not is_mlir_value(b): + b = torch_dialect.ConstantIntOp(b) + else: + b = get_op_result_or_value(b) + assert str(b.type) == '!torch.int', f'`b` should be a !torch.int but is {type(b).__module__}.{type(b).__name__}' + + super(AtenSubIntOp, self).__init__(a, b, loc=loc, ip=ip) + + +class AtenMulIntOp: + def __init__(self, a: int, b: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantIntOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) == '!torch.int', f'`a` should be a !torch.int but is {type(a).__module__}.{type(a).__name__}' + + if not is_mlir_value(b): + b = torch_dialect.ConstantIntOp(b) + else: + b = get_op_result_or_value(b) + assert str(b.type) == '!torch.int', f'`b` should be a !torch.int but is {type(b).__module__}.{type(b).__name__}' + + super(AtenMulIntOp, self).__init__(a, b, loc=loc, ip=ip) + + +class AtenDivIntOp: + def __init__(self, a: int, b: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantIntOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) == '!torch.int', f'`a` should be a !torch.int but is {type(a).__module__}.{type(a).__name__}' + + if not is_mlir_value(b): + b = torch_dialect.ConstantIntOp(b) + else: + b = get_op_result_or_value(b) + assert str(b.type) == '!torch.int', f'`b` should be a !torch.int but is {type(b).__module__}.{type(b).__name__}' + + super(AtenDivIntOp, self).__init__(a, b, loc=loc, ip=ip) + + +class AtenNegIntOp: + def __init__(self, a: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantIntOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) == '!torch.int', f'`a` should be a !torch.int but is {type(a).__module__}.{type(a).__name__}' + + super(AtenNegIntOp, self).__init__(a, loc=loc, ip=ip) + + +class AtenLogIntOp: + def __init__(self, a: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantIntOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) == '!torch.int', f'`a` should be a !torch.int but is {type(a).__module__}.{type(a).__name__}' + + super(AtenLogIntOp, self).__init__(a, loc=loc, ip=ip) + + +class AtenAddFloatIntOp: + def __init__(self, a: float, b: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantFloatOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) == '!torch.float', f'`a` should be a !torch.float but is {type(a).__module__}.{type(a).__name__}' + + if not is_mlir_value(b): + b = torch_dialect.ConstantIntOp(b) + else: + b = get_op_result_or_value(b) + assert str(b.type) == '!torch.int', f'`b` should be a !torch.int but is {type(b).__module__}.{type(b).__name__}' + + super(AtenAddFloatIntOp, self).__init__(a, b, loc=loc, ip=ip) + + +class AtenSubFloatOp: + def __init__(self, a: float, b: float, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantFloatOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) == '!torch.float', f'`a` should be a !torch.float but is {type(a).__module__}.{type(a).__name__}' + + if not is_mlir_value(b): + b = torch_dialect.ConstantFloatOp(b) + else: + b = get_op_result_or_value(b) + assert str(b.type) == '!torch.float', f'`b` should be a !torch.float but is {type(b).__module__}.{type(b).__name__}' + + super(AtenSubFloatOp, self).__init__(a, b, loc=loc, ip=ip) + + +class AtenMulFloatOp: + def __init__(self, a: float, b: float, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantFloatOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) == '!torch.float', f'`a` should be a !torch.float but is {type(a).__module__}.{type(a).__name__}' + + if not is_mlir_value(b): + b = torch_dialect.ConstantFloatOp(b) + else: + b = get_op_result_or_value(b) + assert str(b.type) == '!torch.float', f'`b` should be a !torch.float but is {type(b).__module__}.{type(b).__name__}' + + super(AtenMulFloatOp, self).__init__(a, b, loc=loc, ip=ip) + + +class AtenDivFloatOp: + def __init__(self, a: float, b: float, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantFloatOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) == '!torch.float', f'`a` should be a !torch.float but is {type(a).__module__}.{type(a).__name__}' + + if not is_mlir_value(b): + b = torch_dialect.ConstantFloatOp(b) + else: + b = get_op_result_or_value(b) + assert str(b.type) == '!torch.float', f'`b` should be a !torch.float but is {type(b).__module__}.{type(b).__name__}' + + super(AtenDivFloatOp, self).__init__(a, b, loc=loc, ip=ip) + + +class AtenNegFloatOp: + def __init__(self, a: float, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantFloatOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) == '!torch.float', f'`a` should be a !torch.float but is {type(a).__module__}.{type(a).__name__}' + + super(AtenNegFloatOp, self).__init__(a, loc=loc, ip=ip) + + +class AtenEqFloatOp: + def __init__(self, a: float, b: float, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantFloatOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) == '!torch.float', f'`a` should be a !torch.float but is {type(a).__module__}.{type(a).__name__}' + + if not is_mlir_value(b): + b = torch_dialect.ConstantFloatOp(b) + else: + b = get_op_result_or_value(b) + assert str(b.type) == '!torch.float', f'`b` should be a !torch.float but is {type(b).__module__}.{type(b).__name__}' + + super(AtenEqFloatOp, self).__init__(a, b, loc=loc, ip=ip) + + +class AtenGtFloatOp: + def __init__(self, a: float, b: float, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantFloatOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) == '!torch.float', f'`a` should be a !torch.float but is {type(a).__module__}.{type(a).__name__}' + + if not is_mlir_value(b): + b = torch_dialect.ConstantFloatOp(b) + else: + b = get_op_result_or_value(b) + assert str(b.type) == '!torch.float', f'`b` should be a !torch.float but is {type(b).__module__}.{type(b).__name__}' + + super(AtenGtFloatOp, self).__init__(a, b, loc=loc, ip=ip) + + +class AtenGeFloatOp: + def __init__(self, a: float, b: float, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantFloatOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) == '!torch.float', f'`a` should be a !torch.float but is {type(a).__module__}.{type(a).__name__}' + + if not is_mlir_value(b): + b = torch_dialect.ConstantFloatOp(b) + else: + b = get_op_result_or_value(b) + assert str(b.type) == '!torch.float', f'`b` should be a !torch.float but is {type(b).__module__}.{type(b).__name__}' + + super(AtenGeFloatOp, self).__init__(a, b, loc=loc, ip=ip) + + +class AtenLtFloatOp: + def __init__(self, a: float, b: float, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantFloatOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) == '!torch.float', f'`a` should be a !torch.float but is {type(a).__module__}.{type(a).__name__}' + + if not is_mlir_value(b): + b = torch_dialect.ConstantFloatOp(b) + else: + b = get_op_result_or_value(b) + assert str(b.type) == '!torch.float', f'`b` should be a !torch.float but is {type(b).__module__}.{type(b).__name__}' + + super(AtenLtFloatOp, self).__init__(a, b, loc=loc, ip=ip) + + +class AtenLtFloatIntOp: + def __init__(self, a: float, b: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantFloatOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) == '!torch.float', f'`a` should be a !torch.float but is {type(a).__module__}.{type(a).__name__}' + + if not is_mlir_value(b): + b = torch_dialect.ConstantIntOp(b) + else: + b = get_op_result_or_value(b) + assert str(b.type) == '!torch.int', f'`b` should be a !torch.int but is {type(b).__module__}.{type(b).__name__}' + + super(AtenLtFloatIntOp, self).__init__(a, b, loc=loc, ip=ip) + + +class AtenGeFloatIntOp: + def __init__(self, a: float, b: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantFloatOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) == '!torch.float', f'`a` should be a !torch.float but is {type(a).__module__}.{type(a).__name__}' + + if not is_mlir_value(b): + b = torch_dialect.ConstantIntOp(b) + else: + b = get_op_result_or_value(b) + assert str(b.type) == '!torch.int', f'`b` should be a !torch.int but is {type(b).__module__}.{type(b).__name__}' + + super(AtenGeFloatIntOp, self).__init__(a, b, loc=loc, ip=ip) + + +class AtenNeFloatIntOp: + def __init__(self, a: float, b: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantFloatOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) == '!torch.float', f'`a` should be a !torch.float but is {type(a).__module__}.{type(a).__name__}' + + if not is_mlir_value(b): + b = torch_dialect.ConstantIntOp(b) + else: + b = get_op_result_or_value(b) + assert str(b.type) == '!torch.int', f'`b` should be a !torch.int but is {type(b).__module__}.{type(b).__name__}' + + super(AtenNeFloatIntOp, self).__init__(a, b, loc=loc, ip=ip) + + +class AtenGtFloatIntOp: + def __init__(self, a: float, b: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantFloatOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) == '!torch.float', f'`a` should be a !torch.float but is {type(a).__module__}.{type(a).__name__}' + + if not is_mlir_value(b): + b = torch_dialect.ConstantIntOp(b) + else: + b = get_op_result_or_value(b) + assert str(b.type) == '!torch.int', f'`b` should be a !torch.int but is {type(b).__module__}.{type(b).__name__}' + + super(AtenGtFloatIntOp, self).__init__(a, b, loc=loc, ip=ip) + + +class Aten__And__BoolOp: + def __init__(self, a: bool, b: bool, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantBoolOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) == '!torch.bool', f'`a` should be a !torch.bool but is {type(a).__module__}.{type(a).__name__}' + + if not is_mlir_value(b): + b = torch_dialect.ConstantBoolOp(b) + else: + b = get_op_result_or_value(b) + assert str(b.type) == '!torch.bool', f'`b` should be a !torch.bool but is {type(b).__module__}.{type(b).__name__}' + + super(Aten__And__BoolOp, self).__init__(a, b, loc=loc, ip=ip) + + +class AtenNeBoolOp: + def __init__(self, a: bool, b: bool, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantBoolOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) == '!torch.bool', f'`a` should be a !torch.bool but is {type(a).__module__}.{type(a).__name__}' + + if not is_mlir_value(b): + b = torch_dialect.ConstantBoolOp(b) + else: + b = get_op_result_or_value(b) + assert str(b.type) == '!torch.bool', f'`b` should be a !torch.bool but is {type(b).__module__}.{type(b).__name__}' + + super(AtenNeBoolOp, self).__init__(a, b, loc=loc, ip=ip) + + +class Aten__Is__Op: + def __init__(self, self_: Value, obj: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + # should be t1 + pass + + if not is_mlir_value(obj): + assert is_mlir_value(obj), f'`obj` should be a Value but is {type(obj).__module__}.{type(obj).__name__}' + else: + obj = get_op_result_or_value(obj) + # should be t2 + pass + + super(Aten__Is__Op, self).__init__(self_, obj, loc=loc, ip=ip) + + +class Aten__Isnot__Op: + def __init__(self, self_: Value, obj: Value, *, loc=None, ip=None): + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + # should be t1 + pass + + if not is_mlir_value(obj): + assert is_mlir_value(obj), f'`obj` should be a Value but is {type(obj).__module__}.{type(obj).__name__}' + else: + obj = get_op_result_or_value(obj) + # should be t2 + pass + + super(Aten__Isnot__Op, self).__init__(self_, obj, loc=loc, ip=ip) + + +class Aten__Not__Op: + def __init__(self, self_: bool, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + self_ = torch_dialect.ConstantBoolOp(self_) + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type) == '!torch.bool', f'`self_` should be a !torch.bool but is {type(self_).__module__}.{type(self_).__name__}' + + super(Aten__Not__Op, self).__init__(self_, loc=loc, ip=ip) + + +class AtenLenTOp: + def __init__(self, a: List[Value], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.PrimListConstructOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) == '!torch.list', f'`a` should be a !torch.list but is {type(a).__module__}.{type(a).__name__}' + + super(AtenLenTOp, self).__init__(a, loc=loc, ip=ip) + + +class Aten__Getitem__TOp: + def __init__(self, list_: List[Value], idx: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(list_): + list_ = torch_dialect.PrimListConstructOp(list_) + else: + list_ = get_op_result_or_value(list_) + assert str(list_.type) == '!torch.list', f'`list_` should be a !torch.list but is {type(list_).__module__}.{type(list_).__name__}' + + if not is_mlir_value(idx): + idx = torch_dialect.ConstantIntOp(idx) + else: + idx = get_op_result_or_value(idx) + assert str(idx.type) == '!torch.int', f'`idx` should be a !torch.int but is {type(idx).__module__}.{type(idx).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(Aten__Getitem__TOp, self).__init__(result_type, list_, idx, loc=loc, ip=ip) + + +class Aten_SetItemTOp: + def __init__(self, l: List[Value], idx: int, el: Value, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(l): + l = torch_dialect.PrimListConstructOp(l) + else: + l = get_op_result_or_value(l) + assert str(l.type) == '!torch.list', f'`l` should be a !torch.list but is {type(l).__module__}.{type(l).__name__}' + + if not is_mlir_value(idx): + idx = torch_dialect.ConstantIntOp(idx) + else: + idx = get_op_result_or_value(idx) + assert str(idx.type) == '!torch.int', f'`idx` should be a !torch.int but is {type(idx).__module__}.{type(idx).__name__}' + + if not is_mlir_value(el): + assert is_mlir_value(el), f'`el` should be a Value but is {type(el).__module__}.{type(el).__name__}' + else: + el = get_op_result_or_value(el) + assert str(el.type).startswith("!torch.vtensor"), f'`el` should be a torch.vtensor but is {type(el).__module__}.{type(el).__name__}' + + result_type = Type.parse("!torch.list") + super(Aten_SetItemTOp, self).__init__(result_type, l, idx, el, loc=loc, ip=ip) + + +class AtenDivOp: + def __init__(self, a: "Number", b: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantNumberOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) in {'!torch.float', '!torch.int'}, f'`a` should be a !torch.number but is {type(a).__module__}.{type(a).__name__}' + + if not is_mlir_value(b): + b = torch_dialect.ConstantNumberOp(b) + else: + b = get_op_result_or_value(b) + assert str(b.type) in {'!torch.float', '!torch.int'}, f'`b` should be a !torch.number but is {type(b).__module__}.{type(b).__name__}' + + super(AtenDivOp, self).__init__(a, b, loc=loc, ip=ip) + + +class AtenAddOp: + def __init__(self, a: "Number", b: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantNumberOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) in {'!torch.float', '!torch.int'}, f'`a` should be a !torch.number but is {type(a).__module__}.{type(a).__name__}' + + if not is_mlir_value(b): + b = torch_dialect.ConstantNumberOp(b) + else: + b = get_op_result_or_value(b) + assert str(b.type) in {'!torch.float', '!torch.int'}, f'`b` should be a !torch.number but is {type(b).__module__}.{type(b).__name__}' + + super(AtenAddOp, self).__init__(a, b, loc=loc, ip=ip) + + +class AtenSubOp: + def __init__(self, a: "Number", b: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantNumberOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) in {'!torch.float', '!torch.int'}, f'`a` should be a !torch.number but is {type(a).__module__}.{type(a).__name__}' + + if not is_mlir_value(b): + b = torch_dialect.ConstantNumberOp(b) + else: + b = get_op_result_or_value(b) + assert str(b.type) in {'!torch.float', '!torch.int'}, f'`b` should be a !torch.number but is {type(b).__module__}.{type(b).__name__}' + + super(AtenSubOp, self).__init__(a, b, loc=loc, ip=ip) + + +class AtenCeilScalarOp: + def __init__(self, a: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantNumberOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) in {'!torch.float', '!torch.int'}, f'`a` should be a !torch.number but is {type(a).__module__}.{type(a).__name__}' + + super(AtenCeilScalarOp, self).__init__(a, loc=loc, ip=ip) + + +class AtenSqrtIntOp: + def __init__(self, a: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantIntOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) == '!torch.int', f'`a` should be a !torch.int but is {type(a).__module__}.{type(a).__name__}' + + super(AtenSqrtIntOp, self).__init__(a, loc=loc, ip=ip) + + +class AtenBoolFloatOp: + def __init__(self, a: float, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantFloatOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) == '!torch.float', f'`a` should be a !torch.float but is {type(a).__module__}.{type(a).__name__}' + + super(AtenBoolFloatOp, self).__init__(a, loc=loc, ip=ip) + + +class AtenBoolIntOp: + def __init__(self, a: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantIntOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) == '!torch.int', f'`a` should be a !torch.int but is {type(a).__module__}.{type(a).__name__}' + + super(AtenBoolIntOp, self).__init__(a, loc=loc, ip=ip) + + +class AtenCeilFloatOp: + def __init__(self, a: float, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantFloatOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) == '!torch.float', f'`a` should be a !torch.float but is {type(a).__module__}.{type(a).__name__}' + + super(AtenCeilFloatOp, self).__init__(a, loc=loc, ip=ip) + + +class AtenNarrowOp: + def __init__(self, self_: Value, dim: int, start: int, length: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(dim): + dim = torch_dialect.ConstantIntOp(dim) + else: + dim = get_op_result_or_value(dim) + assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}' + + if not is_mlir_value(start): + start = torch_dialect.ConstantIntOp(start) + else: + start = get_op_result_or_value(start) + assert str(start.type) == '!torch.int', f'`start` should be a !torch.int but is {type(start).__module__}.{type(start).__name__}' + + if not is_mlir_value(length): + length = torch_dialect.ConstantIntOp(length) + else: + length = get_op_result_or_value(length) + assert str(length.type) == '!torch.int', f'`length` should be a !torch.int but is {type(length).__module__}.{type(length).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenNarrowOp, self).__init__(result_type, self_, dim, start, length, loc=loc, ip=ip) + + +class AtenScalarImplicitOp: + def __init__(self, a: Value, *, loc=None, ip=None): + if not is_mlir_value(a): + assert is_mlir_value(a), f'`a` should be a Value but is {type(a).__module__}.{type(a).__name__}' + else: + a = get_op_result_or_value(a) + assert str(a.type).startswith("!torch.vtensor"), f'`a` should be a torch.vtensor but is {type(a).__module__}.{type(a).__name__}' + + super(AtenScalarImplicitOp, self).__init__(a, loc=loc, ip=ip) + + +class Aten_SoftmaxBackwardDataOp: + def __init__(self, grad_output: Value, output: Value, dim: int, input_dtype: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(grad_output): + assert is_mlir_value(grad_output), f'`grad_output` should be a Value but is {type(grad_output).__module__}.{type(grad_output).__name__}' + else: + grad_output = get_op_result_or_value(grad_output) + assert str(grad_output.type).startswith("!torch.vtensor"), f'`grad_output` should be a torch.vtensor but is {type(grad_output).__module__}.{type(grad_output).__name__}' + + if not is_mlir_value(output): + assert is_mlir_value(output), f'`output` should be a Value but is {type(output).__module__}.{type(output).__name__}' + else: + output = get_op_result_or_value(output) + assert str(output.type).startswith("!torch.vtensor"), f'`output` should be a torch.vtensor but is {type(output).__module__}.{type(output).__name__}' + + if not is_mlir_value(dim): + dim = torch_dialect.ConstantIntOp(dim) + else: + dim = get_op_result_or_value(dim) + assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}' + + if not is_mlir_value(input_dtype): + input_dtype = torch_dialect.ConstantIntOp(input_dtype) + else: + input_dtype = get_op_result_or_value(input_dtype) + assert str(input_dtype.type) == '!torch.int', f'`input_dtype` should be a !torch.int but is {type(input_dtype).__module__}.{type(input_dtype).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(Aten_SoftmaxBackwardDataOp, self).__init__(result_type, grad_output, output, dim, input_dtype, loc=loc, ip=ip) + + +class AtenTanhBackwardOp: + def __init__(self, grad_output: Value, output: Value, *, loc=None, ip=None): + if not is_mlir_value(grad_output): + assert is_mlir_value(grad_output), f'`grad_output` should be a Value but is {type(grad_output).__module__}.{type(grad_output).__name__}' + else: + grad_output = get_op_result_or_value(grad_output) + assert str(grad_output.type).startswith("!torch.vtensor"), f'`grad_output` should be a torch.vtensor but is {type(grad_output).__module__}.{type(grad_output).__name__}' + + if not is_mlir_value(output): + assert is_mlir_value(output), f'`output` should be a Value but is {type(output).__module__}.{type(output).__name__}' + else: + output = get_op_result_or_value(output) + assert str(output.type).startswith("!torch.vtensor"), f'`output` should be a torch.vtensor but is {type(output).__module__}.{type(output).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenTanhBackwardOp, self).__init__(result_type, grad_output, output, loc=loc, ip=ip) + + +class AtenGeluBackwardOp: + def __init__(self, grad_output: Value, self_: Value, approximate: str, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(grad_output): + assert is_mlir_value(grad_output), f'`grad_output` should be a Value but is {type(grad_output).__module__}.{type(grad_output).__name__}' + else: + grad_output = get_op_result_or_value(grad_output) + assert str(grad_output.type).startswith("!torch.vtensor"), f'`grad_output` should be a torch.vtensor but is {type(grad_output).__module__}.{type(grad_output).__name__}' + + if not is_mlir_value(self_): + assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}' + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}' + + if not is_mlir_value(approximate): + approximate = torch_dialect.ConstantStrOp(approximate) + else: + approximate = get_op_result_or_value(approximate) + assert str(approximate.type) == '!torch.str', f'`approximate` should be a !torch.str but is {type(approximate).__module__}.{type(approximate).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenGeluBackwardOp, self).__init__(result_type, grad_output, self_, approximate, loc=loc, ip=ip) + + +class Aten_LogSoftmaxBackwardDataOp: + def __init__(self, grad_output: Value, output: Value, dim: int, input_dtype: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(grad_output): + assert is_mlir_value(grad_output), f'`grad_output` should be a Value but is {type(grad_output).__module__}.{type(grad_output).__name__}' + else: + grad_output = get_op_result_or_value(grad_output) + assert str(grad_output.type).startswith("!torch.vtensor"), f'`grad_output` should be a torch.vtensor but is {type(grad_output).__module__}.{type(grad_output).__name__}' + + if not is_mlir_value(output): + assert is_mlir_value(output), f'`output` should be a Value but is {type(output).__module__}.{type(output).__name__}' + else: + output = get_op_result_or_value(output) + assert str(output.type).startswith("!torch.vtensor"), f'`output` should be a torch.vtensor but is {type(output).__module__}.{type(output).__name__}' + + if not is_mlir_value(dim): + dim = torch_dialect.ConstantIntOp(dim) + else: + dim = get_op_result_or_value(dim) + assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}' + + if not is_mlir_value(input_dtype): + input_dtype = torch_dialect.ConstantIntOp(input_dtype) + else: + input_dtype = get_op_result_or_value(input_dtype) + assert str(input_dtype.type) == '!torch.int', f'`input_dtype` should be a !torch.int but is {type(input_dtype).__module__}.{type(input_dtype).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(Aten_LogSoftmaxBackwardDataOp, self).__init__(result_type, grad_output, output, dim, input_dtype, loc=loc, ip=ip) + + +class AtenNativeLayerNormBackwardOp: + def __init__(self, grad_out: Value, input: Value, normalized_shape: List[int], mean: Value, rstd: Value, weight: Optional[Value], bias: Optional[Value], output_mask: List[bool], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(grad_out): + assert is_mlir_value(grad_out), f'`grad_out` should be a Value but is {type(grad_out).__module__}.{type(grad_out).__name__}' + else: + grad_out = get_op_result_or_value(grad_out) + assert str(grad_out.type).startswith("!torch.vtensor"), f'`grad_out` should be a torch.vtensor but is {type(grad_out).__module__}.{type(grad_out).__name__}' + + if not is_mlir_value(input): + assert is_mlir_value(input), f'`input` should be a Value but is {type(input).__module__}.{type(input).__name__}' + else: + input = get_op_result_or_value(input) + assert str(input.type).startswith("!torch.vtensor"), f'`input` should be a torch.vtensor but is {type(input).__module__}.{type(input).__name__}' + + if not is_mlir_value(normalized_shape): + normalized_shape = list(map(torch_dialect.ConstantIntOp, normalized_shape)) + normalized_shape = torch_dialect.PrimListConstructOp(normalized_shape) + else: + normalized_shape = get_op_result_or_value(normalized_shape) + assert str(normalized_shape.type) == '!torch.list', f'`normalized_shape` should be a !torch.list but is {type(normalized_shape).__module__}.{type(normalized_shape).__name__}' + + if not is_mlir_value(mean): + assert is_mlir_value(mean), f'`mean` should be a Value but is {type(mean).__module__}.{type(mean).__name__}' + else: + mean = get_op_result_or_value(mean) + assert str(mean.type).startswith("!torch.vtensor"), f'`mean` should be a torch.vtensor but is {type(mean).__module__}.{type(mean).__name__}' + + if not is_mlir_value(rstd): + assert is_mlir_value(rstd), f'`rstd` should be a Value but is {type(rstd).__module__}.{type(rstd).__name__}' + else: + rstd = get_op_result_or_value(rstd) + assert str(rstd.type).startswith("!torch.vtensor"), f'`rstd` should be a torch.vtensor but is {type(rstd).__module__}.{type(rstd).__name__}' + + if not is_mlir_value(weight): + if weight is not None: + assert is_mlir_value(weight), f'`weight` should be a Value but is {type(weight).__module__}.{type(weight).__name__}' + else: + weight = torch_dialect.ConstantNoneOp() + else: + weight = get_op_result_or_value(weight) + assert str(weight.type).startswith("!torch.vtensor"), f'`weight` should be a torch.vtensor but is {type(weight).__module__}.{type(weight).__name__}' + + if not is_mlir_value(bias): + if bias is not None: + assert is_mlir_value(bias), f'`bias` should be a Value but is {type(bias).__module__}.{type(bias).__name__}' + else: + bias = torch_dialect.ConstantNoneOp() + else: + bias = get_op_result_or_value(bias) + assert str(bias.type).startswith("!torch.vtensor"), f'`bias` should be a torch.vtensor but is {type(bias).__module__}.{type(bias).__name__}' + + if not is_mlir_value(output_mask): + output_mask = list(map(torch_dialect.ConstantBoolOp, output_mask)) + output_mask = torch_dialect.PrimListConstructOp(output_mask) + else: + output_mask = get_op_result_or_value(output_mask) + # should be bool[] + pass + + result0_type = Type.parse("!torch.vtensor") + result1_type = Type.parse("!torch.vtensor") + result2_type = Type.parse("!torch.vtensor") + super(AtenNativeLayerNormBackwardOp, self).__init__(result0_type, result1_type, result2_type, grad_out, input, normalized_shape, mean, rstd, weight, bias, output_mask, loc=loc, ip=ip) + + +class AtenEmbeddingDenseBackwardOp: + def __init__(self, grad_output: Value, indices: Value, num_weights: int, padding_idx: int, scale_grad_by_freq: bool, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(grad_output): + assert is_mlir_value(grad_output), f'`grad_output` should be a Value but is {type(grad_output).__module__}.{type(grad_output).__name__}' + else: + grad_output = get_op_result_or_value(grad_output) + assert str(grad_output.type).startswith("!torch.vtensor"), f'`grad_output` should be a torch.vtensor but is {type(grad_output).__module__}.{type(grad_output).__name__}' + + if not is_mlir_value(indices): + assert is_mlir_value(indices), f'`indices` should be a Value but is {type(indices).__module__}.{type(indices).__name__}' + else: + indices = get_op_result_or_value(indices) + assert str(indices.type).startswith("!torch.vtensor"), f'`indices` should be a torch.vtensor but is {type(indices).__module__}.{type(indices).__name__}' + + if not is_mlir_value(num_weights): + num_weights = torch_dialect.ConstantIntOp(num_weights) + else: + num_weights = get_op_result_or_value(num_weights) + assert str(num_weights.type) == '!torch.int', f'`num_weights` should be a !torch.int but is {type(num_weights).__module__}.{type(num_weights).__name__}' + + if not is_mlir_value(padding_idx): + padding_idx = torch_dialect.ConstantIntOp(padding_idx) + else: + padding_idx = get_op_result_or_value(padding_idx) + assert str(padding_idx.type) == '!torch.int', f'`padding_idx` should be a !torch.int but is {type(padding_idx).__module__}.{type(padding_idx).__name__}' + + if not is_mlir_value(scale_grad_by_freq): + scale_grad_by_freq = torch_dialect.ConstantBoolOp(scale_grad_by_freq) + else: + scale_grad_by_freq = get_op_result_or_value(scale_grad_by_freq) + assert str(scale_grad_by_freq.type) == '!torch.bool', f'`scale_grad_by_freq` should be a !torch.bool but is {type(scale_grad_by_freq).__module__}.{type(scale_grad_by_freq).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenEmbeddingDenseBackwardOp, self).__init__(result_type, grad_output, indices, num_weights, padding_idx, scale_grad_by_freq, loc=loc, ip=ip) + + +class AtenNativeBatchNormBackwardOp: + def __init__(self, grad_out: Value, input: Value, weight: Optional[Value], running_mean: Optional[Value], running_var: Optional[Value], save_mean: Optional[Value], save_invstd: Optional[Value], train: bool, eps: float, output_mask: List[bool], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(grad_out): + assert is_mlir_value(grad_out), f'`grad_out` should be a Value but is {type(grad_out).__module__}.{type(grad_out).__name__}' + else: + grad_out = get_op_result_or_value(grad_out) + assert str(grad_out.type).startswith("!torch.vtensor"), f'`grad_out` should be a torch.vtensor but is {type(grad_out).__module__}.{type(grad_out).__name__}' + + if not is_mlir_value(input): + assert is_mlir_value(input), f'`input` should be a Value but is {type(input).__module__}.{type(input).__name__}' + else: + input = get_op_result_or_value(input) + assert str(input.type).startswith("!torch.vtensor"), f'`input` should be a torch.vtensor but is {type(input).__module__}.{type(input).__name__}' + + if not is_mlir_value(weight): + if weight is not None: + assert is_mlir_value(weight), f'`weight` should be a Value but is {type(weight).__module__}.{type(weight).__name__}' + else: + weight = torch_dialect.ConstantNoneOp() + else: + weight = get_op_result_or_value(weight) + assert str(weight.type).startswith("!torch.vtensor"), f'`weight` should be a torch.vtensor but is {type(weight).__module__}.{type(weight).__name__}' + + if not is_mlir_value(running_mean): + if running_mean is not None: + assert is_mlir_value(running_mean), f'`running_mean` should be a Value but is {type(running_mean).__module__}.{type(running_mean).__name__}' + else: + running_mean = torch_dialect.ConstantNoneOp() + else: + running_mean = get_op_result_or_value(running_mean) + assert str(running_mean.type).startswith("!torch.vtensor"), f'`running_mean` should be a torch.vtensor but is {type(running_mean).__module__}.{type(running_mean).__name__}' + + if not is_mlir_value(running_var): + if running_var is not None: + assert is_mlir_value(running_var), f'`running_var` should be a Value but is {type(running_var).__module__}.{type(running_var).__name__}' + else: + running_var = torch_dialect.ConstantNoneOp() + else: + running_var = get_op_result_or_value(running_var) + assert str(running_var.type).startswith("!torch.vtensor"), f'`running_var` should be a torch.vtensor but is {type(running_var).__module__}.{type(running_var).__name__}' + + if not is_mlir_value(save_mean): + if save_mean is not None: + assert is_mlir_value(save_mean), f'`save_mean` should be a Value but is {type(save_mean).__module__}.{type(save_mean).__name__}' + else: + save_mean = torch_dialect.ConstantNoneOp() + else: + save_mean = get_op_result_or_value(save_mean) + assert str(save_mean.type).startswith("!torch.vtensor"), f'`save_mean` should be a torch.vtensor but is {type(save_mean).__module__}.{type(save_mean).__name__}' + + if not is_mlir_value(save_invstd): + if save_invstd is not None: + assert is_mlir_value(save_invstd), f'`save_invstd` should be a Value but is {type(save_invstd).__module__}.{type(save_invstd).__name__}' + else: + save_invstd = torch_dialect.ConstantNoneOp() + else: + save_invstd = get_op_result_or_value(save_invstd) + assert str(save_invstd.type).startswith("!torch.vtensor"), f'`save_invstd` should be a torch.vtensor but is {type(save_invstd).__module__}.{type(save_invstd).__name__}' + + if not is_mlir_value(train): + train = torch_dialect.ConstantBoolOp(train) + else: + train = get_op_result_or_value(train) + assert str(train.type) == '!torch.bool', f'`train` should be a !torch.bool but is {type(train).__module__}.{type(train).__name__}' + + if not is_mlir_value(eps): + eps = torch_dialect.ConstantFloatOp(eps) + else: + eps = get_op_result_or_value(eps) + assert str(eps.type) == '!torch.float', f'`eps` should be a !torch.float but is {type(eps).__module__}.{type(eps).__name__}' + + if not is_mlir_value(output_mask): + output_mask = list(map(torch_dialect.ConstantBoolOp, output_mask)) + output_mask = torch_dialect.PrimListConstructOp(output_mask) + else: + output_mask = get_op_result_or_value(output_mask) + # should be bool[] + pass + + result0_type = Type.parse("!torch.vtensor") + result1_type = Type.parse("!torch.vtensor") + result2_type = Type.parse("!torch.vtensor") + super(AtenNativeBatchNormBackwardOp, self).__init__(result0_type, result1_type, result2_type, grad_out, input, weight, running_mean, running_var, save_mean, save_invstd, train, eps, output_mask, loc=loc, ip=ip) + + +class AtenNativeDropoutBackwardOp: + def __init__(self, grad_output: Value, mask: Value, scale: float, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(grad_output): + assert is_mlir_value(grad_output), f'`grad_output` should be a Value but is {type(grad_output).__module__}.{type(grad_output).__name__}' + else: + grad_output = get_op_result_or_value(grad_output) + assert str(grad_output.type).startswith("!torch.vtensor"), f'`grad_output` should be a torch.vtensor but is {type(grad_output).__module__}.{type(grad_output).__name__}' + + if not is_mlir_value(mask): + assert is_mlir_value(mask), f'`mask` should be a Value but is {type(mask).__module__}.{type(mask).__name__}' + else: + mask = get_op_result_or_value(mask) + assert str(mask.type).startswith("!torch.vtensor"), f'`mask` should be a torch.vtensor but is {type(mask).__module__}.{type(mask).__name__}' + + if not is_mlir_value(scale): + scale = torch_dialect.ConstantFloatOp(scale) + else: + scale = get_op_result_or_value(scale) + assert str(scale.type) == '!torch.float', f'`scale` should be a !torch.float but is {type(scale).__module__}.{type(scale).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(AtenNativeDropoutBackwardOp, self).__init__(result_type, grad_output, mask, scale, loc=loc, ip=ip) + + +class PrimLayoutOp: + def __init__(self, a: Value, *, loc=None, ip=None): + if not is_mlir_value(a): + assert is_mlir_value(a), f'`a` should be a Value but is {type(a).__module__}.{type(a).__name__}' + else: + a = get_op_result_or_value(a) + assert str(a.type).startswith("!torch.vtensor"), f'`a` should be a torch.vtensor but is {type(a).__module__}.{type(a).__name__}' + + super(PrimLayoutOp, self).__init__(a, loc=loc, ip=ip) + + +class PrimTupleIndexOp: + def __init__(self, tup: Any, i: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(tup): + assert is_mlir_value(tup), f'`tup` should be a Value but is {type(tup).__module__}.{type(tup).__name__}' + else: + tup = get_op_result_or_value(tup) + assert str(tup.type) == '!torch.Any', f'`tup` should be a !torch.Any but is {type(tup).__module__}.{type(tup).__name__}' + + if not is_mlir_value(i): + i = torch_dialect.ConstantIntOp(i) + else: + i = get_op_result_or_value(i) + assert str(i.type) == '!torch.int', f'`i` should be a !torch.int but is {type(i).__module__}.{type(i).__name__}' + + super(PrimTupleIndexOp, self).__init__(tup, i, loc=loc, ip=ip) + + +class PrimDtypeOp: + def __init__(self, a: Value, *, loc=None, ip=None): + if not is_mlir_value(a): + assert is_mlir_value(a), f'`a` should be a Value but is {type(a).__module__}.{type(a).__name__}' + else: + a = get_op_result_or_value(a) + assert str(a.type).startswith("!torch.vtensor"), f'`a` should be a torch.vtensor but is {type(a).__module__}.{type(a).__name__}' + + super(PrimDtypeOp, self).__init__(a, loc=loc, ip=ip) + + +class PrimNumToTensorScalarOp: + def __init__(self, a: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantNumberOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) in {'!torch.float', '!torch.int'}, f'`a` should be a !torch.number but is {type(a).__module__}.{type(a).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(PrimNumToTensorScalarOp, self).__init__(result_type, a, loc=loc, ip=ip) + + +class PrimMinSelfIntOp: + def __init__(self, self_: List[int], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + self_ = list(map(torch_dialect.ConstantIntOp, self_)) + self_ = torch_dialect.PrimListConstructOp(self_) + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type) == '!torch.list', f'`self_` should be a !torch.list but is {type(self_).__module__}.{type(self_).__name__}' + + super(PrimMinSelfIntOp, self).__init__(self_, loc=loc, ip=ip) + + +class PrimMinIntOp: + def __init__(self, a: int, b: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantIntOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) == '!torch.int', f'`a` should be a !torch.int but is {type(a).__module__}.{type(a).__name__}' + + if not is_mlir_value(b): + b = torch_dialect.ConstantIntOp(b) + else: + b = get_op_result_or_value(b) + assert str(b.type) == '!torch.int', f'`b` should be a !torch.int but is {type(b).__module__}.{type(b).__name__}' + + super(PrimMinIntOp, self).__init__(a, b, loc=loc, ip=ip) + + +class PrimMaxSelfIntOp: + def __init__(self, self_: List[int], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(self_): + self_ = list(map(torch_dialect.ConstantIntOp, self_)) + self_ = torch_dialect.PrimListConstructOp(self_) + else: + self_ = get_op_result_or_value(self_) + assert str(self_.type) == '!torch.list', f'`self_` should be a !torch.list but is {type(self_).__module__}.{type(self_).__name__}' + + super(PrimMaxSelfIntOp, self).__init__(self_, loc=loc, ip=ip) + + +class PrimMaxIntOp: + def __init__(self, a: int, b: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantIntOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) == '!torch.int', f'`a` should be a !torch.int but is {type(a).__module__}.{type(a).__name__}' + + if not is_mlir_value(b): + b = torch_dialect.ConstantIntOp(b) + else: + b = get_op_result_or_value(b) + assert str(b.type) == '!torch.int', f'`b` should be a !torch.int but is {type(b).__module__}.{type(b).__name__}' + + super(PrimMaxIntOp, self).__init__(a, b, loc=loc, ip=ip) + + +class PrimRaiseExceptionOp: + def __init__(self, msg: str, cls: Optional[str], *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(msg): + msg = torch_dialect.ConstantStrOp(msg) + else: + msg = get_op_result_or_value(msg) + assert str(msg.type) == '!torch.str', f'`msg` should be a !torch.str but is {type(msg).__module__}.{type(msg).__name__}' + + if not is_mlir_value(cls): + if cls is not None: + cls = torch_dialect.ConstantStrOp(cls) + else: + cls = torch_dialect.ConstantNoneOp() + else: + cls = get_op_result_or_value(cls) + assert str(cls.type) == '!torch.str', f'`cls` should be a !torch.str but is {type(cls).__module__}.{type(cls).__name__}' + + super(PrimRaiseExceptionOp, self).__init__(msg, cls, loc=loc, ip=ip) + + +class PrimUninitializedOp: + def __init__(self, *, loc=None, ip=None): + super(PrimUninitializedOp, self).__init__(loc=loc, ip=ip) + + +class PrimUncheckedCastOp: + def __init__(self, x: Value, *, loc=None, ip=None): + if not is_mlir_value(x): + assert is_mlir_value(x), f'`x` should be a Value but is {type(x).__module__}.{type(x).__name__}' + else: + x = get_op_result_or_value(x) + assert str(x.type).startswith("!torch.vtensor"), f'`x` should be a torch.vtensor but is {type(x).__module__}.{type(x).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(PrimUncheckedCastOp, self).__init__(result_type, x, loc=loc, ip=ip) + + +class PrimAbsScalarOp: + def __init__(self, a: "Number", *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + a = torch_dialect.ConstantNumberOp(a) + else: + a = get_op_result_or_value(a) + assert str(a.type) in {'!torch.float', '!torch.int'}, f'`a` should be a !torch.number but is {type(a).__module__}.{type(a).__name__}' + + super(PrimAbsScalarOp, self).__init__(a, loc=loc, ip=ip) + + +class PrimsConvertElementTypeOp: + def __init__(self, a: Value, dtype: int, *, loc=None, ip=None): + from torch_mlir.dialects import torch as torch_dialect + + if not is_mlir_value(a): + assert is_mlir_value(a), f'`a` should be a Value but is {type(a).__module__}.{type(a).__name__}' + else: + a = get_op_result_or_value(a) + assert str(a.type).startswith("!torch.vtensor"), f'`a` should be a torch.vtensor but is {type(a).__module__}.{type(a).__name__}' + + if not is_mlir_value(dtype): + dtype = torch_dialect.ConstantIntOp(dtype) + else: + dtype = get_op_result_or_value(dtype) + assert str(dtype.type) == '!torch.int', f'`dtype` should be a !torch.int but is {type(dtype).__module__}.{type(dtype).__name__}' + + result_type = Type.parse("!torch.vtensor") + super(PrimsConvertElementTypeOp, self).__init__(result_type, a, dtype, loc=loc, ip=ip) + + diff --git a/pi/dialects/_torch_ops_ext_custom.py b/pi/dialects/_torch_ops_ext_custom.py new file mode 100644 index 0000000..5f03d35 --- /dev/null +++ b/pi/dialects/_torch_ops_ext_custom.py @@ -0,0 +1,113 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +try: + from torch_mlir.ir import * + from torch_mlir.dialects._ods_common import ( + get_default_loc_context, + get_op_result_or_value, + get_op_results_or_values, + ) + from ._torch_ops_ext import * +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +import re +from typing import Any, Optional, Tuple, List, Union + + +def is_mlir_value(v): + return isinstance(v, (OpView, Operation, Value, OpResultList)) + + +class ConstantFloatOp: + def __init__(self, value: float): + f64 = F64Type.get() + # f32 = F32Type.get() + super().__init__(FloatAttr.get(f64, value)) + + +class ConstantIntOp: + def __init__(self, value: int): + i64 = IntegerType.get_signless(64) + super().__init__(IntegerAttr.get(i64, value)) + + +class ConstantStrOp: + def __init__(self, value: int): + super().__init__(StringAttr.get(value)) + + +class ConstantBoolOp: + def __init__(self, value: bool): + i1 = IntegerType.get_signless(1) + super().__init__(IntegerAttr.get(i1, int(value))) + + +class ConstantNumberOp: + def __init__(self, value: Union[int, float]): + if isinstance(value, int): + i64 = IntegerType.get_signless(64) + super().__init__(IntegerAttr.get(i64, value)) + elif isinstance(value, float): + f64 = F64Type.get() + # f32 = F32Type.get() + super().__init__(FloatAttr.get(f64, value)) + else: + raise Exception(f"unknown number type {value}") + + +el_type_reg = re.compile(r"!torch\.(.*)") + + +class PrimListConstructOp: + def __init__( + self, + elements, + *, + loc=None, + ip=None, + ): + if len(elements): + elements = get_op_results_or_values(elements) + el_type = get_op_result_or_value(elements[0]).type + el_type_str = el_type_reg.findall(str(el_type))[0] + res_type = Type.parse(f"!torch.list<{el_type_str}>") + else: + res_type = Type.parse(f"!torch.list") + super().__init__(res_type, elements, loc=loc, ip=ip) + + +class PrimTupleConstructOp: + def __init__( + self, + elements, + *, + loc=None, + ip=None, + ): + if len(elements): + elements = get_op_results_or_values(elements) + el_types = ", ".join( + [el_type_reg.findall(str(e.type))[0] for e in elements] + ) + res_type = Type.parse(f"!torch.tuple<{el_types}>") + else: + res_type = Type.parse(f"!torch.tuple") + super().__init__(res_type, elements, loc=loc, ip=ip) + + +class AtenScalarImplicitOp: + def __init__(self, a: "pi.Tensor", *, loc=None, ip=None): + if not is_mlir_value(a): + assert is_mlir_value( + a + ), f"`a` should be a Tensor but is {type(a).__module__}.{type(a).__name__}" + else: + a = get_op_result_or_value(a) + assert str(a.type).startswith( + "!torch.vtensor" + ), f"`a` should be a Tensor but is {type(a).__module__}.{type(a).__name__}" + + super(AtenScalarImplicitOp, self).__init__(a, loc=loc, ip=ip) diff --git a/pi/dialects/_torch_wrappers.py b/pi/dialects/_torch_wrappers.py new file mode 100644 index 0000000..65a291c --- /dev/null +++ b/pi/dialects/_torch_wrappers.py @@ -0,0 +1,2161 @@ +from enum import Enum +import builtins + +from .._tensor import Tensor +from ..types_ import Number, is_a_torch_tensor +from typing import List, Optional, Any, Tuple + +from torch_mlir.dialects import torch as torch_dialect +from torch_mlir.dialects._ods_common import ( + get_default_loc_context, + get_op_result_or_value, + get_op_results_or_values, +) + +def tanh(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenTanhOp(self_)) + +def tanh_(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenTanh_Op(self_)) + +def hardtanh(self_: Tensor, min_val: Number = -1, max_val: Number = 1) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenHardtanhOp(self_, min_val, max_val)) + +def hardtanh_(self_: Tensor, min_val: Number = -1, max_val: Number = 1) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenHardtanh_Op(self_, min_val, max_val)) + +def relu(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenReluOp(self_)) + +def relu_(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenRelu_Op(self_)) + +def relu6(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenRelu6Op(self_)) + +def relu6_(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenRelu6_Op(self_)) + +def leaky_relu(self_: Tensor, negative_slope: Number = 0.01) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenLeakyReluOp(self_, negative_slope)) + +def leaky_relu_(self_: Tensor, negative_slope: Number = 0.01) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenLeakyRelu_Op(self_, negative_slope)) + +def log(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenLogOp(self_)) + +def log_(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenLog_Op(self_)) + +def sigmoid(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenSigmoidOp(self_)) + +def sigmoid_(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenSigmoid_Op(self_)) + +def hardsigmoid(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenHardsigmoidOp(self_)) + +def hardsigmoid_(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenHardsigmoid_Op(self_)) + +def hardswish(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenHardswishOp(self_)) + +def hardswish_(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenHardswish_Op(self_)) + +def erf(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenErfOp(self_)) + +def erf_(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenErf_Op(self_)) + +def silu(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenSiluOp(self_)) + +def silu_(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenSilu_Op(self_)) + +def sin(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenSinOp(self_)) + +def sin_(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenSin_Op(self_)) + +def exp(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenExpOp(self_)) + +def exp_(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenExp_Op(self_)) + +def expm1(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenExpm1Op(self_)) + +def expm1_(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenExpm1_Op(self_)) + +def cos(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenCosOp(self_)) + +def cos_(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenCos_Op(self_)) + +def atan2(self_: Tensor, other: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(other, Tensor), f'`other` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(other).__module__}.{type(other).__name__}' + other = other.value + return Tensor(torch_dialect.AtenAtan2Op(self_, other)) + +def atan2_(self_: Tensor, other: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(other, Tensor), f'`other` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(other).__module__}.{type(other).__name__}' + other = other.value + return Tensor(torch_dialect.AtenAtan2_Op(self_, other)) + +def neg(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenNegOp(self_)) + +def neg_(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenNeg_Op(self_)) + +def floor(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenFloorOp(self_)) + +def floor_(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenFloor_Op(self_)) + +def ceil(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenCeilOp(self_)) + +def ceil_(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenCeil_Op(self_)) + +def bitwise_not(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenBitwiseNotOp(self_)) + +def bitwise_not_(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenBitwiseNot_Op(self_)) + +def div_Tensor(self_: Tensor, other: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(other, Tensor), f'`other` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(other).__module__}.{type(other).__name__}' + other = other.value + return Tensor(torch_dialect.AtenDivTensorOp(self_, other)) + +def div__Tensor(self_: Tensor, other: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(other, Tensor), f'`other` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(other).__module__}.{type(other).__name__}' + other = other.value + return Tensor(torch_dialect.AtenDiv_TensorOp(self_, other)) + +def logical_or(self_: Tensor, other: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(other, Tensor), f'`other` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(other).__module__}.{type(other).__name__}' + other = other.value + return Tensor(torch_dialect.AtenLogicalOrOp(self_, other)) + +def logical_or_(self_: Tensor, other: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(other, Tensor), f'`other` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(other).__module__}.{type(other).__name__}' + other = other.value + return Tensor(torch_dialect.AtenLogicalOr_Op(self_, other)) + +def logical_and(self_: Tensor, other: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(other, Tensor), f'`other` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(other).__module__}.{type(other).__name__}' + other = other.value + return Tensor(torch_dialect.AtenLogicalAndOp(self_, other)) + +def logical_and_(self_: Tensor, other: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(other, Tensor), f'`other` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(other).__module__}.{type(other).__name__}' + other = other.value + return Tensor(torch_dialect.AtenLogicalAnd_Op(self_, other)) + +def logical_xor(self_: Tensor, other: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(other, Tensor), f'`other` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(other).__module__}.{type(other).__name__}' + other = other.value + return Tensor(torch_dialect.AtenLogicalXorOp(self_, other)) + +def logical_xor_(self_: Tensor, other: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(other, Tensor), f'`other` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(other).__module__}.{type(other).__name__}' + other = other.value + return Tensor(torch_dialect.AtenLogicalXor_Op(self_, other)) + +def logical_not(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenLogicalNotOp(self_)) + +def logical_not_(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenLogicalNot_Op(self_)) + +def lerp_Tensor(self_: Tensor, end: Tensor, weight: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(end, Tensor), f'`end` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(end).__module__}.{type(end).__name__}' + end = end.value + assert isinstance(weight, Tensor), f'`weight` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(weight).__module__}.{type(weight).__name__}' + weight = weight.value + return Tensor(torch_dialect.AtenLerpTensorOp(self_, end, weight)) + +def lerp__Tensor(self_: Tensor, end: Tensor, weight: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(end, Tensor), f'`end` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(end).__module__}.{type(end).__name__}' + end = end.value + assert isinstance(weight, Tensor), f'`weight` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(weight).__module__}.{type(weight).__name__}' + weight = weight.value + return Tensor(torch_dialect.AtenLerp_TensorOp(self_, end, weight)) + +def eq_Tensor(self_: Tensor, other: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(other, Tensor), f'`other` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(other).__module__}.{type(other).__name__}' + other = other.value + return Tensor(torch_dialect.AtenEqTensorOp(self_, other)) + +def eq__Tensor(self_: Tensor, other: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(other, Tensor), f'`other` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(other).__module__}.{type(other).__name__}' + other = other.value + return Tensor(torch_dialect.AtenEq_TensorOp(self_, other)) + +def gt_Tensor(self_: Tensor, other: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(other, Tensor), f'`other` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(other).__module__}.{type(other).__name__}' + other = other.value + return Tensor(torch_dialect.AtenGtTensorOp(self_, other)) + +def gt__Tensor(self_: Tensor, other: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(other, Tensor), f'`other` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(other).__module__}.{type(other).__name__}' + other = other.value + return Tensor(torch_dialect.AtenGt_TensorOp(self_, other)) + +def ge_Tensor(self_: Tensor, other: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(other, Tensor), f'`other` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(other).__module__}.{type(other).__name__}' + other = other.value + return Tensor(torch_dialect.AtenGeTensorOp(self_, other)) + +def ge__Tensor(self_: Tensor, other: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(other, Tensor), f'`other` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(other).__module__}.{type(other).__name__}' + other = other.value + return Tensor(torch_dialect.AtenGe_TensorOp(self_, other)) + +def lt_Tensor(self_: Tensor, other: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(other, Tensor), f'`other` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(other).__module__}.{type(other).__name__}' + other = other.value + return Tensor(torch_dialect.AtenLtTensorOp(self_, other)) + +def lt__Tensor(self_: Tensor, other: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(other, Tensor), f'`other` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(other).__module__}.{type(other).__name__}' + other = other.value + return Tensor(torch_dialect.AtenLt_TensorOp(self_, other)) + +def le_Tensor(self_: Tensor, other: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(other, Tensor), f'`other` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(other).__module__}.{type(other).__name__}' + other = other.value + return Tensor(torch_dialect.AtenLeTensorOp(self_, other)) + +def le__Tensor(self_: Tensor, other: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(other, Tensor), f'`other` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(other).__module__}.{type(other).__name__}' + other = other.value + return Tensor(torch_dialect.AtenLe_TensorOp(self_, other)) + +def ne_Tensor(self_: Tensor, other: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(other, Tensor), f'`other` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(other).__module__}.{type(other).__name__}' + other = other.value + return Tensor(torch_dialect.AtenNeTensorOp(self_, other)) + +def ne__Tensor(self_: Tensor, other: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(other, Tensor), f'`other` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(other).__module__}.{type(other).__name__}' + other = other.value + return Tensor(torch_dialect.AtenNe_TensorOp(self_, other)) + +def div_Scalar(self_: Tensor, other: Number) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenDivScalarOp(self_, other)) + +def div__Scalar(self_: Tensor, other: Number) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenDiv_ScalarOp(self_, other)) + +def ne_Scalar(self_: Tensor, other: Number) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenNeScalarOp(self_, other)) + +def ne__Scalar(self_: Tensor, other: Number) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenNe_ScalarOp(self_, other)) + +def eq_Scalar(self_: Tensor, other: Number) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenEqScalarOp(self_, other)) + +def eq__Scalar(self_: Tensor, other: Number) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenEq_ScalarOp(self_, other)) + +def gt_Scalar(self_: Tensor, other: Number) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenGtScalarOp(self_, other)) + +def gt__Scalar(self_: Tensor, other: Number) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenGt_ScalarOp(self_, other)) + +def ge_Scalar(self_: Tensor, other: Number) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenGeScalarOp(self_, other)) + +def ge__Scalar(self_: Tensor, other: Number) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenGe_ScalarOp(self_, other)) + +def lt_Scalar(self_: Tensor, other: Number) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenLtScalarOp(self_, other)) + +def lt__Scalar(self_: Tensor, other: Number) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenLt_ScalarOp(self_, other)) + +def le_Scalar(self_: Tensor, other: Number) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenLeScalarOp(self_, other)) + +def le__Scalar(self_: Tensor, other: Number) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenLe_ScalarOp(self_, other)) + +def fmod_Scalar(self_: Tensor, other: Number) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenFmodScalarOp(self_, other)) + +def fmod__Scalar(self_: Tensor, other: Number) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenFmod_ScalarOp(self_, other)) + +def masked_fill_Scalar(self_: Tensor, mask: Tensor, value: Number) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(mask, Tensor), f'`mask` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(mask).__module__}.{type(mask).__name__}' + mask = mask.value + return Tensor(torch_dialect.AtenMaskedFillScalarOp(self_, mask, value)) + +def masked_fill__Scalar(self_: Tensor, mask: Tensor, value: Number) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(mask, Tensor), f'`mask` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(mask).__module__}.{type(mask).__name__}' + mask = mask.value + return Tensor(torch_dialect.AtenMaskedFill_ScalarOp(self_, mask, value)) + +def masked_fill_Tensor(self_: Tensor, mask: Tensor, value: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(mask, Tensor), f'`mask` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(mask).__module__}.{type(mask).__name__}' + mask = mask.value + assert isinstance(value, Tensor), f'`value` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(value).__module__}.{type(value).__name__}' + value = value.value + return Tensor(torch_dialect.AtenMaskedFillTensorOp(self_, mask, value)) + +def masked_fill__Tensor(self_: Tensor, mask: Tensor, value: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(mask, Tensor), f'`mask` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(mask).__module__}.{type(mask).__name__}' + mask = mask.value + assert isinstance(value, Tensor), f'`value` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(value).__module__}.{type(value).__name__}' + value = value.value + return Tensor(torch_dialect.AtenMaskedFill_TensorOp(self_, mask, value)) + +def clamp(self_: Tensor, min: Optional[Number] = None, max: Optional[Number] = None) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenClampOp(self_, min, max)) + +def clamp_(self_: Tensor, min: Optional[Number] = None, max: Optional[Number] = None) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenClamp_Op(self_, min, max)) + +def clamp_min(self_: Tensor, min: Number) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenClampMinOp(self_, min)) + +def clamp_min_(self_: Tensor, min: Number) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenClampMin_Op(self_, min)) + +def clamp_max(self_: Tensor, max: Number) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenClampMaxOp(self_, max)) + +def clamp_max_(self_: Tensor, max: Number) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenClampMax_Op(self_, max)) + +def log2(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenLog2Op(self_)) + +def log2_(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenLog2_Op(self_)) + +def sqrt(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenSqrtOp(self_)) + +def sqrt_(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenSqrt_Op(self_)) + +def log1p(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenLog1pOp(self_)) + +def log1p_(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenLog1p_Op(self_)) + +def rsqrt(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenRsqrtOp(self_)) + +def rsqrt_(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenRsqrt_Op(self_)) + +def abs(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenAbsOp(self_)) + +def abs_(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenAbs_Op(self_)) + +def reciprocal(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenReciprocalOp(self_)) + +def reciprocal_(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenReciprocal_Op(self_)) + +def bitwise_and_Tensor(self_: Tensor, other: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(other, Tensor), f'`other` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(other).__module__}.{type(other).__name__}' + other = other.value + return Tensor(torch_dialect.AtenBitwiseAndTensorOp(self_, other)) + +def bitwise_and__Tensor(self_: Tensor, other: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(other, Tensor), f'`other` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(other).__module__}.{type(other).__name__}' + other = other.value + return Tensor(torch_dialect.AtenBitwiseAnd_TensorOp(self_, other)) + +def bitwise_or_Tensor(self_: Tensor, other: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(other, Tensor), f'`other` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(other).__module__}.{type(other).__name__}' + other = other.value + return Tensor(torch_dialect.AtenBitwiseOrTensorOp(self_, other)) + +def bitwise_or__Tensor(self_: Tensor, other: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(other, Tensor), f'`other` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(other).__module__}.{type(other).__name__}' + other = other.value + return Tensor(torch_dialect.AtenBitwiseOr_TensorOp(self_, other)) + +def threshold(self_: Tensor, threshold: Number, value: Number) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenThresholdOp(self_, threshold, value)) + +def threshold_(self_: Tensor, threshold: Number, value: Number) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenThreshold_Op(self_, threshold, value)) + +def square(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenSquareOp(self_)) + +def square_(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenSquare_Op(self_)) + +def unsqueeze(self_: Tensor, dim: int) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenUnsqueezeOp(self_, dim)) + +def unsqueeze_(self_: Tensor, dim: int) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenUnsqueeze_Op(self_, dim)) + +def zero(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenZeroOp(self_)) + +def zero_(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenZero_Op(self_)) + +def fill_Scalar(self_: Tensor, value: Number) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenFillScalarOp(self_, value)) + +def fill__Scalar(self_: Tensor, value: Number) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenFill_ScalarOp(self_, value)) + +def fill_Tensor(self_: Tensor, value: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(value, Tensor), f'`value` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(value).__module__}.{type(value).__name__}' + value = value.value + return Tensor(torch_dialect.AtenFillTensorOp(self_, value)) + +def fill__Tensor(self_: Tensor, value: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(value, Tensor), f'`value` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(value).__module__}.{type(value).__name__}' + value = value.value + return Tensor(torch_dialect.AtenFill_TensorOp(self_, value)) + +def div_Tensor_mode(self_: Tensor, other: Tensor, rounding_mode: Optional[str]) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(other, Tensor), f'`other` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(other).__module__}.{type(other).__name__}' + other = other.value + return Tensor(torch_dialect.AtenDivTensorModeOp(self_, other, rounding_mode)) + +def div__Tensor_mode(self_: Tensor, other: Tensor, rounding_mode: Optional[str]) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(other, Tensor), f'`other` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(other).__module__}.{type(other).__name__}' + other = other.value + return Tensor(torch_dialect.AtenDiv_TensorModeOp(self_, other, rounding_mode)) + +def mul_Tensor(self_: Tensor, other: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(other, Tensor), f'`other` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(other).__module__}.{type(other).__name__}' + other = other.value + return Tensor(torch_dialect.AtenMulTensorOp(self_, other)) + +def mul__Tensor(self_: Tensor, other: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(other, Tensor), f'`other` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(other).__module__}.{type(other).__name__}' + other = other.value + return Tensor(torch_dialect.AtenMul_TensorOp(self_, other)) + +def add_Tensor(self_: Tensor, other: Tensor, alpha: Number = 1) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(other, Tensor), f'`other` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(other).__module__}.{type(other).__name__}' + other = other.value + return Tensor(torch_dialect.AtenAddTensorOp(self_, other, alpha)) + +def add__Tensor(self_: Tensor, other: Tensor, alpha: Number = 1) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(other, Tensor), f'`other` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(other).__module__}.{type(other).__name__}' + other = other.value + return Tensor(torch_dialect.AtenAdd_TensorOp(self_, other, alpha)) + +def sub_Tensor(self_: Tensor, other: Tensor, alpha: Number = 1) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(other, Tensor), f'`other` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(other).__module__}.{type(other).__name__}' + other = other.value + return Tensor(torch_dialect.AtenSubTensorOp(self_, other, alpha)) + +def sub__Tensor(self_: Tensor, other: Tensor, alpha: Number = 1) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(other, Tensor), f'`other` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(other).__module__}.{type(other).__name__}' + other = other.value + return Tensor(torch_dialect.AtenSub_TensorOp(self_, other, alpha)) + +def add_Scalar(self_: Tensor, other: Number, alpha: Number = 1) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenAddScalarOp(self_, other, alpha)) + +def add__Scalar(self_: Tensor, other: Number, alpha: Number = 1) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenAdd_ScalarOp(self_, other, alpha)) + +def sub_Scalar(self_: Tensor, other: Number, alpha: Number = 1) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenSubScalarOp(self_, other, alpha)) + +def sub__Scalar(self_: Tensor, other: Number, alpha: Number = 1) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenSub_ScalarOp(self_, other, alpha)) + +def mul_Scalar(self_: Tensor, other: Number) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenMulScalarOp(self_, other)) + +def mul__Scalar(self_: Tensor, other: Number) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenMul_ScalarOp(self_, other)) + +def addcmul(self_: Tensor, tensor1: Tensor, tensor2: Tensor, value: Number = 1) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(tensor1, Tensor), f'`tensor1` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(tensor1).__module__}.{type(tensor1).__name__}' + tensor1 = tensor1.value + assert isinstance(tensor2, Tensor), f'`tensor2` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(tensor2).__module__}.{type(tensor2).__name__}' + tensor2 = tensor2.value + return Tensor(torch_dialect.AtenAddcmulOp(self_, tensor1, tensor2, value)) + +def addcmul_(self_: Tensor, tensor1: Tensor, tensor2: Tensor, value: Number = 1) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(tensor1, Tensor), f'`tensor1` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(tensor1).__module__}.{type(tensor1).__name__}' + tensor1 = tensor1.value + assert isinstance(tensor2, Tensor), f'`tensor2` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(tensor2).__module__}.{type(tensor2).__name__}' + tensor2 = tensor2.value + return Tensor(torch_dialect.AtenAddcmul_Op(self_, tensor1, tensor2, value)) + +def addcdiv(self_: Tensor, tensor1: Tensor, tensor2: Tensor, value: Number = 1) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(tensor1, Tensor), f'`tensor1` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(tensor1).__module__}.{type(tensor1).__name__}' + tensor1 = tensor1.value + assert isinstance(tensor2, Tensor), f'`tensor2` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(tensor2).__module__}.{type(tensor2).__name__}' + tensor2 = tensor2.value + return Tensor(torch_dialect.AtenAddcdivOp(self_, tensor1, tensor2, value)) + +def addcdiv_(self_: Tensor, tensor1: Tensor, tensor2: Tensor, value: Number = 1) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(tensor1, Tensor), f'`tensor1` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(tensor1).__module__}.{type(tensor1).__name__}' + tensor1 = tensor1.value + assert isinstance(tensor2, Tensor), f'`tensor2` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(tensor2).__module__}.{type(tensor2).__name__}' + tensor2 = tensor2.value + return Tensor(torch_dialect.AtenAddcdiv_Op(self_, tensor1, tensor2, value)) + +def maximum(self_: Tensor, other: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(other, Tensor), f'`other` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(other).__module__}.{type(other).__name__}' + other = other.value + return Tensor(torch_dialect.AtenMaximumOp(self_, other)) + +def minimum(self_: Tensor, other: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(other, Tensor), f'`other` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(other).__module__}.{type(other).__name__}' + other = other.value + return Tensor(torch_dialect.AtenMinimumOp(self_, other)) + +def mish(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenMishOp(self_)) + +def rsub_Scalar(self_: Tensor, other: Number, alpha: Number = 1) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenRsubScalarOp(self_, other, alpha)) + +def gelu(self_: Tensor, approximate: str = "none") -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenGeluOp(self_, approximate)) + +def pow_Tensor_Scalar(self_: Tensor, exponent: Number) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenPowTensorScalarOp(self_, exponent)) + +def pow_Tensor_Tensor(self_: Tensor, exponent: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(exponent, Tensor), f'`exponent` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(exponent).__module__}.{type(exponent).__name__}' + exponent = exponent.value + return Tensor(torch_dialect.AtenPowTensorTensorOp(self_, exponent)) + +def threshold_backward(grad_output: Tensor, self_: Tensor, threshold: Number) -> Tensor: + assert isinstance(grad_output, Tensor), f'`grad_output` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(grad_output).__module__}.{type(grad_output).__name__}' + grad_output = grad_output.value + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenThresholdBackwardOp(grad_output, self_, threshold)) + +def floor_divide(self_: Tensor, other: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(other, Tensor), f'`other` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(other).__module__}.{type(other).__name__}' + other = other.value + return Tensor(torch_dialect.AtenFloorDivideOp(self_, other)) + +def softplus(self_: Tensor, beta: Number = 1, threshold: Number = 20) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenSoftplusOp(self_, beta, threshold)) + +def prelu(self_: Tensor, weight: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(weight, Tensor), f'`weight` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(weight).__module__}.{type(weight).__name__}' + weight = weight.value + return Tensor(torch_dialect.AtenPreluOp(self_, weight)) + +def triu(self_: Tensor, diagonal: int = 0) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenTriuOp(self_, diagonal)) + +def triu_(self_: Tensor, diagonal: int = 0) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenTriu_Op(self_, diagonal)) + +def round(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenRoundOp(self_)) + +def round_(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenRound_Op(self_)) + +def index_put_hacked_twin(self_: Tensor, indices: List[Tensor], values: Tensor, accumulate: bool = False) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert builtins.all(isinstance(t, Tensor) for t in indices) + indices = [t.value for t in indices] + assert isinstance(values, Tensor), f'`values` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(values).__module__}.{type(values).__name__}' + values = values.value + return Tensor(torch_dialect.AtenIndexPutHackedTwinOp(self_, indices, values, accumulate)) + +def index_put__hacked_twin(self_: Tensor, indices: List[Tensor], values: Tensor, accumulate: bool = False) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert builtins.all(isinstance(t, Tensor) for t in indices) + indices = [t.value for t in indices] + assert isinstance(values, Tensor), f'`values` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(values).__module__}.{type(values).__name__}' + values = values.value + return Tensor(torch_dialect.AtenIndexPut_HackedTwinOp(self_, indices, values, accumulate)) + +def linear(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor: + assert isinstance(input, Tensor), f'`input` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(input).__module__}.{type(input).__name__}' + input = input.value + assert isinstance(weight, Tensor), f'`weight` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(weight).__module__}.{type(weight).__name__}' + weight = weight.value + if bias is not None: + assert isinstance(bias, Tensor), f'`bias` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(bias).__module__}.{type(bias).__name__}' + bias = bias.value + return Tensor(torch_dialect.AtenLinearOp(input, weight, bias)) + +def mm(self_: Tensor, mat2: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(mat2, Tensor), f'`mat2` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(mat2).__module__}.{type(mat2).__name__}' + mat2 = mat2.value + return Tensor(torch_dialect.AtenMmOp(self_, mat2)) + +def addmm(self_: Tensor, mat1: Tensor, mat2: Tensor, beta: Number = 1, alpha: Number = 1) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(mat1, Tensor), f'`mat1` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(mat1).__module__}.{type(mat1).__name__}' + mat1 = mat1.value + assert isinstance(mat2, Tensor), f'`mat2` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(mat2).__module__}.{type(mat2).__name__}' + mat2 = mat2.value + return Tensor(torch_dialect.AtenAddmmOp(self_, mat1, mat2, beta, alpha)) + +def matmul(self_: Tensor, other: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(other, Tensor), f'`other` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(other).__module__}.{type(other).__name__}' + other = other.value + return Tensor(torch_dialect.AtenMatmulOp(self_, other)) + +def mv(self_: Tensor, vec: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(vec, Tensor), f'`vec` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(vec).__module__}.{type(vec).__name__}' + vec = vec.value + return Tensor(torch_dialect.AtenMvOp(self_, vec)) + +def conv2d(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None, stride: List[int] = (1, 1), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), groups: int = 1) -> Tensor: + assert isinstance(input, Tensor), f'`input` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(input).__module__}.{type(input).__name__}' + input = input.value + assert isinstance(weight, Tensor), f'`weight` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(weight).__module__}.{type(weight).__name__}' + weight = weight.value + if bias is not None: + assert isinstance(bias, Tensor), f'`bias` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(bias).__module__}.{type(bias).__name__}' + bias = bias.value + return Tensor(torch_dialect.AtenConv2dOp(input, weight, bias, stride, padding, dilation, groups)) + +def conv_transpose1d(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None, stride: List[int] = (1), padding: List[int] = (0), output_padding: List[int] = (0), groups: int = 1, dilation: List[int] = (1)) -> Tensor: + assert isinstance(input, Tensor), f'`input` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(input).__module__}.{type(input).__name__}' + input = input.value + assert isinstance(weight, Tensor), f'`weight` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(weight).__module__}.{type(weight).__name__}' + weight = weight.value + if bias is not None: + assert isinstance(bias, Tensor), f'`bias` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(bias).__module__}.{type(bias).__name__}' + bias = bias.value + return Tensor(torch_dialect.AtenConvTranspose1dOp(input, weight, bias, stride, padding, output_padding, groups, dilation)) + +def conv_transpose2d_input(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None, stride: List[int] = (1, 1), padding: List[int] = (0, 0), output_padding: List[int] = (0, 0), groups: int = 1, dilation: List[int] = (1, 1)) -> Tensor: + assert isinstance(input, Tensor), f'`input` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(input).__module__}.{type(input).__name__}' + input = input.value + assert isinstance(weight, Tensor), f'`weight` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(weight).__module__}.{type(weight).__name__}' + weight = weight.value + if bias is not None: + assert isinstance(bias, Tensor), f'`bias` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(bias).__module__}.{type(bias).__name__}' + bias = bias.value + return Tensor(torch_dialect.AtenConvTranspose2dInputOp(input, weight, bias, stride, padding, output_padding, groups, dilation)) + +def conv_transpose3d_input(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None, stride: List[int] = (1, 1, 1), padding: List[int] = (0, 0, 0), output_padding: List[int] = (0, 0, 0), groups: int = 1, dilation: List[int] = (1, 1, 1)) -> Tensor: + assert isinstance(input, Tensor), f'`input` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(input).__module__}.{type(input).__name__}' + input = input.value + assert isinstance(weight, Tensor), f'`weight` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(weight).__module__}.{type(weight).__name__}' + weight = weight.value + if bias is not None: + assert isinstance(bias, Tensor), f'`bias` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(bias).__module__}.{type(bias).__name__}' + bias = bias.value + return Tensor(torch_dialect.AtenConvTranspose3dInputOp(input, weight, bias, stride, padding, output_padding, groups, dilation)) + +def convolution(input: Tensor, weight: Tensor, bias: Optional[Tensor], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int) -> Tensor: + assert isinstance(input, Tensor), f'`input` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(input).__module__}.{type(input).__name__}' + input = input.value + assert isinstance(weight, Tensor), f'`weight` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(weight).__module__}.{type(weight).__name__}' + weight = weight.value + if bias is not None: + assert isinstance(bias, Tensor), f'`bias` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(bias).__module__}.{type(bias).__name__}' + bias = bias.value + return Tensor(torch_dialect.AtenConvolutionOp(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups)) + +def convolution_overrideable(input: Tensor, weight: Tensor, bias: Optional[Tensor], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int) -> Tensor: + assert isinstance(input, Tensor), f'`input` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(input).__module__}.{type(input).__name__}' + input = input.value + assert isinstance(weight, Tensor), f'`weight` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(weight).__module__}.{type(weight).__name__}' + weight = weight.value + if bias is not None: + assert isinstance(bias, Tensor), f'`bias` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(bias).__module__}.{type(bias).__name__}' + bias = bias.value + return Tensor(torch_dialect.AtenConvolutionOverrideableOp(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups)) + +def _convolution(input: Tensor, weight: Tensor, bias: Optional[Tensor], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool, allow_tf32: bool) -> Tensor: + assert isinstance(input, Tensor), f'`input` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(input).__module__}.{type(input).__name__}' + input = input.value + assert isinstance(weight, Tensor), f'`weight` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(weight).__module__}.{type(weight).__name__}' + weight = weight.value + if bias is not None: + assert isinstance(bias, Tensor), f'`bias` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(bias).__module__}.{type(bias).__name__}' + bias = bias.value + return Tensor(torch_dialect.Aten_ConvolutionOp(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled, allow_tf32)) + +def _convolution_deprecated(input: Tensor, weight: Tensor, bias: Optional[Tensor], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool) -> Tensor: + assert isinstance(input, Tensor), f'`input` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(input).__module__}.{type(input).__name__}' + input = input.value + assert isinstance(weight, Tensor), f'`weight` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(weight).__module__}.{type(weight).__name__}' + weight = weight.value + if bias is not None: + assert isinstance(bias, Tensor), f'`bias` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(bias).__module__}.{type(bias).__name__}' + bias = bias.value + return Tensor(torch_dialect.Aten_ConvolutionDeprecatedOp(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled)) + +def roll(self_: Tensor, shifts: List[int], dims: List[int] = ()) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenRollOp(self_, shifts, dims)) + +def convolution_backward_overrideable(grad_output: Tensor, input: Tensor, weight: Tensor, stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, output_mask: List[bool]) -> Tuple[Tensor, Tensor, Tensor]: + assert isinstance(grad_output, Tensor), f'`grad_output` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(grad_output).__module__}.{type(grad_output).__name__}' + grad_output = grad_output.value + assert isinstance(input, Tensor), f'`input` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(input).__module__}.{type(input).__name__}' + input = input.value + assert isinstance(weight, Tensor), f'`weight` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(weight).__module__}.{type(weight).__name__}' + weight = weight.value + op_results = get_op_results_or_values(torch_dialect.AtenConvolutionBackwardOverrideableOp(grad_output, input, weight, stride, padding, dilation, transposed, output_padding, groups, output_mask)) + return tuple([Tensor(o) if is_a_torch_tensor(o) else o for o in op_results]) + +def flip(self_: Tensor, dims: List[int]) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenFlipOp(self_, dims)) + +def native_batch_norm(input: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], running_mean: Optional[Tensor], running_var: Optional[Tensor], training: bool, momentum: float, eps: float) -> Tuple[Tensor, Tensor, Tensor]: + assert isinstance(input, Tensor), f'`input` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(input).__module__}.{type(input).__name__}' + input = input.value + if weight is not None: + assert isinstance(weight, Tensor), f'`weight` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(weight).__module__}.{type(weight).__name__}' + weight = weight.value + if bias is not None: + assert isinstance(bias, Tensor), f'`bias` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(bias).__module__}.{type(bias).__name__}' + bias = bias.value + if running_mean is not None: + assert isinstance(running_mean, Tensor), f'`running_mean` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(running_mean).__module__}.{type(running_mean).__name__}' + running_mean = running_mean.value + if running_var is not None: + assert isinstance(running_var, Tensor), f'`running_var` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(running_var).__module__}.{type(running_var).__name__}' + running_var = running_var.value + op_results = get_op_results_or_values(torch_dialect.AtenNativeBatchNormOp(input, weight, bias, running_mean, running_var, training, momentum, eps)) + return tuple([Tensor(o) if is_a_torch_tensor(o) else o for o in op_results]) + +def batch_norm(input: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], running_mean: Optional[Tensor], running_var: Optional[Tensor], training: bool, momentum: float, eps: float, cudnn_enabled: bool) -> Tensor: + assert isinstance(input, Tensor), f'`input` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(input).__module__}.{type(input).__name__}' + input = input.value + if weight is not None: + assert isinstance(weight, Tensor), f'`weight` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(weight).__module__}.{type(weight).__name__}' + weight = weight.value + if bias is not None: + assert isinstance(bias, Tensor), f'`bias` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(bias).__module__}.{type(bias).__name__}' + bias = bias.value + if running_mean is not None: + assert isinstance(running_mean, Tensor), f'`running_mean` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(running_mean).__module__}.{type(running_mean).__name__}' + running_mean = running_mean.value + if running_var is not None: + assert isinstance(running_var, Tensor), f'`running_var` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(running_var).__module__}.{type(running_var).__name__}' + running_var = running_var.value + return Tensor(torch_dialect.AtenBatchNormOp(input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled)) + +def layer_norm(input: Tensor, normalized_shape: List[int], weight: Optional[Tensor] = None, bias: Optional[Tensor] = None, eps: float = 1.0000000000000001e-05, cudnn_enable: bool = True) -> Tensor: + assert isinstance(input, Tensor), f'`input` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(input).__module__}.{type(input).__name__}' + input = input.value + if weight is not None: + assert isinstance(weight, Tensor), f'`weight` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(weight).__module__}.{type(weight).__name__}' + weight = weight.value + if bias is not None: + assert isinstance(bias, Tensor), f'`bias` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(bias).__module__}.{type(bias).__name__}' + bias = bias.value + return Tensor(torch_dialect.AtenLayerNormOp(input, normalized_shape, weight, bias, eps, cudnn_enable)) + +def native_layer_norm(input: Tensor, normalized_shape: List[int], weight: Optional[Tensor], bias: Optional[Tensor], eps: float) -> Tuple[Tensor, Tensor, Tensor]: + assert isinstance(input, Tensor), f'`input` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(input).__module__}.{type(input).__name__}' + input = input.value + if weight is not None: + assert isinstance(weight, Tensor), f'`weight` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(weight).__module__}.{type(weight).__name__}' + weight = weight.value + if bias is not None: + assert isinstance(bias, Tensor), f'`bias` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(bias).__module__}.{type(bias).__name__}' + bias = bias.value + op_results = get_op_results_or_values(torch_dialect.AtenNativeLayerNormOp(input, normalized_shape, weight, bias, eps)) + return tuple([Tensor(o) if is_a_torch_tensor(o) else o for o in op_results]) + +def max_pool2d(self_: Tensor, kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), ceil_mode: bool = False) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenMaxPool2dOp(self_, kernel_size, stride, padding, dilation, ceil_mode)) + +def max_pool2d_with_indices(self_: Tensor, kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), ceil_mode: bool = False) -> Tuple[Tensor, Tensor]: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + op_results = get_op_results_or_values(torch_dialect.AtenMaxPool2dWithIndicesOp(self_, kernel_size, stride, padding, dilation, ceil_mode)) + return tuple([Tensor(o) if is_a_torch_tensor(o) else o for o in op_results]) + +def max_pool2d_with_indices_backward(grad_output: Tensor, self_: Tensor, kernel_size: List[int], stride: List[int], padding: List[int], dilation: List[int], ceil_mode: bool, indices: Tensor) -> Tensor: + assert isinstance(grad_output, Tensor), f'`grad_output` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(grad_output).__module__}.{type(grad_output).__name__}' + grad_output = grad_output.value + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(indices, Tensor), f'`indices` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(indices).__module__}.{type(indices).__name__}' + indices = indices.value + return Tensor(torch_dialect.AtenMaxPool2dWithIndicesBackwardOp(grad_output, self_, kernel_size, stride, padding, dilation, ceil_mode, indices)) + +def avg_pool2d(self_: Tensor, kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0), ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenAvgPool2dOp(self_, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)) + +def softmax_int(self_: Tensor, dim: int, dtype: Optional[int] = None) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + if dtype is not None and isinstance(dtype, Enum): + dtype = dtype.value + return Tensor(torch_dialect.AtenSoftmaxIntOp(self_, dim, dtype)) + +def log_softmax_int(self_: Tensor, dim: int, dtype: Optional[int] = None) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + if dtype is not None and isinstance(dtype, Enum): + dtype = dtype.value + return Tensor(torch_dialect.AtenLogSoftmaxIntOp(self_, dim, dtype)) + +def _log_softmax(self_: Tensor, dim: int, half_to_float: bool) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.Aten_LogSoftmaxOp(self_, dim, half_to_float)) + +def adaptive_avg_pool2d(self_: Tensor, output_size: List[int]) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenAdaptiveAvgPool2dOp(self_, output_size)) + +def topk(self_: Tensor, k: int, dim: int = -1, largest: bool = True, sorted: bool = True) -> Tuple[Tensor, Tensor]: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + op_results = get_op_results_or_values(torch_dialect.AtenTopkOp(self_, k, dim, largest, sorted)) + return tuple([Tensor(o) if is_a_torch_tensor(o) else o for o in op_results]) + +def transpose_int(self_: Tensor, dim0: int, dim1: int) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenTransposeIntOp(self_, dim0, dim1)) + +def permute(self_: Tensor, dims: List[int]) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenPermuteOp(self_, dims)) + +def bmm(self_: Tensor, mat2: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(mat2, Tensor), f'`mat2` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(mat2).__module__}.{type(mat2).__name__}' + mat2 = mat2.value + return Tensor(torch_dialect.AtenBmmOp(self_, mat2)) + +def cumsum(self_: Tensor, dim: int, dtype: Optional[int] = None) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + if dtype is not None and isinstance(dtype, Enum): + dtype = dtype.value + return Tensor(torch_dialect.AtenCumsumOp(self_, dim, dtype)) + +def floor_divide_Scalar(self_: Tensor, other: Number) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenFloorDivideScalarOp(self_, other)) + +def logsumexp(self_: Tensor, dim: List[int], keepdim: bool = False) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenLogsumexpOp(self_, dim, keepdim)) + +def __and___Tensor(self_: Tensor, other: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(other, Tensor), f'`other` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(other).__module__}.{type(other).__name__}' + other = other.value + return Tensor(torch_dialect.Aten__And__TensorOp(self_, other)) + +def _softmax(self_: Tensor, dim: int, half_to_float: bool) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.Aten_SoftmaxOp(self_, dim, half_to_float)) + +def mean(self_: Tensor, dtype: Optional[int] = None) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + if dtype is not None and isinstance(dtype, Enum): + dtype = dtype.value + return Tensor(torch_dialect.AtenMeanOp(self_, dtype)) + +def std(self_: Tensor, unbiased: bool = True) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenStdOp(self_, unbiased)) + +def var(self_: Tensor, unbiased: bool = True) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenVarOp(self_, unbiased)) + +def var_mean(self_: Tensor, unbiased: bool = True) -> Tuple[Tensor, Tensor]: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + op_results = get_op_results_or_values(torch_dialect.AtenVarMeanOp(self_, unbiased)) + return tuple([Tensor(o) if is_a_torch_tensor(o) else o for o in op_results]) + +def nll_loss_forward(self_: Tensor, target: Tensor, weight: Optional[Tensor], reduction: int, ignore_index: int) -> Tuple[Tensor, Tensor]: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(target, Tensor), f'`target` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(target).__module__}.{type(target).__name__}' + target = target.value + if weight is not None: + assert isinstance(weight, Tensor), f'`weight` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(weight).__module__}.{type(weight).__name__}' + weight = weight.value + op_results = get_op_results_or_values(torch_dialect.AtenNllLossForwardOp(self_, target, weight, reduction, ignore_index)) + return tuple([Tensor(o) if is_a_torch_tensor(o) else o for o in op_results]) + +def nll_loss_backward(grad_output: Tensor, self_: Tensor, target: Tensor, weight: Optional[Tensor], reduction: int, ignore_index: int, total_weight: Tensor) -> Tensor: + assert isinstance(grad_output, Tensor), f'`grad_output` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(grad_output).__module__}.{type(grad_output).__name__}' + grad_output = grad_output.value + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(target, Tensor), f'`target` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(target).__module__}.{type(target).__name__}' + target = target.value + if weight is not None: + assert isinstance(weight, Tensor), f'`weight` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(weight).__module__}.{type(weight).__name__}' + weight = weight.value + assert isinstance(total_weight, Tensor), f'`total_weight` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(total_weight).__module__}.{type(total_weight).__name__}' + total_weight = total_weight.value + return Tensor(torch_dialect.AtenNllLossBackwardOp(grad_output, self_, target, weight, reduction, ignore_index, total_weight)) + +def bincount(self_: Tensor, weights: Optional[Tensor] = None, minlength: int = 0) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + if weights is not None: + assert isinstance(weights, Tensor), f'`weights` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(weights).__module__}.{type(weights).__name__}' + weights = weights.value + return Tensor(torch_dialect.AtenBincountOp(self_, weights, minlength)) + +def frobenius_norm_dim(self_: Tensor, dim: List[int], keepdim: bool = False) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenFrobeniusNormDimOp(self_, dim, keepdim)) + +def mse_loss(self_: Tensor, target: Tensor, reduction: int = 1) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(target, Tensor), f'`target` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(target).__module__}.{type(target).__name__}' + target = target.value + return Tensor(torch_dialect.AtenMseLossOp(self_, target, reduction)) + +def upsample_nearest2d_backward(grad_output: Tensor, output_size: List[int], input_size: List[int], scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> Tensor: + assert isinstance(grad_output, Tensor), f'`grad_output` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(grad_output).__module__}.{type(grad_output).__name__}' + grad_output = grad_output.value + return Tensor(torch_dialect.AtenUpsampleNearest2dBackwardOp(grad_output, output_size, input_size, scales_h, scales_w)) + +def constant_pad_nd(self_: Tensor, pad: List[int], value: Number = 0) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenConstantPadNdOp(self_, pad, value)) + +def pad(self_: Tensor, pad: List[int], mode: str = "constant", value: Optional[float] = None) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenPadOp(self_, pad, mode, value)) + +def squeeze_dim(self_: Tensor, dim: int) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenSqueezeDimOp(self_, dim)) + +def squeeze(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenSqueezeOp(self_)) + +def flatten_using_ints(self_: Tensor, start_dim: int = 0, end_dim: int = -1) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenFlattenUsingIntsOp(self_, start_dim, end_dim)) + +def dim(self_: Tensor) -> int: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenDimOp(self_)) + +def size(self_: Tensor) -> List[int]: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenSizeOp(self_)) + +def Bool_Tensor(a: Tensor) -> bool: + assert isinstance(a, Tensor), f'`a` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(a).__module__}.{type(a).__name__}' + a = a.value + return Tensor(torch_dialect.AtenBoolTensorOp(a)) + +def is_floating_point(self_: Tensor) -> bool: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenIsFloatingPointOp(self_)) + +def _shape_as_tensor(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.Aten_ShapeAsTensorOp(self_)) + +def all(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenAllOp(self_)) + +def all_bool(self_: List[bool]) -> bool: + return Tensor(torch_dialect.AtenAllBoolOp(self_)) + +def any(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenAnyOp(self_)) + +def any_dim(self_: Tensor, dim: int, keepdim: bool = False) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenAnyDimOp(self_, dim, keepdim)) + +def arange_start_out(start: Number, end: Number, step: Number = 1, out: Tensor= None) -> Tensor: + assert isinstance(out, Tensor), f'`out` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(out).__module__}.{type(out).__name__}' + out = out.value + return Tensor(torch_dialect.AtenArangeStartOutOp(start, end, step, out)) + +def argmax(self_: Tensor, dim: Optional[int] = None, keepdim: bool = False) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenArgmaxOp(self_, dim, keepdim)) + +def bucketize_Tensor(self_: Tensor, boundaries: Tensor, out_int32: bool = False, right: bool = False) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(boundaries, Tensor), f'`boundaries` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(boundaries).__module__}.{type(boundaries).__name__}' + boundaries = boundaries.value + return Tensor(torch_dialect.AtenBucketizeTensorOp(self_, boundaries, out_int32, right)) + +def clone(self_: Tensor, memory_format: Optional[int] = None) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenCloneOp(self_, memory_format)) + +def lift_fresh_copy(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenLiftFreshCopyOp(self_)) + +def contiguous(self_: Tensor, memory_format: int = 0) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenContiguousOp(self_, memory_format)) + +def copy(self_: Tensor, src: Tensor, non_blocking: bool = False) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(src, Tensor), f'`src` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(src).__module__}.{type(src).__name__}' + src = src.value + return Tensor(torch_dialect.AtenCopyOp(self_, src, non_blocking)) + +def copy_(self_: Tensor, src: Tensor, non_blocking: bool = False) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(src, Tensor), f'`src` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(src).__module__}.{type(src).__name__}' + src = src.value + return Tensor(torch_dialect.AtenCopy_Op(self_, src, non_blocking)) + +def detach(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenDetachOp(self_)) + +def embedding(weight: Tensor, indices: Tensor, padding_idx: int = -1, scale_grad_by_freq: bool = False, sparse: bool = False) -> Tensor: + assert isinstance(weight, Tensor), f'`weight` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(weight).__module__}.{type(weight).__name__}' + weight = weight.value + assert isinstance(indices, Tensor), f'`indices` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(indices).__module__}.{type(indices).__name__}' + indices = indices.value + return Tensor(torch_dialect.AtenEmbeddingOp(weight, indices, padding_idx, scale_grad_by_freq, sparse)) + +def embedding_bag_padding_idx(weight: Tensor, indices: Tensor, offsets: Tensor, scale_grad_by_freq: bool, mode: int, sparse: bool, per_sample_weights: Optional[Tensor], include_last_offset: bool, padding_idx: Optional[int]) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + assert isinstance(weight, Tensor), f'`weight` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(weight).__module__}.{type(weight).__name__}' + weight = weight.value + assert isinstance(indices, Tensor), f'`indices` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(indices).__module__}.{type(indices).__name__}' + indices = indices.value + assert isinstance(offsets, Tensor), f'`offsets` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(offsets).__module__}.{type(offsets).__name__}' + offsets = offsets.value + if per_sample_weights is not None: + assert isinstance(per_sample_weights, Tensor), f'`per_sample_weights` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(per_sample_weights).__module__}.{type(per_sample_weights).__name__}' + per_sample_weights = per_sample_weights.value + op_results = get_op_results_or_values(torch_dialect.AtenEmbeddingBagPaddingIdxOp(weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx)) + return tuple([Tensor(o) if is_a_torch_tensor(o) else o for o in op_results]) + +def _embedding_bag(weight: Tensor, indices: Tensor, offsets: Tensor, scale_grad_by_freq: bool = False, mode: int = 0, sparse: bool = False, per_sample_weights: Optional[Tensor] = None, include_last_offset: bool = False, padding_idx: int = -1) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + assert isinstance(weight, Tensor), f'`weight` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(weight).__module__}.{type(weight).__name__}' + weight = weight.value + assert isinstance(indices, Tensor), f'`indices` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(indices).__module__}.{type(indices).__name__}' + indices = indices.value + assert isinstance(offsets, Tensor), f'`offsets` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(offsets).__module__}.{type(offsets).__name__}' + offsets = offsets.value + if per_sample_weights is not None: + assert isinstance(per_sample_weights, Tensor), f'`per_sample_weights` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(per_sample_weights).__module__}.{type(per_sample_weights).__name__}' + per_sample_weights = per_sample_weights.value + op_results = get_op_results_or_values(torch_dialect.Aten_EmbeddingBagOp(weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx)) + return tuple([Tensor(o) if is_a_torch_tensor(o) else o for o in op_results]) + +def expand(self_: Tensor, size: List[int], implicit: bool = False) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenExpandOp(self_, size, implicit)) + +def expand_as(self_: Tensor, other: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(other, Tensor), f'`other` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(other).__module__}.{type(other).__name__}' + other = other.value + return Tensor(torch_dialect.AtenExpandAsOp(self_, other)) + +def broadcast_to(self_: Tensor, size: List[int]) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenBroadcastToOp(self_, size)) + +def index_Tensor_hacked_twin(self_: Tensor, indices: List[Tensor]) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert builtins.all(isinstance(t, Tensor) for t in indices) + indices = [t.value for t in indices] + return Tensor(torch_dialect.AtenIndexTensorHackedTwinOp(self_, indices)) + +def index_select(self_: Tensor, dim: int, index: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(index, Tensor), f'`index` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(index).__module__}.{type(index).__name__}' + index = index.value + return Tensor(torch_dialect.AtenIndexSelectOp(self_, dim, index)) + +def item(self_: Tensor) -> Number: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenItemOp(self_)) + +def masked_select(self_: Tensor, mask: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(mask, Tensor), f'`mask` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(mask).__module__}.{type(mask).__name__}' + mask = mask.value + return Tensor(torch_dialect.AtenMaskedSelectOp(self_, mask)) + +def numel(self_: Tensor) -> int: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenNumelOp(self_)) + +def repeat(self_: Tensor, repeats: List[int]) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenRepeatOp(self_, repeats)) + +def reshape(self_: Tensor, shape: List[int]) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenReshapeOp(self_, shape)) + +def _reshape_alias(self_: Tensor, size: List[int], stride: List[int]) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.Aten_ReshapeAliasOp(self_, size, stride)) + +def resize_(self_: Tensor, size: List[int], memory_format: Optional[int] = None) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenResize_Op(self_, size, memory_format)) + +def select_int(self_: Tensor, dim: int, index: int) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenSelectIntOp(self_, dim, index)) + +def size_int(self_: Tensor, dim: int) -> int: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenSizeIntOp(self_, dim)) + +def stack(tensors: List[Tensor], dim: int = 0) -> Tensor: + assert builtins.all(isinstance(t, Tensor) for t in tensors) + tensors = [t.value for t in tensors] + return Tensor(torch_dialect.AtenStackOp(tensors, dim)) + +def sum(self_: Tensor, dtype: Optional[int] = None) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + if dtype is not None and isinstance(dtype, Enum): + dtype = dtype.value + return Tensor(torch_dialect.AtenSumOp(self_, dtype)) + +def max(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenMaxOp(self_)) + +def max_dim(self_: Tensor, dim: int, keepdim: bool = False) -> Tuple[Tensor, Tensor]: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + op_results = get_op_results_or_values(torch_dialect.AtenMaxDimOp(self_, dim, keepdim)) + return tuple([Tensor(o) if is_a_torch_tensor(o) else o for o in op_results]) + +def amax(self_: Tensor, dim: List[int] = (), keepdim: bool = False) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenAmaxOp(self_, dim, keepdim)) + +def to_dtype(self_: Tensor, dtype: int, non_blocking: bool = False, copy: bool = False, memory_format: Optional[int] = None) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + if dtype is not None and isinstance(dtype, Enum): + dtype = dtype.value + return Tensor(torch_dialect.AtenToDtypeOp(self_, dtype, non_blocking, copy, memory_format)) + +def to_other(self_: Tensor, other: Tensor, non_blocking: bool = False, copy: bool = False, memory_format: Optional[int] = None) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(other, Tensor), f'`other` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(other).__module__}.{type(other).__name__}' + other = other.value + return Tensor(torch_dialect.AtenToOtherOp(self_, other, non_blocking, copy, memory_format)) + +def type_as(self_: Tensor, other: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(other, Tensor), f'`other` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(other).__module__}.{type(other).__name__}' + other = other.value + return Tensor(torch_dialect.AtenTypeAsOp(self_, other)) + +def view(self_: Tensor, size: List[int]) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenViewOp(self_, size)) + +def _unsafe_view(self_: Tensor, size: List[int]) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.Aten_UnsafeViewOp(self_, size)) + +def where_self(condition: Tensor, self_: Tensor, other: Tensor) -> Tensor: + assert isinstance(condition, Tensor), f'`condition` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(condition).__module__}.{type(condition).__name__}' + condition = condition.value + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(other, Tensor), f'`other` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(other).__module__}.{type(other).__name__}' + other = other.value + return Tensor(torch_dialect.AtenWhereSelfOp(condition, self_, other)) + +def where_Scalar(condition: Tensor, self_: Number, other: Number) -> Tensor: + assert isinstance(condition, Tensor), f'`condition` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(condition).__module__}.{type(condition).__name__}' + condition = condition.value + return Tensor(torch_dialect.AtenWhereScalarOp(condition, self_, other)) + +def where_ScalarOther(condition: Tensor, self_: Tensor, other: Number) -> Tensor: + assert isinstance(condition, Tensor), f'`condition` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(condition).__module__}.{type(condition).__name__}' + condition = condition.value + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenWhereScalarOtherOp(condition, self_, other)) + +def where_ScalarSelf(condition: Tensor, self_: Number, other: Tensor) -> Tensor: + assert isinstance(condition, Tensor), f'`condition` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(condition).__module__}.{type(condition).__name__}' + condition = condition.value + assert isinstance(other, Tensor), f'`other` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(other).__module__}.{type(other).__name__}' + other = other.value + return Tensor(torch_dialect.AtenWhereScalarSelfOp(condition, self_, other)) + +def slice_Tensor(self_: Tensor, dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenSliceTensorOp(self_, dim, start, end, step)) + +def len_Tensor(t: Tensor) -> int: + assert isinstance(t, Tensor), f'`t` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(t).__module__}.{type(t).__name__}' + t = t.value + return Tensor(torch_dialect.AtenLenTensorOp(t)) + +def cpu(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenCpuOp(self_)) + +def gather(self_: Tensor, dim: int, index: Tensor, sparse_grad: bool = False) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(index, Tensor), f'`index` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(index).__module__}.{type(index).__name__}' + index = index.value + return Tensor(torch_dialect.AtenGatherOp(self_, dim, index, sparse_grad)) + +def scatter_add(self_: Tensor, dim: int, index: Tensor, src: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(index, Tensor), f'`index` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(index).__module__}.{type(index).__name__}' + index = index.value + assert isinstance(src, Tensor), f'`src` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(src).__module__}.{type(src).__name__}' + src = src.value + return Tensor(torch_dialect.AtenScatterAddOp(self_, dim, index, src)) + +def IntImplicit(a: Tensor) -> int: + assert isinstance(a, Tensor), f'`a` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(a).__module__}.{type(a).__name__}' + a = a.value + return Tensor(torch_dialect.AtenIntImplicitOp(a)) + +def FloatImplicit(a: Tensor) -> float: + assert isinstance(a, Tensor), f'`a` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(a).__module__}.{type(a).__name__}' + a = a.value + return Tensor(torch_dialect.AtenFloatImplicitOp(a)) + +def Int_Tensor(a: Tensor) -> int: + assert isinstance(a, Tensor), f'`a` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(a).__module__}.{type(a).__name__}' + a = a.value + return Tensor(torch_dialect.AtenIntTensorOp(a)) + +def Float_Tensor(a: Tensor) -> float: + assert isinstance(a, Tensor), f'`a` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(a).__module__}.{type(a).__name__}' + a = a.value + return Tensor(torch_dialect.AtenFloatTensorOp(a)) + +def dropout(input: Tensor, p: float, train: bool) -> Tensor: + assert isinstance(input, Tensor), f'`input` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(input).__module__}.{type(input).__name__}' + input = input.value + return Tensor(torch_dialect.AtenDropoutOp(input, p, train)) + +def dropout_(self_: Tensor, p: float, train: bool) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenDropout_Op(self_, p, train)) + +def native_dropout(input: Tensor, p: float, train: Optional[bool]) -> Tuple[Tensor, Tensor]: + assert isinstance(input, Tensor), f'`input` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(input).__module__}.{type(input).__name__}' + input = input.value + op_results = get_op_results_or_values(torch_dialect.AtenNativeDropoutOp(input, p, train)) + return tuple([Tensor(o) if is_a_torch_tensor(o) else o for o in op_results]) + +def t(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenTOp(self_)) + +def numpy_T(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenNumpyTOp(self_)) + +def baddbmm(self_: Tensor, batch1: Tensor, batch2: Tensor, beta: Number = 1, alpha: Number = 1) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(batch1, Tensor), f'`batch1` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(batch1).__module__}.{type(batch1).__name__}' + batch1 = batch1.value + assert isinstance(batch2, Tensor), f'`batch2` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(batch2).__module__}.{type(batch2).__name__}' + batch2 = batch2.value + return Tensor(torch_dialect.AtenBaddbmmOp(self_, batch1, batch2, beta, alpha)) + +def baddbmm_(self_: Tensor, batch1: Tensor, batch2: Tensor, beta: Number = 1, alpha: Number = 1) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(batch1, Tensor), f'`batch1` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(batch1).__module__}.{type(batch1).__name__}' + batch1 = batch1.value + assert isinstance(batch2, Tensor), f'`batch2` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(batch2).__module__}.{type(batch2).__name__}' + batch2 = batch2.value + return Tensor(torch_dialect.AtenBaddbmm_Op(self_, batch1, batch2, beta, alpha)) + +def fft_fft(self_: Tensor, n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenFftFftOp(self_, n, dim, norm)) + +def alias_copy(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenAliasCopyOp(self_)) + +def as_strided_copy(self_: Tensor, size: List[int], stride: List[int], storage_offset: Optional[int] = None) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenAsStridedCopyOp(self_, size, stride, storage_offset)) + +def diagonal_copy(self_: Tensor, offset: int = 0, dim1: int = 0, dim2: int = 1) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenDiagonalCopyOp(self_, offset, dim1, dim2)) + +def expand_copy(self_: Tensor, size: List[int], implicit: bool = False) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenExpandCopyOp(self_, size, implicit)) + +def permute_copy(self_: Tensor, dims: List[int]) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenPermuteCopyOp(self_, dims)) + +def _reshape_alias_copy(self_: Tensor, size: List[int], stride: List[int]) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.Aten_ReshapeAliasCopyOp(self_, size, stride)) + +def select_copy_int(self_: Tensor, dim: int, index: int) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenSelectCopyIntOp(self_, dim, index)) + +def detach_copy(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenDetachCopyOp(self_)) + +def slice_copy_Tensor(self_: Tensor, dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenSliceCopyTensorOp(self_, dim, start, end, step)) + +def squeeze_copy(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenSqueezeCopyOp(self_)) + +def squeeze_copy_dim(self_: Tensor, dim: int) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenSqueezeCopyDimOp(self_, dim)) + +def t_copy(self_: Tensor) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenTCopyOp(self_)) + +def transpose_copy_int(self_: Tensor, dim0: int, dim1: int) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenTransposeCopyIntOp(self_, dim0, dim1)) + +def unsqueeze_copy(self_: Tensor, dim: int) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenUnsqueezeCopyOp(self_, dim)) + +def view_copy(self_: Tensor, size: List[int]) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenViewCopyOp(self_, size)) + +def view_copy_dtype(self_: Tensor, dtype: int) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + if dtype is not None and isinstance(dtype, Enum): + dtype = dtype.value + return Tensor(torch_dialect.AtenViewCopyDtypeOp(self_, dtype)) + +def unfold_copy(self_: Tensor, dimension: int, size: int, step: int) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenUnfoldCopyOp(self_, dimension, size, step)) + +def select_scatter(self_: Tensor, src: Tensor, dim: int, index: int) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(src, Tensor), f'`src` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(src).__module__}.{type(src).__name__}' + src = src.value + return Tensor(torch_dialect.AtenSelectScatterOp(self_, src, dim, index)) + +def slice_scatter(self_: Tensor, src: Tensor, dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(src, Tensor), f'`src` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(src).__module__}.{type(src).__name__}' + src = src.value + return Tensor(torch_dialect.AtenSliceScatterOp(self_, src, dim, start, end, step)) + +def diagonal_scatter(self_: Tensor, src: Tensor, offset: int = 0, dim1: int = 0, dim2: int = 1) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(src, Tensor), f'`src` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(src).__module__}.{type(src).__name__}' + src = src.value + return Tensor(torch_dialect.AtenDiagonalScatterOp(self_, src, offset, dim1, dim2)) + +def as_strided_scatter(self_: Tensor, src: Tensor, size: List[int], stride: List[int], storage_offset: Optional[int] = None) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + assert isinstance(src, Tensor), f'`src` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(src).__module__}.{type(src).__name__}' + src = src.value + return Tensor(torch_dialect.AtenAsStridedScatterOp(self_, src, size, stride, storage_offset)) + +def upsample_nearest2d(self_: Tensor, output_size: List[int], scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenUpsampleNearest2dOp(self_, output_size, scales_h, scales_w)) + +def __contains___int_list(l: List[int], item: int) -> bool: + return Tensor(torch_dialect.Aten__Contains__IntListOp(l, item)) + +def cat(tensors: List[Tensor], dim: int = 0) -> Tensor: + assert builtins.all(isinstance(t, Tensor) for t in tensors) + tensors = [t.value for t in tensors] + return Tensor(torch_dialect.AtenCatOp(tensors, dim)) + +def append_t(self_: List[Tensor], el: Tensor) -> List[Tensor]: + return Tensor(torch_dialect.AtenAppendTOp(self_, el)) + +def add_t(a: List[Tensor], b: List[Tensor]) -> List[Tensor]: + return Tensor(torch_dialect.AtenAddTOp(a, b)) + +def eq_int_list(a: List[int], b: List[int]) -> bool: + return Tensor(torch_dialect.AtenEqIntListOp(a, b)) + +def list_t(l: List[Tensor]) -> List[Tensor]: + return Tensor(torch_dialect.AtenListTOp(l)) + +def slice_t(l: List[Tensor], start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[Tensor]: + return Tensor(torch_dialect.AtenSliceTOp(l, start, end, step)) + +def insert_t(self_: List[Tensor], idx: int, el: Tensor) -> None: + torch_dialect.AtenInsertTOp(self_, idx, el) + +def ne_int_list(a: List[int], b: List[int]) -> bool: + return Tensor(torch_dialect.AtenNeIntListOp(a, b)) + +def any_bool(self_: List[bool]) -> bool: + return Tensor(torch_dialect.AtenAnyBoolOp(self_)) + +def sort_int(self_: List[int], reverse: bool = False) -> None: + torch_dialect.AtenSortIntOp(self_, reverse) + +def add_str(a: str, b: str) -> str: + return Tensor(torch_dialect.AtenAddStrOp(a, b)) + +def eq_str(a: str, b: str) -> bool: + return Tensor(torch_dialect.AtenEqStrOp(a, b)) + +def len_str(s: str) -> int: + return Tensor(torch_dialect.AtenLenStrOp(s)) + +def str(elem: Tensor) -> str: + return Tensor(torch_dialect.AtenStrOp(elem)) + +def join(self_: str, values: List[str]) -> str: + return Tensor(torch_dialect.AtenJoinOp(self_, values)) + +def Float_Scalar(a: Number) -> float: + return Tensor(torch_dialect.AtenFloatScalarOp(a)) + +def Float_str(a: str) -> float: + return Tensor(torch_dialect.AtenFloatStrOp(a)) + +def Int_float(a: float) -> int: + return Tensor(torch_dialect.AtenIntFloatOp(a)) + +def Int_Scalar(a: Number) -> int: + return Tensor(torch_dialect.AtenIntScalarOp(a)) + +def __range_length(lo: int, hi: int, step: int) -> int: + return Tensor(torch_dialect.Aten__RangeLengthOp(lo, hi, step)) + +def __derive_index(index: int, start: int, step: int) -> int: + return Tensor(torch_dialect.Aten__DeriveIndexOp(index, start, step)) + +def gt_int(a: int, b: int) -> bool: + return Tensor(torch_dialect.AtenGtIntOp(a, b)) + +def ge_int(a: int, b: int) -> bool: + return Tensor(torch_dialect.AtenGeIntOp(a, b)) + +def lt_int(a: int, b: int) -> bool: + return Tensor(torch_dialect.AtenLtIntOp(a, b)) + +def le_int(a: int, b: int) -> bool: + return Tensor(torch_dialect.AtenLeIntOp(a, b)) + +def ne_int(a: int, b: int) -> bool: + return Tensor(torch_dialect.AtenNeIntOp(a, b)) + +def eq_int(a: int, b: int) -> bool: + return Tensor(torch_dialect.AtenEqIntOp(a, b)) + +def floordiv_int(a: int, b: int) -> int: + return Tensor(torch_dialect.AtenFloordivIntOp(a, b)) + +def remainder_int(a: int, b: int) -> int: + return Tensor(torch_dialect.AtenRemainderIntOp(a, b)) + +def remainder_Scalar(self_: Tensor, other: Number) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenRemainderScalarOp(self_, other)) + +def add_int(a: int, b: int) -> int: + return Tensor(torch_dialect.AtenAddIntOp(a, b)) + +def sub_int(a: int, b: int) -> int: + return Tensor(torch_dialect.AtenSubIntOp(a, b)) + +def mul_int(a: int, b: int) -> int: + return Tensor(torch_dialect.AtenMulIntOp(a, b)) + +def div_int(a: int, b: int) -> float: + return Tensor(torch_dialect.AtenDivIntOp(a, b)) + +def neg_int(a: int) -> int: + return Tensor(torch_dialect.AtenNegIntOp(a)) + +def log_int(a: int) -> float: + return Tensor(torch_dialect.AtenLogIntOp(a)) + +def add_float_int(a: float, b: int) -> float: + return Tensor(torch_dialect.AtenAddFloatIntOp(a, b)) + +def sub_float(a: float, b: float) -> float: + return Tensor(torch_dialect.AtenSubFloatOp(a, b)) + +def mul_float(a: float, b: float) -> float: + return Tensor(torch_dialect.AtenMulFloatOp(a, b)) + +def div_float(a: float, b: float) -> float: + return Tensor(torch_dialect.AtenDivFloatOp(a, b)) + +def neg_float(a: float) -> float: + return Tensor(torch_dialect.AtenNegFloatOp(a)) + +def eq_float(a: float, b: float) -> bool: + return Tensor(torch_dialect.AtenEqFloatOp(a, b)) + +def gt_float(a: float, b: float) -> bool: + return Tensor(torch_dialect.AtenGtFloatOp(a, b)) + +def ge_float(a: float, b: float) -> bool: + return Tensor(torch_dialect.AtenGeFloatOp(a, b)) + +def lt_float(a: float, b: float) -> bool: + return Tensor(torch_dialect.AtenLtFloatOp(a, b)) + +def lt_float_int(a: float, b: int) -> bool: + return Tensor(torch_dialect.AtenLtFloatIntOp(a, b)) + +def ge_float_int(a: float, b: int) -> bool: + return Tensor(torch_dialect.AtenGeFloatIntOp(a, b)) + +def ne_float_int(a: float, b: int) -> bool: + return Tensor(torch_dialect.AtenNeFloatIntOp(a, b)) + +def gt_float_int(a: float, b: int) -> bool: + return Tensor(torch_dialect.AtenGtFloatIntOp(a, b)) + +def __and___bool(a: bool, b: bool) -> bool: + return Tensor(torch_dialect.Aten__And__BoolOp(a, b)) + +def ne_bool(a: bool, b: bool) -> bool: + return Tensor(torch_dialect.AtenNeBoolOp(a, b)) + +def __is__(self_: Tensor, obj: Tensor) -> bool: + return Tensor(torch_dialect.Aten__Is__Op(self_, obj)) + +def __isnot__(self_: Tensor, obj: Tensor) -> bool: + return Tensor(torch_dialect.Aten__Isnot__Op(self_, obj)) + +def __not__(self_: bool) -> bool: + return Tensor(torch_dialect.Aten__Not__Op(self_)) + +def len_t(a: List[Tensor]) -> int: + return Tensor(torch_dialect.AtenLenTOp(a)) + +def __getitem___t(list_: List[Tensor], idx: int) -> Tensor: + return Tensor(torch_dialect.Aten__Getitem__TOp(list_, idx)) + +def _set_item_t(l: List[Tensor], idx: int, el: Tensor) -> List[Tensor]: + return Tensor(torch_dialect.Aten_SetItemTOp(l, idx, el)) + +def div(a: Number, b: Number) -> float: + return Tensor(torch_dialect.AtenDivOp(a, b)) + +def add(a: Number, b: Number) -> Number: + return Tensor(torch_dialect.AtenAddOp(a, b)) + +def sub(a: Number, b: Number) -> Number: + return Tensor(torch_dialect.AtenSubOp(a, b)) + +def ceil_Scalar(a: Number) -> Number: + return Tensor(torch_dialect.AtenCeilScalarOp(a)) + +def sqrt_int(a: int) -> float: + return Tensor(torch_dialect.AtenSqrtIntOp(a)) + +def Bool_float(a: float) -> bool: + return Tensor(torch_dialect.AtenBoolFloatOp(a)) + +def Bool_int(a: int) -> bool: + return Tensor(torch_dialect.AtenBoolIntOp(a)) + +def ceil_float(a: float) -> int: + return Tensor(torch_dialect.AtenCeilFloatOp(a)) + +def narrow(self_: Tensor, dim: int, start: int, length: int) -> Tensor: + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenNarrowOp(self_, dim, start, length)) + +def ScalarImplicit(a: Tensor) -> Number: + assert isinstance(a, Tensor), f'`a` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(a).__module__}.{type(a).__name__}' + a = a.value + return Tensor(torch_dialect.AtenScalarImplicitOp(a)) + +def _softmax_backward_data(grad_output: Tensor, output: Tensor, dim: int, input_dtype: int) -> Tensor: + assert isinstance(grad_output, Tensor), f'`grad_output` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(grad_output).__module__}.{type(grad_output).__name__}' + grad_output = grad_output.value + assert isinstance(output, Tensor), f'`output` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(output).__module__}.{type(output).__name__}' + output = output.value + return Tensor(torch_dialect.Aten_SoftmaxBackwardDataOp(grad_output, output, dim, input_dtype)) + +def tanh_backward(grad_output: Tensor, output: Tensor) -> Tensor: + assert isinstance(grad_output, Tensor), f'`grad_output` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(grad_output).__module__}.{type(grad_output).__name__}' + grad_output = grad_output.value + assert isinstance(output, Tensor), f'`output` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(output).__module__}.{type(output).__name__}' + output = output.value + return Tensor(torch_dialect.AtenTanhBackwardOp(grad_output, output)) + +def gelu_backward(grad_output: Tensor, self_: Tensor, approximate: str = "none") -> Tensor: + assert isinstance(grad_output, Tensor), f'`grad_output` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(grad_output).__module__}.{type(grad_output).__name__}' + grad_output = grad_output.value + assert isinstance(self_, Tensor), f'`self_` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(self_).__module__}.{type(self_).__name__}' + self_ = self_.value + return Tensor(torch_dialect.AtenGeluBackwardOp(grad_output, self_, approximate)) + +def _log_softmax_backward_data(grad_output: Tensor, output: Tensor, dim: int, input_dtype: int) -> Tensor: + assert isinstance(grad_output, Tensor), f'`grad_output` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(grad_output).__module__}.{type(grad_output).__name__}' + grad_output = grad_output.value + assert isinstance(output, Tensor), f'`output` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(output).__module__}.{type(output).__name__}' + output = output.value + return Tensor(torch_dialect.Aten_LogSoftmaxBackwardDataOp(grad_output, output, dim, input_dtype)) + +def native_layer_norm_backward(grad_out: Tensor, input: Tensor, normalized_shape: List[int], mean: Tensor, rstd: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], output_mask: List[bool]) -> Tuple[Tensor, Tensor, Tensor]: + assert isinstance(grad_out, Tensor), f'`grad_out` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(grad_out).__module__}.{type(grad_out).__name__}' + grad_out = grad_out.value + assert isinstance(input, Tensor), f'`input` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(input).__module__}.{type(input).__name__}' + input = input.value + assert isinstance(mean, Tensor), f'`mean` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(mean).__module__}.{type(mean).__name__}' + mean = mean.value + assert isinstance(rstd, Tensor), f'`rstd` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(rstd).__module__}.{type(rstd).__name__}' + rstd = rstd.value + if weight is not None: + assert isinstance(weight, Tensor), f'`weight` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(weight).__module__}.{type(weight).__name__}' + weight = weight.value + if bias is not None: + assert isinstance(bias, Tensor), f'`bias` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(bias).__module__}.{type(bias).__name__}' + bias = bias.value + op_results = get_op_results_or_values(torch_dialect.AtenNativeLayerNormBackwardOp(grad_out, input, normalized_shape, mean, rstd, weight, bias, output_mask)) + return tuple([Tensor(o) if is_a_torch_tensor(o) else o for o in op_results]) + +def embedding_dense_backward(grad_output: Tensor, indices: Tensor, num_weights: int, padding_idx: int, scale_grad_by_freq: bool) -> Tensor: + assert isinstance(grad_output, Tensor), f'`grad_output` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(grad_output).__module__}.{type(grad_output).__name__}' + grad_output = grad_output.value + assert isinstance(indices, Tensor), f'`indices` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(indices).__module__}.{type(indices).__name__}' + indices = indices.value + return Tensor(torch_dialect.AtenEmbeddingDenseBackwardOp(grad_output, indices, num_weights, padding_idx, scale_grad_by_freq)) + +def native_batch_norm_backward(grad_out: Tensor, input: Tensor, weight: Optional[Tensor], running_mean: Optional[Tensor], running_var: Optional[Tensor], save_mean: Optional[Tensor], save_invstd: Optional[Tensor], train: bool, eps: float, output_mask: List[bool]) -> Tuple[Tensor, Tensor, Tensor]: + assert isinstance(grad_out, Tensor), f'`grad_out` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(grad_out).__module__}.{type(grad_out).__name__}' + grad_out = grad_out.value + assert isinstance(input, Tensor), f'`input` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(input).__module__}.{type(input).__name__}' + input = input.value + if weight is not None: + assert isinstance(weight, Tensor), f'`weight` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(weight).__module__}.{type(weight).__name__}' + weight = weight.value + if running_mean is not None: + assert isinstance(running_mean, Tensor), f'`running_mean` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(running_mean).__module__}.{type(running_mean).__name__}' + running_mean = running_mean.value + if running_var is not None: + assert isinstance(running_var, Tensor), f'`running_var` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(running_var).__module__}.{type(running_var).__name__}' + running_var = running_var.value + if save_mean is not None: + assert isinstance(save_mean, Tensor), f'`save_mean` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(save_mean).__module__}.{type(save_mean).__name__}' + save_mean = save_mean.value + if save_invstd is not None: + assert isinstance(save_invstd, Tensor), f'`save_invstd` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(save_invstd).__module__}.{type(save_invstd).__name__}' + save_invstd = save_invstd.value + op_results = get_op_results_or_values(torch_dialect.AtenNativeBatchNormBackwardOp(grad_out, input, weight, running_mean, running_var, save_mean, save_invstd, train, eps, output_mask)) + return tuple([Tensor(o) if is_a_torch_tensor(o) else o for o in op_results]) + +def native_dropout_backward(grad_output: Tensor, mask: Tensor, scale: float) -> Tensor: + assert isinstance(grad_output, Tensor), f'`grad_output` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(grad_output).__module__}.{type(grad_output).__name__}' + grad_output = grad_output.value + assert isinstance(mask, Tensor), f'`mask` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(mask).__module__}.{type(mask).__name__}' + mask = mask.value + return Tensor(torch_dialect.AtenNativeDropoutBackwardOp(grad_output, mask, scale)) + +def prim_layout(a: Tensor) -> int: + assert isinstance(a, Tensor), f'`a` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(a).__module__}.{type(a).__name__}' + a = a.value + return Tensor(torch_dialect.PrimLayoutOp(a)) + +def prim_TupleIndex(tup: Any, i: int) -> Any: + return Tensor(torch_dialect.PrimTupleIndexOp(tup, i)) + +def prim_dtype(a: Tensor) -> int: + assert isinstance(a, Tensor), f'`a` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(a).__module__}.{type(a).__name__}' + a = a.value + return Tensor(torch_dialect.PrimDtypeOp(a)) + +def prim_NumToTensor_Scalar(a: Number) -> Tensor: + return Tensor(torch_dialect.PrimNumToTensorScalarOp(a)) + +def prim_min_self_int(self_: List[int]) -> int: + return Tensor(torch_dialect.PrimMinSelfIntOp(self_)) + +def prim_min_int(a: int, b: int) -> int: + return Tensor(torch_dialect.PrimMinIntOp(a, b)) + +def prim_max_self_int(self_: List[int]) -> int: + return Tensor(torch_dialect.PrimMaxSelfIntOp(self_)) + +def prim_max_int(a: int, b: int) -> int: + return Tensor(torch_dialect.PrimMaxIntOp(a, b)) + +def prim_RaiseException(msg: str, cls: Optional[str] = None) -> None: + torch_dialect.PrimRaiseExceptionOp(msg, cls) + +def prim_Uninitialized() -> Any: + return Tensor(torch_dialect.PrimUninitializedOp()) + +def prim_unchecked_cast(x: Tensor) -> Tensor: + return Tensor(torch_dialect.PrimUncheckedCastOp(x)) + +def prim_abs_Scalar(a: Number) -> Number: + return Tensor(torch_dialect.PrimAbsScalarOp(a)) + +def prims_convert_element_type(a: Tensor, dtype: int) -> Tensor: + assert isinstance(a, Tensor), f'`a` should be a {Tensor.__module__}.{Tensor.__name__} but is {type(a).__module__}.{type(a).__name__}' + a = a.value + if dtype is not None and isinstance(dtype, Enum): + dtype = dtype.value + return Tensor(torch_dialect.PrimsConvertElementTypeOp(a, dtype)) + + + +__all__ = ['tanh', 'tanh_', 'hardtanh', 'hardtanh_', 'relu', 'relu_', 'relu6', 'relu6_', 'leaky_relu', 'leaky_relu_', 'log', 'log_', 'sigmoid', 'sigmoid_', 'hardsigmoid', 'hardsigmoid_', 'hardswish', 'hardswish_', 'erf', 'erf_', 'silu', 'silu_', 'sin', 'sin_', 'exp', 'exp_', 'expm1', 'expm1_', 'cos', 'cos_', 'atan2', 'atan2_', 'neg', 'neg_', 'floor', 'floor_', 'ceil', 'ceil_', 'bitwise_not', 'bitwise_not_', 'div_Tensor', 'div__Tensor', 'logical_or', 'logical_or_', 'logical_and', 'logical_and_', 'logical_xor', 'logical_xor_', 'logical_not', 'logical_not_', 'lerp_Tensor', 'lerp__Tensor', 'eq_Tensor', 'eq__Tensor', 'gt_Tensor', 'gt__Tensor', 'ge_Tensor', 'ge__Tensor', 'lt_Tensor', 'lt__Tensor', 'le_Tensor', 'le__Tensor', 'ne_Tensor', 'ne__Tensor', 'div_Scalar', 'div__Scalar', 'ne_Scalar', 'ne__Scalar', 'eq_Scalar', 'eq__Scalar', 'gt_Scalar', 'gt__Scalar', 'ge_Scalar', 'ge__Scalar', 'lt_Scalar', 'lt__Scalar', 'le_Scalar', 'le__Scalar', 'fmod_Scalar', 'fmod__Scalar', 'masked_fill_Scalar', 'masked_fill__Scalar', 'masked_fill_Tensor', 'masked_fill__Tensor', 'clamp', 'clamp_', 'clamp_min', 'clamp_min_', 'clamp_max', 'clamp_max_', 'log2', 'log2_', 'sqrt', 'sqrt_', 'log1p', 'log1p_', 'rsqrt', 'rsqrt_', 'abs', 'abs_', 'reciprocal', 'reciprocal_', 'bitwise_and_Tensor', 'bitwise_and__Tensor', 'bitwise_or_Tensor', 'bitwise_or__Tensor', 'threshold', 'threshold_', 'square', 'square_', 'unsqueeze', 'unsqueeze_', 'zero', 'zero_', 'fill_Scalar', 'fill__Scalar', 'fill_Tensor', 'fill__Tensor', 'div_Tensor_mode', 'div__Tensor_mode', 'mul_Tensor', 'mul__Tensor', 'add_Tensor', 'add__Tensor', 'sub_Tensor', 'sub__Tensor', 'add_Scalar', 'add__Scalar', 'sub_Scalar', 'sub__Scalar', 'mul_Scalar', 'mul__Scalar', 'addcmul', 'addcmul_', 'addcdiv', 'addcdiv_', 'maximum', 'minimum', 'mish', 'rsub_Scalar', 'gelu', 'pow_Tensor_Scalar', 'pow_Tensor_Tensor', 'threshold_backward', 'floor_divide', 'softplus', 'prelu', 'triu', 'triu_', 'round', 'round_', 'index_put_hacked_twin', 'index_put__hacked_twin', 'linear', 'mm', 'addmm', 'matmul', 'mv', 'conv2d', 'conv_transpose1d', 'conv_transpose2d_input', 'conv_transpose3d_input', 'convolution', 'convolution_overrideable', '_convolution', '_convolution_deprecated', 'roll', 'convolution_backward_overrideable', 'flip', 'native_batch_norm', 'batch_norm', 'layer_norm', 'native_layer_norm', 'max_pool2d', 'max_pool2d_with_indices', 'max_pool2d_with_indices_backward', 'avg_pool2d', 'softmax_int', 'log_softmax_int', '_log_softmax', 'adaptive_avg_pool2d', 'topk', 'transpose_int', 'permute', 'bmm', 'cumsum', 'floor_divide_Scalar', 'logsumexp', '__and___Tensor', '_softmax', 'mean', 'std', 'var', 'var_mean', 'nll_loss_forward', 'nll_loss_backward', 'bincount', 'frobenius_norm_dim', 'mse_loss', 'upsample_nearest2d_backward', 'constant_pad_nd', 'pad', 'squeeze_dim', 'squeeze', 'flatten_using_ints', 'dim', 'size', 'Bool_Tensor', 'is_floating_point', '_shape_as_tensor', 'all', 'all_bool', 'any', 'any_dim', 'arange_start_out', 'argmax', 'bucketize_Tensor', 'clone', 'lift_fresh_copy', 'contiguous', 'copy', 'copy_', 'detach', 'embedding', 'embedding_bag_padding_idx', '_embedding_bag', 'expand', 'expand_as', 'broadcast_to', 'index_Tensor_hacked_twin', 'index_select', 'item', 'masked_select', 'numel', 'repeat', 'reshape', '_reshape_alias', 'resize_', 'select_int', 'size_int', 'stack', 'sum', 'max', 'max_dim', 'amax', 'to_dtype', 'to_other', 'type_as', 'view', '_unsafe_view', 'where_self', 'where_Scalar', 'where_ScalarOther', 'where_ScalarSelf', 'slice_Tensor', 'len_Tensor', 'cpu', 'gather', 'scatter_add', 'IntImplicit', 'FloatImplicit', 'Int_Tensor', 'Float_Tensor', 'dropout', 'dropout_', 'native_dropout', 't', 'numpy_T', 'baddbmm', 'baddbmm_', 'fft_fft', 'alias_copy', 'as_strided_copy', 'diagonal_copy', 'expand_copy', 'permute_copy', '_reshape_alias_copy', 'select_copy_int', 'detach_copy', 'slice_copy_Tensor', 'squeeze_copy', 'squeeze_copy_dim', 't_copy', 'transpose_copy_int', 'unsqueeze_copy', 'view_copy', 'view_copy_dtype', 'unfold_copy', 'select_scatter', 'slice_scatter', 'diagonal_scatter', 'as_strided_scatter', 'upsample_nearest2d', '__contains___int_list', 'cat', 'append_t', 'add_t', 'eq_int_list', 'list_t', 'slice_t', 'insert_t', 'ne_int_list', 'any_bool', 'sort_int', 'add_str', 'eq_str', 'len_str', 'str', 'join', 'Float_Scalar', 'Float_str', 'Int_float', 'Int_Scalar', '__range_length', '__derive_index', 'gt_int', 'ge_int', 'lt_int', 'le_int', 'ne_int', 'eq_int', 'floordiv_int', 'remainder_int', 'remainder_Scalar', 'add_int', 'sub_int', 'mul_int', 'div_int', 'neg_int', 'log_int', 'add_float_int', 'sub_float', 'mul_float', 'div_float', 'neg_float', 'eq_float', 'gt_float', 'ge_float', 'lt_float', 'lt_float_int', 'ge_float_int', 'ne_float_int', 'gt_float_int', '__and___bool', 'ne_bool', '__is__', '__isnot__', '__not__', 'len_t', '__getitem___t', '_set_item_t', 'div', 'add', 'sub', 'ceil_Scalar', 'sqrt_int', 'Bool_float', 'Bool_int', 'ceil_float', 'narrow', 'ScalarImplicit', '_softmax_backward_data', 'tanh_backward', 'gelu_backward', '_log_softmax_backward_data', 'native_layer_norm_backward', 'embedding_dense_backward', 'native_batch_norm_backward', 'native_dropout_backward', 'prim_layout', 'prim_TupleIndex', 'prim_dtype', 'prim_NumToTensor_Scalar', 'prim_min_self_int', 'prim_min_int', 'prim_max_self_int', 'prim_max_int', 'prim_RaiseException', 'prim_Uninitialized', 'prim_unchecked_cast', 'prim_abs_Scalar', 'prims_convert_element_type'] diff --git a/shark/dialects/affine_.py b/pi/dialects/affine_.py similarity index 100% rename from shark/dialects/affine_.py rename to pi/dialects/affine_.py diff --git a/shark/dialects/value_.py b/pi/dialects/value_.py similarity index 99% rename from shark/dialects/value_.py rename to pi/dialects/value_.py index 7d5837e..9c1f71f 100644 --- a/shark/dialects/value_.py +++ b/pi/dialects/value_.py @@ -1,7 +1,7 @@ from typing import Any, List from torch_mlir.dialects import arith, math -from ._ods_common import get_op_result_or_value +from torch_mlir.dialects._ods_common import get_op_result_or_value from torch_mlir.dialects.linalg.opdsl.lang.emitter import ( _is_integer_type, _is_floating_point_type, diff --git a/pi/mlir_utils.py b/pi/mlir_utils.py new file mode 100644 index 0000000..0177424 --- /dev/null +++ b/pi/mlir_utils.py @@ -0,0 +1,112 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. +import contextlib +from io import StringIO +import os +import sys +import tempfile + +from torch_mlir.passmanager import PassManager +from torch_mlir.ir import StringAttr + + +def get_module_name_for_debug_dump(module): + """Gets a name suitable for a debug dump. + + The name is not guaranteed to be unique. + """ + if not "torch.debug_module_name" in module.operation.attributes: + return "UnnammedModule" + return StringAttr(module.operation.attributes["torch.debug_module_name"]).value + + +class TorchMlirCompilerError(Exception): + def __init__(self, value: str): + super().__init__() + self.value = value + + def __str__(self) -> str: + return self.value + + +def run_pipeline_with_repro_report( + module, + pipeline: str, + description: str, + enable_ir_printing=False, + print_pipeline=False, +): + """Runs `pipeline` on `module`, with a nice repro report if it fails.""" + module_name = get_module_name_for_debug_dump(module) + try: + original_stderr = sys.stderr + sys.stderr = StringIO() + asm_for_error_report = module.operation.get_asm( + large_elements_limit=10, enable_debug_info=True + ) + # Lower module in place to make it ready for compiler backends. + with module.context: + pm = PassManager.parse(pipeline) + if print_pipeline: + print(pm) + if enable_ir_printing: + pm.enable_ir_printing() + pm.run(module) + except Exception as e: + print(e, file=sys.stderr) + filename = os.path.join(tempfile.gettempdir(), module_name + ".mlir") + with open(filename, "w") as f: + f.write(asm_for_error_report) + debug_options = "-mlir-print-ir-after-all -mlir-disable-threading" + # Put something descriptive here even if description is empty. + description = description or f"{module_name} compile" + + message = f"""\ + {description} failed with the following diagnostics: + {sys.stderr.getvalue()} + + For Torch-MLIR developers, the error can be reproduced with: + $ torch-mlir-opt -pass-pipeline='{pipeline}' {filename} + Add '{debug_options}' to get the IR dump for debugging purpose. + """ + trimmed_message = "\n".join([m.lstrip() for m in message.split("\n")]) + raise TorchMlirCompilerError(trimmed_message) from None + finally: + sys.stderr = original_stderr + + +def lower_pi_to_linalg(module): + run_pipeline_with_repro_report( + module, + "builtin.module(" + + ",".join( + [ + # "builtin.module(torchscript-module-to-torch-backend-pipeline)", + # "torchscript-module-to-torch-backend-pipeline", + "symbol-dce", + "inline{default-pipeline= max-iterations=4}", + "torch-adjust-calling-conventions", + "torch-lower-to-backend-contract{decompose=true max-iterations=10}", + "torch-backend-to-linalg-on-tensors-backend-pipeline", + ] + ) + + ")", + "Lowering Torch MLIR -> Linalg", + enable_ir_printing=False, + ) + return module + + +@contextlib.contextmanager +def mlir_cm(enable_multithreading=False): + from torch_mlir.ir import Context, Location, Module, InsertionPoint + from torch_mlir.dialects import torch as torch_dialect + + with Context() as ctx, Location.unknown(): + ctx.enable_multithreading(enable_multithreading) + torch_dialect.register_dialect(ctx, True) + module = Module.create() + with InsertionPoint(module.body): + yield module diff --git a/pi/nn/__init__.py b/pi/nn/__init__.py new file mode 100644 index 0000000..270dceb --- /dev/null +++ b/pi/nn/__init__.py @@ -0,0 +1 @@ +from .modules import * diff --git a/pi/nn/_reduction.py b/pi/nn/_reduction.py new file mode 100644 index 0000000..52b2283 --- /dev/null +++ b/pi/nn/_reduction.py @@ -0,0 +1,54 @@ +from typing import Optional +import warnings + +# NB: Keep this file in sync with enums in aten/src/ATen/core/Reduction.h + + +def get_enum(reduction: str) -> int: + if reduction == "none": + ret = 0 + elif reduction == "mean": + ret = 1 + elif reduction == "elementwise_mean": + warnings.warn( + "reduction='elementwise_mean' is deprecated, please use reduction='mean' instead." + ) + ret = 1 + elif reduction == "sum": + ret = 2 + else: + ret = -1 # TODO: remove once JIT exceptions support control flow + raise ValueError("{} is not a valid value for reduction".format(reduction)) + return ret + + +# In order to support previous versions, accept boolean size_average and reduce +# and convert them into the new constants for now + + +# We use these functions in torch/legacy as well, in which case we'll silence the warning +def legacy_get_string( + size_average: Optional[bool], reduce: Optional[bool], emit_warning: bool = True +) -> str: + warning = "size_average and reduce args will be deprecated, please use reduction='{}' instead." + + if size_average is None: + size_average = True + if reduce is None: + reduce = True + + if size_average and reduce: + ret = "mean" + elif reduce: + ret = "sum" + else: + ret = "none" + if emit_warning: + warnings.warn(warning.format(ret)) + return ret + + +def legacy_get_enum( + size_average: Optional[bool], reduce: Optional[bool], emit_warning: bool = True +) -> int: + return get_enum(legacy_get_string(size_average, reduce, emit_warning)) diff --git a/pi/nn/common_types.py b/pi/nn/common_types.py new file mode 100644 index 0000000..ee134bb --- /dev/null +++ b/pi/nn/common_types.py @@ -0,0 +1,42 @@ +from typing import TypeVar, Union, Tuple, Optional +from .. import Tensor + +# Create some useful type aliases + +# Template for arguments which can be supplied as a tuple, or which can be a scalar which PyTorch will internally +# broadcast to a tuple. +# Comes in several variants: A tuple of unknown size, and a fixed-size tuple for 1d, 2d, or 3d operations. +T = TypeVar("T") +_scalar_or_tuple_any_t = Union[T, Tuple[T, ...]] +_scalar_or_tuple_1_t = Union[T, Tuple[T]] +_scalar_or_tuple_2_t = Union[T, Tuple[T, T]] +_scalar_or_tuple_3_t = Union[T, Tuple[T, T, T]] +_scalar_or_tuple_4_t = Union[T, Tuple[T, T, T, T]] +_scalar_or_tuple_5_t = Union[T, Tuple[T, T, T, T, T]] +_scalar_or_tuple_6_t = Union[T, Tuple[T, T, T, T, T, T]] + +# For arguments which represent size parameters (eg, kernel size, padding) +_size_any_t = _scalar_or_tuple_any_t[int] +_size_1_t = _scalar_or_tuple_1_t[int] +_size_2_t = _scalar_or_tuple_2_t[int] +_size_3_t = _scalar_or_tuple_3_t[int] +_size_4_t = _scalar_or_tuple_4_t[int] +_size_5_t = _scalar_or_tuple_5_t[int] +_size_6_t = _scalar_or_tuple_6_t[int] + +# For arguments which represent optional size parameters (eg, adaptive pool parameters) +_size_any_opt_t = _scalar_or_tuple_any_t[Optional[int]] +_size_2_opt_t = _scalar_or_tuple_2_t[Optional[int]] +_size_3_opt_t = _scalar_or_tuple_3_t[Optional[int]] + +# For arguments that represent a ratio to adjust each dimension of an input with (eg, upsampling parameters) +_ratio_2_t = _scalar_or_tuple_2_t[float] +_ratio_3_t = _scalar_or_tuple_3_t[float] +_ratio_any_t = _scalar_or_tuple_any_t[float] + +_tensor_list_t = _scalar_or_tuple_any_t[Tensor] + +# For the return value of max pooling operations that may or may not return indices. +# With the proposed 'Literal' feature to Python typing, it might be possible to +# eventually eliminate this. +_maybe_indices_t = _scalar_or_tuple_2_t[Tensor] diff --git a/pi/nn/functional.py b/pi/nn/functional.py new file mode 100644 index 0000000..800035a --- /dev/null +++ b/pi/nn/functional.py @@ -0,0 +1,2531 @@ +import warnings +from typing import List, Optional, Tuple, Union, Callable + +import math + +import pi +from . import _reduction as _Reduction +from .modules.utils import _list_with_default, _pair, _triple, _single +from ..types_ import ( + dtype, + BroadcastingList2, + boolean_dispatch, + BroadcastingList3, + BroadcastingList1, +) +from .. import _VF + +Tensor = pi.Tensor + +# conv1d = pi.conv1d + +conv2d = pi.conv2d + +# conv3d = pi.conv3d + +conv_transpose1d = pi.conv_transpose1d + +# conv_transpose2d = pi.conv_transpose2d + +# conv_transpose3d = pi.conv_transpose3d + +# conv_tbc = pi.conv_tbc + +# Pooling +# avg_pool1d = pi.avg_pool1d + +avg_pool2d = pi._nn.avg_pool2d + + +# avg_pool3d = pi._C._nn.avg_pool3d + + +def fractional_max_pool2d_with_indices( + input: Tensor, + kernel_size: BroadcastingList2[int], + output_size: Optional[BroadcastingList2[int]] = None, + output_ratio: Optional[BroadcastingList2[float]] = None, + return_indices: bool = False, + _random_samples: Optional[Tensor] = None, +) -> Tuple[Tensor, Tensor]: + if output_size is None and output_ratio is None: + raise ValueError( + "fractional_max_pool2d requires specifying either " + "an output_size or an output_ratio" + ) + if output_size is None: + assert output_ratio is not None + _output_ratio = _pair(output_ratio) + output_size = [ + int(input.size(-2) * _output_ratio[0]), + int(input.size(-1) * _output_ratio[1]), + ] + + if _random_samples is None: + n_batch = 1 if input.dim() == 3 else input.size(0) + _random_samples = pi.rand( + n_batch, input.size(-3), 2, dtype=input.dtype, device=input.device + ) + return pi._C._nn.fractional_max_pool2d( + input, kernel_size, output_size, _random_samples + ) + + +def _fractional_max_pool2d( + input: Tensor, + kernel_size: BroadcastingList2[int], + output_size: Optional[BroadcastingList2[int]] = None, + output_ratio: Optional[BroadcastingList2[float]] = None, + return_indices: bool = False, + _random_samples: Optional[Tensor] = None, +) -> Tensor: + return fractional_max_pool2d_with_indices( + input, kernel_size, output_size, output_ratio, return_indices, _random_samples + )[0] + + +fractional_max_pool2d = boolean_dispatch( + arg_name="return_indices", + arg_index=4, + default=False, + if_true=fractional_max_pool2d_with_indices, + if_false=_fractional_max_pool2d, + module_name=__name__, + func_name="fractional_max_pool2d", +) + + +def fractional_max_pool3d_with_indices( + input: Tensor, + kernel_size: BroadcastingList3[int], + output_size: Optional[BroadcastingList3[int]] = None, + output_ratio: Optional[BroadcastingList3[float]] = None, + return_indices: bool = False, + _random_samples: Optional[Tensor] = None, +) -> Tuple[Tensor, Tensor]: + if output_size is None and output_ratio is None: + raise ValueError( + "fractional_max_pool3d requires specifying either " + "an output_size or an output_ratio" + ) + if output_size is None: + assert output_ratio is not None + _output_ratio = _triple(output_ratio) + output_size = [ + int(input.size(-3) * _output_ratio[0]), + int(input.size(-2) * _output_ratio[1]), + int(input.size(-1) * _output_ratio[2]), + ] + + if _random_samples is None: + n_batch = 1 if input.dim() == 4 else input.size(0) + _random_samples = pi.rand( + n_batch, input.size(-4), 3, dtype=input.dtype, device=input.device + ) + return pi._C._nn.fractional_max_pool3d( + input, kernel_size, output_size, _random_samples + ) + + +def _fractional_max_pool3d( + input: Tensor, + kernel_size: BroadcastingList3[int], + output_size: Optional[BroadcastingList3[int]] = None, + output_ratio: Optional[BroadcastingList3[float]] = None, + return_indices: bool = False, + _random_samples: Optional[Tensor] = None, +) -> Tensor: + return fractional_max_pool3d_with_indices( + input, kernel_size, output_size, output_ratio, return_indices, _random_samples + )[0] + + +fractional_max_pool3d = boolean_dispatch( + arg_name="return_indices", + arg_index=4, + default=False, + if_true=fractional_max_pool3d_with_indices, + if_false=_fractional_max_pool3d, + module_name=__name__, + func_name="fractional_max_pool3d", +) + + +def max_pool1d_with_indices( + input: Tensor, + kernel_size: BroadcastingList1[int], + stride: Optional[BroadcastingList1[int]] = None, + padding: BroadcastingList1[int] = 0, + dilation: BroadcastingList1[int] = 1, + ceil_mode: bool = False, + return_indices: bool = False, +) -> Tuple[Tensor, Tensor]: + if stride is None: + stride = pi.jit.annotate(List[int], []) + return pi.max_pool1d_with_indices( + input, kernel_size, stride, padding, dilation, ceil_mode + ) + + +def _max_pool1d( + input: Tensor, + kernel_size: BroadcastingList1[int], + stride: Optional[BroadcastingList1[int]] = None, + padding: BroadcastingList1[int] = 0, + dilation: BroadcastingList1[int] = 1, + ceil_mode: bool = False, + return_indices: bool = False, +) -> Tensor: + if stride is None: + stride = pi.jit.annotate(List[int], []) + return pi.max_pool1d(input, kernel_size, stride, padding, dilation, ceil_mode) + + +max_pool1d = boolean_dispatch( + arg_name="return_indices", + arg_index=6, + default=False, + if_true=max_pool1d_with_indices, + if_false=_max_pool1d, + module_name=__name__, + func_name="max_pool1d", +) + + +def max_pool2d_with_indices( + input: Tensor, + kernel_size: BroadcastingList2[int], + stride: Optional[BroadcastingList2[int]] = None, + padding: BroadcastingList2[int] = 0, + dilation: BroadcastingList2[int] = 1, + ceil_mode: bool = False, + return_indices: bool = False, +) -> Tuple[Tensor, Tensor]: + return pi._nn.max_pool2d_with_indices( + input, kernel_size, stride, padding, dilation, ceil_mode + ) + + +def _max_pool2d( + input: Tensor, + kernel_size: BroadcastingList2[int], + stride: Optional[BroadcastingList2[int]] = None, + padding: BroadcastingList2[int] = 0, + dilation: BroadcastingList2[int] = 1, + ceil_mode: bool = False, + return_indices: bool = False, +) -> Tensor: + return pi.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode) + + +max_pool2d = boolean_dispatch( + arg_name="return_indices", + arg_index=6, + default=False, + if_true=max_pool2d_with_indices, + if_false=_max_pool2d, + module_name=__name__, + func_name="max_pool2d", +) + + +def max_pool3d_with_indices( + input: Tensor, + kernel_size: BroadcastingList3[int], + stride: Optional[BroadcastingList3[int]] = None, + padding: BroadcastingList3[int] = 0, + dilation: BroadcastingList3[int] = 1, + ceil_mode: bool = False, + return_indices: bool = False, +) -> Tuple[Tensor, Tensor]: + return pi._C._nn.max_pool3d_with_indices( + input, kernel_size, stride, padding, dilation, ceil_mode + ) + + +def _max_pool3d( + input: Tensor, + kernel_size: BroadcastingList3[int], + stride: Optional[BroadcastingList3[int]] = None, + padding: BroadcastingList3[int] = 0, + dilation: BroadcastingList3[int] = 1, + ceil_mode: bool = False, + return_indices: bool = False, +) -> Tensor: + if stride is None: + stride = pi.jit.annotate(List[int], []) + return pi.max_pool3d(input, kernel_size, stride, padding, dilation, ceil_mode) + + +max_pool3d = boolean_dispatch( + arg_name="return_indices", + arg_index=6, + default=False, + if_true=max_pool3d_with_indices, + if_false=_max_pool3d, + module_name=__name__, + func_name="max_pool3d", +) + + +def _unpool_output_size( + input: Tensor, + kernel_size: List[int], + stride: List[int], + padding: List[int], + output_size: Optional[List[int]], +) -> List[int]: + input_size = input.size() + default_size = pi.jit.annotate(List[int], []) + for d in range(len(kernel_size)): + default_size.append( + (input_size[-len(kernel_size) + d] - 1) * stride[d] + + kernel_size[d] + - 2 * padding[d] + ) + if output_size is None: + ret = default_size + else: + if len(output_size) == len(kernel_size) + 2: + output_size = output_size[2:] + if len(output_size) != len(kernel_size): + raise ValueError( + "output_size should be a sequence containing " + "{} or {} elements, but it has a length of '{}'".format( + len(kernel_size), len(kernel_size) + 2, len(output_size) + ) + ) + for d in range(len(kernel_size)): + min_size = default_size[d] - stride[d] + max_size = default_size[d] + stride[d] + if not (min_size < output_size[d] < max_size): + raise ValueError( + 'invalid output_size "{}" (dim {} must be between {} and {})'.format( + output_size, d, min_size, max_size + ) + ) + + ret = output_size + return ret + + +def max_unpool1d( + input: Tensor, + indices: Tensor, + kernel_size: BroadcastingList1[int], + stride: Optional[BroadcastingList1[int]] = None, + padding: BroadcastingList1[int] = 0, + output_size: Optional[BroadcastingList1[int]] = None, +) -> Tensor: + kernel_size = _single(kernel_size) + if stride is not None: + _stride = _single(stride) + else: + _stride = kernel_size + padding = _single(padding) + output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size) + if isinstance(output_size, list): + output_size = output_size + [1] + else: + output_size = output_size + (1,) + return pi._C._nn.max_unpool2d( + input.unsqueeze(-1), indices.unsqueeze(-1), output_size + ).squeeze(-1) + + +def max_unpool2d( + input: Tensor, + indices: Tensor, + kernel_size: BroadcastingList2[int], + stride: Optional[BroadcastingList2[int]] = None, + padding: BroadcastingList2[int] = 0, + output_size: Optional[BroadcastingList2[int]] = None, +) -> Tensor: + kernel_size = _pair(kernel_size) + if stride is not None: + _stride = _pair(stride) + else: + _stride = kernel_size + padding = _pair(padding) + output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size) + return pi._C._nn.max_unpool2d(input, indices, output_size) + + +def max_unpool3d( + input: Tensor, + indices: Tensor, + kernel_size: BroadcastingList3[int], + stride: Optional[BroadcastingList3[int]] = None, + padding: BroadcastingList3[int] = 0, + output_size: Optional[BroadcastingList3[int]] = None, +) -> Tensor: + kernel_size = _triple(kernel_size) + if stride is not None: + _stride = _triple(stride) + else: + _stride = kernel_size + padding = _triple(padding) + output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size) + return pi._C._nn.max_unpool3d(input, indices, output_size, _stride, padding) + + +def lp_pool2d( + input: Tensor, + norm_type: Union[int, float], + kernel_size: BroadcastingList2[int], + stride: Optional[BroadcastingList2[int]] = None, + ceil_mode: bool = False, +) -> Tensor: + kw, kh = _pair(kernel_size) + if stride is not None: + out = avg_pool2d(input.pow(norm_type), kernel_size, stride, 0, ceil_mode) + else: + out = avg_pool2d( + input.pow(norm_type), kernel_size, padding=0, ceil_mode=ceil_mode + ) + + return (pi.sign(out) * relu(pi.abs(out))).mul(kw * kh).pow(1.0 / norm_type) + + +def lp_pool1d( + input: Tensor, + norm_type: Union[int, float], + kernel_size: int, + stride: Optional[BroadcastingList1[int]] = None, + ceil_mode: bool = False, +) -> Tensor: + if stride is not None: + out = avg_pool1d(input.pow(norm_type), kernel_size, stride, 0, ceil_mode) + else: + out = avg_pool1d( + input.pow(norm_type), kernel_size, padding=0, ceil_mode=ceil_mode + ) + + return ( + (pi.sign(out) * relu(pi.abs(out))).mul(kernel_size).pow(1.0 / norm_type) + ) + + +def adaptive_max_pool1d_with_indices( + input: Tensor, output_size: BroadcastingList1[int], return_indices: bool = False +) -> Tuple[Tensor, Tensor]: + return pi.adaptive_max_pool1d(input, output_size) + + +def _adaptive_max_pool1d( + input: Tensor, output_size: BroadcastingList1[int], return_indices: bool = False +) -> Tensor: + return adaptive_max_pool1d_with_indices(input, output_size)[0] + + +adaptive_max_pool1d = boolean_dispatch( + arg_name="return_indices", + arg_index=2, + default=False, + if_true=adaptive_max_pool1d_with_indices, + if_false=_adaptive_max_pool1d, + module_name=__name__, + func_name="adaptive_max_pool1d", +) + + +def adaptive_max_pool2d_with_indices( + input: Tensor, output_size: BroadcastingList2[int], return_indices: bool = False +) -> Tuple[Tensor, Tensor]: + output_size = _list_with_default(output_size, input.size()) + return pi._C._nn.adaptive_max_pool2d(input, output_size) + + +def _adaptive_max_pool2d( + input: Tensor, output_size: BroadcastingList2[int], return_indices: bool = False +) -> Tensor: + return adaptive_max_pool2d_with_indices(input, output_size)[0] + + +adaptive_max_pool2d = boolean_dispatch( + arg_name="return_indices", + arg_index=2, + default=False, + if_true=adaptive_max_pool2d_with_indices, + if_false=_adaptive_max_pool2d, + module_name=__name__, + func_name="adaptive_max_pool2d", +) + + +def adaptive_max_pool3d_with_indices( + input: Tensor, output_size: BroadcastingList3[int], return_indices: bool = False +) -> Tuple[Tensor, Tensor]: + output_size = _list_with_default(output_size, input.size()) + return pi._C._nn.adaptive_max_pool3d(input, output_size) + + +def _adaptive_max_pool3d( + input: Tensor, output_size: BroadcastingList3[int], return_indices: bool = False +) -> Tensor: + return adaptive_max_pool3d_with_indices(input, output_size)[0] + + +adaptive_max_pool3d = boolean_dispatch( + arg_name="return_indices", + arg_index=2, + default=False, + if_true=adaptive_max_pool3d_with_indices, + if_false=_adaptive_max_pool3d, + module_name=__name__, + func_name="adaptive_max_pool3d", +) + +# adaptive_avg_pool1d = pi.adaptive_avg_pool1d + + +def adaptive_avg_pool2d(input: Tensor, output_size: BroadcastingList2[int]) -> Tensor: + _output_size = _list_with_default(output_size, input.size()) + return pi._nn.adaptive_avg_pool2d(input, _output_size) + + +def adaptive_avg_pool3d(input: Tensor, output_size: BroadcastingList3[int]) -> Tensor: + _output_size = _list_with_default(output_size, input.size()) + return pi._C._nn.adaptive_avg_pool3d(input, _output_size) + + +# Activation functions +def dropout( + input: Tensor, p: float = 0.5, training: bool = True, inplace: bool = False +) -> Tensor: + if p < 0.0 or p > 1.0: + raise ValueError( + "dropout probability has to be between 0 and 1, " "but got {}".format(p) + ) + return ( + _VF.dropout_(input, p, training) if inplace else _VF.dropout(input, p, training) + ) + + +def alpha_dropout( + input: Tensor, p: float = 0.5, training: bool = False, inplace: bool = False +) -> Tensor: + if p < 0.0 or p > 1.0: + raise ValueError( + "dropout probability has to be between 0 and 1, " "but got {}".format(p) + ) + return ( + _VF.alpha_dropout_(input, p, training) + if inplace + else _VF.alpha_dropout(input, p, training) + ) + + +def dropout1d( + input: Tensor, p: float = 0.5, training: bool = True, inplace: bool = False +) -> Tensor: + if p < 0.0 or p > 1.0: + raise ValueError( + "dropout probability has to be between 0 and 1, " "but got {}".format(p) + ) + inp_dim = input.dim() + if inp_dim not in (2, 3): + raise RuntimeError( + f"dropout1d: Expected 2D or 3D input, but received a {inp_dim}D input. " + "Note that dropout1d exists to provide channel-wise dropout on inputs with 1 " + "spatial dimension, a channel dimension, and an optional batch dimension " + "(i.e. 2D or 3D inputs)." + ) + + is_batched = inp_dim == 3 + if not is_batched: + input = input.unsqueeze_(0) if inplace else input.unsqueeze(0) + + result = ( + _VF.feature_dropout_(input, p, training) + if inplace + else _VF.feature_dropout(input, p, training) + ) + + if not is_batched: + result = result.squeeze_(0) if inplace else result.squeeze(0) + + return result + + +def dropout2d( + input: Tensor, p: float = 0.5, training: bool = True, inplace: bool = False +) -> Tensor: + if p < 0.0 or p > 1.0: + raise ValueError( + "dropout probability has to be between 0 and 1, " "but got {}".format(p) + ) + inp_dim = input.dim() + if inp_dim not in (3, 4): + warn_msg = ( + f"dropout2d: Received a {inp_dim}-D input to dropout2d, which is deprecated " + "and will result in an error in a future release. To retain the behavior " + "and silence this warning, please use dropout instead. Note that dropout2d " + "exists to provide channel-wise dropout on inputs with 2 spatial dimensions, " + "a channel dimension, and an optional batch dimension (i.e. 3D or 4D inputs)." + ) + warnings.warn(warn_msg) + + # TODO: Properly support no-batch-dim inputs. For now, these are NOT supported; passing + # a 3D input will perform dropout1d behavior instead. This was done historically and the + # behavior is maintained here for now. + # See https://github.com/pytorch/pytorch/issues/77081 + if inp_dim == 3: + warnings.warn( + "dropout2d: Received a 3D input to dropout2d and assuming that channel-wise " + "1D dropout behavior is desired - input is interpreted as shape (N, C, L), where C " + "is the channel dim. This behavior will change in a future release to interpret the " + "input as one without a batch dimension, i.e. shape (C, H, W). To maintain the 1D " + "channel-wise dropout behavior, please switch to using dropout1d instead." + ) + + result = ( + _VF.feature_dropout_(input, p, training) + if inplace + else _VF.feature_dropout(input, p, training) + ) + + return result + + +def dropout3d( + input: Tensor, p: float = 0.5, training: bool = True, inplace: bool = False +) -> Tensor: + if p < 0.0 or p > 1.0: + raise ValueError( + "dropout probability has to be between 0 and 1, " "but got {}".format(p) + ) + inp_dim = input.dim() + if inp_dim not in (4, 5): + warn_msg = ( + f"dropout3d: Received a {inp_dim}-D input to dropout3d, which is deprecated " + "and will result in an error in a future release. To retain the behavior " + "and silence this warning, please use dropout instead. Note that dropout3d " + "exists to provide channel-wise dropout on inputs with 3 spatial dimensions, " + "a channel dimension, and an optional batch dimension (i.e. 4D or 5D inputs)." + ) + warnings.warn(warn_msg) + + is_batched = inp_dim == 5 + if not is_batched: + input = input.unsqueeze_(0) if inplace else input.unsqueeze(0) + + result = ( + _VF.feature_dropout_(input, p, training) + if inplace + else _VF.feature_dropout(input, p, training) + ) + + if not is_batched: + result = result.squeeze_(0) if inplace else result.squeeze(0) + return result + + +def feature_alpha_dropout( + input: Tensor, p: float = 0.5, training: bool = False, inplace: bool = False +) -> Tensor: + if p < 0.0 or p > 1.0: + raise ValueError( + "dropout probability has to be between 0 and 1, " "but got {}".format(p) + ) + return ( + _VF.feature_alpha_dropout_(input, p, training) + if inplace + else _VF.feature_alpha_dropout(input, p, training) + ) + + +def _threshold( + input: Tensor, threshold: float, value: float, inplace: bool = False +) -> Tensor: + if inplace: + result = _VF.threshold_(input, threshold, value) + else: + result = _VF.threshold(input, threshold, value) + return result + + +threshold = _threshold + +threshold_ = _VF.threshold_ + + +def relu(input: Tensor, inplace: bool = False) -> Tensor: + if inplace: + result = pi.relu_(input) + else: + result = pi.relu(input) + return result + + +relu_ = pi.relu_ + + +def glu(input: Tensor, dim: int = -1) -> Tensor: + if input.dim() == 0: + raise RuntimeError( + "glu does not support scalars because halving size must be even" + ) + return pi._C._nn.glu(input, dim) + + +def hardtanh( + input: Tensor, min_val: float = -1.0, max_val: float = 1.0, inplace: bool = False +) -> Tensor: + if inplace: + result = pi._nn.hardtanh_(input, min_val, max_val) + else: + result = pi._nn.hardtanh(input, min_val, max_val) + return result + + +hardtanh_ = pi._nn.hardtanh_ + + +def relu6(input: Tensor, inplace: bool = False) -> Tensor: + if inplace: + result = pi._nn.relu6_(input) + else: + result = pi._nn.relu6(input) + return result + + +def elu(input: Tensor, alpha: float = 1.0, inplace: bool = False) -> Tensor: + if inplace: + result = pi._C._nn.elu_(input, alpha) + else: + result = pi._C._nn.elu(input, alpha) + return result + + +# elu_ = pi._C._nn.elu_ + + +def selu(input: Tensor, inplace: bool = False) -> Tensor: + if inplace: + result = pi.selu_(input) + else: + result = pi.selu(input) + return result + + +# selu_ = pi.selu_ + + +def celu(input: Tensor, alpha: float = 1.0, inplace: bool = False) -> Tensor: + if inplace: + result = pi.celu_(input, alpha) + else: + result = pi.celu(input, alpha) + return result + + +# celu_ = pi.celu_ + + +def leaky_relu( + input: Tensor, negative_slope: float = 0.01, inplace: bool = False +) -> Tensor: + if inplace: + result = pi._nn.leaky_relu_(input, negative_slope) + else: + result = pi._nn.leaky_relu(input, negative_slope) + return result + + +leaky_relu_ = pi._nn.leaky_relu_ + + +# prelu = pi.prelu + + +def rrelu( + input: Tensor, + lower: float = 1.0 / 8, + upper: float = 1.0 / 3, + training: bool = False, + inplace: bool = False, +) -> Tensor: + if inplace: + result = pi.rrelu_(input, lower, upper, training) + else: + result = pi.rrelu(input, lower, upper, training) + return result + + +# rrelu_ = pi.rrelu_ + +# logsigmoid = pi._C._nn.log_sigmoid + +gelu = pi._C._nn.gelu + +# hardshrink = pi.hardshrink + + +def tanhshrink(input): + return input - input.tanh() + + +def softsign(input): + return input / (input.abs() + 1) + + +softplus = pi._nn.softplus + + +def _get_softmax_dim(name: str, ndim: int, stacklevel: int) -> int: + warnings.warn( + "Implicit dimension choice for {} has been deprecated. " + "Change the call to include dim=X as an argument.".format(name), + stacklevel=stacklevel, + ) + if ndim == 0 or ndim == 1 or ndim == 3: + ret = 0 + else: + ret = 1 + return ret + + +def softmin( + input: Tensor, + dim: Optional[int] = None, + _stacklevel: int = 3, + dtype: Optional[dtype] = None, +) -> Tensor: + if dim is None: + dim = _get_softmax_dim("softmin", input.dim(), _stacklevel) + if dtype is None: + ret = (-input).softmax(dim) + else: + ret = (-input).softmax(dim, dtype=dtype) + return ret + + +def softmax( + input: Tensor, + dim: Optional[int] = None, + _stacklevel: int = 3, + dtype: Optional[dtype] = None, +) -> Tensor: + if dim is None: + dim = _get_softmax_dim("softmax", input.dim(), _stacklevel) + if dtype is None: + ret = input.softmax(dim) + else: + ret = input.softmax(dim, dtype=dtype) + return ret + + +def gumbel_softmax( + logits: Tensor, + tau: float = 1, + hard: bool = False, + eps: float = 1e-10, + dim: int = -1, +) -> Tensor: + if eps != 1e-10: + warnings.warn("`eps` parameter is deprecated and has no effect.") + + gumbels = ( + -pi.empty_like(logits, memory_format=pi.legacy_contiguous_format) + .exponential_() + .log() + ) # ~Gumbel(0,1) + gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau) + y_soft = gumbels.softmax(dim) + + if hard: + # Straight through. + index = y_soft.max(dim, keepdim=True)[1] + y_hard = pi.zeros_like( + logits, memory_format=pi.legacy_contiguous_format + ).scatter_(dim, index, 1.0) + ret = y_hard - y_soft.detach() + y_soft + else: + # Reparametrization trick. + ret = y_soft + return ret + + +def log_softmax( + input: Tensor, + dim: Optional[int] = None, + _stacklevel: int = 3, + dtype: Optional[dtype] = None, +) -> Tensor: + if dim is None: + dim = _get_softmax_dim("log_softmax", input.dim(), _stacklevel) + if dtype is None: + ret = input.log_softmax(dim) + else: + ret = input.log_softmax(dim, dtype=dtype) + return ret + + +# softshrink = pi._C._nn.softshrink + + +def tanh(input): + return input.tanh() + + +def sigmoid(input): + return input.sigmoid() + + +def hardsigmoid(input: Tensor, inplace: bool = False) -> Tensor: + if inplace: + return pi._C._nn.hardsigmoid_(input) + return pi._C._nn.hardsigmoid(input) + + +linear = pi._nn.linear + + +# bilinear = pi.bilinear + + +def silu(input: Tensor, inplace: bool = False) -> Tensor: + if inplace: + return pi._C._nn.silu_(input) + return pi._C._nn.silu(input) + + +def mish(input: Tensor, inplace: bool = False) -> Tensor: + if inplace: + return pi._C._nn.mish_(input) + return pi._C._nn.mish(input) + + +def hardswish(input: Tensor, inplace: bool = False) -> Tensor: + if inplace: + return pi._nn.hardswish_(input) + return pi._nn.hardswish(input) + + +def _no_grad_embedding_renorm_( + weight: Tensor, input: Tensor, max_norm: float, norm_type: float +) -> Tuple[Tensor, Tensor]: + pi.embedding_renorm_(weight.detach(), input, max_norm, norm_type) + + +def embedding( + input: Tensor, + weight: Tensor, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, +) -> Tensor: + if padding_idx is not None: + if padding_idx > 0: + assert padding_idx < weight.size( + 0 + ), "Padding_idx must be within num_embeddings" + elif padding_idx < 0: + assert padding_idx >= -weight.size( + 0 + ), "Padding_idx must be within num_embeddings" + padding_idx = weight.size(0) + padding_idx + else: + padding_idx = -1 + if max_norm is not None: + raise NotImplementedError + # Note [embedding_renorm contiguous] + # `embedding_renorm_` will call .contiguous() on input anyways, so we + # call it here and take advantage of the improved locality in the + # `embedding` call below too. + # input = input.contiguous() + # Note [embedding_renorm set_grad_enabled] + # XXX: equivalent to + # with pi.no_grad(): + # pi.embedding_renorm_ + # remove once script supports set_grad_enabled + # _no_grad_embedding_renorm_(weight, input, max_norm, norm_type) + return pi.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse) + + +def embedding_bag( + input: Tensor, + weight: Tensor, + offsets: Optional[Tensor] = None, + max_norm: Optional[float] = None, + norm_type: float = 2, + scale_grad_by_freq: bool = False, + mode: str = "mean", + sparse: bool = False, + per_sample_weights: Optional[Tensor] = None, + include_last_offset: bool = False, + padding_idx: Optional[int] = None, +) -> Tensor: + # Check for backward compatibility. + # Used to be embedding_bag(weight, input, ...) + # Now is embedding_bag(input, weight, ...) + if weight.dtype == pi.long and input.is_floating_point(): + warnings.warn( + "Argument order of nn.functional.embedding_bag was changed. " + "Usage `embedding_bag(weight, input, ...)` is deprecated, " + "and should now be `embedding_bag(input, weight, ...)`." + ) + weight, input = input, weight + + if per_sample_weights is not None and input.size() != per_sample_weights.size(): + raise ValueError( + "embedding_bag: If per_sample_weights ({}) is not None, " + "then it must have the same shape as the input ({})".format( + per_sample_weights.shape, input.shape + ) + ) + + if input.dim() == 2: + if offsets is not None: + type_str = "" + type_str = str(type(offsets)) + raise ValueError( + "if input is 2D, then offsets has to be None" + ", as input is treated is a mini-batch of" + " fixed length sequences. However, found " + "offsets of type {}".format(type_str) + ) + offsets = pi.arange( + 0, input.numel(), input.size(1), dtype=input.dtype, device=input.device + ) + + input = input.reshape(-1) + if per_sample_weights is not None: + per_sample_weights = per_sample_weights.reshape(-1) + elif input.dim() == 1: + if offsets is None: + raise ValueError("offsets has to be a 1D Tensor but got None") + if offsets.dim() != 1: + raise ValueError("offsets has to be a 1D Tensor") + else: + raise ValueError( + "input has to be 1D or 2D Tensor," + " but got Tensor of dimension {}".format(input.dim()) + ) + if mode == "sum": + mode_enum = 0 + elif mode == "mean": + mode_enum = 1 + elif mode == "max": + mode_enum = 2 + + if scale_grad_by_freq: + raise ValueError( + "max mode does not support scaling the gradient by the frequency" + ) + + if sparse: + raise ValueError("max mode does not support sparse weights") + + else: + raise ValueError("mode has to be one of sum, mean or max") + + if max_norm is not None: + # XXX: equivalent to + # with pi.no_grad(): + # pi.nembedding_renorm_ + # remove once script supports set_grad_enabled + raise NotImplementedError + _no_grad_embedding_renorm_(weight, input, max_norm, norm_type) + + if per_sample_weights is not None and mode != "sum": + raise NotImplementedError( + "embedding_bag: per_sample_weights was not None. " + "per_sample_weights is only supported for mode='sum' " + "(got mode='{}'). Please open a feature request on GitHub.".format(mode) + ) + + ret, _, _, _ = pi.embedding_bag( + weight, + input, + offsets, + scale_grad_by_freq, + mode_enum, + sparse, + per_sample_weights, + include_last_offset, + padding_idx, + ) + return ret + + +def _verify_batch_size(size: List[int]) -> None: + # XXX: JIT script does not support the reduce from functools, and mul op is a + # builtin, which cannot be used as a value to a func yet, so rewrite this size + # check to a simple equivalent for loop + # + # TODO: make use of reduce like below when JIT is ready with the missing features: + # from operator import mul + # from functools import reduce + # + # if reduce(mul, size[2:], size[0]) == 1 + size_prods = size[0] + for i in range(len(size) - 2): + size_prods *= size[i + 2] + if size_prods == 1: + raise ValueError( + "Expected more than 1 value per channel when training, got input size {}".format( + size + ) + ) + + +def batch_norm( + input: Tensor, + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + training: bool = False, + momentum: float = 0.1, + eps: float = 1e-5, +) -> Tensor: + if training: + _verify_batch_size(input.size()) + + return pi.batch_norm( + input, + weight, + bias, + running_mean, + running_var, + training, + momentum, + eps, + False, + ) + + +def _verify_spatial_size(size: List[int]) -> None: + # Verify that there is > 1 spatial element for instance norm calculation. + size_prods = 1 + for i in range(2, len(size)): + size_prods *= size[i] + if size_prods == 1: + raise ValueError( + "Expected more than 1 spatial element when training, got input size {}".format( + size + ) + ) + + +def instance_norm( + input: Tensor, + running_mean: Optional[Tensor] = None, + running_var: Optional[Tensor] = None, + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + use_input_stats: bool = True, + momentum: float = 0.1, + eps: float = 1e-5, +) -> Tensor: + if use_input_stats: + _verify_spatial_size(input.size()) + return pi.instance_norm( + input, + weight, + bias, + running_mean, + running_var, + use_input_stats, + momentum, + eps, + pi.backends.cudnn.enabled, + ) + + +def layer_norm( + input: Tensor, + normalized_shape: List[int], + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + eps: float = 1e-5, +) -> Tensor: + return pi.layer_norm(input, normalized_shape, weight, bias, eps, False) + + +def group_norm( + input: Tensor, + num_groups: int, + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + eps: float = 1e-5, +) -> Tensor: + if input.dim() < 2: + raise RuntimeError( + f"Expected at least 2 dimensions for input tensor but received {input.dim()}" + ) + _verify_batch_size( + [input.size(0) * input.size(1) // num_groups, num_groups] + + list(input.size()[2:]) + ) + return pi.group_norm( + input, num_groups, weight, bias, eps, pi.backends.cudnn.enabled + ) + + +def local_response_norm( + input: Tensor, size: int, alpha: float = 1e-4, beta: float = 0.75, k: float = 1.0 +) -> Tensor: + dim = input.dim() + if dim < 3: + raise ValueError( + "Expected 3D or higher dimensionality \ + input (got {} dimensions)".format( + dim + ) + ) + + if input.numel() == 0: + return input + + div = input.mul(input).unsqueeze(1) + if dim == 3: + div = pad(div, (0, 0, size // 2, (size - 1) // 2)) + div = avg_pool2d(div, (size, 1), stride=1).squeeze(1) + else: + raise NotImplementedError + # sizes = input.size() + # div = div.view(sizes[0], 1, sizes[1], sizes[2], -1) + # div = pad(div, (0, 0, 0, 0, size // 2, (size - 1) // 2)) + # div = avg_pool3d(div, (size, 1, 1), stride=1).squeeze(1) + # div = div.view(sizes) + div = div.mul(alpha).add(k).pow(beta) + return input / div + + +# loss + + +def ctc_loss( + log_probs: Tensor, + targets: Tensor, + input_lengths: Tensor, + target_lengths: Tensor, + blank: int = 0, + reduction: str = "mean", + zero_infinity: bool = False, +) -> Tensor: + return pi.ctc_loss( + log_probs, + targets, + input_lengths, + target_lengths, + blank, + _Reduction.get_enum(reduction), + zero_infinity, + ) + + +def nll_loss( + input: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + size_average: Optional[bool] = None, + ignore_index: int = -100, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: + if size_average is not None or reduce is not None: + reduction = _Reduction.legacy_get_string(size_average, reduce) + return pi._nn.nll_loss_nd( + input, target, weight, _Reduction.get_enum(reduction), ignore_index + ) + + +def poisson_nll_loss( + input: Tensor, + target: Tensor, + log_input: bool = True, + full: bool = False, + size_average: Optional[bool] = None, + eps: float = 1e-8, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: + if size_average is not None or reduce is not None: + reduction = _Reduction.legacy_get_string(size_average, reduce) + if reduction != "none" and reduction != "mean" and reduction != "sum": + ret = input + raise ValueError(reduction + " is not a valid value for reduction") + + ret = pi.poisson_nll_loss( + input, target, log_input, full, eps, _Reduction.get_enum(reduction) + ) + return ret + + +def gaussian_nll_loss( + input: Tensor, + target: Tensor, + var: Tensor, + full: bool = False, + eps: float = 1e-6, + reduction: str = "mean", +) -> Tensor: + # Check var size + # If var.size == input.size, the case is heteroscedastic and no further checks are needed. + # Otherwise: + if var.size() != input.size(): + + # If var is one dimension short of input, but the sizes match otherwise, then this is a homoscedastic case. + # e.g. input.size = (10, 2, 3), var.size = (10, 2) + # -> unsqueeze var so that var.shape = (10, 2, 1) + # this is done so that broadcasting can happen in the loss calculation + if input.size()[:-1] == var.size(): + var = pi.unsqueeze(var, -1) + + # This checks if the sizes match up to the final dimension, and the final dimension of var is of size 1. + # This is also a homoscedastic case. + # e.g. input.size = (10, 2, 3), var.size = (10, 2, 1) + elif ( + input.size()[:-1] == var.size()[:-1] and var.size(-1) == 1 + ): # Heteroscedastic case + pass + + # If none of the above pass, then the size of var is incorrect. + else: + raise ValueError("var is of incorrect size") + + # Check validity of reduction mode + if reduction != "none" and reduction != "mean" and reduction != "sum": + raise ValueError(reduction + " is not valid") + + # Entries of var must be non-negative + if pi.any(var < 0): + raise ValueError("var has negative entry/entries") + + # Clamp for stability + var = var.clone() + var.clamp_(min=eps) + + # Calculate the loss + loss = 0.5 * (pi.log(var) + (input - target) ** 2 / var) + if full: + loss += 0.5 * math.log(2 * math.pi) + + if reduction == "mean": + return loss.mean() + elif reduction == "sum": + return loss.sum() + else: + return loss + + +def kl_div( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", + log_target: bool = False, +) -> Tensor: + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + if reduction == "mean": + warnings.warn( + "reduction: 'mean' divides the total loss by both the batch size and the support size." + "'batchmean' divides only by the batch size, and aligns with the KL div math definition." + "'mean' will be changed to behave the same as 'batchmean' in the next major release." + ) + + # special case for batchmean + if reduction == "batchmean": + reduction_enum = _Reduction.get_enum("sum") + else: + reduction_enum = _Reduction.get_enum(reduction) + + reduced = pi.kl_div(input, target, reduction_enum, log_target=log_target) + + if reduction == "batchmean" and input.dim() != 0: + reduced = reduced / input.size()[0] + + return reduced + + +def cross_entropy( + input: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + size_average: Optional[bool] = None, + ignore_index: int = -100, + reduce: Optional[bool] = None, + reduction: str = "mean", + label_smoothing: float = 0.0, +) -> Tensor: + if size_average is not None or reduce is not None: + reduction = _Reduction.legacy_get_string(size_average, reduce) + return pi._C._nn.cross_entropy_loss( + input, + target, + weight, + _Reduction.get_enum(reduction), + ignore_index, + label_smoothing, + ) + + +def binary_cross_entropy( + input: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + reduction_enum = _Reduction.get_enum(reduction) + if target.size() != input.size(): + raise ValueError( + "Using a target size ({}) that is different to the input size ({}) is deprecated. " + "Please ensure they have the same size.".format(target.size(), input.size()) + ) + + if weight is not None: + new_size = _infer_size(target.size(), weight.size()) + weight = weight.expand(new_size) + + return pi._C._nn.binary_cross_entropy(input, target, weight, reduction_enum) + + +def binary_cross_entropy_with_logits( + input: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", + pos_weight: Optional[Tensor] = None, +) -> Tensor: + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + reduction_enum = _Reduction.get_enum(reduction) + + if not (target.size() == input.size()): + raise ValueError( + "Target size ({}) must be the same as input size ({})".format( + target.size(), input.size() + ) + ) + + return pi.binary_cross_entropy_with_logits( + input, target, weight, pos_weight, reduction_enum + ) + + +def smooth_l1_loss( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", + beta: float = 1.0, +) -> Tensor: + if not (target.size() == input.size()): + warnings.warn( + "Using a target size ({}) that is different to the input size ({}). " + "This will likely lead to incorrect results due to broadcasting. " + "Please ensure they have the same size.".format( + target.size(), input.size() + ), + stacklevel=2, + ) + if size_average is not None or reduce is not None: + reduction = _Reduction.legacy_get_string(size_average, reduce) + + expanded_input, expanded_target = pi.broadcast_tensors(input, target) + return pi._C._nn.smooth_l1_loss( + expanded_input, expanded_target, _Reduction.get_enum(reduction), beta + ) + + +def huber_loss( + input: Tensor, + target: Tensor, + reduction: str = "mean", + delta: float = 1.0, +) -> Tensor: + if not (target.size() == input.size()): + warnings.warn( + "Using a target size ({}) that is different to the input size ({}). " + "This will likely lead to incorrect results due to broadcasting. " + "Please ensure they have the same size.".format( + target.size(), input.size() + ), + stacklevel=2, + ) + + expanded_input, expanded_target = pi.broadcast_tensors(input, target) + return pi._C._nn.huber_loss( + expanded_input, expanded_target, _Reduction.get_enum(reduction), delta + ) + + +def l1_loss( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: + if not (target.size() == input.size()): + warnings.warn( + "Using a target size ({}) that is different to the input size ({}). " + "This will likely lead to incorrect results due to broadcasting. " + "Please ensure they have the same size.".format( + target.size(), input.size() + ), + stacklevel=2, + ) + if size_average is not None or reduce is not None: + reduction = _Reduction.legacy_get_string(size_average, reduce) + + expanded_input, expanded_target = pi.broadcast_tensors(input, target) + return pi._C._nn.l1_loss( + expanded_input, expanded_target, _Reduction.get_enum(reduction) + ) + + +def mse_loss( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: + if not (target.size() == input.size()): + warnings.warn( + "Using a target size ({}) that is different to the input size ({}). " + "This will likely lead to incorrect results due to broadcasting. " + "Please ensure they have the same size.".format( + target.size(), input.size() + ), + stacklevel=2, + ) + if size_average is not None or reduce is not None: + reduction = _Reduction.legacy_get_string(size_average, reduce) + + expanded_input, expanded_target = pi.broadcast_tensors(input, target) + return pi._nn.mse_loss( + expanded_input, expanded_target, _Reduction.get_enum(reduction) + ) + + +def margin_ranking_loss( + input1: Tensor, + input2: Tensor, + target: Tensor, + margin: float = 0, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + reduction_enum = _Reduction.get_enum(reduction) + if input1.dim() != input2.dim() or input1.dim() != target.dim(): + raise RuntimeError( + ( + "margin_ranking_loss : All input tensors should have same dimension but got sizes: " + "input1: {}, input2: {}, target: {} ".format( + input1.size(), input2.size(), target.size() + ) + ) + ) + return pi.margin_ranking_loss(input1, input2, target, margin, reduction_enum) + + +def hinge_embedding_loss( + input: Tensor, + target: Tensor, + margin: float = 1.0, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + reduction_enum = _Reduction.get_enum(reduction) + return pi.hinge_embedding_loss(input, target, margin, reduction_enum) + + +def multilabel_margin_loss( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + reduction_enum = _Reduction.get_enum(reduction) + return pi._C._nn.multilabel_margin_loss(input, target, reduction_enum) + + +def soft_margin_loss( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + reduction_enum = _Reduction.get_enum(reduction) + return pi._C._nn.soft_margin_loss(input, target, reduction_enum) + + +def multilabel_soft_margin_loss( + input: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: + if size_average is not None or reduce is not None: + reduction = _Reduction.legacy_get_string(size_average, reduce) + + loss = -(target * logsigmoid(input) + (1 - target) * logsigmoid(-input)) + + if weight is not None: + loss = loss * weight + + class_dim = input.dim() - 1 + C = input.size(class_dim) + loss = loss.sum(dim=class_dim) / C # only return N loss values + + if reduction == "none": + ret = loss + elif reduction == "mean": + ret = loss.mean() + elif reduction == "sum": + ret = loss.sum() + else: + ret = input + raise ValueError(reduction + " is not valid") + return ret + + +def cosine_embedding_loss( + input1: Tensor, + input2: Tensor, + target: Tensor, + margin: float = 0, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + reduction_enum = _Reduction.get_enum(reduction) + return pi.cosine_embedding_loss(input1, input2, target, margin, reduction_enum) + + +def multi_margin_loss( + input: Tensor, + target: Tensor, + p: int = 1, + margin: float = 1.0, + weight: Optional[Tensor] = None, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + reduction_enum = _Reduction.get_enum(reduction) + if p != 1 and p != 2: + raise ValueError("only p == 1 and p == 2 supported") + if weight is not None: + if weight.dim() != 1: + raise ValueError("weight must be one-dimensional") + + return pi._C._nn.multi_margin_loss( + input, target, p, margin, weight, reduction_enum + ) + + +# pixel_shuffle = pi.pixel_shuffle + +# pixel_unshuffle = pi.pixel_unshuffle + +# channel_shuffle = pi.channel_shuffle + +# native_channel_shuffle = pi.native_channel_shuffle + + +def upsample( + input, size=None, scale_factor=None, mode="nearest", align_corners=None +): # noqa: F811 + + warnings.warn( + "nn.functional.upsample is deprecated. Use nn.functional.interpolate instead." + ) + return interpolate(input, size, scale_factor, mode, align_corners) + + +def interpolate( + input: Tensor, + size: Optional[int] = None, + scale_factor: Optional[List[float]] = None, + mode: str = "nearest", + align_corners: Optional[bool] = None, + recompute_scale_factor: Optional[bool] = None, + antialias: bool = False, +) -> Tensor: # noqa: F811 + + if mode in ("nearest", "area", "nearest-exact"): + if align_corners is not None: + raise ValueError( + "align_corners option can only be set with the " + "interpolating modes: linear | bilinear | bicubic | trilinear" + ) + else: + if align_corners is None: + align_corners = False + + dim = input.dim() - 2 # Number of spatial dimensions. + + # Process size and scale_factor. Validate that exactly one is set. + # Validate its length if it is a list, or expand it if it is a scalar. + # After this block, exactly one of output_size and scale_factors will + # be non-None, and it will be a list (or tuple). + if size is not None and scale_factor is not None: + raise ValueError("only one of size or scale_factor should be defined") + elif size is not None: + assert scale_factor is None + scale_factors = None + if isinstance(size, (list, tuple)): + if len(size) != dim: + raise ValueError( + "Input and output must have the same number of spatial dimensions, but got " + f"input with spatial dimensions of {list(input.shape[2:])} and output size of {size}. " + "Please provide input tensor in (N, C, d1, d2, ...,dK) format and " + "output size in (o1, o2, ...,oK) format." + ) + output_size = size + else: + output_size = [size for _ in range(dim)] + elif scale_factor is not None: + assert size is None + output_size = None + if isinstance(scale_factor, (list, tuple)): + if len(scale_factor) != dim: + raise ValueError( + "Input and scale_factor must have the same number of spatial dimensions, but " + f"got input with spatial dimensions of {list(input.shape[2:])} and " + f"scale_factor of shape {scale_factor}. " + "Please provide input tensor in (N, C, d1, d2, ...,dK) format and " + "scale_factor in (s1, s2, ...,sK) format." + ) + scale_factors = scale_factor + else: + scale_factors = [scale_factor for _ in range(dim)] + else: + raise ValueError("either size or scale_factor should be defined") + + if ( + recompute_scale_factor is not None + and recompute_scale_factor + and size is not None + ): + raise ValueError( + "recompute_scale_factor is not meaningful with an explicit size." + ) + + # "area" mode always requires an explicit size rather than scale factor. + # Re-use the recompute_scale_factor code path. + if mode == "area" and output_size is None: + recompute_scale_factor = True + + if recompute_scale_factor is not None and recompute_scale_factor: + # We compute output_size here, then un-set scale_factors. + # The C++ code will recompute it based on the (integer) output size. + if not pi.jit.is_scripting() and pi._C._get_tracing_state(): + # make scale_factor a tensor in tracing so constant doesn't get baked in + output_size = [ + ( + pi.floor( + ( + input.size(i + 2).float() + * pi.tensor(scale_factors[i], dtype=pi.float32) + ).float() + ) + ) + for i in range(dim) + ] + else: + assert scale_factors is not None + output_size = [ + int(math.floor(float(input.size(i + 2)) * scale_factors[i])) + for i in range(dim) + ] + scale_factors = None + + if antialias and not (mode in ("bilinear", "bicubic") and input.ndim == 4): + raise ValueError( + "Anti-alias option is only supported for bilinear and bicubic modes" + ) + + if input.dim() == 3 and mode == "nearest": + return pi._C._nn.upsample_nearest1d(input, output_size, scale_factors) + if input.dim() == 4 and mode == "nearest": + return pi._C._nn.upsample_nearest2d(input, output_size, scale_factors) + if input.dim() == 5 and mode == "nearest": + return pi._C._nn.upsample_nearest3d(input, output_size, scale_factors) + + if input.dim() == 3 and mode == "nearest-exact": + return pi._C._nn._upsample_nearest_exact1d(input, output_size, scale_factors) + if input.dim() == 4 and mode == "nearest-exact": + return pi._C._nn._upsample_nearest_exact2d(input, output_size, scale_factors) + if input.dim() == 5 and mode == "nearest-exact": + return pi._C._nn._upsample_nearest_exact3d(input, output_size, scale_factors) + + if input.dim() == 3 and mode == "area": + assert output_size is not None + return adaptive_avg_pool1d(input, output_size) + if input.dim() == 4 and mode == "area": + assert output_size is not None + return adaptive_avg_pool2d(input, output_size) + if input.dim() == 5 and mode == "area": + assert output_size is not None + return adaptive_avg_pool3d(input, output_size) + + if input.dim() == 3 and mode == "linear": + assert align_corners is not None + return pi._C._nn.upsample_linear1d( + input, output_size, align_corners, scale_factors + ) + if input.dim() == 4 and mode == "bilinear": + assert align_corners is not None + if antialias: + return pi._C._nn._upsample_bilinear2d_aa( + input, output_size, align_corners, scale_factors + ) + return pi._C._nn.upsample_bilinear2d( + input, output_size, align_corners, scale_factors + ) + if input.dim() == 5 and mode == "trilinear": + assert align_corners is not None + return pi._C._nn.upsample_trilinear3d( + input, output_size, align_corners, scale_factors + ) + if input.dim() == 4 and mode == "bicubic": + assert align_corners is not None + if antialias: + return pi._C._nn._upsample_bicubic2d_aa( + input, output_size, align_corners, scale_factors + ) + return pi._C._nn.upsample_bicubic2d( + input, output_size, align_corners, scale_factors + ) + + if input.dim() == 3 and mode == "bilinear": + raise NotImplementedError("Got 3D input, but bilinear mode needs 4D input") + if input.dim() == 3 and mode == "trilinear": + raise NotImplementedError("Got 3D input, but trilinear mode needs 5D input") + if input.dim() == 4 and mode == "linear": + raise NotImplementedError("Got 4D input, but linear mode needs 3D input") + if input.dim() == 4 and mode == "trilinear": + raise NotImplementedError("Got 4D input, but trilinear mode needs 5D input") + if input.dim() == 5 and mode == "linear": + raise NotImplementedError("Got 5D input, but linear mode needs 3D input") + if input.dim() == 5 and mode == "bilinear": + raise NotImplementedError("Got 5D input, but bilinear mode needs 4D input") + + raise NotImplementedError( + "Input Error: Only 3D, 4D and 5D input Tensors supported" + " (got {}D) for the modes: nearest | linear | bilinear | bicubic | trilinear | area | nearest-exact" + " (got {})".format(input.dim(), mode) + ) + + +def upsample_nearest(input, size=None, scale_factor=None): # noqa: F811 + + # DeprecationWarning is ignored by default + warnings.warn( + "nn.functional.upsample_nearest is deprecated. Use nn.functional.interpolate instead." + ) + return interpolate(input, size, scale_factor, mode="nearest") + + +def upsample_bilinear(input, size=None, scale_factor=None): # noqa: F811 + + # DeprecationWarning is ignored by default + warnings.warn( + "nn.functional.upsample_bilinear is deprecated. Use nn.functional.interpolate instead." + ) + return interpolate(input, size, scale_factor, mode="bilinear", align_corners=True) + + +GRID_SAMPLE_INTERPOLATION_MODES = { + "bilinear": 0, + "nearest": 1, + "bicubic": 2, +} + +GRID_SAMPLE_PADDING_MODES = { + "zeros": 0, + "border": 1, + "reflection": 2, +} + + +def grid_sample( + input: Tensor, + grid: Tensor, + mode: str = "bilinear", + padding_mode: str = "zeros", + align_corners: Optional[bool] = None, +) -> Tensor: + if mode != "bilinear" and mode != "nearest" and mode != "bicubic": + raise ValueError( + "nn.functional.grid_sample(): expected mode to be " + "'bilinear', 'nearest' or 'bicubic', but got: '{}'".format(mode) + ) + if ( + padding_mode != "zeros" + and padding_mode != "border" + and padding_mode != "reflection" + ): + raise ValueError( + "nn.functional.grid_sample(): expected padding_mode " + "to be 'zeros', 'border', or 'reflection', " + "but got: '{}'".format(padding_mode) + ) + + if mode == "bilinear": + mode_enum = 0 + elif mode == "nearest": + mode_enum = 1 + else: # mode == 'bicubic' + mode_enum = 2 + + if padding_mode == "zeros": + padding_mode_enum = 0 + elif padding_mode == "border": + padding_mode_enum = 1 + else: # padding_mode == 'reflection' + padding_mode_enum = 2 + + if align_corners is None: + warnings.warn( + "Default grid_sample and affine_grid behavior has changed " + "to align_corners=False since 1.3.0. Please specify " + "align_corners=True if the old behavior is desired. " + "See the documentation of grid_sample for details." + ) + align_corners = False + + return pi.grid_sampler(input, grid, mode_enum, padding_mode_enum, align_corners) + + +def affine_grid( + theta: Tensor, size: List[int], align_corners: Optional[bool] = None +) -> Tensor: + if align_corners is None: + warnings.warn( + "Default grid_sample and affine_grid behavior has changed " + "to align_corners=False since 1.3.0. Please specify " + "align_corners=True if the old behavior is desired. " + "See the documentation of grid_sample for details." + ) + align_corners = False + + # enforce floating point dtype on theta + if not theta.is_floating_point(): + raise ValueError( + "Expected theta to have floating point type, but got {}".format(theta.dtype) + ) + # check that shapes and sizes match + if len(size) == 4: + if theta.dim() != 3 or theta.shape[-2] != 2 or theta.shape[-1] != 3: + raise ValueError( + "Expected a batch of 2D affine matrices of shape Nx2x3 " + "for size {}. Got {}.".format(size, theta.shape) + ) + spatial_size = size[-2:] # spatial dimension sizes + elif len(size) == 5: + if theta.dim() != 3 or theta.shape[-2] != 3 or theta.shape[-1] != 4: + raise ValueError( + "Expected a batch of 3D affine matrices of shape Nx3x4 " + "for size {}. Got {}.".format(size, theta.shape) + ) + spatial_size = size[-3:] # spatial dimension sizes + else: + raise NotImplementedError( + "affine_grid only supports 4D and 5D sizes, " + "for 2D and 3D affine transforms, respectively. " + "Got size {}.".format(size) + ) + # check for empty span + if align_corners and min(spatial_size) == 1: + warnings.warn( + "Since version 1.3.0, affine_grid behavior has changed " + "for unit-size grids when align_corners=True. " + "This is not an intended use case of affine_grid. " + "See the documentation of affine_grid for details." + ) + elif min(size) <= 0: + raise ValueError("Expected non-zero, positive output size. Got {}".format(size)) + + return pi.affine_grid_generator(theta, size, align_corners) + + +pad = pi._nn.pad + + +# distance + + +# pairwise_distance = pi.pairwise_distance + +# pdist = pi.pdist + +# cosine_similarity = pi.cosine_similarity + +# one_hot = pi._C._nn.one_hot + + +def triplet_margin_loss( + anchor: Tensor, + positive: Tensor, + negative: Tensor, + margin: float = 1.0, + p: float = 2, + eps: float = 1e-6, + swap: bool = False, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + reduction_enum = _Reduction.get_enum(reduction) + return pi.triplet_margin_loss( + anchor, positive, negative, margin, p, eps, swap, reduction_enum + ) + + +def triplet_margin_with_distance_loss( + anchor: Tensor, + positive: Tensor, + negative: Tensor, + *, + distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = None, + margin: float = 1.0, + swap: bool = False, + reduction: str = "mean", +) -> Tensor: + # Check validity of reduction mode + if reduction not in ("mean", "sum", "none"): + raise ValueError(f"{reduction} is not a valid value for reduction") + + # Check dimensions + a_dim = anchor.ndim + p_dim = positive.ndim + n_dim = negative.ndim + if not (a_dim == p_dim and p_dim == n_dim): + raise RuntimeError( + ( + f"The anchor, positive, and negative tensors are expected to have " + f"the same number of dimensions, but got: anchor {a_dim}D, " + f"positive {p_dim}D, and negative {n_dim}D inputs" + ) + ) + + # Calculate loss + if distance_function is None: + distance_function = pi.pairwise_distance + + dist_pos = distance_function(anchor, positive) + dist_neg = distance_function(anchor, negative) + # The distance swap is described in the paper "Learning shallow + # convolutional feature descriptors with triplet losses" by V. Balntas, E. + # Riba et al. If True, and if the positive example is closer to the + # negative example than the anchor is, swaps the positive example and the + # anchor in the loss computation. + if swap: + dist_swap = distance_function(positive, negative) + dist_neg = pi.minimum(dist_neg, dist_swap) + loss = pi.clamp_min(margin + dist_pos - dist_neg, 0) + + # Apply reduction + if reduction == "sum": + return pi.sum(loss) + elif reduction == "mean": + return pi.mean(loss) + else: # reduction == "none" + return loss + + +def normalize( + input: Tensor, + p: float = 2.0, + dim: int = 1, + eps: float = 1e-12, + out: Optional[Tensor] = None, +) -> Tensor: + if out is None: + denom = input.norm(p, dim, keepdim=True).clamp_min(eps).expand_as(input) + return input / denom + else: + denom = input.norm(p, dim, keepdim=True).clamp_min_(eps).expand_as(input) + return pi.div(input, denom, out=out) + + +def assert_int_or_pair(arg: List[int], arg_name: str, message: str) -> None: + assert isinstance(arg, int) or len(arg) == 2, message.format(arg_name) + + +def unfold( + input: Tensor, + kernel_size: BroadcastingList2[int], + dilation: BroadcastingList2[int] = 1, + padding: BroadcastingList2[int] = 0, + stride: BroadcastingList2[int] = 1, +) -> Tensor: + return pi._C._nn.im2col( + input, _pair(kernel_size), _pair(dilation), _pair(padding), _pair(stride) + ) + + +def fold( + input: Tensor, + output_size: BroadcastingList2[int], + kernel_size: BroadcastingList2[int], + dilation: BroadcastingList2[int] = 1, + padding: BroadcastingList2[int] = 0, + stride: BroadcastingList2[int] = 1, +) -> Tensor: + return pi._C._nn.col2im( + input, + _pair(output_size), + _pair(kernel_size), + _pair(dilation), + _pair(padding), + _pair(stride), + ) + + +# multihead attention + + +def _in_projection_packed( + q: Tensor, + k: Tensor, + v: Tensor, + w: Tensor, + b: Optional[Tensor] = None, +) -> List[Tensor]: + E = q.size(-1) + if k is v: + if q is k: + # self-attention + return linear(q, w, b).chunk(3, dim=-1) + else: + # encoder-decoder attention + w_q, w_kv = w.split([E, E * 2]) + if b is None: + b_q = b_kv = None + else: + b_q, b_kv = b.split([E, E * 2]) + return (linear(q, w_q, b_q),) + linear(k, w_kv, b_kv).chunk(2, dim=-1) + else: + w_q, w_k, w_v = w.chunk(3) + if b is None: + b_q = b_k = b_v = None + else: + b_q, b_k, b_v = b.chunk(3) + return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v) + + +def _in_projection( + q: Tensor, + k: Tensor, + v: Tensor, + w_q: Tensor, + w_k: Tensor, + w_v: Tensor, + b_q: Optional[Tensor] = None, + b_k: Optional[Tensor] = None, + b_v: Optional[Tensor] = None, +) -> Tuple[Tensor, Tensor, Tensor]: + Eq, Ek, Ev = q.size(-1), k.size(-1), v.size(-1) + assert w_q.shape == ( + Eq, + Eq, + ), f"expecting query weights shape of {(Eq, Eq)}, but got {w_q.shape}" + assert w_k.shape == ( + Eq, + Ek, + ), f"expecting key weights shape of {(Eq, Ek)}, but got {w_k.shape}" + assert w_v.shape == ( + Eq, + Ev, + ), f"expecting value weights shape of {(Eq, Ev)}, but got {w_v.shape}" + assert b_q is None or b_q.shape == ( + Eq, + ), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}" + assert b_k is None or b_k.shape == ( + Eq, + ), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}" + assert b_v is None or b_v.shape == ( + Eq, + ), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}" + return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v) + + +# _scaled_dot_product_attention = pi._C._nn._scaled_dot_product_attention + + +def _mha_shape_check( + query: Tensor, + key: Tensor, + value: Tensor, + key_padding_mask: Optional[Tensor], + attn_mask: Optional[Tensor], + num_heads: int, +): + # Verifies the expected shape for `query, `key`, `value`, `key_padding_mask` and `attn_mask` + # and returns if the input is batched or not. + # Raises an error if `query` is not 2-D (unbatched) or 3-D (batched) tensor. + + # Shape check. + if query.dim() == 3: + # Batched Inputs + is_batched = True + assert key.dim() == 3 and value.dim() == 3, ( + "For batched (3-D) `query`, expected `key` and `value` to be 3-D" + f" but found {key.dim()}-D and {value.dim()}-D tensors respectively" + ) + if key_padding_mask is not None: + assert key_padding_mask.dim() == 2, ( + "For batched (3-D) `query`, expected `key_padding_mask` to be `None` or 2-D" + f" but found {key_padding_mask.dim()}-D tensor instead" + ) + if attn_mask is not None: + assert attn_mask.dim() in (2, 3), ( + "For batched (3-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D" + f" but found {attn_mask.dim()}-D tensor instead" + ) + elif query.dim() == 2: + # Unbatched Inputs + is_batched = False + assert key.dim() == 2 and value.dim() == 2, ( + "For unbatched (2-D) `query`, expected `key` and `value` to be 2-D" + f" but found {key.dim()}-D and {value.dim()}-D tensors respectively" + ) + + if key_padding_mask is not None: + assert key_padding_mask.dim() == 1, ( + "For unbatched (2-D) `query`, expected `key_padding_mask` to be `None` or 1-D" + f" but found {key_padding_mask.dim()}-D tensor instead" + ) + + if attn_mask is not None: + assert attn_mask.dim() in (2, 3), ( + "For unbatched (2-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D" + f" but found {attn_mask.dim()}-D tensor instead" + ) + if attn_mask.dim() == 3: + expected_shape = (num_heads, query.shape[0], key.shape[0]) + assert ( + attn_mask.shape == expected_shape + ), f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}" + else: + raise AssertionError( + f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor" + ) + + return is_batched + + +def multi_head_attention_forward( + query: Tensor, + key: Tensor, + value: Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Optional[Tensor], + in_proj_bias: Optional[Tensor], + bias_k: Optional[Tensor], + bias_v: Optional[Tensor], + add_zero_attn: bool, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Optional[Tensor], + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + use_separate_proj_weight: bool = False, + q_proj_weight: Optional[Tensor] = None, + k_proj_weight: Optional[Tensor] = None, + v_proj_weight: Optional[Tensor] = None, + static_k: Optional[Tensor] = None, + static_v: Optional[Tensor] = None, + average_attn_weights: bool = True, + is_causal: bool = False, +) -> Tuple[Tensor, Optional[Tensor]]: + tens_ops = ( + query, + key, + value, + in_proj_weight, + in_proj_bias, + bias_k, + bias_v, + out_proj_weight, + out_proj_bias, + ) + is_batched = _mha_shape_check( + query, key, value, key_padding_mask, attn_mask, num_heads + ) + + # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input + # is batched, run the computation and before returning squeeze the + # batch dimension so that the output doesn't carry this temporary batch dimension. + if not is_batched: + # unsqueeze if the input is unbatched + query = query.unsqueeze(1) + key = key.unsqueeze(1) + value = value.unsqueeze(1) + if key_padding_mask is not None: + key_padding_mask = key_padding_mask.unsqueeze(0) + + # set up shape vars + tgt_len, bsz, embed_dim = query.shape + src_len, _, _ = key.shape + if key_padding_mask is not None: + _kpm_dtype = key_padding_mask.dtype + if _kpm_dtype != pi.bool and not pi.is_floating_point(key_padding_mask): + raise AssertionError( + "only bool and floating types of key_padding_mask are supported" + ) + assert ( + embed_dim == embed_dim_to_check + ), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}" + if isinstance(embed_dim, pi.Tensor): + # embed_dim can be a tensor when JIT tracing + head_dim = embed_dim.div(num_heads, rounding_mode="trunc") + else: + head_dim = embed_dim // num_heads + assert ( + head_dim * num_heads == embed_dim + ), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}" + if use_separate_proj_weight: + # allow MHA to have different embedding dimensions when separate projection weights are used + assert ( + key.shape[:2] == value.shape[:2] + ), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}" + else: + assert ( + key.shape == value.shape + ), f"key shape {key.shape} does not match value shape {value.shape}" + + # + # compute in-projection + # + if not use_separate_proj_weight: + assert ( + in_proj_weight is not None + ), "use_separate_proj_weight is False but in_proj_weight is None" + q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias) + else: + assert ( + q_proj_weight is not None + ), "use_separate_proj_weight is True but q_proj_weight is None" + assert ( + k_proj_weight is not None + ), "use_separate_proj_weight is True but k_proj_weight is None" + assert ( + v_proj_weight is not None + ), "use_separate_proj_weight is True but v_proj_weight is None" + if in_proj_bias is None: + b_q = b_k = b_v = None + else: + b_q, b_k, b_v = in_proj_bias.chunk(3) + q, k, v = _in_projection( + query, + key, + value, + q_proj_weight, + k_proj_weight, + v_proj_weight, + b_q, + b_k, + b_v, + ) + + # prep attention mask + if attn_mask is not None: + if attn_mask.dtype == pi.uint8: + warnings.warn( + "Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead." + ) + attn_mask = attn_mask.to(pi.bool) + else: + assert ( + attn_mask.is_floating_point() or attn_mask.dtype == pi.bool + ), f"Only float, byte, and bool types are supported for attn_mask, not {attn_mask.dtype}" + # ensure attn_mask's dim is 3 + if attn_mask.dim() == 2: + correct_2d_size = (tgt_len, src_len) + if attn_mask.shape != correct_2d_size: + raise RuntimeError( + f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}." + ) + attn_mask = attn_mask.unsqueeze(0) + elif attn_mask.dim() == 3: + correct_3d_size = (bsz * num_heads, tgt_len, src_len) + if attn_mask.shape != correct_3d_size: + raise RuntimeError( + f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}." + ) + else: + raise RuntimeError( + f"attn_mask's dimension {attn_mask.dim()} is not supported" + ) + + # add bias along batch dimension (currently second) + if bias_k is not None and bias_v is not None: + assert static_k is None, "bias cannot be added to static key." + assert static_v is None, "bias cannot be added to static value." + k = pi.cat([k, bias_k.repeat(1, bsz, 1)]) + v = pi.cat([v, bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = pad(key_padding_mask, (0, 1)) + else: + assert bias_k is None + assert bias_v is None + + # + # reshape q, k, v for multihead attention and make em batch first + # + q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) + if static_k is None: + k = k.contiguous().view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1) + else: + # TODO finish disentangling control flow so we don't do in-projections when statics are passed + assert ( + static_k.size(0) == bsz * num_heads + ), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}" + assert ( + static_k.size(2) == head_dim + ), f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}" + k = static_k + if static_v is None: + v = v.contiguous().view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1) + else: + # TODO finish disentangling control flow so we don't do in-projections when statics are passed + assert ( + static_v.size(0) == bsz * num_heads + ), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}" + assert ( + static_v.size(2) == head_dim + ), f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}" + v = static_v + + # add zero attention along batch dimension (now first) + if add_zero_attn: + zero_attn_shape = (bsz * num_heads, 1, head_dim) + k = pi.cat( + [k, pi.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1 + ) + v = pi.cat( + [v, pi.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1 + ) + if attn_mask is not None: + attn_mask = pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = pad(key_padding_mask, (0, 1)) + + # update source sequence length after adjustments + src_len = k.size(1) + + # merge key padding and attention masks + if key_padding_mask is not None: + assert key_padding_mask.shape == ( + bsz, + src_len, + ), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}" + key_padding_mask = ( + key_padding_mask.view(bsz, 1, 1, src_len) + .expand(-1, num_heads, -1, -1) + .reshape(bsz * num_heads, 1, src_len) + ) + if attn_mask is None: + attn_mask = key_padding_mask + elif attn_mask.dtype == pi.bool: + attn_mask = attn_mask.logical_or(key_padding_mask) + else: + attn_mask = attn_mask.masked_fill(key_padding_mask, float("-inf")) + + # convert mask to float + if attn_mask is not None and attn_mask.dtype == pi.bool: + new_attn_mask = pi.zeros_like(attn_mask, dtype=q.dtype) + new_attn_mask.masked_fill_(attn_mask, float("-inf")) + attn_mask = new_attn_mask + + # adjust dropout probability + if not training: + dropout_p = 0.0 + + # + # (deep breath) calculate attention and out projection + # + + if attn_mask is not None: + if attn_mask.size(0) == 1: + attn_mask = attn_mask.unsqueeze(0) + else: + attn_mask = attn_mask.view(bsz, num_heads, -1, src_len) + + q = q.view(bsz, num_heads, tgt_len, head_dim) + k = k.view(bsz, num_heads, src_len, head_dim) + v = v.view(bsz, num_heads, src_len, head_dim) + + attn_output, attn_output_weights = _scaled_dot_product_attention( + q, k, v, attn_mask, dropout_p, need_weights, is_causal + ) + attn_output = ( + attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim) + ) + + attn_output = linear(attn_output, out_proj_weight, out_proj_bias) + attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) + + if need_weights: + # optionally average attention weights over heads + attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) + if average_attn_weights: + attn_output_weights = attn_output_weights.sum(dim=1) / num_heads + + if not is_batched: + # squeeze the output if input was unbatched + attn_output = attn_output.squeeze(1) + attn_output_weights = attn_output_weights.squeeze(0) + return attn_output, attn_output_weights + else: + if not is_batched: + # squeeze the output if input was unbatched + attn_output = attn_output.squeeze(1) + return attn_output, None diff --git a/pi/nn/init.py b/pi/nn/init.py new file mode 100644 index 0000000..19a8065 --- /dev/null +++ b/pi/nn/init.py @@ -0,0 +1,335 @@ +import math +import warnings + +from pi import Tensor +import pi + + +# These no_grad_* functions are necessary as wrappers around the parts of these +# functions that use `with pi.no_grad()`. The JIT doesn't support context +# managers, so these need to be implemented as builtins. Using these wrappers +# lets us keep those builtins small and re-usable. +def _no_grad_uniform_(tensor, a, b): + return tensor.uniform_(a, b) + + +def _no_grad_normal_(tensor, mean, std): + return tensor.normal_(mean, std) + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def _no_grad_fill_(tensor, val): + return tensor.fill_(val) + + +def _no_grad_zero_(tensor): + return tensor.zero_() + + +def calculate_gain(nonlinearity, param=None): + + linear_fns = [ + "linear", + "conv1d", + "conv2d", + "conv3d", + "conv_transpose1d", + "conv_transpose2d", + "conv_transpose3d", + ] + if nonlinearity in linear_fns or nonlinearity == "sigmoid": + return 1 + elif nonlinearity == "tanh": + return 5.0 / 3 + elif nonlinearity == "relu": + return math.sqrt(2.0) + elif nonlinearity == "leaky_relu": + if param is None: + negative_slope = 0.01 + elif ( + not isinstance(param, bool) + and isinstance(param, int) + or isinstance(param, float) + ): + # True/False are instances of int, hence check above + negative_slope = param + else: + raise ValueError("negative_slope {} not a valid number".format(param)) + return math.sqrt(2.0 / (1 + negative_slope**2)) + elif nonlinearity == "selu": + return ( + 3.0 / 4 + ) # Value found empirically (https://github.com/pyshark.pyshark.pull/50664) + else: + raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) + + +def uniform_(tensor: Tensor, a: float = 0.0, b: float = 1.0) -> Tensor: + return _no_grad_uniform_(tensor, a, b) + + +def normal_(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> Tensor: + return _no_grad_normal_(tensor, mean, std) + + +def trunc_normal_( + tensor: Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0 +) -> Tensor: + + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def constant_(tensor: Tensor, val: float) -> Tensor: + return _no_grad_fill_(tensor, val) + + +def ones_(tensor: Tensor) -> Tensor: + return _no_grad_fill_(tensor, 1.0) + + +def zeros_(tensor: Tensor) -> Tensor: + return _no_grad_zero_(tensor) + + +def eye_(tensor): + + if tensor.ndimension() != 2: + raise ValueError("Only tensors with 2 dimensions are supported") + + pi.eye(*tensor.shape, out=tensor, requires_grad=tensor.requires_grad) + return tensor + + +def dirac_(tensor, groups=1): + + dimensions = tensor.ndimension() + if dimensions not in [3, 4, 5]: + raise ValueError("Only tensors with 3, 4, or 5 dimensions are supported") + + sizes = tensor.size() + + if sizes[0] % groups != 0: + raise ValueError("dim 0 must be divisible by groups") + + out_chans_per_grp = sizes[0] // groups + min_dim = min(out_chans_per_grp, sizes[1]) + + tensor.zero_() + + for g in range(groups): + for d in range(min_dim): + if dimensions == 3: # Temporal convolution + tensor[g * out_chans_per_grp + d, d, tensor.size(2) // 2] = 1 + elif dimensions == 4: # Spatial convolution + tensor[ + g * out_chans_per_grp + d, + d, + tensor.size(2) // 2, + tensor.size(3) // 2, + ] = 1 + else: # Volumetric convolution + tensor[ + g * out_chans_per_grp + d, + d, + tensor.size(2) // 2, + tensor.size(3) // 2, + tensor.size(4) // 2, + ] = 1 + return tensor + + +def _calculate_fan_in_and_fan_out(tensor): + dimensions = tensor.dim() + if dimensions < 2: + raise ValueError( + "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions" + ) + + num_input_fmaps = tensor.size(1) + num_output_fmaps = tensor.size(0) + receptive_field_size = 1 + if tensor.dim() > 2: + # math.prod is not always available, accumulate the product manually + # we could use functools.reduce but that is not supported by pi.cript + for s in tensor.shape[2:]: + receptive_field_size *= s + fan_in = num_input_fmaps * receptive_field_size + fan_out = num_output_fmaps * receptive_field_size + + return fan_in, fan_out + + +def xavier_uniform_(tensor: Tensor, gain: float = 1.0) -> Tensor: + + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) + a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation + + return _no_grad_uniform_(tensor, -a, a) + + +def xavier_normal_(tensor: Tensor, gain: float = 1.0) -> Tensor: + + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) + + return _no_grad_normal_(tensor, 0.0, std) + + +def _calculate_correct_fan(tensor, mode): + mode = mode.lower() + valid_modes = ["fan_in", "fan_out"] + if mode not in valid_modes: + raise ValueError( + "Mode {} not supported, please use one of {}".format(mode, valid_modes) + ) + + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + return fan_in if mode == "fan_in" else fan_out + + +def kaiming_uniform_( + tensor: Tensor, a: float = 0, mode: str = "fan_in", nonlinearity: str = "leaky_relu" +): + if 0 in tensor.shape: + warnings.warn("Initializing zero-element tensors is a no-op") + return tensor + fan = _calculate_correct_fan(tensor, mode) + gain = calculate_gain(nonlinearity, a) + std = gain / math.sqrt(fan) + bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation + return tensor.uniform_(-bound, bound) + + +def kaiming_normal_( + tensor: Tensor, a: float = 0, mode: str = "fan_in", nonlinearity: str = "leaky_relu" +): + + if 0 in tensor.shape: + warnings.warn("Initializing zero-element tensors is a no-op") + return tensor + fan = _calculate_correct_fan(tensor, mode) + gain = calculate_gain(nonlinearity, a) + std = gain / math.sqrt(fan) + return tensor.normal_(0, std) + + +def orthogonal_(tensor, gain=1): + + if tensor.ndimension() < 2: + raise ValueError("Only tensors with 2 or more dimensions are supported") + + if tensor.numel() == 0: + # no-op + return tensor + rows = tensor.size(0) + cols = tensor.numel() // rows + flattened = tensor.new(rows, cols).normal_(0, 1) + + if rows < cols: + flattened.t_() + + # Compute the qr factorization + q, r = pi.linalg.qr(flattened) + # Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf + d = pi.diag(r, 0) + ph = d.sign() + q *= ph + + if rows < cols: + q.t_() + + tensor.view_as(q).copy_(q) + tensor.mul_(gain) + return tensor + + +def sparse_(tensor, sparsity, std=0.01): + + if tensor.ndimension() != 2: + raise ValueError("Only tensors with 2 dimensions are supported") + + rows, cols = tensor.shape + num_zeros = int(math.ceil(sparsity * rows)) + + tensor.normal_(0, std) + for col_idx in range(cols): + row_indices = pi.randperm(rows) + zero_indices = row_indices[:num_zeros] + tensor[zero_indices, col_idx] = 0 + return tensor + + +# for backward compatibility +def _make_deprecate(meth): + new_name = meth.__name__ + old_name = new_name[:-1] + + def deprecated_init(*args, **kwargs): + warnings.warn( + "nn.init.{} is now deprecated in favor of nn.init.{}.".format( + old_name, new_name + ), + stacklevel=2, + ) + return meth(*args, **kwargs) + + deprecated_init.__doc__ = r""" + {old_name}(...) + + .. warning:: + This method is now deprecated in favor of :func:`pi.nn.init.{new_name}`. + + See :func:`~pi.nn.init.{new_name}` for details.""".format( + old_name=old_name, new_name=new_name + ) + deprecated_init.__name__ = old_name + return deprecated_init + + +uniform = _make_deprecate(uniform_) +normal = _make_deprecate(normal_) +constant = _make_deprecate(constant_) +eye = _make_deprecate(eye_) +dirac = _make_deprecate(dirac_) +xavier_uniform = _make_deprecate(xavier_uniform_) +xavier_normal = _make_deprecate(xavier_normal_) +kaiming_uniform = _make_deprecate(kaiming_uniform_) +kaiming_normal = _make_deprecate(kaiming_normal_) +orthogonal = _make_deprecate(orthogonal_) +sparse = _make_deprecate(sparse_) diff --git a/pi/nn/modules/__init__.py b/pi/nn/modules/__init__.py new file mode 100644 index 0000000..45bea76 --- /dev/null +++ b/pi/nn/modules/__init__.py @@ -0,0 +1,340 @@ +from .module import Module + +from .linear import ( + Identity, + Linear, + Bilinear, +) +from .conv import ( + Conv1d, + Conv2d, + Conv3d, + ConvTranspose1d, + ConvTranspose2d, + ConvTranspose3d, +) + +from .activation import ( + Threshold, + ReLU, + Hardtanh, + ReLU6, + Sigmoid, + Tanh, + Softmax, + Softmax2d, + LogSoftmax, + ELU, + SELU, + CELU, + GELU, + Hardshrink, + LeakyReLU, + LogSigmoid, + Softplus, + Softshrink, + MultiheadAttention, + PReLU, + Softsign, + Softmin, + Tanhshrink, + RReLU, + GLU, + Hardsigmoid, + Hardswish, + SiLU, + Mish, +) + +from .loss import ( + L1Loss, + NLLLoss, + KLDivLoss, + MSELoss, + BCELoss, + BCEWithLogitsLoss, + NLLLoss2d, + CosineEmbeddingLoss, + CTCLoss, + HingeEmbeddingLoss, + MarginRankingLoss, + MultiLabelMarginLoss, + MultiLabelSoftMarginLoss, + MultiMarginLoss, + SmoothL1Loss, + HuberLoss, + SoftMarginLoss, + CrossEntropyLoss, + TripletMarginLoss, + TripletMarginWithDistanceLoss, + PoissonNLLLoss, + GaussianNLLLoss, +) +from .container import ( + Container, + Sequential, + ModuleList, + ModuleDict, + ParameterList, + ParameterDict, +) +from .pooling import ( + AvgPool1d, + AvgPool2d, + AvgPool3d, + MaxPool1d, + MaxPool2d, + MaxPool3d, + MaxUnpool1d, + MaxUnpool2d, + MaxUnpool3d, + FractionalMaxPool2d, + FractionalMaxPool3d, + LPPool1d, + LPPool2d, + AdaptiveMaxPool1d, + AdaptiveMaxPool2d, + AdaptiveMaxPool3d, + AdaptiveAvgPool1d, + AdaptiveAvgPool2d, + AdaptiveAvgPool3d, +) +from .batchnorm import ( + BatchNorm1d, + BatchNorm2d, + BatchNorm3d, + SyncBatchNorm, + # LazyBatchNorm1d, + # LazyBatchNorm2d, + # LazyBatchNorm3d, +) +from .instancenorm import ( + InstanceNorm1d, + InstanceNorm2d, + InstanceNorm3d, + # LazyInstanceNorm1d, + # LazyInstanceNorm2d, + # LazyInstanceNorm3d, +) +from .normalization import ( + LocalResponseNorm, + # CrossMapLRN2d, + LayerNorm, + GroupNorm, +) +from .dropout import ( + Dropout, + Dropout1d, + Dropout2d, + Dropout3d, + AlphaDropout, + FeatureAlphaDropout, +) +from .padding import ( + ReflectionPad1d, + ReflectionPad2d, + ReflectionPad3d, + ReplicationPad1d, + ReplicationPad2d, + ReplicationPad3d, + ZeroPad2d, + ConstantPad1d, + ConstantPad2d, + ConstantPad3d, +) + +from .sparse import Embedding, EmbeddingBag + +# from .rnn import ( +# RNNBase, +# RNN, +# LSTM, +# GRU, +# RNNCellBase, +# RNNCell, +# LSTMCell, +# GRUCell, +# ) +from .pixelshuffle import ( + PixelShuffle, + PixelUnshuffle, +) +from .upsampling import ( + UpsamplingNearest2d, + UpsamplingBilinear2d, + Upsample, +) +from .distance import PairwiseDistance, CosineSimilarity + +from .fold import ( + Fold, + Unfold, +) +from .adaptive import AdaptiveLogSoftmaxWithLoss +from .transformer import ( + TransformerEncoder, + TransformerDecoder, + TransformerEncoderLayer, + TransformerDecoderLayer, + Transformer, +) +from .flatten import Flatten, Unflatten + +from .channelshuffle import ChannelShuffle + +__all__ = [ + "Module", + "Identity", + "AdaptiveAvgPool1d", + "AdaptiveAvgPool2d", + "AdaptiveAvgPool3d", + "AdaptiveLogSoftmaxWithLoss", + "AdaptiveMaxPool1d", + "AdaptiveMaxPool2d", + "AdaptiveMaxPool3d", + "AlphaDropout", + "AvgPool1d", + "AvgPool2d", + "AvgPool3d", + "BCELoss", + "BCEWithLogitsLoss", + "BatchNorm1d", + "BatchNorm2d", + "BatchNorm3d", + "Bilinear", + "CELU", + "CTCLoss", + "ChannelShuffle", + "ConstantPad1d", + "ConstantPad2d", + "ConstantPad3d", + "Container", + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", + "CosineEmbeddingLoss", + "CosineSimilarity", + "CrossEntropyLoss", + # "CrossMapLRN2d", + "Dropout", + "Dropout1d", + "Dropout2d", + "Dropout3d", + "ELU", + "Embedding", + "EmbeddingBag", + "FeatureAlphaDropout", + "Flatten", + "Fold", + "FractionalMaxPool2d", + "FractionalMaxPool3d", + "GELU", + "GLU", + # "GRU", + # "GRUCell", + "GaussianNLLLoss", + "GroupNorm", + "Hardshrink", + "Hardsigmoid", + "Hardswish", + "Hardtanh", + "HingeEmbeddingLoss", + "HuberLoss", + "InstanceNorm1d", + "InstanceNorm2d", + "InstanceNorm3d", + "KLDivLoss", + "L1Loss", + "LPPool1d", + "LPPool2d", + # "LSTM", + # "LSTMCell", + "LayerNorm", + # "LazyBatchNorm1d", + # "LazyBatchNorm2d", + # "LazyBatchNorm3d", + # "LazyConv1d", + # "LazyConv2d", + # "LazyConv3d", + # "LazyConvTranspose1d", + # "LazyConvTranspose2d", + # "LazyConvTranspose3d", + # "LazyInstanceNorm1d", + # "LazyInstanceNorm2d", + # "LazyInstanceNorm3d", + # "LazyLinear", + "LeakyReLU", + "Linear", + "LocalResponseNorm", + "LogSigmoid", + "LogSoftmax", + "MSELoss", + "MarginRankingLoss", + "MaxPool1d", + "MaxPool2d", + "MaxPool3d", + "MaxUnpool1d", + "MaxUnpool2d", + "MaxUnpool3d", + "Mish", + "ModuleDict", + "ModuleList", + "MultiLabelMarginLoss", + "MultiLabelSoftMarginLoss", + "MultiMarginLoss", + "MultiheadAttention", + "NLLLoss", + "NLLLoss2d", + "PReLU", + "PairwiseDistance", + "ParameterDict", + "ParameterList", + "PixelShuffle", + "PixelUnshuffle", + "PoissonNLLLoss", + # "RNN", + # "RNNBase", + # "RNNCell", + # "RNNCellBase", + "RReLU", + "ReLU", + "ReLU6", + "ReflectionPad1d", + "ReflectionPad2d", + "ReflectionPad3d", + "ReplicationPad1d", + "ReplicationPad2d", + "ReplicationPad3d", + "SELU", + "Sequential", + "SiLU", + "Sigmoid", + "SmoothL1Loss", + "SoftMarginLoss", + "Softmax", + "Softmax2d", + "Softmin", + "Softplus", + "Softshrink", + "Softsign", + "SyncBatchNorm", + "Tanh", + "Tanhshrink", + "Threshold", + "Transformer", + "TransformerDecoder", + "TransformerDecoderLayer", + "TransformerEncoder", + "TransformerEncoderLayer", + "TripletMarginLoss", + "TripletMarginWithDistanceLoss", + "Unflatten", + "Unfold", + "Upsample", + "UpsamplingBilinear2d", + "UpsamplingNearest2d", + "ZeroPad2d", +] diff --git a/pi/nn/modules/activation.py b/pi/nn/modules/activation.py new file mode 100644 index 0000000..df24d9b --- /dev/null +++ b/pi/nn/modules/activation.py @@ -0,0 +1,792 @@ +import warnings +from typing import Optional, Tuple, Union + +import pi +from pi import Tensor + +from .module import Module +from .. import functional as F +from ..init import xavier_uniform_, constant_, xavier_normal_ +from ..parameter import UninitializedParameter +from .linear import NonDynamicallyQuantizableLinear + +__all__ = [ + "Threshold", + "ReLU", + "RReLU", + "Hardtanh", + "ReLU6", + "Sigmoid", + "Hardsigmoid", + "Tanh", + "SiLU", + "Mish", + "Hardswish", + "ELU", + "CELU", + "SELU", + "GLU", + "GELU", + "Hardshrink", + "LeakyReLU", + "LogSigmoid", + "Softplus", + "Softshrink", + "MultiheadAttention", + "PReLU", + "Softsign", + "Tanhshrink", + "Softmin", + "Softmax", + "Softmax2d", + "LogSoftmax", +] + + +class Threshold(Module): + + __constants__ = ["threshold", "value", "inplace"] + + threshold: float + value: float + inplace: bool + + def __init__(self, threshold: float, value: float, inplace: bool = False) -> None: + super(Threshold, self).__init__() + self.threshold = threshold + self.value = value + self.inplace = inplace + # TODO: check in THNN (if inplace == True, then assert value <= threshold) + + def forward(self, input: Tensor) -> Tensor: + return F.threshold(input, self.threshold, self.value, self.inplace) + + def extra_repr(self): + inplace_str = ", inplace=True" if self.inplace else "" + return "threshold={}, value={}{}".format( + self.threshold, self.value, inplace_str + ) + + +class ReLU(Module): + + __constants__ = ["inplace"] + inplace: bool + + def __init__(self, inplace: bool = False): + super(ReLU, self).__init__() + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + return F.relu(input, inplace=self.inplace) + + def extra_repr(self) -> str: + inplace_str = "inplace=True" if self.inplace else "" + return inplace_str + + +class RReLU(Module): + + __constants__ = ["lower", "upper", "inplace"] + + lower: float + upper: float + inplace: bool + + def __init__( + self, lower: float = 1.0 / 8, upper: float = 1.0 / 3, inplace: bool = False + ): + super(RReLU, self).__init__() + self.lower = lower + self.upper = upper + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + return F.rrelu(input, self.lower, self.upper, self.training, self.inplace) + + def extra_repr(self): + inplace_str = ", inplace=True" if self.inplace else "" + return "lower={}, upper={}{}".format(self.lower, self.upper, inplace_str) + + +class Hardtanh(Module): + + __constants__ = ["min_val", "max_val", "inplace"] + + min_val: float + max_val: float + inplace: bool + + def __init__( + self, + min_val: float = -1.0, + max_val: float = 1.0, + inplace: bool = False, + min_value: Optional[float] = None, + max_value: Optional[float] = None, + ) -> None: + super(Hardtanh, self).__init__() + if min_value is not None: + warnings.warn( + "keyword argument min_value is deprecated and rename to min_val" + ) + min_val = min_value + if max_value is not None: + warnings.warn( + "keyword argument max_value is deprecated and rename to max_val" + ) + max_val = max_value + + self.min_val = min_val + self.max_val = max_val + self.inplace = inplace + assert self.max_val > self.min_val + + def forward(self, input: Tensor) -> Tensor: + return F.hardtanh(input, self.min_val, self.max_val, self.inplace) + + def extra_repr(self) -> str: + inplace_str = ", inplace=True" if self.inplace else "" + return "min_val={}, max_val={}{}".format( + self.min_val, self.max_val, inplace_str + ) + + +class ReLU6(Hardtanh): + def __init__(self, inplace: bool = False): + super(ReLU6, self).__init__(0.0, 6.0, inplace) + + def extra_repr(self) -> str: + inplace_str = "inplace=True" if self.inplace else "" + return inplace_str + + +class Sigmoid(Module): + def forward(self, input: Tensor) -> Tensor: + return pi.sigmoid(input) + + +class Hardsigmoid(Module): + + __constants__ = ["inplace"] + + inplace: bool + + def __init__(self, inplace: bool = False) -> None: + super(Hardsigmoid, self).__init__() + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + return F.hardsigmoid(input, self.inplace) + + +class Tanh(Module): + def forward(self, input: Tensor) -> Tensor: + return pi.tanh(input) + + +class SiLU(Module): + + __constants__ = ["inplace"] + inplace: bool + + def __init__(self, inplace: bool = False): + super(SiLU, self).__init__() + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + return F.silu(input, inplace=self.inplace) + + def extra_repr(self) -> str: + inplace_str = "inplace=True" if self.inplace else "" + return inplace_str + + +class Mish(Module): + + __constants__ = ["inplace"] + inplace: bool + + def __init__(self, inplace: bool = False): + super(Mish, self).__init__() + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + return F.mish(input, inplace=self.inplace) + + def extra_repr(self) -> str: + inplace_str = "inplace=True" if self.inplace else "" + return inplace_str + + +class Hardswish(Module): + + __constants__ = ["inplace"] + + inplace: bool + + def __init__(self, inplace: bool = False) -> None: + super(Hardswish, self).__init__() + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + return F.hardswish(input, self.inplace) + + +class ELU(Module): + + __constants__ = ["alpha", "inplace"] + alpha: float + inplace: bool + + def __init__(self, alpha: float = 1.0, inplace: bool = False) -> None: + super(ELU, self).__init__() + self.alpha = alpha + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + return F.elu(input, self.alpha, self.inplace) + + def extra_repr(self) -> str: + inplace_str = ", inplace=True" if self.inplace else "" + return "alpha={}{}".format(self.alpha, inplace_str) + + +class CELU(Module): + + __constants__ = ["alpha", "inplace"] + alpha: float + inplace: bool + + def __init__(self, alpha: float = 1.0, inplace: bool = False) -> None: + super(CELU, self).__init__() + self.alpha = alpha + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + return F.celu(input, self.alpha, self.inplace) + + def extra_repr(self) -> str: + inplace_str = ", inplace=True" if self.inplace else "" + return "alpha={}{}".format(self.alpha, inplace_str) + + +class SELU(Module): + + __constants__ = ["inplace"] + inplace: bool + + def __init__(self, inplace: bool = False) -> None: + super(SELU, self).__init__() + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + return F.selu(input, self.inplace) + + def extra_repr(self) -> str: + inplace_str = "inplace=True" if self.inplace else "" + return inplace_str + + +class GLU(Module): + + __constants__ = ["dim"] + dim: int + + def __init__(self, dim: int = -1) -> None: + super(GLU, self).__init__() + self.dim = dim + + def forward(self, input: Tensor) -> Tensor: + return F.glu(input, self.dim) + + def extra_repr(self) -> str: + return "dim={}".format(self.dim) + + +class GELU(Module): + + __constants__ = ["approximate"] + approximate: str + + def __init__(self, approximate: str = "none") -> None: + super(GELU, self).__init__() + self.approximate = approximate + + def forward(self, input: Tensor) -> Tensor: + return F.gelu(input, approximate=self.approximate) + + def extra_repr(self) -> str: + return "approximate={}".format(repr(self.approximate)) + + +class Hardshrink(Module): + + __constants__ = ["lambd"] + lambd: float + + def __init__(self, lambd: float = 0.5) -> None: + super(Hardshrink, self).__init__() + self.lambd = lambd + + def forward(self, input: Tensor) -> Tensor: + return F.hardshrink(input, self.lambd) + + def extra_repr(self) -> str: + return "{}".format(self.lambd) + + +class LeakyReLU(Module): + + __constants__ = ["inplace", "negative_slope"] + inplace: bool + negative_slope: float + + def __init__(self, negative_slope: float = 1e-2, inplace: bool = False) -> None: + super(LeakyReLU, self).__init__() + self.negative_slope = negative_slope + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + return F.leaky_relu(input, self.negative_slope, self.inplace) + + def extra_repr(self) -> str: + inplace_str = ", inplace=True" if self.inplace else "" + return "negative_slope={}{}".format(self.negative_slope, inplace_str) + + +class LogSigmoid(Module): + def forward(self, input: Tensor) -> Tensor: + return F.logsigmoid(input) + + +class Softplus(Module): + __constants__ = ["beta", "threshold"] + beta: int + threshold: int + + def __init__(self, beta: int = 1, threshold: int = 20) -> None: + super(Softplus, self).__init__() + self.beta = beta + self.threshold = threshold + + def forward(self, input: Tensor) -> Tensor: + return F.softplus(input, self.beta, self.threshold) + + def extra_repr(self) -> str: + return "beta={}, threshold={}".format(self.beta, self.threshold) + + +class Softshrink(Module): + __constants__ = ["lambd"] + lambd: float + + def __init__(self, lambd: float = 0.5) -> None: + super(Softshrink, self).__init__() + self.lambd = lambd + + def forward(self, input: Tensor) -> Tensor: + return F.softshrink(input, self.lambd) + + def extra_repr(self) -> str: + return str(self.lambd) + + +class MultiheadAttention(Module): + __constants__ = ["batch_first"] + bias_k: Optional[Union[UninitializedParameter, pi.Tensor]] + bias_v: Optional[Union[UninitializedParameter, pi.Tensor]] + + def __init__( + self, + embed_dim, + num_heads, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + kdim=None, + vdim=None, + batch_first=False, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(MultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout = dropout + self.batch_first = batch_first + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + if not self._qkv_same_embed_dim: + self.q_proj_weight = UninitializedParameter( + (embed_dim, embed_dim), **factory_kwargs + ) + self.k_proj_weight = UninitializedParameter( + (embed_dim, self.kdim), **factory_kwargs + ) + self.v_proj_weight = UninitializedParameter( + (embed_dim, self.vdim), **factory_kwargs + ) + # self.register_("in_proj_weight", None) + else: + self.in_proj_weight = UninitializedParameter( + (3 * embed_dim, embed_dim), **factory_kwargs + ) + # self.register_("q_proj_weight", None) + # self.register_("k_proj_weight", None) + # self.register_("v_proj_weight", None) + + if bias: + self.in_proj_bias = UninitializedParameter(3 * embed_dim, **factory_kwargs) + else: + self.register_("in_proj_bias", None) + self.out_proj = NonDynamicallyQuantizableLinear( + embed_dim, embed_dim, bias=bias, **factory_kwargs + ) + + if add_bias_kv: + self.bias_k = UninitializedParameter((1, 1, embed_dim), **factory_kwargs) + self.bias_v = UninitializedParameter((1, 1, embed_dim), **factory_kwargs) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self._reset_parameters() + + def _reset_parameters(self): + if self._qkv_same_embed_dim: + xavier_uniform_(self.in_proj_weight) + else: + xavier_uniform_(self.q_proj_weight) + xavier_uniform_(self.k_proj_weight) + xavier_uniform_(self.v_proj_weight) + + if self.in_proj_bias is not None: + constant_(self.in_proj_bias, 0.0) + constant_(self.out_proj.bias, 0.0) + if self.bias_k is not None: + xavier_normal_(self.bias_k) + if self.bias_v is not None: + xavier_normal_(self.bias_v) + + def __setstate__(self, state): + # Support loading old MultiheadAttention checkpoints generated by v1.1.0 + if "_qkv_same_embed_dim" not in state: + state["_qkv_same_embed_dim"] = True + + super(MultiheadAttention, self).__setstate__(state) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + average_attn_weights: bool = True, + is_causal: bool = False, + ) -> Tuple[Tensor, Optional[Tensor]]: + + if attn_mask is not None and is_causal: + raise AssertionError("Only allow causal mask or attn_mask") + + is_batched = query.dim() == 3 + if key_padding_mask is not None: + _kpm_dtype = key_padding_mask.dtype + if _kpm_dtype != pi.bool and not pi.is_floating_point( + key_padding_mask + ): + raise AssertionError( + "only bool and floating types of key_padding_mask are supported" + ) + why_not_fast_path = "" + if not is_batched: + why_not_fast_path = ( + f"input not batched; expected query.dim() of 3 but got {query.dim()}" + ) + elif query is not key or key is not value: + # When lifting this restriction, don't forget to either + # enforce that the dtypes all match or test cases where + # they don't! + why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)" + elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype: + why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match" + elif ( + self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype + ): + # this case will fail anyway, but at least they'll get a useful error message. + why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match" + elif self.training: + why_not_fast_path = "training is enabled" + elif not self.batch_first: + why_not_fast_path = "batch_first was not True" + elif self.bias_k is not None: + why_not_fast_path = "self.bias_k was not None" + elif self.bias_v is not None: + why_not_fast_path = "self.bias_v was not None" + elif self.add_zero_attn: + why_not_fast_path = "add_zero_attn was enabled" + elif not self._qkv_same_embed_dim: + why_not_fast_path = "_qkv_same_embed_dim was not True" + elif query.is_nested and ( + key_padding_mask is not None or attn_mask is not None + ): + why_not_fast_path = "supplying both src_key_padding_mask and src_mask at the same time \ + is not supported with NestedTensor input" + elif pi.is_autocast_enabled(): + why_not_fast_path = "autocast is enabled" + + if not why_not_fast_path: + tensor_args = ( + query, + key, + value, + self.in_proj_weight, + self.in_proj_bias, + self.out_proj.weight, + self.out_proj.bias, + ) + # We have to use list comprehensions below because TorchScript does not support + # generator expressions. + if pi.overrides.has_torch_function(tensor_args): + why_not_fast_path = "some Tensor argument has_torch_function" + elif not all( + [ + (x is None or x.is_cuda or "cpu" in str(x.device)) + for x in tensor_args + ] + ): + why_not_fast_path = "some Tensor argument is neither CUDA nor CPU" + elif pi.is_grad_enabled() and any( + [x is not None and x.requires_grad for x in tensor_args] + ): + why_not_fast_path = ( + "grad is enabled and at least one of query or the " + "input/output projection weights or biases requires_grad" + ) + if not why_not_fast_path: + merged_mask, mask_type = self.merge_masks( + attn_mask, key_padding_mask, query + ) + + return pi._native_multi_head_attention( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.out_proj.weight, + self.out_proj.bias, + merged_mask, + need_weights, + average_attn_weights, + mask_type, + ) + + any_nested = query.is_nested or key.is_nested or value.is_nested + assert not any_nested, ( + "MultiheadAttention does not support NestedTensor outside of its fast path. " + + f"The fast path was not hit because {why_not_fast_path}" + ) + + if self.batch_first and is_batched: + # make sure that the transpose op does not affect the "is" property + if key is value: + if query is key: + query = key = value = query.transpose(1, 0) + else: + query, key = [x.transpose(1, 0) for x in (query, key)] + value = key + else: + query, key, value = [x.transpose(1, 0) for x in (query, key, value)] + + if not self._qkv_same_embed_dim: + attn_output, attn_output_weights = F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj_weight, + k_proj_weight=self.k_proj_weight, + v_proj_weight=self.v_proj_weight, + average_attn_weights=average_attn_weights, + is_causal=is_causal, + ) + else: + attn_output, attn_output_weights = F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + average_attn_weights=average_attn_weights, + is_causal=is_causal, + ) + if self.batch_first and is_batched: + return attn_output.transpose(1, 0), attn_output_weights + else: + return attn_output, attn_output_weights + + def merge_masks( + self, + attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], + query: Tensor, + ) -> Tuple[Optional[Tensor], Optional[int]]: + + mask_type: Optional[int] = None + merged_mask: Optional[Tensor] = None + if attn_mask is not None: + mask_type = 0 + merged_mask = attn_mask + if key_padding_mask is not None: + mask_type = 1 + merged_mask = key_padding_mask + if (attn_mask is not None) and (key_padding_mask is not None): + # In this branch query can't be a nested tensor, so it has a shape + batch_size, seq_len, _ = query.shape + mask_type = 2 + key_padding_mask_expanded = key_padding_mask.view( + batch_size, 1, 1, seq_len + ).expand(-1, self.num_heads, -1, -1) + attn_mask_expanded = attn_mask.view(1, 1, seq_len, seq_len).expand( + batch_size, self.num_heads, -1, -1 + ) + merged_mask = attn_mask_expanded.logical_or(key_padding_mask_expanded) + return merged_mask, mask_type + + +class PReLU(Module): + + __constants__ = ["num_parameters"] + num_parameters: int + + def __init__( + self, num_parameters: int = 1, init: float = 0.25, device=None, dtype=None + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + self.num_parameters = num_parameters + super(PReLU, self).__init__() + self.weight = UninitializedParameter(num_parameters, **factory_kwargs).fill_( + init + ) + + def forward(self, input: Tensor) -> Tensor: + return F.prelu(input, self.weight) + + def extra_repr(self) -> str: + return "num_parameters={}".format(self.num_parameters) + + +class Softsign(Module): + def forward(self, input: Tensor) -> Tensor: + return F.softsign(input) + + +class Tanhshrink(Module): + def forward(self, input: Tensor) -> Tensor: + return F.tanhshrink(input) + + +class Softmin(Module): + + __constants__ = ["dim"] + dim: Optional[int] + + def __init__(self, dim: Optional[int] = None) -> None: + super(Softmin, self).__init__() + self.dim = dim + + def __setstate__(self, state): + super().__setstate__(state) + if not hasattr(self, "dim"): + self.dim = None + + def forward(self, input: Tensor) -> Tensor: + return F.softmin(input, self.dim, _stacklevel=5) + + def extra_repr(self): + return "dim={dim}".format(dim=self.dim) + + +class Softmax(Module): + __constants__ = ["dim"] + dim: Optional[int] + + def __init__(self, dim: Optional[int] = None) -> None: + super(Softmax, self).__init__() + self.dim = dim + + def forward(self, input: Tensor) -> Tensor: + return F.softmax(input, self.dim, _stacklevel=5) + + def extra_repr(self) -> str: + return "dim={dim}".format(dim=self.dim) + + +class Softmax2d(Module): + def forward(self, input: Tensor) -> Tensor: + assert ( + input.dim() == 4 or input.dim() == 3 + ), "Softmax2d requires a 3D or 4D tensor as input" + return F.softmax(input, -3, _stacklevel=5) + + +class LogSoftmax(Module): + + __constants__ = ["dim"] + dim: Optional[int] + + def __init__(self, dim: Optional[int] = None) -> None: + super(LogSoftmax, self).__init__() + self.dim = dim + + def __setstate__(self, state): + super().__setstate__(state) + if not hasattr(self, "dim"): + self.dim = None + + def forward(self, input: Tensor) -> Tensor: + return F.log_softmax(input, self.dim, _stacklevel=5) + + def extra_repr(self): + return "dim={dim}".format(dim=self.dim) diff --git a/pi/nn/modules/adaptive.py b/pi/nn/modules/adaptive.py new file mode 100644 index 0000000..6f8bb07 --- /dev/null +++ b/pi/nn/modules/adaptive.py @@ -0,0 +1,218 @@ +from collections import namedtuple + +import pi + +from pi import Tensor +from typing import List, Sequence + +from . import Sequential, ModuleList, Linear +from .module import Module +from ..functional import log_softmax + +__all__ = ["AdaptiveLogSoftmaxWithLoss"] + +_ASMoutput = namedtuple("_ASMoutput", ["output", "loss"]) + + +class AdaptiveLogSoftmaxWithLoss(Module): + + in_features: int + n_classes: int + cutoffs: List[int] + div_value: float + head_bias: bool + head: Linear + tail: ModuleList + + def __init__( + self, + in_features: int, + n_classes: int, + cutoffs: Sequence[int], + div_value: float = 4.0, + head_bias: bool = False, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(AdaptiveLogSoftmaxWithLoss, self).__init__() + + cutoffs = list(cutoffs) + + if ( + (cutoffs != sorted(cutoffs)) + or (min(cutoffs) <= 0) + or (max(cutoffs) > (n_classes - 1)) + or (len(set(cutoffs)) != len(cutoffs)) + or any([int(c) != c for c in cutoffs]) + ): + + raise ValueError( + "cutoffs should be a sequence of unique, positive " + "integers sorted in an increasing order, where " + "each value is between 1 and n_classes-1" + ) + + self.in_features = in_features + self.n_classes = n_classes + self.cutoffs = cutoffs + [n_classes] + self.div_value = div_value + self.head_bias = head_bias + + self.shortlist_size = self.cutoffs[0] + self.n_clusters = len(self.cutoffs) - 1 + self.head_size = self.shortlist_size + self.n_clusters + + self.head = Linear( + self.in_features, self.head_size, bias=self.head_bias, **factory_kwargs + ) + self.tail = ModuleList() + + for i in range(self.n_clusters): + + hsz = int(self.in_features // (self.div_value ** (i + 1))) + osz = self.cutoffs[i + 1] - self.cutoffs[i] + + projection = Sequential( + Linear(self.in_features, hsz, bias=False, **factory_kwargs), + Linear(hsz, osz, bias=False, **factory_kwargs), + ) + + self.tail.append(projection) + + def reset_parameters(self) -> None: + self.head.reset_parameters() + for i2h, h2o in self.tail: + i2h.reset_parameters() + h2o.reset_parameters() + + def forward(self, input_: Tensor, target_: Tensor) -> _ASMoutput: + targ_dim = target_.dim() + + if targ_dim == 1: + if input_.size(0) != target_.size(0): + raise RuntimeError( + "Input and target should have the same size " + "in the batch dimension." + ) + if input_.dim() != 2: + raise RuntimeError( + "1D target tensor expects 2D input tensors, " + "but found inputs with size", + input_.size(), + ) + elif targ_dim == 0: + if input_.dim() != 1: + raise RuntimeError( + "0D target tensor expects 1D input tensors, " + "but found inputs with size", + input_.size(), + ) + else: + raise RuntimeError( + "0D or 1D target tensor expected, " "multi-target not supported" + ) + + is_batched = targ_dim > 0 + input = input_ if is_batched else input_.unsqueeze(0) + target = target_ if is_batched else target_.unsqueeze(0) + + used_rows = 0 + batch_size = target.size(0) + + output = input.new_zeros(batch_size) + gather_inds = target.new_empty(batch_size) + + cutoff_values = [0] + self.cutoffs + for i in range(len(cutoff_values) - 1): + + low_idx = cutoff_values[i] + high_idx = cutoff_values[i + 1] + + target_mask = (target >= low_idx) & (target < high_idx) + row_indices = target_mask.nonzero().squeeze() + + if row_indices.numel() == 0: + continue + + if i == 0: + gather_inds.index_copy_(0, row_indices, target[target_mask]) + + else: + relative_target = target[target_mask] - low_idx + input_subset = input.index_select(0, row_indices) + + cluster_output = self.tail[i - 1](input_subset) + cluster_index = self.shortlist_size + i - 1 + + gather_inds.index_fill_(0, row_indices, cluster_index) + cluster_logprob = log_softmax(cluster_output, dim=1) + local_logprob = cluster_logprob.gather(1, relative_target.unsqueeze(1)) + output.index_copy_(0, row_indices, local_logprob.squeeze(1)) + + used_rows += row_indices.numel() + + if used_rows != batch_size: + raise RuntimeError( + "Target values should be in [0, {}], " + "but values in range [{}, {}] " + "were found. ".format( + self.n_classes - 1, target.min().item(), target.max().item() + ) + ) + + head_output = self.head(input) + head_logprob = log_softmax(head_output, dim=1) + output += head_logprob.gather(1, gather_inds.unsqueeze(1)).squeeze() + loss = (-output).mean() + + if not is_batched: + output = output.squeeze(0) + + return _ASMoutput(output, loss) + + def _get_full_log_prob(self, input, head_output): + """Given input tensor, and output of `self.head`, + compute the log of the full distribution""" + + out = input.new_empty((head_output.size(0), self.n_classes)) + head_logprob = log_softmax(head_output, dim=1) + + out[:, : self.shortlist_size] = head_logprob[:, : self.shortlist_size] + + for i, (start_idx, stop_idx) in enumerate(zip(self.cutoffs, self.cutoffs[1:])): + cluster_output = self.tail[i](input) + cluster_logprob = log_softmax(cluster_output, dim=1) + output_logprob = cluster_logprob + head_logprob[ + :, self.shortlist_size + i + ].unsqueeze(1) + + out[:, start_idx:stop_idx] = output_logprob + + return out + + def log_prob(self, input: Tensor) -> Tensor: + + head_output = self.head(input) + return self._get_full_log_prob(input, head_output) + + def predict(self, input: Tensor) -> Tensor: + + head_output = self.head(input) + output = pi.argmax(head_output, dim=1) + not_in_shortlist = output >= self.shortlist_size + all_in_shortlist = not (not_in_shortlist.any()) + + if all_in_shortlist: + return output + + elif not_in_shortlist.all(): + log_prob = self._get_full_log_prob(input, head_output) + return pi.argmax(log_prob, dim=1) + + else: + log_prob = self._get_full_log_prob( + input[not_in_shortlist], head_output[not_in_shortlist] + ) + output[not_in_shortlist] = pi.argmax(log_prob, dim=1) + return output diff --git a/pi/nn/modules/batchnorm.py b/pi/nn/modules/batchnorm.py new file mode 100644 index 0000000..fa1bb9c --- /dev/null +++ b/pi/nn/modules/batchnorm.py @@ -0,0 +1,438 @@ +from typing import Optional, Any + +import pi +from pi import Tensor +from .module import Module +from .. import functional as F +from ..parameter import UninitializedParameter, UninitializedBuffer +from .. import init + + +__all__ = [ + "BatchNorm1d", + # "LazyBatchNorm1d", + "BatchNorm2d", + # "LazyBatchNorm2d", + "BatchNorm3d", + # "LazyBatchNorm3d", + "SyncBatchNorm", +] + + +class _NormBase(Module): + """Common base of _InstanceNorm and _BatchNorm""" + + _version = 2 + __constants__ = ["track_running_stats", "momentum", "eps", "num_features", "affine"] + num_features: int + eps: float + momentum: float + affine: bool + track_running_stats: bool + # WARNING: weight and bias purposely not defined here. + # See https://github.com/pytorch/pytorch/issues/39670 + + def __init__( + self, + num_features: int, + eps: float = 1e-5, + momentum: float = 0.1, + affine: bool = True, + track_running_stats: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(_NormBase, self).__init__() + self.num_features = num_features + self.eps = eps + self.momentum = momentum + self.affine = affine + self.track_running_stats = track_running_stats + if self.affine: + self.weight = UninitializedParameter(num_features, **factory_kwargs) + self.bias = UninitializedParameter(num_features, **factory_kwargs) + else: + self.register_("weight", None) + self.register_("bias", None) + if self.track_running_stats: + self.register_buffer( + "running_mean", + UninitializedBuffer(pi.zeros, (num_features,), **factory_kwargs), + ) + self.register_buffer( + "running_var", + UninitializedBuffer(pi.ones, (num_features,), **factory_kwargs), + ) + self.running_mean: Optional[Tensor] + self.running_var: Optional[Tensor] + self.register_buffer( + "num_batches_tracked", + UninitializedBuffer( + 0, + dtype=pi.long, + **{k: v for k, v in factory_kwargs.items() if k != "dtype"}, + ), + ) + self.num_batches_tracked: Optional[Tensor] + else: + self.register_buffer("running_mean", None) + self.register_buffer("running_var", None) + self.register_buffer("num_batches_tracked", None) + self.reset_parameters() + + def reset_running_stats(self) -> None: + if self.track_running_stats: + # running_mean/running_var/num_batches... are registered at runtime depending + # if self.track_running_stats is on + self.running_mean.zero_() # type: ignore[union-attr] + self.running_var.fill_(1) # type: ignore[union-attr] + self.num_batches_tracked.zero_() # type: ignore[union-attr,operator] + + def reset_parameters(self) -> None: + self.reset_running_stats() + if self.affine: + init.ones_(self.weight) + init.zeros_(self.bias) + + def _check_input_dim(self, input): + raise NotImplementedError + + def extra_repr(self): + return ( + "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, " + "track_running_stats={track_running_stats}".format(**self.__dict__) + ) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + + if (version is None or version < 2) and self.track_running_stats: + # at version 2: added num_batches_tracked buffer + # this should have a default value of 0 + num_batches_tracked_key = prefix + "num_batches_tracked" + if num_batches_tracked_key not in state_dict: + state_dict[num_batches_tracked_key] = pi.tensor(0, dtype=pi.long) + + super(_NormBase, self)._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + +class _BatchNorm(_NormBase): + def __init__( + self, + num_features: int, + eps: float = 1e-5, + momentum: float = 0.1, + affine: bool = True, + track_running_stats: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(_BatchNorm, self).__init__( + num_features, eps, momentum, affine, track_running_stats, **factory_kwargs + ) + + def forward(self, input: Tensor) -> Tensor: + self._check_input_dim(input) + + # exponential_average_factor is set to self.momentum + # (when it is available) only so that it gets updated + # in ONNX graph when this node is exported to ONNX. + if self.momentum is None: + exponential_average_factor = 0.0 + else: + exponential_average_factor = self.momentum + + if self.training and self.track_running_stats: + # TODO: if statement only here to tell the jit to skip emitting this when it is None + if self.num_batches_tracked is not None: # type: ignore[has-type] + self.num_batches_tracked.add_(1) # type: ignore[has-type] + if self.momentum is None: # use cumulative moving average + exponential_average_factor = 1.0 / float(self.num_batches_tracked) + else: # use exponential moving average + exponential_average_factor = self.momentum + + if self.training: + bn_training = True + else: + bn_training = (self.running_mean is None) and (self.running_var is None) + + return F.batch_norm( + input, + # If buffers are not to be tracked, ensure that they won't be updated + self.running_mean + if not self.training or self.track_running_stats + else None, + self.running_var if not self.training or self.track_running_stats else None, + self.weight, + self.bias, + bn_training, + exponential_average_factor, + self.eps, + ) + + +# class _LazyNormBase(LazyModuleMixin, _NormBase): +# +# weight: UninitializedParameter # type: ignore[assignment] +# bias: UninitializedParameter # type: ignore[assignment] +# +# def __init__( +# self, +# eps=1e-5, +# momentum=0.1, +# affine=True, +# track_running_stats=True, +# device=None, +# dtype=None, +# ) -> None: +# factory_kwargs = {"device": device, "dtype": dtype} +# super(_LazyNormBase, self).__init__( +# # affine and track_running_stats are hardcoded to False to +# # avoid creating tensors that will soon be overwritten. +# 0, +# eps, +# momentum, +# False, +# False, +# **factory_kwargs, +# ) +# self.affine = affine +# self.track_running_stats = track_running_stats +# if self.affine: +# self.weight = Uninitialized(**factory_kwargs) +# self.bias = Uninitialized(**factory_kwargs) +# if self.track_running_stats: +# self.running_mean = UninitializedBuffer(**factory_kwargs) +# self.running_var = UninitializedBuffer(**factory_kwargs) +# self.num_batches_tracked = pi.tensor( +# 0, +# dtype=pi.long, +# **{k: v for k, v in factory_kwargs.items() if k != "dtype"}, +# ) +# +# def reset_parameters(self) -> None: +# if not self.has_uninitialized_params() and self.num_features != 0: +# super().reset_parameters() +# +# def initialize_parameters(self, input) -> None: # type: ignore[override] +# if self.has_uninitialized_params(): +# self.num_features = input.shape[1] +# if self.affine: +# assert isinstance(self.weight, UninitializedParameter) +# assert isinstance(self.bias, UninitializedParameter) +# self.weight.materialize((self.num_features,)) +# self.bias.materialize((self.num_features,)) +# if self.track_running_stats: +# self.running_mean.materialize( +# (self.num_features,) +# ) # type:ignore[union-attr] +# self.running_var.materialize( +# (self.num_features,) +# ) # type:ignore[union-attr] +# self.reset_parameters() + + +class BatchNorm1d(_BatchNorm): + def _check_input_dim(self, input): + if input.dim() != 2 and input.dim() != 3: + raise ValueError( + "expected 2D or 3D input (got {}D input)".format(input.dim()) + ) + + +# class LazyBatchNorm1d(_LazyNormBase, _BatchNorm): +# +# cls_to_become = BatchNorm1d # type: ignore[assignment] +# +# def _check_input_dim(self, input): +# if input.dim() != 2 and input.dim() != 3: +# raise ValueError( +# "expected 2D or 3D input (got {}D input)".format(input.dim()) +# ) + + +class BatchNorm2d(_BatchNorm): + def _check_input_dim(self, input): + if input.dim() != 4: + raise ValueError("expected 4D input (got {}D input)".format(input.dim())) + + +# class LazyBatchNorm2d(_LazyNormBase, _BatchNorm): +# +# cls_to_become = BatchNorm2d # type: ignore[assignment] +# +# def _check_input_dim(self, input): +# if input.dim() != 4: +# raise ValueError("expected 4D input (got {}D input)".format(input.dim())) + + +class BatchNorm3d(_BatchNorm): + def _check_input_dim(self, input): + if input.dim() != 5: + raise ValueError("expected 5D input (got {}D input)".format(input.dim())) + + +# class LazyBatchNorm3d(_LazyNormBase, _BatchNorm): +# +# cls_to_become = BatchNorm3d # type: ignore[assignment] +# +# def _check_input_dim(self, input): +# if input.dim() != 5: +# raise ValueError("expected 5D input (got {}D input)".format(input.dim())) + + +class SyncBatchNorm(_BatchNorm): + def __init__( + self, + num_features: int, + eps: float = 1e-5, + momentum: float = 0.1, + affine: bool = True, + track_running_stats: bool = True, + process_group: Optional[Any] = None, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(SyncBatchNorm, self).__init__( + num_features, eps, momentum, affine, track_running_stats, **factory_kwargs + ) + self.process_group = process_group + + def _check_input_dim(self, input): + if input.dim() < 2: + raise ValueError( + "expected at least 2D input (got {}D input)".format(input.dim()) + ) + + def _check_non_zero_input_channels(self, input): + if input.size(1) == 0: + raise ValueError( + "SyncBatchNorm number of input channels should be non-zero" + ) + + def forward(self, input: Tensor) -> Tensor: + self._check_input_dim(input) + self._check_non_zero_input_channels(input) + + # exponential_average_factor is set to self.momentum + # (when it is available) only so that it gets updated + # in ONNX graph when this node is exported to ONNX. + if self.momentum is None: + exponential_average_factor = 0.0 + else: + exponential_average_factor = self.momentum + + if self.training and self.track_running_stats: + assert self.num_batches_tracked is not None + self.num_batches_tracked.add_(1) + if self.momentum is None: # use cumulative moving average + exponential_average_factor = 1.0 / self.num_batches_tracked.item() + else: # use exponential moving average + exponential_average_factor = self.momentum + + if self.training: + bn_training = True + else: + bn_training = (self.running_mean is None) and (self.running_var is None) + + # If buffers are not to be tracked, ensure that they won't be updated + running_mean = ( + self.running_mean if not self.training or self.track_running_stats else None + ) + running_var = ( + self.running_var if not self.training or self.track_running_stats else None + ) + + # Don't sync batchnorm stats in inference mode (model.eval()). + need_sync = ( + bn_training + and self.training + and pi.distributed.is_available() + and pi.distributed.is_initialized() + ) + if need_sync: + # currently only GPU input is supported + if not input.is_cuda: + raise ValueError("SyncBatchNorm expected input tensor to be on GPU") + + process_group = pi.distributed.group.WORLD + if self.process_group: + process_group = self.process_group + world_size = pi.distributed.get_world_size(process_group) + need_sync = world_size > 1 + + # fallback to framework BN when synchronization is not necessary + if not need_sync: + return F.batch_norm( + input, + running_mean, + running_var, + self.weight, + self.bias, + bn_training, + exponential_average_factor, + self.eps, + ) + else: + raise NotImplementedError + + # assert bn_training + # return sync_batch_norm.apply( + # input, + # self.weight, + # self.bias, + # running_mean, + # running_var, + # self.eps, + # exponential_average_factor, + # process_group, + # world_size, + # ) + + @classmethod + def convert_sync_batchnorm(cls, module, process_group=None): + + module_output = module + if isinstance(module, pi.nn.modules.batchnorm._BatchNorm): + module_output = pi.nn.SyncBatchNorm( + module.num_features, + module.eps, + module.momentum, + module.affine, + module.track_running_stats, + process_group, + ) + if module.affine: + module_output.weight = module.weight + module_output.bias = module.bias + module_output.running_mean = module.running_mean + module_output.running_var = module.running_var + module_output.num_batches_tracked = module.num_batches_tracked + if hasattr(module, "qconfig"): + module_output.qconfig = module.qconfig + for name, child in module.named_children(): + module_output.add_module( + name, cls.convert_sync_batchnorm(child, process_group) + ) + del module + return module_output diff --git a/pi/nn/modules/channelshuffle.py b/pi/nn/modules/channelshuffle.py new file mode 100644 index 0000000..d4319e0 --- /dev/null +++ b/pi/nn/modules/channelshuffle.py @@ -0,0 +1,21 @@ +from pi import Tensor +from .module import Module +from .. import functional as F + +__all__ = ["ChannelShuffle"] + + +class ChannelShuffle(Module): + + __constants__ = ["groups"] + groups: int + + def __init__(self, groups: int) -> None: + super(ChannelShuffle, self).__init__() + self.groups = groups + + def forward(self, input: Tensor) -> Tensor: + return F.channel_shuffle(input, self.groups) + + def extra_repr(self) -> str: + return "groups={}".format(self.groups) diff --git a/pi/nn/modules/container.py b/pi/nn/modules/container.py new file mode 100644 index 0000000..17afe9e --- /dev/null +++ b/pi/nn/modules/container.py @@ -0,0 +1,674 @@ +import warnings +from collections import OrderedDict, abc as container_abcs +from itertools import chain, islice +import operator + +import pi +from .module import Module +from ..parameter import Parameter + +from typing import ( + Any, + Dict, + Iterable, + Iterator, + Mapping, + Optional, + overload, + Tuple, + TypeVar, + Union, +) + +__all__ = [ + "Container", + "Sequential", + "ModuleList", + "ModuleDict", + "ParameterList", + "ParameterDict", +] + +T = TypeVar("T", bound=Module) + + +class Container(Module): + def __init__(self, **kwargs: Any) -> None: + super(Container, self).__init__() + # DeprecationWarning is ignored by default + warnings.warn( + "nn.Container is deprecated. All of it's functionality " + "is now implemented in nn.Module. Subclass that instead." + ) + for key, value in kwargs.items(): + self.add_module(key, value) + + +class Sequential(Module): + + _modules: Dict[str, Module] # type: ignore[assignment] + + @overload + def __init__(self, *args: Module) -> None: + ... + + @overload + def __init__(self, arg: "OrderedDict[str, Module]") -> None: + ... + + def __init__(self, *args): + super(Sequential, self).__init__() + if len(args) == 1 and isinstance(args[0], OrderedDict): + for key, module in args[0].items(): + self.add_module(key, module) + else: + for idx, module in enumerate(args): + self.add_module(str(idx), module) + + def _get_item_by_idx(self, iterator, idx) -> T: + """Get the idx-th item of the iterator""" + size = len(self) + idx = operator.index(idx) + if not -size <= idx < size: + raise IndexError("index {} is out of range".format(idx)) + idx %= size + return next(islice(iterator, idx, None)) + + def __getitem__(self, idx: Union[slice, int]) -> Union["Sequential", T]: + if isinstance(idx, slice): + return self.__class__(OrderedDict(list(self._modules.items())[idx])) + else: + return self._get_item_by_idx(self._modules.values(), idx) + + def __setitem__(self, idx: int, module: Module) -> None: + key: str = self._get_item_by_idx(self._modules.keys(), idx) + return setattr(self, key, module) + + def __delitem__(self, idx: Union[slice, int]) -> None: + if isinstance(idx, slice): + for key in list(self._modules.keys())[idx]: + delattr(self, key) + else: + key = self._get_item_by_idx(self._modules.keys(), idx) + delattr(self, key) + # To preserve numbering + str_indices = [str(i) for i in range(len(self._modules))] + self._modules = OrderedDict(list(zip(str_indices, self._modules.values()))) + + def __len__(self) -> int: + return len(self._modules) + + def __add__(self, other) -> "Sequential": + if isinstance(other, Sequential): + ret = Sequential() + for layer in self: + ret.append(layer) + for layer in other: + ret.append(layer) + return ret + else: + raise ValueError( + "add operator supports only objects " + "of Sequential class, but {} is given.".format(str(type(other))) + ) + + def pop(self, key: Union[int, slice]) -> Module: + v = self[key] + del self[key] + return v + + def __iadd__(self, other) -> "Sequential": + if isinstance(other, Sequential): + offset = len(self) + for i, module in enumerate(other): + self.add_module(str(i + offset), module) + return self + else: + raise ValueError( + "add operator supports only objects " + "of Sequential class, but {} is given.".format(str(type(other))) + ) + + def __mul__(self, other: int) -> "Sequential": + if not isinstance(other, int): + raise TypeError( + f"unsupported operand type(s) for *: {type(self)} and {type(other)}" + ) + elif other <= 0: + raise ValueError( + f"Non-positive multiplication factor {other} for {type(self)}" + ) + else: + combined = Sequential() + offset = 0 + for _ in range(other): + for module in self: + combined.add_module(str(offset), module) + offset += 1 + return combined + + def __rmul__(self, other: int) -> "Sequential": + return self.__mul__(other) + + def __imul__(self, other: int) -> "Sequential": + if not isinstance(other, int): + raise TypeError( + f"unsupported operand type(s) for *: {type(self)} and {type(other)}" + ) + elif other <= 0: + raise ValueError( + f"Non-positive multiplication factor {other} for {type(self)}" + ) + else: + len_original = len(self) + offset = len(self) + for _ in range(other - 1): + for i in range(len_original): + self.add_module(str(i + offset), self._modules[str(i)]) + offset += len_original + return self + + def __dir__(self): + keys = super(Sequential, self).__dir__() + keys = [key for key in keys if not key.isdigit()] + return keys + + def __iter__(self) -> Iterator[Module]: + return iter(self._modules.values()) + + # NB: We can't really type check this function as the type of input + # may change dynamically (as is tested in + # TestScript.test_sequential_intermediary_types). Cannot annotate + # with Any as TorchScript expects a more precise type + def forward(self, input): + for module in self: + input = module(input) + return input + + def append(self, module: Module) -> "Sequential": + + self.add_module(str(len(self)), module) + return self + + def insert(self, index: int, module: Module) -> "Sequential": + if not isinstance(module, Module): + raise AssertionError("module should be of type: {}".format(Module)) + n = len(self._modules) + if not (-n <= index <= n): + raise IndexError("Index out of range: {}".format(index)) + if index < 0: + index += n + for i in range(n, index, -1): + self._modules[str(i)] = self._modules[str(i - 1)] + self._modules[str(index)] = module + return self + + def extend(self, sequential) -> "Sequential": + for layer in sequential: + self.append(layer) + return self + + +class ModuleList(Module): + + _modules: Dict[str, Module] # type: ignore[assignment] + + def __init__(self, modules: Optional[Iterable[Module]] = None) -> None: + super(ModuleList, self).__init__() + if modules is not None: + self += modules + + def _get_abs_string_index(self, idx): + """Get the absolute index for the list of modules""" + idx = operator.index(idx) + if not (-len(self) <= idx < len(self)): + raise IndexError("index {} is out of range".format(idx)) + if idx < 0: + idx += len(self) + return str(idx) + + def __getitem__(self, idx: Union[int, slice]) -> Union[Module, "ModuleList"]: + if isinstance(idx, slice): + return self.__class__(list(self._modules.values())[idx]) + else: + return self._modules[self._get_abs_string_index(idx)] + + def __setitem__(self, idx: int, module: Module) -> None: + idx = self._get_abs_string_index(idx) + return setattr(self, str(idx), module) + + def __delitem__(self, idx: Union[int, slice]) -> None: + if isinstance(idx, slice): + for k in range(len(self._modules))[idx]: + delattr(self, str(k)) + else: + delattr(self, self._get_abs_string_index(idx)) + # To preserve numbering, self._modules is being reconstructed with modules after deletion + str_indices = [str(i) for i in range(len(self._modules))] + self._modules = OrderedDict(list(zip(str_indices, self._modules.values()))) + + def __len__(self) -> int: + return len(self._modules) + + def __iter__(self) -> Iterator[Module]: + return iter(self._modules.values()) + + def __iadd__(self, modules: Iterable[Module]) -> "ModuleList": + return self.extend(modules) + + def __add__(self, other: Iterable[Module]) -> "ModuleList": + combined = ModuleList() + for i, module in enumerate(chain(self, other)): + combined.add_module(str(i), module) + return combined + + def __dir__(self): + keys = super(ModuleList, self).__dir__() + keys = [key for key in keys if not key.isdigit()] + return keys + + def insert(self, index: int, module: Module) -> None: + + for i in range(len(self._modules), index, -1): + self._modules[str(i)] = self._modules[str(i - 1)] + self._modules[str(index)] = module + + def append(self, module: Module) -> "ModuleList": + + self.add_module(str(len(self)), module) + return self + + def pop(self, key: Union[int, slice]) -> Module: + v = self[key] + del self[key] + return v + + def extend(self, modules: Iterable[Module]) -> "ModuleList": + + if not isinstance(modules, container_abcs.Iterable): + raise TypeError( + "ModuleList.extend should be called with an " + "iterable, but got " + type(modules).__name__ + ) + offset = len(self) + for i, module in enumerate(modules): + self.add_module(str(offset + i), module) + return self + + # remove forward alltogether to fallback on Module's _forward_unimplemented + + +class ModuleDict(Module): + + _modules: Dict[str, Module] # type: ignore[assignment] + + def __init__(self, modules: Optional[Mapping[str, Module]] = None) -> None: + super(ModuleDict, self).__init__() + if modules is not None: + self.update(modules) + + def __getitem__(self, key: str) -> Module: + return self._modules[key] + + def __setitem__(self, key: str, module: Module) -> None: + self.add_module(key, module) + + def __delitem__(self, key: str) -> None: + del self._modules[key] + + def __len__(self) -> int: + return len(self._modules) + + def __iter__(self) -> Iterator[str]: + return iter(self._modules) + + def __contains__(self, key: str) -> bool: + return key in self._modules + + def clear(self) -> None: + """Remove all items from the ModuleDict.""" + self._modules.clear() + + def pop(self, key: str) -> Module: + + v = self[key] + del self[key] + return v + + def keys(self) -> Iterable[str]: + + return self._modules.keys() + + def items(self) -> Iterable[Tuple[str, Module]]: + + return self._modules.items() + + def values(self) -> Iterable[Module]: + + return self._modules.values() + + def update(self, modules: Mapping[str, Module]) -> None: + + if not isinstance(modules, container_abcs.Iterable): + raise TypeError( + "ModuleDict.update should be called with an " + "iterable of key/value pairs, but got " + type(modules).__name__ + ) + + if isinstance(modules, (OrderedDict, ModuleDict, container_abcs.Mapping)): + for key, module in modules.items(): + self[key] = module + else: + # modules here can be a list with two items + for j, m in enumerate(modules): + if not isinstance(m, container_abcs.Iterable): + raise TypeError( + "ModuleDict update sequence element " + "#" + str(j) + " should be Iterable; is" + type(m).__name__ + ) + if not len(m) == 2: + raise ValueError( + "ModuleDict update sequence element " + "#" + str(j) + " has length " + str(len(m)) + "; 2 is required" + ) + # modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)] + # that's too cumbersome to type correctly with overloads, so we add an ignore here + self[m[0]] = m[1] # type: ignore[assignment] + + # remove forward alltogether to fallback on Module's _forward_unimplemented + + +class ParameterList(Module): + def __init__(self, values: Optional[Iterable[Any]] = None) -> None: + super(ParameterList, self).__init__() + self._size = 0 + if values is not None: + self += values + + def _get_abs_string_index(self, idx): + """Get the absolute index for the list of modules""" + idx = operator.index(idx) + if not (-len(self) <= idx < len(self)): + raise IndexError("index {} is out of range".format(idx)) + if idx < 0: + idx += len(self) + return str(idx) + + @overload + def __getitem__(self, idx: int) -> Any: + ... + + @overload + def __getitem__(self: T, idx: slice) -> T: + ... + + def __getitem__(self, idx): + if isinstance(idx, slice): + start, stop, step = idx.indices(len(self)) + out = self.__class__() + for i in range(start, stop, step): + out.append(self[i]) + return out + else: + idx = self._get_abs_string_index(idx) + return getattr(self, str(idx)) + + def __setitem__(self, idx: int, param: Any) -> None: + # Note that all other function that add an entry to the list part of + # the ParameterList end up here. So this is the only place where we need + # to wrap things into Parameter if needed. + # Objects added via setattr() are not in the list part and thus won't + # call into this function. + idx = self._get_abs_string_index(idx) + if isinstance(param, pi.Tensor) and not isinstance(param, Parameter): + param = param + return setattr(self, str(idx), param) + + def __len__(self) -> int: + return self._size + + def __iter__(self) -> Iterator[Any]: + return iter(self[i] for i in range(len(self))) + + def __iadd__(self, parameters: Iterable[Any]) -> "ParameterList": + return self.extend(parameters) + + def __dir__(self): + keys = super(ParameterList, self).__dir__() + keys = [key for key in keys if not key.isdigit()] + return keys + + def append(self, value: Any) -> "ParameterList": + """Appends a given value at the end of the list. + + Args: + value (Any): value to append + """ + new_idx = len(self) + self._size += 1 + self[new_idx] = value + return self + + def extend(self, values: Iterable[Any]) -> "ParameterList": + """Appends values from a Python iterable to the end of the list. + + Args: + values (iterable): iterable of values to append + """ + # Tensor is an iterable but we never want to unpack it here + if not isinstance(values, container_abcs.Iterable) or isinstance( + values, pi.Tensor + ): + raise TypeError( + "ParameterList.extend should be called with an " + "iterable, but got " + type(values).__name__ + ) + for value in values: + self.append(value) + return self + + def extra_repr(self) -> str: + child_lines = [] + for k, p in enumerate(self): + if isinstance(p, pi.Tensor): + size_str = "x".join(str(size) for size in p.size()) + device_str = "" if not p.is_cuda else " (GPU {})".format(p.get_device()) + parastr = "{} containing: [{} of size {}{}]".format( + "Parameter" if isinstance(p, Parameter) else "Tensor", + p.dtype, + size_str, + device_str, + ) + child_lines.append(" (" + str(k) + "): " + parastr) + else: + child_lines.append( + " (" + str(k) + "): Object of type: " + type(p).__name__ + ) + + tmpstr = "\n".join(child_lines) + return tmpstr + + def __call__(self, *args, **kwargs): + raise RuntimeError("ParameterList should not be called.") + + +class ParameterDict(Module): + def __init__(self, parameters: Any = None) -> None: + super(ParameterDict, self).__init__() + self._keys: Dict[str, None] = {} + if parameters is not None: + self.update(parameters) + + def _key_to_attr(self, key: str) -> str: + if not isinstance(key, str): + raise TypeError( + "Index given to ParameterDict cannot be used as a key as it is " + f"not a string (type is '{type(key).__name__}'). Open an issue on " + "github if you need non-string keys." + ) + else: + # Use the key as-is so that `.named_parameters()` returns the right thing + return key + + def __getitem__(self, key: str) -> Any: + attr = self._key_to_attr(key) + return getattr(self, attr) + + def __setitem__(self, key: str, value: Any) -> None: + # Note that all other function that add an entry to the dictionary part of + # the ParameterDict end up here. So this is the only place where we need + # to wrap things into Parameter if needed. + # Objects added via setattr() are not in the dictionary part and thus won't + # call into this function. + self._keys[key] = None + attr = self._key_to_attr(key) + if isinstance(value, pi.Tensor) and not isinstance(value, Parameter): + value = value + setattr(self, attr, value) + + def __delitem__(self, key: str) -> None: + del self._keys[key] + attr = self._key_to_attr(key) + delattr(self, attr) + + def __len__(self) -> int: + return len(self._keys) + + def __iter__(self) -> Iterator[str]: + return iter(self._keys) + + def __reversed__(self) -> Iterator[str]: + return reversed(list(self._keys)) + + def copy(self) -> "ParameterDict": + """Returns a copy of this :class:`~pi.nn.ParameterDict` instance.""" + # We have to use an OrderedDict because the ParameterDict constructor + # behaves differently on plain dict vs OrderedDict + return ParameterDict(OrderedDict((k, self[k]) for k in self._keys)) + + def __contains__(self, key: str) -> bool: + return key in self._keys + + def setdefault(self, key: str, default: Optional[Any] = None) -> Any: + """If key is in the ParameterDict, return its value. + If not, insert `key` with a parameter `default` and return `default`. + `default` defaults to `None`. + + Args: + key (str): key to set default for + default (Any): the parameter set to the key + """ + + if key not in self: + self[key] = default + return self[key] + + def clear(self) -> None: + """Remove all items from the ParameterDict.""" + for k in self._keys.copy(): + del self[k] + + def pop(self, key: str) -> Any: + + v = self[key] + del self[key] + return v + + def popitem(self) -> Tuple[str, Any]: + """Remove and return the last inserted `(key, parameter)` pair + from the ParameterDict + """ + k, _ = self._keys.popitem() + # We need the key in the _keys to be able to access/del + self._keys[k] = None + val = self[k] + del self[k] + return k, val + + def get(self, key: str, default: Optional[Any] = None) -> Any: + + return self[key] if key in self else default + + def fromkeys( + self, keys: Iterable[str], default: Optional[Any] = None + ) -> "ParameterDict": + + return ParameterDict(((k, default) for k in keys)) + + def keys(self) -> Iterable[str]: + + return self._keys.keys() + + def items(self) -> Iterable[Tuple[str, Any]]: + + return ((k, self[k]) for k in self._keys) + + def values(self) -> Iterable[Any]: + + return (self[k] for k in self._keys) + + def update(self, parameters: Union[Mapping[str, Any], "ParameterDict"]) -> None: + + if not isinstance(parameters, container_abcs.Iterable): + raise TypeError( + "ParametersDict.update should be called with an " + "iterable of key/value pairs, but got " + type(parameters).__name__ + ) + + if isinstance(parameters, (OrderedDict, ParameterDict)): + for key, parameter in parameters.items(): + self[key] = parameter + elif isinstance(parameters, container_abcs.Mapping): + for key, parameter in sorted(parameters.items()): + self[key] = parameter + else: + for j, p in enumerate(parameters): + if not isinstance(p, container_abcs.Iterable): + raise TypeError( + "ParameterDict update sequence element " + "#" + str(j) + " should be Iterable; is" + type(p).__name__ + ) + if not len(p) == 2: + raise ValueError( + "ParameterDict update sequence element " + "#" + str(j) + " has length " + str(len(p)) + "; 2 is required" + ) + # parameters as length-2 list too cumbersome to type, see ModuleDict.update comment + self[p[0]] = p[1] # type: ignore[assignment] + + def extra_repr(self) -> str: + child_lines = [] + for k, p in self.items(): + if isinstance(p, pi.Tensor): + size_str = "x".join(str(size) for size in p.size()) + device_str = "" if not p.is_cuda else " (GPU {})".format(p.get_device()) + parastr = "{} containing: [{} of size {}{}]".format( + "Parameter" if isinstance(p, Parameter) else "Tensor", + pi.typename(p), + size_str, + device_str, + ) + child_lines.append(" (" + str(k) + "): " + parastr) + else: + child_lines.append( + " (" + str(k) + "): Object of type: " + type(p).__name__ + ) + tmpstr = "\n".join(child_lines) + return tmpstr + + def __call__(self, input): + raise RuntimeError("ParameterDict should not be called.") + + def __or__(self, other: "ParameterDict") -> "ParameterDict": + copy = self.copy() + copy.update(other) + return copy + + def __ror__(self, other: "ParameterDict") -> "ParameterDict": + copy = other.copy() + copy.update(self) + return copy + + def __ior__(self, other: "ParameterDict") -> "ParameterDict": + self.update(other) + return self diff --git a/pi/nn/modules/conv.py b/pi/nn/modules/conv.py new file mode 100644 index 0000000..e607b1a --- /dev/null +++ b/pi/nn/modules/conv.py @@ -0,0 +1,631 @@ +from typing import Optional, List, Tuple, Union + +import math + +from pi import Tensor +from .module import Module +from .utils import _single, _pair, _reverse_repeat_tuple, _triple +from .. import functional as F +from ..common_types import _size_1_t, _size_2_t, _size_3_t +from ..parameter import UninitializedParameter +from .. import init + + +class _ConvNd(Module): + def _conv_forward( + self, input: Tensor, weight: Tensor, bias: Optional[Tensor] + ) -> Tensor: + ... + + in_channels: int + _reversed_padding_repeated_twice: List[int] + out_channels: int + kernel_size: Tuple[int, ...] + stride: Tuple[int, ...] + padding: Union[str, Tuple[int, ...]] + dilation: Tuple[int, ...] + transposed: bool + output_padding: Tuple[int, ...] + groups: int + padding_mode: str + weight: Union[Tensor, UninitializedParameter] + bias: Optional[Union[Tensor, UninitializedParameter]] + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Tuple[int, ...], + stride: Tuple[int, ...], + padding: Tuple[int, ...], + dilation: Tuple[int, ...], + transposed: bool, + output_padding: Tuple[int, ...], + groups: int, + bias: bool, + padding_mode: str, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"dtype": dtype} + super(_ConvNd, self).__init__() + if groups <= 0: + raise ValueError("groups must be a positive integer") + if in_channels % groups != 0: + raise ValueError("in_channels must be divisible by groups") + if out_channels % groups != 0: + raise ValueError("out_channels must be divisible by groups") + valid_padding_strings = {"same", "valid"} + if isinstance(padding, str): + if padding not in valid_padding_strings: + raise ValueError( + "Invalid padding string {!r}, should be one of {}".format( + padding, valid_padding_strings + ) + ) + if padding == "same" and any(s != 1 for s in stride): + raise ValueError( + "padding='same' is not supported for strided convolutions" + ) + + valid_padding_modes = {"zeros", "reflect", "replicate", "circular"} + if padding_mode not in valid_padding_modes: + raise ValueError( + "padding_mode must be one of {}, but got padding_mode='{}'".format( + valid_padding_modes, padding_mode + ) + ) + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.transposed = transposed + self.output_padding = output_padding + self.groups = groups + self.padding_mode = padding_mode + # `_reversed_padding_repeated_twice` is the padding to be passed to + # `F.pad` if needed (e.g., for non-zero padding types that are + # implemented as two ops: padding + conv). `F.pad` accepts paddings in + # reverse order than the dimension. + if isinstance(self.padding, str): + self._reversed_padding_repeated_twice = [0, 0] * len(kernel_size) + if padding == "same": + for d, k, i in zip( + dilation, kernel_size, range(len(kernel_size) - 1, -1, -1) + ): + total_padding = d * (k - 1) + left_pad = total_padding // 2 + self._reversed_padding_repeated_twice[2 * i] = left_pad + self._reversed_padding_repeated_twice[2 * i + 1] = ( + total_padding - left_pad + ) + else: + self._reversed_padding_repeated_twice = _reverse_repeat_tuple( + self.padding, 2 + ) + + if transposed: + self.weight = UninitializedParameter( + (in_channels, out_channels // groups, *kernel_size), **factory_kwargs + ) + else: + self.weight = UninitializedParameter( + (out_channels, in_channels // groups, *kernel_size), **factory_kwargs + ) + if bias: + self.bias = UninitializedParameter((out_channels,), **factory_kwargs) + else: + self.bias = None + + # self.reset_parameters() + + def reset_parameters(self) -> None: + # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with + # uniform(-1/sqrt(k), 1/sqrt(k)), where k = weight.size(1) * prod(*kernel_size) + # For more details see: https://github.com/pytorch/pytorch/issues/15314#issuecomment-477448573 + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) + if fan_in != 0: + bound = 1 / math.sqrt(fan_in) + init.uniform_(self.bias, -bound, bound) + + def extra_repr(self): + s = ( + "{in_channels}, {out_channels}, kernel_size={kernel_size}" + ", stride={stride}" + ) + if self.padding != (0,) * len(self.padding): + s += ", padding={padding}" + if self.dilation != (1,) * len(self.dilation): + s += ", dilation={dilation}" + if self.output_padding != (0,) * len(self.output_padding): + s += ", output_padding={output_padding}" + if self.groups != 1: + s += ", groups={groups}" + if self.bias is None: + s += ", bias=False" + if self.padding_mode != "zeros": + s += ", padding_mode={padding_mode}" + return s.format(**self.__dict__) + + +class Conv1d(_ConvNd): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t = 1, + padding: Union[str, _size_1_t] = 0, + dilation: _size_1_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", # TODO: refine this type + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + # we create new variables below to make mypy happy since kernel_size has + # type Union[int, Tuple[int]] and kernel_size_ has type Tuple[int] + kernel_size_ = _single(kernel_size) + stride_ = _single(stride) + padding_ = padding if isinstance(padding, str) else _single(padding) + dilation_ = _single(dilation) + super(Conv1d, self).__init__( + in_channels, + out_channels, + kernel_size_, + stride_, + padding_, + dilation_, + False, + _single(0), + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]): + if self.padding_mode != "zeros": + return F.conv1d( + F.pad( + input, self._reversed_padding_repeated_twice, mode=self.padding_mode + ), + weight, + bias, + self.stride, + _single(0), + self.dilation, + self.groups, + ) + return F.conv1d( + input, weight, bias, self.stride, self.padding, self.dilation, self.groups + ) + + def forward(self, input: Tensor) -> Tensor: + return self._conv_forward(input, self.weight, self.bias) + + +class Conv2d(_ConvNd): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: Union[str, _size_2_t] = 0, + dilation: _size_2_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", # TODO: refine this type + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"dtype": dtype} + kernel_size_ = _pair(kernel_size) + stride_ = _pair(stride) + padding_ = padding if isinstance(padding, str) else _pair(padding) + dilation_ = _pair(dilation) + super(Conv2d, self).__init__( + in_channels, + out_channels, + kernel_size_, + stride_, + padding_, + dilation_, + False, + _pair(0), + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]): + if self.padding_mode != "zeros": + return F.conv2d( + F.pad( + input, self._reversed_padding_repeated_twice, mode=self.padding_mode + ), + weight, + bias, + self.stride, + _pair(0), + self.dilation, + self.groups, + ) + return F.conv2d( + input, weight, bias, self.stride, self.padding, self.dilation, self.groups + ) + + def forward(self, input: Tensor) -> Tensor: + return self._conv_forward(input, self.weight, self.bias) + + +class Conv3d(_ConvNd): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_3_t, + stride: _size_3_t = 1, + padding: Union[str, _size_3_t] = 0, + dilation: _size_3_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + kernel_size_ = _triple(kernel_size) + stride_ = _triple(stride) + padding_ = padding if isinstance(padding, str) else _triple(padding) + dilation_ = _triple(dilation) + super(Conv3d, self).__init__( + in_channels, + out_channels, + kernel_size_, + stride_, + padding_, + dilation_, + False, + _triple(0), + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]): + if self.padding_mode != "zeros": + return F.conv3d( + F.pad( + input, self._reversed_padding_repeated_twice, mode=self.padding_mode + ), + weight, + bias, + self.stride, + _triple(0), + self.dilation, + self.groups, + ) + return F.conv3d( + input, weight, bias, self.stride, self.padding, self.dilation, self.groups + ) + + def forward(self, input: Tensor) -> Tensor: + return self._conv_forward(input, self.weight, self.bias) + + +class _ConvTransposeNd(_ConvNd): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + bias, + padding_mode, + device=None, + dtype=None, + ) -> None: + if padding_mode != "zeros": + raise ValueError( + 'Only "zeros" padding mode is supported for {}'.format( + self.__class__.__name__ + ) + ) + + factory_kwargs = {"device": device, "dtype": dtype} + super(_ConvTransposeNd, self).__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + # dilation being an optional parameter is for backwards + # compatibility + def _output_padding( + self, + input: Tensor, + output_size: Optional[List[int]], + stride: List[int], + padding: List[int], + kernel_size: List[int], + num_spatial_dims: int, + dilation: Optional[List[int]] = None, + ) -> List[int]: + if output_size is None: + ret = _single(self.output_padding) # converting to list if was not already + else: + has_batch_dim = input.dim() == num_spatial_dims + 2 + num_non_spatial_dims = 2 if has_batch_dim else 1 + if len(output_size) == num_non_spatial_dims + num_spatial_dims: + output_size = output_size[num_non_spatial_dims:] + if len(output_size) != num_spatial_dims: + raise ValueError( + "ConvTranspose{}D: for {}D input, output_size must have {} or {} elements (got {})".format( + num_spatial_dims, + input.dim(), + num_spatial_dims, + num_non_spatial_dims + num_spatial_dims, + len(output_size), + ) + ) + + min_sizes: List[int] = [] + max_sizes: List[int] = [] + for d in range(num_spatial_dims): + dim_size = ( + (input.size(d + num_non_spatial_dims) - 1) * stride[d] + - 2 * padding[d] + + (dilation[d] if dilation is not None else 1) + * (kernel_size[d] - 1) + + 1 + ) + min_sizes.append(dim_size) + max_sizes.append(min_sizes[d] + stride[d] - 1) + + for i in range(len(output_size)): + size = output_size[i] + min_size = min_sizes[i] + max_size = max_sizes[i] + if size < min_size or size > max_size: + raise ValueError( + ( + "requested an output size of {}, but valid sizes range " + "from {} to {} (for an input of {})" + ).format(output_size, min_sizes, max_sizes, input.size()[2:]) + ) + + res: List[int] = [] + for d in range(num_spatial_dims): + res.append(output_size[d] - min_sizes[d]) + + ret = res + return ret + + +class ConvTranspose1d(_ConvTransposeNd): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t = 1, + padding: _size_1_t = 0, + output_padding: _size_1_t = 0, + groups: int = 1, + bias: bool = True, + dilation: _size_1_t = 1, + padding_mode: str = "zeros", + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + kernel_size = _single(kernel_size) + stride = _single(stride) + padding = _single(padding) + dilation = _single(dilation) + output_padding = _single(output_padding) + super(ConvTranspose1d, self).__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + True, + output_padding, + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor: + if self.padding_mode != "zeros": + raise ValueError( + "Only `zeros` padding mode is supported for ConvTranspose1d" + ) + + assert isinstance(self.padding, tuple) + # One cannot replace List by Tuple or Sequence in "_output_padding" because + # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`. + num_spatial_dims = 1 + output_padding = self._output_padding( + input, + output_size, + self.stride, + self.padding, + self.kernel_size, # type: ignore[arg-type] + num_spatial_dims, + self.dilation, + ) # type: ignore[arg-type] + return F.conv_transpose1d( + input, + self.weight, + self.bias, + self.stride, + self.padding, + output_padding, + self.groups, + self.dilation, + ) + + +class ConvTranspose2d(_ConvTransposeNd): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: _size_2_t = 0, + output_padding: _size_2_t = 0, + groups: int = 1, + bias: bool = True, + dilation: _size_2_t = 1, + padding_mode: str = "zeros", + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + kernel_size = _pair(kernel_size) + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + output_padding = _pair(output_padding) + super(ConvTranspose2d, self).__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + True, + output_padding, + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor: + if self.padding_mode != "zeros": + raise ValueError( + "Only `zeros` padding mode is supported for ConvTranspose2d" + ) + + assert isinstance(self.padding, tuple) + # One cannot replace List by Tuple or Sequence in "_output_padding" because + # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`. + num_spatial_dims = 2 + output_padding = self._output_padding( + input, + output_size, + self.stride, + self.padding, + self.kernel_size, # type: ignore[arg-type] + num_spatial_dims, + self.dilation, + ) # type: ignore[arg-type] + + return F.conv_transpose2d( + input, + self.weight, + self.bias, + self.stride, + self.padding, + output_padding, + self.groups, + self.dilation, + ) + + +class ConvTranspose3d(_ConvTransposeNd): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_3_t, + stride: _size_3_t = 1, + padding: _size_3_t = 0, + output_padding: _size_3_t = 0, + groups: int = 1, + bias: bool = True, + dilation: _size_3_t = 1, + padding_mode: str = "zeros", + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + kernel_size = _triple(kernel_size) + stride = _triple(stride) + padding = _triple(padding) + dilation = _triple(dilation) + output_padding = _triple(output_padding) + super(ConvTranspose3d, self).__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + True, + output_padding, + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor: + if self.padding_mode != "zeros": + raise ValueError( + "Only `zeros` padding mode is supported for ConvTranspose3d" + ) + + assert isinstance(self.padding, tuple) + # One cannot replace List by Tuple or Sequence in "_output_padding" because + # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`. + num_spatial_dims = 3 + output_padding = self._output_padding( + input, + output_size, + self.stride, + self.padding, + self.kernel_size, # type: ignore[arg-type] + num_spatial_dims, + self.dilation, + ) # type: ignore[arg-type] + + return F.conv_transpose3d( + input, + self.weight, + self.bias, + self.stride, + self.padding, + output_padding, + self.groups, + self.dilation, + ) diff --git a/pi/nn/modules/distance.py b/pi/nn/modules/distance.py new file mode 100644 index 0000000..d596fbf --- /dev/null +++ b/pi/nn/modules/distance.py @@ -0,0 +1,40 @@ +from .module import Module +from .. import functional as F + +from pi import Tensor + +__all__ = ["PairwiseDistance", "CosineSimilarity"] + + +class PairwiseDistance(Module): + + __constants__ = ["norm", "eps", "keepdim"] + norm: float + eps: float + keepdim: bool + + def __init__( + self, p: float = 2.0, eps: float = 1e-6, keepdim: bool = False + ) -> None: + super(PairwiseDistance, self).__init__() + self.norm = p + self.eps = eps + self.keepdim = keepdim + + def forward(self, x1: Tensor, x2: Tensor) -> Tensor: + return F.pairwise_distance(x1, x2, self.norm, self.eps, self.keepdim) + + +class CosineSimilarity(Module): + + __constants__ = ["dim", "eps"] + dim: int + eps: float + + def __init__(self, dim: int = 1, eps: float = 1e-8) -> None: + super(CosineSimilarity, self).__init__() + self.dim = dim + self.eps = eps + + def forward(self, x1: Tensor, x2: Tensor) -> Tensor: + return F.cosine_similarity(x1, x2, self.dim, self.eps) diff --git a/pi/nn/modules/dropout.py b/pi/nn/modules/dropout.py new file mode 100644 index 0000000..eb49136 --- /dev/null +++ b/pi/nn/modules/dropout.py @@ -0,0 +1,61 @@ +from .module import Module +from .. import functional as F + +from pi import Tensor + +__all__ = [ + "Dropout", + "Dropout1d", + "Dropout2d", + "Dropout3d", + "AlphaDropout", + "FeatureAlphaDropout", +] + + +class _DropoutNd(Module): + __constants__ = ["p", "inplace"] + p: float + inplace: bool + + def __init__(self, p: float = 0.5, inplace: bool = False) -> None: + super(_DropoutNd, self).__init__() + if p < 0 or p > 1: + raise ValueError( + "dropout probability has to be between 0 and 1, " "but got {}".format(p) + ) + self.p = p + self.inplace = inplace + + def extra_repr(self) -> str: + return "p={}, inplace={}".format(self.p, self.inplace) + + +class Dropout(_DropoutNd): + def forward(self, input: Tensor) -> Tensor: + return F.dropout(input, self.p, self.training, self.inplace) + + +class Dropout1d(_DropoutNd): + def forward(self, input: Tensor) -> Tensor: + return F.dropout1d(input, self.p, self.training, self.inplace) + + +class Dropout2d(_DropoutNd): + def forward(self, input: Tensor) -> Tensor: + return F.dropout2d(input, self.p, self.training, self.inplace) + + +class Dropout3d(_DropoutNd): + def forward(self, input: Tensor) -> Tensor: + return F.dropout3d(input, self.p, self.training, self.inplace) + + +class AlphaDropout(_DropoutNd): + def forward(self, input: Tensor) -> Tensor: + return F.alpha_dropout(input, self.p, self.training) + + +class FeatureAlphaDropout(_DropoutNd): + def forward(self, input: Tensor) -> Tensor: + return F.feature_alpha_dropout(input, self.p, self.training) diff --git a/pi/nn/modules/flatten.py b/pi/nn/modules/flatten.py new file mode 100644 index 0000000..1f63270 --- /dev/null +++ b/pi/nn/modules/flatten.py @@ -0,0 +1,87 @@ +from typing import Tuple, Union + +from .module import Module +from ...types_ import Size + +from pi import Tensor + +__all__ = ["Flatten", "Unflatten"] + + +class Flatten(Module): + __constants__ = ["start_dim", "end_dim"] + start_dim: int + end_dim: int + + def __init__(self, start_dim: int = 1, end_dim: int = -1) -> None: + super(Flatten, self).__init__() + self.start_dim = start_dim + self.end_dim = end_dim + + def forward(self, input: Tensor) -> Tensor: + return input.flatten(self.start_dim, self.end_dim) + + def extra_repr(self) -> str: + return "start_dim={}, end_dim={}".format(self.start_dim, self.end_dim) + + +class Unflatten(Module): + NamedShape = Tuple[Tuple[str, int]] + + __constants__ = ["dim", "unflattened_size"] + dim: Union[int, str] + unflattened_size: Union[Size, NamedShape] + + def __init__( + self, dim: Union[int, str], unflattened_size: Union[Size, NamedShape] + ) -> None: + super(Unflatten, self).__init__() + + if isinstance(dim, int): + self._require_tuple_int(unflattened_size) + elif isinstance(dim, str): + self._require_tuple_tuple(unflattened_size) + else: + raise TypeError("invalid argument type for dim parameter") + + self.dim = dim + self.unflattened_size = unflattened_size + + def _require_tuple_tuple(self, input): + if isinstance(input, tuple): + for idx, elem in enumerate(input): + if not isinstance(elem, tuple): + raise TypeError( + "unflattened_size must be tuple of tuples, " + + "but found element of type {} at pos {}".format( + type(elem).__name__, idx + ) + ) + return + raise TypeError( + "unflattened_size must be a tuple of tuples, " + + "but found type {}".format(type(input).__name__) + ) + + def _require_tuple_int(self, input): + if isinstance(input, (tuple, list)): + for idx, elem in enumerate(input): + if not isinstance(elem, int): + raise TypeError( + "unflattened_size must be tuple of ints, " + + "but found element of type {} at pos {}".format( + type(elem).__name__, idx + ) + ) + return + raise TypeError( + "unflattened_size must be a tuple of ints, but found type {}".format( + type(input).__name__ + ) + ) + + def forward(self, input: Tensor) -> Tensor: + return input.unflatten(self.dim, self.unflattened_size) + + def extra_repr(self) -> str: + return "dim={}, unflattened_size={}".format(self.dim, self.unflattened_size) diff --git a/pi/nn/modules/fold.py b/pi/nn/modules/fold.py new file mode 100644 index 0000000..c039ef4 --- /dev/null +++ b/pi/nn/modules/fold.py @@ -0,0 +1,84 @@ +# -*- coding: utf-8 -*- +from .module import Module +from .. import functional as F + +from pi import Tensor +from ..common_types import _size_any_t + +__all__ = ["Fold", "Unfold"] + + +class Fold(Module): + + __constants__ = ["output_size", "kernel_size", "dilation", "padding", "stride"] + output_size: _size_any_t + kernel_size: _size_any_t + dilation: _size_any_t + padding: _size_any_t + stride: _size_any_t + + def __init__( + self, + output_size: _size_any_t, + kernel_size: _size_any_t, + dilation: _size_any_t = 1, + padding: _size_any_t = 0, + stride: _size_any_t = 1, + ) -> None: + super(Fold, self).__init__() + self.output_size = output_size + self.kernel_size = kernel_size + self.dilation = dilation + self.padding = padding + self.stride = stride + + def forward(self, input: Tensor) -> Tensor: + return F.fold( + input, + self.output_size, + self.kernel_size, + self.dilation, + self.padding, + self.stride, + ) + + def extra_repr(self) -> str: + return ( + "output_size={output_size}, kernel_size={kernel_size}, " + "dilation={dilation}, padding={padding}, stride={stride}".format( + **self.__dict__ + ) + ) + + +class Unfold(Module): + + __constants__ = ["kernel_size", "dilation", "padding", "stride"] + kernel_size: _size_any_t + dilation: _size_any_t + padding: _size_any_t + stride: _size_any_t + + def __init__( + self, + kernel_size: _size_any_t, + dilation: _size_any_t = 1, + padding: _size_any_t = 0, + stride: _size_any_t = 1, + ) -> None: + super(Unfold, self).__init__() + self.kernel_size = kernel_size + self.dilation = dilation + self.padding = padding + self.stride = stride + + def forward(self, input: Tensor) -> Tensor: + return F.unfold( + input, self.kernel_size, self.dilation, self.padding, self.stride + ) + + def extra_repr(self) -> str: + return ( + "kernel_size={kernel_size}, dilation={dilation}, padding={padding}," + " stride={stride}".format(**self.__dict__) + ) diff --git a/pi/nn/modules/instancenorm.py b/pi/nn/modules/instancenorm.py new file mode 100644 index 0000000..2c34780 --- /dev/null +++ b/pi/nn/modules/instancenorm.py @@ -0,0 +1,182 @@ +from pi import Tensor + +from .batchnorm import _NormBase +from .. import functional as F + +__all__ = [ + "InstanceNorm1d", + "InstanceNorm2d", + "InstanceNorm3d", + # "LazyInstanceNorm1d", + # "LazyInstanceNorm2d", + # "LazyInstanceNorm3d", +] + + +class _InstanceNorm(_NormBase): + def __init__( + self, + num_features: int, + eps: float = 1e-5, + momentum: float = 0.1, + affine: bool = False, + track_running_stats: bool = False, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(_InstanceNorm, self).__init__( + num_features, eps, momentum, affine, track_running_stats, **factory_kwargs + ) + + def _check_input_dim(self, input): + raise NotImplementedError + + def _get_no_batch_dim(self): + raise NotImplementedError + + def _handle_no_batch_input(self, input): + return self._apply_instance_norm(input.unsqueeze(0)).squeeze(0) + + def _apply_instance_norm(self, input): + return F.instance_norm( + input, + self.running_mean, + self.running_var, + self.weight, + self.bias, + self.training or not self.track_running_stats, + self.momentum, + self.eps, + ) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + # at version 1: removed running_mean and running_var when + # track_running_stats=False (default) + if version is None and not self.track_running_stats: + running_stats_keys = [] + for name in ("running_mean", "running_var"): + key = prefix + name + if key in state_dict: + running_stats_keys.append(key) + if len(running_stats_keys) > 0: + error_msgs.append( + "Unexpected running stats buffer(s) {names} for {klass} " + "with track_running_stats=False. If state_dict is a " + "checkpoint saved before 0.4.0, this may be expected " + "because {klass} does not track running stats by default " + "since 0.4.0. Please remove these keys from state_dict. If " + "the running stats are actually needed, instead set " + "track_running_stats=True in {klass} to enable them. See " + "the documentation of {klass} for details.".format( + names=" and ".join( + '"{}"'.format(k) for k in running_stats_keys + ), + klass=self.__class__.__name__, + ) + ) + for key in running_stats_keys: + state_dict.pop(key) + + super(_InstanceNorm, self)._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def forward(self, input: Tensor) -> Tensor: + self._check_input_dim(input) + + if input.dim() == self._get_no_batch_dim(): + return self._handle_no_batch_input(input) + + return self._apply_instance_norm(input) + + +class InstanceNorm1d(_InstanceNorm): + def _get_no_batch_dim(self): + return 2 + + def _check_input_dim(self, input): + if input.dim() not in (2, 3): + raise ValueError( + "expected 2D or 3D input (got {}D input)".format(input.dim()) + ) + + +# class LazyInstanceNorm1d(_LazyNormBase, _InstanceNorm): +# +# cls_to_become = InstanceNorm1d # type: ignore[assignment] +# +# def _get_no_batch_dim(self): +# return 2 +# +# def _check_input_dim(self, input): +# if input.dim() not in (2, 3): +# raise ValueError( +# "expected 2D or 3D input (got {}D input)".format(input.dim()) +# ) + + +class InstanceNorm2d(_InstanceNorm): + def _get_no_batch_dim(self): + return 3 + + def _check_input_dim(self, input): + if input.dim() not in (3, 4): + raise ValueError( + "expected 3D or 4D input (got {}D input)".format(input.dim()) + ) + + +# class LazyInstanceNorm2d(_LazyNormBase, _InstanceNorm): +# +# cls_to_become = InstanceNorm2d # type: ignore[assignment] +# +# def _get_no_batch_dim(self): +# return 3 +# +# def _check_input_dim(self, input): +# if input.dim() not in (3, 4): +# raise ValueError( +# "expected 3D or 4D input (got {}D input)".format(input.dim()) +# ) + + +class InstanceNorm3d(_InstanceNorm): + def _get_no_batch_dim(self): + return 4 + + def _check_input_dim(self, input): + if input.dim() not in (4, 5): + raise ValueError( + "expected 4D or 5D input (got {}D input)".format(input.dim()) + ) + + +# class LazyInstanceNorm3d(_LazyNormBase, _InstanceNorm): +# +# cls_to_become = InstanceNorm3d # type: ignore[assignment] +# +# def _get_no_batch_dim(self): +# return 4 +# +# def _check_input_dim(self, input): +# if input.dim() not in (4, 5): +# raise ValueError( +# "expected 4D or 5D input (got {}D input)".format(input.dim()) +# ) diff --git a/pi/nn/modules/linear.py b/pi/nn/modules/linear.py new file mode 100644 index 0000000..ec6c399 --- /dev/null +++ b/pi/nn/modules/linear.py @@ -0,0 +1,139 @@ +from typing import Any, Union + +import math + +from pi import Tensor +from .module import Module +from .. import functional as F +from ..parameter import UninitializedParameter +from .. import init + +__all__ = [ + "Bilinear", + "Identity", + # 'LazyLinear', + "Linear", + "NonDynamicallyQuantizableLinear", +] + + +class Identity(Module): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super(Identity, self).__init__() + + def forward(self, input: Tensor) -> Tensor: + return input + + +class Linear(Module): + __constants__ = ["in_features", "out_features"] + in_features: int + out_features: int + weight: Union[Tensor, UninitializedParameter] + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"dtype": dtype} + super(Linear, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = UninitializedParameter( + (out_features, in_features), **factory_kwargs + ) + if bias: + self.bias = UninitializedParameter(out_features, **factory_kwargs) + else: + self.register_("bias", None) + # self.reset_parameters() + + def reset_parameters(self) -> None: + # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with + # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see + # https://github.com/pytorch/pytorch/issues/57109 + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + init.uniform_(self.bias, -bound, bound) + + def forward(self, input: Tensor) -> Tensor: + return F.linear(input, self.weight, self.bias) + + def extra_repr(self) -> str: + return "in_features={}, out_features={}, bias={}".format( + self.in_features, self.out_features, self.bias is not None + ) + + +# This class exists solely to avoid triggering an obscure error when scripting +# an improperly quantized attention layer. See this issue for details: +# https://github.com/pytorch/pytorch/issues/58969 +# TODO: fail fast on quantization API usage error, then remove this class +# and replace uses of it with plain Linear +class NonDynamicallyQuantizableLinear(Linear): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + super().__init__( + in_features, out_features, bias=bias, device=device, dtype=dtype + ) + + +class Bilinear(Module): + __constants__ = ["in1_features", "in2_features", "out_features"] + in1_features: int + in2_features: int + out_features: int + weight: Union[Tensor, UninitializedParameter] + + def __init__( + self, + in1_features: int, + in2_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"dtype": dtype} + super(Bilinear, self).__init__() + self.in1_features = in1_features + self.in2_features = in2_features + self.out_features = out_features + self.weight = UninitializedParameter( + (out_features, in1_features, in2_features), **factory_kwargs + ) + + if bias: + self.bias = UninitializedParameter(out_features, **factory_kwargs) + else: + self.register_("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + bound = 1 / math.sqrt(self.weight.size(1)) + init.uniform_(self.weight, -bound, bound) + if self.bias is not None: + init.uniform_(self.bias, -bound, bound) + + def forward(self, input1: Tensor, input2: Tensor) -> Tensor: + return F.bilinear(input1, input2, self.weight, self.bias) + + def extra_repr(self) -> str: + return "in1_features={}, in2_features={}, out_features={}, bias={}".format( + self.in1_features, + self.in2_features, + self.out_features, + self.bias is not None, + ) diff --git a/pi/nn/modules/loss.py b/pi/nn/modules/loss.py new file mode 100644 index 0000000..3b67ebd --- /dev/null +++ b/pi/nn/modules/loss.py @@ -0,0 +1,543 @@ +import warnings + +from .distance import PairwiseDistance +from .module import Module +from .. import functional as F +from .. import _reduction as _Reduction + +from pi import Tensor +from typing import Callable, Optional + +__all__ = [ + "L1Loss", + "NLLLoss", + "NLLLoss2d", + "PoissonNLLLoss", + "GaussianNLLLoss", + "KLDivLoss", + "MSELoss", + "BCELoss", + "BCEWithLogitsLoss", + "HingeEmbeddingLoss", + "MultiLabelMarginLoss", + "SmoothL1Loss", + "HuberLoss", + "SoftMarginLoss", + "CrossEntropyLoss", + "MultiLabelSoftMarginLoss", + "CosineEmbeddingLoss", + "MarginRankingLoss", + "MultiMarginLoss", + "TripletMarginLoss", + "TripletMarginWithDistanceLoss", + "CTCLoss", +] + + +class _Loss(Module): + reduction: str + + def __init__(self, size_average=None, reduce=None, reduction: str = "mean") -> None: + super(_Loss, self).__init__() + if size_average is not None or reduce is not None: + self.reduction: str = _Reduction.legacy_get_string(size_average, reduce) + else: + self.reduction = reduction + + +class _WeightedLoss(_Loss): + def __init__( + self, + weight: Optional[Tensor] = None, + size_average=None, + reduce=None, + reduction: str = "mean", + ) -> None: + super(_WeightedLoss, self).__init__(size_average, reduce, reduction) + self.register_buffer("weight", weight) + self.weight: Optional[Tensor] + + +class L1Loss(_Loss): + + __constants__ = ["reduction"] + + def __init__(self, size_average=None, reduce=None, reduction: str = "mean") -> None: + super(L1Loss, self).__init__(size_average, reduce, reduction) + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.l1_loss(input, target, reduction=self.reduction) + + +class NLLLoss(_WeightedLoss): + + __constants__ = ["ignore_index", "reduction"] + ignore_index: int + + def __init__( + self, + weight: Optional[Tensor] = None, + size_average=None, + ignore_index: int = -100, + reduce=None, + reduction: str = "mean", + ) -> None: + super(NLLLoss, self).__init__(weight, size_average, reduce, reduction) + self.ignore_index = ignore_index + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.nll_loss( + input, + target, + weight=self.weight, + ignore_index=self.ignore_index, + reduction=self.reduction, + ) + + +class NLLLoss2d(NLLLoss): + def __init__( + self, + weight: Optional[Tensor] = None, + size_average=None, + ignore_index: int = -100, + reduce=None, + reduction: str = "mean", + ) -> None: + warnings.warn( + "NLLLoss2d has been deprecated. " + "Please use NLLLoss instead as a drop-in replacement and see " + "https://pytorch.org/docs/master/nn.html#torch.nn.NLLLoss for more details." + ) + super(NLLLoss2d, self).__init__( + weight, size_average, ignore_index, reduce, reduction + ) + + +class PoissonNLLLoss(_Loss): + + __constants__ = ["log_input", "full", "eps", "reduction"] + log_input: bool + full: bool + eps: float + + def __init__( + self, + log_input: bool = True, + full: bool = False, + size_average=None, + eps: float = 1e-8, + reduce=None, + reduction: str = "mean", + ) -> None: + super(PoissonNLLLoss, self).__init__(size_average, reduce, reduction) + self.log_input = log_input + self.full = full + self.eps = eps + + def forward(self, log_input: Tensor, target: Tensor) -> Tensor: + return F.poisson_nll_loss( + log_input, + target, + log_input=self.log_input, + full=self.full, + eps=self.eps, + reduction=self.reduction, + ) + + +class GaussianNLLLoss(_Loss): + + __constants__ = ["full", "eps", "reduction"] + full: bool + eps: float + + def __init__( + self, *, full: bool = False, eps: float = 1e-6, reduction: str = "mean" + ) -> None: + super(GaussianNLLLoss, self).__init__(None, None, reduction) + self.full = full + self.eps = eps + + def forward(self, input: Tensor, target: Tensor, var: Tensor) -> Tensor: + return F.gaussian_nll_loss( + input, target, var, full=self.full, eps=self.eps, reduction=self.reduction + ) + + +class KLDivLoss(_Loss): + + __constants__ = ["reduction"] + + def __init__( + self, + size_average=None, + reduce=None, + reduction: str = "mean", + log_target: bool = False, + ) -> None: + super(KLDivLoss, self).__init__(size_average, reduce, reduction) + self.log_target = log_target + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.kl_div( + input, target, reduction=self.reduction, log_target=self.log_target + ) + + +class MSELoss(_Loss): + + __constants__ = ["reduction"] + + def __init__(self, size_average=None, reduce=None, reduction: str = "mean") -> None: + super(MSELoss, self).__init__(size_average, reduce, reduction) + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.mse_loss(input, target, reduction=self.reduction) + + +class BCELoss(_WeightedLoss): + + __constants__ = ["reduction"] + + def __init__( + self, + weight: Optional[Tensor] = None, + size_average=None, + reduce=None, + reduction: str = "mean", + ) -> None: + super(BCELoss, self).__init__(weight, size_average, reduce, reduction) + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.binary_cross_entropy( + input, target, weight=self.weight, reduction=self.reduction + ) + + +class BCEWithLogitsLoss(_Loss): + def __init__( + self, + weight: Optional[Tensor] = None, + size_average=None, + reduce=None, + reduction: str = "mean", + pos_weight: Optional[Tensor] = None, + ) -> None: + super(BCEWithLogitsLoss, self).__init__(size_average, reduce, reduction) + self.register_buffer("weight", weight) + self.register_buffer("pos_weight", pos_weight) + self.weight: Optional[Tensor] + self.pos_weight: Optional[Tensor] + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.binary_cross_entropy_with_logits( + input, + target, + self.weight, + pos_weight=self.pos_weight, + reduction=self.reduction, + ) + + +class HingeEmbeddingLoss(_Loss): + + __constants__ = ["margin", "reduction"] + margin: float + + def __init__( + self, + margin: float = 1.0, + size_average=None, + reduce=None, + reduction: str = "mean", + ) -> None: + super(HingeEmbeddingLoss, self).__init__(size_average, reduce, reduction) + self.margin = margin + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.hinge_embedding_loss( + input, target, margin=self.margin, reduction=self.reduction + ) + + +class MultiLabelMarginLoss(_Loss): + + __constants__ = ["reduction"] + + def __init__(self, size_average=None, reduce=None, reduction: str = "mean") -> None: + super(MultiLabelMarginLoss, self).__init__(size_average, reduce, reduction) + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.multilabel_margin_loss(input, target, reduction=self.reduction) + + +class SmoothL1Loss(_Loss): + + __constants__ = ["reduction"] + + def __init__( + self, size_average=None, reduce=None, reduction: str = "mean", beta: float = 1.0 + ) -> None: + super(SmoothL1Loss, self).__init__(size_average, reduce, reduction) + self.beta = beta + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.smooth_l1_loss(input, target, reduction=self.reduction, beta=self.beta) + + +class HuberLoss(_Loss): + + __constants__ = ["reduction", "delta"] + + def __init__(self, reduction: str = "mean", delta: float = 1.0) -> None: + super().__init__(reduction=reduction) + self.delta = delta + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.huber_loss(input, target, reduction=self.reduction, delta=self.delta) + + +class SoftMarginLoss(_Loss): + + __constants__ = ["reduction"] + + def __init__(self, size_average=None, reduce=None, reduction: str = "mean") -> None: + super(SoftMarginLoss, self).__init__(size_average, reduce, reduction) + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.soft_margin_loss(input, target, reduction=self.reduction) + + +class CrossEntropyLoss(_WeightedLoss): + + __constants__ = ["ignore_index", "reduction", "label_smoothing"] + ignore_index: int + label_smoothing: float + + def __init__( + self, + weight: Optional[Tensor] = None, + size_average=None, + ignore_index: int = -100, + reduce=None, + reduction: str = "mean", + label_smoothing: float = 0.0, + ) -> None: + super(CrossEntropyLoss, self).__init__(weight, size_average, reduce, reduction) + self.ignore_index = ignore_index + self.label_smoothing = label_smoothing + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.cross_entropy( + input, + target, + weight=self.weight, + ignore_index=self.ignore_index, + reduction=self.reduction, + label_smoothing=self.label_smoothing, + ) + + +class MultiLabelSoftMarginLoss(_WeightedLoss): + + __constants__ = ["reduction"] + + def __init__( + self, + weight: Optional[Tensor] = None, + size_average=None, + reduce=None, + reduction: str = "mean", + ) -> None: + super(MultiLabelSoftMarginLoss, self).__init__( + weight, size_average, reduce, reduction + ) + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.multilabel_soft_margin_loss( + input, target, weight=self.weight, reduction=self.reduction + ) + + +class CosineEmbeddingLoss(_Loss): + + __constants__ = ["margin", "reduction"] + margin: float + + def __init__( + self, + margin: float = 0.0, + size_average=None, + reduce=None, + reduction: str = "mean", + ) -> None: + super(CosineEmbeddingLoss, self).__init__(size_average, reduce, reduction) + self.margin = margin + + def forward(self, input1: Tensor, input2: Tensor, target: Tensor) -> Tensor: + return F.cosine_embedding_loss( + input1, input2, target, margin=self.margin, reduction=self.reduction + ) + + +class MarginRankingLoss(_Loss): + + __constants__ = ["margin", "reduction"] + margin: float + + def __init__( + self, + margin: float = 0.0, + size_average=None, + reduce=None, + reduction: str = "mean", + ) -> None: + super(MarginRankingLoss, self).__init__(size_average, reduce, reduction) + self.margin = margin + + def forward(self, input1: Tensor, input2: Tensor, target: Tensor) -> Tensor: + return F.margin_ranking_loss( + input1, input2, target, margin=self.margin, reduction=self.reduction + ) + + +class MultiMarginLoss(_WeightedLoss): + + __constants__ = ["p", "margin", "reduction"] + margin: float + p: int + + def __init__( + self, + p: int = 1, + margin: float = 1.0, + weight: Optional[Tensor] = None, + size_average=None, + reduce=None, + reduction: str = "mean", + ) -> None: + super(MultiMarginLoss, self).__init__(weight, size_average, reduce, reduction) + if p != 1 and p != 2: + raise ValueError("only p == 1 and p == 2 supported") + assert weight is None or weight.dim() == 1 + self.p = p + self.margin = margin + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.multi_margin_loss( + input, + target, + p=self.p, + margin=self.margin, + weight=self.weight, + reduction=self.reduction, + ) + + +class TripletMarginLoss(_Loss): + + __constants__ = ["margin", "p", "eps", "swap", "reduction"] + margin: float + p: float + eps: float + swap: bool + + def __init__( + self, + margin: float = 1.0, + p: float = 2.0, + eps: float = 1e-6, + swap: bool = False, + size_average=None, + reduce=None, + reduction: str = "mean", + ): + super(TripletMarginLoss, self).__init__(size_average, reduce, reduction) + self.margin = margin + self.p = p + self.eps = eps + self.swap = swap + + def forward(self, anchor: Tensor, positive: Tensor, negative: Tensor) -> Tensor: + return F.triplet_margin_loss( + anchor, + positive, + negative, + margin=self.margin, + p=self.p, + eps=self.eps, + swap=self.swap, + reduction=self.reduction, + ) + + +class TripletMarginWithDistanceLoss(_Loss): + + __constants__ = ["margin", "swap", "reduction"] + margin: float + swap: bool + + def __init__( + self, + *, + distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = None, + margin: float = 1.0, + swap: bool = False, + reduction: str = "mean", + ): + super(TripletMarginWithDistanceLoss, self).__init__( + size_average=None, reduce=None, reduction=reduction + ) + self.distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = ( + distance_function if distance_function is not None else PairwiseDistance() + ) + self.margin = margin + self.swap = swap + + def forward(self, anchor: Tensor, positive: Tensor, negative: Tensor) -> Tensor: + return F.triplet_margin_with_distance_loss( + anchor, + positive, + negative, + distance_function=self.distance_function, + margin=self.margin, + swap=self.swap, + reduction=self.reduction, + ) + + +class CTCLoss(_Loss): + + __constants__ = ["blank", "reduction"] + blank: int + zero_infinity: bool + + def __init__( + self, blank: int = 0, reduction: str = "mean", zero_infinity: bool = False + ): + super(CTCLoss, self).__init__(reduction=reduction) + self.blank = blank + self.zero_infinity = zero_infinity + + def forward( + self, + log_probs: Tensor, + targets: Tensor, + input_lengths: Tensor, + target_lengths: Tensor, + ) -> Tensor: + return F.ctc_loss( + log_probs, + targets, + input_lengths, + target_lengths, + self.blank, + self.reduction, + self.zero_infinity, + ) + + +# TODO: L1HingeEmbeddingCriterion +# TODO: MSECriterion weight +# TODO: ClassSimplexCriterion diff --git a/pi/nn/modules/module.py b/pi/nn/modules/module.py new file mode 100644 index 0000000..b4a149a --- /dev/null +++ b/pi/nn/modules/module.py @@ -0,0 +1,231 @@ +import inspect +import itertools +from abc import abstractmethod +from collections import OrderedDict +from typing import Dict, Optional, Callable, Union + +from ..._tensor import Tensor +from ...types_ import dtype as pi_dtype +from ..parameter import ( + Parameter, + UninitializedParameter, + UninitializedBuffer, + is_uninitialized, +) +from ...utils import hooks +from ...utils.hooks import RemovableHandle + + +class Module: + _parameters: Dict[str, Optional[Union[Parameter, UninitializedParameter]]] + _buffers: Dict[str, Optional[Union[Tensor, UninitializedBuffer]]] + _modules: Dict[str, Optional["Module"]] + _forward_pre_hooks: OrderedDict[str, Callable] + _forward_post_hooks: OrderedDict[str, Callable] + _forward: Callable + + def __init__(self): + _set = super().__setattr__ + _get = super().__getattribute__ + + _set("_parameters", {}) + _set("_buffers", {}) + _set("_modules", {}) + _set("_forward_pre_hooks", OrderedDict()) + _set("_forward_post_hooks", OrderedDict()) + _set( + "_initialize_hook", + _get("register_forward_pre_hook")(_get("_infer_parameters")), + ) + + if "forward" in dir(self): + orig_forward = _get("forward") + # super attr is Module.__call__ d'oh + call = self.__call__ + _set("_forward", orig_forward) + _set("forward", call) + # TODO(max): checks here + if hasattr(orig_forward, "__placeholders__"): + # setattr(call, "__annotations__", orig_forward.__annotations__) + call.__dict__["__placeholders__"] = orig_forward.__placeholders__ + + super(Module, self).__init__() + + @abstractmethod + def forward(self, *args, **kwargs): + raise NotImplementedError + + def __call__(self, *args, **kwargs): + for hook_id, hook in self._forward_pre_hooks.items(): + result = hook(self, *args, **kwargs) # type: ignore[misc] + if result is not None: + if isinstance(result, tuple) and len(result) == 2: + args, kwargs = result + else: + raise RuntimeError( + "forward pre-hook must return None or a tuple " + f"of (new_args, new_kwargs), but got {result}." + ) + result = self._forward(*args, **kwargs) + for hook_id, hook in self._forward_post_hooks.items(): + result = hook(self, result, *args, **kwargs) + + return result + + def _register_hook( + self, + hook: Callable[..., None], + hook_dict: OrderedDict[str, Callable], + *, + prepend: bool = False, + ) -> RemovableHandle: + handle = hooks.RemovableHandle(hook_dict) + hook_name = hook.__func__.__name__ if inspect.ismethod(hook) else hook.__name__ + hook_id = f"{hook_name}_{handle.id}" + hook_dict[hook_id] = hook + if prepend: + hook_dict.move_to_end(hook_id, last=False) # type: ignore[attr-defined] + return handle + + def register_forward_pre_hook( + self, + hook: Callable[..., None], + *, + prepend: bool = False, + ) -> RemovableHandle: + return self._register_hook(hook, self._forward_pre_hooks, prepend=prepend) + + def register_forward_post_hook( + self, + hook: Callable[..., None], + *, + prepend: bool = False, + ) -> RemovableHandle: + return self._register_hook(hook, self._forward_post_hooks, prepend=prepend) + + def __getattribute__(self, item): + return super(Module, self).__getattribute__(item) + + def __setattr__(self, name: str, value: Union[Tensor, "Module"]) -> None: + def remove_from(*dicts_or_sets): + for d in dicts_or_sets: + if name in d: + if isinstance(d, dict): + del d[name] + else: + d.discard(name) + + if value.__class__.__name__ == UninitializedParameter.__name__: + remove_from( + self.__dict__, + self._buffers, + self._modules, + ) + assert isinstance( + value, UninitializedParameter + ), f"class comparison failed {type(value)} {UninitializedParameter}" + self.register_(name, value) + elif name in self._parameters: + assert value is None or ( + isinstance(self._parameters[name], UninitializedParameter) + and isinstance(value, Tensor) + ), f"{name}:{type(value).__module__}.{type(value).__name__} cannot override parameter {name}:{type(self._parameters[name]).__module__}.{type(self._parameters[name]).__name__}" + self.register_(name, value) + else: + if isinstance(value, Module): + remove_from( + self.__dict__, + self._parameters, + self._buffers, + ) + self.register_module(name, value) + elif name in self._modules: + assert value is None, f"{type(value)} cannot override module {name}" + self.register_module(name, value) + else: + if name in self._buffers: + assert value is None, f"{type(value)} cannot override buffer {name}" + self.register_buffer(name, value) + else: + super().__setattr__(name, value) + + def __getattr__( + self, name: str + ) -> Union[Tensor, "Module", UninitializedParameter, UninitializedBuffer]: + _parameters = self.__dict__["_parameters"] + if name in _parameters: + return _parameters[name] + _buffers = self.__dict__["_buffers"] + if name in _buffers: + return _buffers[name] + modules = self.__dict__["_modules"] + if name in modules: + return modules[name] + raise AttributeError(f"{type(self).__name__} object has no attribute {name}") + + def register_buffer( + self, + name: str, + tensor: Optional[Union[Tensor, UninitializedBuffer]], + persistent: bool = True, + ) -> None: + self._buffers[name] = tensor + + def register_( + self, name: str, param: Optional[Union[Parameter, UninitializedParameter]] + ) -> None: + self._parameters[name] = param + + def register_module(self, name: str, module: Optional["Module"]) -> None: + self._modules[name] = module + + def initialize_parameters(self, *_args, **_kwargs): + parameters = self.__dict__["_parameters"] + for name, param in sorted(parameters.items()): + if param.__class__.__name__ == UninitializedParameter.__name__: + assert isinstance( + param, UninitializedParameter + ), f"class comparison failed {type(param)} {UninitializedParameter}" + parameters[name] = param() + + def has_uninitialized_params(self): + params = self._parameters.values() + buffers = self._buffers.values() + for param in itertools.chain(params, buffers): + if is_uninitialized(param): + return param + return None + + def not_uninitialized(self): + params = self._parameters.values() + buffers = self._buffers.values() + + for param in itertools.chain(params, buffers): + if not is_uninitialized(param): + return param + return None + + def _infer_parameters(self, _self, *args, **kwargs): + self.initialize_parameters(*args, **kwargs) + if uninitialized_param := self.has_uninitialized_params(): + raise RuntimeError( + f"module {self.__class__.__name__} has not been fully initialized; {uninitialized_param}" + ) + self._initialize_hook.remove() + delattr(self, "_initialize_hook") + + def to(self, dtype: pi_dtype): + if initialized_param := self.not_uninitialized(): + raise RuntimeError( + f"module {self.__class__.__name__} has already been initialized; {initialized_param}" + ) + + for name, param in self._parameters.items(): + assert is_uninitialized(param), f"{param} already initialized" + self._parameters[name] = UninitializedParameter(*param.size, dtype=dtype) + + for name, buffer in self._buffers.items(): + assert is_uninitialized(buffer), f"{buffer} already initialized" + self._buffers[name] = UninitializedBuffer(*buffer.size, dtype=dtype) + + return self diff --git a/pi/nn/modules/normalization.py b/pi/nn/modules/normalization.py new file mode 100644 index 0000000..dc8b7ac --- /dev/null +++ b/pi/nn/modules/normalization.py @@ -0,0 +1,167 @@ +import pi +import numbers +from .module import Module + +# from ._functions import CrossMapLRN2d as _cross_map_lrn2d +from .. import functional as F +from .. import init +from ..parameter import UninitializedParameter + +from pi import Tensor +from typing import Union, List, Tuple + +__all__ = [ + "LocalResponseNorm", + # "CrossMapLRN2d", + "LayerNorm", + "GroupNorm", +] + + +class LocalResponseNorm(Module): + + __constants__ = ["size", "alpha", "beta", "k"] + size: int + alpha: float + beta: float + k: float + + def __init__( + self, size: int, alpha: float = 1e-4, beta: float = 0.75, k: float = 1.0 + ) -> None: + super(LocalResponseNorm, self).__init__() + self.size = size + self.alpha = alpha + self.beta = beta + self.k = k + + def forward(self, input: Tensor) -> Tensor: + return F.local_response_norm(input, self.size, self.alpha, self.beta, self.k) + + def extra_repr(self): + return "{size}, alpha={alpha}, beta={beta}, k={k}".format(**self.__dict__) + + +# class CrossMapLRN2d(Module): +# size: int +# alpha: float +# beta: float +# k: float +# +# def __init__( +# self, size: int, alpha: float = 1e-4, beta: float = 0.75, k: float = 1 +# ) -> None: +# super(CrossMapLRN2d, self).__init__() +# self.size = size +# self.alpha = alpha +# self.beta = beta +# self.k = k +# +# def forward(self, input: Tensor) -> Tensor: +# return _cross_map_lrn2d.apply(input, self.size, self.alpha, self.beta, self.k) +# +# def extra_repr(self) -> str: +# return "{size}, alpha={alpha}, beta={beta}, k={k}".format(**self.__dict__) + + +_shape_t = Union[int, List[int]] + + +class LayerNorm(Module): + + __constants__ = ["normalized_shape", "eps", "elementwise_affine"] + normalized_shape: Tuple[int, ...] + eps: float + elementwise_affine: bool + + def __init__( + self, + normalized_shape: _shape_t, + eps: float = 1e-5, + elementwise_affine: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + # mypy error: incompatible types in assignment + normalized_shape = (normalized_shape,) # type: ignore[assignment] + self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = UninitializedParameter( + self.normalized_shape, **factory_kwargs + ) + self.bias = UninitializedParameter(self.normalized_shape, **factory_kwargs) + else: + self.register_("weight", None) + self.register_("bias", None) + + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.elementwise_affine: + init.ones_(self.weight) + init.zeros_(self.bias) + + def forward(self, input: Tensor) -> Tensor: + return F.layer_norm( + input, self.normalized_shape, self.weight, self.bias, self.eps + ) + + def extra_repr(self) -> str: + return ( + "{normalized_shape}, eps={eps}, " + "elementwise_affine={elementwise_affine}".format(**self.__dict__) + ) + + +class GroupNorm(Module): + + __constants__ = ["num_groups", "num_channels", "eps", "affine"] + num_groups: int + num_channels: int + eps: float + affine: bool + + def __init__( + self, + num_groups: int, + num_channels: int, + eps: float = 1e-5, + affine: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(GroupNorm, self).__init__() + if num_channels % num_groups != 0: + raise ValueError("num_channels must be divisible by num_groups") + + self.num_groups = num_groups + self.num_channels = num_channels + self.eps = eps + self.affine = affine + if self.affine: + self.weight = UninitializedParameter(num_channels, **factory_kwargs) + self.bias = UninitializedParameter(num_channels, **factory_kwargs) + else: + self.register_("weight", None) + self.register_("bias", None) + + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.affine: + init.ones_(self.weight) + init.zeros_(self.bias) + + def forward(self, input: Tensor) -> Tensor: + return F.group_norm(input, self.num_groups, self.weight, self.bias, self.eps) + + def extra_repr(self) -> str: + return "{num_groups}, {num_channels}, eps={eps}, " "affine={affine}".format( + **self.__dict__ + ) diff --git a/pi/nn/modules/padding.py b/pi/nn/modules/padding.py new file mode 100644 index 0000000..e848b9e --- /dev/null +++ b/pi/nn/modules/padding.py @@ -0,0 +1,143 @@ +from typing import Sequence, Tuple + +from .module import Module +from .utils import _pair, _quadruple, _ntuple +from .. import functional as F + +from pi import Tensor +from ..common_types import _size_2_t, _size_4_t, _size_6_t + + +__all__ = [ + "ConstantPad1d", + "ConstantPad2d", + "ConstantPad3d", + "ReflectionPad1d", + "ReflectionPad2d", + "ReflectionPad3d", + "ReplicationPad1d", + "ReplicationPad2d", + "ReplicationPad3d", + "ZeroPad2d", +] + + +class _ConstantPadNd(Module): + __constants__ = ["padding", "value"] + value: float + padding: Sequence[int] + + def __init__(self, value: float) -> None: + super(_ConstantPadNd, self).__init__() + self.value = value + + def forward(self, input: Tensor) -> Tensor: + return F.pad(input, self.padding, "constant", self.value) + + def extra_repr(self) -> str: + return "padding={}, value={}".format(self.padding, self.value) + + +class ConstantPad1d(_ConstantPadNd): + padding: Tuple[int, int] + + def __init__(self, padding: _size_2_t, value: float): + super(ConstantPad1d, self).__init__(value) + self.padding = _pair(padding) + + +class ConstantPad2d(_ConstantPadNd): + __constants__ = ["padding", "value"] + padding: Tuple[int, int, int, int] + + def __init__(self, padding: _size_4_t, value: float) -> None: + super(ConstantPad2d, self).__init__(value) + self.padding = _quadruple(padding) + + +class ConstantPad3d(_ConstantPadNd): + padding: Tuple[int, int, int, int, int, int] + + def __init__(self, padding: _size_6_t, value: float) -> None: + super(ConstantPad3d, self).__init__(value) + self.padding = _ntuple(6)(padding) + + +class _ReflectionPadNd(Module): + __constants__ = ["padding"] + padding: Sequence[int] + + def forward(self, input: Tensor) -> Tensor: + return F.pad(input, self.padding, "reflect") + + def extra_repr(self) -> str: + return "{}".format(self.padding) + + +class ReflectionPad1d(_ReflectionPadNd): + padding: Tuple[int, int] + + def __init__(self, padding: _size_2_t) -> None: + super(ReflectionPad1d, self).__init__() + self.padding = _pair(padding) + + +class ReflectionPad2d(_ReflectionPadNd): + padding: Tuple[int, int, int, int] + + def __init__(self, padding: _size_4_t) -> None: + super(ReflectionPad2d, self).__init__() + self.padding = _quadruple(padding) + + +class ReflectionPad3d(_ReflectionPadNd): + padding: Tuple[int, int, int, int, int, int] + + def __init__(self, padding: _size_6_t) -> None: + super(ReflectionPad3d, self).__init__() + self.padding = _ntuple(6)(padding) + + +class _ReplicationPadNd(Module): + __constants__ = ["padding"] + padding: Sequence[int] + + def forward(self, input: Tensor) -> Tensor: + return F.pad(input, self.padding, "replicate") + + def extra_repr(self) -> str: + return "{}".format(self.padding) + + +class ReplicationPad1d(_ReplicationPadNd): + padding: Tuple[int, int] + + def __init__(self, padding: _size_2_t) -> None: + super(ReplicationPad1d, self).__init__() + self.padding = _pair(padding) + + +class ReplicationPad2d(_ReplicationPadNd): + padding: Tuple[int, int, int, int] + + def __init__(self, padding: _size_4_t) -> None: + super(ReplicationPad2d, self).__init__() + self.padding = _quadruple(padding) + + +class ReplicationPad3d(_ReplicationPadNd): + padding: Tuple[int, int, int, int, int, int] + + def __init__(self, padding: _size_6_t) -> None: + super(ReplicationPad3d, self).__init__() + self.padding = _ntuple(6)(padding) + + +class ZeroPad2d(ConstantPad2d): + padding: Tuple[int, int, int, int] + + def __init__(self, padding: _size_4_t) -> None: + super(ZeroPad2d, self).__init__(padding, 0.0) + + def extra_repr(self) -> str: + return "{}".format(self.padding) diff --git a/pi/nn/modules/pixelshuffle.py b/pi/nn/modules/pixelshuffle.py new file mode 100644 index 0000000..0c5ba29 --- /dev/null +++ b/pi/nn/modules/pixelshuffle.py @@ -0,0 +1,38 @@ +from .module import Module +from .. import functional as F + +from pi import Tensor + +__all__ = ["PixelShuffle", "PixelUnshuffle"] + + +class PixelShuffle(Module): + + __constants__ = ["upscale_factor"] + upscale_factor: int + + def __init__(self, upscale_factor: int) -> None: + super(PixelShuffle, self).__init__() + self.upscale_factor = upscale_factor + + def forward(self, input: Tensor) -> Tensor: + return F.pixel_shuffle(input, self.upscale_factor) + + def extra_repr(self) -> str: + return "upscale_factor={}".format(self.upscale_factor) + + +class PixelUnshuffle(Module): + + __constants__ = ["downscale_factor"] + downscale_factor: int + + def __init__(self, downscale_factor: int) -> None: + super(PixelUnshuffle, self).__init__() + self.downscale_factor = downscale_factor + + def forward(self, input: Tensor) -> Tensor: + return F.pixel_unshuffle(input, self.downscale_factor) + + def extra_repr(self) -> str: + return "downscale_factor={}".format(self.downscale_factor) diff --git a/pi/nn/modules/pooling.py b/pi/nn/modules/pooling.py new file mode 100644 index 0000000..3e83765 --- /dev/null +++ b/pi/nn/modules/pooling.py @@ -0,0 +1,590 @@ +from typing import List, Optional + +from pi import Tensor +from .module import Module +from .utils import _single, _pair, _triple +from .. import functional as F + +from ..common_types import ( + _size_any_t, + _size_1_t, + _size_2_t, + _size_3_t, + _ratio_3_t, + _ratio_2_t, + _size_any_opt_t, + _size_2_opt_t, + _size_3_opt_t, +) + +__all__ = [ + "MaxPool1d", + "MaxPool2d", + "MaxPool3d", + "MaxUnpool1d", + "MaxUnpool2d", + "MaxUnpool3d", + "AvgPool1d", + "AvgPool2d", + "AvgPool3d", + "FractionalMaxPool2d", + "FractionalMaxPool3d", + "LPPool1d", + "LPPool2d", + "AdaptiveMaxPool1d", + "AdaptiveMaxPool2d", + "AdaptiveMaxPool3d", + "AdaptiveAvgPool1d", + "AdaptiveAvgPool2d", + "AdaptiveAvgPool3d", +] + + +class _MaxPoolNd(Module): + __constants__ = [ + "kernel_size", + "stride", + "padding", + "dilation", + "return_indices", + "ceil_mode", + ] + return_indices: bool + ceil_mode: bool + + def __init__( + self, + kernel_size: _size_any_t, + stride: Optional[_size_any_t] = None, + padding: _size_any_t = 0, + dilation: _size_any_t = 1, + return_indices: bool = False, + ceil_mode: bool = False, + ) -> None: + super(_MaxPoolNd, self).__init__() + self.kernel_size = kernel_size + self.stride = stride if (stride is not None) else kernel_size + self.padding = padding + self.dilation = dilation + self.return_indices = return_indices + self.ceil_mode = ceil_mode + + def extra_repr(self) -> str: + return ( + "kernel_size={kernel_size}, stride={stride}, padding={padding}" + ", dilation={dilation}, ceil_mode={ceil_mode}".format(**self.__dict__) + ) + + +class MaxPool1d(_MaxPoolNd): + + kernel_size: _size_1_t + stride: _size_1_t + padding: _size_1_t + dilation: _size_1_t + + def forward(self, input: Tensor): + return F.max_pool1d( + input, + self.kernel_size, + self.stride, + self.padding, + self.dilation, + ceil_mode=self.ceil_mode, + return_indices=self.return_indices, + ) + + +class MaxPool2d(_MaxPoolNd): + + kernel_size: _size_2_t + stride: _size_2_t + padding: _size_2_t + dilation: _size_2_t + + def forward(self, input: Tensor): + return F.max_pool2d( + input, + self.kernel_size, + self.stride, + self.padding, + self.dilation, + ceil_mode=self.ceil_mode, + return_indices=self.return_indices, + ) + + +class MaxPool3d(_MaxPoolNd): + # noqa: E501 + + kernel_size: _size_3_t + stride: _size_3_t + padding: _size_3_t + dilation: _size_3_t + + def forward(self, input: Tensor): + return F.max_pool3d( + input, + self.kernel_size, + self.stride, + self.padding, + self.dilation, + ceil_mode=self.ceil_mode, + return_indices=self.return_indices, + ) + + +class _MaxUnpoolNd(Module): + def extra_repr(self) -> str: + return "kernel_size={}, stride={}, padding={}".format( + self.kernel_size, self.stride, self.padding + ) + + +class MaxUnpool1d(_MaxUnpoolNd): + + kernel_size: _size_1_t + stride: _size_1_t + padding: _size_1_t + + def __init__( + self, + kernel_size: _size_1_t, + stride: Optional[_size_1_t] = None, + padding: _size_1_t = 0, + ) -> None: + super(MaxUnpool1d, self).__init__() + self.kernel_size = _single(kernel_size) + self.stride = _single(stride if (stride is not None) else kernel_size) + self.padding = _single(padding) + + def forward( + self, input: Tensor, indices: Tensor, output_size: Optional[List[int]] = None + ) -> Tensor: + return F.max_unpool1d( + input, indices, self.kernel_size, self.stride, self.padding, output_size + ) + + +class MaxUnpool2d(_MaxUnpoolNd): + + kernel_size: _size_2_t + stride: _size_2_t + padding: _size_2_t + + def __init__( + self, + kernel_size: _size_2_t, + stride: Optional[_size_2_t] = None, + padding: _size_2_t = 0, + ) -> None: + super(MaxUnpool2d, self).__init__() + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride if (stride is not None) else kernel_size) + self.padding = _pair(padding) + + def forward( + self, input: Tensor, indices: Tensor, output_size: Optional[List[int]] = None + ) -> Tensor: + return F.max_unpool2d( + input, indices, self.kernel_size, self.stride, self.padding, output_size + ) + + +class MaxUnpool3d(_MaxUnpoolNd): + + kernel_size: _size_3_t + stride: _size_3_t + padding: _size_3_t + + def __init__( + self, + kernel_size: _size_3_t, + stride: Optional[_size_3_t] = None, + padding: _size_3_t = 0, + ) -> None: + super(MaxUnpool3d, self).__init__() + self.kernel_size = _triple(kernel_size) + self.stride = _triple(stride if (stride is not None) else kernel_size) + self.padding = _triple(padding) + + def forward( + self, input: Tensor, indices: Tensor, output_size: Optional[List[int]] = None + ) -> Tensor: + return F.max_unpool3d( + input, indices, self.kernel_size, self.stride, self.padding, output_size + ) + + +class _AvgPoolNd(Module): + __constants__ = [ + "kernel_size", + "stride", + "padding", + "ceil_mode", + "count_include_pad", + ] + + def extra_repr(self) -> str: + return "kernel_size={}, stride={}, padding={}".format( + self.kernel_size, self.stride, self.padding + ) + + +class AvgPool1d(_AvgPoolNd): + + kernel_size: _size_1_t + stride: _size_1_t + padding: _size_1_t + ceil_mode: bool + count_include_pad: bool + + def __init__( + self, + kernel_size: _size_1_t, + stride: _size_1_t = None, + padding: _size_1_t = 0, + ceil_mode: bool = False, + count_include_pad: bool = True, + ) -> None: + super(AvgPool1d, self).__init__() + self.kernel_size = _single(kernel_size) + self.stride = _single(stride if stride is not None else kernel_size) + self.padding = _single(padding) + self.ceil_mode = ceil_mode + self.count_include_pad = count_include_pad + + def forward(self, input: Tensor) -> Tensor: + return F.avg_pool1d( + input, + self.kernel_size, + self.stride, + self.padding, + self.ceil_mode, + self.count_include_pad, + ) + + +class AvgPool2d(_AvgPoolNd): + + __constants__ = [ + "kernel_size", + "stride", + "padding", + "ceil_mode", + "count_include_pad", + "divisor_override", + ] + + kernel_size: _size_2_t + stride: _size_2_t + padding: _size_2_t + ceil_mode: bool + count_include_pad: bool + + def __init__( + self, + kernel_size: _size_2_t, + stride: Optional[_size_2_t] = None, + padding: _size_2_t = 0, + ceil_mode: bool = False, + count_include_pad: bool = True, + divisor_override: Optional[int] = None, + ) -> None: + super(AvgPool2d, self).__init__() + self.kernel_size = kernel_size + self.stride = stride if (stride is not None) else kernel_size + self.padding = padding + self.ceil_mode = ceil_mode + self.count_include_pad = count_include_pad + self.divisor_override = divisor_override + + def forward(self, input: Tensor) -> Tensor: + return F.avg_pool2d( + input, + self.kernel_size, + self.stride, + self.padding, + self.ceil_mode, + self.count_include_pad, + self.divisor_override, + ) + + +class AvgPool3d(_AvgPoolNd): + + __constants__ = [ + "kernel_size", + "stride", + "padding", + "ceil_mode", + "count_include_pad", + "divisor_override", + ] + + kernel_size: _size_3_t + stride: _size_3_t + padding: _size_3_t + ceil_mode: bool + count_include_pad: bool + + def __init__( + self, + kernel_size: _size_3_t, + stride: Optional[_size_3_t] = None, + padding: _size_3_t = 0, + ceil_mode: bool = False, + count_include_pad: bool = True, + divisor_override: Optional[int] = None, + ) -> None: + super(AvgPool3d, self).__init__() + self.kernel_size = kernel_size + self.stride = stride if (stride is not None) else kernel_size + self.padding = padding + self.ceil_mode = ceil_mode + self.count_include_pad = count_include_pad + self.divisor_override = divisor_override + + def forward(self, input: Tensor) -> Tensor: + return F.avg_pool3d( + input, + self.kernel_size, + self.stride, + self.padding, + self.ceil_mode, + self.count_include_pad, + self.divisor_override, + ) + + def __setstate__(self, d): + super(AvgPool3d, self).__setstate__(d) + self.__dict__.setdefault("padding", 0) + self.__dict__.setdefault("ceil_mode", False) + self.__dict__.setdefault("count_include_pad", True) + + +class FractionalMaxPool2d(Module): + + __constants__ = ["kernel_size", "return_indices", "output_size", "output_ratio"] + + kernel_size: _size_2_t + return_indices: bool + output_size: _size_2_t + output_ratio: _ratio_2_t + + def __init__( + self, + kernel_size: _size_2_t, + output_size: Optional[_size_2_t] = None, + output_ratio: Optional[_ratio_2_t] = None, + return_indices: bool = False, + _random_samples=None, + ) -> None: + super(FractionalMaxPool2d, self).__init__() + self.kernel_size = _pair(kernel_size) + self.return_indices = return_indices + self.register_buffer("_random_samples", _random_samples) + self.output_size = _pair(output_size) if output_size is not None else None + self.output_ratio = _pair(output_ratio) if output_ratio is not None else None + if output_size is None and output_ratio is None: + raise ValueError( + "FractionalMaxPool2d requires specifying either " + "an output size, or a pooling ratio" + ) + if output_size is not None and output_ratio is not None: + raise ValueError( + "only one of output_size and output_ratio may be specified" + ) + if self.output_ratio is not None: + if not (0 < self.output_ratio[0] < 1 and 0 < self.output_ratio[1] < 1): + raise ValueError( + "output_ratio must be between 0 and 1 (got {})".format(output_ratio) + ) + + def forward(self, input: Tensor): + return F.fractional_max_pool2d( + input, + self.kernel_size, + self.output_size, + self.output_ratio, + self.return_indices, + _random_samples=self._random_samples, + ) + + +class FractionalMaxPool3d(Module): + + __constants__ = ["kernel_size", "return_indices", "output_size", "output_ratio"] + kernel_size: _size_3_t + return_indices: bool + output_size: _size_3_t + output_ratio: _ratio_3_t + + def __init__( + self, + kernel_size: _size_3_t, + output_size: Optional[_size_3_t] = None, + output_ratio: Optional[_ratio_3_t] = None, + return_indices: bool = False, + _random_samples=None, + ) -> None: + super(FractionalMaxPool3d, self).__init__() + self.kernel_size = _triple(kernel_size) + self.return_indices = return_indices + self.register_buffer("_random_samples", _random_samples) + self.output_size = _triple(output_size) if output_size is not None else None + self.output_ratio = _triple(output_ratio) if output_ratio is not None else None + if output_size is None and output_ratio is None: + raise ValueError( + "FractionalMaxPool3d requires specifying either " + "an output size, or a pooling ratio" + ) + if output_size is not None and output_ratio is not None: + raise ValueError( + "only one of output_size and output_ratio may be specified" + ) + if self.output_ratio is not None: + if not ( + 0 < self.output_ratio[0] < 1 + and 0 < self.output_ratio[1] < 1 + and 0 < self.output_ratio[2] < 1 + ): + raise ValueError( + "output_ratio must be between 0 and 1 (got {})".format(output_ratio) + ) + + def forward(self, input: Tensor): + return F.fractional_max_pool3d( + input, + self.kernel_size, + self.output_size, + self.output_ratio, + self.return_indices, + _random_samples=self._random_samples, + ) + + +class _LPPoolNd(Module): + __constants__ = ["norm_type", "kernel_size", "stride", "ceil_mode"] + + norm_type: float + ceil_mode: bool + + def __init__( + self, + norm_type: float, + kernel_size: _size_any_t, + stride: Optional[_size_any_t] = None, + ceil_mode: bool = False, + ) -> None: + super(_LPPoolNd, self).__init__() + self.norm_type = norm_type + self.kernel_size = kernel_size + self.stride = stride + self.ceil_mode = ceil_mode + + def extra_repr(self) -> str: + return ( + "norm_type={norm_type}, kernel_size={kernel_size}, stride={stride}, " + "ceil_mode={ceil_mode}".format(**self.__dict__) + ) + + +class LPPool1d(_LPPoolNd): + + kernel_size: _size_1_t + stride: _size_1_t + + def forward(self, input: Tensor) -> Tensor: + return F.lp_pool1d( + input, float(self.norm_type), self.kernel_size, self.stride, self.ceil_mode + ) + + +class LPPool2d(_LPPoolNd): + + kernel_size: _size_2_t + stride: _size_2_t + + def forward(self, input: Tensor) -> Tensor: + return F.lp_pool2d( + input, float(self.norm_type), self.kernel_size, self.stride, self.ceil_mode + ) + + +class _AdaptiveMaxPoolNd(Module): + __constants__ = ["output_size", "return_indices"] + return_indices: bool + + def __init__( + self, output_size: _size_any_opt_t, return_indices: bool = False + ) -> None: + super(_AdaptiveMaxPoolNd, self).__init__() + self.output_size = output_size + self.return_indices = return_indices + + def extra_repr(self) -> str: + return "output_size={}".format(self.output_size) + + +# FIXME (by @ssnl): Improve adaptive pooling docs: specify what the input and +# output shapes are, and how the operation computes output. + + +class AdaptiveMaxPool1d(_AdaptiveMaxPoolNd): + + output_size: _size_1_t + + def forward(self, input: Tensor) -> Tensor: + return F.adaptive_max_pool1d(input, self.output_size, self.return_indices) + + +class AdaptiveMaxPool2d(_AdaptiveMaxPoolNd): + + output_size: _size_2_opt_t + + def forward(self, input: Tensor): + return F.adaptive_max_pool2d(input, self.output_size, self.return_indices) + + +class AdaptiveMaxPool3d(_AdaptiveMaxPoolNd): + + output_size: _size_3_opt_t + + def forward(self, input: Tensor): + return F.adaptive_max_pool3d(input, self.output_size, self.return_indices) + + +class _AdaptiveAvgPoolNd(Module): + __constants__ = ["output_size"] + + def __init__(self, output_size: _size_any_opt_t) -> None: + super(_AdaptiveAvgPoolNd, self).__init__() + self.output_size = output_size + + def extra_repr(self) -> str: + return "output_size={}".format(self.output_size) + + +class AdaptiveAvgPool1d(_AdaptiveAvgPoolNd): + + output_size: _size_1_t + + def forward(self, input: Tensor) -> Tensor: + return F.adaptive_avg_pool1d(input, self.output_size) + + +class AdaptiveAvgPool2d(_AdaptiveAvgPoolNd): + + output_size: _size_2_opt_t + + def forward(self, input: Tensor) -> Tensor: + return F.adaptive_avg_pool2d(input, self.output_size) + + +class AdaptiveAvgPool3d(_AdaptiveAvgPoolNd): + + output_size: _size_3_opt_t + + def forward(self, input: Tensor) -> Tensor: + return F.adaptive_avg_pool3d(input, self.output_size) diff --git a/pi/nn/modules/sparse.py b/pi/nn/modules/sparse.py new file mode 100644 index 0000000..16e775e --- /dev/null +++ b/pi/nn/modules/sparse.py @@ -0,0 +1,288 @@ +from typing import Optional, Union + +from pi import Tensor +from .module import Module +from .. import functional as F +from ..parameter import UninitializedParameter +from .. import init + +__all__ = ["Embedding", "EmbeddingBag"] + + +class Embedding(Module): + __constants__ = [ + "num_embeddings", + "embedding_dim", + "padding_idx", + "max_norm", + "norm_type", + "scale_grad_by_freq", + "sparse", + ] + + num_embeddings: int + embedding_dim: int + padding_idx: Optional[int] + max_norm: Optional[float] + norm_type: float + scale_grad_by_freq: bool + weight: Union[Tensor, UninitializedParameter] + freeze: bool + sparse: bool + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, + _weight: Optional[Union[Tensor, UninitializedParameter]] = None, + _freeze: bool = False, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"dtype": dtype} + super(Embedding, self).__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + if padding_idx is not None: + if padding_idx > 0: + assert ( + padding_idx < self.num_embeddings + ), "Padding_idx must be within num_embeddings" + elif padding_idx < 0: + assert ( + padding_idx >= -self.num_embeddings + ), "Padding_idx must be within num_embeddings" + padding_idx = self.num_embeddings + padding_idx + self.padding_idx = padding_idx + self.max_norm = max_norm + self.norm_type = norm_type + self.scale_grad_by_freq = scale_grad_by_freq + if _weight is None: + self.weight = UninitializedParameter( + (num_embeddings, embedding_dim), + **factory_kwargs, + ) + # self.reset_parameters() + else: + raise NotImplementedError + # assert list(_weight.shape) == [ + # num_embeddings, + # embedding_dim, + # ], "Shape of weight does not match num_embeddings and embedding_dim" + # self.weight = (_weight, requires_grad=not _freeze) + + self.sparse = sparse + + def reset_parameters(self) -> None: + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + self.weight[self.padding_idx].fill_(0) + + def forward(self, input: Tensor) -> Tensor: + return F.embedding( + input, + self.weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) + + def extra_repr(self) -> str: + s = "{num_embeddings}, {embedding_dim}" + if self.padding_idx is not None: + s += ", padding_idx={padding_idx}" + if self.max_norm is not None: + s += ", max_norm={max_norm}" + if self.norm_type != 2: + s += ", norm_type={norm_type}" + if self.scale_grad_by_freq is not False: + s += ", scale_grad_by_freq={scale_grad_by_freq}" + if self.sparse is not False: + s += ", sparse=True" + return s.format(**self.__dict__) + + # @classmethod + # def from_pretrained( + # cls, + # embeddings, + # freeze=True, + # padding_idx=None, + # max_norm=None, + # norm_type=2.0, + # scale_grad_by_freq=False, + # sparse=False, + # ): + # + # assert ( + # embeddings.dim() == 2 + # ), "Embeddings parameter is expected to be 2-dimensional" + # rows, cols = embeddings.shape + # embedding = cls( + # num_embeddings=rows, + # embedding_dim=cols, + # _weight=embeddings, + # _freeze=freeze, + # padding_idx=padding_idx, + # max_norm=max_norm, + # norm_type=norm_type, + # scale_grad_by_freq=scale_grad_by_freq, + # sparse=sparse, + # ) + # return embedding + + +class EmbeddingBag(Module): + + __constants__ = [ + "num_embeddings", + "embedding_dim", + "max_norm", + "norm_type", + "scale_grad_by_freq", + "mode", + "sparse", + "include_last_offset", + "padding_idx", + ] + + num_embeddings: int + embedding_dim: int + max_norm: Optional[float] + norm_type: float + scale_grad_by_freq: bool + weight: Tensor + mode: str + sparse: bool + include_last_offset: bool + padding_idx: Optional[int] + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + mode: str = "mean", + sparse: bool = False, + _weight: Optional[Tensor] = None, + include_last_offset: bool = False, + padding_idx: Optional[int] = None, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(EmbeddingBag, self).__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.max_norm = max_norm + self.norm_type = norm_type + self.scale_grad_by_freq = scale_grad_by_freq + if padding_idx is not None: + if padding_idx > 0: + assert ( + padding_idx < self.num_embeddings + ), "padding_idx must be within num_embeddings" + elif padding_idx < 0: + assert ( + padding_idx >= -self.num_embeddings + ), "padding_idx must be within num_embeddings" + padding_idx = self.num_embeddings + padding_idx + self.padding_idx = padding_idx + if _weight is None: + self.weight = UninitializedParameter( + (num_embeddings, embedding_dim), **factory_kwargs + ) + self.reset_parameters() + else: + assert list(_weight.shape) == [ + num_embeddings, + embedding_dim, + ], "Shape of weight does not match num_embeddings and embedding_dim" + self.weight = _weight + self.mode = mode + self.sparse = sparse + self.include_last_offset = include_last_offset + + def reset_parameters(self) -> None: + init.normal_(self.weight) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + self.weight[self.padding_idx].fill_(0) + + def forward( + self, + input: Tensor, + offsets: Optional[Tensor] = None, + per_sample_weights: Optional[Tensor] = None, + ) -> Tensor: + return F.embedding_bag( + input, + self.weight, + offsets, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.mode, + self.sparse, + per_sample_weights, + self.include_last_offset, + self.padding_idx, + ) + + def extra_repr(self) -> str: + s = "{num_embeddings}, {embedding_dim}" + if self.max_norm is not None: + s += ", max_norm={max_norm}" + if self.norm_type != 2: + s += ", norm_type={norm_type}" + if self.scale_grad_by_freq is not False: + s += ", scale_grad_by_freq={scale_grad_by_freq}" + s += ", mode={mode}" + if self.padding_idx is not None: + s += ", padding_idx={padding_idx}" + return s.format(**self.__dict__) + + @classmethod + def from_pretrained( + cls, + embeddings: Tensor, + freeze: bool = True, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + mode: str = "mean", + sparse: bool = False, + include_last_offset: bool = False, + padding_idx: Optional[int] = None, + ) -> "EmbeddingBag": + + assert ( + embeddings.dim() == 2 + ), "Embeddings parameter is expected to be 2-dimensional" + rows, cols = embeddings.shape + embeddingbag = cls( + num_embeddings=rows, + embedding_dim=cols, + _weight=embeddings, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + mode=mode, + sparse=sparse, + include_last_offset=include_last_offset, + padding_idx=padding_idx, + ) + embeddingbag.weight.requires_grad = not freeze + return embeddingbag diff --git a/pi/nn/modules/transformer.py b/pi/nn/modules/transformer.py new file mode 100644 index 0000000..06c785d --- /dev/null +++ b/pi/nn/modules/transformer.py @@ -0,0 +1,654 @@ +import copy +from typing import Optional, Any, Union, Callable + +import pi +from pi import Tensor +from .. import functional as F +from .module import Module +from .activation import MultiheadAttention +from .container import ModuleList +from ..init import xavier_uniform_ +from .dropout import Dropout +from .linear import Linear +from .normalization import LayerNorm + +__all__ = [ + "Transformer", + "TransformerEncoder", + "TransformerDecoder", + "TransformerEncoderLayer", + "TransformerDecoderLayer", +] + + +class Transformer(Module): + def __init__( + self, + d_model: int = 512, + nhead: int = 8, + num_encoder_layers: int = 6, + num_decoder_layers: int = 6, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + custom_encoder: Optional[Any] = None, + custom_decoder: Optional[Any] = None, + layer_norm_eps: float = 1e-5, + batch_first: bool = False, + norm_first: bool = False, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(Transformer, self).__init__() + pi._C._log_api_usage_once(f"pi.nn.modules.{self.__class__.__name__}") + + if custom_encoder is not None: + self.encoder = custom_encoder + else: + encoder_layer = TransformerEncoderLayer( + d_model, + nhead, + dim_feedforward, + dropout, + activation, + layer_norm_eps, + batch_first, + norm_first, + **factory_kwargs, + ) + encoder_norm = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + self.encoder = TransformerEncoder( + encoder_layer, num_encoder_layers, encoder_norm + ) + + if custom_decoder is not None: + self.decoder = custom_decoder + else: + decoder_layer = TransformerDecoderLayer( + d_model, + nhead, + dim_feedforward, + dropout, + activation, + layer_norm_eps, + batch_first, + norm_first, + **factory_kwargs, + ) + decoder_norm = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + self.decoder = TransformerDecoder( + decoder_layer, num_decoder_layers, decoder_norm + ) + + self._reset_parameters() + + self.d_model = d_model + self.nhead = nhead + + self.batch_first = batch_first + + def forward( + self, + src: Tensor, + tgt: Tensor, + src_mask: Optional[Tensor] = None, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + + is_batched = src.dim() == 3 + if not self.batch_first and src.size(1) != tgt.size(1) and is_batched: + raise RuntimeError("the batch number of src and tgt must be equal") + elif self.batch_first and src.size(0) != tgt.size(0) and is_batched: + raise RuntimeError("the batch number of src and tgt must be equal") + + if src.size(-1) != self.d_model or tgt.size(-1) != self.d_model: + raise RuntimeError( + "the feature number of src and tgt must be equal to d_model" + ) + + memory = self.encoder( + src, mask=src_mask, src_key_padding_mask=src_key_padding_mask + ) + output = self.decoder( + tgt, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + ) + return output + + @staticmethod + def generate_square_subsequent_mask(sz: int, device="cpu") -> Tensor: + + return pi.triu( + pi.full((sz, sz), float("-inf"), device=device), diagonal=1 + ) + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + xavier_uniform_(p) + + +class TransformerEncoder(Module): + + __constants__ = ["norm"] + + def __init__( + self, + encoder_layer, + num_layers, + norm=None, + enable_nested_tensor=True, + mask_check=True, + ): + super(TransformerEncoder, self).__init__() + pi._C._log_api_usage_once(f"pi.nn.modules.{self.__class__.__name__}") + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.enable_nested_tensor = enable_nested_tensor + self.mask_check = mask_check + + def forward( + self, + src: Tensor, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + is_causal: bool = False, + ) -> Tensor: + + if src_key_padding_mask is not None: + _skpm_dtype = src_key_padding_mask.dtype + if _skpm_dtype != pi.bool and not pi.is_floating_point( + src_key_padding_mask + ): + raise AssertionError( + "only bool and floating types of key_padding_mask are supported" + ) + output = src + convert_to_nested = False + first_layer = self.layers[0] + src_key_padding_mask_for_layers = src_key_padding_mask + why_not_sparsity_fast_path = "" + str_first_layer = "self.layers[0]" + if not isinstance(first_layer, pi.nn.TransformerEncoderLayer): + why_not_sparsity_fast_path = ( + f"{str_first_layer} was not TransformerEncoderLayer" + ) + elif first_layer.norm_first: + why_not_sparsity_fast_path = f"{str_first_layer}.norm_first was True" + elif first_layer.training: + why_not_sparsity_fast_path = f"{str_first_layer} was in training mode" + elif not first_layer.self_attn.batch_first: + why_not_sparsity_fast_path = ( + f" {str_first_layer}.self_attn.batch_first was not True" + ) + elif not first_layer.self_attn._qkv_same_embed_dim: + why_not_sparsity_fast_path = ( + f"{str_first_layer}.self_attn._qkv_same_embed_dim was not True" + ) + elif not first_layer.activation_relu_or_gelu: + why_not_sparsity_fast_path = ( + f" {str_first_layer}.activation_relu_or_gelu was not True" + ) + elif not (first_layer.norm1.eps == first_layer.norm2.eps): + why_not_sparsity_fast_path = f"{str_first_layer}.norm1.eps was not equal to {str_first_layer}.norm2.eps" + elif not src.dim() == 3: + why_not_sparsity_fast_path = ( + f"input not batched; expected src.dim() of 3 but got {src.dim()}" + ) + elif not self.enable_nested_tensor: + why_not_sparsity_fast_path = "enable_nested_tensor was not True" + elif src_key_padding_mask is None: + why_not_sparsity_fast_path = "src_key_padding_mask was None" + elif ( + (not hasattr(self, "mask_check")) or self.mask_check + ) and not pi._nested_tensor_from_mask_left_aligned( + src, src_key_padding_mask.logical_not() + ): + why_not_sparsity_fast_path = "mask_check enabled, and src and src_key_padding_mask was not left aligned" + elif output.is_nested: + why_not_sparsity_fast_path = "NestedTensor input is not supported" + elif mask is not None: + why_not_sparsity_fast_path = ( + "src_key_padding_mask and mask were both supplied" + ) + elif first_layer.self_attn.num_heads % 2 == 1: + why_not_sparsity_fast_path = "num_head is odd" + elif pi.is_autocast_enabled(): + why_not_sparsity_fast_path = "autocast is enabled" + + if not why_not_sparsity_fast_path: + tensor_args = ( + src, + first_layer.self_attn.in_proj_weight, + first_layer.self_attn.in_proj_bias, + first_layer.self_attn.out_proj.weight, + first_layer.self_attn.out_proj.bias, + first_layer.norm1.weight, + first_layer.norm1.bias, + first_layer.norm2.weight, + first_layer.norm2.bias, + first_layer.linear1.weight, + first_layer.linear1.bias, + first_layer.linear2.weight, + first_layer.linear2.bias, + ) + + if pi.overrides.has_torch_function(tensor_args): + why_not_sparsity_fast_path = "some Tensor argument has_torch_function" + elif not (src.is_cuda or "cpu" in str(src.device)): + why_not_sparsity_fast_path = "src is neither CUDA nor CPU" + elif pi.is_grad_enabled() and any(x.requires_grad for x in tensor_args): + why_not_sparsity_fast_path = ( + "grad is enabled and at least one of query or the " + "input/output projection weights or biases requires_grad" + ) + + if (not why_not_sparsity_fast_path) and (src_key_padding_mask is not None): + convert_to_nested = True + output = pi._nested_tensor_from_mask( + output, src_key_padding_mask.logical_not(), mask_check=False + ) + src_key_padding_mask_for_layers = None + + # Prevent type refinement + make_causal = False + if mask is not None: + if is_causal: + raise RuntimeError("specify either mask or is_causal, but not both") + + if make_causal: + is_causal = True + mask = None + + for mod in self.layers: + output = mod( + output, + src_mask=mask, + is_causal=is_causal, + src_key_padding_mask=src_key_padding_mask_for_layers, + ) + + if convert_to_nested: + output = output.to_padded_tensor(0.0) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerDecoder(Module): + + __constants__ = ["norm"] + + def __init__(self, decoder_layer, num_layers, norm=None): + super(TransformerDecoder, self).__init__() + pi._C._log_api_usage_once(f"pi.nn.modules.{self.__class__.__name__}") + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward( + self, + tgt: Tensor, + memory: Tensor, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + + output = tgt + + for mod in self.layers: + output = mod( + output, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + ) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerEncoderLayer(Module): + + __constants__ = ["batch_first", "norm_first"] + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + layer_norm_eps: float = 1e-5, + batch_first: bool = False, + norm_first: bool = False, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(TransformerEncoderLayer, self).__init__() + self.self_attn = MultiheadAttention( + d_model, nhead, dropout=dropout, batch_first=batch_first, **factory_kwargs + ) + # Implementation of Feedforward model + self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs) + self.dropout = Dropout(dropout) + self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs) + + self.norm_first = norm_first + self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + self.dropout1 = Dropout(dropout) + self.dropout2 = Dropout(dropout) + + # Legacy string support for activation function. + if isinstance(activation, str): + activation = _get_activation_fn(activation) + + # We can't test self.activation in forward() in TorchScript, + # so stash some information about it instead. + if activation is F.relu or isinstance(activation, pi.nn.ReLU): + self.activation_relu_or_gelu = 1 + elif activation is F.gelu or isinstance(activation, pi.nn.GELU): + self.activation_relu_or_gelu = 2 + else: + self.activation_relu_or_gelu = 0 + self.activation = activation + + def __setstate__(self, state): + super(TransformerEncoderLayer, self).__setstate__(state) + if not hasattr(self, "activation"): + self.activation = F.relu + + def forward( + self, + src: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + is_causal: bool = False, + ) -> Tensor: + + if src_key_padding_mask is not None: + _skpm_dtype = src_key_padding_mask.dtype + if _skpm_dtype != pi.bool and not pi.is_floating_point( + src_key_padding_mask + ): + raise AssertionError( + "only bool and floating types of key_padding_mask are supported" + ) + # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf + why_not_sparsity_fast_path = "" + if not src.dim() == 3: + why_not_sparsity_fast_path = ( + f"input not batched; expected src.dim() of 3 but got {src.dim()}" + ) + elif self.training: + why_not_sparsity_fast_path = "training is enabled" + elif not self.self_attn.batch_first: + why_not_sparsity_fast_path = "self_attn.batch_first was not True" + elif not self.self_attn._qkv_same_embed_dim: + why_not_sparsity_fast_path = "self_attn._qkv_same_embed_dim was not True" + elif not self.activation_relu_or_gelu: + why_not_sparsity_fast_path = "activation_relu_or_gelu was not True" + elif not (self.norm1.eps == self.norm2.eps): + why_not_sparsity_fast_path = "norm1.eps is not equal to norm2.eps" + elif src.is_nested and ( + src_key_padding_mask is not None or src_mask is not None + ): + why_not_sparsity_fast_path = "neither src_key_padding_mask nor src_mask are not supported with NestedTensor input" + elif self.self_attn.num_heads % 2 == 1: + why_not_sparsity_fast_path = "num_head is odd" + elif pi.is_autocast_enabled(): + why_not_sparsity_fast_path = "autocast is enabled" + if not why_not_sparsity_fast_path: + tensor_args = ( + src, + self.self_attn.in_proj_weight, + self.self_attn.in_proj_bias, + self.self_attn.out_proj.weight, + self.self_attn.out_proj.bias, + self.norm1.weight, + self.norm1.bias, + self.norm2.weight, + self.norm2.bias, + self.linear1.weight, + self.linear1.bias, + self.linear2.weight, + self.linear2.bias, + ) + + # We have to use list comprehensions below because TorchScript does not support + # generator expressions. + if pi.overrides.has_torch_function(tensor_args): + why_not_sparsity_fast_path = "some Tensor argument has_torch_function" + elif not all((x.is_cuda or "cpu" in str(x.device)) for x in tensor_args): + why_not_sparsity_fast_path = ( + "some Tensor argument is neither CUDA nor CPU" + ) + elif pi.is_grad_enabled() and any(x.requires_grad for x in tensor_args): + why_not_sparsity_fast_path = ( + "grad is enabled and at least one of query or the " + "input/output projection weights or biases requires_grad" + ) + + if not why_not_sparsity_fast_path: + merged_mask, mask_type = self.self_attn.merge_masks( + src_mask, src_key_padding_mask, src + ) + return pi._transformer_encoder_layer_fwd( + src, + self.self_attn.embed_dim, + self.self_attn.num_heads, + self.self_attn.in_proj_weight, + self.self_attn.in_proj_bias, + self.self_attn.out_proj.weight, + self.self_attn.out_proj.bias, + self.activation_relu_or_gelu == 2, + self.norm_first, + self.norm1.eps, + self.norm1.weight, + self.norm1.bias, + self.norm2.weight, + self.norm2.bias, + self.linear1.weight, + self.linear1.bias, + self.linear2.weight, + self.linear2.bias, + merged_mask, + mask_type, + ) + + x = src + if self.norm_first: + x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask) + x = x + self._ff_block(self.norm2(x)) + else: + x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask)) + x = self.norm2(x + self._ff_block(x)) + + return x + + # self-attention block + def _sa_block( + self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor] + ) -> Tensor: + x = self.self_attn( + x, + x, + x, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False, + )[0] + return self.dropout1(x) + + # feed forward block + def _ff_block(self, x: Tensor) -> Tensor: + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + return self.dropout2(x) + + +class TransformerDecoderLayer(Module): + + __constants__ = ["batch_first", "norm_first"] + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + layer_norm_eps: float = 1e-5, + batch_first: bool = False, + norm_first: bool = False, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(TransformerDecoderLayer, self).__init__() + self.self_attn = MultiheadAttention( + d_model, nhead, dropout=dropout, batch_first=batch_first, **factory_kwargs + ) + self.multihead_attn = MultiheadAttention( + d_model, nhead, dropout=dropout, batch_first=batch_first, **factory_kwargs + ) + # Implementation of Feedforward model + self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs) + self.dropout = Dropout(dropout) + self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs) + + self.norm_first = norm_first + self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + self.dropout1 = Dropout(dropout) + self.dropout2 = Dropout(dropout) + self.dropout3 = Dropout(dropout) + + # Legacy string support for activation function. + if isinstance(activation, str): + self.activation = _get_activation_fn(activation) + else: + self.activation = activation + + def __setstate__(self, state): + if "activation" not in state: + state["activation"] = F.relu + super(TransformerDecoderLayer, self).__setstate__(state) + + def forward( + self, + tgt: Tensor, + memory: Tensor, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + tgt_is_causal: bool = False, + memory_is_causal: bool = False, + ) -> Tensor: + + # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf + + x = tgt + if self.norm_first: + x = x + self._sa_block( + self.norm1(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal + ) + x = x + self._mha_block( + self.norm2(x), + memory, + memory_mask, + memory_key_padding_mask, + memory_is_causal, + ) + x = x + self._ff_block(self.norm3(x)) + else: + x = self.norm1( + x + self._sa_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal) + ) + x = self.norm2( + x + + self._mha_block( + x, memory, memory_mask, memory_key_padding_mask, memory_is_causal + ) + ) + x = self.norm3(x + self._ff_block(x)) + + return x + + # self-attention block + def _sa_block( + self, + x: Tensor, + attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], + is_causal: bool = False, + ) -> Tensor: + x = self.self_attn( + x, + x, + x, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + is_causal=is_causal, + need_weights=False, + )[0] + return self.dropout1(x) + + # multihead attention block + def _mha_block( + self, + x: Tensor, + mem: Tensor, + attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], + is_causal: bool = False, + ) -> Tensor: + x = self.multihead_attn( + x, + mem, + mem, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + is_causal=is_causal, + need_weights=False, + )[0] + return self.dropout2(x) + + # feed forward block + def _ff_block(self, x: Tensor) -> Tensor: + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + return self.dropout3(x) + + +def _get_clones(module, N): + return ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]: + if activation == "relu": + return F.relu + elif activation == "gelu": + return F.gelu + + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) diff --git a/pi/nn/modules/upsampling.py b/pi/nn/modules/upsampling.py new file mode 100644 index 0000000..9a1ac0e --- /dev/null +++ b/pi/nn/modules/upsampling.py @@ -0,0 +1,83 @@ +from .module import Module +from .. import functional as F + +from pi import Tensor +from typing import Optional +from ..common_types import _size_2_t, _ratio_2_t, _size_any_t, _ratio_any_t + +__all__ = ["Upsample", "UpsamplingNearest2d", "UpsamplingBilinear2d"] + + +class Upsample(Module): + + __constants__ = [ + "size", + "scale_factor", + "mode", + "align_corners", + "name", + "recompute_scale_factor", + ] + name: str + size: Optional[_size_any_t] + scale_factor: Optional[_ratio_any_t] + mode: str + align_corners: Optional[bool] + recompute_scale_factor: Optional[bool] + + def __init__( + self, + size: Optional[_size_any_t] = None, + scale_factor: Optional[_ratio_any_t] = None, + mode: str = "nearest", + align_corners: Optional[bool] = None, + recompute_scale_factor: Optional[bool] = None, + ) -> None: + super(Upsample, self).__init__() + self.name = type(self).__name__ + self.size = size + if isinstance(scale_factor, tuple): + self.scale_factor = tuple(float(factor) for factor in scale_factor) + else: + self.scale_factor = float(scale_factor) if scale_factor else None + self.mode = mode + self.align_corners = align_corners + self.recompute_scale_factor = recompute_scale_factor + + def forward(self, input: Tensor) -> Tensor: + return F.interpolate( + input, + self.size, + self.scale_factor, + self.mode, + self.align_corners, + recompute_scale_factor=self.recompute_scale_factor, + ) + + def extra_repr(self) -> str: + if self.scale_factor is not None: + info = "scale_factor=" + str(self.scale_factor) + else: + info = "size=" + str(self.size) + info += ", mode=" + self.mode + return info + + +class UpsamplingNearest2d(Upsample): + def __init__( + self, + size: Optional[_size_2_t] = None, + scale_factor: Optional[_ratio_2_t] = None, + ) -> None: + super(UpsamplingNearest2d, self).__init__(size, scale_factor, mode="nearest") + + +class UpsamplingBilinear2d(Upsample): + def __init__( + self, + size: Optional[_size_2_t] = None, + scale_factor: Optional[_ratio_2_t] = None, + ) -> None: + super(UpsamplingBilinear2d, self).__init__( + size, scale_factor, mode="bilinear", align_corners=True + ) diff --git a/pi/nn/modules/utils.py b/pi/nn/modules/utils.py new file mode 100644 index 0000000..7334c53 --- /dev/null +++ b/pi/nn/modules/utils.py @@ -0,0 +1,63 @@ +import collections +from itertools import repeat +from typing import List, Dict, Any + +__all__ = ["consume_prefix_in_state_dict_if_present"] + + +def _ntuple(n, name="parse"): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return tuple(x) + return tuple(repeat(x, n)) + + parse.__name__ = name + return parse + + +_single = _ntuple(1, "_single") +_pair = _ntuple(2, "_pair") +_triple = _ntuple(3, "_triple") +_quadruple = _ntuple(4, "_quadruple") + + +def _reverse_repeat_tuple(t, n): + + return tuple(x for x in reversed(t) for _ in range(n)) + + +def _list_with_default(out_size: List[int], defaults: List[int]) -> List[int]: + if isinstance(out_size, int): + return out_size + if len(defaults) <= len(out_size): + raise ValueError( + "Input dimension should be at least {}".format(len(out_size) + 1) + ) + return [ + v if v is not None else d for v, d in zip(out_size, defaults[-len(out_size) :]) + ] + + +def consume_prefix_in_state_dict_if_present( + state_dict: Dict[str, Any], prefix: str +) -> None: + + keys = sorted(state_dict.keys()) + for key in keys: + if key.startswith(prefix): + newkey = key[len(prefix) :] + state_dict[newkey] = state_dict.pop(key) + + # also strip the prefix in metadata if any. + if "_metadata" in state_dict: + metadata = state_dict["_metadata"] + for key in list(metadata.keys()): + # for the metadata dict, the key can be: + # '': for the DDP module, which we want to remove. + # 'module': for the actual model. + # 'module.xx.xx': for the rest. + + if len(key) == 0: + continue + newkey = key[len(prefix) :] + metadata[newkey] = metadata.pop(key) diff --git a/pi/nn/parameter.py b/pi/nn/parameter.py new file mode 100644 index 0000000..3caa79b --- /dev/null +++ b/pi/nn/parameter.py @@ -0,0 +1,61 @@ +import inspect +from functools import partial +from typing import Union, List, Tuple + +# this is the right way to import in order to not screw up the tests (torch.dtype vs pi.type) +from ..types_ import dtype as pi_dtype + +# import pi +from .._tensor import Tensor, empty + + +class Parameter(Tensor): + def __repr__(self): + return "Parameter containing:\n" + super(Parameter, self).__repr__() + + +class Uninitialized(partial): + size: Union[List[int], Tuple[int, ...]] + dtype: pi_dtype = None + + def __new__(cls, *args, **keywords): + func = args[0] + if not inspect.isfunction(func): + func = empty + else: + args = args[1:] + + if isinstance(args[0], (tuple, list)): + assert len(args) == 1, f"unknown len args {args}" + args = args[0] + + assert all([isinstance(a, int) for a in args]), f"{args}" + instance = super(Uninitialized, cls).__new__(cls, func, *args, **keywords) + instance.size = args + if "dtype" in keywords and keywords["dtype"] is not None: + dtype = keywords["dtype"] + assert isinstance( + dtype, pi_dtype + ), f"unknown dtype {type(dtype).__module__}.{type(dtype).__name__} (should be {pi_dtype.__module__}.{pi_dtype.__name__})" + instance.dtype = dtype + + return instance + + def __call__(self, /, *args, **keywords): + keywords = {**self.keywords, **keywords} + args = (*self.args, *args) + return self.func(args, **keywords) + + +class UninitializedParameter(Uninitialized): + cls_to_become = Parameter + + +class UninitializedBuffer(Uninitialized): + cls_to_become = Tensor + + +def is_uninitialized( + v: Union[Tensor, Parameter, UninitializedBuffer, UninitializedParameter] +) -> bool: + return isinstance(v, Uninitialized) diff --git a/pi/nn/shape_functions.py b/pi/nn/shape_functions.py new file mode 100644 index 0000000..b5ff13c --- /dev/null +++ b/pi/nn/shape_functions.py @@ -0,0 +1,1031 @@ +from typing import List, Any, Optional, Union, Dict, Callable, Tuple +import math + +number = Union[int, float] + + +def broadcast(a: List[int], b: List[int]): + dimsA = len(a) + dimsB = len(b) + ndim = max(dimsA, dimsB) + expandedSizes: List[int] = [] + + for i in range(ndim): + offset = ndim - 1 - i + dimA = dimsA - 1 - offset + dimB = dimsB - 1 - offset + sizeA = a[dimA] if (dimA >= 0) else 1 + sizeB = b[dimB] if (dimB >= 0) else 1 + + if sizeA != sizeB and sizeA != 1 and sizeB != 1: + # TODO: only assertion error is bound in C++ compilation right now + raise AssertionError( + "The size of tensor a {} must match the size of tensor b (" + "{}) at non-singleton dimension {}".format(sizeA, sizeB, i) + ) + + expandedSizes.append(sizeB if sizeA == 1 else sizeA) + + return expandedSizes + + +def broadcast_three(a: List[int], b: List[int], c: List[int]): + return broadcast(broadcast(a, b), c) + + +def broadcast_one_three(a: List[int], b: Any, c: List[int]): + return broadcast(a, c) + + +def adaptive_avg_pool2d(self: List[int], out: List[int]): + assert len(out) == 2 + assert len(self) == 3 or len(self) == 4 + for i in range(1, len(self)): + assert self[i] != 0 + + shape: List[int] = [] + for i in range(0, len(self) - 2): + shape.append(self[i]) + for elem in out: + shape.append(elem) + return shape + + +def _copy(self: List[int]): + out: List[int] = [] + for elem in self: + out.append(elem) + return out + + +def unary(self: List[int]): + return _copy(self) + + +def broadcast_inplace(a: List[int], b: List[int]): + dimsA = len(a) + dimsB = len(b) + if dimsB > dimsA: + raise AssertionError( + "The dims of tensor b ({}) must be less than or equal to" + "the dims of tensor a ({}) ".format(dimsB, dimsA) + ) + for dimA in range(dimsA): + dimB = dimsB - dimsA + dimA + sizeA = a[dimA] + sizeB = b[dimB] if (dimB >= 0) else 1 + if sizeA != sizeB and sizeB != 1: + # TODO: only assertion error is bound in C++ compilation right now + raise AssertionError( + "The size of tensor a {} must match the size of tensor b (" + "{}) at non-singleton dimension {}".format(sizeA, sizeB, dimA) + ) + return _copy(a) + + +def expand(self: List[int], sizes: List[int]): + assert len(sizes) >= len(self) + ndim = len(sizes) + tensor_dim = len(self) + if ndim == 0: + return _copy(sizes) + out: List[int] = [] + for i in range(ndim): + offset = ndim - 1 - i + dim = tensor_dim - 1 - offset + size = self[dim] if dim >= 0 else 1 + targetSize = sizes[i] + if targetSize == -1: + assert dim >= 0 + targetSize = size + if size != targetSize: + assert size == 1 + size = targetSize + out.append(size) + return out + + +def expand_one_unused(self: List[int], sizes: List[int], inp0: Any): + return expand(self, sizes) + + +def infer_size_impl(shape: List[int], numel: int) -> List[int]: + newsize = 1 + infer_dim: Optional[int] = None + for dim in range(len(shape)): + if shape[dim] == -1: + if infer_dim is not None: + raise AssertionError("only one dimension can be inferred") + infer_dim = dim + elif shape[dim] >= 0: + newsize *= shape[dim] + else: + raise AssertionError("invalid shape dimensions") + if not ( + numel == newsize + or (infer_dim is not None and newsize > 0 and numel % newsize == 0) + ): + raise AssertionError("invalid shape") + out = _copy(shape) + if infer_dim is not None: + out[infer_dim] = numel // newsize + return out + + +def numel(sizes: List[int]): + numel = 1 + for elem in sizes: + numel *= elem + return numel + + +def view(self: List[int], sizes: List[int]): + return infer_size_impl(sizes, numel(self)) + + +def view_one_unused(self: List[int], sizes: List[int], *, implicit: bool = False): + return view(self, sizes) + + +def sum_mean_dim( + self: List[int], opt_dims: Optional[List[int]], keep_dim: bool, dt: Any +): + out: List[int] = [] + if opt_dims is None or len(opt_dims) == 0: + dims: List[int] = list(range(len(self))) + else: + dims = opt_dims + + for idx in range(len(self)): + is_mean_dim: bool = False + for reduce_dim in dims: + if idx == maybe_wrap_dim(reduce_dim, len(self)): + is_mean_dim = True + if is_mean_dim: + if keep_dim: + out.append(1) + else: + out.append(self[idx]) + return out + + +def max_dim(self: List[int], dim: int, keep_dim: bool): + out = sum_mean_dim(self, [dim], keep_dim, None) + return out, out + + +# note: python already rounds down towards negative infinity on integer division, special arithmetic not needed +def div_rtn(x: int, y: int): + return x // y + + +def pooling_output_shape_pad_lr( + inputSize: int, + kernelSize: int, + pad_l: int, + pad_r: int, + stride: int, + dilation: int, + ceil_mode: bool, +): + outputSize = ( + div_rtn( + inputSize + + pad_l + + pad_r + - dilation * (kernelSize - 1) + - 1 + + (stride - 1 if ceil_mode else 0), + stride, + ) + + 1 + ) + if ceil_mode: + if (outputSize - 1) * stride >= inputSize + pad_l: + outputSize = outputSize - 1 + return outputSize + + +def pooling_output_shape( + inputSize: int, + kernelSize: int, + pad_l: int, + stride: int, + dilation: int, + ceil_mode: bool, +): + assert stride != 0, "stride should not be zeero" + return pooling_output_shape_pad_lr( + inputSize, kernelSize, pad_l, pad_l, stride, dilation, ceil_mode + ) + + +def pool2d_shape_check( + input: List[int], + kH: int, + kW: int, + dH: int, + dW: int, + padH: int, + padW: int, + dilationH: int, + dilationW: int, + nInputPlane: int, + inputHeight: int, + inputWidth: int, + outputHeight: int, + outputWidth: int, +): + ndim = len(input) + nOutputPlane = nInputPlane + + assert kW > 0 and kH > 0 + assert dW > 0 and dH > 0 + assert dilationH > 0 and dilationW > 0 + + valid_dims = input[1] != 0 and input[2] != 0 + assert ( + ndim == 3 + and input[0] != 0 + and valid_dims + or (ndim == 4 and valid_dims and input[3] != 0) + ) + + assert kW // 2 >= padW and kH // 2 >= padH + assert outputWidth >= 1 and outputHeight >= 1 + + +def max_pool2d( + input: List[int], + kernel_size: List[int], + stride: List[int], + padding: List[int], + dilation: List[int], + ceil_mode: bool, +): + assert ( + len(kernel_size) == 1 or len(kernel_size) == 2 + ), "max_pool2d: kernel_size must either be a single int, or a tuple of two ints" + kH = kernel_size[0] + kW = kH if len(kernel_size) == 1 else kernel_size[1] + + assert ( + len(stride) == 0 or len(stride) == 1 or len(stride) == 2 + ), "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints" + dH = kH if len(stride) == 0 else stride[0] + if len(stride) == 0: + dW = kW + elif len(stride) == 1: + dW = dH + else: + dW = stride[1] + + assert ( + len(padding) == 1 or len(padding) == 2 + ), "max_pool2d: padding must be either be a single int, or a tuple of two ints" + padH = padding[0] + padW = padH if len(padding) == 1 else padding[1] + + assert ( + len(dilation) == 1 or len(dilation) == 2 + ), "max_pool2d: dilation must be either a single int, or a tuple of two ints" + dilationH = dilation[0] + dilationW = dilationH if len(dilation) == 1 else dilation[1] + + assert len(input) == 3 or len(input) == 4 + + nbatch = input[-4] if len(input) == 4 else 1 + nInputPlane = input[-3] + inputHeight = input[-2] + inputWidth = input[-1] + + outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, dilationH, ceil_mode) + outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, dilationW, ceil_mode) + + pool2d_shape_check( + input, + kH, + kW, + dH, + dW, + padH, + padW, + dilationH, + dilationW, + nInputPlane, + inputHeight, + inputWidth, + outputHeight, + outputWidth, + ) + + if len(input) == 3: + return [nInputPlane, outputHeight, outputWidth] + else: + return [nbatch, nInputPlane, outputHeight, outputWidth] + + +def max_pool2d_with_indices( + input: List[int], + kernel_size: List[int], + stride: List[int], + padding: List[int], + dilation: List[int], + ceil_mode: bool, +): + out = max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode) + return (out, out) + + +def upsample_nearest2d( + input: List[int], + output_size: Optional[List[int]], + scale_factors: Optional[List[float]], +): + out: List[int] = [] + out.append(input[0]) + out.append(input[1]) + + if scale_factors is None and output_size is None: + assert 0, "Either output_size or scale_factors must be presented" + + if output_size is not None: + assert ( + scale_factors is None + ), "Must specify exactly one of output_size and scale_factors" + assert len(output_size) == 2 + out.append(output_size[0]) + out.append(output_size[1]) + + if scale_factors is not None: + assert ( + output_size is None + ), "Must specify exactly one of output_size and scale_factors" + assert len(scale_factors) == 2 + out.append(int(input[2] * scale_factors[0])) + out.append(int(input[3] * scale_factors[1])) + + return out + + +def mm(self: List[int], mat2: List[int]): + assert len(self) == 2, "self must be a matrix" + assert len(mat2) == 2, "mat2 must be a matrix" + + assert self[1] == mat2[0] + return [self[0], mat2[1]] + + +def dot(self: List[int], tensor: List[int]): + assert len(self) == 1 and len(tensor) == 1 + assert self[0] == tensor[0] + out: List[int] = [] + return out + + +def mv(self: List[int], vec: List[int]): + assert len(self) == 2 and len(vec) == 1 + assert self[1] == vec[0] + # TODO: return self + return [self[0]] + + +def unsqueeze(li: List[int], dim: int): + dim = maybe_wrap_dim(dim, len(li) + 1) + out = _copy(li) + out.insert(dim, 1) + return out + + +def squeeze_nodim(li: List[int]): + out: List[int] = [] + for i in range(len(li)): + if li[i] != 1: + out.append(li[i]) + return out + + +def squeeze(li: List[int], dim: int): + out: List[int] = [] + wrapped_dim = maybe_wrap_dim(dim, len(li)) + for i in range(len(li)): + if i == wrapped_dim: + if li[i] != 1: + out.append(li[i]) + else: + out.append(li[i]) + return out + + +def index_select(self: List[int], dim: int, index: List[int]): + dim = maybe_wrap_dim(dim, len(self)) + numel = multiply_integers(index) + assert len(index) <= 1 + assert dim == 0 or dim < len(self) + result_size: List[int] = [] + for i in range(len(self)): + if dim == i: + result_size.append(numel) + else: + result_size.append(self[i]) + return result_size + + +def embedding( + weight: List[int], + indices: List[int], + padding_idx: int = -1, + scale_grad_by_freq: bool = False, + sparse: bool = False, +): + assert len(weight) == 2 + if len(indices) == 1: + return index_select(weight, 0, indices) + size = _copy(indices) + size.append(weight[1]) + return size + + +def max_int(): + return 9223372036854775807 + + +def slice( + self: List[int], dim: int, start: Optional[int], end: Optional[int], step: int +): + ndim = len(self) + assert ndim != 0 + dim = maybe_wrap_dim(dim, ndim) + start_val = start if start is not None else 0 + end_val = end if end is not None else max_int() + assert step > 0 + if start_val == max_int(): + start_val = 0 + if start_val < 0: + start_val += self[dim] + if end_val < 0: + end_val += self[dim] + if start_val < 0: + start_val = 0 + elif start_val > self[dim]: + start_val = self[dim] + if end_val < start_val: + end_val = start_val + elif end_val >= self[dim]: + end_val = self[dim] + slice_len = end_val - start_val + out = _copy(self) + out[dim] = (slice_len + step - 1) // step + return out + + +def check_cat_no_zero_dim(tensors: List[List[int]]): + for tensor in tensors: + assert len(tensor) > 0 + + +def legacy_cat_wrap_dim(dim: int, tensor_sizes: List[List[int]]): + out_dim: Optional[int] = None + for size in tensor_sizes: + if not (len(size) == 1 and size[0] == 0): + if out_dim is None: + out_dim = maybe_wrap_dim(dim, len(size)) + if out_dim is None: + out_dim = dim + return out_dim + + +def should_skip(tensor: List[int]): + return numel(tensor) == 0 and len(tensor) == 1 + + +def check_cat_shape_except_dim( + first: List[int], second: List[int], dimension: int, index: int +): + first_dims = len(first) + second_dims = len(second) + assert first_dims == second_dims, "Tensors must have same number of dimensions" + for dim in range(0, first_dims): + if dim != dimension: + assert ( + first[dim] == second[dim] + ), "Sizes of tensors must match except in dimension" + + +def cat(tensors: List[List[int]], dim: int): + check_cat_no_zero_dim(tensors) + dim = legacy_cat_wrap_dim(dim, tensors) + assert len(tensors) > 0 + not_skipped_tensor: Optional[List[int]] = None + for tensor in tensors: + if not should_skip(tensor): + not_skipped_tensor = tensor + if not_skipped_tensor is None: + return [0] + + cat_dim_size = 0 + + for i in range(len(tensors)): + tensor = tensors[i] + if not should_skip(tensor): + check_cat_shape_except_dim(not_skipped_tensor, tensor, dim, i) + cat_dim_size = cat_dim_size + tensor[dim] + + result_size = _copy(not_skipped_tensor) + result_size[dim] = cat_dim_size + return result_size + + +def select(self: List[int], dim: int, index: int): + ndim = len(self) + assert ndim != 0 + dim = maybe_wrap_dim(dim, ndim) + size = self[dim] + assert not (index < -size or index >= size) + if index < 0: + index += size + out: List[int] = [] + for i in range(ndim): + if i != dim: + out.append(self[i]) + return out + + +def matmul(tensor1: List[int], tensor2: List[int]): + dim_tensor1 = len(tensor1) + dim_tensor2 = len(tensor2) + if dim_tensor1 == 1 and dim_tensor2 == 1: + return dot(tensor1, tensor2) + elif dim_tensor1 == 2 and dim_tensor2 == 1: + return mv(tensor1, tensor2) + elif dim_tensor1 == 1 and dim_tensor2 == 2: + return squeeze(mm(unsqueeze(tensor1, 0), tensor2), 0) + elif dim_tensor1 == 2 and dim_tensor2 == 2: + return mm(tensor1, tensor2) + elif dim_tensor1 >= 1 and dim_tensor2 >= 1: + # We are multiplying b1 x n x m1 by x2 x m2 x p (where b1 can be a list); + # we track m1 vs m2 separately even though they must match for nicer error messages + n = tensor1[-2] if dim_tensor1 > 1 else 1 + m1 = tensor1[-1] + batch_tensor1: List[int] = [] + # TODO: handling of slice + for i in range(dim_tensor1 - 2): + batch_tensor1.append(tensor1[i]) + m2 = tensor2[-1] if dim_tensor2 > 1 else 1 + p = tensor2[-1] + batch_tensor2: List[int] = [] + # TODO: handling of slice + for i in range(dim_tensor2 - 2): + batch_tensor2.append(tensor2[i]) + + # expand the batch portion (i.e. cut off matrix dimensions and expand rest) + expand_batch_portion = broadcast(batch_tensor1, batch_tensor2) + + # todo: copy ? + output_shape = expand_batch_portion + if dim_tensor1 > 1: + output_shape.append(n) + + if dim_tensor2 > 1: + output_shape.append(p) + + return output_shape + else: + assert False, "both arguments to matmul need to be at least 1D" + + +def t(self: List[int]): + assert len(self) <= 2 + self_len = len(self) + if self_len == 0: + out: List[int] = [] + return out + elif self_len == 1: + return [self[0]] + else: + return [self[1], self[0]] + + +def transpose(self: List[int], dim0: int, dim1: int): + ndims = len(self) + dim0 = maybe_wrap_dim(dim0, ndims) + dim1 = maybe_wrap_dim(dim1, ndims) + if dim0 == dim1: + return _copy(self) + out: List[int] = [] + for i in range(ndims): + if i == dim0: + out.append(self[dim1]) + elif i == dim1: + out.append(self[dim0]) + else: + out.append(self[i]) + return out + + +def linear(input: List[int], weight: List[int], bias: Optional[List[int]]): + out = matmul(input, t(weight)) + if bias is not None: + assert broadcast(bias, out) == out + return out + + +def addmm(self: List[int], mat1: List[int], mat2: List[int], beta: Any, alpha: Any): + return broadcast(self, mm(mat1, mat2)) + + +def check_non_negative(array: List[int]) -> bool: + # TODO: look into rewriting with early return and getting loop unrolling to fire + non_negative = False + for val in array: + if val < 0: + non_negative = True + return non_negative + + +def check_shape_forward( + input: List[int], + weight_sizes: List[int], + bias: Optional[List[int]], + stride: List[int], + padding: List[int], + dilation: List[int], + groups: int, +): + k = len(input) + weight_dim = len(weight_sizes) + + # TODO: assertions could be expanded with the error messages + assert not check_non_negative(padding) + assert not check_non_negative(stride) + + assert weight_dim == k + assert weight_sizes[0] >= groups + assert (weight_sizes[0] % groups) == 0 + # only handling not transposed + assert input[1] == weight_sizes[1] * groups + assert bias is None or (len(bias) == 1 and bias[0] == weight_sizes[0]) + + for i in range(2, k): + assert (input[i] + 2 * padding[i - 2]) >= ( + dilation[i - 2] * (weight_sizes[i] - 1) + 1 + ) + + # this is not handling transposed convolution yet + + +def conv_output_size( + input_size: List[int], + weight_size: List[int], + bias: Optional[List[int]], + stride: List[int], + padding: List[int], + dilation: List[int], + groups: int, +): + check_shape_forward( + input_size, weight_size, bias, stride, padding, dilation, groups + ) + + has_dilation = len(dilation) > 0 + dim = len(input_size) + output_size: List[int] = [] + input_batch_size_dim = 0 + weight_output_channels_dim = 0 + output_size.append(input_size[input_batch_size_dim]) + output_size.append(weight_size[weight_output_channels_dim]) + + for d in range(2, dim): + dilation_ = dilation[d - 2] if has_dilation else 1 + kernel = dilation_ * (weight_size[d] - 1) + 1 + output_size.append( + (input_size[d] + (2 * padding[d - 2]) - kernel) // stride[d - 2] + 1 + ) + return output_size + + +def conv1d( + input: List[int], + weight: List[int], + bias: Optional[List[int]], + stride: List[int], + padding: List[int], + dilation: List[int], + groups: int, +): + assert len(weight) == 3 + assert len(input) == 3 + return conv_output_size(input, weight, bias, stride, padding, dilation, groups) + + +def conv2d( + input: List[int], + weight: List[int], + bias: Optional[List[int]], + stride: List[int], + padding: List[int], + dilation: List[int], + groups: int, +): + assert len(weight) == 4 + assert len(input) == 4 + return conv_output_size(input, weight, bias, stride, padding, dilation, groups) + + +def conv_backwards( + grad_output: List[int], + input: List[int], + weight: List[int], + biases: Optional[List[int]], +): + # Bias gradient is always generated regardess of if biases is supplied + return _copy(input), _copy(weight), [grad_output[1]] + + +def conv_transpose2d_input( + input: List[int], + weight: List[int], + bias: Optional[List[int]] = None, + stride: Optional[List[int]] = None, + padding: Optional[List[int]] = None, + output_padding: Optional[List[int]] = None, + groups: int = 1, + dilation: Optional[List[int]] = None, +) -> List[int]: + if stride is None: + stride = [1, 1] + if padding is None: + padding = [0, 0] + if output_padding is None: + output_padding = [0, 0] + if dilation is None: + dilation = [1, 1] + has_dilation = len(dilation) > 0 + dim = len(input) + output_size: List[int] = [] + input_batch_size_dim = 0 + weight_output_channels_dim = 1 + output_size.append(input[input_batch_size_dim]) + output_size.append(weight[weight_output_channels_dim]) + + for d in range(2, dim): + dilation_ = dilation[d - 2] if has_dilation else 1 + kernel = dilation_ * (weight[d] - 1) + output_size.append( + (input[d] - 1) * stride[d - 2] - 2 * padding[d - 2] + kernel + 1 + ) + return output_size + + +def conv_forwards( + input: List[int], + weight: List[int], + bias: Optional[List[int]], + stride: List[int], + padding: List[int], + dilation: List[int], + transposed: bool, + output_padding: List[int], + groups: int, +) -> List[int]: + has_dilation = len(dilation) > 0 + dim = len(input) + output_size: List[int] = [] + input_batch_size_dim = 0 + weight_output_channels_dim = 1 if transposed else 0 + output_size.append(input[input_batch_size_dim]) + output_size.append(weight[weight_output_channels_dim]) + + for d in range(2, dim): + dilation_ = dilation[d - 2] if has_dilation else 1 + if transposed: + kernel = dilation_ * (weight[d] - 1) + output_size.append( + (input[d] - 1) * stride[d - 2] - 2 * padding[d - 2] + kernel + 1 + ) + else: + kernel = dilation_ * (weight[d] - 1) + 1 + output_size.append( + (input[d] + (2 * padding[d - 2]) - kernel) // stride[d - 2] + 1 + ) + return output_size + + +def batch_norm( + input: List[int], + weight: Optional[List[int]], + bias: Optional[List[int]], + running_mean: Optional[List[int]], + running_var: Optional[List[int]], + training: bool, + momentum: float, + eps: float, + cudnn_enabled: bool, +): + out: List[int] = [] + for elem in input: + out.append(elem) + return out + + +def conv3d( + input: List[int], + weight: List[int], + bias: Optional[List[int]], + stride: List[int], + padding: List[int], + dilation: List[int], + groups: int, +): + assert len(weight) == 5 + assert len(input) == 5 + return conv_output_size(input, weight, bias, stride, padding, dilation, groups) + + +def maybe_wrap_dim(dim: int, dim_post_expr: int, wrap_scalar: bool = True): + if dim_post_expr <= 0: + assert wrap_scalar + dim_post_expr = 1 + min = -dim_post_expr + max = dim_post_expr - 1 + assert not (dim < min or dim > max) + if dim < 0: + dim += dim_post_expr + return dim + + +def zero_dim_tensor(input: Any): + out: List[int] = [] + return out + + +def multiply_integers(li: List[int]): + out = 1 + for elem in li: + out = out * elem + return out + + +def arange_end(end: number, inp0: Any, inp1: Any, inp2: Any, inp3: Any): + assert end >= 0 + return [int(math.ceil(end))] + + +def arange_start( + start: number, end: number, inp0: Any, inp1: Any, inp2: Any, inp3: Any +): + assert end >= 0 + assert end >= start + return [int(math.ceil(end - start))] + + +def arange_start_step( + start: number, end: number, step: number, inp0: Any, inp1: Any, inp2: Any, inp3: Any +): + assert step != 0 + if step < 0: + assert start >= end + else: + assert end >= start + return [int(math.ceil((end - start) / step))] + + +def permute(input: List[int], dims: List[int]): + assert len(input) == len(dims) + ndim = len(dims) + seen_dims: List[int] = [] + newSizes: List[int] = [] + for i in range(ndim): + dim = maybe_wrap_dim(dims[i], ndim) + seen_dims.append(dim) + newSizes.append(input[dim]) + for i in range(1, ndim): + for j in range(i): + assert seen_dims[i] != seen_dims[j] + return newSizes + + +def flatten(input: List[int], start_dim: int, end_dim: int): + start_dim = maybe_wrap_dim(start_dim, len(input)) + end_dim = maybe_wrap_dim(end_dim, len(input)) + assert start_dim <= end_dim + if len(input) == 0: + return [1] + if start_dim == end_dim: + # TODO: return self + out: List[int] = [] + for elem in input: + out.append(elem) + return out + slice_numel = 1 + for i in range(start_dim, end_dim + 1): + slice_numel *= input[i] + # TODO: use slicing when slice optimization has landed + # slice_numel = multiply_integers(input[start_dim:end_dim - start_dim + 1]) + shape: List[int] = [] + for i in range(start_dim): + shape.append(input[i]) + shape.append(slice_numel) + for i in range(end_dim + 1, len(input)): + shape.append(input[i]) + return shape + + +def nonzero_lower_bound(input: List[int]): + return [0, len(input)] + + +def nonzero_upper_bound(input: List[int]): + return [numel(input), len(input)] + + +def _reduce_along_dim(self: List[int], dim: int, keepdim: bool): + dim = maybe_wrap_dim(dim, len(self)) + out: List[int] = [] + for i, self_dim in enumerate(self): + if i == dim: + if keepdim: + out.append(1) + else: + out.append(self_dim) + return out + + +def argmax( + self: List[int], dim: Optional[int] = None, keepdim: bool = False +) -> List[int]: + if dim is None: + return [] + return _reduce_along_dim(self, dim, keepdim) + + +def bmm(self: List[int], mat2: List[int]) -> List[int]: + assert len(self) == 3, "bmm only supports 3D tensors" + assert len(mat2) == 3, "bmm only supports 3D tensors" + assert self[0] == mat2[0], "mismatching batch dimension" + assert self[2] == mat2[1], "mismatching contracting dimension" + return [self[0], self[1], mat2[2]] + + +def _shape_as_tensor(self: List[int]) -> List[int]: + return [len(self)] + + +def topk(self: List[int], k: int, dim: int = -1) -> Tuple[List[int], List[int]]: + if len(self) == 0: + result: List[int] = [] + else: + assert ( + k <= self[dim] + ), f"k ({k}) is too big for dimension {dim} of size {self[dim]}" + result = _copy(self) + result[dim] = k + return result, result + + +def nll_loss_forward( + self: List[int], target: List[int], weight: Optional[List[int]], reduction: int +) -> Tuple[List[int], List[int]]: + # This is taken shamelessly from the meta function in LossNLL.cpp + self_dim = len(self) + target_dim = len(target) + assert 0 < self_dim <= 2 + assert target_dim <= 1 + no_batch_dim = self_dim == 1 and target_dim == 0 + assert no_batch_dim or (self[0] == target[0]) + n_classes = self[-1] + scalar_shape: List[int] = [] + assert weight is None or (len(weight) == 1 and weight[0] == n_classes) + if reduction == 0 and self_dim == 2: + reduction_shape = [self[0]] + else: + reduction_shape = scalar_shape + return reduction_shape, scalar_shape + + +def native_layer_norm( + input: List[int], normalized_shape: List[int] +) -> Tuple[List[int], List[int], List[int]]: + reduction_shape: List[int] = [] + num_unreduced_dimensions = len(input) - len(normalized_shape) + assert num_unreduced_dimensions >= 0 + for i in range(num_unreduced_dimensions): + reduction_shape.append(input[i]) + for i in range(num_unreduced_dimensions, len(input)): + reduction_shape.append(1) + return _copy(input), reduction_shape, reduction_shape + + +def native_batch_norm( + input: List[int], + weight: Optional[List[int]], + bias: Optional[List[int]], + running_mean: Optional[List[int]], + running_var: Optional[List[int]], + training: bool, +) -> Tuple[List[int], List[int], List[int]]: + if training: + _size = [input[1]] + else: + _size = [0] + return _copy(input), _size, _size diff --git a/pi/testing/__init__.py b/pi/testing/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pi/testing/framework.py b/pi/testing/framework.py new file mode 100644 index 0000000..f78e57b --- /dev/null +++ b/pi/testing/framework.py @@ -0,0 +1,38 @@ +from typing import Any, Callable, NamedTuple, Union + +import numpy as np +from torch_mlir import ir + +from pi.types_ import float32, int64 +from pi import nn +from pi.compiler.annotations import TensorPlaceholder + + +# Utilities for common testing trace generation. +# Also, resets the random seed for reproducibility. +# TODO: If generating in parallel, how to have manual_seed be local? +class TestUtils: + def __init__(self): + np.random.seed(0) + + # TODO: Add zeros/ones/etc. as convenient. + def rand(self, *sizes, low=0.0, high=1.0): + # return uniform(low, high, sizes) + return TensorPlaceholder(sizes, dtype=float32) + + def randn(self, *sizes): + # return randn(low, high, sizes) + return TensorPlaceholder(sizes, dtype=float32) + + def randint(self, *sizes, low=0, high=10): + # return randint(low, high, sizes) + return TensorPlaceholder(sizes, dtype=int64) + + +TestResult = Union[ir.OpView, ir.Operation, ir.Value, ir.OpResultList] + + +class Test(NamedTuple): + unique_name: str + program_factory: Callable[[], nn.Module] + program_invoker: Callable[[Any, TestUtils], None] diff --git a/pi/testing/registry.py b/pi/testing/registry.py new file mode 100644 index 0000000..82c4d7c --- /dev/null +++ b/pi/testing/registry.py @@ -0,0 +1,37 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +from typing import Callable + +from pi import nn + +from .framework import Test + +# The global registry of tests. +GLOBAL_TEST_REGISTRY = [] +# Ensure that there are no duplicate names in the global test registry. +_SEEN_UNIQUE_NAMES = set() + + +def register_test_case(module_factory: Callable[[], nn.Module]): + def decorator(f): + # Ensure that there are no duplicate names in the global test registry. + if f.__name__ in _SEEN_UNIQUE_NAMES: + raise Exception( + f"Duplicate test name: '{f.__name__}'. Please make sure that the function wrapped by `register_test_case` has a unique name." + ) + _SEEN_UNIQUE_NAMES.add(f.__name__) + + # Store the test in the registry. + GLOBAL_TEST_REGISTRY.append( + Test( + unique_name=f.__name__, + program_factory=module_factory, + program_invoker=f, + ) + ) + return f + + return decorator diff --git a/pi/testing/util.py b/pi/testing/util.py new file mode 100644 index 0000000..816f9c4 --- /dev/null +++ b/pi/testing/util.py @@ -0,0 +1,176 @@ +import warnings +from typing import Any, OrderedDict + +import numpy as np +import torch +from torch_mlir import ir +from torch_mlir.dialects import torch as torch_dialect, func as func_dialect + +from .framework import Test, TestUtils +from ..mlir_utils import run_pipeline_with_repro_report, mlir_cm +from .. import nn, DEBUG +from .._tensor import Tensor + +FIXED = np.linspace(0, 0.1, 101) + + +def set_weights( + mod, typ=torch.float32, val=1, requires_grad=False, fixed=False, random=False +): + import torch + from torch import nn + + for m in mod.modules(): + if hasattr(m, "weight"): + if fixed: + m.weight = torch.nn.Parameter( + torch.from_numpy( + np.random.choice(FIXED, m.weight.numel()) + .astype(np.float16, casting="unsafe") + .reshape(m.weight.shape) + ).type(typ), + requires_grad=requires_grad, + ) + elif random: + nn.init.constant_(m.weight, np.random.randint(1, 100)) + m.weight.requires_grad_(False) + m.weight = torch.nn.Parameter( + m.weight.type(typ), requires_grad=requires_grad + ) + else: + nn.init.constant_(m.weight, val) + m.weight.requires_grad_(False) + m.weight = torch.nn.Parameter( + m.weight.type(typ), requires_grad=requires_grad + ) + if hasattr(m, "bias") and m.bias is not None: + if fixed: + m.bias = torch.nn.Parameter( + torch.from_numpy( + np.random.choice(FIXED, m.bias.numel()) + .astype(np.float16, casting="unsafe") + .reshape(m.bias.shape) + ).type(typ), + requires_grad=requires_grad, + ) + elif random: + nn.init.constant_(m.bias, np.random.randint(1, 100)) + m.bias.requires_grad_(False) + m.bias = torch.nn.Parameter( + m.bias.type(typ), requires_grad=requires_grad + ) + else: + nn.init.constant_(m.bias, val) + m.bias.requires_grad_(False) + m.bias = torch.nn.Parameter( + m.bias.type(typ), requires_grad=requires_grad + ) + + +class TorchDialectConfig: + import torch + + """Base class for TestConfig's that are implemented with linalg-on-tensors. + + This class handles all the common lowering that torch-mlir does before + reaching the linalg-on-tensors abstraction level. + """ + + def compile(self, program: torch.nn.Module) -> Any: + from torch_mlir_e2e_test.utils import convert_annotations_to_placeholders + import torch_mlir + + example_args = convert_annotations_to_placeholders(program.forward) + module = torch_mlir.compile(program, example_args) + + return module + + +SMOKE_TEST = False + + +class PIConfig: + def compile(self, test_case: Test, test_module: nn.Module) -> Any: + tu = TestUtils() + with mlir_cm() as module: + module.operation.attributes["torch.debug_module_name"] = ir.StringAttr.get( + test_module.__class__.__name__ + ("SMOKE_TEST" if SMOKE_TEST else "") + ) + # TODO(max): for some reason updated __call__ doesn't stick + # (setattr doesn't work, gives 'method' object has no attribute '__annotations__' + placeholders = test_module.forward.__dict__["__placeholders__"] + if placeholders: + assert isinstance(placeholders, OrderedDict) + func_op = func_dialect.FuncOp( + name="forward", + type=( + [p.to_value_tensor_type() for p in placeholders.values()], + [], + ), + # visibility="public", + ) + func_op_entry_block = func_op.add_entry_block() + block_args = list(map(Tensor, func_op.arguments)) + + def replace_block_args(self_, *args, **kwargs): + assert not kwargs, f"kwargs not supported {kwargs}" + assert len(args) == len(block_args) + return block_args, kwargs + + test_module.register_forward_pre_hook(replace_block_args, prepend=True) + + results = [] + + def collect_results(_self, result, *_args, **_kwargs): + if len(results): + warnings.warn( + f"results already collected {results} (new result {result}); overwriting" + ) + results[0] = result + else: + results.append(result) + return result + + test_module.register_forward_post_hook(collect_results, prepend=True) + + with ir.InsertionPoint.at_block_begin(func_op_entry_block): + test_case.program_invoker(test_module, tu) + if isinstance(results[0], (tuple, list)): + results = results[0] + + # functions created from python can't return multiple results + if len(results) > 1: + results = [torch_dialect.PrimTupleConstructOp(results).result] + print(results) + + canonical_func_type = ir.FunctionType.get( + inputs=[b.type for b in block_args], + results=[r.type for r in results], + ) + func_op.attributes["function_type"] = ir.TypeAttr.get( + canonical_func_type + ) + + # these are pi tensors + results = [r.value for r in results] + func_dialect.ReturnOp(results) + + return module + + +def lower_torch_mlir_to_linalg(module): + run_pipeline_with_repro_report( + module, + "builtin.module(" + + ",".join( + [ + "cse", + # "builtin.module(torchscript-module-to-torch-backend-pipeline)", + "torch-backend-to-linalg-on-tensors-backend-pipeline", + ] + ) + + ")", + # "builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)", + "Lowering TorchScript IR -> Torch Backend IR", + ) + return module diff --git a/pi/types_.py b/pi/types_.py new file mode 100644 index 0000000..8f172d7 --- /dev/null +++ b/pi/types_.py @@ -0,0 +1,315 @@ +import builtins +import re +import weakref +from enum import Enum +from typing import Union, List, Tuple, Any + +import numpy as np +from torch_mlir import ir +from torch_mlir.dialects._ods_common import ( + get_op_result_or_value, +) +from torch_mlir.ir import ( + Type as MLIRType, + Value as MLIRValue, +) + +# !torch.vtensor<[1,2,3],f32> +reg = re.compile(r"!torch.vtensor<\[(.*)\],(.*)>") + + +def parse_sizes_from_tensor_type_str(t: ir.OpView) -> List[int]: + # TODO(max): pull straight from the ranked type + t = get_op_result_or_value(t) + sizes, dtype = reg.findall(str(t.type))[0] + sizes = [s if s != "?" else "-1" for s in sizes.split(",")] + return list(map(int, sizes)), dtype + + +def get_type(t: Union[MLIRType, MLIRValue]): + if not isinstance(t, MLIRType): + assert isinstance( + t, MLIRValue + ), f"unknown type {type(t).__module__}.{type(t).__name__})" + t = t.type + return t + + +# def is_mlir_value(v): +# return isinstance(v, (ir.OpView, ir.Operation, ir.Value, ir.OpResultList, Tensor)) + + +def is_a_torch_tensor(t): + try: + t = get_op_result_or_value(t) + type_str = str(t.type) + return "torch.tensor" in type_str or "torch.vtensor" in type_str + except: + return False + + + + +# IntegerType.get_signless(32) -> i32 +# IntegerType.get_signed(32) -> si32 +# IntegerType.get_unsigned(32) -> ui32 + + +class dtype(Enum): + """ + |-------------------|--------------------| + | Torch Type | MLIR Type | + |-------------------|--------------------| + | torch.bfloat16 | bf16 | + | torch.bool | i1 | + | torch.complex* | complex<*> | + | torch.float16 | f16 | + | torch.float32 | f32 | + | torch.float64 | f64 | + | torch.int16 | si16 | + | torch.int32 | si32 | + | torch.int64 | si64 | + | torch.int8 | si8 | + | torch.qint8 | !torch.qint8 | + | torch.quint8 | !torch.quint8 | + | torch.uint8 | ui8 | + |-------------------|--------------------| + """ + + uint8 = 0 + int8 = 1 + int16 = 2 + int32 = 3 + int64 = 4 + float16 = 5 + float32 = 6 + float64 = 7 + # complex_half 8 + complex32 = 9 + complex64 = 10 + bool = 11 + qint8 = 12 + quint8 = 13 + # qint32 14 + bfloat16 = 15 + # qint4x2 16 + # qint2x4 17 + + def to_mlir_type(self): + # if ctx is None: + # ctx = get_default_loc_context() + match self: + case dtype.bfloat16: + return ir.BF16Type.get() + case dtype.bool: + return ir.IntegerType.get_signless(1) + case dtype.complex32: + return ir.ComplexType.get(ir.F32Type.get()) + case dtype.complex64: + return ir.ComplexType.get(ir.F64Type.get()) + case dtype.float16: + return ir.F16Type.get() + case dtype.float32: + return ir.F32Type.get() + case dtype.float64: + return ir.F64Type.get() + case dtype.int8: + return ir.IntegerType.get_signed(8) + case dtype.int16: + return ir.IntegerType.get_signed(16) + case dtype.int32: + return ir.IntegerType.get_signed(32) + case dtype.int64: + return ir.IntegerType.get_signed(64) + case dtype.uint8: + return ir.IntegerType.get_unsigned(8) + case _: + raise NotImplementedError("Something's wrong with the internet") + + def to_np_type(self): + # if ctx is None: + # ctx = get_default_loc_context() + match self: + case dtype.bfloat16 | dtype.float16: + return np.half + case dtype.bool: + return np.bool_ + case dtype.complex32 | dtype.complex64: + return np.complex_ + case dtype.float32: + return np.float32 + case dtype.float64: + return np.float64 + case dtype.int8: + return np.int8 + case dtype.int16: + return np.int16 + case dtype.int32: + return np.int32 + case dtype.int64: + return np.int64 + case dtype.uint8: + return np.uint8 + case _: + raise NotImplementedError("Something's wrong with the internet") + + @staticmethod + def from_mlir_type(t: str): + match t: + case "bf16": + return dtype.bfloat16 + case "i1": + return dtype.bool + case "complex32": + return dtype.complex32 + case "complex64": + return dtype.complex64 + case "f16": + return dtype.float16 + case "f32": + return dtype.float32 + case "f64": + return dtype.float64 + case "si8": + return dtype.int8 + case "si16": + return dtype.int16 + case "si32": + return dtype.int32 + case "si64": + return dtype.int64 + case "ui8": + return dtype.uint8 + case _: + raise NotImplementedError(f"Something's wrong with the internet {t}") + + +# attr = DenseFPElementsAttr(Attribute.parse("dense<0.0> : tensor<3x5xf32>")) + + +bfloat16 = dtype.bfloat16 +bool = dtype.bool +complex32 = dtype.complex32 +complex64 = dtype.complex64 +half = float16 = dtype.float16 +float = float32 = dtype.float32 +double = float64 = dtype.float64 +int8 = dtype.int8 +int16 = dtype.int16 +int32 = dtype.int32 +long = int64 = dtype.int64 +qint8 = dtype.qint8 +quint8 = dtype.quint8 +uint8 = dtype.uint8 + +# _int = builtins.int +# _float = builtins.float +# _bool = builtins.bool +size = Union[List[int], Tuple[int, ...]] + +Number = Union[builtins.int, builtins.float, builtins.bool] +Generator = Device = Any + + +class BroadcastingListCls(object): + def __getitem__(self, types): + return + + +BroadcastingList1 = BroadcastingList2 = BroadcastingList3 = BroadcastingListCls() + +# Wrapper functions that can call either of 2 functions depending on a boolean +# argument +boolean_dispatched: "weakref.WeakKeyDictionary[Callable, Dict[str, Callable]]" = ( + weakref.WeakKeyDictionary() +) # noqa: T484 + + +def boolean_dispatch( + arg_name, arg_index, default, if_true, if_false, module_name, func_name +): + """ + Dispatches to either of 2 script functions based on a boolean argument. + In TorchScript, the boolean argument must be constant so that the correct + function to use can be determined at compile time. + """ + + def fn(*args, **kwargs): + dispatch_flag = False + if arg_name in kwargs: + dispatch_flag = kwargs[arg_name] + elif arg_index < len(args): + dispatch_flag = args[arg_index] + + if dispatch_flag: + return if_true(*args, **kwargs) + else: + return if_false(*args, **kwargs) + + if if_true.__doc__ is None and if_false.__doc__ is not None: + doc = if_false.__doc__ + if_true.__doc__ = doc + elif if_false.__doc__ is None and if_true.__doc__ is not None: + doc = if_true.__doc__ + if_false.__doc__ = doc + elif if_false.__doc__ is None and if_true.__doc__ is None: + # neither function has a docstring + doc = None + else: + raise RuntimeError("only one function can have a docstring") + fn.__doc__ = doc + + if module_name is not None: + fn.__module__ = module_name + if func_name is not None: + fn.__name__ = func_name + + boolean_dispatched[fn] = { + "if_true": if_true, + "if_false": if_false, + "index": arg_index, + "default": default, + "arg_name": arg_name, + } + return fn + + +# def _overload(func): +# qual_name = func.__name__ +# global _overloaded_fns +# fn_overload_list = _overloaded_fns.get(qual_name) +# if fn_overload_list is None: +# fn_overload_list = [] +# _overloaded_fns[qual_name] = fn_overload_list +# fn_overload_list.append(func) +# return func + +Size = Union[List[int], Tuple[int, ...]] + +# namespace c10 { +# enum class MemoryFormat : int8_t { +# Contiguous, +# Preserve, +# ChannelsLast, +# ChannelsLast3d, +# NumOptions +# }; +# enum MemoryFormat { +# Contiguous, +# Preserve, +# ChannelsLast, +# ChannelsLast3d +# }; + + +class memory_format(Enum): + contiguous_format = 0 + preserve_format = 1 + channels_last = 2 + channels_last_3d = 3 + + +contiguous_format = memory_format.contiguous_format.value +preserve_format = memory_format.preserve_format.value +channels_last = memory_format.channels_last.value +channels_last_3d = memory_format.channels_last_3d.value diff --git a/pi/utils/__init__.py b/pi/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pi/utils/hooks.py b/pi/utils/hooks.py new file mode 100644 index 0000000..1034a06 --- /dev/null +++ b/pi/utils/hooks.py @@ -0,0 +1,40 @@ +from collections import OrderedDict +import weakref +from typing import Any + +__all__ = [ + "RemovableHandle", +] + + +class RemovableHandle(object): + id: int + next_id: int = 0 + + def __init__(self, hooks_dict: Any) -> None: + self.hooks_dict_ref = weakref.ref(hooks_dict) + self.id = RemovableHandle.next_id + RemovableHandle.next_id += 1 + + def remove(self) -> None: + hooks_dict = self.hooks_dict_ref() + if hooks_dict is not None and self.id in hooks_dict: + del hooks_dict[self.id] + + def __getstate__(self): + return self.hooks_dict_ref(), self.id + + def __setstate__(self, state) -> None: + if state[0] is None: + # create a dead reference + self.hooks_dict_ref = weakref.ref(OrderedDict()) + else: + self.hooks_dict_ref = weakref.ref(state[0]) + self.id = state[1] + RemovableHandle.next_id = max(RemovableHandle.next_id, self.id + 1) + + def __enter__(self) -> "RemovableHandle": + return self + + def __exit__(self, type: Any, value: Any, tb: Any) -> None: + self.remove() diff --git a/pyproject.toml b/pyproject.toml index 620bc3b..4fe360a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,5 +4,6 @@ requires = [ "wheel", "numpy", "PyYAML", + "pybind11" ] build-backend = "setuptools.build_meta" diff --git a/scripts/generate_stuff/.gitignore b/scripts/generate_stuff/.gitignore new file mode 100644 index 0000000..60db189 --- /dev/null +++ b/scripts/generate_stuff/.gitignore @@ -0,0 +1,2 @@ +*.pyi* +*.yaml \ No newline at end of file diff --git a/scripts/generate_stuff/download_templates.sh b/scripts/generate_stuff/download_templates.sh new file mode 100644 index 0000000..7be9edd --- /dev/null +++ b/scripts/generate_stuff/download_templates.sh @@ -0,0 +1,14 @@ +wget https://raw.githubusercontent.com/pytorch/pytorch/nightly/aten/src/ATen/native/native_functions.yaml +wget https://raw.githubusercontent.com/pytorch/pytorch/nightly/aten/src/ATen/native/tags.yaml +wget https://raw.githubusercontent.com/pytorch/pytorch/nightly/tools/autograd/deprecated.yaml +wget https://raw.githubusercontent.com/pytorch/pytorch/nightly/torch/nn/functional.pyi.in +wget https://raw.githubusercontent.com/pytorch/pytorch/nightly/torch/_C/_nn.pyi.in +wget https://raw.githubusercontent.com/pytorch/pytorch/nightly/torch/_C/__init__.pyi.in +wget https://raw.githubusercontent.com/pytorch/pytorch/nightly/torch/_C/_VariableFunctions.pyi.in +wget https://raw.githubusercontent.com/pytorch/pytorch/nightly/torch/_C/return_types.pyi.in + + +#wget https://raw.githubusercontent.com/pytorch/pytorch/nightly/tools/autograd/gen_python_functions.py +#wget https://raw.githubusercontent.com/pytorch/pytorch/nightly/tools/pyi/gen_pyi.py +#wget https://raw.githubusercontent.com/pytorch/pytorch/nightly/tools/autograd/gen_trace_type.py +# diff --git a/scripts/generate_stuff/gen_pyi.py b/scripts/generate_stuff/gen_pyi.py new file mode 100644 index 0000000..0c20d56 --- /dev/null +++ b/scripts/generate_stuff/gen_pyi.py @@ -0,0 +1,959 @@ +import argparse +import collections +from pprint import pformat +from typing import Dict, List, Sequence + +from torchgen.api.python import ( + PythonSignatureGroup, + PythonSignatureNativeFunctionPair, + returns_named_tuple_pyi, +) +from torchgen.gen import parse_native_yaml + +from torchgen.model import DispatchKey, Variant +from torchgen.utils import FileManager + +from gen_python_functions import ( + group_overloads, + load_signatures, + should_generate_py_binding, +) + +""" +This module implements generation of type stubs for PyTorch, +enabling use of autocomplete in IDEs like PyCharm, which otherwise +don't understand C extension modules. + +At the moment, this module only handles type stubs for torch and +torch.Tensor. It should eventually be expanded to cover all functions +which come are autogenerated. + +Here's our general strategy: + +- We start off with a hand-written __init__.pyi.in file. This + file contains type definitions for everything we cannot automatically + generate, including pure Python definitions directly in __init__.py + (the latter case should be pretty rare). + +- We go through automatically bound functions based on the + type information recorded in native_functions.yaml and + generate type hints for them (generate_type_hints) + +There are a number of type hints which we've special-cased; +read gen_pyi for the gory details. +""" + + +def get_py_torch_functions( + python_funcs: Sequence[PythonSignatureNativeFunctionPair], + method: bool = False, +) -> Sequence[PythonSignatureGroup]: + """ + Get declarations (grouped by name) which should be generated + as either functions in the "torch" module or methods on Tensor. + """ + + def should_bind_function(python_func: PythonSignatureNativeFunctionPair) -> bool: + return ( + should_generate_py_binding(python_func.function) + and not python_func.function.python_module + and Variant.function in python_func.function.variants + ) + + def should_bind_method(python_func: PythonSignatureNativeFunctionPair) -> bool: + return ( + should_generate_py_binding(python_func.function) + and not python_func.function.python_module + and Variant.method in python_func.function.variants + ) + + should_bind = should_bind_method if method else should_bind_function + return group_overloads([f for f in python_funcs if should_bind(f)]) + + +# TODO: Consider defining some aliases for our Union[...] types, to make +# the stubs to read on the human eye. + +DEVICE_PARAM = "device: Device=None" +FACTORY_PARAMS = ( + f"dtype: Optional[_dtype]=None, {DEVICE_PARAM}, requires_grad: _bool=False" +) + +# this could be more precise w.r.t list contents etc. How to do Ellipsis? +INDICES = "indices: Union[None, _int, slice, Tensor, List, Tuple]" + +blocklist = [ + "__init_subclass__", + "__new__", + "__subclasshook__", + "cdist", + "device", + "grad", + "requires_grad", + "range", + # defined in functional + "einsum", + # reduction argument; these bindings don't make sense + "binary_cross_entropy_with_logits", + "ctc_loss", + "cosine_embedding_loss", + "hinge_embedding_loss", + "kl_div", + "margin_ranking_loss", + "triplet_margin_loss", + # Somehow, these are defined in both _C and in functional. Ick! + "broadcast_tensors", + # Manually define named tensor type stubs in __init__.pyi.in + "align_tensors", + "meshgrid", + "cartesian_prod", + "block_diag", + "norm", + "chain_matmul", + "stft", + "tensordot", + "split", + "unique_consecutive", + "atleast_1d", + "atleast_2d", + "atleast_3d", + # These are handled specially by python_arg_parser.cpp + "add", + "add_", + "add_out", + "sub", + "sub_", + "sub_out", + "mul", + "mul_", + "mul_out", + "div", + "div_", + "div_out", + "true_divide", + "true_divide_", + "true_divide_out", + "floor_divide", + "floor_divide_", + "floor_divide_out", + "to", + "_to_copy", + "copy_", +] + +binary_ops = ( + "add", + "sub", + "mul", + "div", + "pow", + "lshift", + "rshift", + "mod", + "truediv", + "matmul", + "floordiv", + "radd", + "rsub", + "rmul", + "rtruediv", + "rfloordiv", + "rpow", # reverse arithmetic + "and", + "or", + "xor", + "rand", + "ror", + "rxor", # logic + "iadd", + "iand", + "idiv", + "ilshift", + "imul", + "ior", + "irshift", + "isub", + "ixor", + "ifloordiv", + "imod", # inplace ops +) +symmetric_comparison_ops = ("eq", "ne") +asymmetric_comparison_ops = ("ge", "gt", "lt", "le") +comparison_ops = symmetric_comparison_ops + asymmetric_comparison_ops + +unary_ops = ("neg", "abs", "invert") +to_py_type_ops = ("bool", "float", "complex", "long", "index", "int", "nonzero") +all_ops = binary_ops + comparison_ops + unary_ops + to_py_type_ops + + +def sig_for_ops(opname: str) -> List[str]: + """sig_for_ops(opname : str) -> List[str] + + Returns signatures for operator special functions (__add__ etc.)""" + + # we have to do this by hand, because they are hand-bound in Python + + assert opname.endswith("__") and opname.startswith("__"), "Unexpected op {}".format( + opname + ) + + name = opname[2:-2] + if name in binary_ops: + return ["def {}(self, other: Any) -> Tensor: ...".format(opname)] + elif name in comparison_ops: + sig = "def {}(self, other: Any) -> Tensor: ...".format(opname) + if name in symmetric_comparison_ops: + # unsafe override https://github.com/python/mypy/issues/5704 + sig += " # type: ignore[override]" + return [sig] + elif name in unary_ops: + return ["def {}(self) -> Tensor: ...".format(opname)] + elif name in to_py_type_ops: + if name in {"bool", "float", "complex"}: + tname = name + elif name == "nonzero": + tname = "bool" + else: + tname = "int" + if tname in {"float", "int", "bool", "complex"}: + tname = "builtins." + tname + return ["def {}(self) -> {}: ...".format(opname, tname)] + else: + raise Exception("unknown op", opname) + + +def generate_type_hints(sig_group: PythonSignatureGroup) -> List[str]: + type_hints: List[str] = [] + + # Some deprecated ops that are on the blocklist are still included in pyi + if sig_group.signature.name in blocklist and not sig_group.signature.deprecated: + return type_hints + + # deprecated signatures have separate entries for their functional and out variants + # (as opposed to the native ops, which fuse the two into a single signature). + # generate the functional variant here, if an out variant exists. + if sig_group.signature.deprecated and sig_group.outplace is not None: + type_hint = sig_group.signature.signature_str_pyi(skip_outputs=True) + type_hints.append(type_hint) + + # PythonSignatureGroups that have both a functional + out variant get a single signature, with an optional out argument + # Generates the out variant if one exists. Otherwise, generate the functional variant + type_hint = sig_group.signature.signature_str_pyi( + skip_outputs=sig_group.outplace is None + ) + type_hints.append(type_hint) + + # Some operators also additionally have a vararg variant of their signature + type_hint_vararg = sig_group.signature.signature_str_pyi_vararg( + skip_outputs=sig_group.outplace is None + ) + if type_hint_vararg: + type_hints.append(type_hint_vararg) + + return type_hints + + +def gen_nn_functional(fm: FileManager) -> None: + # Functions imported into `torch.nn.functional` from `torch`, perhaps being filtered + # through an `_add_docstr` call + imports = [ + "conv1d", + "conv2d", + "conv3d", + "conv_transpose1d", + "conv_transpose2d", + "conv_transpose3d", + "conv_tbc", + "avg_pool1d", + "relu_", + "selu_", + "celu_", + "rrelu_", + "pixel_shuffle", + "pixel_unshuffle", + "channel_shuffle", + "native_channel_shuffle", + "pdist", + "cosine_similarity", + ] + # Functions generated by `torch._jit_internal.boolean_dispatch` + dispatches = [ + "fractional_max_pool2d", + "fractional_max_pool3d", + "max_pool1d", + "max_pool2d", + "max_pool3d", + "adaptive_max_pool1d", + "adaptive_max_pool2d", + "adaptive_max_pool3d", + ] + # Functions directly imported from `torch._C` + from_c = [ + "avg_pool2d", + "avg_pool3d", + "hardtanh_", + "elu_", + "leaky_relu_", + "logsigmoid", + "softplus", + "softshrink", + "one_hot", + ] + import_code = ["from .. import {0} as {0}".format(_) for _ in imports] + # TODO make these types more precise + dispatch_code = ["{}: Callable".format(_) for _ in (dispatches + from_c)] + fm.write_with_template( + "functional.pyi", + "functional.pyi.in", + lambda: { + "imported_hints": import_code, + "dispatched_hints": dispatch_code, + }, + ) + + # functional.pyi already contains the definitions for those functions + # so, we don't export then to it + from_c.extend(["hardtanh", "leaky_relu", "hardsigmoid"]) + dispatch_code = ["{}: Callable".format(_) for _ in (dispatches + from_c)] + fm.write_with_template( + "_nn.pyi", + "_nn.pyi.in", + lambda: { + "imported_hints": import_code, + "dispatched_hints": dispatch_code, + }, + ) + + +def gen_pyi( + native_yaml_path: str, + tags_yaml_path: str, + deprecated_yaml_path: str, + fm: FileManager, +) -> None: + """gen_pyi() + + This function generates a pyi file for torch. + """ + + # Some of this logic overlaps with generate_python_signature in + # tools/autograd/gen_python_functions.py; however, this + # function is all about generating mypy type signatures, whereas + # the other function generates are custom format for argument + # checking. If you are update this, consider if your change + # also needs to update the other file. + + # Dictionary for NamedTuple definitions + namedtuples: Dict[str, str] = {} + + # Generate type signatures for top-level functions + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + unsorted_function_hints: Dict[str, List[str]] = collections.defaultdict(list) + + for n, n1, n2 in [ + ("csr", "crow", "col"), + ("csc", "ccol", "row"), + ("bsr", "crow", "col"), + ("bsc", "ccol", "row"), + ]: + unsorted_function_hints.update( + { + f"sparse_{n}_tensor": [ + f"def sparse_{n}_tensor({n1}_indices: Union[Tensor, List]," + f"{n2}_indices: Union[Tensor, List]," + " values: Union[Tensor, List], size: Optional[_size]=None," + " *, dtype: Optional[_dtype]=None," + " device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ..." + ], + f"_sparse_{n}_tensor_unsafe": [ + f"def _sparse_{n}_tensor_unsafe({n1}_indices: Union[Tensor, List]," + f"{n2}_indices: Union[Tensor, List]," + " values: Union[Tensor, List], size: List[int]," + " dtype: Optional[_dtype] = None, device: Optional[_device] = None," + " requires_grad: bool = False) -> Tensor: ..." + ], + } + ) + + unsorted_function_hints.update( + { + "set_flush_denormal": ["def set_flush_denormal(mode: _bool) -> _bool: ..."], + "get_default_dtype": ["def get_default_dtype() -> _dtype: ..."], + "asarray": [ + "def asarray(obj: Any, *, dtype: Optional[_dtype]=None, " + "device: Union[_device, str, None]=None, copy: Optional[_bool]=None, " + "requires_grad: _bool=False) -> Tensor: ..." + ], + "from_numpy": ["def from_numpy(ndarray) -> Tensor: ..."], + "frombuffer": [ + "def frombuffer(buffer: Any, *, dtype: _dtype, count: int=-1, " + "offset: int=0, device: Union[_device, str, None]=None, " + "requires_grad: _bool=False) -> Tensor: ..." + ], + "numel": ["def numel(self: Tensor) -> _int: ..."], + "as_tensor": [ + f"def as_tensor(data: Any, dtype: Optional[_dtype]=None, {DEVICE_PARAM}) -> Tensor: ..." + ], + "get_num_threads": ["def get_num_threads() -> _int: ..."], + "set_num_threads": ["def set_num_threads(num: _int) -> None: ..."], + "init_num_threads": ["def init_num_threads() -> None: ..."], + "get_num_interop_threads": ["def get_num_interop_threads() -> _int: ..."], + "set_num_interop_threads": [ + "def set_num_interop_threads(num: _int) -> None: ..." + ], + # These functions are explicitly disabled by + # SKIP_PYTHON_BINDINGS because they are hand bound. + # Correspondingly, we must hand-write their signatures. + "tensor": [ + "def tensor(data: Any, {}) -> Tensor: ...".format(FACTORY_PARAMS) + ], + "sparse_coo_tensor": [ + "def sparse_coo_tensor(indices: Tensor, values: Union[Tensor,List]," + " size: Optional[_size]=None, *, dtype: Optional[_dtype]=None," + " device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ..." + ], + "_sparse_coo_tensor_unsafe": [ + "def _sparse_coo_tensor_unsafe(indices: Tensor, values: Tensor, size: List[int]," + " dtype: Optional[_dtype] = None, device: Optional[_device] = None," + " requires_grad: bool = False) -> Tensor: ..." + ], + "sparse_compressed_tensor": [ + "def sparse_compressed_tensor(compressed_indices: Union[Tensor, List]," + "plain_indices: Union[Tensor, List]," + " values: Union[Tensor, List], size: Optional[_size]=None," + " *, dtype: Optional[_dtype]=None, layout: Optional[_layout] = None," + " device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ..." + ], + "_sparse_compressed_tensor_unsafe": [ + "def _sparse_compressed_tensor_unsafe(comp_indices: Union[Tensor, List]," + "plain_indices: Union[Tensor, List]," + " values: Union[Tensor, List], size: List[int]," + " dtype: Optional[_dtype] = None, layout: Optional[_layout] = None," + " device: Optional[_device] = None," + " requires_grad: bool = False) -> Tensor: ..." + ], + "_sync": ["def _sync(t: Tensor) -> None: ..."], + "_is_functional_tensor": [ + "def _is_functional_tensor(t: Tensor) -> _bool: ..." + ], + "_from_functional_tensor": [ + "def _from_functional_tensor(t: Tensor) -> Tensor: ..." + ], + "_to_functional_tensor": [ + "def _to_functional_tensor(t: Tensor) -> Tensor: ..." + ], + "_enable_functionalization": [ + "def _enable_functionalization(*, reapply_views: _bool = False): ..." + ], + "_disable_functionalization": ["def _disable_functionalization(): ..."], + "range": [ + "def range(start: Number, end: Number," + " step: Number=1, *, out: Optional[Tensor]=None, {}) -> Tensor: ...".format( + FACTORY_PARAMS + ) + ], + "arange": [ + "def arange(start: Number, end: Number, step: Number, *," + " out: Optional[Tensor]=None, {}) -> Tensor: ...".format( + FACTORY_PARAMS + ), + "def arange(start: Number, end: Number, *, out: Optional[Tensor]=None, {}) -> Tensor: ...".format( + FACTORY_PARAMS + ), + "def arange(end: Number, *, out: Optional[Tensor]=None, {}) -> Tensor: ...".format( + FACTORY_PARAMS + ), + ], + "linspace": [ + "def linspace(start: Number, end: Number, steps: Optional[_int]=None, *," + " out: Optional[Tensor]=None, {}) -> Tensor: ...".format(FACTORY_PARAMS) + ], + "logspace": [ + "def logspace(start: Number, end: Number, steps: Optional[_int]=None, base: _float=10.0, *," + " out: Optional[Tensor]=None, {}) -> Tensor: ...".format(FACTORY_PARAMS) + ], + "randint": [ + "def randint(low: _int, high: _int, size: _size, *," + " generator: Optional[Generator]=None, {}) -> Tensor: ...".format( + FACTORY_PARAMS + ), + "def randint(high: _int, size: _size, *," + " generator: Optional[Generator]=None, {}) -> Tensor: ...".format( + FACTORY_PARAMS + ), + ], + "full": [ + "def full(size: _size, fill_value: Number, *," + " out: Optional[Tensor]=None," + " layout: _layout=strided, {}) -> Tensor: ...".format(FACTORY_PARAMS), + "def full(size: _size, fill_value: Number, *," + " names: List[Union[str, None]]," + " layout: _layout=strided, {}) -> Tensor: ...".format(FACTORY_PARAMS), + ], + "is_grad_enabled": ["def is_grad_enabled() -> _bool: ..."], + "is_inference_mode_enabled": [ + "def is_inference_mode_enabled() -> _bool: ..." + ], + "nonzero": [ + "def nonzero(input: Tensor, *, as_tuple: Literal[False]=False, out: Optional[Tensor]=None) -> Tensor: ...", + "def nonzero(input: Tensor, *, as_tuple: Literal[True]) -> Tuple[Tensor, ...]: ...", + ], + "binary_cross_entropy_with_logits": [ + "def binary_cross_entropy_with_logits(input: Tensor, target: Tensor, " + "weight: Optional[Tensor] = None, size_average: Optional[bool] = None, " + "reduce: Optional[bool] = None, reduction: str = ..., " + "pos_weight: Optional[Tensor] = None) -> Tensor: ..." + ], + "cosine_embedding_loss": [ + "def cosine_embedding_loss(input1: Tensor, input2: Tensor, " + "target: Tensor, margin: float = ..., size_average: Optional[bool] = ..., " + "reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ..." + ], + "ctc_loss": [ + "def ctc_loss(log_probs: Tensor, targets: Tensor, input_lengths: Tensor, target_lengths: Tensor," + " blank: int = ..., reduction: str = ..., zero_infinity: bool = ...) -> Tensor: ..." + ], + "hinge_embedding_loss": [ + "def hinge_embedding_loss(input: Tensor, target: Tensor, margin: float = ...," + " size_average: Optional[bool] = ..., reduce: Optional[bool] = ..., " + "reduction: str = ...) -> Tensor: ..." + ], + "kl_div": [ + "def kl_div(input: Tensor, target: Tensor, size_average: Optional[bool] = ..., " + "reduce: Optional[bool] = ..., reduction: str = ..., log_target: bool = ...) -> Tensor: ..." + ], + "margin_ranking_loss": [ + "def margin_ranking_loss(input1: Tensor, input2: Tensor, target: Tensor," + " margin: float = ..., size_average: Optional[bool] = ..., " + " reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ..." + ], + "triplet_margin_loss": [ + "def triplet_margin_loss(anchor: Tensor, positive: Tensor, negative: Tensor, " + "margin: float = ..., p: float = ..., eps: float = ..., swap: bool = ..., " + "size_average: Optional[bool] = ..., " + "reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ..." + ], + "dsmm": ["def dsmm(input: Tensor, mat2: Tensor) -> Tensor: ..."], + "hsmm": ["def hsmm(input: Tensor, mat2: Tensor) -> Tensor: ..."], + "saddmm": [ + "def saddmm(input: Tensor, mat1: Tensor, mat2: Tensor, *, beta: Number=1, " + "alpha: Number=1, out: Optional[Tensor]=None) -> Tensor: ..." + ], + "spmm": ["def spmm(input: Tensor, mat2: Tensor) -> Tensor: ..."], + "div": [ + "def div(input: Union[Tensor, Number], other: Union[Tensor, Number], *, " + "rounding_mode: Optional[str] = None, out: Optional[Tensor]=None) -> Tensor: ..." + ], + } + ) + for binop in ["mul", "true_divide", "floor_divide"]: + unsorted_function_hints[binop].append( + "def {}(input: Union[Tensor, Number]," + " other: Union[Tensor, Number]," + " *, out: Optional[Tensor]=None) -> Tensor: ...".format(binop) + ) + for binop in ["add", "sub"]: + unsorted_function_hints[binop].append( + "def {}(input: Union[Tensor, Number]," + " other: Union[Tensor, Number]," + " *, alpha: Optional[Number]=1, out: Optional[Tensor]=None) -> Tensor: ...".format( + binop + ) + ) + + native_functions = parse_native_yaml( + native_yaml_path, tags_yaml_path + ).native_functions + native_functions = list(filter(should_generate_py_binding, native_functions)) + + function_signatures = load_signatures( + native_functions, deprecated_yaml_path, method=False, pyi=True + ) + sig_groups = get_py_torch_functions(function_signatures) + for group in sorted(sig_groups, key=lambda g: g.signature.name): + name = group.signature.name + unsorted_function_hints[name] += generate_type_hints(group) + + named_tuple = returns_named_tuple_pyi(group.signature) + if named_tuple is not None and not group.signature.deprecated: + # deprecated namedtuples are currently not included for torch functions + tuple_name, tuple_def = named_tuple + if tuple_name in namedtuples: + assert namedtuples[tuple_name] == tuple_def + else: + namedtuples[tuple_name] = tuple_def + + function_hints = [] + for name, hints in sorted(unsorted_function_hints.items()): + if len(hints) > 1: + hints = ["@overload\n" + h for h in hints] + function_hints += hints + + # Generate type signatures for Tensor methods + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + unsorted_tensor_method_hints: Dict[str, List[str]] = collections.defaultdict(list) + unsorted_tensor_method_hints.update( + { + "size": [ + "def size(self) -> Size: ...", + "def size(self, dim: _int) -> _int: ...", + ], + "stride": [ + "def stride(self) -> Tuple[_int, ...]: ...", + "def stride(self, _int) -> _int: ...", + ], + "new_ones": [ + "def new_ones(self, size: _size, {}) -> Tensor: ...".format( + FACTORY_PARAMS + ) + ], + "new_tensor": [ + "def new_tensor(self, data: Any, {}) -> Tensor: ...".format( + FACTORY_PARAMS + ) + ], + # new and __init__ have the same signatures differ only in return type + # Adapted from legacy_tensor_ctor and legacy_tensor_new + "new": [ + "def new(self, *args: Any, {}) ->Tensor: ...".format(DEVICE_PARAM), + "def new(self, storage: Storage) -> Tensor: ...", + "def new(self, other: Tensor) -> Tensor: ...", + "def new(self, size: _size, *, {}) -> Tensor: ...".format(DEVICE_PARAM), + ], + "__init__": [ + "def __init__(self, *args: Any, {}) -> None: ...".format(DEVICE_PARAM), + "def __init__(self, storage: Storage) -> None: ...", + "def __init__(self, other: Tensor) -> None: ...", + "def __init__(self, size: _size, *, {}) -> None: ...".format( + DEVICE_PARAM + ), + ], + "as_subclass": ["def as_subclass(self, cls: Type[S]) -> S: ..."], + "_make_subclass": [ + "def _make_subclass(cls, data: Tensor, require_grad: _bool = False, dispatch_strides: _bool=False," + " dispatch_device: _bool=False, device_for_backend_keys: Optional[_device] = None) -> Tensor: ..." + ], + "__getitem__": ["def __getitem__(self, {}) -> Tensor: ...".format(INDICES)], + "__setitem__": [ + "def __setitem__(self, {}, val: Union[Tensor, Number])" + " -> None: ...".format(INDICES) + ], + "tolist": ["def tolist(self) -> List: ..."], + "requires_grad_": [ + "def requires_grad_(self, mode: _bool=True) -> Tensor: ..." + ], + "element_size": ["def element_size(self) -> _int: ..."], + "data_ptr": ["def data_ptr(self) -> _int: ..."], + "dim": ["def dim(self) -> _int: ..."], + "nonzero": [ + "def nonzero(self, *, as_tuple: Literal[False]=False) -> Tensor: ...", + "def nonzero(self, *, as_tuple: Literal[True]) -> Tuple[Tensor, ...]: ...", + ], + "numel": ["def numel(self) -> _int: ..."], + "ndimension": ["def ndimension(self) -> _int: ..."], + "nelement": ["def nelement(self) -> _int: ..."], + "cuda": [ + "def cuda(self, device: Optional[Union[_device, _int, str]]=None, non_blocking: _bool=False) -> Tensor: ..." + ], + "numpy": ["def numpy(self, *, force: _bool=False) -> Any: ..."], + "apply_": ["def apply_(self, callable: Callable) -> Tensor: ..."], + "map_": [ + "def map_(self, tensor: Tensor, callable: Callable) -> Tensor: ..." + ], + "map2_": [ + "def map2_(self, x: Tensor, y: Tensor, callable: Callable) -> Tensor: ..." + ], + "storage": ["def untyped_storage(self) -> Storage: ..."], + "storage_type": ["def storage_type(self) -> Storage: ..."], + "type": [ + "def type(self, dtype: None=None, non_blocking: _bool=False) -> str: ...", + "def type(self, dtype: Union[str, _dtype], non_blocking: _bool=False) -> Tensor: ...", + ], + "get_device": ["def get_device(self) -> _int: ..."], + "contiguous": [ + "def contiguous(self, memory_format=torch.contiguous_format) -> Tensor: ..." + ], + "has_names": ["def has_names(self) -> _bool: ..."], + "is_contiguous": [ + "def is_contiguous(self, memory_format=torch.contiguous_format) -> _bool: ..." + ], + "_is_view": ["def _is_view(self) -> _bool: ..."], + "is_cuda": ["is_cuda: _bool"], + "is_leaf": ["is_leaf: _bool"], + "is_nested": ["is_nested: _bool"], + "is_sparse": ["is_sparse: _bool"], + "is_sparse_csr": ["is_sparse_csr: _bool"], + "is_quantized": ["is_quantized: _bool"], + "is_meta": ["is_meta: _bool"], + "is_mps": ["is_mps: _bool"], + "is_ort": ["is_ort: _bool"], + "is_mkldnn": ["is_mkldnn: _bool"], + "is_vulkan": ["is_vulkan: _bool"], + "is_ipu": ["is_ipu: _bool"], + "storage_offset": ["def storage_offset(self) -> _int: ..."], + "to": [ + "def to(self, dtype: _dtype, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...", + "def to(self, device: Optional[Union[_device, str]]=None, dtype: Optional[_dtype]=None, " + "non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...", + "def to(self, other: Tensor, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...", + ], + "item": ["def item(self) -> Number: ..."], + "copy_": [ + "def copy_(self, src: Tensor, non_blocking: _bool=False) -> Tensor: ..." + ], + "set_": [ + "def set_(self, storage: Union[Storage, TypedStorage], offset: _int, size: _size, stride: _size) -> Tensor: ...", + "def set_(self, storage: Union[Storage, TypedStorage]) -> Tensor: ...", + ], + "split": [ + "def split(self, split_size: _int, dim: _int=0) -> Sequence[Tensor]: ...", + "def split(self, split_size: Tuple[_int, ...], dim: _int=0) -> Sequence[Tensor]: ...", + ], + "div": [ + "def div(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: ..." + ], + "div_": [ + "def div_(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: ..." + ], + } + ) + for binop in ["mul", "true_divide", "floor_divide"]: + for inplace in [False, True]: + out_suffix = ", *, out: Optional[Tensor]=None" + if inplace: + binop += "_" + out_suffix = "" + unsorted_tensor_method_hints[binop].append( + "def {}(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat]{})" + " -> Tensor: ...".format(binop, out_suffix) + ) + for binop in ["add", "sub"]: + for inplace in [False, True]: + out_suffix = ", out: Optional[Tensor]=None" + if inplace: + binop += "_" + out_suffix = "" + unsorted_tensor_method_hints[binop].append( + "def {}(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat], " + "*, alpha: Optional[Number]=1{})" + " -> Tensor: ...".format(binop, out_suffix) + ) + simple_conversions = [ + "byte", + "char", + "cpu", + "double", + "float", + "half", + "int", + "long", + "short", + "bool", + "bfloat16", + ] + for name in simple_conversions: + unsorted_tensor_method_hints[name].append( + "def {}(self) -> Tensor: ...".format(name) + ) + + # pyi tensor methods don't currently include deprecated signatures for some reason + # TODO: we should probably add them in + tensor_method_signatures = load_signatures( + native_functions, + deprecated_yaml_path, + method=True, + skip_deprecated=True, + pyi=True, + ) + tensor_method_sig_groups = get_py_torch_functions( + tensor_method_signatures, method=True + ) + + for group in sorted(tensor_method_sig_groups, key=lambda g: g.signature.name): + name = group.signature.name + unsorted_tensor_method_hints[name] += generate_type_hints(group) + + named_tuple = returns_named_tuple_pyi(group.signature) + if named_tuple is not None and not group.signature.deprecated: + # deprecated namedtuples are currently not included for torch functions + tuple_name, tuple_def = named_tuple + if tuple_name in namedtuples: + assert namedtuples[tuple_name] == tuple_def + else: + namedtuples[tuple_name] = tuple_def + + for op in all_ops: + name = "__{}__".format(op) + unsorted_tensor_method_hints[name] += sig_for_ops(name) + + tensor_method_hints = [] + for name, hints in sorted(unsorted_tensor_method_hints.items()): + if len(hints) > 1: + hints = ["@overload\n" + h for h in hints] + tensor_method_hints += hints + + # TODO: Missing type hints for nn + + # Generate namedtuple definitions + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + namedtuple_defs = [ + "{} = {}".format(name, defn) for name, defn in namedtuples.items() + ] + + # Generate type signatures for legacy classes + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + legacy_storage_base_hints = ["class StorageBase(object): ..."] + + legacy_class_hints = [] + for c in ( + "DoubleTensor", + "FloatTensor", + "LongTensor", + "IntTensor", + "ShortTensor", + "HalfTensor", + "CharTensor", + "ByteTensor", + "BoolTensor", + ): + legacy_class_hints.append("class {}(Tensor): ...".format(c)) + + # Generate type signatures for dtype classes + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # TODO: don't explicitly list dtypes here; get it from canonical + # source + dtype_class_hints = [ + "{}: dtype = ...".format(n) + for n in [ + "float32", + "float", + "float64", + "double", + "float16", + "bfloat16", + "half", + "uint8", + "int8", + "int16", + "short", + "int32", + "int", + "int64", + "long", + "complex32", + "complex64", + "cfloat", + "complex128", + "cdouble", + "quint8", + "qint8", + "qint32", + "bool", + "quint4x2", + "quint2x4", + ] + ] + + # Generate __all__ directive + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # Include only the functions that contain hints, to prevent undefined + # symbols to be included in the `__all__` directive. + hinted_function_names = [ + name for name, hint in unsorted_function_hints.items() if hint + ] + all_symbols = sorted(list(namedtuples.keys()) + hinted_function_names) + all_directive = pformat(all_symbols, width=100, compact=True).split("\n") + all_directive[0] = "__all__ = {}".format(all_directive[0]) + + # Dispatch key hints + # ~~~~~~~~~~~~~~~~~~ + dispatch_key_hints = [f"{d.name}: DispatchKey = ..." for d in DispatchKey] + + # Write out the stub + # ~~~~~~~~~~~~~~~~~~ + + env = { + "namedtuple_defs": namedtuple_defs, + "function_hints": function_hints, + "tensor_method_hints": tensor_method_hints, + "legacy_class_hints": legacy_class_hints, + "legacy_storage_base_hints": legacy_storage_base_hints, + "dtype_class_hints": dtype_class_hints, + "dispatch_key_hints": dispatch_key_hints, + "all_directive": all_directive, + } + fm.write_with_template( + "__init__.pyi", + "__init__.pyi.in", + lambda: { + "generated_comment": "@" + "generated from torch/_C/__init__.pyi.in", + **env, + }, + ) + fm.write_with_template( + "_VariableFunctions.pyi", + "_VariableFunctions.pyi.in", + lambda: { + "generated_comment": "@" + + "generated from torch/_C/_VariableFunctions.pyi.in", + **env, + }, + ) + fm.write_with_template( + "_VF.pyi", + "_VariableFunctions.pyi.in", + lambda: { + "generated_comment": "@" + + "generated from torch/_C/_VariableFunctions.pyi.in", + **env, + }, + ) + fm.write_with_template( + "return_types.pyi", + "return_types.pyi.in", + lambda: { + "generated_comment": "@" + "generated from torch/_C/return_types.pyi", + **env, + }, + ) + gen_nn_functional(fm) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate type stubs for PyTorch") + parser.add_argument( + "--native-functions-path", + metavar="NATIVE", + default="native_functions.yaml", + help="path to native_functions.yaml", + ) + parser.add_argument( + "--tags-path", + metavar="TAGS", + default="tags.yaml", + help="path to tags.yaml", + ) + parser.add_argument( + "--deprecated-functions-path", + metavar="DEPRECATED", + default="deprecated.yaml", + help="path to deprecated.yaml", + ) + parser.add_argument( + "--out", metavar="OUT", default=".", help="path to output directory" + ) + args = parser.parse_args() + fm = FileManager(install_dir=args.out, template_dir=".", dry_run=False) + gen_pyi( + args.native_functions_path, args.tags_path, args.deprecated_functions_path, fm + ) + + +if __name__ == "__main__": + main() diff --git a/scripts/generate_stuff/gen_python_functions.py b/scripts/generate_stuff/gen_python_functions.py new file mode 100644 index 0000000..2a85a27 --- /dev/null +++ b/scripts/generate_stuff/gen_python_functions.py @@ -0,0 +1,1306 @@ +# Generates Python bindings for ATen functions +# +# The bindings are generated as methods on python_variable or functions on the +# torch._C._nn. torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._sparse +# or torch._C._special objects. +# + +# Code tries to stick to the following rules: +# +# - templates should be colocated with the functions that use them. +# no templates are currently shared between functions, but if that +# happens, maybe put the template with the first one +# +# - don't use environment dictionaries when calling template.substitute(). +# pass named arguments directly for everything, otherwise it's much too +# hard to track what's actually being used and by who +# +# - colocate any new hacks/adjustments with existing ones of the same kind. +# ideally in a data structure rather than code if possible. See e.g. +# SCHEMA_DEFAULT_CONVERSION_HACKS, etc. +# +# - similarly, conversions from one format to another should ideally happen +# all at once in a single place. +# +# - no nontrivial nested functions. couple-liners are ok but please no more. +# especially avoid functions that read/write outer variables defined far away. +# +# - raise RuntimeError instead of asserting, and put as much +# information as is available into the message. I.e. no need to +# plumb in new params whose only purpose is to fill out an error +# message, but use what's there +# + +import itertools +import re +from collections import defaultdict + +from typing import Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple + +import yaml +from torchgen.api import cpp +from torchgen.api.python import ( + arg_parser_output_exprs, + cpp_dispatch_exprs, + cpp_dispatch_target, + dispatch_lambda_args, + dispatch_lambda_exprs, + dispatch_lambda_return_str, + has_tensor_options, + namedtuple_fieldnames, + PythonSignature, + PythonSignatureDeprecated, + PythonSignatureGroup, + PythonSignatureNativeFunctionPair, + signature, + signature_from_schema, +) + +from torchgen.code_template import CodeTemplate +from torchgen.context import with_native_function +from torchgen.gen import cpp_string, parse_native_yaml, parse_tags_yaml +from torchgen.model import ( + Argument, + BaseOperatorName, + FunctionSchema, + NativeFunction, + Type, + Variant, +) +from torchgen.utils import FileManager, split_name_params, YamlLoader + +from gen_trace_type import should_trace + +# +# declarations blocklist +# We skip codegen for these functions, for various reasons. +# Future PRs will categorize this list and eliminate or hoist +# them out of eager-only codegen. +# See https://github.com/pytorch/pytorch/issues/30788 +# + +# These functions require manual Python bindings or are not exposed to Python +_SKIP_PYTHON_BINDINGS = [ + "alias", + "contiguous", + "is_cuda", + "is_sparse", + "is_sparse_csr", + "size", + "stride", + ".*_backward", + ".*_backward_(out|input|weight|bias)", + ".*_forward", + ".*_forward_out", + ".*_jvp", + "_unsafe_view", + "tensor", + "_?sparse_(coo|compressed|csr|csc|bsr|bsc)_tensor.*", + "_range.*", + "_sparse_add_out", + "_sparse_div.*", + "_sparse_mul.*", + "_sparse_sub.*", + "_sparse_dense_add_out", + "index", + "index_out", + "unique_dim_consecutive", + "_cumsum.*", + "_cumprod.*", + "_sum.*", + "_prod.*", + "_th_.*", + "_thnn_.*", + "range.*", + "_solve.*", + "_inverse.*", + "_cholesky.*", + "_triangular_solve.*", + "_qr.*", + "_symeig.*", + "_svd.*", + "slice", + "item", + "_local_scalar_dense", + "to", + "_to_copy", + "_reshape_copy", + "copy_sparse_to_sparse_", + "copy_", + "numpy_T", + "matrix_H", + "mT", + "mH", # these need to be an attributes in Python, not functions + "nonzero(_(out|numpy))?", + "set_data", + ".*_overrideable", # overrideable functions for backend extension + "data", + "is_leaf", + "output_nr", + "_version", + "requires_grad_", + "retains_grad", + "set_", + "_fw_primal", + "fake_quantize_per_tensor_affine_cachemask", + "fake_quantize_per_channel_affine_cachemask", + "_new_zeros_with_same_feature_meta", + "_has_same_storage_numel", # used for forward AD internals + "_reshape_alias", + "replace_", # only used by the functionalization pass, doesn't need to be exposed to python + "copy", # only used by the functionalization pass + "fill.Tensor", # only used by the functionalization pass + "fill.Scalar", # only used by the functionalization pass + "lift.*", + "normal_functional", # only used by the functionalization pas + "_nested_tensor_strides", # don't want to expose this to python + "_nested_tensor_offsets", # don't want to expose this to python + "_nested_view_from_buffer", # View only version of _nested_from_buffer. This will force users to only use the "safe" version. + "_nested_view_from_buffer_copy", +] + +SKIP_PYTHON_BINDINGS = list( + map(lambda pattern: re.compile(rf"^{pattern}$"), _SKIP_PYTHON_BINDINGS) +) + +# These function signatures are not exposed to Python. Note that this signature +# list does not support regex. +SKIP_PYTHON_BINDINGS_SIGNATURES = [ + "add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", + "add_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)", + "sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", + "sub_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)", + "mul.Scalar(Tensor self, Scalar other) -> Tensor", + "mul_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", + "div.Scalar(Tensor self, Scalar other) -> Tensor", + "div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", +] + + +@with_native_function +def should_generate_py_binding(f: NativeFunction) -> bool: + # So far, all NativeFunctions that are entirely code-generated do not get python bindings. + if "generated" in f.tags: + return False + name = cpp.name(f.func) + for skip_regex in SKIP_PYTHON_BINDINGS: + if skip_regex.match(name): + return False + + signature = str(f.func) + for pattern in SKIP_PYTHON_BINDINGS_SIGNATURES: + if pattern == signature: + return False + + return True + + +def get_pycname(name: BaseOperatorName) -> str: + return f"THPVariable_{name}" + + +def is_noarg(overloads: Sequence[PythonSignatureNativeFunctionPair]) -> bool: + return len(overloads) == 1 and overloads[0].signature.arguments_count() == 0 + + +def is_py_variable_method(f: NativeFunction) -> bool: + return f.python_module is None and Variant.method in f.variants + + +def is_py_torch_function(f: NativeFunction) -> bool: + return f.python_module is None and Variant.function in f.variants + + +def is_py_nn_function(f: NativeFunction) -> bool: + return f.python_module == "nn" + + +def is_py_fft_function(f: NativeFunction) -> bool: + return f.python_module == "fft" + + +def is_py_linalg_function(f: NativeFunction) -> bool: + return f.python_module == "linalg" + + +def is_py_nested_function(f: NativeFunction) -> bool: + return f.python_module == "nested" + + +def is_py_sparse_function(f: NativeFunction) -> bool: + return f.python_module == "sparse" + + +def is_py_special_function(f: NativeFunction) -> bool: + return f.python_module == "special" + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Main Function +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def gen( + out: str, + native_yaml_path: str, + tags_yaml_path: str, + deprecated_yaml_path: str, + template_path: str, + *, + symint: bool = True, +) -> None: + fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) + native_functions = parse_native_yaml( + native_yaml_path, tags_yaml_path + ).native_functions + native_functions = list(filter(should_generate_py_binding, native_functions)) + + methods = load_signatures(native_functions, deprecated_yaml_path, method=True) + create_python_bindings( + fm, + methods, + is_py_variable_method, + None, + "python_variable_methods.cpp", + method=True, + symint=symint, + ) + + # NOTE: num_shards here must be synced with gatherTorchFunctions in + # torch/csrc/autograd/python_torch_functions_manual.cpp + functions = load_signatures(native_functions, deprecated_yaml_path, method=False) + create_python_bindings_sharded( + fm, + functions, + is_py_torch_function, + "torch", + "python_torch_functions.cpp", + method=False, + num_shards=3, + symint=symint, + ) + + create_python_bindings( + fm, + functions, + is_py_nn_function, + "torch.nn", + "python_nn_functions.cpp", + method=False, + symint=symint, + ) + + create_python_bindings( + fm, + functions, + is_py_fft_function, + "torch.fft", + "python_fft_functions.cpp", + method=False, + symint=symint, + ) + + create_python_bindings( + fm, + functions, + is_py_linalg_function, + "torch.linalg", + "python_linalg_functions.cpp", + method=False, + symint=symint, + ) + + create_python_bindings( + fm, + functions, + is_py_nested_function, + "torch.nested", + "python_nested_functions.cpp", + method=False, + ) + + create_python_bindings( + fm, + functions, + is_py_sparse_function, + "torch.sparse", + "python_sparse_functions.cpp", + method=False, + symint=symint, + ) + + create_python_bindings( + fm, + functions, + is_py_special_function, + "torch.special", + "python_special_functions.cpp", + method=False, + symint=symint, + ) + + # Currently, we only use `functions` to generate `return_types` bindings. + # All methods which return namedtuple have function variant at this point. + # If any method only operator with namedtuple is added in the future, + # we will have to address that. + create_python_return_type_bindings( + fm, functions, lambda fn: True, "python_return_types.cpp" + ) + + valid_tags = parse_tags_yaml(tags_yaml_path) + + def gen_tags_enum() -> Dict[str, str]: + return { + "enum_of_valid_tags": ( + "".join([f'\n.value("{tag}", at::Tag::{tag})' for tag in valid_tags]) + ) + } + + fm.write("python_enum_tag.cpp", gen_tags_enum) + + +def group_filter_overloads( + pairs: Sequence[PythonSignatureNativeFunctionPair], + pred: Callable[[NativeFunction], bool], +) -> Dict[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]: + grouped: Dict[ + BaseOperatorName, List[PythonSignatureNativeFunctionPair] + ] = defaultdict(list) + for pair in pairs: + if pred(pair.function): + grouped[pair.function.func.name.name].append(pair) + return grouped + + +def create_python_bindings( + fm: FileManager, + pairs: Sequence[PythonSignatureNativeFunctionPair], + pred: Callable[[NativeFunction], bool], + module: Optional[str], + filename: str, + *, + method: bool, + symint: bool = True, +) -> None: + """Generates Python bindings to ATen functions""" + py_methods: List[str] = [] + ops_headers: List[str] = [] + py_method_defs: List[str] = [] + py_forwards: List[str] = [] + + grouped = group_filter_overloads(pairs, pred) + + for name in sorted(grouped.keys(), key=lambda x: str(x)): + overloads = grouped[name] + py_methods.append( + method_impl(name, module, overloads, method=method, symint=symint) + ) + py_method_defs.append(method_def(name, module, overloads, method=method)) + py_forwards.extend(forward_decls(name, overloads, method=method)) + ops_headers.append(f"#include ") + + fm.write_with_template( + filename, + filename, + lambda: { + "generated_comment": "@" + + f"generated from {fm.template_dir_for_comments()}/{filename}", + "ops_headers": ops_headers, + "py_forwards": py_forwards, + "py_methods": py_methods, + "py_method_defs": py_method_defs, + }, + ) + + +def create_python_return_type_bindings( + fm: FileManager, + pairs: Sequence[PythonSignatureNativeFunctionPair], + pred: Callable[[NativeFunction], bool], + filename: str, +) -> None: + """ + Generate function to initialize and return named tuple for native functions + which returns named tuple and relevant entry for the map in `python_return_types.cpp`. + """ + py_return_types_definition: List[str] = [] + py_return_types_map: List[str] = [] + + grouped = group_filter_overloads(pairs, pred) + + for name in sorted(grouped.keys(), key=lambda x: str(x)): + overloads = grouped[name] + definitions, map_entries = generate_return_type_definition_and_map_entry( + overloads + ) + py_return_types_definition.append( + "" if not definitions else "\n".join(definitions) + ) + py_return_types_map.append("" if not map_entries else "\n".join(map_entries)) + + fm.write_with_template( + filename, + filename, + lambda: { + "generated_comment": "@" + + f"generated from {fm.template_dir_for_comments()}/{filename}", + "py_return_types": py_return_types_definition, + "py_return_types_map": py_return_types_map, + }, + ) + + +def create_python_bindings_sharded( + fm: FileManager, + pairs: Sequence[PythonSignatureNativeFunctionPair], + pred: Callable[[NativeFunction], bool], + module: Optional[str], + filename: str, + *, + method: bool, + num_shards: int, + symint: bool = True, +) -> None: + """Generates Python bindings to ATen functions""" + grouped = group_filter_overloads(pairs, pred) + + def key_func( + kv: Tuple[BaseOperatorName, List[PythonSignatureNativeFunctionPair]] + ) -> str: + return kv[0].base + + def env_func( + kv: Tuple[BaseOperatorName, List[PythonSignatureNativeFunctionPair]] + ) -> Dict[str, List[str]]: + name, fn_pairs = kv + return { + "ops_headers": [f"#include "], + "py_forwards": list(forward_decls(name, fn_pairs, method=method)), + "py_methods": [ + method_impl(name, module, fn_pairs, method=method, symint=symint) + ], + "py_method_defs": [method_def(name, module, fn_pairs, method=method)], + } + + fm.write_sharded( + filename, + grouped.items(), + base_env={ + "generated_comment": "@" + + f"generated from {fm.template_dir_for_comments()}/{filename}", + }, + key_fn=key_func, + env_callable=env_func, + num_shards=num_shards, + sharded_keys={"ops_headers", "py_forwards", "py_methods", "py_method_defs"}, + ) + + +def load_signatures( + native_functions: List[NativeFunction], + deprecated_yaml_path: str, + *, + method: bool, + skip_deprecated: bool = False, + pyi: bool = False, +) -> Sequence[PythonSignatureNativeFunctionPair]: + @with_native_function + def gen_signature_pairs(f: NativeFunction) -> PythonSignatureNativeFunctionPair: + return PythonSignatureNativeFunctionPair( + signature=signature(f, method=method, pyi=pyi), + function=f, + ) + + pairs = list(map(gen_signature_pairs, native_functions)) + deprecated = load_deprecated_signatures( + pairs, deprecated_yaml_path, method=method, pyi=pyi + ) + return pairs if skip_deprecated else pairs + deprecated + + +def load_deprecated_signatures( + pairs: Sequence[PythonSignatureNativeFunctionPair], + deprecated_yaml_path: str, + *, + method: bool, + pyi: bool, +) -> List[PythonSignatureNativeFunctionPair]: + # The deprecated.yaml doesn't have complete type information, we need + # find and leverage the original ATen signature (to which it delegates + # the call) to generate the full python signature. + # We join the deprecated and the original signatures using type-only form. + + # group the original ATen signatures by name + grouped: Dict[str, List[PythonSignatureNativeFunctionPair]] = defaultdict(list) + for pair in pairs: + grouped[pair.signature.name].append(pair) + + # find matching original signatures for each deprecated signature + results: List[PythonSignatureNativeFunctionPair] = [] + + with open(deprecated_yaml_path, "r") as f: + deprecated_defs = yaml.load(f, Loader=YamlLoader) + + for deprecated in deprecated_defs: + schema = FunctionSchema.parse(deprecated["name"]) + aten_name, call_args = split_name_params(deprecated["aten"]) + is_out = aten_name.endswith("_out") + if is_out: + aten_name = aten_name.replace("_out", "") + + # HACK: these are fixed constants used to pass the the aten function. + # The type must be known ahead of time + known_constants = { + "1": Type.parse("Scalar"), + } + schema_args_by_name = {a.name: a for a in schema.arguments.flat_all} + for name in call_args: + assert ( + name in schema_args_by_name or name in known_constants + ), f"deprecation definiton: Unrecognized value {name}" + + # Map deprecated signature arguments to their aten signature and test + # if the types and alias annotation match. + def is_schema_compatible( + aten_schema: FunctionSchema, + ) -> bool: + arguments: Iterable[Argument] + if is_out: + arguments = itertools.chain( + aten_schema.arguments.out, aten_schema.arguments.flat_non_out + ) + else: + arguments = aten_schema.arguments.flat_all + + for i, arg in enumerate(arguments): + if i < len(call_args): + arg_name = call_args[i] + if arg_name in known_constants: + schema_type = known_constants[arg_name] + schema_annotation = None + else: + schema_arg = schema_args_by_name[arg_name] + schema_type = schema_arg.type + schema_annotation = schema_arg.annotation + + if schema_type != arg.type or schema_annotation != arg.annotation: + return False + else: + if arg.default is None: + return False + + return len(schema.returns) == len(aten_schema.returns) and all( + a == b for a, b in zip(schema.returns, aten_schema.returns) + ) + + any_schema_found = False + for pair in grouped[aten_name]: + if not is_schema_compatible(pair.function.func): + continue + any_schema_found = True + + python_sig = signature_from_schema( + schema, + category_override=pair.function.category_override, + method=method, + pyi=pyi, + ) + + results.append( + PythonSignatureNativeFunctionPair( + signature=PythonSignatureDeprecated( + name=python_sig.name, + input_args=python_sig.input_args, + input_kwargs=python_sig.input_kwargs, + output_args=python_sig.output_args, + tensor_options_args=python_sig.tensor_options_args, + method=python_sig.method, + deprecated_schema=schema, + deprecated_args_exprs=tuple(call_args), + returns=python_sig.returns, + ), + function=pair.function, + ) + ) + assert ( + any_schema_found + ), f"No native function with name {aten_name} matched signature:\n {str(schema)}" + + return results + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Named Tuple Codegen +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +@with_native_function +def gen_namedtuple_typename_key(f: NativeFunction) -> str: + name = cpp.name(f.func) + fieldnames = namedtuple_fieldnames(f.func.returns) + return "_".join([name] + fieldnames) + + +def emit_namedtuple_call( + overloads: Sequence[PythonSignatureNativeFunctionPair], +) -> Tuple[List[str], Dict[str, str]]: + """ + Generate block of named tuple type def inits, and add typeref snippets + to declarations that use them + """ + typenames: Dict[ + str, str + ] = {} # map from unique name + field name lists to typedef name + typedefs: List[str] = [] # typedef declarations and init code + + for overload in overloads: + fieldnames = namedtuple_fieldnames(overload.function.func.returns) + if not fieldnames: + continue + + name = cpp.name(overload.function.func) # use @with_native_function? + tn_key = gen_namedtuple_typename_key(overload.function) + typename = typenames.get(tn_key) + if typename is None: + typename = f'NamedTuple{"" if not typedefs else len(typedefs)}' + typenames[tn_key] = typename + typedefs.append( + f"""\ +static PyTypeObject* {typename} = get_namedtuple("{name}");""" + ) + + return typedefs, typenames + + +def generate_return_type_definition_and_map_entry( + overloads: Sequence[PythonSignatureNativeFunctionPair], +) -> Tuple[List[str], List[str]]: + """ + Generate block of function in `python_return_types.cpp` to initialize + and return named tuple for a native function which returns named tuple + and relevant entry for the map in same file. + """ + typenames: Dict[ + str, str + ] = {} # map from unique name + field name lists to typedef name + definitions: List[str] = [] # function defintion to register the typedef + map_entries: List[ + str + ] = [] # C++ map entry of + + for overload in overloads: + fieldnames = namedtuple_fieldnames(overload.function.func.returns) + if not fieldnames: + continue + + fields = ", ".join(f'{{"{fn}", ""}}' for fn in fieldnames) + + name = cpp.name(overload.function.func) # use @with_native_function? + tn_key = gen_namedtuple_typename_key(overload.function) + typename = typenames.get(tn_key) + + if typename is None: + typename = f'{name}NamedTuple{"" if not definitions else len(definitions)}' + typenames[tn_key] = typename + definitions.append( + f"""\ +PyTypeObject* get_{name}_namedtuple() {{ + static PyStructSequence_Field NamedTuple_fields[] = {{ {fields}, {{nullptr}} }}; + static PyTypeObject {typename}; + static bool is_initialized = false; + static PyStructSequence_Desc desc = {{ "torch.return_types.{name}", nullptr, NamedTuple_fields, {len(fieldnames)} }}; + if (!is_initialized) {{ + PyStructSequence_InitType(&{typename}, &desc); + {typename}.tp_repr = (reprfunc)torch::utils::returned_structseq_repr; + is_initialized = true; + }} + return &{typename}; +}} +""" + ) + map_entries.append(f'{{"{name}", get_{name}_namedtuple()}}, ') + + return definitions, map_entries + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Method Impl Codegen +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + +# python binding for all overloads of a particular function/method +PY_VARIABLE_METHOD_VARARGS = CodeTemplate( + r"""\ +// ${name} +static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs) +{ + ${method_header} + static PythonArgParser parser({ + ${signatures} + }, /*traceable=*/${traceable}); + + ParsedArgs<${max_args}> parsed_args; + auto _r = parser.parse(${self_}, args, kwargs, parsed_args); + ${check_has_torch_function} + switch (_r.idx) { + ${dispatch} + } + ${method_footer} +} + +""" +) + +# handler for a single parsed signature - may be a single overload or +# a pair of overloads that whose signatures only differ in output params +# (plugged into PY_VARIABLE_METHOD_VARARGS as an item in ${dispatch}) +PY_VARIABLE_CASE = CodeTemplate( + """\ +case ${overload_index}: { + ${body} +} +""" +) + +# python binding for single-overload function/method +PY_VARIABLE_METHOD_VARARGS_SINGLETON = CodeTemplate( + """\ +// ${name} +static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs) +{ + ${method_header} + static PythonArgParser parser({ + ${signatures} + }, /*traceable=*/${traceable}); + + ParsedArgs<${max_args}> parsed_args; + auto _r = parser.parse(${self_}, args, kwargs, parsed_args); + ${check_has_torch_function} + ${dispatch} + ${method_footer} +} + +""" +) + +# python binding for a method with no args, shortcuts parsing +PY_VARIABLE_METHOD_NOARGS = CodeTemplate( + """\ +// ${name} +static PyObject * ${pycname}(PyObject* self_, PyObject* args) +{ + ${method_header} + ${check_has_torch_function} + ${dispatch} + ${method_footer} +} + +""" +) + + +def method_impl( + name: BaseOperatorName, + module: Optional[str], + overloads: Sequence[PythonSignatureNativeFunctionPair], + *, + method: bool, + symint: bool = True, +) -> str: + """ + Generate a python binding for all overloads of an op. + """ + pycname = get_pycname(name) + noarg = is_noarg(overloads) + namedtuple_inits, namedtuple_typenames = emit_namedtuple_call(overloads) + + method_header = ["HANDLE_TH_ERRORS"] + method_header += namedtuple_inits + method_header += ( + ["const Tensor& self = THPVariable_Unpack(self_);"] if method else [] + ) + + method_footer = ([] if noarg else ["Py_RETURN_NONE;"]) + ["END_HANDLE_TH_ERRORS"] + + traceable = "true" if all(should_trace(o.function) for o in overloads) else "false" + + grouped_overloads: Sequence[PythonSignatureGroup] = group_overloads( + overloads, symint=symint + ) + is_singleton = len(grouped_overloads) == 1 + signatures: List[str] = [] + dispatch: List[str] = [] + for overload_index, overload in enumerate(grouped_overloads): + signature = overload.signature.signature_str(symint=symint) + signatures.append(f"{cpp_string(str(signature))},") + dispatch_body = emit_dispatch_case( + overload, namedtuple_typenames, symint=symint + ) + dispatch.append( + PY_VARIABLE_CASE.substitute( + overload_index=overload_index, body=dispatch_body + ) + if not is_singleton + else dispatch_body + ) + + if noarg: + template = PY_VARIABLE_METHOD_NOARGS + elif is_singleton: + template = PY_VARIABLE_METHOD_VARARGS_SINGLETON + else: + template = PY_VARIABLE_METHOD_VARARGS + + return template.substitute( + name=name, + pycname=pycname, + method_header=method_header, + max_args=max(map(lambda o: o.signature.arguments_count(), overloads)), + signatures=signatures, + traceable=traceable, + check_has_torch_function=gen_has_torch_function_check( + name=name, + module=module, + noarg=noarg, + method=method, + ), + dispatch=dispatch, + method_footer=method_footer, + self_="self_" if method else "nullptr", + ) + + +def gen_has_torch_function_check( + name: BaseOperatorName, module: Optional[str], *, noarg: bool, method: bool +) -> str: + if noarg: + if method: + return f"""\ +if(check_has_torch_function(self_)) {{ + return handle_torch_function(self_, "{name}"); +}} +""" + else: + return "" + + self_ = "self_" if method else "nullptr" + namespace = ( + { + "torch": "THPVariableFunctionsModule", + "torch.nn": "THPNNVariableFunctionsModule", + "torch.fft": "THPFFTVariableFunctionsModule", + "torch.linalg": "THPLinalgVariableFunctionsModule", + "torch.nested": "THPNestedVariableFunctionsModule", + "torch.sparse": "THPSparseVariableFunctionsModule", + "torch.special": "THPSpecialVariableFunctionsModule", + }[module] + if module + else "THPVariableClass" + ) + + return f"""\ +if(_r.has_torch_function()) {{ + return handle_torch_function(_r, {self_}, args, kwargs, {namespace}, "{module or "torch.Tensor"}"); +}} +""" + + +# handler for output/no-output overload pair +PY_VARIABLE_OUT = CodeTemplate( + """\ +if (_r.isNone(${out_idx})) { + ${call_dispatch} +} else { + ${call_dispatch_out} +} +""" +) + + +def emit_dispatch_case( + overload: PythonSignatureGroup, + namedtuple_typenames: Dict[str, str], + *, + symint: bool = True, +) -> str: + """ + Emit dispatch code for a single parsed signature. This corresponds to either + a single native function, or a pair that differ only in output params. In the + latter case, a single python signature is used for both and dispatching + switches on the presence/absence of passed output args. + """ + if overload.outplace is not None: + # dispatch output and no-output variants, branch on _r.isNone() + return PY_VARIABLE_OUT.substitute( + out_idx=overload.signature.output_idx(), + call_dispatch=emit_single_dispatch( + overload.signature, overload.base, namedtuple_typenames, symint=symint + ), + call_dispatch_out=emit_single_dispatch( + overload.signature, + overload.outplace, + namedtuple_typenames, + symint=symint, + ), + ) + else: + # no-output version only + return emit_single_dispatch( + overload.signature, overload.base, namedtuple_typenames, symint=symint + ) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Forward Declarations Codegen +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def forward_decls( + name: BaseOperatorName, + overloads: Sequence[PythonSignatureNativeFunctionPair], + *, + method: bool, +) -> Tuple[str, ...]: + if method: + return () + + pycname = get_pycname(name) + if is_noarg(overloads): + return ( + f"""\ +static PyObject * {pycname}(PyObject* self_, PyObject* args); +""", + ) + else: + return ( + f"""\ +static PyObject * {pycname}(PyObject* self_, PyObject* args, PyObject* kwargs); +""", + ) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Method Def (Binding Table Entry) Codegen +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def method_def( + name: BaseOperatorName, + module: Optional[str], + overloads: Sequence[PythonSignatureNativeFunctionPair], + *, + method: bool, +) -> str: + """ + Generate method def entry. + """ + pycname = get_pycname(name) + + if is_noarg(overloads): + pyfunc_cast = "" + flags = "METH_NOARGS" if method else "METH_VARARGS | METH_KEYWORDS" + else: + pyfunc_cast = "castPyCFunctionWithKeywords" + flags = "METH_VARARGS | METH_KEYWORDS" + + if module == "torch": + flags += " | METH_STATIC" + + if name.dunder_method: + # PyMethodDef entry for binary op, throws not implemented error + return f"""\ +{{"{name}", {pyfunc_cast}(TypeError_to_NotImplemented_<{pycname}>), {flags}, NULL}},""" + else: + # PyMethodDef entry + return f"""\ +{{"{name}", {pyfunc_cast}({pycname}), {flags}, NULL}},""" + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Overload Sorting and Grouping +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def group_overloads( + overloads: Sequence[PythonSignatureNativeFunctionPair], *, symint: bool = True +) -> Sequence[PythonSignatureGroup]: + bases: Dict[str, PythonSignatureNativeFunctionPair] = {} + outplaces: Dict[str, PythonSignatureNativeFunctionPair] = {} + + # first group by signature ignoring out arguments + for overload in overloads: + sig = overload.signature.signature_str(skip_outputs=True, symint=symint) + if overload.function.func.is_out_fn(): + if sig in outplaces: + raise RuntimeError( + f"Found duplicated function definition:\n- {overload.function.func}.\n" + f"Existing definition:\n- {outplaces[sig].function.func}." + ) + outplaces[sig] = overload + else: + if sig in bases: + raise RuntimeError( + f"Found duplicated function definition:\n- {overload.function.func}.\n" + f"Existing definition:\n- {bases[sig].function.func}." + ) + bases[sig] = overload + + for sig, out in outplaces.items(): + if sig not in bases: + candidates: List[str] = [] + for overload in overloads: + if ( + str(overload.function.func.name.name) + == str(out.function.func.name.name) + and not overload.function.func.is_out_fn() + and not overload.signature.deprecated + ): + candidates.append( + overload.signature.signature_str( + skip_outputs=True, symint=symint + ) + ) + out_sig = out.signature.signature_str(symint=symint) + raise RuntimeError( + f"While identifying overloads, we found an out schema {out_sig} without a corresponding non-out variant. " + f"We expected the non-out variant to have schema: \n- {sig}\nPlease check that you spelled the schema " + "correctly in native_functions.yaml. We discovered the following candidate(s): \n" + + "\n".join(f"- {candidate}" for candidate in candidates) + ) + + grouped = [ + PythonSignatureGroup.from_pairs( + functional=base, + out=outplaces.get(sig), + ) + for sig, base in bases.items() + ] + return sort_overloads(grouped, symint=symint) + + +# This function declares a partial order on declarations, and sorts them according +# to its linear extension. This is necessary, because there's some ambiguity in the +# choice of overload, and we want a different order. +# +# See Note[Order of overloads matters] +# +# A few examples of ambiguous python signature pairs. +# +# All parameters have the same type, except one taking Tensor the other taking +# Scalar. A numeric PyObject can be casted into Tensor, and a zero-dim Tensor +# object can be accepted as Scalar type parameter (see python_arg_parser.cpp). +# Therefore, same input arguments might be accepted by either python signature. +# We want to always parse the one taking Tensor first. +# +# bitwise_and(Tensor input, Tensor other, *, Tensor out=None) +# bitwise_and(Tensor input, Scalar other, *, Tensor out=None) +# +# If they have different number of parameters then they are not ambiguous - but +# the difference on output param can be ignored as it's optional. +# +# multiply(Tensor input, Tensor other, *, Tensor out=None) +# multiply(Tensor input, Scalar other) +# +# Both positional args and keyword-only args are considered together. +# +# subtract(Tensor other, *, Scalar alpha=1) +# subtract(Scalar other, Scalar alpha=1) +# +# A few ambiguous cases which it does NOT handle yet. +# +# If there is any difference in other parameters besides the Tensor/Scalar +# difference, then they are not considered ambiguous by this method anymore. +# However, the difference could be too trivial to disambiguate. +# +# foo(Tensor input, Scalar other, Scalar bar) +# foo(Tensor input, Tensor other, double bar) +# +# If they are taking different number of parameters then they are not considered +# ambiguous anymore, even if the difference is only on optional kwargs. +# +# foo(Scalar other, Scalar alpha=1) +# foo(Tensor other, *, Scalar alpha=1, Scalar beta=1) +# + + +def sort_overloads( + grouped_overloads: Sequence[PythonSignatureGroup], *, symint: bool = True +) -> Sequence[PythonSignatureGroup]: + # NB: Smaller here means lower priority + + def is_arg_smaller(t1: Type, t2: Type) -> bool: + return ( + str(t1) == "Scalar" + and str(t2) == "Tensor" + or str(t1) == "Scalar?" + and str(t2) == "Tensor?" + or "Dimname" in str(t1) + and "Dimname" not in str(t2) + or + # In the discussion https://github.com/pytorch/pytorch/issues/54555 it has been + # discussed why it is important to prioritize int/int? over int[] + str(t1) == "int[]" + and (str(t2) == "int" or str(t2) == "int?") + or + # TensorList currently throws an error during argument parsing, that's why it needs to be + # last in signature ordering. See discussion: https://github.com/pytorch/pytorch/issues/58087 + str(t1) == "Tensor[]" + and str(t2).find("[]") != -1 + or + # Prioritize IntArrayRef overload over SymIntArrayRef + str(t1) == "SymInt[]" + and str(t2) == "int[]" + or + # Make sure both in, SymInt are sorted consistently w.r.t. Tensor since Tensor can be implicitly + # converted to either int or SymInt. Prioritize the Tensor overload since it otherwise gets shadowed. + (str(t1) == "SymInt" or str(t1) == "int") + and str(t2) == "Tensor" + ) + + def is_smaller(s1: PythonSignature, s2: PythonSignature) -> bool: + """Returns True if s1 < s2 in the partial order.""" + args1, args2 = s1.arguments(skip_outputs=True), s2.arguments(skip_outputs=True) + if len(args1) != len(args2): + return False + # TODO: should use some canonical form instead of 'str(arg.type)' - see comments + # above. The old codegen used the deprecated 'dynamic_type(arg.type)', which + # ignores the optional annotation, i.e. 'Scalar' and 'Scalar?'. + equal = all(arg1.type == arg2.type for arg1, arg2 in zip(args1, args2)) + smaller_or_equal = all( + str(arg1.type) == str(arg2.type) or is_arg_smaller(arg1.type, arg2.type) + for arg1, arg2 in zip(args1, args2) + ) + return smaller_or_equal and not equal + + # First sort by signature + grouped_overloads = sorted( + grouped_overloads, key=lambda x: x.signature.signature_str(symint=symint) + ) + + # Construct the relation graph + larger_than: Dict[int, Set[int]] = defaultdict(set) + for i1, overload1 in enumerate(grouped_overloads): + for i2, overload2 in enumerate(grouped_overloads): + if is_smaller(overload1.signature, overload2.signature): + larger_than[i1].add(i2) + + if not larger_than: + return list(grouped_overloads) + + # Use a topological sort to sort overloads according to the partial order. + N = len(grouped_overloads) + sorted_ids: List[int] = list(filter(lambda x: x not in larger_than, range(N))) + + for idx in range(N): + # The size of sorted_ids will grow to N eventually. + i = sorted_ids[idx] + for j in sorted(larger_than.keys()): + larger = larger_than[j] + larger.discard(i) + if not larger: + del larger_than[j] + sorted_ids.append(j) + + return list(map(lambda x: grouped_overloads[x], sorted_ids)) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Codegen API Integration +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def emit_single_dispatch( + ps: PythonSignature, + f: NativeFunction, + namedtuple_typenames: Dict[str, str], + *, + symint: bool = True, +) -> str: + """ + Emit dispatch code for a single native function. + """ + + @with_native_function + def go(f: NativeFunction) -> str: + # header comments + if isinstance(ps, PythonSignatureDeprecated): + schema_comment = f"// [deprecated] aten::{ps.deprecated_schema}" + else: + schema_comment = f"// aten::{f.func}" + + deprecated = "[deprecated] " if ps.deprecated else "" + + # dispatch lambda signature + name = cpp.name(f.func) + lambda_formals = ", ".join( + map( + lambda a: f"{a.type_str} {a.name}", + dispatch_lambda_args(ps, f, symint=symint), + ) + ) + lambda_return = dispatch_lambda_return_str(f) + + # dispatch lambda body + dispatch_callee = cpp_dispatch_target(f) + dispatch_args = ", ".join(cpp_dispatch_exprs(f, python_signature=ps)) + + # from arg parser outputs to dispatch lambda arguments + parser_outputs = arg_parser_output_exprs(ps, f, symint=symint) + lambda_arg_exprs = dispatch_lambda_exprs(ps, f, symint=symint) + inits = "\n".join(lambda_arg_exprs.inits) + lambda_args = ", ".join(lambda_arg_exprs.exprs) + + # scatter fields + # TODO: Checking `ps.method and ('requires_grad' in parser_outputs)` is a hacky + # solution for enabling the 'requires_grad' argument for tensor methods + # new_full, new_empty, and new_zeros. A much better but more difficult to + # implement solution involves refactoring according to Ed's description here: + # https://github.com/pytorch/pytorch/issues/36455#issuecomment-614767589 + need_set_requires_grad = ps.tensor_options_args and ( + not has_tensor_options(f) + or (ps.method and ("requires_grad" in parser_outputs)) + ) + set_requires_grad = ( + f'.set_requires_grad({parser_outputs["requires_grad"].expr})' + if need_set_requires_grad + else "" + ) + + if lambda_return == "void": + return f"""\ +{schema_comment} +{inits} +auto dispatch_{name} = []({lambda_formals}) -> {lambda_return} {{ + pybind11::gil_scoped_release no_gil; + {dispatch_callee}({dispatch_args}); +}}; +dispatch_{name}({lambda_args}){set_requires_grad}; +Py_RETURN_NONE; +""" + else: + typename = namedtuple_typenames.get(gen_namedtuple_typename_key(f)) + namedtuple_typeref = f"{typename}, " if typename is not None else "" + return f"""\ +{schema_comment} +{inits} +auto dispatch_{name} = []({lambda_formals}) -> {lambda_return} {{ + pybind11::gil_scoped_release no_gil; + return {dispatch_callee}({dispatch_args}); +}}; +return wrap({namedtuple_typeref}dispatch_{name}({lambda_args}){set_requires_grad}); +""" + + return go(f) diff --git a/scripts/generate_stuff/gen_trace_type.py b/scripts/generate_stuff/gen_trace_type.py new file mode 100644 index 0000000..45796d8 --- /dev/null +++ b/scripts/generate_stuff/gen_trace_type.py @@ -0,0 +1,548 @@ +import itertools +from typing import Dict, List, Sequence, Union + +from torchgen.api import cpp + +from torchgen.api.types import DispatcherSignature +from torchgen.code_template import CodeTemplate +from torchgen.context import with_native_function +from torchgen.model import Argument, NativeFunction, SchemaKind, TensorOptionsArguments +from torchgen.utils import FileManager + +# Note [Manual Backend kernels] +# For these ops, we want to manually register to dispatch key Backend and +# skip codegen-ed registeration to all keys before Backend. +# For codegen this means: +# - op set below must match ops with manual_kernel_registration=True in native_functions.yaml +# where we skip codegen backend kernels +# - all ops below are part of MANUAL_AUTOGRAD to skip codegen Autograd kernel registration +# - all ops below are part of MANUAL_TRACER to skip codegen Tracer kernel registration +# Note: we still register to dispatch key Profiler for these ops, keeping it untouched for now. +# You can find the manual registration in torch/csrc/autograd/VariableTypeManual.cpp +MANUAL_BACKEND = set( + [ + "options", + "data", + "set_data", + "is_leaf", + "output_nr", + "_version", + "retain_grad", + "_backward", + "requires_grad_", + ] +) + +# For these ops we want to skip the codegen-ed registration to both Autograd and Tracer keys. +# You can find the manual registration in torch/csrc/autograd/VariableTypeManual.cpp +MANUAL_AUTOGRAD_AND_TRACER = set( + [ + "resize_", + "resize_as_", + "detach", + "detach_", + "copy_", + "_fw_primal", + "_make_dual", + ] +) + +# Currently MANUAL_AUTOGRAD and MANUAL_TRACER share the same set of ops: +# union(MANUAL_BACKEND, MANUAL_AUTOGRAD_AND_TRACER) +# You can find the manual registration in torch/csrc/autograd/VariableTypeManual.cpp +MANUAL_AUTOGRAD = MANUAL_TRACER = MANUAL_BACKEND | MANUAL_AUTOGRAD_AND_TRACER + +# These functions we don't want to record for tracing, because we always want +# to trace their constituent parts. This is a temporary hack in lieue +# of proper scopes, where subsequent compilation passes can ask for the unfolding +# on demand. Only concrete ATen methods can be disabled this way; it will have +# NO EFFECT otherwise. +DONT_RECORD_TRACE = { + "convolution", + "conv1d", + "conv2d", + "conv3d", + "conv_transpose1d", + "conv_transpose2d", + "conv_transpose3d", + "lstm_cell", + "gru_cell", + "rnn_tanh_cell", + "rnn_relu_cell", + # FIXME: figure out a better way when we support sparse tensors in jit + "_coalesced", +} + + +def should_trace(f: NativeFunction) -> bool: + # Operations involving Storage or Type are not traceable at the moment + if any( + str(arg.type) in {"Storage", "Type", "ConstQuantizerPtr"} + for arg in f.func.schema_order_arguments() + ): + return False + # We can't trace functions which don't have any Tensor or TensorList returns + if not any(r.type.is_tensor_like() for r in f.func.returns): + return False + return f.func.name.name.base not in DONT_RECORD_TRACE + + +SELECT = CodeTemplate( + """\ + +if (${cond}) { + ${true} +} else { + ${false} +} +""" +) + +OP_NAME = CodeTemplate( + """\ +op_name = c10::Symbol::fromQualString("aten::${trace_name}"); +""" +) + +# These functions have their names recorded under trace renamed, +RENAME_TRACE = { + "zero": "zeros_like", # replacing aten::zero_ with aten::zeros_like + "fill": "full_like", # replacing aten::fill_ with aten::full_like +} + + +def format_trace_op_name(f: NativeFunction) -> str: + # TODO: byte-for-byte compatible with old codegen behavior - should clean up + if ( + f.func.kind() in (SchemaKind.functional, SchemaKind.out) + or f.func.name.name.dunder_method + ): + # special case for *_out functions: the in-place and out-of-place ops + # are overloaded with the same name in the JIT + trace_name = str(f.func.name.name) + trace_name = RENAME_TRACE.get(trace_name, trace_name) + return OP_NAME.substitute(trace_name=trace_name) + + # otherwise, this is an in-place op and we need to emit both in- and + # out-of-place versions + outplace_trace_name = f.func.name.name.base + inplace_trace_name = cpp.name(f.func) + outplace_trace_name = RENAME_TRACE.get(outplace_trace_name, outplace_trace_name) + inplace_trace_name = RENAME_TRACE.get(inplace_trace_name, inplace_trace_name) + + return SELECT.substitute( + cond="tracer_state->force_outplace", + true=OP_NAME.substitute(trace_name=outplace_trace_name), + false=OP_NAME.substitute(trace_name=inplace_trace_name), + ) + + +ADD_TRACE_INPUT = CodeTemplate("""jit::tracer::addInputs(node, "${name}", ${input});""") + + +def format_trace_inputs(f: NativeFunction) -> str: + def dispatch_trace_input( + arg: Union[Argument, TensorOptionsArguments] + ) -> Sequence[str]: + if isinstance(arg, TensorOptionsArguments): + name = "options" + return [ + ADD_TRACE_INPUT.substitute( + name=name, input="optTypeMetaToScalarType(options.dtype_opt())" + ), + ADD_TRACE_INPUT.substitute(name=name, input="options.layout()"), + ADD_TRACE_INPUT.substitute(name=name, input="options.device()"), + ADD_TRACE_INPUT.substitute(name=name, input="options.pinned_memory()"), + ] + else: + name = arg.name + if str(arg.type) == "Tensor?[]": + return [f'jit::tracer::addInputs(node, "{name}", {name});'] + else: + return [ADD_TRACE_INPUT.substitute(name=name, input=name)] + + args: List[Union[Argument, TensorOptionsArguments]] = list( + f.func.schema_order_arguments() + ) + + if f.func.is_out_fn(): + # *_out functions take the result as a separate argument, but we don't want to + # trace that argument directly. Instead, we trace its TensorOptions. + # So first, we need to remove the out argument from the list of arguments to trace. + # TODO: byte-for-byte compatible with old codegen behavior - it's incorrect to assume + # there is only one output argument. + args = args[:-1] + + trace_inputs = itertools.chain.from_iterable( + dispatch_trace_input(arg) for arg in args + ) + + if f.func.is_out_fn(): + # for *_out functions, handle the result argument differently for inplace/outplace. + # For inplace: just add the input to the end to confirm with the JIT schema + name = f.func.arguments.out[0].name # TODO: old codegen behavior - should fix + inplace = ADD_TRACE_INPUT.substitute(name=name, input=name) + + # for outplace: do nothing, except if the function is a factory. + # Factories are a bit special because their out-of-place overloads + # take an extra TensorOptions argument, which is missing in the _out function + has_tensor_return = any(r.type.is_tensor_like() for r in f.func.returns) + has_tensor_input_arg = any( + a.type.is_tensor_like() for a in f.func.arguments.flat_non_out + ) + is_factory_method = f.category_override == "factory" or ( + has_tensor_return and not has_tensor_input_arg + ) + + # HACK: preserve old codegen behavior - the old codegen set the `is_factory_method` + # flag for the whole family of ops with the same basename if any of them is a + # factory method. For most cases the whole family of ops are indeed all factory + # method - 'normal' is the only exception. So we handle it specially here to avoid + # cloning the old logic. + if f.func.name.name.base == "normal": + is_factory_method = True + + if is_factory_method: + outplace = [ + ADD_TRACE_INPUT.substitute( + name="out", + input="optTypeMetaToScalarType(out.options().dtype_opt())", + ), + ADD_TRACE_INPUT.substitute(name="out", input="out.options().layout()"), + ADD_TRACE_INPUT.substitute(name="out", input="out.options().device()"), + ADD_TRACE_INPUT.substitute( + name="out", input="out.options().pinned_memory()" + ), + ] + else: + outplace = [] + + trace_inputs = itertools.chain( + trace_inputs, + [ + SELECT.substitute( + cond="tracer_state->force_outplace", + true="\n".join(outplace), + false=inplace, + ) + ], + ) + + return "\n".join(trace_inputs) + + +# `torch.jit.trace` have undocumented keyword argument `_force_outplace`, +# which force jit to replace functions with outplace variants (for +# example `aten::add_` becomes `aten::add`). +# +# This replacement implemented in-place with minimum modifications of +# arguments stack (as it assumes that outplace call has the same arguments +# as inplace version). +# +# However there are no such substitutions available for `aten::fill_` +# and `aten::zero_` operators, as we never implemented `aten::fill` +# and `aten::zero`. So jit tracing hack replacing `aten::zero_` with +# `aten::zeros_like` and replacing `aten::fill_` with `aten::full_like`. +# +# But as they potentially can have different arguments, we also have +# to hack into the stack and add missing ones. +# +# A possible alternative would be: +# +# - Add `aten::fill` and `aten::zero` +# +# - Or keep `aten::zeros_like` arguments aligned with `aten::zero_` +# arguments (inside of the `native_functions.yaml`) +RENAME_TRACE_ADD_ARGS = { + "fill": """\ + jit::tracer::addInputs(node, "options", c10::optional()); + jit::tracer::addInputs(node, "options", layout_or_default(c10::nullopt)); + jit::tracer::addInputs(node, "options", device_or_default(c10::nullopt)); + jit::tracer::addInputs(node, "options", pinned_memory_or_default(c10::nullopt)); + c10::optional memory_format = c10::MemoryFormat::Preserve; + jit::tracer::addInputs(node, "memory_format", memory_format); +""", + "zero": """\ + jit::tracer::addInputs(node, "options", c10::optional()); + jit::tracer::addInputs(node, "options", layout_or_default(c10::nullopt)); + jit::tracer::addInputs(node, "options", device_or_default(c10::nullopt)); + jit::tracer::addInputs(node, "options", pinned_memory_or_default(c10::nullopt)); + c10::optional memory_format = c10::MemoryFormat::Preserve; + jit::tracer::addInputs(node, "memory_format", memory_format); +""", +} + +INPLACE_GUARD = CodeTemplate( + """\ +jit::tracer::ensureUniqueIfOutOfPlaced("${name}", ${mutable_input}); +""" +) + +PRE_RECORD_TRACE = CodeTemplate( + """\ +torch::jit::Node* node = nullptr; +std::shared_ptr tracer_state; +if (jit::tracer::isTracing()) { + tracer_state = jit::tracer::getTracingState(); + at::Symbol op_name; + ${set_op_name} + node = tracer_state->createNode(op_name, /*num_outputs=*/0); + jit::tracer::recordSourceLocation(node); + ${add_trace_inputs} + tracer_state->insertNode(node); + ${inplace_guard} + jit::tracer::setTracingState(nullptr); +} +""" +) + + +def format_prerecord_trace(f: NativeFunction) -> str: + if not should_trace(f): + return "" + + # TODO: clean up old codegen behavior + is_inplace = ( + f.func.kind() in (SchemaKind.inplace, SchemaKind.out) + and not f.func.name.name.dunder_method + ) + add_args = ( + RENAME_TRACE_ADD_ARGS.get(f.func.name.name.base, "") if is_inplace else "" + ) + additional_inputs = ( + SELECT.substitute( + cond="tracer_state->force_outplace", + true=add_args, + false="", + ) + if add_args + else "" + ) + + return PRE_RECORD_TRACE.substitute( + set_op_name=format_trace_op_name(f), + add_trace_inputs=format_trace_inputs(f) + additional_inputs, + inplace_guard=INPLACE_GUARD.substitute( + name=cpp.name(f.func), + mutable_input=f.func.arguments.out[0].name + if f.func.arguments.out + else "self", + ) + if is_inplace + else "", + ) + + +POST_RECORD_TRACE = CodeTemplate( + """\ +if (tracer_state) { + jit::tracer::setTracingState(std::move(tracer_state)); + ${add_trace_outputs} +} +""" +) + + +def format_postrecord_trace(f: NativeFunction) -> str: + if not should_trace(f): + return "" + + # For outplacing ops, *_out overloads require special handling to move the + # output *argument* to a return value + if f.func.is_out_fn(): + output_names_outplace = [arg.name for arg in f.func.arguments.out] + output_names_inplace = cpp.return_names(f) + + # Code size optimization: the common case is that the return value is + # the same for both variants + if output_names_outplace == output_names_inplace: + outputs = [ + f"jit::tracer::addOutput(node, {n});" for n in output_names_outplace + ] + return POST_RECORD_TRACE.substitute(add_trace_outputs=outputs) + + selection = SELECT.substitute( + cond="force_outplace", + true="\n".join( + f"jit::tracer::addOutput(node, {n});" for n in output_names_outplace + ), + false="\n".join( + f"jit::tracer::addOutput(node, {n});" for n in output_names_inplace + ), + ) + return POST_RECORD_TRACE.substitute(add_trace_outputs=selection) + else: + output_names = cpp.return_names(f) + outputs = [f"jit::tracer::addOutput(node, {n});" for n in output_names] + return POST_RECORD_TRACE.substitute(add_trace_outputs=outputs) + + +def declare_returned_variables(f: NativeFunction) -> str: + modifies_arguments = f.func.kind() in (SchemaKind.inplace, SchemaKind.out) + if modifies_arguments: + return "" + if len(f.func.returns) == 1: + return "" + types = [cpp.return_type(r, symint=True) for r in f.func.returns] + names = cpp.return_names(f) + return "\n".join(f"{type.cpp_type()} {name};" for type, name in zip(types, names)) + + +def tie_return_values(f: NativeFunction) -> str: + if len(f.func.returns) == 1: + return f'auto {f.func.returns[0].name or "result"}' + names = cpp.return_names(f) + return f'std::tie({", ".join(names)})' + + +def get_return_value(f: NativeFunction) -> str: + names = cpp.return_names(f) + if len(f.func.returns) == 1: + return names[0] + if f.func.kind() == SchemaKind.out: + return f'std::forward_as_tuple({", ".join(names)})' + else: + moved = ", ".join(f"std::move({name})" for name in names) + return f"std::make_tuple({moved})" + + +TRACE_DISPATCH = CodeTemplate( + """\ +${assign_return_values}at::_ops::${unambiguous_name}::redispatch(${unpacked_args});""" +) + + +def emit_trace_body(f: NativeFunction) -> List[str]: + trace_body: List[str] = [] + + trace_body.append(format_prerecord_trace(f)) + trace_body.append(declare_returned_variables(f)) + + dispatcher_sig = DispatcherSignature.from_schema(f.func) + dispatcher_exprs = dispatcher_sig.exprs() + + # code-generated tracing kernels plumb and recompute dispatch keys directly through the kernel for performance. + # See Note [Plumbing Keys Through The Dispatcher] for details. + dispatch_key_set = "ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::Tracer)" + redispatch_args = ", ".join([dispatch_key_set] + [a.expr for a in dispatcher_exprs]) + + assign_return_values = ( + f"{tie_return_values(f)} = " + if f.func.kind() in [SchemaKind.functional, SchemaKind.mutable] + and f.func.returns + else "" + ) + + # Note that this calls the slow, dispatching variants of manual_cpp_binding ops. + # We could probably work harder to ensure that the fast variants are called instead, but the perf benefit would be minimal. + trace_body.append( + TRACE_DISPATCH.substitute( + assign_return_values=assign_return_values, + unambiguous_name=f.func.name.unambiguous_name(), + unpacked_args=redispatch_args, + ) + ) + + trace_body.append(format_postrecord_trace(f)) + if f.func.returns: + trace_body.append(f"return {get_return_value(f)};") + return trace_body + + +METHOD_DEFINITION = CodeTemplate( + """\ +${return_type} ${type_wrapper_name}(${formals}) { + ${type_definition_body} +} +""" +) + + +def type_wrapper_name(f: NativeFunction, key: str = "Default") -> str: + if f.func.name.overload_name: + name = f"{cpp.name(f.func)}_{f.func.name.overload_name}" + else: + name = cpp.name(f.func) + + # The key argument is only used in gen_variable_type where we need fns per autograd dispatch key. + # In gen_trace_type and gen_inplace_view_type where only one fn per native_fn must be generated, + # the key argument should not be passed. + # We do not append key if it is Default so that generated functions from + # before per-dispatch-key derivatives were added retain the same names. + if key != "Default": + name = name + f"_{key}" + return name + + +@with_native_function +def method_definition(f: NativeFunction) -> str: + assert cpp.name(f.func) not in MANUAL_TRACER + + formals = ", ".join( + # code-generated tracing kernels plumb and recompute dispatch keys directly through the kernel for performance. + # See Note [Plumbing Keys Through The Dispatcher] for details. + ["c10::DispatchKeySet ks"] + + [ + f'{cpp.argument_type(a, binds="__placeholder__", symint=True).cpp_type()} {a.name}' + for a in f.func.schema_order_arguments() + ] + ) + + return METHOD_DEFINITION.substitute( + return_type=cpp.returns_type(f.func.returns, symint=True).cpp_type(), + type_wrapper_name=type_wrapper_name(f), + formals=formals, + type_definition_body=emit_trace_body(f), + ) + + +WRAPPER_REGISTRATION = CodeTemplate( + """\ +m.impl("${name}", + TORCH_FN(${class_type}::${type_wrapper_name}) +); +""" +) + + +@with_native_function +def method_registration(f: NativeFunction) -> str: + assert cpp.name(f.func) not in MANUAL_TRACER + + return WRAPPER_REGISTRATION.substitute( + name=f.func.name, + type_wrapper_name=type_wrapper_name(f), + class_type="TraceType", + ) + + +def gen_trace_type_func(fn: NativeFunction) -> Dict[str, List[str]]: + return { + "ops_headers": [f"#include "], + "trace_method_definitions": [method_definition(fn)], + "trace_wrapper_registrations": [method_registration(fn)], + } + + +def gen_trace_type( + out: str, native_functions: List[NativeFunction], template_path: str +) -> None: + # NOTE: see Note [Sharded File] at the top of the VariableType.cpp + # template regarding sharding of the generated files. + fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) + fm.write_sharded( + "TraceType.cpp", + [fn for fn in native_functions if cpp.name(fn.func) not in MANUAL_TRACER], + key_fn=lambda fn: fn.root_name, + base_env={ + "generated_comment": "@" + + f"generated from {fm.template_dir_for_comments()}/TraceType.cpp", + }, + env_callable=gen_trace_type_func, + num_shards=5, + sharded_keys={ + "ops_headers", + "trace_method_definitions", + "trace_wrapper_registrations", + }, + ) diff --git a/scripts/generate_stuff/generate_pytorch_wrappers.py b/scripts/generate_stuff/generate_pytorch_wrappers.py new file mode 100644 index 0000000..481698b --- /dev/null +++ b/scripts/generate_stuff/generate_pytorch_wrappers.py @@ -0,0 +1,172 @@ +import collections +from typing import List + +from torchgen.api.python import PythonSignatureGroup +from torchgen.gen import parse_native_yaml +from torchgen.model import SelfArgument, TensorOptionsArguments, NativeFunction + +from scripts.generate_stuff.gen_pyi import ( + get_py_torch_functions, + blocklist, + binary_ops, + comparison_ops, + symmetric_comparison_ops, + unary_ops, + to_py_type_ops, +) +from scripts.generate_stuff.gen_python_functions import ( + should_generate_py_binding, + load_signatures, +) + + +def generate_type_hints(sig_group: PythonSignatureGroup) -> List[str]: + type_hints: List[str] = [] + + # Some deprecated ops that are on the blocklist are still included in pyi + if sig_group.signature.name in blocklist and not sig_group.signature.deprecated: + return type_hints + + # deprecated signatures have separate entries for their functional and out variants + # (as opposed to the native ops, which fuse the two into a single signature). + # generate the functional variant here, if an out variant exists. + if sig_group.signature.deprecated and sig_group.outplace is not None: + type_hint = sig_group.signature.signature_str_pyi(skip_outputs=True) + type_hints.append(type_hint) + + # PythonSignatureGroups that have both a functional + out variant get a single signature, with an optional out argument + # Generates the out variant if one exists. Otherwise, generate the functional variant + type_hint = sig_group.signature.signature_str_pyi( + skip_outputs=sig_group.outplace is None + ) + type_hints.append(type_hint) + + # Some operators also additionally have a vararg variant of their signature + type_hint_vararg = sig_group.signature.signature_str_pyi_vararg( + skip_outputs=sig_group.outplace is None + ) + if type_hint_vararg: + type_hints.append(type_hint_vararg) + + return type_hints + + +def sig_for_ops(opname: str) -> List[str]: + """sig_for_ops(opname : str) -> List[str] + + Returns signatures for operator special functions (__add__ etc.)""" + + # we have to do this by hand, because they are hand-bound in Python + + assert opname.endswith("__") and opname.startswith("__"), "Unexpected op {}".format( + opname + ) + + name = opname[2:-2] + if name in binary_ops: + return ["def {}(self, other: Any) -> Tensor: ...".format(opname)] + elif name in comparison_ops: + sig = "def {}(self, other: Any) -> Tensor: ...".format(opname) + if name in symmetric_comparison_ops: + # unsafe override https://github.com/python/mypy/issues/5704 + sig += " # type: ignore[override]" + return [sig] + elif name in unary_ops: + return ["def {}(self) -> Tensor: ...".format(opname)] + elif name in to_py_type_ops: + if name in {"bool", "float", "complex"}: + tname = name + elif name == "nonzero": + tname = "bool" + else: + tname = "int" + if tname in {"float", "int", "bool", "complex"}: + tname = "builtins." + tname + return ["def {}(self) -> {}: ...".format(opname, tname)] + else: + raise Exception("unknown op", opname) + + +native_yaml_path = "native_functions.yaml" +tags_yaml_path = "tags.yaml" +deprecated_yaml_path = "deprecated.yaml" + +native_functions = parse_native_yaml(native_yaml_path, tags_yaml_path).native_functions +native_functions = list(filter(should_generate_py_binding, native_functions)) + +function_signatures = load_signatures( + native_functions, deprecated_yaml_path, method=False, pyi=True +) +function_sig_groups = get_py_torch_functions(function_signatures) + +tensor_method_signatures = load_signatures( + native_functions, + deprecated_yaml_path, + method=True, + skip_deprecated=True, + pyi=True, +) +tensor_method_sig_groups = get_py_torch_functions(tensor_method_signatures, method=True) + + +# def create_unique_key(sig_group: PythonSignatureGroup) -> str: +def create_unique_key(base: NativeFunction) -> str: + func = base.func + + # is_vararg = sig_group.signature.signature_str_pyi_vararg( + # skip_outputs=sig_group.outplace is None + # ) + # if is_vararg is not None: + # print(is_vararg) + is_vararg = False + is_varret = False + + overload = "" if not func.name.overload_name else f".{func.name.overload_name}" + if is_vararg: + arg_str = "..." + else: + arg_str = [] + for arg in func.arguments.all: + if isinstance(arg, SelfArgument): + arg_str.append(str(arg.argument.type)) + elif isinstance(arg, TensorOptionsArguments): + arg_str.append(str(arg.dtype.type)) + arg_str.append(str(arg.layout.type)) + arg_str.append(str(arg.device.type)) + arg_str.append(str(arg.pin_memory.type)) + else: + typ_str = str(arg.type) + if typ := arg.type.is_list_like(): + if typ.size is not None: + typ_str = typ_str.replace(str(typ.size), "") + + arg_str.append(typ_str) + + arg_str = ", ".join(arg_str) + + if is_varret: + ret_str = "..." + else: + ret_str = ", ".join(str(ret.type) for ret in func.returns) + if "." in str(func.name): + unqualified_name, _overload = str(func.name).split(".") + else: + unqualified_name = str(func.name) + return f"{base.namespace}::{unqualified_name}{overload} : ({arg_str}) -> ({ret_str})" + + +function_signatures_dict = { + create_unique_key(group.base): group + for group in sorted(function_sig_groups, key=lambda g: g.signature.name) +} +for uniq in function_signatures_dict: + print(uniq) + +tensor_method_signatures_dict = { + create_unique_key(group.base): group + for group in sorted(tensor_method_sig_groups, key=lambda g: g.signature.name) +} +for uniq in tensor_method_signatures_dict: + print(uniq) + +print() \ No newline at end of file diff --git a/scripts/generate_stuff/generate_torch_mlir_extensions.py b/scripts/generate_stuff/generate_torch_mlir_extensions.py new file mode 100644 index 0000000..dc0415a --- /dev/null +++ b/scripts/generate_stuff/generate_torch_mlir_extensions.py @@ -0,0 +1,496 @@ +import warnings +from textwrap import dedent +from typing import List, Callable + +from torch_mlir.dialects.torch.importer.jit_ir.build_tools.registry import ( + JitOperator, + Registry, + _pytype_to_decomposition_fn_pytype, + _get_default_value, + _rename_python_keyword_parameter_name, +) +from torch_mlir.dialects.torch.importer.jit_ir.build_tools.utils import TextEmitter +from torchgen.api.python import signature_from_schema, FunctionSchema + + +# from scripts.generate_stuff.generate_pytorch_wrappers import ( +# tensor_method_signatures_dict, +# function_signatures_dict, +# ) + + +ALL = [] + + +def _get_function_signature( + self, + function_kind: str, + parameter_decl_builder: Callable[["SIG_ATTR_TYPE"], str], + ret_decl_builder: Callable[["SIG_ATTR_TYPE"], str], +) -> str: + mlir_op_name, _ = self.get_mlir_names() + # Replace `.` with a valid Python identifier character. + # `〇` vaguely looks like `.`. + def_name = "_".join(mlir_op_name.split(".")).replace("aten_", "") + parameter_decls = list(map(parameter_decl_builder, self.arguments)) + ret_decls = list(map(ret_decl_builder, self.returns)) + parameters = ", ".join(parameter_decls) + result = ", ".join(ret_decls) + if len(ret_decls) >= 2: + result = f"Tuple[{result}]" + + if len(ret_decls) == 0: + result = "None" + + ALL.append(def_name) + return f"def {def_name}({parameters}) -> {result}:" + + +JitOperator._get_function_signature = _get_function_signature + + +TORCH_TYPE_TO_ODS_TYPE = { + "Tensor": "AnyTorchTensorType", + "Tensor?": "AnyTorchOptionalTensorType", + "Tensor?[]": "AnyTorchListOfOptionalTensorType", + "Tensor[]": "AnyTorchListOfTensorType", + "Scalar": "AnyTorchScalarType", + "Scalar?": "AnyTorchOptionalScalarType", + "int": "Torch_IntType", + "int[]": "AnyTorchListOfTorchIntType", + "int?": "AnyTorchOptionalIntType", + "int[]?": "AnyTorchOptionalListOfTorchIntType", + "bool": "Torch_BoolType", + "bool[]": "AnyTorchListOfTorchBoolType", + "bool?": "AnyTorchOptionalBoolType", + "float": "Torch_FloatType", + "float?": "AnyTorchOptionalFloatType", + "float[]": "AnyTorchListOfTorchFloatType", + "float[]?": "AnyTorchOptionalListOfTorchFloatType", + "t[]": "AnyTorchListType", + "t": "AnyTorchType", + "t1": "AnyTorchType", + "t2": "AnyTorchType", + "Any": "AnyTorchType", + "Device": "Torch_DeviceType", + "Device?": "AnyTorchOptionalDeviceType", + "Generator": "Torch_GeneratorType", + "Generator?": "AnyTorchOptionalGeneratorType", + "str": "Torch_StringType", + "str?": "AnyTorchOptionalStringType", + "str[]": "AnyTorchListOfTorchStringType", + "Dict": "Torch_DictType", + "__torch__.torch.classes.quantized.LinearPackedParamsBase": "Torch_LinearParamsType", +} + + +# pyt_type is reversed +def convert_type(pyt_type: str): + if pyt_type.endswith("?"): + nested, interior = convert_type(pyt_type[:-1]) + return f"Optional[{nested}]", interior + elif pyt_type.endswith("[]"): + nested, interior = convert_type(pyt_type[:-2]) + return f"List[{nested}]", interior + else: + subs = { + "Scalar": "Number", + "t2": "Tensor", + "t1": "Tensor", + "t": "Tensor", + "Dict(str, t)": "Dict[str, Tensor]", + "device": "Device", + } + interior = subs.get(pyt_type, pyt_type) + return interior, interior + + +def convert_type_to_op(arg_name, pyt_type, p_td, emitter_td): + _, interior = convert_type(pyt_type) + + if pyt_type.endswith("?"): + p_td(f"if {arg_name} is not None:") + with emitter_td.indent(): + convert_type_to_op(arg_name, pyt_type[:-1], p_td, emitter_td) + p_td(f"else:") + with emitter_td.indent(): + p_td(f"{arg_name} = torch_dialect.ConstantNoneOp()") + p_td("") + elif pyt_type.endswith("[]"): + if interior == "Tensor": + pass + # p_td(f"{arg_name} = get_op_results_or_values({arg_name})") + else: + op = convert_type_to_op(None, pyt_type[:-2], p_td, emitter_td) + p_td(f"{arg_name} = list(map({op}, {arg_name}))") + p_td(f"{arg_name} = torch_dialect.PrimListConstructOp({arg_name})") + else: + if interior in {"int", "bool", "float", "Number", "str"}: + op = f"torch_dialect.Constant{interior.capitalize()}Op" + if arg_name is not None: + p_td(f"{arg_name} = {op}({arg_name})") + else: + return op + else: + if arg_name is not None: + p_td( + f"assert is_mlir_value({arg_name}), f'`{arg_name}` should be a Value but is {{type({arg_name}).__module__}}.{{type({arg_name}).__name__}}'" + ) + + +EXISTING = { + "ConstantFloatOp", + "ConstantIntOp", + "ConstantStrOp", + "ConstantBoolOp", + "PrimListConstructOp", +} + + +def py_reserved_keywords(k): + subs = { + "from": "from_", + "self": "self_", + "list": "list_", + } + return subs.get(k, k) + + +def get_wrapper_function_signature(operator): + """Gets the Python function signature for this op's decomposition function. + + While this is technically debug-only output, it is useful to copy-paste + it from the debug dump into the shape library definitions, as many + ops have extra default arguments and stuff that are tedious to write out + right. + """ + + def parameter_decl_builder(arg: "SIG_ATTR_TYPE") -> str: + pytype = convert_type(arg["type"])[0] + default = _get_default_value(arg) + if arg["name"] == "out": + default = "= None" + parameter_name = py_reserved_keywords( + _rename_python_keyword_parameter_name(arg["name"]) + ) + return f"{parameter_name}: {pytype}{default}" + + def ret_decl_builder(arg: "SIG_ATTR_TYPE") -> str: + ret = convert_type(arg["type"])[0] + if not ret: + ret = "None" + return ret + + return operator._get_function_signature( + "", parameter_decl_builder, ret_decl_builder + ) + + +def raw_emit_op( + operator: JitOperator, + emitter_td: TextEmitter, + *, + traits: List[str], + has_folder: bool, + has_canonicalizer: bool, +): + p_td = lambda *args: emitter_td.print(*args) + stub_td = lambda *args: stubs_emitter_td.print(*args) + op_name, cpp_class_name = operator.get_mlir_names() + if cpp_class_name in {"QuantizedLinearOp"} | EXISTING: + return + + # Generate unique result names for ops with nameless results + multiple_results = len(operator.returns) > 1 + + if any( + [ + arg["type"].endswith("?[]") + or arg["type"].endswith("[]?") + or "Dict" in arg["type"] + or "Device" in arg["type"] + or "Generator" in arg["type"] + for arg in operator.arguments + ] + ): + print(f"{cpp_class_name} has weird args") + return + + if operator.is_vararg: + print(f"{cpp_class_name} is vararg") + return + else: + args = { + py_reserved_keywords(arg["name"]): convert_type(arg["type"])[0] + for arg in operator.arguments + } + for k, v in args.items(): + args[k] = v.replace("Tensor", "Value").replace("Number", '"Number"') + + def generic_result_name(i): + return "result" + (str(i) if multiple_results else "") + + if operator.is_varret: + print(f"{cpp_class_name} is vararg") + return + else: + ret_names = [ + f'{ret["name"] or generic_result_name(e)}' + for e, ret in enumerate(operator.returns) + ] + + if any([ret["type"] == "Device" for ret in operator.returns]): + print(f"{cpp_class_name} returns device") + return + + p_td(f"class {cpp_class_name}:") + ret_type_names = [] + arg_names = [] + with emitter_td.indent(): + args_str = ", ".join([f"{k}: {v}" for k, v in args.items()]) + if args_str: + args_str = f" {args_str}," + else: + args_str = "" + p_td(f"def __init__(self,{args_str} *, loc=None, ip=None):") + with emitter_td.indent(): + if any( + [ + convert_type(arg["type"])[1] != "Tensor" + or "?" in arg["type"] + or "[]" in arg["type"] + for arg in operator.arguments + ] + ): + p_td(f"from torch_mlir.dialects import torch as torch_dialect\n\n") + + for arg in operator.arguments: + arg_name = py_reserved_keywords(arg["name"]) + arg_names.append(arg_name) + arg_type = arg["type"] + p_td(f"if not is_mlir_value({arg_name}):") + with emitter_td.indent(): + convert_type_to_op(arg_name, arg_type, p_td, emitter_td) + p_td(f"else:") + not_none_arg_type = arg["type"].replace("?", "") + with emitter_td.indent(): + p_td(f"{arg_name} = get_op_result_or_value({arg_name})") + if not_none_arg_type in {"Tensor", "t"}: + p_td( + f"""assert str({arg_name}.type).startswith("!torch.vtensor"), f'`{arg_name}` should be a torch.vtensor but is {{type({arg_name}).__module__}}.{{type({arg_name}).__name__}}'""" + ) + elif not_none_arg_type == "int": + p_td( + f"""assert str({arg_name}.type) == '!torch.int', f'`{arg_name}` should be a !torch.int but is {{type({arg_name}).__module__}}.{{type({arg_name}).__name__}}'""" + ) + elif not_none_arg_type == "str": + p_td( + f"""assert str({arg_name}.type) == '!torch.str', f'`{arg_name}` should be a !torch.str but is {{type({arg_name}).__module__}}.{{type({arg_name}).__name__}}'""" + ) + elif not_none_arg_type == "float": + p_td( + f"""assert str({arg_name}.type) == '!torch.float', f'`{arg_name}` should be a !torch.float but is {{type({arg_name}).__module__}}.{{type({arg_name}).__name__}}'""" + ) + elif not_none_arg_type == "bool": + p_td( + f"""assert str({arg_name}.type) == '!torch.bool', f'`{arg_name}` should be a !torch.bool but is {{type({arg_name}).__module__}}.{{type({arg_name}).__name__}}'""" + ) + elif not_none_arg_type == "Scalar": + p_td( + f"""assert str({arg_name}.type) in {{'!torch.float', '!torch.int'}}, f'`{arg_name}` should be a !torch.number but is {{type({arg_name}).__module__}}.{{type({arg_name}).__name__}}'""" + ) + elif not_none_arg_type == "Any": + p_td( + f"""assert str({arg_name}.type) == '!torch.Any', f'`{arg_name}` should be a !torch.Any but is {{type({arg_name}).__module__}}.{{type({arg_name}).__name__}}'""" + ) + elif not_none_arg_type == "int[]": + p_td( + f"""assert str({arg_name}.type) == '!torch.list', f'`{arg_name}` should be a !torch.list but is {{type({arg_name}).__module__}}.{{type({arg_name}).__name__}}'""" + ) + elif not_none_arg_type in {"t[]", "Tensor[]"}: + p_td( + f"""assert str({arg_name}.type) == '!torch.list', f'`{arg_name}` should be a !torch.list but is {{type({arg_name}).__module__}}.{{type({arg_name}).__name__}}'""" + ) + else: + print( + f"{cpp_class_name} weird arg {arg_name} type {arg['type']}" + ) + p_td(f"# should be {arg['type']}") + p_td(f"pass") + + p_td("\n") + + for e, ret in enumerate(operator.returns): + name = f'{ret["name"] or generic_result_name(e)}' + if ret["type"] in {"Tensor", "t"}: + p_td(f"""{name}_type = Type.parse("!torch.vtensor")""") + elif ret["type"] == "int": + continue + # p_td(f"""{name}_type = Type.parse("!torch.int")""") + elif ret["type"] == "str": + continue + # p_td(f"""{name}_type = Type.parse("!torch.str")""") + elif ret["type"] == "float": + continue + # p_td(f"""{name}_type = Type.parse("!torch.float")""") + elif ret["type"] == "bool": + continue + # p_td(f"""{name}_type = Type.parse("!torch.bool")""") + elif ret["type"] == "Scalar": + continue + # p_td(f"""{name}_type = Type.parse("!torch.number")""") + elif ret["type"] == "Any": + continue + # p_td(f"""{name}_type = Type.parse("!torch.Any")""") + elif ret["type"] == "int[]": + p_td(f"""{name}_type = Type.parse("!torch.list")""") + elif ret["type"] == "t[]": + p_td(f"""{name}_type = Type.parse("!torch.list")""") + else: + raise Exception( + f"{cpp_class_name} weird ret {name} type {ret['type']}" + ) + ret_type_names.append(f"{name}_type") + + if ret_type_names: + ret_type_names = f"{', '.join(ret_type_names)}, " + else: + ret_type_names = "" + + if arg_names: + arg_names = f"{', '.join(arg_names)}, " + else: + arg_names = "" + + p_td( + f"super({cpp_class_name}, self).__init__({ret_type_names}{arg_names}loc=loc, ip=ip)" + ) + p_td("\n") + p_td("\n") + + stub_td(get_wrapper_function_signature(operator)) + with stubs_emitter_td.indent(): + for arg in operator.arguments: + arg_name = py_reserved_keywords(arg["name"]) + if arg_name == "dtype": + stub_td("if dtype is not None and isinstance(dtype, Enum):") + with stubs_emitter_td.indent(): + stub_td("dtype = dtype.value") + if arg["type"] == "Tensor": + stub_td( + f"assert isinstance({arg_name}, Tensor), f'`{arg_name}` should be a {{Tensor.__module__}}.{{Tensor.__name__}} but is {{type({arg_name}).__module__}}.{{type({arg_name}).__name__}}'" + ) + stub_td(f"{arg_name} = {arg_name}.value") + elif arg["type"] == "Tensor?": + stub_td(f"if {arg_name} is not None:") + with stubs_emitter_td.indent(): + stub_td( + f"assert isinstance({arg_name}, Tensor), f'`{arg_name}` should be a {{Tensor.__module__}}.{{Tensor.__name__}} but is {{type({arg_name}).__module__}}.{{type({arg_name}).__name__}}'" + ) + stub_td(f"{arg_name} = {arg_name}.value") + elif arg["type"] == "Tensor[]": + stub_td( + f"assert builtins.all(isinstance(t, Tensor) for t in {arg_name})" + ) + stub_td(f"{arg_name} = [t.value for t in {arg_name}]") + + call_str = f'torch_dialect.{cpp_class_name}({", ".join([f"{k}" for k, _v in args.items()])})' + if len(operator.returns) == 0: + ret = call_str + elif len(operator.returns) == 1: + ret = f"return Tensor({call_str})" + else: + stub_td(f"op_results = get_op_results_or_values({call_str})") + ret = f"return tuple([Tensor(o) if is_a_torch_tensor(o) else o for o in op_results])" + stub_td(f"{ret}") + stub_td("\n") + + +import torch_mlir.dialects.torch.importer.jit_ir.build_tools.torch_ods_gen + +torch_mlir.dialects.torch.importer.jit_ir.build_tools.torch_ods_gen.raw_emit_op = ( + raw_emit_op +) + +from torch_mlir.dialects.torch.importer.jit_ir.build_tools.torch_ods_gen import emit_ops + +# native_yaml_path = "native_functions.yaml" +# tags_yaml_path = "tags.yaml" +# deprecated_yaml_path = "deprecated.yaml" +# +# native_functions = parse_native_yaml(native_yaml_path, tags_yaml_path).native_functions +# native_functions = list(filter(should_generate_py_binding, native_functions)) +# +# function_signatures = load_signatures( +# native_functions, deprecated_yaml_path, method=False, pyi=True +# ) +# sig_groups = get_py_torch_functions(function_signatures) +# +# +# tensor_method_signatures = load_signatures( +# native_functions, +# deprecated_yaml_path, +# method=True, +# skip_deprecated=True, +# pyi=True, +# ) +# tensor_method_sig_groups = get_py_torch_functions(tensor_method_signatures, method=True) + +_torch_ops_ext_fp = "../../pi/dialects/_torch_ops_ext.py" +_torch_wrappers_fp = "../../pi/dialects/_torch_wrappers.py" + +registry = Registry.load() +with open(_torch_ops_ext_fp, "w") as f_td: + emitter_td = TextEmitter(f_td) + emitter_td._INDENT = " " + with open(_torch_wrappers_fp, "w") as stubs_td: + stubs_emitter_td = TextEmitter(stubs_td) + stubs_emitter_td._INDENT = " " + stubs_emitter_td.print( + dedent( + f"""\ + from enum import Enum + import builtins + + from .._tensor import Tensor + from ..types_ import Number, is_a_torch_tensor + from typing import List, Optional, Any, Tuple + + from torch_mlir.dialects import torch as torch_dialect + from torch_mlir.dialects._ods_common import ( + get_default_loc_context, + get_op_result_or_value, + get_op_results_or_values, + ) + + """ + ) + ) + + emitter_td.print( + dedent( + f"""\ + try: + # from pi import Tensor, Number + from torch_mlir.ir import * + from torch_mlir.dialects._ods_common import ( + get_default_loc_context, + get_op_result_or_value, + get_op_results_or_values, + ) + from ._torch_ops_ext_custom import * + except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + + from typing import List, Optional, Any + + + """ + ) + ) + emit_ops(emitter_td, registry) + + assert len(ALL) == len(set(ALL)), "duplicate ALL" + ALL = [f"'{a}'" for a in ALL] + stubs_emitter_td.print("\n\n") + stubs_emitter_td.print(f"__all__ = [{', '.join(ALL)}]") diff --git a/setup.py b/setup.py index 16fbf54..d7c63a7 100644 --- a/setup.py +++ b/setup.py @@ -1,29 +1,161 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import os +import platform +import re +import subprocess import sys +import tarfile +import urllib.request +from pathlib import Path + +from setuptools import Extension, setup, find_namespace_packages +from setuptools.command.build_ext import build_ext + +# Convert distutils Windows platform specifiers to CMake -A arguments +PLAT_TO_CMAKE = { + "win-amd64": "x64", + "win-arm32": "ARM", + "win-arm64": "ARM64", +} + + +def get_llvm_package(): + # download if nothing is installed + system = platform.system() + system_suffix = {"Linux": "linux-gnu-ubuntu-20.04", "Darwin": "apple-darwin"}[ + system + ] + LIB_ARCH = os.environ.get("LIB_ARCH", platform.machine()) + assert LIB_ARCH is not None + print(f"ARCH {LIB_ARCH}") + name = f"llvm+mlir+python-{sys.version_info.major}.{sys.version_info.minor}-15.0.4-{LIB_ARCH}-{system_suffix}-release" + here = Path(__file__).parent + if not (here / "llvm_install").exists(): + url = f"https://github.com/makslevental/llvm-releases/releases/latest/download/{name}.tar.xz" + print(f"downloading and extracting {url} ...") + ftpstream = urllib.request.urlopen(url) + file = tarfile.open(fileobj=ftpstream, mode="r|*") + file.extractall(path=str(here)) + + print("done downloading") + return str((here / "llvm_install").absolute()) + + +# A CMakeExtension needs a sourcedir instead of a file list. +# The name must be the _single_ output extension from the CMake build. +# If you need multiple extensions, see scikit-build. +class CMakeExtension(Extension): + def __init__(self, name: str, sourcedir: str = "") -> None: + super().__init__(name, sources=[]) + self.sourcedir = os.fspath(Path(sourcedir).resolve()) + + +class CMakeBuild(build_ext): + def build_extension(self, ext: CMakeExtension) -> None: + # Must be in this form due to bug in .resolve() only fixed in Python 3.10+ + ext_fullpath = Path.cwd() / self.get_ext_fullpath(ext.name) # type: ignore[no-untyped-call] + ext_build_lib_dir = ext_fullpath.parent.resolve() + + # Using this requires trailing slash for auto-detection & inclusion of + # auxiliary "native" libs + + debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug + cfg = "Debug" if debug else "Release" + + # CMake lets you override the generator - we need to check this. + # Can be set with Conda-Build, for example. + cmake_generator = os.environ.get("CMAKE_GENERATOR", "") + + llvm_install_dir = os.environ.get("LLVM_INSTALL_DIR", None) + if llvm_install_dir is None: + llvm_install_dir = get_llvm_package() + # Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON + # EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code + # from Python. + cmake_args = [ + # f"-DCMAKE_BUILD_TYPE=Debug", + # f"-DCMAKE_C_COMPILER=clang", + # f"-DCMAKE_CXX_COMPILER=clang++", + f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={ext_build_lib_dir}{os.sep}pi", + f"-DCMAKE_PREFIX_PATH={llvm_install_dir}", + f"-DCMAKE_MODULE_LINKER_FLAGS=-L{llvm_install_dir}/lib", + f"-DCMAKE_SHARED_LINKER_FLAGS=-L{llvm_install_dir}/lib", + f"-DCMAKE_EXE_LINKER_FLAGS=-L{llvm_install_dir}/lib", + f"-DPython3_EXECUTABLE={sys.executable}", + f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm + ] + build_args = [] + # Adding CMake arguments set as environment variable + # (needed e.g. to build for ARM OSx on conda-forge) + if "CMAKE_ARGS" in os.environ: + cmake_args += [item for item in os.environ["CMAKE_ARGS"].split(" ") if item] + + # In this example, we pass in the version to C++. You might not need to. + # Using Ninja-build since it a) is available as a wheel and b) + # multithreads automatically. MSVC would require all variables be + # exported for Ninja to pick it up, which is a little tricky to do. + # Users can override the generator with CMAKE_GENERATOR in CMake + # 3.15+. + if not cmake_generator or cmake_generator == "Ninja": + try: + import ninja # noqa: F401 + + ninja_executable_path = Path(ninja.BIN_DIR) / "ninja" + cmake_args += [ + "-GNinja", + f"-DCMAKE_MAKE_PROGRAM:FILEPATH={ninja_executable_path}", + ] + except ImportError: + pass + + if sys.platform.startswith("darwin"): + # Cross-compile support for macOS - respect ARCHFLAGS if set + archs = re.findall(r"-arch (\S+)", os.environ.get("ARCHFLAGS", "")) + if archs: + cmake_args += ["-DCMAKE_OSX_ARCHITECTURES={}".format(";".join(archs))] + + # Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level + # across all generators. + if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ: + # self.parallel is a Python 3 only way to set parallel jobs by hand + # using -j in the build_ext call, not supported by pip or PyPA-build. + if hasattr(self, "parallel") and self.parallel: + # CMake 3.12+ only. + build_args += [f"-j{self.parallel}"] + + build_temp = Path(self.build_temp) / ext.name + if not build_temp.exists(): + build_temp.mkdir(parents=True) + + subprocess.run( + ["cmake", ext.sourcedir] + cmake_args, cwd=build_temp, check=True + ) + subprocess.run( + ["cmake", "--build", "."] + build_args, cwd=build_temp, check=True + ) -from setuptools import find_namespace_packages, setup packages = find_namespace_packages( include=[ - "shark", - "shark.*", + "pi", + "pi.*", ], ) -VERSION = "0.0.1" + +VERSION = "0.0.2" if len(sys.argv) > 1 and sys.argv[1] == "--version": print(VERSION) else: setup( - name="SharkPy", + name="PI", version=VERSION, author="Maksim Levental", author_email="maksim.levental@gmail.com", - description="Python frontend for MLIR (and torch-mlir)", - zip_safe=False, + description="A lightweight MLIR Python frontend with PyTorch like syntax", + # ext_modules=[CMakeExtension("_tensor")], + # cmdclass={"build_ext": CMakeBuild}, packages=packages, - install_requires=["PyYAML", "pyccolo", "torch-mlir"], + zip_safe=False, + install_requires=["PyYAML", "pyccolo", "torch-mlir", "multiprocess"], ) diff --git a/shark/__init__.py b/shark/__init__.py deleted file mode 100644 index 54947fb..0000000 --- a/shark/__init__.py +++ /dev/null @@ -1 +0,0 @@ -import shark.dialects \ No newline at end of file diff --git a/shark/compiler/utils.py b/shark/compiler/utils.py deleted file mode 100644 index 0cd9caf..0000000 --- a/shark/compiler/utils.py +++ /dev/null @@ -1,108 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# Also available under a BSD-style license. See LICENSE. -import ast -import os -import sys -import tempfile -from io import StringIO - -from shark._mlir_libs._mlir.ir import ( - Operation, - IntegerType, -) -from shark._mlir_libs._mlir.passmanager import PassManager -from shark.dialects import arith - -ONE_SHOT_BUFFERIZATION_PIPELINE = [ - "func.func(linalg-init-tensor-to-alloc-tensor)", - "one-shot-bufferize", - "func-bufferize", - "arith-bufferize", - "func.func(finalizing-bufferize)", -] - -LOWERING_PIPELINE = [ - # Lower to LLVM - "convert-scf-to-cf", - # "func.func(refback-expand-ops-for-llvm)", - "func.func(arith-expand)", - "func.func(convert-math-to-llvm)", - # Handle some complex mlir::math ops (e.g. atan2) - "convert-math-to-libm", - "convert-linalg-to-llvm", - "convert-memref-to-llvm", - "func.func(convert-arith-to-llvm)", - "convert-func-to-llvm", - "convert-cf-to-llvm", - "reconcile-unrealized-casts", -] - - -def run_pipeline_with_repro_report(module, pipeline: str, description: str = None): - try: - original_stderr = sys.stderr - sys.stderr = StringIO() - asm_for_error_report = module.operation.get_asm( - large_elements_limit=10, enable_debug_info=True - ) - # Lower module in place to make it ready for compiler backends. - with module.context: - pm = PassManager.parse(pipeline) - pm.run(module) - except Exception as e: - filename = os.path.join(tempfile.gettempdir(), "tmp.mlir") - with open(filename, "w") as f: - f.write(asm_for_error_report) - debug_options = "-mlir-print-ir-after-all -mlir-disable-threading" - description = description or f"tmp compile" - - message = f"""\ - {description} failed with the following diagnostics: - {sys.stderr.getvalue()} - - For MLIR developers, the error can be reproduced with: - $ mlir-opt -pass-pipeline='{pipeline}' {filename} - Add '{debug_options}' to get the IR dump for debugging purpose. - """ - trimmed_message = "\n".join([m.lstrip() for m in message.split("\n")]) - raise Exception(trimmed_message) from None - finally: - sys.stderr = original_stderr - - -def traverse_op_region_block_iterators(op, handler): - for i, region in enumerate(op.regions): - for j, block in enumerate(region): - for k, child_op in enumerate(block): - res = handler(child_op) - if res is not None and isinstance(res, Exception): - return res - res = traverse_op_region_block_iterators(child_op, handler) - if res is not None and isinstance(res, Exception): - return res - - -def parse_attrs_to_dict(attrs): - d = {} - for named_attr in attrs: - if named_attr.name in {"lpStartTime", "value"}: - d[named_attr.name] = ast.literal_eval( - str(named_attr.attr).split(":")[0].strip() - ) - elif named_attr.name in {"opr"}: - d[named_attr.name] = ast.literal_eval(str(named_attr.attr)) - else: - d[named_attr.name] = ast.literal_eval(str(named_attr.attr).replace('"', "")) - return d - - -def make_i32_int(x): - return arith.ConstantOp(IntegerType.get_signless(32), x).result - - -def add_dummy_value(): - return Operation.create( - "custom.value", results=[IntegerType.get_signless(32)] - ).result diff --git a/shark/dialects/__init__.py b/shark/dialects/__init__.py deleted file mode 100644 index 5b8048a..0000000 --- a/shark/dialects/__init__.py +++ /dev/null @@ -1,105 +0,0 @@ -import sys - -import logging -import threading -from contextlib import contextmanager -from importlib.abc import MetaPathFinder -from importlib.machinery import SourceFileLoader, ModuleSpec -from importlib.util import find_spec, spec_from_loader -from pathlib import Path -from typing import Generator - -logger = logging.getLogger(__name__) - - -OVERLOADS = { - "torch_mlir.dialects._arith_ops_ext": str( - Path(__file__).parent / "_arith_ops_ext.py" - ), - "torch_mlir.dialects._memref_ops_ext": str( - Path(__file__).parent / "_memref_ops_ext.py" - ), -} - - -# this is based on the birdseye finder (which uses import hooks based on MacroPy's): -# https://github.com/alexmojaki/birdseye/blob/9974af715b1801f9dd99fef93ff133d0ab5223af/birdseye/import_hook.py -class Overloader(MetaPathFinder): - def __init__(self) -> None: - self.tracers = None - self._thread = threading.current_thread() - - @contextmanager - def _clear_preceding_finders(self) -> Generator[None, None, None]: - """ - Clear all preceding finders from sys.meta_path, and restore them afterwards. - """ - orig_finders = sys.meta_path - try: - sys.meta_path = sys.meta_path[sys.meta_path.index(self) + 1 :] # noqa: E203 - yield - finally: - sys.meta_path = orig_finders - - def _find_plain_spec(self, fullname, path, target): - """Try to find the original module using all the - remaining meta_path finders.""" - spec = None - self_seen = False - for finder in sys.meta_path: - if finder is self: - self_seen = True - continue - elif not self_seen or "pytest" in finder.__module__: - # when testing with pytest, it installs a finder that for - # some yet unknown reasons makes birdseye - # fail. For now it will just avoid using it and pass to - # the next one - continue - if hasattr(finder, "find_spec"): - spec = finder.find_spec(fullname, path, target=target) - elif hasattr(finder, "load_module"): - spec = spec_from_loader(fullname, finder) - - if spec is not None and spec.origin != "builtin": - return spec - - def find_spec(self, fullname, path=None, target=None): - if threading.current_thread() is not self._thread: - return None - if target is None: - with self._clear_preceding_finders(): - spec = find_spec(fullname, path) - else: - spec = self._find_plain_spec(fullname, path, target) - if spec is None or not ( - hasattr(spec.loader, "get_source") and callable(spec.loader.get_source) - ): # noqa: E128 - if fullname != "org": - # stdlib pickle.py at line 94 contains a ``from - # org.python.core for Jython which is always failing, - # of course - logger.debug("Failed finding spec for %s", fullname) - return None - - if not isinstance(spec.loader, SourceFileLoader): - return None - if fullname in OVERLOADS: - # spec = OVERLOADS[fullname] - new_path = OVERLOADS[fullname] - source_file_loader = SourceFileLoader(fullname, new_path) - spec = ModuleSpec( - name=fullname, - loader=source_file_loader, - origin=new_path, - is_package=False, - ) - spec.has_location = True - return spec - - -if len(sys.meta_path) > 0 and isinstance(sys.meta_path[0], Overloader): - orig_meta_path_entry = sys.meta_path[0] - sys.meta_path[0] = Overloader() -else: - sys.meta_path.insert(0, Overloader()) diff --git a/shark/dialects/_affine_ops_gen.py b/shark/dialects/_affine_ops_gen.py deleted file mode 100644 index f2dadc1..0000000 --- a/shark/dialects/_affine_ops_gen.py +++ /dev/null @@ -1,380 +0,0 @@ - -# Autogenerated by mlir-tblgen; don't manually edit. - -from ._ods_common import _cext as _ods_cext -from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values -_ods_ir = _ods_cext.ir - -try: - from . import _affine_ops_ext as _ods_ext_module -except ImportError: - _ods_ext_module = None - -import builtins - - -@_ods_cext.register_dialect -class _Dialect(_ods_ir.Dialect): - DIALECT_NAMESPACE = "affine" - pass - - -@_ods_cext.register_operation(_Dialect) -@_ods_extend_opview_class(_ods_ext_module) -class AffineApplyOp(_ods_ir.OpView): - OPERATION_NAME = "affine.apply" - - _ODS_REGIONS = (0, True) - - def __init__(self, result, map, mapOperands, *, loc=None, ip=None): - operands = [] - results = [] - attributes = {} - regions = None - operands.extend(_get_op_results_or_values(mapOperands)) - attributes["map"] = map - results.append(result) - _ods_successors = None - super().__init__(self.build_generic( - attributes=attributes, results=results, operands=operands, - successors=_ods_successors, regions=regions, loc=loc, ip=ip)) - - @builtins.property - def mapOperands(self): - _ods_variadic_group_length = len(self.operation.operands) - 1 + 1 - return self.operation.operands[0:0 + _ods_variadic_group_length] - -@_ods_cext.register_operation(_Dialect) -@_ods_extend_opview_class(_ods_ext_module) -class AffineForOp(_ods_ir.OpView): - OPERATION_NAME = "affine.for" - - _ODS_REGIONS = (1, True) - - @builtins.property - def results_(self): - _ods_variadic_group_length = len(self.operation.results) - 1 + 1 - return self.operation.results[0:0 + _ods_variadic_group_length] - - @builtins.property - def region(self): - return self.regions[0] - -@_ods_cext.register_operation(_Dialect) -@_ods_extend_opview_class(_ods_ext_module) -class AffineIfOp(_ods_ir.OpView): - OPERATION_NAME = "affine.if" - - _ODS_REGIONS = (2, True) - - @builtins.property - def results_(self): - _ods_variadic_group_length = len(self.operation.results) - 1 + 1 - return self.operation.results[0:0 + _ods_variadic_group_length] - - @builtins.property - def thenRegion(self): - return self.regions[0] - - @builtins.property - def elseRegion(self): - return self.regions[1] - -@_ods_cext.register_operation(_Dialect) -@_ods_extend_opview_class(_ods_ext_module) -class AffineLoadOp(_ods_ir.OpView): - OPERATION_NAME = "affine.load" - - _ODS_REGIONS = (0, True) - - def __init__(self, result, memref, indices, *, loc=None, ip=None): - operands = [] - results = [] - attributes = {} - regions = None - operands.append(_get_op_result_or_value(memref)) - operands.extend(_get_op_results_or_values(indices)) - results.append(result) - _ods_successors = None - super().__init__(self.build_generic( - attributes=attributes, results=results, operands=operands, - successors=_ods_successors, regions=regions, loc=loc, ip=ip)) - - @builtins.property - def memref(self): - return self.operation.operands[0] - - @builtins.property - def indices(self): - _ods_variadic_group_length = len(self.operation.operands) - 2 + 1 - return self.operation.operands[1:1 + _ods_variadic_group_length] - - @builtins.property - def result(self): - return self.operation.results[0] - -@_ods_cext.register_operation(_Dialect) -@_ods_extend_opview_class(_ods_ext_module) -class AffineMaxOp(_ods_ir.OpView): - OPERATION_NAME = "affine.max" - - _ODS_REGIONS = (0, True) - - def __init__(self, result, map, operands_, *, loc=None, ip=None): - operands = [] - results = [] - attributes = {} - regions = None - operands.extend(_get_op_results_or_values(operands_)) - attributes["map"] = map - results.append(result) - _ods_successors = None - super().__init__(self.build_generic( - attributes=attributes, results=results, operands=operands, - successors=_ods_successors, regions=regions, loc=loc, ip=ip)) - - @builtins.property - def operands_(self): - _ods_variadic_group_length = len(self.operation.operands) - 1 + 1 - return self.operation.operands[0:0 + _ods_variadic_group_length] - -@_ods_cext.register_operation(_Dialect) -@_ods_extend_opview_class(_ods_ext_module) -class AffineMinOp(_ods_ir.OpView): - OPERATION_NAME = "affine.min" - - _ODS_REGIONS = (0, True) - - def __init__(self, result, map, operands_, *, loc=None, ip=None): - operands = [] - results = [] - attributes = {} - regions = None - operands.extend(_get_op_results_or_values(operands_)) - attributes["map"] = map - results.append(result) - _ods_successors = None - super().__init__(self.build_generic( - attributes=attributes, results=results, operands=operands, - successors=_ods_successors, regions=regions, loc=loc, ip=ip)) - - @builtins.property - def operands_(self): - _ods_variadic_group_length = len(self.operation.operands) - 1 + 1 - return self.operation.operands[0:0 + _ods_variadic_group_length] - -@_ods_cext.register_operation(_Dialect) -@_ods_extend_opview_class(_ods_ext_module) -class AffineParallelOp(_ods_ir.OpView): - OPERATION_NAME = "affine.parallel" - - _ODS_REGIONS = (1, True) - - def __init__(self, results_, reductions, lowerBoundsMap, lowerBoundsGroups, upperBoundsMap, upperBoundsGroups, steps, mapOperands, *, loc=None, ip=None): - operands = [] - results = [] - attributes = {} - regions = None - operands.extend(_get_op_results_or_values(mapOperands)) - attributes["reductions"] = reductions - attributes["lowerBoundsMap"] = lowerBoundsMap - attributes["lowerBoundsGroups"] = lowerBoundsGroups - attributes["upperBoundsMap"] = upperBoundsMap - attributes["upperBoundsGroups"] = upperBoundsGroups - attributes["steps"] = steps - results.extend(results_) - _ods_successors = None - super().__init__(self.build_generic( - attributes=attributes, results=results, operands=operands, - successors=_ods_successors, regions=regions, loc=loc, ip=ip)) - - @builtins.property - def mapOperands(self): - _ods_variadic_group_length = len(self.operation.operands) - 1 + 1 - return self.operation.operands[0:0 + _ods_variadic_group_length] - - @builtins.property - def lowerBoundsGroups(self): - return _ods_ir.DenseIntElementsAttr(self.operation.attributes["lowerBoundsGroups"]) - - @lowerBoundsGroups.setter - def lowerBoundsGroups(self, value): - if value is None: - raise ValueError("'None' not allowed as value for mandatory attributes") - self.operation.attributes["lowerBoundsGroups"] = value - - @builtins.property - def upperBoundsGroups(self): - return _ods_ir.DenseIntElementsAttr(self.operation.attributes["upperBoundsGroups"]) - - @upperBoundsGroups.setter - def upperBoundsGroups(self, value): - if value is None: - raise ValueError("'None' not allowed as value for mandatory attributes") - self.operation.attributes["upperBoundsGroups"] = value - - @builtins.property - def results_(self): - _ods_variadic_group_length = len(self.operation.results) - 1 + 1 - return self.operation.results[0:0 + _ods_variadic_group_length] - - @builtins.property - def region(self): - return self.regions[0] - -@_ods_cext.register_operation(_Dialect) -@_ods_extend_opview_class(_ods_ext_module) -class AffinePrefetchOp(_ods_ir.OpView): - OPERATION_NAME = "affine.prefetch" - - _ODS_REGIONS = (0, True) - - def __init__(self, memref, indices, isWrite, localityHint, isDataCache, *, loc=None, ip=None): - operands = [] - results = [] - attributes = {} - regions = None - operands.append(_get_op_result_or_value(memref)) - operands.extend(_get_op_results_or_values(indices)) - attributes["isWrite"] = isWrite - attributes["localityHint"] = localityHint - attributes["isDataCache"] = isDataCache - _ods_successors = None - super().__init__(self.build_generic( - attributes=attributes, results=results, operands=operands, - successors=_ods_successors, regions=regions, loc=loc, ip=ip)) - - @builtins.property - def memref(self): - return self.operation.operands[0] - - @builtins.property - def indices(self): - _ods_variadic_group_length = len(self.operation.operands) - 2 + 1 - return self.operation.operands[1:1 + _ods_variadic_group_length] - - @builtins.property - def isWrite(self): - return _ods_ir.BoolAttr(self.operation.attributes["isWrite"]) - - @isWrite.setter - def isWrite(self, value): - if value is None: - raise ValueError("'None' not allowed as value for mandatory attributes") - self.operation.attributes["isWrite"] = value - - @builtins.property - def localityHint(self): - return _ods_ir.IntegerAttr(self.operation.attributes["localityHint"]) - - @localityHint.setter - def localityHint(self, value): - if value is None: - raise ValueError("'None' not allowed as value for mandatory attributes") - self.operation.attributes["localityHint"] = value - - @builtins.property - def isDataCache(self): - return _ods_ir.BoolAttr(self.operation.attributes["isDataCache"]) - - @isDataCache.setter - def isDataCache(self, value): - if value is None: - raise ValueError("'None' not allowed as value for mandatory attributes") - self.operation.attributes["isDataCache"] = value - -@_ods_cext.register_operation(_Dialect) -@_ods_extend_opview_class(_ods_ext_module) -class AffineStoreOp(_ods_ir.OpView): - OPERATION_NAME = "affine.store" - - _ODS_REGIONS = (0, True) - - @builtins.property - def value(self): - return self.operation.operands[0] - - @builtins.property - def memref(self): - return self.operation.operands[1] - - @builtins.property - def indices(self): - _ods_variadic_group_length = len(self.operation.operands) - 3 + 1 - return self.operation.operands[2:2 + _ods_variadic_group_length] - -@_ods_cext.register_operation(_Dialect) -@_ods_extend_opview_class(_ods_ext_module) -class AffineVectorLoadOp(_ods_ir.OpView): - OPERATION_NAME = "affine.vector_load" - - _ODS_REGIONS = (0, True) - - def __init__(self, result, memref, indices, *, loc=None, ip=None): - operands = [] - results = [] - attributes = {} - regions = None - operands.append(_get_op_result_or_value(memref)) - operands.extend(_get_op_results_or_values(indices)) - results.append(result) - _ods_successors = None - super().__init__(self.build_generic( - attributes=attributes, results=results, operands=operands, - successors=_ods_successors, regions=regions, loc=loc, ip=ip)) - - @builtins.property - def memref(self): - return self.operation.operands[0] - - @builtins.property - def indices(self): - _ods_variadic_group_length = len(self.operation.operands) - 2 + 1 - return self.operation.operands[1:1 + _ods_variadic_group_length] - - @builtins.property - def result(self): - return self.operation.results[0] - -@_ods_cext.register_operation(_Dialect) -@_ods_extend_opview_class(_ods_ext_module) -class AffineVectorStoreOp(_ods_ir.OpView): - OPERATION_NAME = "affine.vector_store" - - _ODS_REGIONS = (0, True) - - @builtins.property - def value(self): - return self.operation.operands[0] - - @builtins.property - def memref(self): - return self.operation.operands[1] - - @builtins.property - def indices(self): - _ods_variadic_group_length = len(self.operation.operands) - 3 + 1 - return self.operation.operands[2:2 + _ods_variadic_group_length] - -@_ods_cext.register_operation(_Dialect) -@_ods_extend_opview_class(_ods_ext_module) -class AffineYieldOp(_ods_ir.OpView): - OPERATION_NAME = "affine.yield" - - _ODS_REGIONS = (0, True) - - def __init__(self, operands_, *, loc=None, ip=None): - operands = [] - results = [] - attributes = {} - regions = None - operands.extend(_get_op_results_or_values(operands_)) - _ods_successors = None - super().__init__(self.build_generic( - attributes=attributes, results=results, operands=operands, - successors=_ods_successors, regions=regions, loc=loc, ip=ip)) - - @builtins.property - def operands_(self): - _ods_variadic_group_length = len(self.operation.operands) - 1 + 1 - return self.operation.operands[0:0 + _ods_variadic_group_length] diff --git a/shark/dialects/_ods_common.py b/shark/dialects/_ods_common.py deleted file mode 100644 index 30eef39..0000000 --- a/shark/dialects/_ods_common.py +++ /dev/null @@ -1,162 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -# Provide a convenient name for sub-packages to resolve the main C-extension -# with a relative import. -from torch_mlir._mlir_libs import _mlir as _cext - -from typing import Sequence as _Sequence, Union as _Union - -__all__ = [ - "equally_sized_accessor", - "extend_opview_class", - "get_default_loc_context", - "get_op_result_or_value", - "get_op_results_or_values", - "segmented_accessor", -] - - -def extend_opview_class(ext_module): - """Decorator to extend an OpView class from an extension module. - - Extension modules can expose various entry-points: - Stand-alone class with the same name as a parent OpView class (i.e. - "ReturnOp"). A name-based match is attempted first before falling back - to a below mechanism. - - def select_opview_mixin(parent_opview_cls): - If defined, allows an appropriate mixin class to be selected dynamically - based on the parent OpView class. Should return NotImplemented if a - decision is not made. - - Args: - ext_module: A module from which to locate extensions. Can be None if not - available. - - Returns: - A decorator that takes an OpView subclass and further extends it as - needed. - """ - - def class_decorator(parent_opview_cls: type): - if ext_module is None: - return parent_opview_cls - mixin_cls = NotImplemented - # First try to resolve by name. - try: - mixin_cls = getattr(ext_module, parent_opview_cls.__name__) - except AttributeError: - # Fall back to a select_opview_mixin hook. - try: - select_mixin = getattr(ext_module, "select_opview_mixin") - except AttributeError: - pass - else: - mixin_cls = select_mixin(parent_opview_cls) - - if mixin_cls is NotImplemented or mixin_cls is None: - return parent_opview_cls - - # Have a mixin_cls. Create an appropriate subclass. - try: - - class LocalOpView(mixin_cls, parent_opview_cls): - pass - except TypeError as e: - raise TypeError( - f"Could not mixin {mixin_cls} into {parent_opview_cls}") from e - LocalOpView.__name__ = parent_opview_cls.__name__ - LocalOpView.__qualname__ = parent_opview_cls.__qualname__ - return LocalOpView - - return class_decorator - - -def segmented_accessor(elements, raw_segments, idx): - """ - Returns a slice of elements corresponding to the idx-th segment. - - elements: a sliceable container (operands or results). - raw_segments: an mlir.ir.Attribute, of DenseI32Array subclass containing - sizes of the segments. - idx: index of the segment. - """ - segments = _cext.ir.DenseI32ArrayAttr(raw_segments) - start = sum(segments[i] for i in range(idx)) - end = start + segments[idx] - return elements[start:end] - - -def equally_sized_accessor(elements, n_variadic, n_preceding_simple, - n_preceding_variadic): - """ - Returns a starting position and a number of elements per variadic group - assuming equally-sized groups and the given numbers of preceding groups. - - elements: a sequential container. - n_variadic: the number of variadic groups in the container. - n_preceding_simple: the number of non-variadic groups preceding the current - group. - n_preceding_variadic: the number of variadic groups preceding the current - group. - """ - - total_variadic_length = len(elements) - n_variadic + 1 - # This should be enforced by the C++-side trait verifier. - assert total_variadic_length % n_variadic == 0 - - elements_per_group = total_variadic_length // n_variadic - start = n_preceding_simple + n_preceding_variadic * elements_per_group - return start, elements_per_group - - -def get_default_loc_context(location=None): - """ - Returns a context in which the defaulted location is created. If the location - is None, takes the current location from the stack, raises ValueError if there - is no location on the stack. - """ - if location is None: - # Location.current raises ValueError if there is no current location. - return _cext.ir.Location.current.context - return location.context - - -def get_op_result_or_value( - arg: _Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value, _cext.ir.OpResultList] -) -> _cext.ir.Value: - """Returns the given value or the single result of the given op. - - This is useful to implement op constructors so that they can take other ops as - arguments instead of requiring the caller to extract results for every op. - Raises ValueError if provided with an op that doesn't have a single result. - """ - if isinstance(arg, _cext.ir.OpView): - return arg.operation.result - elif isinstance(arg, _cext.ir.Operation): - return arg.result - elif isinstance(arg, _cext.ir.OpResultList): - return arg[0] - else: - assert isinstance(arg, _cext.ir.Value) - return arg - - -def get_op_results_or_values( - arg: _Union[_cext.ir.OpView, _cext.ir.Operation, - _Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]]] -) -> _Union[_Sequence[_cext.ir.Value], _cext.ir.OpResultList]: - """Returns the given sequence of values or the results of the given op. - - This is useful to implement op constructors so that they can take other ops as - lists of arguments instead of requiring the caller to extract results for - every op. - """ - if isinstance(arg, _cext.ir.OpView): - return arg.operation.results - elif isinstance(arg, _cext.ir.Operation): - return arg.results - else: - return [get_op_result_or_value(element) for element in arg] diff --git a/tests/nn_module.py b/tests/nn_module.py new file mode 100644 index 0000000..4ba93e5 --- /dev/null +++ b/tests/nn_module.py @@ -0,0 +1,33 @@ +from pi import empty +from pi import nn + + +class MyConv2d(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(3, 1, 3) + + def forward(self, x): + y = self.conv(x) + z = y + y + w = z * z + return w + + +def simple_conv2d(): + x = empty((1, 3, 32, 32)) + my_conv = MyConv2d() + y = my_conv(x) + return y + + +simple_conv2d() + + +class Conv2dNoPaddingModule(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(2, 10, 3, bias=False) + + def forward(self, x): + return self.conv(x) diff --git a/tests/simple_kernels.py b/tests/simple_kernels.py index f1227a7..1f11ec5 100644 --- a/tests/simple_kernels.py +++ b/tests/simple_kernels.py @@ -1,14 +1,6 @@ from torch_mlir.dialects import memref, linalg, torch, arith, tensor -from torch_mlir.ir import ( - FloatAttr, - F64Type, - RankedTensorType, - DictAttr, - DenseIntElementsAttr, - Attribute, - ShapedType, - DenseFPElementsAttr, -) +from torch_mlir.ir import Attribute, DenseFPElementsAttr, DenseElementsAttr +import numpy as np def saxpy(a: float, b: float): @@ -59,12 +51,11 @@ def linalg_ops(min: float, max: float, seed: "i32"): def torch_ops(): - f64 = F64Type.get() - z = torch.ConstantFloatOp(value=FloatAttr.get(f64, 256.0)) + z = torch.ConstantFloatOp(value=256.0) attr = DenseFPElementsAttr(Attribute.parse("dense<0.0> : tensor<3x5xf32>")) a = torch.ValueTensorLiteralOp(attr) b = torch.ValueTensorLiteralOp(attr) - c = torch.AtenAddTensorOp(a.result.type, a.result, b.result, z) + c = torch.AtenAddTensorOp(a.result, b.result, 1) return c diff --git a/tests/tensor.py b/tests/tensor.py new file mode 100644 index 0000000..38f97fb --- /dev/null +++ b/tests/tensor.py @@ -0,0 +1,21 @@ +import pi + +from torch_mlir.dialects import torch +from torch_mlir.ir import IntegerType +import numpy as np + +from pi.mlir_utils import mlir_cm + +if __name__ == "__main__": + with mlir_cm() as module: + z = pi.empty((1, 2, 3)) + t = torch.ConstantIntOp(1).result + print(t.type) + i = IntegerType.get_unsigned(32) + print(i) + print(module) + vt = pi.from_numpy(np.random.rand(10, 10)) + z = vt + vt + print(module) + # sizes = parse_sizes_from_tensor_type_str(z) + # print(sizes) diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 9b6c935..a2aef39 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -1,6 +1,11 @@ -from shark.compiler.compiler import mlir_trace +import os.path + +from pi.compiler.compiler import mlir_trace # TODO(max): need to figure out how to reload the module so that the bytecode gets run through pycc -mlir_module = mlir_trace("simple_kernels.py") +mlir_module = mlir_trace(os.path.abspath("simple_kernels.py")) +print(mlir_module) + +mlir_module = mlir_trace(os.path.abspath("nn_module.py")) print(mlir_module) diff --git a/tests/test_tiling.py b/tests/test_tiling.py index abc2015..6e3bd81 100644 --- a/tests/test_tiling.py +++ b/tests/test_tiling.py @@ -1,5 +1,5 @@ -from shark.compiler.compiler import mlir_trace -from shark.compiler.utils import run_pipeline_with_repro_report +from pi.compiler.compiler import mlir_trace +from pi.compiler.utils import run_pipeline_with_repro_report mlir_module = mlir_trace("simple_kernels.py") diff --git a/tests/torch_mlir/main.py b/tests/torch_mlir/main.py new file mode 100644 index 0000000..96130e9 --- /dev/null +++ b/tests/torch_mlir/main.py @@ -0,0 +1,212 @@ +import difflib +import sys +import traceback +from pathlib import Path + + +# noinspection PyUnresolvedReferences +import pi + +# noinspection PyUnresolvedReferences +import torch_mlir +import torch_mlir_e2e_test + +from xfail import PI_XFAIL_SET + +pi_package_root_path = Path(pi.__file__).parent +torch_mlir_package_root_path = Path(torch_mlir.__file__).parent +torch_mlir_e2e_test_package_root_path = Path(torch_mlir_e2e_test.__file__).parent + +from torch_mlir_e2e_test.test_suite import COMMON_TORCH_MLIR_LOWERING_XFAILS +from torch_mlir_e2e_test.test_suite import ( + register_all_tests as torch_mlir_register_all_tests, +) + +from pi.dialects import ( + remove_modules, + ImportOverload, + BASE_OVERLOADS, + patch_meta_path, +) +from pi.mlir_utils import lower_pi_to_linalg +from pi.testing.util import ( + PIConfig, +) +from pi.testing.util import ( + TorchDialectConfig, + set_weights, + lower_torch_mlir_to_linalg, +) + + +def run_torch_mlir_tests(): + torch_mlir_register_all_tests() + import torch_mlir_e2e_test.registry + import torch_mlir_e2e_test.framework + + tu = torch_mlir_e2e_test.framework.TestUtils() + tests = sorted( + torch_mlir_e2e_test.registry.GLOBAL_TEST_REGISTRY, key=lambda t: t.unique_name + ) + + torch_dialect_config = TorchDialectConfig() + torch_mlir_linalg_module_strs = {} + for test in tests: + if test.unique_name in PI_XFAIL_SET | COMMON_TORCH_MLIR_LOWERING_XFAILS: + continue + + mod = test.program_factory() + set_weights(mod) + mod.eval() + torch_mlir_module = torch_dialect_config.compile(mod) + torch_mlir_linalg_module_strs[test.unique_name] = str( + lower_torch_mlir_to_linalg(torch_mlir_module) + ) + # test.program_invoker(mod, tu) + + return torch_mlir_linalg_module_strs + + +def run_pi_tests(torch_mlir_linalg_module_strs): + torch_mlir_register_all_tests() + # after remapping, this imports pi test registry + import torch_mlir_e2e_test.registry + + tests = sorted( + torch_mlir_e2e_test.registry.GLOBAL_TEST_REGISTRY, key=lambda t: t.unique_name + ) + assert tests, "failed to load tests" + + from torch import nn + + assert ( + nn.__spec__.origin == f"{pi_package_root_path}/nn/__init__.py" + ), f"monkey patch failed {nn.__spec__.origin}" + # for compatibility + nn.Module.train = lambda *args, **kwargs: None + + # torchscript_config = TorchScriptTestConfig() + pi_config = PIConfig() + torch_dialect_config = TorchDialectConfig() + PASS, FAIL, TOTAL = 0, 0, 0 + for test in tests: + if test.unique_name in PI_XFAIL_SET | COMMON_TORCH_MLIR_LOWERING_XFAILS: + print(f"skipping {test.unique_name}") + continue + print(f"running {test.unique_name}") + TOTAL += 1 + + test_module = test.program_factory() + torch_mlir_linalg_module_str = torch_mlir_linalg_module_strs[test.unique_name] + + try: + pi_mlir_module = pi_config.compile(test, test_module) + except NotImplementedError as e: + print(traceback.format_exc(-2)) + print(f"{e}") + print(f"FAIL pi compile NotImplementedError") + FAIL += 1 + continue + except Exception as e: + print(traceback.format_exc()) + print(f"{e}") + print("\ntorch_mlir module\n") + print(torch_mlir_linalg_module_str) + # torch_script_compiled = torchscript_config.compile(mod) + # frozen = torch.jit.freeze(torch_script_compiled) + # torch_mlir_module = torch_dialect_config.compile(mod) + # print("frozen.graph\n", frozen.graph) + # print("torch_mlir_module\n", torch_mlir_module) + print(f"FAIL pi compile Exception") + FAIL += 1 + raise e + + try: + pi_mlir_linalg_module_str = str( + lower_pi_to_linalg(pi_mlir_module) + ) + except Exception as e: + print(traceback.format_exc()) + print("\ntorch_mlir module\n") + print(torch_mlir_linalg_module_str) + print(f"FAIL lower_pi_to_linalg Exception") + FAIL += 1 + raise e + + diff = list( + difflib.unified_diff( + str(pi_mlir_linalg_module_str).splitlines(), + str(torch_mlir_linalg_module_str).splitlines(), + lineterm="", + ) + ) + + if len(diff): + print(f"\n{''.join('*' * 10)}\ndiff\n{''.join('*' * 10)}\n") + print("\n".join(diff)) + print() + # print("torch_mlir_linalg_module_str:\n", torch_mlir_linalg_module_str) + # print("pi_mlir_linalg_module_str:\n", pi_mlir_linalg_module_str) + print(f"FAIL IR diff") + FAIL += 1 + + # torch_script_compiled = torchscript_config.compile(mod) + # frozen = torch.jit.freeze(torch_script_compiled) + # torch_mlir_module = torch_dialect_config.compile(mod) + # print("frozen.graph\n", frozen.graph) + # print("torch_mlir_module\n", torch_mlir_module) + else: + print("PASS") + PASS += 1 + + print(f"\n{''.join('*' * 10)}\n\n{PASS=} {FAIL=} out of {TOTAL=}\n\n{''.join('*' * 10)}\n") + + +def main(): + torch_mlir_linalg_module_strs = run_torch_mlir_tests() + remove_modules(lambda mod: mod.startswith("torch_mlir_e2e_test")) + remove_modules(lambda mod: mod == "torch" or mod.startswith("torch.")) + + # remap to torch so that isintance works... + # remove_modules(lambda mod: mod == "pi" or mod.startswith("pi.")) + + overloads = [ + ImportOverload( + "torch_mlir_e2e_test.framework", + pi_package_root_path / "testing/framework.py", + False, + ), + ImportOverload( + "torch_mlir_e2e_test.registry", + pi_package_root_path / "testing/registry.py", + False, + ), + ImportOverload( + "torch_mlir_e2e_test.annotations", + pi_package_root_path / "compiler/annotations.py", + False, + ), + ImportOverload( + "torch", + pi_package_root_path / "__init__.py", + True, + ), + ImportOverload( + "torch.jit._shape_functions", + Path(""), + False, + ), + # ImportOverload( + # "torch._functorch", + # pi_package_root_path / "_torch/_functorch/__init__.py", + # True, + # ), + ] + overloads = {o.name: o for o in overloads} + overloads.update(BASE_OVERLOADS) + with patch_meta_path(overloads): + run_pi_tests(torch_mlir_linalg_module_strs) + + +if __name__ == "__main__": + main() diff --git a/tests/torch_mlir/xfail.py b/tests/torch_mlir/xfail.py new file mode 100644 index 0000000..5015ff6 --- /dev/null +++ b/tests/torch_mlir/xfail.py @@ -0,0 +1,249 @@ +PI_XFAIL_SET = { + # view op + "ElementwiseFlattenBroadcastModule_basic", + "FlattenDynamicModule_basic", + "ElementwiseExpm1Module_basic", + "FlattenRank0Module_basic", + # unknown + "UniformModule_basic", + "BernoulliFloatModule_basic", + # torchvision + "IouOfModule_basic", + "ResNet18Module", + "ResNet18Module_basic", + "ResNet18StaticModule_basic", + "MobilenetV3Module_basic", + # deprecated _convolution signature + "_ConvolutionDeprecated2DAllFalseModule_basic", + "_ConvolutionDeprecated2DBenchmarkModule_basic", + "_ConvolutionDeprecated2DDeterministicModule_basic", + "_ConvolutionDeprecated2DCudnnModule_basic", + # tuple returns + "VarMeanUnbiasedModule_basic", + "VarMeanBiasedModule_basic", + "TestMultipleTensorReturn_basic", + "TestMultipleTensorAndPrimitiveTypesReturn_basic", + "MaxPool2dWithIndicesModule_basic", + "MaxPool2dWithIndicesCeilModeTrueModule_basic", + "MaxPool2dWithIndicesAllOnesModule_basic", + "ElementwiseClampModule_basic", + "ElementwiseClampMaxModule_basic", + "ElementwiseClampMinModule_basic", + "NativeLayerNormModule_basic", + "NativeLayerNormDynamicModule_basic", + "NativeBatchNormNoneWeightModule_basic", + "NativeBatchNorm3DModule_basic", + "NativeBatchNorm2DModule_basic", + "NativeBatchNorm1DModule_basic", + "Aten_EmbeddingBagExample_basic", + "DropoutTrainModule_basic", + "ReturnThreeTensorFloat32_basic", + "ReturnTwoTensorF32I64_basic", + # type/api overload + "ToDtypeLayoutStridedModule_basic", + "SqueezeModule_broadcast", + "SqueezeDimModule_unitDim", + "SqueezeDimModule_static", + "SqueezeDimModule_negDim", + "SqueezeDimModule_identity", + "SqueezeDimModule_dynamic", + "ArangeZeroElementOutputModule_basic", + "ArangeStartStepIntModule_basic", + "ArangeStartStepFloatModule_basic", + "ArangeStartNegativeStepIntModule_basic", + "ArangeStartNegativeStepFloatModule_basic", + "ArangeStartIntModule_basic", + "ArangeStartFloatModule_basic", + "ArangeNegativeStartIntModule_basic", + "ArangeNegativeStartFloatModule_basic", + "ArangeFloatModule_basic", + "ArangeIntModule_basic", + "ArangeFalsePinMemoryModule_basic", + "ArangeDtypeIntModule_basic", + "ArangeDtypeFloatModule_basic", + "VarMeanCorrectionNoneModule_basic", + "VarMeanCorrectionModule_basic", + "VarDimUnbiasedModule_basic", + "VarDimSingleDimModule_basic", + "VarDimNoneDimModule_basic", + "VarDimNegativeModule_basic", + "VarDimMultiDimModule_basic", + "VarDimModule_basic", + "VarDimEmptyDimModule_basic", + "VarDimBiasedModule_basic", + "VarDimAllDimReduceModule_basic", + "VarCorrectionSingleDimReduceModule_basic", + "VarCorrectionNoneModule_basic", + "VarCorrectionModule_basic", + "VarCorrectionLargeInputModule_basic", + "VarCorrectionKeepDimModule_basic", + "VarCorrectionEmptyDimModule_basic", + "VarCorrectionAllDimReduceModule_basic", + "StdDimNoneDimModule_basic", + "StdDimKeepDimTrueModule_basic", + "StdDimKeepDimFalseModule_basic", + "StdDimEmptyDimModule_basic", + "StdDimBiasedModule_basic", + "StdCorrectionSingleDimReduceModule_basic", + "StdCorrectionNoneModule_basic", + "StdCorrectionModule_basic", + "StdCorrectionLargeInputModule_basic", + "StdCorrectionKeepDimModule_basic", + "StdCorrectionEmptyDimModule_basic", + "StdCorrectionAllDimReduceModule_basic", + "MeanDimNoneDimModule_basic", + "MeanDimNegativeModule_basic", + "MeanDimModule_basic", + "MeanDimKeepdimModule_basic", + "MeanDimEmptyDimModule_basic", + "MeanDimDtypeModule_basic", + "MeanDimAllReduceModule_basic", + "MeanDimAllReduceKeepdimModule_basic", + "MaxPool2dWithIndicesStaticModule_basic", + "MaxPool2dWithIndicesNonDefaultStrideModule_basic", + "MaxPool2dWithIndicesNonDefaultParamsModule_basic", + "MaxPool2dWithIndicesNonDefaultPaddingModule_basic", + "MaxPool2dWithIndicesNonDefaultDilationModule_basic", + "MaxPool2dWithIndicesFullSizeKernelModule_basic", + "MaxPool2dWithIndicesAllNegativeValuesModule_basic", + "MaxPool2dModule_basic", + "MaxPool2dCeilModeTrueModule_basic", + "TensorToFloat_basic", + "TensorToFloatZeroRank_basic", + "RandnDtypeDeviceModule_basic", + "BernoulliModule_basic", + "HBC_basic", + "ElementwiseNeIntScalarModule_basic", + "ElementwiseNeFloatTensorModule_basic", + "ElementwiseLtIntTensorModule_basic", + "ElementwiseLtIntScalarModule_basic", + "ElementwiseLtFloatTensorModule_basic", + "ElementwiseLtFloatScalarModule_basic", + "ElementwiseLtDiffWidthScalarModule_basic", + "ElementwiseLeMixedIntScalarModule_basic", + "ElementwiseLeIntScalarModule_basic", + "ElementwiseLeFloatScalarModule_basic", + "ElementwiseLeFloatIntScalarModule_basic", + "ElementwiseGtMixed2ScalarModule_basic", + "ElementwiseGtIntTensorModule_basic", + "ElementwiseGtIntScalarModule_basic", + "ElementwiseGtFloatTensorModule_basic", + "ElementwiseGtFloatScalarModule_basic", + "ElementwiseGeMixedIntScalarModule_basic", + "ElementwiseGeIntScalarModule_basic", + "ElementwiseGeFloatScalarModule_basic", + "ElementwiseGeFloatIntScalarModule_basic", + "ElementwiseEqIntTensorModule_basic", + "ElementwiseEqIntScalarModule_basic", + "ElementwiseEqDiffWidthScalarModule_basic", + "AnyBoolTrueModule_basic", + "AnyBoolFalseModule_basic", + "AllBoolTrueModule_basic", + "NeFloatIntModule_basic", + "GtFloatIntModule_basic", + "GeFloatModule_basic", + "GeFloatIntModule_basic", + "SubFloatModule_basic", + "SqrtIntConstantModule_basic", + "DivFloatModule_basic", + "CeilFloatModule_basic", + "ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic", + "ReduceSumDimIntListKeepDimIntModule_basic", + "ReduceSumDimIntListKeepDimFloatModule_basic", + "ReduceSumDimIntListIntModule_basic", + "ReduceSumDimIntListFloatModule_basic", + "ReduceSumDimIntListEmptyDimModule_basic", + "ReduceSumDimIntListElementTypeBoolModule_basic", + "ReduceSumDimIntListDtypeIntModule_basic", + "ReduceSumDimIntListDtypeFloatModule_basic", + "ReduceMaxNegativeDim_basic", + "ReduceMaxKeepDim_basic", + "ReduceMaxKeepDimReturnBoth_basic", + "ReduceMaxAlongDim_basic", + "ReduceMaxAlongDimNegative_basic", + "ReduceAmaxSingleDim_basic", + "TypePromotionZeroRankHigherCategoryModule_basic", + "TypePromotionSameCategoryZeroRankWider_basic", + "TypePromotionSameCategoryDifferentWidthModule_basic", + "TypePromotionDifferentCategoryModule_basic", + "TypePromotionAlphaWiderModule_basic", + "RsubIntModule_noalpha_basic", + "RsubIntModule_basic", + "RsubFloatModule_noalpha_basic", + "RsubFloatModule_basic", + "ElementwiseWhereSelfModule_basic", + "ElementwiseWhereScalarSelfModule_basic", + "ElementwiseWhereScalarOtherModule_basic", + "ElementwiseWhereScalarModule_basic", + "ElementwiseTernaryModule_basic", + "ElementwiseSubScalarIntModule_basic", + "ElementwiseSubScalarFloatModule_basic", + "ElementwiseRemainderScalarModule_Int_basic", + "ElementwiseRemainderScalarModule_Int_Float_basic", + "ElementwiseRemainderScalarModule_Float_basic", + "ElementwiseRemainderScalarModule_Bool_basic", + "ElementwisePowTensorModule_basic", + "ElementwisePowTensorBroadcastModule_basic", + "ElementwisePowModule_basic", + "ElementwiseOrIntegerModule_basic", + "ElementwiseMulTensorIntModule_basic", + "ElementwiseMulTensorFloatModule_basic", + "ElementwiseMulScalarModule_int", + "ElementwiseMulScalarModule_float", + "ElementwiseMulScalarModule_basic", + "ElementwiseDivTensorFloatModule_basic", + "ElementwiseDivScalarModule_basic", + "ElementwiseDivRoundingModeTruncModule_basic", + "ElementwiseDivRoundingModeFloorModule_basic", + "ElementwiseAtenWhereSelfModule_basic", + "ElementwiseAndIntegerModule_basic", + "ElementwiseAddScalarIntModule_basic", + "ElementwiseAddScalarInt64Module_basic", + "ElementwiseAddScalarFloatModule_basic", + "BroadcastToSameRankStaticModule_basic", + "BroadcastZeroRankInputStaticModule_basic", + "ScalarImplicitFloatModule_basic", + "ScalarImplicitIntModule_basic", + "SortIntListReverse_basic", + "SortIntList_basic", + "TransposeIntModule_basic", + "TransposeIntNegDimsModule_basic", + # segfault (lol) + "ZeroInt64Module_basic", + "ZeroInt32Module_basic", + "ZeroFloat32Module_basic", + "CopyModule_basic", + "CopyWithDifferentDTypesAndSizesModule_basic", + "CopyWithDifferentDTypesModule_basic", + "CopyWithDifferentSizesModule_basic", + # eager/lazy materialization + "ZerosLikeModule_int", + "ZerosLikeModule_float", + "ZerosLikeModule_falsePinMemory", + "ZerosLikeModule_defaultDtype", + "OnesLikeModule_int", + "OnesLikeModule_float", + "OnesLikeModule_falsePinMemory", + "OnesLikeModule_defaultDtype", + "EmptyLikeModule_int", + "EmptyLikeModule_float", + "EmptyLikeModule_falsePinMemory", + "EmptyLikeModule_defaultDtype", + "EmptyLikeMemoryFormatModule_basic", + "ElementwiseAtenLogicalOrOpNegativeModule_basic", + "LayerNormNormalizeOverAllDimsModule_basic", + "LayerNormModule_basic", + "LayerNormLastDimModule_basic", + "BatchNorm3DModule_basic", + "BatchNorm2DModule_basic", + "BatchNorm1DWith2DInputModule_basic", + "BatchNorm1DModule_basic", + "TensorLiteralModule", + "AllBoolFalseModule_basic", + "TensorIntModule_basic", + "TensorLiteralModule_basic", + "TensorOpaqueLiteralModule_basic", + # backends + "ConvolutionBackwardModule2DPadded_basic", + "ConvolutionBackwardModule2D_basic", +} diff --git a/tutorial/benchmark.py b/tutorial/benchmark.py index 3f3fb7f..83c8efe 100644 --- a/tutorial/benchmark.py +++ b/tutorial/benchmark.py @@ -7,11 +7,11 @@ import numpy as np -from shark import ir -from shark.compiler.config import DEBUG -from shark.dialects import func, arith, memref, scf -from shark.execution_engine import ExecutionEngine -from shark.runtime import get_ranked_memref_descriptor +from pi import ir +from pi.compiler.config import DEBUG +from pi.dialects import func, arith, memref, scf +from pi.execution_engine import ExecutionEngine +from pi.runtime import get_ranked_memref_descriptor from refbackend import assert_arg_type_is_supported diff --git a/tutorial/linalg_tut.py b/tutorial/linalg_tut.py index 2429ad2..08b6837 100644 --- a/tutorial/linalg_tut.py +++ b/tutorial/linalg_tut.py @@ -8,7 +8,7 @@ BUFFERIZATION_PIPELINE, LOWER_LLVM_PIPELINE, ) -from shark.ir import ( +from pi.ir import ( Context, Location, Module, @@ -16,9 +16,9 @@ RankedTensorType, F64Type, ) -from shark.compiler.config import MLIR_C_RUNNER_UTILS, MLIR_RUNNER_UTILS -from shark.compiler.utils import run_pipeline_with_repro_report -from shark.dialects import func, linalg +from pi.compiler.config import MLIR_C_RUNNER_UTILS, MLIR_RUNNER_UTILS +from pi.compiler.utils import run_pipeline_with_repro_report +from pi.dialects import func, linalg M = 32 N = 32 diff --git a/tutorial/passes.py b/tutorial/passes.py index 529059e..fadff7b 100644 --- a/tutorial/passes.py +++ b/tutorial/passes.py @@ -1,9 +1,9 @@ import re -from shark.ir import Type, Context, InsertionPoint, Location -from shark.compiler.utils import add_dummy_value, traverse_op_region_block_iterators +from pi.ir import Type, Context, InsertionPoint, Location +from pi.compiler.utils import add_dummy_value, traverse_op_region_block_iterators -from shark.dialects import memref +from pi.dialects import memref def promote_alloc(module): diff --git a/tutorial/refbackend.py b/tutorial/refbackend.py index a765523..2cd85af 100644 --- a/tutorial/refbackend.py +++ b/tutorial/refbackend.py @@ -11,13 +11,13 @@ "RefBackendLinalgOnTensorsBackend", ] -from shark.ir import Module -from shark.compiler.config import MLIR_C_RUNNER_UTILS, MLIR_RUNNER_UTILS -from shark.compiler.utils import run_pipeline_with_repro_report +from pi.ir import Module +from pi.compiler.config import MLIR_C_RUNNER_UTILS, MLIR_RUNNER_UTILS +from pi.compiler.utils import run_pipeline_with_repro_report -from shark.execution_engine import ExecutionEngine +from pi.execution_engine import ExecutionEngine -from shark.runtime import ( +from pi.runtime import ( UnrankedMemRefDescriptor, unranked_memref_to_numpy, get_unranked_memref_descriptor,