diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index 10a3423..2598cfe 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -1,12 +1,11 @@
name: Build
on:
- pull_request:
- branches:
- - main
- push:
- branches:
- - main
+ workflow_run:
+ workflows: [ "Test" ]
+ types:
+ - completed
+ branches: [ main ]
workflow_dispatch:
branches:
- main
@@ -44,11 +43,11 @@ jobs:
--extra-index-url https://download.pytorch.org/whl/nightly/cpu
mkdir -p ${{ github.sha }}
- mv wheelhouse/SharkPy*.whl ${{ github.sha }}/
+ mv wheelhouse/PI*.whl ${{ github.sha }}/
- name: Upload an artifact
uses: actions/upload-artifact@v3
- if: github.event_name == 'push'
+ if: github.event_name == 'push' || github.event_name == 'workflow_dispatch'
with:
if-no-files-found: error
name: build_artifact
@@ -62,7 +61,7 @@ jobs:
needs: [ build ]
- if: ${{ github.event_name == 'push' }}
+ if: github.event_name == 'push' || github.event_name == 'workflow_dispatch'
steps:
- name: Checkout
uses: actions/checkout@v2
@@ -76,10 +75,10 @@ jobs:
- name: Set up a release page
id: setup_release
run: |
- SHARKPY_VERSION=$(python setup.py --version)
- tag_name="$SHARKPY_VERSION"
- release_title="SharkPy $SHARKPY_VERSION"
- echo "SharkPy $SHARKPY_VERSION created at $(date)" > body.md
+ PI_VERSION=$(python setup.py --version)
+ tag_name="$PI_VERSION"
+ release_title="PI $PI_VERSION"
+ echo "PI $PI_VERSION created at $(date)" > body.md
echo "tag_name=${tag_name}" >> $GITHUB_OUTPUT
echo "release_title=${release_title}" >> $GITHUB_OUTPUT
@@ -88,7 +87,7 @@ jobs:
with:
artifacts: "${{ github.sha }}/*.whl"
bodyFile: body.md
- token: "${{ secrets.SHARK_PY_CI }}"
+ token: "${{ secrets.PI_CI }}"
tag: "${{ steps.setup_release.outputs.tag_name }}"
name: "${{ steps.setup_release.outputs.release_title }}"
removeArtifacts: true
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
new file mode 100644
index 0000000..6593b8a
--- /dev/null
+++ b/.github/workflows/test.yml
@@ -0,0 +1,53 @@
+name: Test
+
+on:
+ pull_request:
+ branches:
+ - main
+ - nn_module
+ push:
+ branches:
+ - main
+ - nn_module
+ workflow_dispatch:
+ branches:
+ - main
+ - nn_module
+
+jobs:
+
+ test-against-torch-mlir:
+
+ runs-on: ${{ matrix.os }}
+
+ strategy:
+ matrix:
+ os: [ ubuntu-latest ]
+ arch: [ x86_64 ]
+ python_version: [ "3.10" ]
+
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v2
+
+# - name: Install linux system packages
+# run: |
+# sudo apt-get update
+# sudo apt-get -y install ninja-build cmake clang
+
+ - name: Setup Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python_version }}
+
+ - name: Install
+ run: |
+ pip install . \
+ --pre torch-mlir torchvision \
+ -f https://llvm.github.io/torch-mlir/package-index/ \
+ --extra-index-url https://download.pytorch.org/whl/nightly/cpu \
+ -v
+
+ - name: Test vs. torch-mlir
+ run: |
+ PYTHONPATH=tests/torch_mlir python tests/torch_mlir/main.py
diff --git a/CMakeLists.txt b/CMakeLists.txt
new file mode 100644
index 0000000..17fe0db
--- /dev/null
+++ b/CMakeLists.txt
@@ -0,0 +1,62 @@
+cmake_minimum_required(VERSION 3.13.4)
+
+if (POLICY CMP0068)
+ cmake_policy(SET CMP0068 NEW)
+ set(CMAKE_BUILD_WITH_INSTALL_NAME_DIR ON)
+endif ()
+
+if (POLICY CMP0075)
+ cmake_policy(SET CMP0075 NEW)
+endif ()
+
+if (POLICY CMP0077)
+ cmake_policy(SET CMP0077 NEW)
+endif ()
+
+if (POLICY CMP0116)
+ cmake_policy(SET CMP0116 NEW)
+endif ()
+
+project(PI LANGUAGES CXX C)
+
+set(CMAKE_BUILD_WITH_INSTALL_NAME_DIR ON)
+
+set(CMAKE_CXX_STANDARD 17 CACHE STRING "C++ standard to conform to")
+
+find_package(MLIR REQUIRED CONFIG)
+
+message(STATUS "Using MLIRConfig.cmake in: ${MLIR_DIR}")
+message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}")
+
+set(LLVM_RUNTIME_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/bin)
+set(LLVM_LIBRARY_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/lib)
+set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR})
+
+list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}")
+list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}")
+include(TableGen)
+include(AddLLVM)
+include(AddMLIR)
+include(HandleLLVMOptions)
+
+include_directories(${LLVM_INCLUDE_DIRS})
+include_directories(${MLIR_INCLUDE_DIRS})
+link_directories(${LLVM_BUILD_LIBRARY_DIR})
+add_definitions(${LLVM_DEFINITIONS})
+
+##################################### Bindings path hacks
+
+include(MLIRDetectPythonEnv)
+include(AddMLIRPython)
+mlir_configure_python_dev_packages()
+mlir_detect_pybind11_install()
+
+set(PYTHON_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/cpp_ext) # --src-root
+set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR})
+# set(MLIR_TABLEGEN_EXE "" CACHE STRING "Path to mlir-tablegen")
+# message(STATUS "MLIR_TABLEGEN_EXE: ${MLIR_TABLEGEN_EXE}")
+set(MLIR_INCLUDE_TESTS 0)
+
+pybind11_add_module(_mlir cpp_ext/MainModule.cpp cpp_ext/TensorValue.cpp cpp_ext/TorchTypes.cpp)
+#target_link_libraries(_mlir PRIVATE MLIRIR MLIRSupport MLIRCAPIInterfaces MLIRCAPIIR)
+
diff --git a/README.md b/README.md
index cd65eb5..931bcf7 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,4 @@
-- [SharkPy](#sharkpy)
+- [PI](#PI)
- [Installing](#installing)
- [Minimal example](#minimal-example)
- [Moderately interesting example](#moderately-interesting-example)
@@ -8,7 +8,7 @@
-# SharkPy
+# PI
Early days of a Python frontend for MLIR.
@@ -25,10 +25,10 @@ pip install . \
and you're good to go.
-Alternatively, you can install the [latest released wheel](https://github.com/nod-ai/SharkPy/releases/latest):
+Alternatively, you can install the [latest released wheel](https://github.com/nod-ai/PI/releases/latest):
```shell
-pip install https://github.com/nod-ai/SharkPy/releases/latest/download/SharkPy-$CURRENT_VERSION-py3-none-any.whl \
+pip install https://github.com/nod-ai/PI/releases/latest/download/PI-$CURRENT_VERSION-py3-none-any.whl \
--pre torch-mlir torchvision \
-f https://llvm.github.io/torch-mlir/package-index/ \
--extra-index-url https://download.pytorch.org/whl/nightly/cpu
@@ -39,7 +39,7 @@ pip install https://github.com/nod-ai/SharkPy/releases/latest/download/SharkPy-$
[simple_kernels.py](./tests/simple_kernels.py) (in [tests](./tests)) looks like this
```python
-from shark.dialects import memref, linalg
+from pi.dialects import memref, linalg
def saxpy(a: float, b: float):
A = memref.AllocaOp((10, 30))
@@ -233,29 +233,47 @@ func.func private @saxpy(%arg0: f64, %arg1: f64) -> memref<10x20xf64> {
Preliminary support for the `torch-mlir` dialect is available:
```python
-def torch_ops():
- f64 = F64Type.get()
- z = torch.ConstantFloatOp(value=FloatAttr.get(f64, 256.0))
- attr = DenseFPElementsAttr(Attribute.parse("dense<0.0> : tensor<3x5xf32>"))
- a = torch.ValueTensorLiteralOp(attr)
- b = torch.ValueTensorLiteralOp(attr)
- c = torch.AtenAddTensorOp(a.result.type, a.result, b.result, z)
- return c
+class MyConv2d(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.conv = nn.Conv2d(3, 1, 3)
+
+ def forward(self, x):
+ y = self.conv(x)
+ z = y + y
+ w = z * z
+ return w
```
lowers to
```mlir
-func.func private @torch_ops() -> !torch.vtensor<[3,5],f32> {
- %float2.560000e02 = torch.constant.float 2.560000e+02
- %0 = torch.vtensor.literal(dense<0.000000e+00> : tensor<3x5xf32>) : !torch.vtensor<[3,5],f32>
- %1 = torch.vtensor.literal(dense<0.000000e+00> : tensor<3x5xf32>) : !torch.vtensor<[3,5],f32>
- %2 = torch.aten.add.Tensor %0, %1, %float2.560000e02 :
- !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>, !torch.float -> !torch.vtensor<[3,5],f32>
- return %2 : !torch.vtensor<[3,5],f32>
+module {
+ func.func private @simple_conv2d() -> !torch.vtensor {
+ %0 = torch.vtensor.literal(dense<1.000000e+00> : tensor<1x3x32x32xf32>) : !torch.vtensor<[1,3,32,32],f32>
+ %1 = torch.vtensor.literal(dense<1.000000e+00> : tensor<1xf32>) : !torch.vtensor<[1],f32>
+ %2 = torch.vtensor.literal(dense<1.000000e+00> : tensor<1x3x3x3xf32>) : !torch.vtensor<[1,3,3,3],f32>
+ %int1 = torch.constant.int 1
+ %int1_0 = torch.constant.int 1
+ %3 = torch.prim.ListConstruct %int1, %int1_0 : (!torch.int, !torch.int) -> !torch.list
+ %int0 = torch.constant.int 0
+ %int0_1 = torch.constant.int 0
+ %4 = torch.prim.ListConstruct %int0, %int0_1 : (!torch.int, !torch.int) -> !torch.list
+ %int1_2 = torch.constant.int 1
+ %int1_3 = torch.constant.int 1
+ %5 = torch.prim.ListConstruct %int1_2, %int1_3 : (!torch.int, !torch.int) -> !torch.list
+ %int1_4 = torch.constant.int 1
+ %6 = torch.aten.conv2d %0, %2, %1, %3, %4, %5, %int1_4 : !torch.vtensor<[1,3,32,32],f32>, !torch.vtensor<[1,3,3,3],f32>, !torch.vtensor<[1],f32>, !torch.list, !torch.list, !torch.list, !torch.int -> !torch.vtensor
+ %7 = "torch.constant.number"() {value = 1 : i64} : () -> !torch.number
+ %8 = torch.aten.add.Tensor %6, %6, %7 : !torch.vtensor, !torch.vtensor, !torch.number -> !torch.vtensor
+ %9 = torch.aten.mul.Tensor %8, %8 : !torch.vtensor, !torch.vtensor -> !torch.vtensor
+ return %9 : !torch.vtensor
+ }
}
```
+This is very rough right now; to get a rough idea of the current status check the [latest tests](https://github.com/nod-ai/PI/actions?query=branch%3Ann_module+) on the `nn_module` branch.
+
# Build Wheel
```shell
diff --git a/cpp_ext/IRModule.h b/cpp_ext/IRModule.h
new file mode 100644
index 0000000..bf7679d
--- /dev/null
+++ b/cpp_ext/IRModule.h
@@ -0,0 +1,158 @@
+//===- IRModules.h - IR Submodules of pybind module -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H
+#define MLIR_BINDINGS_PYTHON_IRMODULES_H
+
+#include
+#include
+#include
+
+#include
+#include
+#include
+
+#include "mlir-c/AffineExpr.h"
+#include "mlir-c/AffineMap.h"
+#include "mlir-c/Diagnostics.h"
+#include "mlir-c/IR.h"
+#include "mlir-c/IntegerSet.h"
+
+namespace py = pybind11;
+
+namespace mlir::python {
+
+class PyOperation;
+
+/// Template for a reference to a concrete type which captures a python
+/// reference to its underlying python object.
+template
+class PyObjectRef {
+public:
+ PyObjectRef(T *referrent, pybind11::object object)
+ : referrent(referrent), object(std::move(object)) {
+ assert(this->referrent && "cannot construct PyObjectRef with null referrent");
+ assert(this->object && "cannot construct PyObjectRef with null object");
+ }
+ PyObjectRef(PyObjectRef &&other)
+ : referrent(other.referrent), object(std::move(other.object)) {
+ other.referrent = nullptr;
+ assert(!other.object);
+ }
+ PyObjectRef(const PyObjectRef &other)
+ : referrent(other.referrent), object(other.object /* copies */) {}
+ ~PyObjectRef() = default;
+
+ T *operator->() {
+ assert(referrent && object);
+ return referrent;
+ }
+ pybind11::object getObject() {
+ assert(referrent && object);
+ return object;
+ }
+ explicit operator bool() const { return referrent && object; }
+
+private:
+ T *referrent;
+ pybind11::object object;
+};
+
+/// Wrapper around MlirContext.
+class PyMlirContext {
+public:
+ PyMlirContext() = delete;
+ PyMlirContext(const PyMlirContext &) = delete;
+ PyMlirContext(PyMlirContext &&) = delete;
+
+ explicit PyMlirContext(MlirContext context) : context(context){};
+
+ MlirContext context;
+ friend class PyModule;
+ friend class PyOperation;
+};
+
+using PyMlirContextRef = PyObjectRef;
+
+class BaseContextObject {
+public:
+ explicit BaseContextObject(PyMlirContextRef ref) : contextRef(std::move(ref)) {
+ assert(this->contextRef && "context object constructed with null context ref");
+ }
+ PyMlirContextRef contextRef;
+};
+
+class PyOperation : public BaseContextObject {
+public:
+ PyOperation &getOperation() { return *this; }
+ PyOperation(PyMlirContextRef contextRef, MlirOperation operation) : BaseContextObject(std::move(contextRef)), operation(operation){};
+
+ pybind11::handle handle;
+ MlirOperation operation;
+ pybind11::object parentKeepAlive;
+ bool attached = true;
+ bool valid = true;
+
+ friend class PyOperationBase;
+ friend class PySymbolTable;
+};
+
+using PyOperationRef = PyObjectRef;
+
+class PyValue {
+public:
+ PyValue(PyOperationRef parentOperation, MlirValue value)
+ : parentOperation(std::move(parentOperation)), value(value) {}
+ explicit operator MlirValue() const { return value; }
+
+private:
+ PyOperationRef parentOperation;
+ MlirValue value;
+};
+
+struct PyType : public BaseContextObject {
+ PyType(PyMlirContextRef contextRef, MlirType type)
+ : BaseContextObject(std::move(contextRef)), type(type) {}
+ explicit operator MlirType() const { return type; }
+ [[nodiscard]] MlirType get() const { return type; }
+
+
+ MlirType type;
+};
+
+
+
+template
+struct PyConcreteType : public BaseTy {
+// using ClassTy = pybind11::class_;
+
+ PyConcreteType() = default;
+ PyConcreteType(PyMlirContextRef contextRef, MlirType t)
+ : BaseTy(std::move(contextRef), t) {}
+
+ static DerivedTy createFromCapsule_(py::capsule& capsule) {
+ MlirType rawType = {capsule.get_pointer()};
+ if (mlirTypeIsNull(rawType))
+ throw py::error_already_set();
+
+ MlirContext ctx = mlirTypeGetContext(rawType);
+ auto *unownedContextWrapper = new PyMlirContext(ctx);
+ auto pyCtxRef = py::reinterpret_steal(mlirPythonContextToCapsule(ctx));
+ assert(pyCtxRef && "cast to py::object failed");
+ auto ctxRef = PyMlirContextRef(unownedContextWrapper, std::move(pyCtxRef));
+
+ return {std::move(ctxRef), rawType};
+ }
+
+};
+
+void populateTorchTypes(py::module &m);
+
+}// namespace mlir::python
+
+#endif// MLIR_BINDINGS_PYTHON_IRMODULES_H
diff --git a/cpp_ext/MainModule.cpp b/cpp_ext/MainModule.cpp
new file mode 100644
index 0000000..520958a
--- /dev/null
+++ b/cpp_ext/MainModule.cpp
@@ -0,0 +1,33 @@
+#include "dylib.hpp"
+#include "IRModule.h"
+#include "TensorValue.h"
+#include "TorchTypes.h"
+#include "TorchTypesCAPI.h"
+#include "mlir-c/Bindings/Python/Interop.h"
+
+#include
+
+namespace py = pybind11;
+using namespace mlir::python;
+
+// no clue why but without this i get a missing symbol error
+namespace llvm {
+int DisableABIBreakingChecks = 1;
+int EnableABIBreakingChecks = 0;
+}// namespace llvm
+
+PYBIND11_MODULE(_mlir, m) {
+// dylib lib1("_torchMlir.cpython-310-darwin.so", dylib::no_filename_decorations);
+ dylib lib2("TorchMLIRAggregateCAPI");
+
+// if (!lib1.has_symbol("mlirValueIsAOpResult"))
+// std::cerr << "symbol 'mlirValueIsAOpResult' not found in '_torchMlir' lib" << std::endl;
+ if (!lib2.has_symbol("mlirValueIsAOpResult"))
+ std::cerr << "symbol 'mlirValueIsAOpResult' not found in 'TorchMLIRAggregateCAPI' lib" << std::endl;
+ else
+ std::cerr << "found symbol 'mlirValueIsAOpResult' in 'TorchMLIRAggregateCAPI' lib" << std::endl;
+
+ bindValues(m);
+ bindTypes(m);
+ bindTypeHelpers(m);
+}
diff --git a/cpp_ext/TensorValue.cpp b/cpp_ext/TensorValue.cpp
new file mode 100644
index 0000000..e065439
--- /dev/null
+++ b/cpp_ext/TensorValue.cpp
@@ -0,0 +1,55 @@
+//===- TorchTypes.cpp - C Interface for torch types -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+// Also available under a BSD-style license. See LICENSE.
+//
+//===----------------------------------------------------------------------===//
+
+#include "TensorValue.h"
+#include "TorchTypesCAPI.h"
+
+#include "mlir/CAPI/Support.h"
+
+using namespace mlir;
+using namespace mlir::python;
+
+Torch_Tensor Torch_Tensor::createFromCapsule_(const py::capsule &capsule) {
+ MlirValue value = {capsule.get_pointer()};
+ if (mlirValueIsNull(value))
+ throw py::error_already_set();
+ MlirOperation owner;
+ if (mlirValueIsAOpResult(value))
+ owner = mlirOpResultGetOwner(value);
+ if (mlirValueIsABlockArgument(value))
+ owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
+ if (mlirOperationIsNull(owner))
+ throw py::error_already_set();
+
+ MlirContext ctx = mlirOperationGetContext(owner);
+ auto *unownedContextWrapper = new PyMlirContext(ctx);
+ auto pyCtxRef = py::reinterpret_steal(mlirPythonContextToCapsule(ctx));
+ assert(pyCtxRef && "cast to py::object failed");
+ auto ctxRef = PyMlirContextRef(unownedContextWrapper, std::move(pyCtxRef));
+
+ auto pyOpRef = py::reinterpret_steal(mlirPythonOperationToCapsule(owner));
+ auto *unownedOperation =
+ new PyOperation(std::move(ctxRef), owner);
+ unownedOperation->handle = pyOpRef;
+ auto ownerRef = PyOperationRef(unownedOperation, std::move(pyOpRef));
+
+ return {ownerRef, value};
+}
+
+void bindValues(py::module &m) {
+ py::object value_ =
+ (py::object) py::module_::import("torch_mlir.ir").attr("Value");
+ py::object op_result_ =
+ (py::object) py::module_::import("torch_mlir.ir").attr("OpResult");
+
+ py::class_(m, "_Torch_Tensor", value_)
+ .def(py::init<>([](const py::capsule &capsule) {
+ return Torch_Tensor::createFromCapsule_(capsule);
+ }));
+}
\ No newline at end of file
diff --git a/cpp_ext/TensorValue.h b/cpp_ext/TensorValue.h
new file mode 100644
index 0000000..ff0e2e9
--- /dev/null
+++ b/cpp_ext/TensorValue.h
@@ -0,0 +1,27 @@
+//===- TorchTypes.cpp - C Interface for torch types -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+// Also available under a BSD-style license. See LICENSE.
+//
+//===----------------------------------------------------------------------===//
+
+#include "IRModule.h"
+
+#include "mlir/CAPI/IR.h"
+#include "mlir/CAPI/Support.h"
+#include "mlir/IR/BuiltinTypes.h"
+
+using namespace mlir;
+using namespace mlir::python;
+
+
+struct Torch_Tensor : PyValue {
+ Torch_Tensor(PyOperationRef operationRef, MlirValue value)
+ : PyValue(std::move(operationRef), value) {}
+
+ static Torch_Tensor createFromCapsule_(const py::capsule& capsule);
+};
+
+void bindValues(py::module &m);
\ No newline at end of file
diff --git a/cpp_ext/TorchTypes.cpp b/cpp_ext/TorchTypes.cpp
new file mode 100644
index 0000000..51288b0
--- /dev/null
+++ b/cpp_ext/TorchTypes.cpp
@@ -0,0 +1,84 @@
+//===- TorchTypes.cpp - C Interface for torch types -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+// Also available under a BSD-style license. See LICENSE.
+//
+//===----------------------------------------------------------------------===//
+
+#include "TorchTypes.h"
+#include "IRModule.h"
+#include "TorchTypesCAPI.h"
+#include "mlir-c/BuiltinAttributes.h"
+
+using namespace mlir;
+using namespace mlir::python;
+
+void bindTypes(py::module &m) {
+ py::object type_ =
+ (py::object) py::module_::import("torch_mlir.ir").attr("Type");
+
+ py::class_(m, "_Torch_IntType", type_)
+ .def(py::init<>([](py::capsule capsule) {
+ return Torch_IntType::createFromCapsule_(capsule);
+ }));
+
+ py::class_(m, "_Torch_BoolType", type_)
+ .def(py::init<>([](py::capsule capsule) {
+ return Torch_BoolType::createFromCapsule_(capsule);
+ }));
+
+ py::class_(m, "_Torch_StringType", type_)
+ .def(py::init<>([](py::capsule capsule) {
+ return Torch_StringType::createFromCapsule_(capsule);
+ }));
+
+ py::class_(m, "_Torch_FloatType", type_)
+ .def(py::init<>([](py::capsule capsule) {
+ return Torch_FloatType::createFromCapsule_(capsule);
+ }));
+
+ py::class_(m, "_Torch_ValueTensorType", type_)
+ .def(py::init<>([](py::capsule capsule) {
+ return Torch_ValueTensorType::createFromCapsule_(capsule);
+ }));
+
+ py::class_(m, "_Torch_NonValueTensorType", type_)
+ .def(py::init<>([](py::capsule capsule) {
+ return Torch_NonValueTensorType::createFromCapsule_(capsule);
+ }));
+}
+
+void bindTypeHelpers(py::module &m) {
+ m.def(
+ "is_a_torch_int_type", [](const py::capsule &type) {
+ MlirType rawType = {type.get_pointer()};
+ return torchMlirTypeIsATorchInt(rawType);
+ });
+ m.def(
+ "is_a_torch_bool_type", [](const py::capsule &type) {
+ MlirType rawType = {type.get_pointer()};
+ return torchMlirTypeIsATorchBool(rawType);
+ });
+ m.def(
+ "is_a_torch_string_type", [](const py::capsule &type) {
+ MlirType rawType = {type.get_pointer()};
+ return torchMlirTypeIsATorchString(rawType);
+ });
+ m.def(
+ "is_a_torch_float_type", [](const py::capsule &type) {
+ MlirType rawType = {type.get_pointer()};
+ return torchMlirTypeIsATorchFloat(rawType);
+ });
+ m.def(
+ "is_a_torch_value_tensor_type", [](const py::capsule &type) {
+ MlirType rawType = {type.get_pointer()};
+ return torchMlirTypeIsATorchValueTensor(rawType);
+ });
+ m.def(
+ "is_a_torch_nonvalue_tensor_type", [](const py::capsule &type) {
+ MlirType rawType = {type.get_pointer()};
+ return torchMlirTypeIsATorchNonValueTensor(rawType);
+ });
+}
diff --git a/cpp_ext/TorchTypes.h b/cpp_ext/TorchTypes.h
new file mode 100644
index 0000000..bc3f99b
--- /dev/null
+++ b/cpp_ext/TorchTypes.h
@@ -0,0 +1,51 @@
+//===- TorchTypes.cpp - C Interface for torch types -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+// Also available under a BSD-style license. See LICENSE.
+//
+//===----------------------------------------------------------------------===//
+
+#include "IRModule.h"
+#include "TorchTypesCAPI.h"
+
+#include "mlir/CAPI/IR.h"
+#include "mlir/CAPI/Support.h"
+#include "mlir/IR/BuiltinTypes.h"
+
+using namespace mlir;
+using namespace mlir::python;
+
+struct Torch_IntType : public PyConcreteType {
+ Torch_IntType(PyMlirContextRef contextRef, MlirType t)
+ : PyConcreteType(std::move(contextRef), t) {}
+};
+
+struct Torch_BoolType : public PyConcreteType {
+ Torch_BoolType(PyMlirContextRef contextRef, MlirType t)
+ : PyConcreteType(std::move(contextRef), t) {}
+};
+
+struct Torch_StringType : public PyConcreteType {
+ Torch_StringType(PyMlirContextRef contextRef, MlirType t)
+ : PyConcreteType(std::move(contextRef), t) {}
+};
+
+struct Torch_FloatType : public PyConcreteType {
+ Torch_FloatType(PyMlirContextRef contextRef, MlirType t)
+ : PyConcreteType(std::move(contextRef), t) {}
+};
+
+struct Torch_ValueTensorType : public PyConcreteType {
+ Torch_ValueTensorType(PyMlirContextRef contextRef, MlirType t)
+ : PyConcreteType(std::move(contextRef), t) {}
+};
+
+struct Torch_NonValueTensorType : public PyConcreteType {
+ Torch_NonValueTensorType(PyMlirContextRef contextRef, MlirType t)
+ : PyConcreteType(std::move(contextRef), t) {}
+};
+
+void bindTypes(py::module &m);
+void bindTypeHelpers(py::module &m);
\ No newline at end of file
diff --git a/cpp_ext/TorchTypesCAPI.h b/cpp_ext/TorchTypesCAPI.h
new file mode 100644
index 0000000..761ebea
--- /dev/null
+++ b/cpp_ext/TorchTypesCAPI.h
@@ -0,0 +1,266 @@
+//===-- torch-mlir-c/TorchTypes.h - C API for torch types ---------*- C -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM
+// Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+// Also available under a BSD-style license. See LICENSE.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef TORCHMLIR_C_TORCHTYPES_H
+#define TORCHMLIR_C_TORCHTYPES_H
+
+#include "mlir-c/IR.h"
+#include "mlir-c/Support.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+//===----------------------------------------------------------------------===//
+// torch.nn.Module type.
+//===----------------------------------------------------------------------===//
+
+/// Checks whether the given type is a torch.nn.Module type
+MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNnModule(MlirType t);
+
+/// Gets the !torch.nn.Module type of the specified class.
+MLIR_CAPI_EXPORTED MlirType
+torchMlirTorchNnModuleTypeGet(MlirContext context, MlirStringRef className);
+
+//===----------------------------------------------------------------------===//
+// torch.optional type.
+//===----------------------------------------------------------------------===//
+
+/// Checks whether the given type is a !torch.optional type
+MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchOptional(MlirType t);
+
+/// Gets the !torch.optional type with subtype T.
+MLIR_CAPI_EXPORTED MlirType
+torchMlirTorchOptionalTypeGet(MlirType containedType);
+
+//===----------------------------------------------------------------------===//
+// torch.tuple type.
+//===----------------------------------------------------------------------===//
+
+/// Checks whether the given type is a !torch.tuple type
+MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchTuple(MlirType t);
+
+/// Gets the !torch.tuple type with contained types `containedTypes`.
+MLIR_CAPI_EXPORTED MlirType
+torchMlirTorchTupleTypeGet(MlirContext context, intptr_t numContainedTypes,
+ MlirType const *containedTypes);
+
+//===----------------------------------------------------------------------===//
+// torch.union type.
+//===----------------------------------------------------------------------===//
+
+/// Checks whether the given type is a !torch.union type
+MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchUnion(MlirType t);
+
+/// Gets the !torch.union type with contained types `containedTypes`.
+MLIR_CAPI_EXPORTED MlirType
+torchMlirTorchUnionTypeGet(MlirContext context, intptr_t numContainedTypes,
+ MlirType const *containedTypes);
+
+//===----------------------------------------------------------------------===//
+// torch.list type.
+//===----------------------------------------------------------------------===//
+
+/// Checks whether the given type is a !torch.list type
+MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchList(MlirType t);
+
+/// Gets the !torch.list type with contained T.
+MLIR_CAPI_EXPORTED MlirType torchMlirTorchListTypeGet(MlirType containedType);
+
+//===----------------------------------------------------------------------===//
+// torch.Device type.
+//===----------------------------------------------------------------------===//
+
+/// Checks whether the given type is a !torch.Device type
+MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchDevice(MlirType t);
+
+/// Gets the !torch.Device type.
+MLIR_CAPI_EXPORTED MlirType torchMlirTorchDeviceTypeGet(MlirContext context);
+
+//===----------------------------------------------------------------------===//
+// torch.Generator type.
+//===----------------------------------------------------------------------===//
+
+/// Checks whether the given type is a !torch.Generator type
+MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchGenerator(MlirType t);
+
+/// Gets the !torch.Generator type.
+MLIR_CAPI_EXPORTED MlirType torchMlirTorchGeneratorTypeGet(MlirContext context);
+
+//===----------------------------------------------------------------------===//
+// torch.bool type.
+//===----------------------------------------------------------------------===//
+
+/// Checks whether the given type is a !torch.bool type
+MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchBool(MlirType t);
+
+/// Gets the !torch.bool type.
+MLIR_CAPI_EXPORTED MlirType torchMlirTorchBoolTypeGet(MlirContext context);
+
+//===----------------------------------------------------------------------===//
+// torch.int type.
+//===----------------------------------------------------------------------===//
+
+/// Checks whether the given type is a !torch.int type
+MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchInt(MlirType t);
+
+/// Gets the !torch.int type.
+MLIR_CAPI_EXPORTED MlirType torchMlirTorchIntTypeGet(MlirContext context);
+
+//===----------------------------------------------------------------------===//
+// torch.float type.
+//===----------------------------------------------------------------------===//
+
+/// Checks whether the given type is a !torch.float type
+MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchFloat(MlirType t);
+
+/// Gets the !torch.float type.
+MLIR_CAPI_EXPORTED MlirType torchMlirTorchFloatTypeGet(MlirContext context);
+
+//===----------------------------------------------------------------------===//
+// torch.LinearParams type.
+//===----------------------------------------------------------------------===//
+
+/// Checks whether the given type is a !torch.LinearParams type
+MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchLinearParams(MlirType t);
+
+/// Gets the !torch.LinearParams type.
+MLIR_CAPI_EXPORTED MlirType
+torchMlirTorchLinearParamsTypeGet(MlirContext context);
+
+//===----------------------------------------------------------------------===//
+// torch.qint8 type.
+//===----------------------------------------------------------------------===//
+
+/// Checks whether the given type is a !torch.qint8 type
+MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchQInt8(MlirType t);
+
+/// Gets the !torch.qint8 type.
+MLIR_CAPI_EXPORTED MlirType torchMlirTorchQInt8TypeGet(MlirContext context);
+
+//===----------------------------------------------------------------------===//
+// torch.quint8 type.
+//===----------------------------------------------------------------------===//
+
+/// Checks whether the given type is a !torch.quint8 type
+MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchQUInt8(MlirType t);
+
+/// Gets the !torch.quint8 type.
+MLIR_CAPI_EXPORTED MlirType torchMlirTorchQUInt8TypeGet(MlirContext context);
+
+//===----------------------------------------------------------------------===//
+// torch.tensor type.
+//===----------------------------------------------------------------------===//
+
+/// Checks whether the given type is a !torch.tensor type
+MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNonValueTensor(MlirType t);
+
+/// Gets a !torch.tensor type.
+///
+/// - `numSizes` having a value of -1 denotes an unranked tensor.
+/// - `optionalSizes` is allowed to be null, meaning that no size
+/// information is present (and `numSizes` is ignored in that case). -
+/// `optionalDtype` is allowed to be null, meaning that no dtype
+/// information is present.
+MLIR_CAPI_EXPORTED MlirType torchMlirTorchNonValueTensorTypeGet(
+ MlirContext context, intptr_t numSizes, const int64_t *optionalSizes,
+ MlirType optionalDtype);
+
+/// Gets the !torch.tensor type with the least static information.
+MLIR_CAPI_EXPORTED MlirType
+torchMlirTorchNonValueTensorTypeGetWithLeastStaticInformation(
+ MlirContext context);
+
+/// Gets the !torch.tensor type with the tensor attribute.
+MLIR_CAPI_EXPORTED MlirType
+torchMlirTorchNonValueTensorTypeGetFromAttribute(MlirAttribute attr);
+
+//===----------------------------------------------------------------------===//
+// torch.vtensor type.
+//===----------------------------------------------------------------------===//
+
+/// Checks whether the given type is a !torch.vtensor type
+MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchValueTensor(MlirType t);
+
+/// Gets a !torch.vtensor type.
+///
+/// - `numSizes` having a value of -1 denotes an unranked tensor.
+/// - `optionalSizes` is allowed to be null, meaning that no size
+/// information is present (and `numSizes` is ignored in that case).
+/// - `optionalDtype` is allowed to be null, meaning that no dtype
+/// information is present.
+MLIR_CAPI_EXPORTED MlirType torchMlirTorchValueTensorTypeGet(
+ MlirContext context, intptr_t numSizes, const int64_t *optionalSizes,
+ MlirType optionalDtype);
+
+/// Gets the !torch.tensor type with the least static information.
+MLIR_CAPI_EXPORTED MlirType
+torchMlirTorchValueTensorTypeGetWithLeastStaticInformation(MlirContext context);
+
+/// Gets the !torch.vtensor type with the tensor attribute.
+MLIR_CAPI_EXPORTED MlirType
+torchMlirTorchValueTensorTypeGetFromAttribute(MlirAttribute attr);
+
+//===----------------------------------------------------------------------===//
+// !torch.none type.
+//===----------------------------------------------------------------------===//
+
+/// Checks whether the given type is a !torch.none type
+MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNone(MlirType t);
+
+/// Gets the !torch.none type.
+MLIR_CAPI_EXPORTED MlirType torchMlirTorchNoneTypeGet(MlirContext context);
+
+//===----------------------------------------------------------------------===//
+// !torch.str type.
+//===----------------------------------------------------------------------===//
+
+/// Checks whether the given type is a !torch.str type
+MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchString(MlirType t);
+
+/// Gets the !torch.str type.
+MLIR_CAPI_EXPORTED MlirType torchMlirTorchStringTypeGet(MlirContext context);
+
+//===----------------------------------------------------------------------===//
+// !torch.any type.
+//===----------------------------------------------------------------------===//
+
+/// Checks whether the given type is a !torch.any type.
+MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchAny(MlirType t);
+
+/// Gets the !torch.str type.
+MLIR_CAPI_EXPORTED MlirType torchMlirTorchAnyTypeGet(MlirContext context);
+
+//===----------------------------------------------------------------------===//
+// !torch.number type.
+//===----------------------------------------------------------------------===//
+
+/// Checks whether the given type is a !torch.number type.
+MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNumber(MlirType t);
+
+/// Gets the !torch.number type.
+MLIR_CAPI_EXPORTED MlirType torchMlirTorchNumberTypeGet(MlirContext context);
+
+//===----------------------------------------------------------------------===//
+// !torch.dict type.
+//===----------------------------------------------------------------------===//
+
+/// Checks whether the given type is a !torch.dict type.
+MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchDict(MlirType t);
+
+/// Gets the !torch.dict type.
+MLIR_CAPI_EXPORTED MlirType torchMlirTorchDictTypeGet(MlirType keyType,
+ MlirType valueType);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif// TORCHMLIR_C_TORCHTYPES_H
diff --git a/cpp_ext/dylib.hpp b/cpp_ext/dylib.hpp
new file mode 100644
index 0000000..62a9b39
--- /dev/null
+++ b/cpp_ext/dylib.hpp
@@ -0,0 +1,298 @@
+/**
+ * @file dylib.hpp
+ * @version 2.1.0
+ * @brief C++ cross-platform wrapper around dynamic loading of shared libraries
+ * @link https://github.com/martin-olivier/dylib
+ *
+ * @author Martin Olivier
+ * @copyright (c) 2022 Martin Olivier
+ *
+ * This library is released under MIT license
+ */
+
+#pragma once
+
+#include
+#include
+#include
+
+#if ((defined(_MSVC_LANG) && _MSVC_LANG >= 201703L) || __cplusplus >= 201703L)
+#define DYLIB_CPP17
+#include
+#endif
+
+#if defined(_WIN32) || defined(_WIN64)
+#define WIN32_LEAN_AND_MEAN
+#include
+#undef WIN32_LEAN_AND_MEAN
+#else
+#include
+#endif
+
+#if defined(_WIN32) || defined(_WIN64)
+#define DYLIB_WIN_MAC_OTHER(win_def, mac_def, other_def) win_def
+#define DYLIB_WIN_OTHER(win_def, other_def) win_def
+#elif defined(__APPLE__)
+#define DYLIB_WIN_MAC_OTHER(win_def, mac_def, other_def) mac_def
+#define DYLIB_WIN_OTHER(win_def, other_def) other_def
+#else
+#define DYLIB_WIN_MAC_OTHER(win_def, mac_def, other_def) other_def
+#define DYLIB_WIN_OTHER(win_def, other_def) other_def
+#endif
+
+/**
+ * The dylib class can hold a dynamic library instance and interact with it
+ * by getting its symbols like functions or global variables
+ */
+class dylib {
+public:
+ struct filename_components {
+ static constexpr const char *prefix = DYLIB_WIN_OTHER("", "lib");
+ static constexpr const char *suffix = DYLIB_WIN_MAC_OTHER(".dll", ".dylib", ".so");
+ };
+ using native_handle_type = DYLIB_WIN_OTHER(HINSTANCE, void *);
+ using native_symbol_type = DYLIB_WIN_OTHER(FARPROC, void *);
+
+ static_assert(std::is_pointer::value, "Expecting HINSTANCE to be a pointer");
+ static_assert(std::is_pointer::value, "Expecting FARPROC to be a pointer");
+
+ static constexpr bool add_filename_decorations = true;
+ static constexpr bool no_filename_decorations = false;
+
+ /**
+ * This exception is raised when the library failed to load a dynamic library or a symbol
+ *
+ * @param message the error message
+ */
+ class exception : public std::runtime_error {
+ public:
+ explicit exception(const std::string &message) : std::runtime_error(message) {}
+ };
+
+ /**
+ * This exception is raised when the library failed to load or encountered symbol resolution issues
+ *
+ * @param message the error message
+ */
+ class load_error : public exception {
+ public:
+ explicit load_error(const std::string &message) : exception(message) {}
+ };
+
+ /**
+ * This exception is raised when the library failed to load a symbol
+ *
+ * @param message the error message
+ */
+ class symbol_error : public exception {
+ public:
+ explicit symbol_error(const std::string &message) : exception(message) {}
+ };
+
+ dylib(const dylib&) = delete;
+ dylib& operator=(const dylib&) = delete;
+
+ dylib(dylib &&other) noexcept : m_handle(other.m_handle) {
+ other.m_handle = nullptr;
+ }
+
+ dylib& operator=(dylib &&other) noexcept {
+ if (this != &other)
+ std::swap(m_handle, other.m_handle);
+ return *this;
+ }
+
+ /**
+ * @brief Loads a dynamic library
+ *
+ * @throws dylib::load_error if the library could not be opened (including
+ * the case of the library file not being found)
+ *
+ * @param dir_path the directory path where is located the dynamic library you want to load
+ * @param name the name of the dynamic library to load
+ * @param decorations add os decorations to the library name
+ */
+ ///@{
+ dylib(const char *dir_path, const char *lib_name, bool decorations = add_filename_decorations) {
+ if (!dir_path || !lib_name)
+ throw std::invalid_argument("Null parameter");
+
+ std::string final_name = lib_name;
+ std::string final_path = dir_path;
+
+ if (decorations)
+ final_name = filename_components::prefix + final_name + filename_components::suffix;
+
+ if (final_path != "" && final_path.find_last_of('/') != final_path.size() - 1)
+ final_path += '/';
+
+ m_handle = open((final_path + final_name).c_str());
+
+ if (!m_handle)
+ throw load_error("Could not load library \"" + final_path + final_name + "\"\n" + get_error_description());
+ }
+
+ dylib(const std::string &dir_path, const std::string &lib_name, bool decorations = add_filename_decorations)
+ : dylib(dir_path.c_str(), lib_name.c_str(), decorations) {}
+
+ dylib(const std::string &dir_path, const char *lib_name, bool decorations = add_filename_decorations)
+ : dylib(dir_path.c_str(), lib_name, decorations) {}
+
+ dylib(const char *dir_path, const std::string &lib_name, bool decorations = add_filename_decorations)
+ : dylib(dir_path, lib_name.c_str(), decorations) {}
+
+ explicit dylib(const std::string &lib_name, bool decorations = add_filename_decorations)
+ : dylib("", lib_name.c_str(), decorations) {}
+
+ explicit dylib(const char *lib_name, bool decorations = add_filename_decorations)
+ : dylib("", lib_name, decorations) {}
+
+#ifdef DYLIB_CPP17
+ explicit dylib(const std::filesystem::path &lib_path)
+ : dylib("", lib_path.string().c_str(), no_filename_decorations) {}
+
+ dylib(const std::filesystem::path &dir_path, const std::string &lib_name, bool decorations = add_filename_decorations)
+ : dylib(dir_path.string().c_str(), lib_name.c_str(), decorations) {}
+
+ dylib(const std::filesystem::path &dir_path, const char *lib_name, bool decorations = add_filename_decorations)
+ : dylib(dir_path.string().c_str(), lib_name, decorations) {}
+#endif
+ ///@}
+
+ ~dylib() {
+ if (m_handle)
+ close(m_handle);
+ }
+
+ /**
+ * Get a symbol from the dynamic library currently loaded in the object
+ *
+ * @throws dylib::symbol_error if the symbol could not be found
+ *
+ * @param symbol_name the symbol name to get from the dynamic library
+ *
+ * @return a pointer to the requested symbol
+ */
+ native_symbol_type get_symbol(const char *symbol_name) const {
+ if (!symbol_name)
+ throw std::invalid_argument("Null parameter");
+ if (!m_handle)
+ throw std::logic_error("The dynamic library handle is null");
+
+ auto symbol = locate_symbol(m_handle, symbol_name);
+
+ if (symbol == nullptr)
+ throw symbol_error("Could not get symbol \"" + std::string(symbol_name) + "\"\n" + get_error_description());
+ return symbol;
+ }
+
+ native_symbol_type get_symbol(const std::string &symbol_name) const {
+ return get_symbol(symbol_name.c_str());
+ }
+
+ /**
+ * Get a function from the dynamic library currently loaded in the object
+ *
+ * @throws dylib::symbol_error if the symbol could not be found
+ *
+ * @param T the template argument must be the function prototype to get
+ * @param symbol_name the symbol name of a function to get from the dynamic library
+ *
+ * @return a pointer to the requested function
+ */
+ template
+ T *get_function(const char *symbol_name) const {
+ return reinterpret_cast(get_symbol(symbol_name));
+ }
+
+ template
+ T *get_function(const std::string &symbol_name) const {
+ return get_function(symbol_name.c_str());
+ }
+
+ /**
+ * Get a variable from the dynamic library currently loaded in the object
+ *
+ * @throws dylib::symbol_error if the symbol could not be found
+ *
+ * @param T the template argument must be the type of the variable to get
+ * @param symbol_name the symbol name of a variable to get from the dynamic library
+ *
+ * @return a reference to the requested variable
+ */
+ template
+ T &get_variable(const char *symbol_name) const {
+ return *reinterpret_cast(get_symbol(symbol_name));
+ }
+
+ template
+ T &get_variable(const std::string &symbol_name) const {
+ return get_variable(symbol_name.c_str());
+ }
+
+ /**
+ * Check if a symbol exists in the currently loaded dynamic library.
+ * This method will return false if no dynamic library is currently loaded
+ * or if the symbol name is nullptr
+ *
+ * @param symbol_name the symbol name to look for
+ *
+ * @return true if the symbol exists in the dynamic library, false otherwise
+ */
+ bool has_symbol(const char *symbol_name) const noexcept {
+ if (!m_handle || !symbol_name)
+ return false;
+ return locate_symbol(m_handle, symbol_name) != nullptr;
+ }
+
+ bool has_symbol(const std::string &symbol) const noexcept {
+ return has_symbol(symbol.c_str());
+ }
+
+ /**
+ * @return the dynamic library handle
+ */
+ native_handle_type native_handle() noexcept {
+ return m_handle;
+ }
+
+protected:
+ native_handle_type m_handle{nullptr};
+
+ static native_handle_type open(const char *path) noexcept {
+#if defined(_WIN32) || defined(_WIN64)
+ return LoadLibraryA(path);
+#else
+ return dlopen(path, RTLD_NOW | RTLD_LOCAL);
+#endif
+ }
+
+ static native_symbol_type locate_symbol(native_handle_type lib, const char *name) noexcept {
+ return DYLIB_WIN_OTHER(GetProcAddress, dlsym)(lib, name);
+ }
+
+ static void close(native_handle_type lib) noexcept {
+ DYLIB_WIN_OTHER(FreeLibrary, dlclose)(lib);
+ }
+
+ static std::string get_error_description() noexcept {
+#if defined(_WIN32) || defined(_WIN64)
+ constexpr const size_t buf_size = 512;
+ auto error_code = GetLastError();
+ if (!error_code)
+ return "Unknown error (GetLastError failed)";
+ char description[512];
+ auto lang = MAKELANGID(LANG_ENGLISH, SUBLANG_ENGLISH_US);
+ const DWORD length =
+ FormatMessageA(FORMAT_MESSAGE_FROM_SYSTEM, nullptr, error_code, lang, description, buf_size, nullptr);
+ return (length == 0) ? "Unknown error (FormatMessage failed)" : description;
+#else
+ auto description = dlerror();
+ return (description == nullptr) ? "Unknown error (dlerror failed)" : description;
+#endif
+ }
+};
+
+#undef DYLIB_WIN_MAC_OTHER
+#undef DYLIB_WIN_OTHER
+#undef DYLIB_CPP17
diff --git a/pi/__init__.py b/pi/__init__.py
new file mode 100644
index 0000000..ef8035e
--- /dev/null
+++ b/pi/__init__.py
@@ -0,0 +1,43 @@
+import logging
+
+logger = logging.getLogger(__name__)
+# noinspection PyUnresolvedReferences
+from .dialects import patch_meta_path_non_context
+
+if __name__ == "pi":
+ # prevent double patching of path during testing
+ # where we've already patched torch -> pi
+ patch_meta_path_non_context()
+else:
+ logger.debug(f"reimporting pi as {__name__}")
+
+# this has to go before the above (otherwise torch extensions won't be picked up)
+# noinspection PyUnresolvedReferences
+from torch_mlir import ir
+import torch_mlir
+
+assert (
+ len(torch_mlir.dialects.torch.AtenConv2dOp.__bases__) > 1
+), "failed to import torch dialect extensions"
+
+from ._tensor import *
+from .types_ import *
+from .dialects._torch_wrappers import *
+from ._ops import _OpNamespace
+
+ops = _OpNamespace("ops")
+_nn = _OpNamespace("_nn")
+_C = _OpNamespace("_C")
+_VF = _OpNamespace("_VF")
+special = _OpNamespace("special")
+linalg = _OpNamespace("linalg")
+
+
+from . import nn as nn
+
+
+def manual_seed(*_, **__):
+ return
+
+
+DEBUG = True
diff --git a/pi/_ops.py b/pi/_ops.py
new file mode 100644
index 0000000..830758a
--- /dev/null
+++ b/pi/_ops.py
@@ -0,0 +1,83 @@
+import types
+
+from .dialects import _torch_wrappers
+
+all_ops = {o: _torch_wrappers.__dict__[o] for o in _torch_wrappers.__all__}
+
+
+class _OpNamespace(types.ModuleType):
+ def __init__(self, name):
+ super(_OpNamespace, self).__init__("pi." + name)
+ self.name = name
+ self._dir = []
+
+ def __iter__(self):
+ return iter(self._dir)
+
+ def __getattr__(self, op_name):
+ # It is not a valid op_name when __file__ is passed in
+ # if op_name == "__file__":
+ # return "pi.ops"
+ # elif op_name == "__origin__":
+ # raise AttributeError()
+
+ # namespace_name = self.name
+ # qualified_op_name = "{}::{}".format(namespace_name, op_name)
+ op_name = op_name.split(".")[-1]
+
+ if op_name in all_ops:
+ return all_ops[op_name]
+ else:
+ return _OpNamespace(op_name)
+
+ def __call__(self, *args, **kwargs):
+ if self.name in all_ops:
+ return all_ops[self.name](*args, **kwargs)
+ else:
+ raise NotImplementedError(self.name)
+
+ # TODO(max): resolve overloads correctly here
+ # Get the op `my_namespace::my_op` if available. This will also check
+ # for overloads and raise an exception if there are more than one.
+ # try:
+ # op, overload_names = pi._C._jit_get_operation(qualified_op_name)
+ # except RuntimeError as e:
+ # # Turn this into AttributeError so getattr(obj, key, default)
+ # # works (this is called by TorchScript with __origin__)
+ # raise AttributeError(
+ # f"'_OpNamespace' '{self.name}' object has no attribute '{op_name}'"
+ # ) from e
+ #
+ # # let the script frontend know that op is identical to the builtin op
+ # # with qualified_op_name
+ # pi.jit._builtins._register_builtin(op, qualified_op_name)
+ # op.__module__ = self.__module__ + "." + namespace_name
+ # opoverloadpacket = OpOverloadPacket(
+ # qualified_op_name, op_name, op, overload_names
+ # )
+ # opoverloadpacket.__module__ = self.__module__ + "." + namespace_name
+ # # cache the opoverloadpacket to ensure that each op corresponds to
+ # # a unique OpOverloadPacket object
+ # setattr(self, op_name, opoverloadpacket)
+ # self._dir.append(op_name)
+ # return opoverloadpacket
+
+
+# class _PyOpNamespace(_OpNamespace):
+# def __init__(self):
+# super(_PyOpNamespace, self).__init__("pi.ops")
+# self.pyop_namespace = all_ops
+
+
+# class _Ops(types.ModuleType):
+# def __init__(self, name):
+# super(_Ops, self).__init__(name)
+#
+# def __getattr__(self, name):
+# # Check if the name is a pyop
+# if name in self.pyops.pyop_namespace:
+# return self.pyops.pyop_namespace[name]
+#
+# namespace = _OpNamespace(name)
+# setattr(self, name, namespace)
+# return namespace
diff --git a/pi/_tensor.py b/pi/_tensor.py
new file mode 100644
index 0000000..42bc4cf
--- /dev/null
+++ b/pi/_tensor.py
@@ -0,0 +1,2216 @@
+from __future__ import annotations
+import warnings
+from typing import Tuple, Optional, Any
+
+# noinspection PyUnresolvedReferences
+import numpy as np
+from torch_mlir.dialects import torch as torch_dialect
+from torch_mlir.dialects._ods_common import get_op_result_or_value
+from torch_mlir.ir import (
+ DenseElementsAttr,
+)
+from torch_mlir.ir import (
+ Value as MLIRValue,
+)
+
+import pi
+from .types_ import dtype as pi_dtype
+
+
+class TorchTensorWrapper(type):
+ # def __new__(mcs, name, bases, class_dict):
+ # for k, f in class_dict.items():
+ # if k in {"__init__", "__hash__", "_version", "value", "__class__", "type"}:
+ # continue
+ # if inspect.isfunction(f) and not isinstance(f, property):
+ # def run_on_actual_value(*args, **kwargs):
+ # self = args[0]
+ # return f((self.value, *args[1:]), **kwargs)
+ #
+ # class_dict[k] = run_on_actual_value
+ # return type.__new__(mcs, name, bases, class_dict)
+
+ def __subclasscheck__(cls, subclass):
+ print(cls, subclass)
+ return False
+
+ @classmethod
+ def __instancecheck__(cls, instance):
+ try:
+ return instance.is_pi_tensor
+ except:
+ return False
+
+
+class Tensor(metaclass=TorchTensorWrapper):
+ @property
+ def is_pi_tensor(self):
+ return True
+
+ @property
+ def __class__(self):
+ return MLIRValue
+
+ @property
+ def type(self):
+ return self._value.type
+
+ @property
+ def value(self):
+ return self._value
+
+ def __init__(self, tensor: MLIRValue):
+ self._value = get_op_result_or_value(tensor)
+
+ def abs(self):
+ raise NotImplementedError
+
+ def absolute(self):
+ raise NotImplementedError
+
+ def absolute_(self):
+ raise NotImplementedError
+
+ def abs_(self):
+ raise NotImplementedError
+
+ def acos(self):
+ raise NotImplementedError
+
+ def acosh(self):
+ raise NotImplementedError
+
+ def acosh_(self):
+ raise NotImplementedError
+
+ def acos_(self):
+ raise NotImplementedError
+
+ def add(self, other, *args, **kwargs):
+ raise NotImplementedError
+
+ def addbmm(self, batch1, batch2, *args, **kwargs):
+ raise NotImplementedError
+
+ def addbmm_(self, batch1, batch2, *args, **kwargs):
+ raise NotImplementedError
+
+ def addcdiv(self, tensor1, tensor2, *args, **kwargs):
+ raise NotImplementedError
+
+ def addcdiv_(self, tensor1, tensor2, *args, **kwargs):
+ raise NotImplementedError
+
+ def addcmul(self, tensor1, tensor2, *args, **kwargs):
+ raise NotImplementedError
+
+ def addcmul_(self, tensor1, tensor2, *args, **kwargs):
+ raise NotImplementedError
+
+ def addmm(self, mat1, mat2, *args, **kwargs):
+ raise NotImplementedError
+
+ def addmm_(self, mat1, mat2, *args, **kwargs):
+ raise NotImplementedError
+
+ def addmv(self, mat, vec, *args, **kwargs):
+ raise NotImplementedError
+
+ def addmv_(self, mat, vec, *args, **kwargs):
+ raise NotImplementedError
+
+ def addr(self, vec1, vec2, *args, **kwargs):
+ raise NotImplementedError
+
+ def addr_(self, vec1, vec2, *args, **kwargs):
+ raise NotImplementedError
+
+ def add_(self, other, *args, **kwargs):
+ raise NotImplementedError
+
+ def adjoint(self):
+ raise NotImplementedError
+
+ def align_as(self, other):
+ raise NotImplementedError
+
+ def align_to(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def all(self, dim=None, keepdim=False):
+ raise NotImplementedError
+
+ def allclose(self, other, rtol=1, *args, **kwargs):
+ raise NotImplementedError
+
+ def amax(self, dim=None, keepdim=False):
+ raise NotImplementedError
+
+ def amin(self, dim=None, keepdim=False):
+ raise NotImplementedError
+
+ def aminmax(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def angle(self):
+ raise NotImplementedError
+
+ def any(self, dim=None, keepdim=False):
+ raise NotImplementedError
+
+ def apply_(self, callable):
+ raise NotImplementedError
+
+ def arccos(self):
+ raise NotImplementedError
+
+ def arccosh(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def arccosh_(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def arccos_(self):
+ raise NotImplementedError
+
+ def arcsin(self):
+ raise NotImplementedError
+
+ def arcsinh(self):
+ raise NotImplementedError
+
+ def arcsinh_(self):
+ raise NotImplementedError
+
+ def arcsin_(self):
+ raise NotImplementedError
+
+ def arctan(self):
+ raise NotImplementedError
+
+ def arctan2(self, other):
+ raise NotImplementedError
+
+ def arctan2_(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def arctanh(self):
+ raise NotImplementedError
+
+ def arctanh_(self, other):
+ raise NotImplementedError
+
+ def arctan_(self):
+ raise NotImplementedError
+
+ def argmax(self, dim=None, keepdim=False):
+ raise NotImplementedError
+
+ def argmin(self, dim=None, keepdim=False):
+ raise NotImplementedError
+
+ def argsort(self, dim=-1, descending=False):
+ raise NotImplementedError
+
+ def argwhere(self):
+ raise NotImplementedError
+
+ def asin(self):
+ raise NotImplementedError
+
+ def asinh(self):
+ raise NotImplementedError
+
+ def asinh_(self):
+ raise NotImplementedError
+
+ def asin_(self):
+ raise NotImplementedError
+
+ def as_strided(self, size, stride, storage_offset=None):
+ raise NotImplementedError
+
+ def as_strided_(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def as_strided_scatter(self, src, size, stride, storage_offset=None):
+ raise NotImplementedError
+
+ def as_subclass(self, cls):
+ raise NotImplementedError
+
+ def atan(self):
+ raise NotImplementedError
+
+ def atan2(self, other):
+ raise NotImplementedError
+
+ def atan2_(self, other):
+ raise NotImplementedError
+
+ def atanh(self):
+ raise NotImplementedError
+
+ def atanh_(self, other):
+ raise NotImplementedError
+
+ def atan_(self):
+ raise NotImplementedError
+
+ def baddbmm(self, batch1, batch2, *args, **kwargs):
+ raise NotImplementedError
+
+ def baddbmm_(self, batch1, batch2, *args, **kwargs):
+ raise NotImplementedError
+
+ def bernoulli(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def bernoulli_(self, p=0.5, *args, **kwargs):
+ raise NotImplementedError
+
+ def bfloat16(self, memory_format=None):
+ raise NotImplementedError
+
+ def bincount(self, weights=None, minlength=0):
+ raise NotImplementedError
+
+ def bitwise_and(self):
+ raise NotImplementedError
+
+ def bitwise_and_(self):
+ raise NotImplementedError
+
+ def bitwise_left_shift(self, other):
+ raise NotImplementedError
+
+ def bitwise_left_shift_(self, other):
+ raise NotImplementedError
+
+ def bitwise_not(self):
+ raise NotImplementedError
+
+ def bitwise_not_(self):
+ raise NotImplementedError
+
+ def bitwise_or(self):
+ raise NotImplementedError
+
+ def bitwise_or_(self):
+ raise NotImplementedError
+
+ def bitwise_right_shift(self, other):
+ raise NotImplementedError
+
+ def bitwise_right_shift_(self, other):
+ raise NotImplementedError
+
+ def bitwise_xor(self):
+ raise NotImplementedError
+
+ def bitwise_xor_(self):
+ raise NotImplementedError
+
+ def bmm(self, batch2):
+ raise NotImplementedError
+
+ def bool(self, memory_format=None):
+ raise NotImplementedError
+
+ def broadcast_to(self, shape):
+ raise NotImplementedError
+
+ def byte(self, memory_format=None):
+ raise NotImplementedError
+
+ def cauchy_(self, median=0, sigma=1, *args, **kwargs):
+ raise NotImplementedError
+
+ def ccol_indices(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def cdouble(self, memory_format=None):
+ raise NotImplementedError
+
+ def ceil(self):
+ raise NotImplementedError
+
+ def ceil_(self):
+ raise NotImplementedError
+
+ def cfloat(self, memory_format=None):
+ raise NotImplementedError
+
+ def chalf(self, memory_format=None):
+ raise NotImplementedError
+
+ def char(self, memory_format=None):
+ raise NotImplementedError
+
+ def cholesky(self, upper=False):
+ raise NotImplementedError
+
+ def cholesky_inverse(self, upper=False):
+ raise NotImplementedError
+
+ def cholesky_solve(self, input2, upper=False):
+ raise NotImplementedError
+
+ def chunk(self, chunks, dim=0):
+ raise NotImplementedError
+
+ def clamp(self, min=None, max=None):
+ raise NotImplementedError
+
+ def clamp_(self, min=None, max=None):
+ raise NotImplementedError
+
+ def clamp_max(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def clamp_max_(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def clamp_min(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def clamp_min_(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def clip(self, min=None, max=None):
+ raise NotImplementedError
+
+ def clip_(self, min=None, max=None):
+ raise NotImplementedError
+
+ def clone(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def coalesce(self):
+ raise NotImplementedError
+
+ def col_indices(self):
+ raise NotImplementedError
+
+ def conj(self):
+ raise NotImplementedError
+
+ def conj_physical(self):
+ raise NotImplementedError
+
+ def conj_physical_(self):
+ raise NotImplementedError
+
+ def contiguous(self, memory_format=None):
+ raise NotImplementedError
+
+ def copysign(self, other):
+ raise NotImplementedError
+
+ def copysign_(self, other):
+ raise NotImplementedError
+
+ def copy_(self, src, non_blocking=False):
+ raise NotImplementedError
+
+ def corrcoef(self):
+ raise NotImplementedError
+
+ def cos(self):
+ raise NotImplementedError
+
+ def cosh(self):
+ raise NotImplementedError
+
+ def cosh_(self):
+ raise NotImplementedError
+
+ def cos_(self):
+ raise NotImplementedError
+
+ def count_nonzero(self, dim=None):
+ raise NotImplementedError
+
+ def cov(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def cpu(self, memory_format=None):
+ raise NotImplementedError
+
+ def cross(self, other, dim=None):
+ raise NotImplementedError
+
+ def crow_indices(self):
+ raise NotImplementedError
+
+ def cuda(self, device=None, non_blocking=False, memory_format=None):
+ raise NotImplementedError
+
+ def cummax(self, dim):
+ raise NotImplementedError
+
+ def cummin(self, dim):
+ raise NotImplementedError
+
+ def cumprod(self, dim, dtype=None):
+ raise NotImplementedError
+
+ def cumprod_(self, dim, dtype=None):
+ raise NotImplementedError
+
+ def cumsum(self, dim, dtype=None):
+ raise NotImplementedError
+
+ def cumsum_(self, dim, dtype=None):
+ raise NotImplementedError
+
+ def data_ptr(self):
+
+ return 0
+
+ def deg2rad(self):
+ raise NotImplementedError
+
+ def deg2rad_(self):
+ raise NotImplementedError
+
+ def dense_dim(self):
+
+ return 0
+
+ def dequantize(self):
+ raise NotImplementedError
+
+ def det(self):
+ raise NotImplementedError
+
+ def detach(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def detach_(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def diag(self, diagonal=0):
+ raise NotImplementedError
+
+ def diagflat(self, offset=0):
+ raise NotImplementedError
+
+ def diagonal(self, offset=0, dim1=0, dim2=1):
+ raise NotImplementedError
+
+ def diagonal_scatter(self, src, offset=0, dim1=0, dim2=1):
+ raise NotImplementedError
+
+ def diag_embed(self, offset=0, dim1=-2, dim2=-1):
+ raise NotImplementedError
+
+ def diff(self, n=1, dim=-1, prepend=None, append=None):
+ raise NotImplementedError
+
+ def digamma(self):
+ raise NotImplementedError
+
+ def digamma_(self):
+ raise NotImplementedError
+
+ def dim(self):
+
+ return 0
+
+ def dist(self, other, p=2):
+ raise NotImplementedError
+
+ def div(self, value, *args, **kwargs):
+ raise NotImplementedError
+
+ def divide(self, value, *args, **kwargs):
+ raise NotImplementedError
+
+ def divide_(self, value, *args, **kwargs):
+ raise NotImplementedError
+
+ def div_(self, value, *args, **kwargs):
+ raise NotImplementedError
+
+ def dot(self, other):
+ raise NotImplementedError
+
+ def double(self, memory_format=None):
+ raise NotImplementedError
+
+ def dsplit(self, split_size_or_sections):
+ raise NotImplementedError
+
+ def element_size(self):
+
+ return 0
+
+ def eq(self, other):
+ raise NotImplementedError
+
+ def equal(self, other):
+
+ return False
+
+ def eq_(self, other):
+ raise NotImplementedError
+
+ def erf(self):
+ raise NotImplementedError
+
+ def erfc(self):
+ raise NotImplementedError
+
+ def erfc_(self):
+ raise NotImplementedError
+
+ def erfinv(self):
+ raise NotImplementedError
+
+ def erfinv_(self):
+ raise NotImplementedError
+
+ def erf_(self):
+ raise NotImplementedError
+
+ def exp(self):
+ raise NotImplementedError
+
+ def exp2(self):
+ raise NotImplementedError
+
+ def exp2_(self):
+ raise NotImplementedError
+
+ def expand(self, *sizes):
+ raise NotImplementedError
+
+ def expand_as(self, other):
+ raise NotImplementedError
+
+ def expm1(self):
+ raise NotImplementedError
+
+ def expm1_(self):
+ raise NotImplementedError
+
+ def exponential_(self, lambd=1, *args, **kwargs):
+ raise NotImplementedError
+
+ def exp_(self):
+ raise NotImplementedError
+
+ def fill_(self, value):
+ raise NotImplementedError
+
+ def fill_diagonal_(self, fill_value, wrap=False):
+ raise NotImplementedError
+
+ def fix(self):
+ raise NotImplementedError
+
+ def fix_(self):
+ raise NotImplementedError
+
+ def flatten(self, start_dim=0, end_dim=-1):
+ raise NotImplementedError
+
+ def flip(self, dims):
+ raise NotImplementedError
+
+ def fliplr(self):
+ raise NotImplementedError
+
+ def flipud(self):
+ raise NotImplementedError
+
+ def float(self, memory_format=None):
+ raise NotImplementedError
+
+ def float_power(self, exponent):
+ raise NotImplementedError
+
+ def float_power_(self, exponent):
+ raise NotImplementedError
+
+ def floor(self):
+ raise NotImplementedError
+
+ def floor_(self):
+ raise NotImplementedError
+
+ def floor_divide(self, value):
+ raise NotImplementedError
+
+ def floor_divide_(self, value):
+ raise NotImplementedError
+
+ def fmax(self, other):
+ raise NotImplementedError
+
+ def fmin(self, other):
+ raise NotImplementedError
+
+ def fmod(self, divisor):
+ raise NotImplementedError
+
+ def fmod_(self, divisor):
+ raise NotImplementedError
+
+ def frac(self):
+ raise NotImplementedError
+
+ def frac_(self):
+ raise NotImplementedError
+
+ def frexp(self, input):
+ raise NotImplementedError
+
+ def gather(self, dim, index):
+ raise NotImplementedError
+
+ def gcd(self, other):
+ raise NotImplementedError
+
+ def gcd_(self, other):
+ raise NotImplementedError
+
+ def ge(self, other):
+ raise NotImplementedError
+
+ def geometric_(self, p, *args, **kwargs):
+ raise NotImplementedError
+
+ def geqrf(self):
+ raise NotImplementedError
+
+ def ger(self, vec2):
+ raise NotImplementedError
+
+ def get_device(self):
+ raise NotImplementedError
+
+ def ge_(self, other):
+ raise NotImplementedError
+
+ def greater(self, other):
+ raise NotImplementedError
+
+ def greater_(self, other):
+ raise NotImplementedError
+
+ def greater_equal(self, other):
+ raise NotImplementedError
+
+ def greater_equal_(self, other):
+ raise NotImplementedError
+
+ def gt(self, other):
+ raise NotImplementedError
+
+ def gt_(self, other):
+ raise NotImplementedError
+
+ def half(self, memory_format=None):
+ raise NotImplementedError
+
+ def hardshrink(self, lambd=0.5):
+ raise NotImplementedError
+
+ def has_names(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def heaviside(self, values):
+ raise NotImplementedError
+
+ def heaviside_(self, values):
+ raise NotImplementedError
+
+ def histc(self, bins=100, min=0, max=0):
+ raise NotImplementedError
+
+ def histogram(self, input, bins, *args, **kwargs):
+ raise NotImplementedError
+
+ def hsplit(self, split_size_or_sections):
+ raise NotImplementedError
+
+ def hypot(self, other):
+ raise NotImplementedError
+
+ def hypot_(self, other):
+ raise NotImplementedError
+
+ def i0(self):
+ raise NotImplementedError
+
+ def i0_(self):
+ raise NotImplementedError
+
+ def igamma(self, other):
+ raise NotImplementedError
+
+ def igammac(self, other):
+ raise NotImplementedError
+
+ def igammac_(self, other):
+ raise NotImplementedError
+
+ def igamma_(self, other):
+ raise NotImplementedError
+
+ def index_add(self, dim, index, source, *args, **kwargs):
+ raise NotImplementedError
+
+ def index_add_(self, dim, index, source, *args, **kwargs):
+ raise NotImplementedError
+
+ def index_copy(self, dim, index, tensor2):
+ raise NotImplementedError
+
+ def index_copy_(self, dim, index, tensor):
+ raise NotImplementedError
+
+ def index_fill(self, dim, index, value):
+ raise NotImplementedError
+
+ def index_fill_(self, dim, index, value):
+ raise NotImplementedError
+
+ def index_put(self, indices, values, accumulate=False):
+ raise NotImplementedError
+
+ def index_put_(self, indices, values, accumulate=False):
+ raise NotImplementedError
+
+ def index_reduce(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def index_reduce_(self, dim, index, source, reduce, *args, **kwargs):
+ raise NotImplementedError
+
+ def index_select(self, dim, index):
+ raise NotImplementedError
+
+ def indices(self):
+ raise NotImplementedError
+
+ def inner(self, other):
+ raise NotImplementedError
+
+ def int(self, memory_format=None):
+ raise NotImplementedError
+
+ def int_repr(self):
+ raise NotImplementedError
+
+ def inverse(self):
+ raise NotImplementedError
+
+ def ipu(self, device=None, non_blocking=False, memory_format=None):
+ raise NotImplementedError
+
+ def isclose(self, other, rtol=1, *args, **kwargs):
+ raise NotImplementedError
+
+ def isfinite(self):
+ raise NotImplementedError
+
+ def isinf(self):
+ raise NotImplementedError
+
+ def isnan(self):
+ raise NotImplementedError
+
+ def isneginf(self):
+ raise NotImplementedError
+
+ def isposinf(self):
+ raise NotImplementedError
+
+ def isreal(self):
+ raise NotImplementedError
+
+ def istft(
+ self,
+ n_fft,
+ hop_length=None,
+ win_length=None,
+ window=None,
+ center=True,
+ normalized=False,
+ onesided=True,
+ length=None,
+ ):
+ raise NotImplementedError
+
+ def is_coalesced(self):
+
+ return False
+
+ def is_complex(self):
+
+ return False
+
+ def is_conj(self):
+
+ return False
+
+ def is_contiguous(self, memory_format=None):
+
+ return False
+
+ def is_distributed(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def is_floating_point(self):
+
+ return False
+
+ def is_inference(self):
+
+ return False
+
+ def is_neg(self):
+
+ return False
+
+ def is_nonzero(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def is_pinned(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def is_same_size(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def is_set_to(self, tensor):
+
+ return False
+
+ def is_signed(self):
+
+ return False
+
+ def item(self):
+
+ return 0
+
+ def kron(self, other):
+ raise NotImplementedError
+
+ def kthvalue(self, k, dim=None, keepdim=False):
+ raise NotImplementedError
+
+ def lcm(self, other):
+ raise NotImplementedError
+
+ def lcm_(self, other):
+ raise NotImplementedError
+
+ def ldexp(self, other):
+ raise NotImplementedError
+
+ def ldexp_(self, other):
+ raise NotImplementedError
+
+ def le(self, other):
+ raise NotImplementedError
+
+ def lerp(self, end, weight):
+ raise NotImplementedError
+
+ def lerp_(self, end, weight):
+ raise NotImplementedError
+
+ def less(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def less_(self, other):
+ raise NotImplementedError
+
+ def less_equal(self, other):
+ raise NotImplementedError
+
+ def less_equal_(self, other):
+ raise NotImplementedError
+
+ def le_(self, other):
+ raise NotImplementedError
+
+ def lgamma(self):
+ raise NotImplementedError
+
+ def lgamma_(self):
+ raise NotImplementedError
+
+ def log(self):
+ raise NotImplementedError
+
+ def log10(self):
+ raise NotImplementedError
+
+ def log10_(self):
+ raise NotImplementedError
+
+ def log1p(self):
+ raise NotImplementedError
+
+ def log1p_(self):
+ raise NotImplementedError
+
+ def log2(self):
+ raise NotImplementedError
+
+ def log2_(self):
+ raise NotImplementedError
+
+ def logaddexp(self, other):
+ raise NotImplementedError
+
+ def logaddexp2(self, other):
+ raise NotImplementedError
+
+ def logcumsumexp(self, dim):
+ raise NotImplementedError
+
+ def logdet(self):
+ raise NotImplementedError
+
+ def logical_and(self):
+ raise NotImplementedError
+
+ def logical_and_(self):
+ raise NotImplementedError
+
+ def logical_not(self):
+ raise NotImplementedError
+
+ def logical_not_(self):
+ raise NotImplementedError
+
+ def logical_or(self):
+ raise NotImplementedError
+
+ def logical_or_(self):
+ raise NotImplementedError
+
+ def logical_xor(self):
+ raise NotImplementedError
+
+ def logical_xor_(self):
+ raise NotImplementedError
+
+ def logit(self):
+ raise NotImplementedError
+
+ def logit_(self):
+ raise NotImplementedError
+
+ def logsumexp(self, dim, keepdim=False):
+ raise NotImplementedError
+
+ def log_(self):
+ raise NotImplementedError
+
+ def log_normal_(self, mean=1, std=2, *args, **kwargs):
+ raise NotImplementedError
+
+ def log_softmax(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def long(self, memory_format=None):
+ raise NotImplementedError
+
+ def lt(self, other):
+ raise NotImplementedError
+
+ def lt_(self, other):
+ raise NotImplementedError
+
+ def lu_solve(self, LU_data, LU_pivots):
+ raise NotImplementedError
+
+ def map2_(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def map_(self, tensor, callable):
+ raise NotImplementedError
+
+ def masked_fill(self, mask, value):
+ raise NotImplementedError
+
+ def masked_fill_(self, mask, value):
+ raise NotImplementedError
+
+ def masked_scatter(self, mask, tensor):
+ raise NotImplementedError
+
+ def masked_scatter_(self, mask, source):
+ raise NotImplementedError
+
+ def masked_select(self, mask):
+ raise NotImplementedError
+
+ def matmul(self, tensor2):
+ raise NotImplementedError
+
+ def matrix_exp(self):
+ raise NotImplementedError
+
+ def matrix_power(self, n):
+ raise NotImplementedError
+
+ def max(self, dim=None, keepdim=False):
+ raise NotImplementedError
+
+ def maximum(self, other):
+ raise NotImplementedError
+
+ def mean(self, dim=None, keepdim=False, *args, **kwargs):
+ raise NotImplementedError
+
+ def median(self, dim=None, keepdim=False):
+ raise NotImplementedError
+
+ def min(self, dim=None, keepdim=False):
+ raise NotImplementedError
+
+ def minimum(self, other):
+ raise NotImplementedError
+
+ def mm(self, mat2):
+ raise NotImplementedError
+
+ def mode(self, dim=None, keepdim=False):
+ raise NotImplementedError
+
+ def moveaxis(self, source, destination):
+ raise NotImplementedError
+
+ def movedim(self, source, destination):
+ raise NotImplementedError
+
+ def msort(self):
+ raise NotImplementedError
+
+ def mul(self, value):
+ raise NotImplementedError
+
+ def multinomial(self, num_samples, replacement=False, *args, **kwargs):
+ raise NotImplementedError
+
+ def multiply(self, value):
+ raise NotImplementedError
+
+ def multiply_(self, value):
+ raise NotImplementedError
+
+ def mul_(self, value):
+ raise NotImplementedError
+
+ def mv(self, vec):
+ raise NotImplementedError
+
+ def mvlgamma(self, p):
+ raise NotImplementedError
+
+ def mvlgamma_(self, p):
+ raise NotImplementedError
+
+ def nanmean(self, dim=None, keepdim=False, *args, **kwargs):
+ raise NotImplementedError
+
+ def nanmedian(self, dim=None, keepdim=False):
+ raise NotImplementedError
+
+ def nanquantile(self, q, dim=None, keepdim=False, *args, **kwargs):
+ raise NotImplementedError
+
+ def nansum(self, dim=None, keepdim=False, dtype=None):
+ raise NotImplementedError
+
+ def nan_to_num(self, nan=0.0, posinf=None, neginf=None):
+ raise NotImplementedError
+
+ def nan_to_num_(self, nan=0.0, posinf=None, neginf=None):
+ raise NotImplementedError
+
+ def narrow(self, dimension, start, length):
+ raise NotImplementedError
+
+ def narrow_copy(self, dimension, start, length):
+ raise NotImplementedError
+
+ def ndimension(self):
+
+ return 0
+
+ def ne(self, other):
+ raise NotImplementedError
+
+ def neg(self):
+ raise NotImplementedError
+
+ def negative(self):
+ raise NotImplementedError
+
+ def negative_(self):
+ raise NotImplementedError
+
+ def neg_(self):
+ raise NotImplementedError
+
+ def nelement(self):
+
+ return 0
+
+ def new(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def new_empty(self, size, *args, **kwargs):
+ raise NotImplementedError
+
+ def new_empty_strided(
+ self,
+ size,
+ stride,
+ dtype=None,
+ device=None,
+ requires_grad=False,
+ layout=None,
+ pin_memory=False,
+ ):
+ raise NotImplementedError
+
+ def new_full(self, size, fill_value, *args, **kwargs):
+ raise NotImplementedError
+
+ def new_ones(self, size, *args, **kwargs):
+ raise NotImplementedError
+
+ def new_tensor(self, data, *args, **kwargs):
+ raise NotImplementedError
+
+ def new_zeros(self, size, *args, **kwargs):
+ raise NotImplementedError
+
+ def nextafter(self, other):
+ raise NotImplementedError
+
+ def nextafter_(self, other):
+ raise NotImplementedError
+
+ def ne_(self, other):
+ raise NotImplementedError
+
+ def nonzero(self):
+ raise NotImplementedError
+
+ def norm(self, p=2, dim=None, keepdim=False):
+ raise NotImplementedError
+
+ def normal_(self, mean=0, std=1, *args, **kwargs):
+ raise NotImplementedError
+
+ def not_equal(self, other):
+ raise NotImplementedError
+
+ def not_equal_(self, other):
+ raise NotImplementedError
+
+ def numel(self):
+
+ return 0
+
+ def numpy(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def orgqr(self, input2):
+ raise NotImplementedError
+
+ def ormqr(self, input2, input3, left=True, transpose=False):
+ raise NotImplementedError
+
+ def outer(self, vec2):
+ raise NotImplementedError
+
+ def permute(self, *dims):
+ raise NotImplementedError
+
+ def pinverse(self):
+ raise NotImplementedError
+
+ def pin_memory(self):
+ raise NotImplementedError
+
+ def polygamma(self, n):
+ raise NotImplementedError
+
+ def polygamma_(self, n):
+ raise NotImplementedError
+
+ def positive(self):
+ raise NotImplementedError
+
+ def pow(self, exponent):
+ raise NotImplementedError
+
+ def pow_(self, exponent):
+ raise NotImplementedError
+
+ def prelu(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def prod(self, dim=None, keepdim=False, dtype=None):
+ raise NotImplementedError
+
+ def put(self, input, index, source, accumulate=False):
+ raise NotImplementedError
+
+ def put_(self, index, source, accumulate=False):
+ raise NotImplementedError
+
+ def qr(self, some=True):
+ raise NotImplementedError
+
+ def qscheme(self):
+ raise NotImplementedError
+
+ def quantile(self, q, dim=None, keepdim=False, *args, **kwargs):
+ raise NotImplementedError
+
+ def q_per_channel_axis(self):
+
+ return 0
+
+ def q_per_channel_scales(self):
+ raise NotImplementedError
+
+ def q_per_channel_zero_points(
+ self,
+ ):
+ raise NotImplementedError
+
+ def q_scale(self):
+
+ return 0.0
+
+ def q_zero_point(self):
+
+ return 0
+
+ def rad2deg(self):
+ raise NotImplementedError
+
+ def rad2deg_(self):
+ raise NotImplementedError
+
+ def random_(self, from_=0, to=None, *args, **kwargs):
+ raise NotImplementedError
+
+ def ravel(self):
+ raise NotImplementedError
+
+ def reciprocal(self):
+ raise NotImplementedError
+
+ def reciprocal_(self):
+ raise NotImplementedError
+
+ def record_stream(self, stream):
+ raise NotImplementedError
+
+ def refine_names(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def relu(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def relu_(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def remainder(self, divisor):
+ raise NotImplementedError
+
+ def remainder_(self, divisor):
+ raise NotImplementedError
+
+ def rename(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def rename_(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def renorm(self, p, dim, maxnorm):
+ raise NotImplementedError
+
+ def renorm_(self, p, dim, maxnorm):
+ raise NotImplementedError
+
+ def repeat(self, *sizes):
+ raise NotImplementedError
+
+ def repeat_interleave(self, repeats, dim=None, *args, **kwargs):
+ raise NotImplementedError
+
+ def requires_grad_(self, requires_grad=True):
+ raise NotImplementedError
+
+ def reshape(self, *shape):
+ raise NotImplementedError
+
+ def reshape_as(self, other):
+ raise NotImplementedError
+
+ def resize_(self, *sizes, memory_format=None):
+ raise NotImplementedError
+
+ def resize_as_(self, tensor, memory_format=None):
+ raise NotImplementedError
+
+ def resize_as_sparse_(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def resolve_conj(self):
+ raise NotImplementedError
+
+ def resolve_neg(self):
+ raise NotImplementedError
+
+ def retain_grad(self):
+ raise NotImplementedError
+
+ def roll(self, shifts, dims):
+ raise NotImplementedError
+
+ def rot90(self, k, dims):
+ raise NotImplementedError
+
+ def round(self, decimals=0):
+ raise NotImplementedError
+
+ def round_(self, decimals=0):
+ raise NotImplementedError
+
+ def row_indices(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def rsqrt(self):
+ raise NotImplementedError
+
+ def rsqrt_(self):
+ raise NotImplementedError
+
+ def scatter(self, dim, index, src):
+ raise NotImplementedError
+
+ def scatter_(self, dim, index, src, reduce=None):
+ raise NotImplementedError
+
+ def scatter_add(self, dim, index, src):
+ raise NotImplementedError
+
+ def scatter_add_(self, dim, index, src):
+ raise NotImplementedError
+
+ def scatter_reduce(self, dim, index, src, reduce, *args, **kwargs):
+ raise NotImplementedError
+
+ def scatter_reduce_(self, dim, index, src, reduce, *args, **kwargs):
+ raise NotImplementedError
+
+ def select(self, dim, index):
+ raise NotImplementedError
+
+ def select_scatter(self, src, dim, index):
+ raise NotImplementedError
+
+ def set_(self, source=None, storage_offset=0, size=None, stride=None):
+ raise NotImplementedError
+
+ def sgn(self):
+ raise NotImplementedError
+
+ def sgn_(self):
+ raise NotImplementedError
+
+ def short(self, memory_format=None):
+ raise NotImplementedError
+
+ def sigmoid(self):
+ raise NotImplementedError
+
+ def sigmoid_(self):
+ raise NotImplementedError
+
+ def sign(self):
+ raise NotImplementedError
+
+ def signbit(self):
+ raise NotImplementedError
+
+ def sign_(self):
+ raise NotImplementedError
+
+ def sin(self):
+ raise NotImplementedError
+
+ def sinc(self):
+ raise NotImplementedError
+
+ def sinc_(self):
+ raise NotImplementedError
+
+ def sinh(self):
+ raise NotImplementedError
+
+ def sinh_(self):
+ raise NotImplementedError
+
+ def sin_(self):
+ raise NotImplementedError
+
+ def size(self, dim=None):
+ raise NotImplementedError
+
+ def slice_scatter(self, src, dim=0, start=None, end=None, step=1):
+ raise NotImplementedError
+
+ def slogdet(self):
+ raise NotImplementedError
+
+ def smm(self, mat):
+ raise NotImplementedError
+
+ def softmax(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def sort(self, dim=-1, descending=False):
+ raise NotImplementedError
+
+ def sparse_dim(self):
+
+ return 0
+
+ def sparse_mask(self, mask):
+ raise NotImplementedError
+
+ def sparse_resize_(self, size, sparse_dim, dense_dim):
+ raise NotImplementedError
+
+ def sparse_resize_and_clear_(self, size, sparse_dim, dense_dim):
+ raise NotImplementedError
+
+ def split(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def split_with_sizes(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def sqrt(self):
+ raise NotImplementedError
+
+ def sqrt_(self):
+ raise NotImplementedError
+
+ def square(self):
+ raise NotImplementedError
+
+ def square_(self):
+ raise NotImplementedError
+
+ def squeeze(self, dim=None):
+ raise NotImplementedError
+
+ def squeeze_(self, dim=None):
+ raise NotImplementedError
+
+ def sspaddmm(self, mat1, mat2, *args, **kwargs):
+ raise NotImplementedError
+
+ def std(self, dim, unbiased=True, keepdim=False):
+ raise NotImplementedError
+
+ def stft(
+ self,
+ frame_length,
+ hop,
+ fft_size=None,
+ return_onesided=True,
+ window=None,
+ pad_end=0,
+ ):
+ raise NotImplementedError
+
+ def storage_offset(self):
+
+ return 0
+
+ def stride(self, dim):
+
+ return ()
+
+ def sub(self, other, *args, **kwargs):
+ raise NotImplementedError
+
+ def subtract(self, other, *args, **kwargs):
+ raise NotImplementedError
+
+ def subtract_(self, other, *args, **kwargs):
+ raise NotImplementedError
+
+ def sub_(self, other, *args, **kwargs):
+ raise NotImplementedError
+
+ def sum(self, dim=None, keepdim=False, dtype=None):
+ raise NotImplementedError
+
+ def sum_to_size(self, *size):
+ raise NotImplementedError
+
+ def svd(self, some=True, compute_uv=True):
+ raise NotImplementedError
+
+ def swapaxes(self, axis0, axis1):
+ raise NotImplementedError
+
+ def swapaxes_(self, axis0, axis1):
+ raise NotImplementedError
+
+ def swapdims(self, dim0, dim1):
+ raise NotImplementedError
+
+ def swapdims_(self, dim0, dim1):
+ raise NotImplementedError
+
+ def symeig(self, eigenvectors=False, upper=True):
+ raise NotImplementedError
+
+ def t(self):
+ raise NotImplementedError
+
+ def take(self, indices):
+ raise NotImplementedError
+
+ def take_along_dim(self, indices, dim):
+ raise NotImplementedError
+
+ def tan(self):
+ raise NotImplementedError
+
+ def tanh(self):
+ raise NotImplementedError
+
+ def tanh_(self):
+ raise NotImplementedError
+
+ def tan_(self):
+ raise NotImplementedError
+
+ def tensor_split(self, indices_or_sections, dim=0):
+ raise NotImplementedError
+
+ def tile(self, *reps):
+ raise NotImplementedError
+
+ def to(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def tolist(self):
+ raise NotImplementedError
+
+ def topk(self, k, dim=None, largest=True, sorted=True):
+ raise NotImplementedError
+
+ def to_dense(self):
+ raise NotImplementedError
+
+ def to_mkldnn(self):
+ raise NotImplementedError
+
+ def to_padded_tensor(self, padding, output_size=None):
+ raise NotImplementedError
+
+ def to_sparse(self, sparseDims):
+ raise NotImplementedError
+
+ def to_sparse_bsc(self, blocksize):
+ raise NotImplementedError
+
+ def to_sparse_bsr(self, blocksize):
+ raise NotImplementedError
+
+ def to_sparse_csc(self):
+ raise NotImplementedError
+
+ def to_sparse_csr(self):
+ raise NotImplementedError
+
+ def trace(self):
+ raise NotImplementedError
+
+ def transpose(self, dim0, dim1):
+ raise NotImplementedError
+
+ def transpose_(self, dim0, dim1):
+ raise NotImplementedError
+
+ def triangular_solve(self, A, upper=True, transpose=False, unitriangular=False):
+ raise NotImplementedError
+
+ def tril(self, diagonal=0):
+ raise NotImplementedError
+
+ def tril_(self, diagonal=0):
+ raise NotImplementedError
+
+ def triu(self, diagonal=0):
+ raise NotImplementedError
+
+ def triu_(self, diagonal=0):
+ raise NotImplementedError
+
+ def true_divide(self, value):
+ raise NotImplementedError
+
+ def true_divide_(self, value):
+ raise NotImplementedError
+
+ def trunc(self):
+ raise NotImplementedError
+
+ def trunc_(self):
+ raise NotImplementedError
+
+ # def type(
+ # self, dtype=None, non_blocking=False, **kwargs
+ # ):
+ # return ""
+
+ def type_as(self, tensor):
+ raise NotImplementedError
+
+ def t_(self):
+ raise NotImplementedError
+
+ def unbind(self, dim=0):
+ raise NotImplementedError
+
+ def unflatten(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def unfold(self, dimension, size, step):
+ raise NotImplementedError
+
+ def uniform_(self, from_=0, to=1):
+ raise NotImplementedError
+
+ def unsafe_chunk(self, chunks, dim=0):
+ raise NotImplementedError
+
+ def unsafe_split(self, split_size, dim=0):
+ raise NotImplementedError
+
+ def unsafe_split_with_sizes(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def unsqueeze(self, dim):
+ raise NotImplementedError
+
+ def unsqueeze_(self, dim):
+ raise NotImplementedError
+
+ def values(self):
+ raise NotImplementedError
+
+ def var(self, dim, unbiased=True, keepdim=False):
+ raise NotImplementedError
+
+ def vdot(self, other):
+ raise NotImplementedError
+
+ def view(self, *shape):
+ raise NotImplementedError
+
+ def view_as(self, other):
+ raise NotImplementedError
+
+ def vsplit(self, split_size_or_sections):
+ raise NotImplementedError
+
+ def where(self, condition, y):
+ raise NotImplementedError
+
+ def xlogy(self, other):
+ raise NotImplementedError
+
+ def xlogy_(self, other):
+ raise NotImplementedError
+
+ def xpu(self, device=None, non_blocking=False, memory_format=None):
+ raise NotImplementedError
+
+ def zero_(self):
+ raise NotImplementedError
+
+ def _addmm_activation(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def _autocast_to_full_precision(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def _autocast_to_reduced_precision(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def _coalesced_(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def _conj(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def _conj_physical(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def _dimI(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def _dimV(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def _fix_weakref(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def _indices(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def _is_view(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def _is_zerotensor(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def _make_subclass(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def _make_wrapper_subclass(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def _neg_view(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def _nested_tensor_size(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def _nnz(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def _storage(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def _to_dense(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def _values(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def __add__(self, *args, **kwargs):
+ return pi.add_Tensor(self, args[0])
+
+ def __and__(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def __bool__(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def __complex__(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def __delitem__(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def __div__(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def __eq__(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def __float__(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def __floordiv__(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def __getitem__(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def __ge__(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def __gt__(self, *args, **kwargs):
+ raise NotImplementedError
+
+ #
+ # def __iadd__(self, *args, **kwargs):
+ # raise NotImplementedError
+ #
+ # def __iand__(self, *args, **kwargs):
+ # raise NotImplementedError
+ #
+ # def __idiv__(self, *args, **kwargs):
+ # raise NotImplementedError
+ #
+ # def __ifloordiv__(self, *args, **kwargs):
+ # raise NotImplementedError
+ #
+ # def __ilshift__(self, *args, **kwargs):
+ # raise NotImplementedError
+ #
+ # def __imod__(self, *args, **kwargs):
+ # raise NotImplementedError
+ #
+ # def __imul__(self, *args, **kwargs):
+ # raise NotImplementedError
+ #
+ # def __index__(self, *args, **kwargs):
+ # raise NotImplementedError
+ #
+ def __int__(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def __invert__(self, *args, **kwargs):
+ raise NotImplementedError
+
+ #
+ # def __ior__(self, *args, **kwargs):
+ # raise NotImplementedError
+ #
+ # def __irshift__(self, *args, **kwargs):
+ # raise NotImplementedError
+ #
+ # def __isub__(self, *args, **kwargs):
+ # raise NotImplementedError
+ #
+ # def __ixor__(self, *args, **kwargs):
+ # raise NotImplementedError
+ #
+ def __len__(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def __le__(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def __long__(self, *args, **kwargs):
+ raise NotImplementedError
+
+ #
+ # def __lshift__(self, *args, **kwargs):
+ # raise NotImplementedError
+ #
+ def __lt__(self, *args, **kwargs):
+ raise NotImplementedError
+
+ #
+ # def __matmul__(self, *args, **kwargs):
+ # raise NotImplementedError
+ #
+ def __mod__(self, *args, **kwargs):
+ raise NotImplementedError
+
+ #
+ def __mul__(self, *args, **kwargs):
+ return pi.mul_Tensor(self, args[0])
+
+ def __ne__(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def __nonzero__(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def __or__(self, *args, **kwargs):
+ raise NotImplementedError
+
+ # def __radd__(self, *args, **kwargs):
+ # raise NotImplementedError
+ #
+ # def __rand__(self, *args, **kwargs):
+ # raise NotImplementedError
+ #
+ # def __rmul__(self, *args, **kwargs):
+ # raise NotImplementedError
+ #
+ # def __ror__(self, *args, **kwargs):
+ # raise NotImplementedError
+ #
+ # def __rshift__(self, *args, **kwargs):
+ # raise NotImplementedError
+ #
+ # def __rxor__(self, *args, **kwargs):
+ # raise NotImplementedError
+ #
+ def __setitem__(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def __sub__(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def __truediv__(self, *args, **kwargs):
+ raise NotImplementedError
+
+ # def __xor__(self, *args, **kwargs):
+ # raise NotImplementedError
+
+ # data = property(
+ # lambda self: object(), lambda self, v: None, lambda self: None
+ # ) # default
+ #
+ # device = property(
+ # lambda self: object(), lambda self, v: None, lambda self: None
+ # ) # default
+ #
+ # dtype = property(
+ # lambda self: object(), lambda self, v: None, lambda self: None
+ # ) # default
+ #
+ # grad = property(
+ # lambda self: object(), lambda self, v: None, lambda self: None
+ # ) # default
+ #
+ # grad_fn = property(
+ # lambda self: object(), lambda self, v: None, lambda self: None
+ # ) # default
+ #
+ # H = property(
+ # lambda self: object(), lambda self, v: None, lambda self: None
+ # ) # default
+ #
+ # imag = property(
+ # lambda self: object(), lambda self, v: None, lambda self: None
+ # ) # default
+ #
+ # is_cpu = property(
+ # lambda self: object(), lambda self, v: None, lambda self: None
+ # ) # default
+ #
+ # is_cuda = property(
+ # lambda self: object(), lambda self, v: None, lambda self: None
+ # ) # default
+ #
+ # is_ipu = property(
+ # lambda self: object(), lambda self, v: None, lambda self: None
+ # ) # default
+ #
+ # is_leaf = property(
+ # lambda self: object(), lambda self, v: None, lambda self: None
+ # ) # default
+ #
+ # is_meta = property(
+ # lambda self: object(), lambda self, v: None, lambda self: None
+ # ) # default
+ #
+ # is_mkldnn = property(
+ # lambda self: object(), lambda self, v: None, lambda self: None
+ # ) # default
+ #
+ # is_mps = property(
+ # lambda self: object(), lambda self, v: None, lambda self: None
+ # ) # default
+ #
+ # is_nested = property(
+ # lambda self: object(), lambda self, v: None, lambda self: None
+ # ) # default
+ #
+ # is_ort = property(
+ # lambda self: object(), lambda self, v: None, lambda self: None
+ # ) # default
+ #
+ # is_quantized = property(
+ # lambda self: object(), lambda self, v: None, lambda self: None
+ # ) # default
+ #
+ # is_sparse = property(
+ # lambda self: object(), lambda self, v: None, lambda self: None
+ # ) # default
+ #
+ # is_sparse_csr = property(
+ # lambda self: object(), lambda self, v: None, lambda self: None
+ # ) # default
+ #
+ # is_vulkan = property(
+ # lambda self: object(), lambda self, v: None, lambda self: None
+ # ) # default
+ #
+ # is_xpu = property(
+ # lambda self: object(), lambda self, v: None, lambda self: None
+ # ) # default
+ #
+ # layout = property(
+ # lambda self: object(), lambda self, v: None, lambda self: None
+ # ) # default
+ #
+ # mH = property(
+ # lambda self: object(), lambda self, v: None, lambda self: None
+ # ) # default
+ #
+ # mT = property(
+ # lambda self: object(), lambda self, v: None, lambda self: None
+ # ) # default
+ #
+ # name = property(
+ # lambda self: object(), lambda self, v: None, lambda self: None
+ # ) # default
+ #
+ # names = property(
+ # lambda self: object(), lambda self, v: None, lambda self: None
+ # ) # default
+ #
+ # ndim = property(
+ # lambda self: object(), lambda self, v: None, lambda self: None
+ # ) # default
+ #
+ # output_nr = property(
+ # lambda self: object(), lambda self, v: None, lambda self: None
+ # ) # default
+ #
+ # real = property(
+ # lambda self: object(), lambda self, v: None, lambda self: None
+ # ) # default
+ #
+ # requires_grad = property(
+ # lambda self: object(), lambda self, v: None, lambda self: None
+ # ) # default
+ #
+ # retains_grad = property(
+ # lambda self: object(), lambda self, v: None, lambda self: None
+ # ) # default
+ #
+ # shape = property(
+ # lambda self: object(), lambda self, v: None, lambda self: None
+ # ) # default
+ #
+ # T = property(
+ # lambda self: object(), lambda self, v: None, lambda self: None
+ # ) # default
+ #
+ # volatile = property(
+ # lambda self: object(), lambda self, v: None, lambda self: None
+ # ) # default
+ #
+ # _backward_hooks = property(
+ # lambda self: object(), lambda self, v: None, lambda self: None
+ # ) # default
+ #
+ # _base = property(
+ # lambda self: object(), lambda self, v: None, lambda self: None
+ # ) # default
+ #
+ # _cdata = property(
+ # lambda self: object(), lambda self, v: None, lambda self: None
+ # ) # default
+ #
+ # _grad = property(
+ # lambda self: object(), lambda self, v: None, lambda self: None
+ # ) # default
+ #
+ # _grad_fn = property(
+ # lambda self: object(), lambda self, v: None, lambda self: None
+ # ) # default
+ #
+ # _has_symbolic_sizes_strides = property(
+ # lambda self: object(), lambda self, v: None, lambda self: None
+ # ) # default
+ #
+ # _python_dispatch = property(
+ # lambda self: object(), lambda self, v: None, lambda self: None
+ # ) # default
+ #
+ # _version = property(
+ # lambda self: object(), lambda self, v: None, lambda self: None
+ # ) # default
+
+
+def from_numpy(arr: np.ndarray):
+ from pi import DEBUG
+
+ if DEBUG:
+ arr = np.ones_like(arr, dtype=np.float32)
+ attr = DenseElementsAttr.get(arr)
+ vt = Tensor(torch_dialect.ValueTensorLiteralOp(attr))
+ return vt
+
+
+def empty(shape: Tuple[int, ...], dtype: "pi.dtype" = None, **kwargs) -> Tensor:
+ if np.prod(shape) == 0:
+ return Tensor(None)
+ else:
+ if dtype is not None:
+ dtype = dtype.to_np_type()
+
+ return from_numpy(np.empty(shape, dtype))
+
+
+def randint(low: int, high: int, size: Tuple[int, ...]) -> Tensor:
+ return from_numpy(np.random.randint(low, high, size))
+
+
+def randn(*size: Tuple[int, ...]) -> Tensor:
+ return from_numpy(np.random.randn(*size))
+
+
+def uniform(low: float, high: float, size: Tuple[int, ...]) -> Tensor:
+ return from_numpy(np.random.uniform(low, high, size))
+
+
+def rand(*size: Tuple[int, ...], **kwargs) -> Tensor:
+ dtype = kwargs.get("dtype", None)
+ if dtype is not None:
+ dtype = dtype.to_np_type()
+ return from_numpy(np.random.rand(*size))
+
+
+def ones(*size: Tuple[int, ...], **kwargs) -> Tensor:
+ # dtype: "pi.dtype" = None, _device: Any = None
+ dtype = kwargs.get("dtype", None)
+ if dtype is not None:
+ dtype = dtype.to_np_type()
+ return from_numpy(np.ones(size, dtype=dtype))
+
+
+def zeros(*size: Tuple[int, ...], **kwargs) -> Tensor:
+ dtype = kwargs.get("dtype", None)
+ if dtype is not None:
+ dtype = dtype.to_np_type()
+ return from_numpy(np.zeros(size, dtype))
+
+
+def tensor(data: Any, dtype: Optional["pi.dtype"] = None) -> Tensor:
+ if dtype is not None:
+ dtype = dtype.to_np_type()
+
+ return from_numpy(np.array(data, dtype=dtype))
+
+
+def LongTensor(data: Any) -> Tensor:
+ return from_numpy(np.array(data, dtype=pi_dtype.int64.to_np_type()))
+
+
+def clone(x: Tensor, **kwargs):
+ # TODO(max): is this necessary?
+ warnings.warn(f"not actually cloning")
+ return x
+
+
+__all__ = [
+ "from_numpy",
+ "empty",
+ "randint",
+ "randn",
+ "rand",
+ "uniform",
+ "ones",
+ "tensor",
+ "Tensor",
+ "LongTensor",
+ "zeros",
+]
diff --git a/shark/compiler/__init__.py b/pi/compiler/__init__.py
similarity index 100%
rename from shark/compiler/__init__.py
rename to pi/compiler/__init__.py
diff --git a/pi/compiler/annotations.py b/pi/compiler/annotations.py
new file mode 100644
index 0000000..d82f152
--- /dev/null
+++ b/pi/compiler/annotations.py
@@ -0,0 +1,102 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+# Also available under a BSD-style license. See LICENSE.
+import functools
+import inspect
+from collections import OrderedDict
+from typing import List, Optional, Tuple, Union
+
+import pi
+from torch_mlir.dialects.torch.importer.jit_ir.build_tools.torch_ods_gen import (
+ get_ods_type,
+)
+from torch_mlir import ir
+
+PI_EXPORT_ATTR_NAME = "_PI_EXPORT"
+PI_ARG_ANNOTATIONS_ATTR_NAME = "_PI_ARG_ANNOTATIONS"
+
+
+def export(fn):
+ # setattr(fn, PI_EXPORT_ATTR_NAME, True)
+ return fn
+
+
+ArgAnnotation = Union[type, Tuple[List[int], pi.dtype]]
+
+
+class TensorPlaceholder:
+ def __init__(self, shape: List[int], dtype: pi.dtype):
+ self.shape = shape
+ self.dtype = dtype
+
+ def to_value_tensor_type(self):
+ dtype = self.dtype.to_mlir_type()
+ type = ir.Type.parse(
+ f"!torch.vtensor<[{','.join(map(str, self.shape))}],{dtype}>"
+ )
+ return type
+
+ def to(self, dtype: pi.dtype):
+ self.dtype = dtype
+ return self
+
+ def type(self, dtype):
+ return self.to(dtype)
+
+ def bool(self):
+ return self.to(pi.dtype.bool)
+
+ def double(self):
+ self.dtype = pi.dtype.float64
+ return self
+
+
+def annotations_to_placeholders(
+ args: List[str], annotations: List[Optional[ArgAnnotation]]
+) -> OrderedDict:
+ placeholders = OrderedDict()
+ for annotation, arg in zip(annotations, args):
+ # Skip the "self" annotation.
+ if annotation is None:
+ assert arg == "self"
+ continue
+ shape, dtype, value_tensor = annotation
+ assert value_tensor, f"non-value tensors not supported {arg}"
+ placeholders[arg] = TensorPlaceholder(annotation[0], annotation[1])
+ return placeholders
+
+
+# TODO: Replace with py3 extended argument annotations when available.
+# See https://www.python.org/dev/peps/pep-0593/
+def annotate_args(annotations: List[Optional[ArgAnnotation]]):
+ def decorator(fn):
+ arg_spec = inspect.getfullargspec(fn)
+ placeholders = annotations_to_placeholders(arg_spec.args, annotations)
+ setattr(fn, "__placeholders__", placeholders)
+ return fn
+
+ return decorator
+
+
+def convert_annotations_to_placeholders(forward_method):
+ """Converts the annotations on a forward method into tensor placeholders.
+
+ These placeholders are suitable for being passed to `torch_mlir.compile`.
+ """
+ annotations = getattr(forward_method, PI_ARG_ANNOTATIONS_ATTR_NAME)
+ placeholders = []
+ # Skip the "self" annotation.
+ for annotation in annotations[1:]:
+ placeholders.append(TensorPlaceholder(annotation[0], annotation[1]))
+ return placeholders
+
+
+def pipile(annotations: List[Optional[ArgAnnotation]]):
+ def actual_decorator(func):
+ func = export(func)
+ if len(annotations):
+ func = annotate_args(annotations)(func)
+ return func
+
+ return actual_decorator
diff --git a/shark/compiler/compiler.py b/pi/compiler/compiler.py
similarity index 67%
rename from shark/compiler/compiler.py
rename to pi/compiler/compiler.py
index db77661..178a1b0 100644
--- a/shark/compiler/compiler.py
+++ b/pi/compiler/compiler.py
@@ -1,21 +1,18 @@
+import os
+
from torch_mlir import ir
# noinspection PyUnresolvedReferences
-from torch_mlir.dialects import (
- arith,
- linalg,
- math,
- memref,
- torch as torch_dialect
-)
+from torch_mlir.dialects import arith, linalg, math, memref, torch as torch_dialect
# noinspection PyUnresolvedReferences
-from shark.dialects import affine_
+from pi.dialects import affine_
-from shark.compiler.tracing.trace import trace
+from pi.compiler.tracing.trace import trace
def mlir_trace(script_path):
+ assert os.path.isabs(script_path), f"script path must be absolute {script_path}"
top_mlir_context = ir.Context()
mlir_location = ir.Location.unknown(context=top_mlir_context)
with top_mlir_context, mlir_location:
diff --git a/shark/compiler/config.py b/pi/compiler/config.py
similarity index 100%
rename from shark/compiler/config.py
rename to pi/compiler/config.py
diff --git a/shark/compiler/tracing/__init__.py b/pi/compiler/tracing/__init__.py
similarity index 100%
rename from shark/compiler/tracing/__init__.py
rename to pi/compiler/tracing/__init__.py
diff --git a/pi/compiler/tracing/handlers.py b/pi/compiler/tracing/handlers.py
new file mode 100644
index 0000000..30ea345
--- /dev/null
+++ b/pi/compiler/tracing/handlers.py
@@ -0,0 +1,116 @@
+import ast
+from enum import Enum
+from typing import List, cast
+
+from pyccolo import fast, TraceEvent
+from pyccolo.extra_builtins import (
+ TRACING_ENABLED,
+ make_guard_name,
+ EMIT_EVENT,
+)
+from pyccolo.fast import make_composite_condition, make_test
+from pyccolo.stmt_inserter import StatementInserter
+
+
+def _handle_class_body(
+ self,
+ node: ast.ClassDef,
+ orig_body: List[ast.AST],
+) -> List[ast.AST]:
+ classdef_copy = cast(
+ ast.ClassDef,
+ self.orig_to_copy_mapping[id(node)],
+ )
+ if self.global_guards_enabled:
+ classdef_copy = self._global_nonlocal_stripper.visit(classdef_copy)
+ class_guard = make_guard_name(classdef_copy)
+ self.register_guard(class_guard)
+ else:
+ class_guard = None
+ docstring = []
+ if (
+ len(orig_body) > 0
+ and isinstance(orig_body[0], ast.Expr)
+ and isinstance(orig_body[0].value, ast.Str)
+ ):
+ docstring = [orig_body.pop(0)]
+ if len(orig_body) == 0:
+ return docstring
+ with fast.location_of(classdef_copy):
+ if self.global_guards_enabled:
+ ret = [
+ fast.If(
+ test=make_composite_condition(
+ [
+ make_test(TRACING_ENABLED),
+ make_test(class_guard),
+ self.emit(
+ TraceEvent.before_class_body,
+ node,
+ ret=fast.NameConstant(True),
+ )
+ if self.handler_predicate_by_event[
+ TraceEvent.before_class_body
+ ](classdef_copy)
+ else None,
+ ]
+ ),
+ body=orig_body,
+ orelse=classdef_copy.body
+ if len(docstring) == 0
+ else classdef_copy.body[len(docstring) :], # noqa: E203
+ ),
+ ]
+ return docstring + ret
+
+
+def generic_visit(self, node):
+ if self.is_tracing_disabled_context(node):
+ return node
+ for name, field in ast.iter_fields(node):
+ if isinstance(field, ast.AST):
+ setattr(node, name, self.visit(field))
+ elif isinstance(field, list):
+ new_field = []
+ future_imports = []
+ if isinstance(node, ast.Module) and name == "body":
+ node_copy = self.get_copy_node(node)
+ if self.handler_predicate_by_event[TraceEvent.init_module](node_copy):
+ with fast.location_of(node):
+ new_field.extend(
+ fast.parse(
+ f'{EMIT_EVENT}("{TraceEvent.init_module.name}", '
+ + f"{id(node_copy)})"
+ ).body
+ )
+ for inner_node in field:
+ if isinstance(inner_node, ast.stmt):
+ if (
+ isinstance(inner_node, ast.ImportFrom)
+ and inner_node.module == "__future__"
+ ):
+ future_imports.append(inner_node)
+ else:
+ new_field.extend(self._handle_stmt(node, name, inner_node))
+ elif isinstance(inner_node, ast.AST):
+ new_field.append(self.visit(inner_node))
+ else:
+ new_field.append(inner_node)
+ new_field = future_imports + new_field
+ if name == "body":
+ if isinstance(node, ast.Module):
+ new_field = self._handle_module_body(node, new_field)
+ elif isinstance(node, (ast.For, ast.While)):
+ new_field = self._handle_loop_body(node, new_field)
+ elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
+ new_field = self._handle_function_body(node, new_field)
+ elif isinstance(node, ast.ClassDef):
+ new_field = self._handle_class_body(node, new_field)
+ setattr(node, name, new_field)
+ else:
+ continue
+ return node
+
+
+StatementInserter.generic_visit = generic_visit
+StatementInserter._handle_class_body = _handle_class_body
diff --git a/shark/compiler/tracing/trace.py b/pi/compiler/tracing/trace.py
similarity index 91%
rename from shark/compiler/tracing/trace.py
rename to pi/compiler/tracing/trace.py
index c5ebc6d..e5a8e5d 100644
--- a/shark/compiler/tracing/trace.py
+++ b/pi/compiler/tracing/trace.py
@@ -1,3 +1,4 @@
+import operator
import sys
from collections import namedtuple
@@ -5,17 +6,24 @@
import ctypes
import inspect
import os
+
+from pi.compiler.annotations import PI_EXPORT_ATTR_NAME
+
+# noinspection PyUnresolvedReferences
+import pi.compiler.tracing.handlers
import pyccolo as pyc
import traceback
import types
from contextlib import contextmanager
from pathlib import Path
-from pyccolo import TraceEvent, AstRewriter, fast
+from pyccolo import TraceEvent, AstRewriter, fast, register_raw_handler
from pyccolo.emit_event import _TRACER_STACK
from runpy import run_module
from typing import Optional, Union, Tuple, List
+import pi
from torch_mlir import ir
+
# this needs to be all of the dialects that will be used in the user scripts
# (in order to register ops)
# noinspection PyUnresolvedReferences
@@ -29,7 +37,7 @@
from torch_mlir.dialects._ods_common import get_op_result_or_value
from torch_mlir.ir import Type as MLIRType, IntegerType, F64Type
-from shark.dialects import affine_, value_
+from pi.dialects import affine_, value_
# from uncompyle6 import main
@@ -102,10 +110,19 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
for n in node.body:
self.visit(n)
if not isinstance(node.body[-1], ast.Return):
- return_ = ast.Return(value=ast.Constant(value=None))
+ return_ = ast.Return(
+ value=None if node.name == "__init__" else ast.Constant(value=None),
+ lineno=node.body[-1].lineno + 1,
+ col_offset=node.col_offset + len("return "),
+ )
node.body.append(return_)
return node
+ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
+ for n in node.body:
+ self.visit(n)
+ return node
+
class BreakIf(ast.NodeTransformer):
def visit_If(self, node: ast.If):
@@ -158,11 +175,13 @@ def update_frame_locals(frame, updates: dict):
class MLIRTracer(pyc.BaseTracer):
def __init__(
self,
+ script_path: str,
mlir_context: ir.Context,
mlir_location: ir.Location,
mlir_module: ir.Module,
):
super().__init__()
+ self.script_path = script_path
self.mlir_context = mlir_context
self.mlir_location = mlir_location
self.mlir_module = mlir_module
@@ -186,9 +205,10 @@ def __init__(
mlir_context=self.mlir_context or self.mlir_context,
scope_name="module",
)
- self.if_bodies_executed = set()
# dirty dirty hack
- self.compares_executed = {}
+ self.if_bodies_executed = set()
+ self.binops_executed = {}
+ self.fn_to_node = {}
def enter_mlir_block_scope(
self,
@@ -290,10 +310,22 @@ def should_propagate_handler_exception(
return True
def should_instrument_file(self, filename: str) -> bool:
- return True
+ return filename == self.script_path
# handlers
+ @TraceEvent.before_class_body
+ def handle_before_class_body(
+ self,
+ old_ret,
+ node,
+ frame: types.FrameType,
+ event,
+ guard_for_spec,
+ **_,
+ ):
+ return False
+
@pyc.before_function_body
def handle_before_function_body(
self,
@@ -366,6 +398,8 @@ def handle_after_return(
inputs=func_type.inputs, results=[mlir_return_val.type]
)
func_op.attributes["function_type"] = ir.TypeAttr.get(canonical_func_type)
+ if isinstance(mlir_return_val, pi.Tensor):
+ mlir_return_val = mlir_return_val.value
func_dialect.ReturnOp((mlir_return_val,))
else:
func_dialect.ReturnOp(())
@@ -383,6 +417,11 @@ def handle_exit_module(
guard_for_spec,
**kwargs,
):
+ for loc_name, loc in frame.f_locals.items():
+ if inspect.isfunction(loc) and getattr(
+ loc, PI_EXPORT_ATTR_NAME, False
+ ):
+ print(loc)
self.exit_mlir_block_scope(scope_name="module")
ast_rewriter_cls = MLIRRewriter
@@ -459,7 +498,7 @@ def handle_before_binop(
):
def eval_op(x, y):
hash = id(x), id(y)
- if hash not in self.compares_executed:
+ if hash not in self.binops_executed:
x, y = map(
lambda v: self.get_or_make_mlir_constant(v)
if isinstance(v, (float, int, bool))
@@ -471,8 +510,12 @@ def eval_op(x, y):
else:
# ...god damn it
op = node.op.__class__.__name__.lower().replace("mult", "mul")
- self.compares_executed[hash] = getattr(value_, op)(x, y)
- return self.compares_executed[hash]
+ if isinstance(x, pi.Tensor):
+ assert isinstance(y, pi.Tensor)
+ self.binops_executed[hash] = getattr(operator, op)(x, y)
+ else:
+ self.binops_executed[hash] = getattr(value_, op)(x, y)
+ return self.binops_executed[hash]
return eval_op
@@ -564,7 +607,12 @@ def get_script_as_module(script: str) -> str:
def trace(script_path, mlir_context, mlir_location, mlir_module) -> ir.Module:
module_to_run = get_script_as_module(script_path)
- with MLIRTracer(mlir_context, mlir_location, mlir_module):
+ with MLIRTracer(
+ script_path,
+ mlir_context,
+ mlir_location,
+ mlir_module,
+ ):
run_module(module_to_run)
return mlir_module
diff --git a/pi/compiler/tracing/trace_events.py b/pi/compiler/tracing/trace_events.py
new file mode 100644
index 0000000..e11de88
--- /dev/null
+++ b/pi/compiler/tracing/trace_events.py
@@ -0,0 +1,166 @@
+# -*- coding: utf-8 -*-
+import ast
+from enum import Enum
+
+from pyccolo import fast
+
+
+class TraceEvent(Enum):
+ before_import = "before_import"
+ init_module = "init_module"
+ exit_module = "exit_module"
+ after_import = "after_import"
+
+ before_stmt = "before_stmt"
+ after_stmt = "after_stmt"
+ after_module_stmt = "after_module_stmt"
+ after_expr_stmt = "after_expr_stmt"
+
+ load_name = "load_name"
+
+ before_for_loop_body = "before_for_loop_body"
+ after_for_loop_iter = "after_for_loop_iter"
+ before_while_loop_body = "before_while_loop_body"
+ after_while_loop_iter = "after_while_loop_iter"
+
+ before_attribute_load = "before_attribute_load"
+ before_attribute_store = "before_attribute_store"
+ before_attribute_del = "before_attribute_del"
+ after_attribute_load = "after_attribute_load"
+ before_subscript_load = "before_subscript_load"
+ before_subscript_store = "before_subscript_store"
+ before_subscript_del = "before_subscript_del"
+ after_subscript_load = "after_subscript_load"
+
+ before_subscript_slice = "before_subscript_slice"
+ after_subscript_slice = "after_subscript_slice"
+ _load_saved_slice = "_load_saved_slice"
+
+ before_load_complex_symbol = "before_load_complex_symbol"
+ after_load_complex_symbol = "after_load_complex_symbol"
+
+ after_if_test = "after_if_test"
+ after_while_test = "after_while_test"
+
+ before_lambda = "before_lambda"
+ after_lambda = "after_lambda"
+
+ before_call = "before_call"
+ after_call = "after_call"
+ before_argument = "before_argument"
+ after_argument = "after_argument"
+ before_return = "before_return"
+ after_return = "after_return"
+
+ before_dict_literal = "before_dict_literal"
+ after_dict_literal = "after_dict_literal"
+ before_list_literal = "before_list_literal"
+ after_list_literal = "after_list_literal"
+ before_set_literal = "before_set_literal"
+ after_set_literal = "after_set_literal"
+ before_tuple_literal = "before_tuple_literal"
+ after_tuple_literal = "after_tuple_literal"
+
+ dict_key = "dict_key"
+ dict_value = "dict_value"
+ list_elt = "list_elt"
+ set_elt = "set_elt"
+ tuple_elt = "tuple_elt"
+
+ before_assign_rhs = "before_assign_rhs"
+ after_assign_rhs = "after_assign_rhs"
+ before_augassign_rhs = "before_augassign_rhs"
+ after_augassign_rhs = "after_augassign_rhs"
+
+ before_function_body = "before_function_body"
+ after_function_execution = "after_function_execution"
+
+ before_class_body = "before_class_body"
+
+ before_lambda_body = "before_lambda_body"
+ after_lambda_body = "after_lambda_body"
+
+ left_binop_arg = "left_binop_arg"
+ right_binop_arg = "right_binop_arg"
+ before_binop = "before_binop"
+ after_binop = "after_binop"
+
+ left_compare_arg = "left_compare_arg"
+ compare_arg = "compare_arg"
+ before_compare = "before_compare"
+ after_compare = "after_compare"
+
+ after_comprehension_if = "after_comprehension_if"
+ after_comprehension_elt = "after_comprehension_elt"
+ after_dict_comprehension_key = "after_dict_comprehension_key"
+ after_dict_comprehension_value = "after_dict_comprehension_value"
+
+ ellipses = "ellipses"
+
+ line = "line"
+ call = "call"
+ return_ = "return"
+ exception = "exception"
+ opcode = "opcode"
+
+ # these are included for completeness but will probably not be used
+ c_call = "c_call"
+ c_return = "c_return"
+ c_exception = "c_exception"
+
+ def __str__(self):
+ return self.value
+
+ def __repr__(self):
+ return "<" + str(self) + ">"
+
+ def __call__(self, handler=None, **kwargs):
+ # this will be filled by tracer.py
+ ...
+
+ def to_ast(self):
+ return fast.Constant(self.name)
+
+
+SYS_TRACE_EVENTS = {
+ TraceEvent.line,
+ TraceEvent.call,
+ TraceEvent.return_,
+ TraceEvent.exception,
+ TraceEvent.opcode,
+}
+
+
+BEFORE_EXPR_EVENTS = {
+ TraceEvent.before_argument,
+ TraceEvent.before_assign_rhs,
+ TraceEvent.before_augassign_rhs,
+ TraceEvent.before_binop,
+ TraceEvent.before_compare,
+ TraceEvent.before_dict_literal,
+ TraceEvent.before_lambda,
+ TraceEvent.before_list_literal,
+ TraceEvent.before_load_complex_symbol,
+ TraceEvent.before_return,
+ TraceEvent.before_set_literal,
+ TraceEvent.before_subscript_slice,
+ TraceEvent.before_tuple_literal,
+}
+
+
+AST_TO_EVENT_MAPPING = {
+ ast.stmt: TraceEvent.after_stmt,
+ ast.Assign: TraceEvent.after_assign_rhs,
+ ast.Module: TraceEvent.init_module,
+ ast.Name: TraceEvent.load_name,
+ ast.Attribute: TraceEvent.before_attribute_load,
+ ast.Subscript: TraceEvent.before_subscript_load,
+ ast.Call: TraceEvent.before_call,
+ ast.Dict: TraceEvent.after_dict_literal,
+ ast.List: TraceEvent.after_list_literal,
+ ast.Tuple: TraceEvent.after_tuple_literal,
+ ast.Set: TraceEvent.after_set_literal,
+ ast.Return: TraceEvent.after_return,
+ ast.BinOp: TraceEvent.after_binop,
+ ast.Compare: TraceEvent.after_compare,
+}
diff --git a/pi/dialects/__init__.py b/pi/dialects/__init__.py
new file mode 100644
index 0000000..2e00f34
--- /dev/null
+++ b/pi/dialects/__init__.py
@@ -0,0 +1,189 @@
+import sys
+
+import logging
+import threading
+from contextlib import contextmanager
+from dataclasses import dataclass
+from importlib.abc import MetaPathFinder
+from importlib.machinery import SourceFileLoader, ModuleSpec
+from importlib.util import find_spec, spec_from_loader
+from pathlib import Path
+from typing import Generator, Callable, Dict, List
+
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass(order=True, frozen=True)
+class ImportOverload:
+ name: str
+ origin: Path
+ is_package: bool
+ submodule_search_locations: List[Path] = None
+
+ def __post_init__(self):
+ if self.is_package and self.submodule_search_locations is None:
+ assert (
+ self.origin.name == "__init__.py"
+ ), f"default search path for {self.name} isn't a package: {self.origin}"
+ object.__setattr__(self, "submodule_search_locations", [self.origin.parent])
+
+
+_base_overloads = [
+ ImportOverload(
+ "torch_mlir.dialects._arith_ops_ext",
+ Path(__file__).parent / "_arith_ops_ext.py",
+ False,
+ ),
+ ImportOverload(
+ "torch_mlir.dialects._memref_ops_ext",
+ Path(__file__).parent / "_memref_ops_ext.py",
+ False,
+ ),
+ ImportOverload(
+ "torch_mlir.dialects._torch_ops_ext_custom",
+ Path(__file__).parent / "_torch_ops_ext_custom.py",
+ False,
+ ),
+ ImportOverload(
+ "torch_mlir.dialects._torch_ops_ext",
+ Path(__file__).parent / "_torch_ops_ext.py",
+ False,
+ ),
+ ImportOverload(
+ "pyccolo.trace_events",
+ Path(__file__).parent.parent / "compiler" / "tracing" / "trace_events.py",
+ False,
+ ),
+]
+
+BASE_OVERLOADS: Dict[str, ImportOverload] = {i.name: i for i in _base_overloads}
+
+
+# this is based on the birdseye finder (which uses import hooks based on MacroPy's):
+# https://github.com/alexmojaki/birdseye/blob/9974af715b1801f9dd99fef93ff133d0ab5223af/birdseye/import_hook.py
+class Overloader(MetaPathFinder):
+ def __init__(self, overloads) -> None:
+ self.tracers = None
+ self._thread = threading.current_thread()
+ self.overloads: Dict[str, ImportOverload] = overloads
+
+ @contextmanager
+ def _clear_preceding_finders(self) -> Generator[None, None, None]:
+ """
+ Clear all preceding finders from sys.meta_path, and restore them afterwards.
+ """
+ orig_finders = sys.meta_path
+ try:
+ sys.meta_path = sys.meta_path[sys.meta_path.index(self) + 1 :] # noqa: E203
+ yield
+ finally:
+ sys.meta_path = orig_finders
+
+ def _find_plain_spec(self, fullname, path, target):
+ """Try to find the original module using all the
+ remaining meta_path finders."""
+ spec = None
+ self_seen = False
+ for finder in sys.meta_path:
+ if finder is self:
+ self_seen = True
+ continue
+ elif not self_seen or "pytest" in finder.__module__:
+ # when testing with pytest, it installs a finder that for
+ # some yet unknown reasons makes birdseye
+ # fail. For now it will just avoid using it and pass to
+ # the next one
+ continue
+ if hasattr(finder, "find_spec"):
+ spec = finder.find_spec(fullname, path, target=target)
+ elif hasattr(finder, "load_module"):
+ spec = spec_from_loader(fullname, finder)
+
+ if spec is not None and spec.origin != "builtin":
+ return spec
+
+ def find_spec(self, fullname, path=None, target=None):
+ logger.debug(f"finding spec for {fullname=} {path=} {target=}")
+
+ if threading.current_thread() is not self._thread:
+ return None
+ if target is None:
+ with self._clear_preceding_finders():
+ spec = find_spec(fullname, path)
+ else:
+ spec = self._find_plain_spec(fullname, path, target)
+
+ if fullname not in self.overloads:
+ if spec is None or not (
+ hasattr(spec.loader, "get_source") and callable(spec.loader.get_source)
+ ): # noqa: E128
+ if fullname != "org":
+ # stdlib pickle.py at line 94 contains a ``from
+ # org.python.core for Jython which is always failing,
+ # of course
+ logger.debug("Failed finding spec for %s", fullname)
+ return None
+
+ if not isinstance(spec.loader, SourceFileLoader):
+ return None
+ return spec
+
+ logger.debug("patching spec for %s", fullname)
+
+ overload = self.overloads[fullname]
+ new_path = str(overload.origin)
+ source_file_loader = SourceFileLoader(fullname, new_path)
+ spec = ModuleSpec(
+ name=fullname,
+ loader=source_file_loader,
+ origin=new_path,
+ is_package=overload.is_package,
+ )
+ if overload.is_package:
+ spec.submodule_search_locations = [
+ str(p) for p in overload.submodule_search_locations
+ ]
+ spec.has_location = True
+ return spec
+
+
+def patch_meta_path_non_context(overloads=None) -> Callable:
+ if overloads is None:
+ overloads = BASE_OVERLOADS
+ orig_meta_path_entry = None
+
+ def cleanup_callback():
+ if orig_meta_path_entry is None:
+ del sys.meta_path[0]
+ else:
+ sys.meta_path[0] = orig_meta_path_entry
+
+ if len(sys.meta_path) > 0 and isinstance(sys.meta_path[0], Overloader):
+ orig_meta_path_entry = sys.meta_path[0]
+ sys.meta_path[0] = Overloader(overloads)
+ else:
+ sys.meta_path.insert(0, Overloader(overloads))
+
+ return cleanup_callback
+
+
+@contextmanager
+def patch_meta_path(overloads=None) -> Generator[None, None, None]:
+ cleanup_callback = None
+ try:
+ cleanup_callback = patch_meta_path_non_context(overloads)
+ yield
+ finally:
+ if cleanup_callback is not None:
+ cleanup_callback()
+
+
+def remove_modules(pred: Callable):
+ to_delete = []
+ for mod in sys.modules:
+ if pred(mod):
+ logger.debug(f"removing from sys.modules {mod}")
+ to_delete.append(mod)
+ for mod in to_delete:
+ del sys.modules[mod]
diff --git a/shark/dialects/_affine_ops_ext.py b/pi/dialects/_affine_ops_ext.py
similarity index 87%
rename from shark/dialects/_affine_ops_ext.py
rename to pi/dialects/_affine_ops_ext.py
index d28741c..da50434 100644
--- a/shark/dialects/_affine_ops_ext.py
+++ b/pi/dialects/_affine_ops_ext.py
@@ -1,19 +1,13 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
- from torch_mlir.ir import *
- from ._ods_common import get_default_loc_context as _get_default_loc_context
-except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
-
from typing import Optional, Sequence, Union
-from ._ods_common import (
- get_op_result_or_value as _get_op_result_or_value,
- get_op_results_or_values as _get_op_results_or_values,
+from torch_mlir.dialects._ods_common import (
+ get_op_result_or_value,
+ get_op_results_or_values,
)
+from torch_mlir.ir import *
class AffineForOp:
@@ -86,8 +80,8 @@ def __init__(
loc: user-visible location of the operation.
ip: insertion point.
"""
- memref_resolved = _get_op_result_or_value(memref)
- indices_resolved = [] if indices is None else _get_op_results_or_values(indices)
+ memref_resolved = get_op_result_or_value(memref)
+ indices_resolved = [] if indices is None else get_op_results_or_values(indices)
return_type = MemRefType(memref_resolved.type).element_type
super().__init__(return_type, memref, indices_resolved, loc=loc, ip=ip)
diff --git a/shark/dialects/_arith_ops_ext.py b/pi/dialects/_arith_ops_ext.py
similarity index 94%
rename from shark/dialects/_arith_ops_ext.py
rename to pi/dialects/_arith_ops_ext.py
index 6465ce4..7b11c2b 100644
--- a/shark/dialects/_arith_ops_ext.py
+++ b/pi/dialects/_arith_ops_ext.py
@@ -1,13 +1,16 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from torch_mlir.dialects._ods_common import (
+ get_default_loc_context,
+ get_op_result_or_value,
+)
+from torch_mlir.ir import *
+from typing import Any, List, Union
try:
- from torch_mlir.ir import *
- from ._ods_common import get_op_result_or_value, get_default_loc_context
- from shark.dialects.value_ import _Value
+ from pi.dialects.value_ import _Value
- from typing import Any, List, Union
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
diff --git a/shark/dialects/_memref_ops_ext.py b/pi/dialects/_memref_ops_ext.py
similarity index 98%
rename from shark/dialects/_memref_ops_ext.py
rename to pi/dialects/_memref_ops_ext.py
index bbb5e7a..ad8bdd3 100644
--- a/shark/dialects/_memref_ops_ext.py
+++ b/pi/dialects/_memref_ops_ext.py
@@ -5,7 +5,7 @@
try:
from torch_mlir.ir import *
from ._ods_common import get_op_result_or_value, get_op_results_or_values
- from shark.dialects.value_ import _Value
+ from pi.dialects.value_ import _Value
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
diff --git a/pi/dialects/_torch_ops_ext.py b/pi/dialects/_torch_ops_ext.py
new file mode 100644
index 0000000..c12fe1f
--- /dev/null
+++ b/pi/dialects/_torch_ops_ext.py
@@ -0,0 +1,8744 @@
+try:
+ # from pi import Tensor, Number
+ from torch_mlir.ir import *
+ from torch_mlir.dialects._ods_common import (
+ get_default_loc_context,
+ get_op_result_or_value,
+ get_op_results_or_values,
+ )
+ from ._torch_ops_ext_custom import *
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import List, Optional, Any
+
+
+class AtenTanhOp:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenTanhOp, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenTanh_Op:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenTanh_Op, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenHardtanhOp:
+ def __init__(self, self_: Value, min_val: "Number", max_val: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(min_val):
+ min_val = torch_dialect.ConstantNumberOp(min_val)
+ else:
+ min_val = get_op_result_or_value(min_val)
+ assert str(min_val.type) in {'!torch.float', '!torch.int'}, f'`min_val` should be a !torch.number but is {type(min_val).__module__}.{type(min_val).__name__}'
+
+ if not is_mlir_value(max_val):
+ max_val = torch_dialect.ConstantNumberOp(max_val)
+ else:
+ max_val = get_op_result_or_value(max_val)
+ assert str(max_val.type) in {'!torch.float', '!torch.int'}, f'`max_val` should be a !torch.number but is {type(max_val).__module__}.{type(max_val).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenHardtanhOp, self).__init__(result_type, self_, min_val, max_val, loc=loc, ip=ip)
+
+
+class AtenHardtanh_Op:
+ def __init__(self, self_: Value, min_val: "Number", max_val: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(min_val):
+ min_val = torch_dialect.ConstantNumberOp(min_val)
+ else:
+ min_val = get_op_result_or_value(min_val)
+ assert str(min_val.type) in {'!torch.float', '!torch.int'}, f'`min_val` should be a !torch.number but is {type(min_val).__module__}.{type(min_val).__name__}'
+
+ if not is_mlir_value(max_val):
+ max_val = torch_dialect.ConstantNumberOp(max_val)
+ else:
+ max_val = get_op_result_or_value(max_val)
+ assert str(max_val.type) in {'!torch.float', '!torch.int'}, f'`max_val` should be a !torch.number but is {type(max_val).__module__}.{type(max_val).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenHardtanh_Op, self).__init__(result_type, self_, min_val, max_val, loc=loc, ip=ip)
+
+
+class AtenReluOp:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenReluOp, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenRelu_Op:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenRelu_Op, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenRelu6Op:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenRelu6Op, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenRelu6_Op:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenRelu6_Op, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenLeakyReluOp:
+ def __init__(self, self_: Value, negative_slope: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(negative_slope):
+ negative_slope = torch_dialect.ConstantNumberOp(negative_slope)
+ else:
+ negative_slope = get_op_result_or_value(negative_slope)
+ assert str(negative_slope.type) in {'!torch.float', '!torch.int'}, f'`negative_slope` should be a !torch.number but is {type(negative_slope).__module__}.{type(negative_slope).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenLeakyReluOp, self).__init__(result_type, self_, negative_slope, loc=loc, ip=ip)
+
+
+class AtenLeakyRelu_Op:
+ def __init__(self, self_: Value, negative_slope: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(negative_slope):
+ negative_slope = torch_dialect.ConstantNumberOp(negative_slope)
+ else:
+ negative_slope = get_op_result_or_value(negative_slope)
+ assert str(negative_slope.type) in {'!torch.float', '!torch.int'}, f'`negative_slope` should be a !torch.number but is {type(negative_slope).__module__}.{type(negative_slope).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenLeakyRelu_Op, self).__init__(result_type, self_, negative_slope, loc=loc, ip=ip)
+
+
+class AtenLogOp:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenLogOp, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenLog_Op:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenLog_Op, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenSigmoidOp:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenSigmoidOp, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenSigmoid_Op:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenSigmoid_Op, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenHardsigmoidOp:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenHardsigmoidOp, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenHardsigmoid_Op:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenHardsigmoid_Op, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenHardswishOp:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenHardswishOp, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenHardswish_Op:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenHardswish_Op, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenErfOp:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenErfOp, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenErf_Op:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenErf_Op, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenSiluOp:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenSiluOp, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenSilu_Op:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenSilu_Op, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenSinOp:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenSinOp, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenSin_Op:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenSin_Op, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenExpOp:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenExpOp, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenExp_Op:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenExp_Op, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenExpm1Op:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenExpm1Op, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenExpm1_Op:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenExpm1_Op, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenCosOp:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenCosOp, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenCos_Op:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenCos_Op, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenAtan2Op:
+ def __init__(self, self_: Value, other: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}'
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenAtan2Op, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenAtan2_Op:
+ def __init__(self, self_: Value, other: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}'
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenAtan2_Op, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenNegOp:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenNegOp, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenNeg_Op:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenNeg_Op, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenFloorOp:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenFloorOp, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenFloor_Op:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenFloor_Op, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenCeilOp:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenCeilOp, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenCeil_Op:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenCeil_Op, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenBitwiseNotOp:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenBitwiseNotOp, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenBitwiseNot_Op:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenBitwiseNot_Op, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenDivTensorOp:
+ def __init__(self, self_: Value, other: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}'
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenDivTensorOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenDiv_TensorOp:
+ def __init__(self, self_: Value, other: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}'
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenDiv_TensorOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenLogicalOrOp:
+ def __init__(self, self_: Value, other: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}'
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenLogicalOrOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenLogicalOr_Op:
+ def __init__(self, self_: Value, other: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}'
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenLogicalOr_Op, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenLogicalAndOp:
+ def __init__(self, self_: Value, other: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}'
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenLogicalAndOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenLogicalAnd_Op:
+ def __init__(self, self_: Value, other: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}'
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenLogicalAnd_Op, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenLogicalXorOp:
+ def __init__(self, self_: Value, other: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}'
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenLogicalXorOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenLogicalXor_Op:
+ def __init__(self, self_: Value, other: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}'
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenLogicalXor_Op, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenLogicalNotOp:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenLogicalNotOp, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenLogicalNot_Op:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenLogicalNot_Op, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenLerpTensorOp:
+ def __init__(self, self_: Value, end: Value, weight: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(end):
+ assert is_mlir_value(end), f'`end` should be a Value but is {type(end).__module__}.{type(end).__name__}'
+ else:
+ end = get_op_result_or_value(end)
+ assert str(end.type).startswith("!torch.vtensor"), f'`end` should be a torch.vtensor but is {type(end).__module__}.{type(end).__name__}'
+
+ if not is_mlir_value(weight):
+ assert is_mlir_value(weight), f'`weight` should be a Value but is {type(weight).__module__}.{type(weight).__name__}'
+ else:
+ weight = get_op_result_or_value(weight)
+ assert str(weight.type).startswith("!torch.vtensor"), f'`weight` should be a torch.vtensor but is {type(weight).__module__}.{type(weight).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenLerpTensorOp, self).__init__(result_type, self_, end, weight, loc=loc, ip=ip)
+
+
+class AtenLerp_TensorOp:
+ def __init__(self, self_: Value, end: Value, weight: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(end):
+ assert is_mlir_value(end), f'`end` should be a Value but is {type(end).__module__}.{type(end).__name__}'
+ else:
+ end = get_op_result_or_value(end)
+ assert str(end.type).startswith("!torch.vtensor"), f'`end` should be a torch.vtensor but is {type(end).__module__}.{type(end).__name__}'
+
+ if not is_mlir_value(weight):
+ assert is_mlir_value(weight), f'`weight` should be a Value but is {type(weight).__module__}.{type(weight).__name__}'
+ else:
+ weight = get_op_result_or_value(weight)
+ assert str(weight.type).startswith("!torch.vtensor"), f'`weight` should be a torch.vtensor but is {type(weight).__module__}.{type(weight).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenLerp_TensorOp, self).__init__(result_type, self_, end, weight, loc=loc, ip=ip)
+
+
+class AtenEqTensorOp:
+ def __init__(self, self_: Value, other: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}'
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenEqTensorOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenEq_TensorOp:
+ def __init__(self, self_: Value, other: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}'
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenEq_TensorOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenGtTensorOp:
+ def __init__(self, self_: Value, other: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}'
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenGtTensorOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenGt_TensorOp:
+ def __init__(self, self_: Value, other: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}'
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenGt_TensorOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenGeTensorOp:
+ def __init__(self, self_: Value, other: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}'
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenGeTensorOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenGe_TensorOp:
+ def __init__(self, self_: Value, other: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}'
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenGe_TensorOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenLtTensorOp:
+ def __init__(self, self_: Value, other: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}'
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenLtTensorOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenLt_TensorOp:
+ def __init__(self, self_: Value, other: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}'
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenLt_TensorOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenLeTensorOp:
+ def __init__(self, self_: Value, other: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}'
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenLeTensorOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenLe_TensorOp:
+ def __init__(self, self_: Value, other: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}'
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenLe_TensorOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenNeTensorOp:
+ def __init__(self, self_: Value, other: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}'
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenNeTensorOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenNe_TensorOp:
+ def __init__(self, self_: Value, other: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}'
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenNe_TensorOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenDivScalarOp:
+ def __init__(self, self_: Value, other: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ other = torch_dialect.ConstantNumberOp(other)
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenDivScalarOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenDiv_ScalarOp:
+ def __init__(self, self_: Value, other: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ other = torch_dialect.ConstantNumberOp(other)
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenDiv_ScalarOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenNeScalarOp:
+ def __init__(self, self_: Value, other: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ other = torch_dialect.ConstantNumberOp(other)
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenNeScalarOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenNe_ScalarOp:
+ def __init__(self, self_: Value, other: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ other = torch_dialect.ConstantNumberOp(other)
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenNe_ScalarOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenEqScalarOp:
+ def __init__(self, self_: Value, other: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ other = torch_dialect.ConstantNumberOp(other)
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenEqScalarOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenEq_ScalarOp:
+ def __init__(self, self_: Value, other: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ other = torch_dialect.ConstantNumberOp(other)
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenEq_ScalarOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenGtScalarOp:
+ def __init__(self, self_: Value, other: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ other = torch_dialect.ConstantNumberOp(other)
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenGtScalarOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenGt_ScalarOp:
+ def __init__(self, self_: Value, other: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ other = torch_dialect.ConstantNumberOp(other)
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenGt_ScalarOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenGeScalarOp:
+ def __init__(self, self_: Value, other: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ other = torch_dialect.ConstantNumberOp(other)
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenGeScalarOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenGe_ScalarOp:
+ def __init__(self, self_: Value, other: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ other = torch_dialect.ConstantNumberOp(other)
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenGe_ScalarOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenLtScalarOp:
+ def __init__(self, self_: Value, other: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ other = torch_dialect.ConstantNumberOp(other)
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenLtScalarOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenLt_ScalarOp:
+ def __init__(self, self_: Value, other: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ other = torch_dialect.ConstantNumberOp(other)
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenLt_ScalarOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenLeScalarOp:
+ def __init__(self, self_: Value, other: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ other = torch_dialect.ConstantNumberOp(other)
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenLeScalarOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenLe_ScalarOp:
+ def __init__(self, self_: Value, other: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ other = torch_dialect.ConstantNumberOp(other)
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenLe_ScalarOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenFmodScalarOp:
+ def __init__(self, self_: Value, other: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ other = torch_dialect.ConstantNumberOp(other)
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenFmodScalarOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenFmod_ScalarOp:
+ def __init__(self, self_: Value, other: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ other = torch_dialect.ConstantNumberOp(other)
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenFmod_ScalarOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenMaskedFillScalarOp:
+ def __init__(self, self_: Value, mask: Value, value: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(mask):
+ assert is_mlir_value(mask), f'`mask` should be a Value but is {type(mask).__module__}.{type(mask).__name__}'
+ else:
+ mask = get_op_result_or_value(mask)
+ assert str(mask.type).startswith("!torch.vtensor"), f'`mask` should be a torch.vtensor but is {type(mask).__module__}.{type(mask).__name__}'
+
+ if not is_mlir_value(value):
+ value = torch_dialect.ConstantNumberOp(value)
+ else:
+ value = get_op_result_or_value(value)
+ assert str(value.type) in {'!torch.float', '!torch.int'}, f'`value` should be a !torch.number but is {type(value).__module__}.{type(value).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenMaskedFillScalarOp, self).__init__(result_type, self_, mask, value, loc=loc, ip=ip)
+
+
+class AtenMaskedFill_ScalarOp:
+ def __init__(self, self_: Value, mask: Value, value: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(mask):
+ assert is_mlir_value(mask), f'`mask` should be a Value but is {type(mask).__module__}.{type(mask).__name__}'
+ else:
+ mask = get_op_result_or_value(mask)
+ assert str(mask.type).startswith("!torch.vtensor"), f'`mask` should be a torch.vtensor but is {type(mask).__module__}.{type(mask).__name__}'
+
+ if not is_mlir_value(value):
+ value = torch_dialect.ConstantNumberOp(value)
+ else:
+ value = get_op_result_or_value(value)
+ assert str(value.type) in {'!torch.float', '!torch.int'}, f'`value` should be a !torch.number but is {type(value).__module__}.{type(value).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenMaskedFill_ScalarOp, self).__init__(result_type, self_, mask, value, loc=loc, ip=ip)
+
+
+class AtenMaskedFillTensorOp:
+ def __init__(self, self_: Value, mask: Value, value: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(mask):
+ assert is_mlir_value(mask), f'`mask` should be a Value but is {type(mask).__module__}.{type(mask).__name__}'
+ else:
+ mask = get_op_result_or_value(mask)
+ assert str(mask.type).startswith("!torch.vtensor"), f'`mask` should be a torch.vtensor but is {type(mask).__module__}.{type(mask).__name__}'
+
+ if not is_mlir_value(value):
+ assert is_mlir_value(value), f'`value` should be a Value but is {type(value).__module__}.{type(value).__name__}'
+ else:
+ value = get_op_result_or_value(value)
+ assert str(value.type).startswith("!torch.vtensor"), f'`value` should be a torch.vtensor but is {type(value).__module__}.{type(value).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenMaskedFillTensorOp, self).__init__(result_type, self_, mask, value, loc=loc, ip=ip)
+
+
+class AtenMaskedFill_TensorOp:
+ def __init__(self, self_: Value, mask: Value, value: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(mask):
+ assert is_mlir_value(mask), f'`mask` should be a Value but is {type(mask).__module__}.{type(mask).__name__}'
+ else:
+ mask = get_op_result_or_value(mask)
+ assert str(mask.type).startswith("!torch.vtensor"), f'`mask` should be a torch.vtensor but is {type(mask).__module__}.{type(mask).__name__}'
+
+ if not is_mlir_value(value):
+ assert is_mlir_value(value), f'`value` should be a Value but is {type(value).__module__}.{type(value).__name__}'
+ else:
+ value = get_op_result_or_value(value)
+ assert str(value.type).startswith("!torch.vtensor"), f'`value` should be a torch.vtensor but is {type(value).__module__}.{type(value).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenMaskedFill_TensorOp, self).__init__(result_type, self_, mask, value, loc=loc, ip=ip)
+
+
+class AtenClampOp:
+ def __init__(self, self_: Value, min: Optional["Number"], max: Optional["Number"], *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(min):
+ if min is not None:
+ min = torch_dialect.ConstantNumberOp(min)
+ else:
+ min = torch_dialect.ConstantNoneOp()
+ else:
+ min = get_op_result_or_value(min)
+ assert str(min.type) in {'!torch.float', '!torch.int'}, f'`min` should be a !torch.number but is {type(min).__module__}.{type(min).__name__}'
+
+ if not is_mlir_value(max):
+ if max is not None:
+ max = torch_dialect.ConstantNumberOp(max)
+ else:
+ max = torch_dialect.ConstantNoneOp()
+ else:
+ max = get_op_result_or_value(max)
+ assert str(max.type) in {'!torch.float', '!torch.int'}, f'`max` should be a !torch.number but is {type(max).__module__}.{type(max).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenClampOp, self).__init__(result_type, self_, min, max, loc=loc, ip=ip)
+
+
+class AtenClamp_Op:
+ def __init__(self, self_: Value, min: Optional["Number"], max: Optional["Number"], *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(min):
+ if min is not None:
+ min = torch_dialect.ConstantNumberOp(min)
+ else:
+ min = torch_dialect.ConstantNoneOp()
+ else:
+ min = get_op_result_or_value(min)
+ assert str(min.type) in {'!torch.float', '!torch.int'}, f'`min` should be a !torch.number but is {type(min).__module__}.{type(min).__name__}'
+
+ if not is_mlir_value(max):
+ if max is not None:
+ max = torch_dialect.ConstantNumberOp(max)
+ else:
+ max = torch_dialect.ConstantNoneOp()
+ else:
+ max = get_op_result_or_value(max)
+ assert str(max.type) in {'!torch.float', '!torch.int'}, f'`max` should be a !torch.number but is {type(max).__module__}.{type(max).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenClamp_Op, self).__init__(result_type, self_, min, max, loc=loc, ip=ip)
+
+
+class AtenClampMinOp:
+ def __init__(self, self_: Value, min: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(min):
+ min = torch_dialect.ConstantNumberOp(min)
+ else:
+ min = get_op_result_or_value(min)
+ assert str(min.type) in {'!torch.float', '!torch.int'}, f'`min` should be a !torch.number but is {type(min).__module__}.{type(min).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenClampMinOp, self).__init__(result_type, self_, min, loc=loc, ip=ip)
+
+
+class AtenClampMin_Op:
+ def __init__(self, self_: Value, min: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(min):
+ min = torch_dialect.ConstantNumberOp(min)
+ else:
+ min = get_op_result_or_value(min)
+ assert str(min.type) in {'!torch.float', '!torch.int'}, f'`min` should be a !torch.number but is {type(min).__module__}.{type(min).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenClampMin_Op, self).__init__(result_type, self_, min, loc=loc, ip=ip)
+
+
+class AtenClampMaxOp:
+ def __init__(self, self_: Value, max: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(max):
+ max = torch_dialect.ConstantNumberOp(max)
+ else:
+ max = get_op_result_or_value(max)
+ assert str(max.type) in {'!torch.float', '!torch.int'}, f'`max` should be a !torch.number but is {type(max).__module__}.{type(max).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenClampMaxOp, self).__init__(result_type, self_, max, loc=loc, ip=ip)
+
+
+class AtenClampMax_Op:
+ def __init__(self, self_: Value, max: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(max):
+ max = torch_dialect.ConstantNumberOp(max)
+ else:
+ max = get_op_result_or_value(max)
+ assert str(max.type) in {'!torch.float', '!torch.int'}, f'`max` should be a !torch.number but is {type(max).__module__}.{type(max).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenClampMax_Op, self).__init__(result_type, self_, max, loc=loc, ip=ip)
+
+
+class AtenLog2Op:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenLog2Op, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenLog2_Op:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenLog2_Op, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenSqrtOp:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenSqrtOp, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenSqrt_Op:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenSqrt_Op, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenLog1pOp:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenLog1pOp, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenLog1p_Op:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenLog1p_Op, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenRsqrtOp:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenRsqrtOp, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenRsqrt_Op:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenRsqrt_Op, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenAbsOp:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenAbsOp, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenAbs_Op:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenAbs_Op, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenReciprocalOp:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenReciprocalOp, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenReciprocal_Op:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenReciprocal_Op, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenBitwiseAndTensorOp:
+ def __init__(self, self_: Value, other: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}'
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenBitwiseAndTensorOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenBitwiseAnd_TensorOp:
+ def __init__(self, self_: Value, other: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}'
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenBitwiseAnd_TensorOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenBitwiseOrTensorOp:
+ def __init__(self, self_: Value, other: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}'
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenBitwiseOrTensorOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenBitwiseOr_TensorOp:
+ def __init__(self, self_: Value, other: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}'
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenBitwiseOr_TensorOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenThresholdOp:
+ def __init__(self, self_: Value, threshold: "Number", value: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(threshold):
+ threshold = torch_dialect.ConstantNumberOp(threshold)
+ else:
+ threshold = get_op_result_or_value(threshold)
+ assert str(threshold.type) in {'!torch.float', '!torch.int'}, f'`threshold` should be a !torch.number but is {type(threshold).__module__}.{type(threshold).__name__}'
+
+ if not is_mlir_value(value):
+ value = torch_dialect.ConstantNumberOp(value)
+ else:
+ value = get_op_result_or_value(value)
+ assert str(value.type) in {'!torch.float', '!torch.int'}, f'`value` should be a !torch.number but is {type(value).__module__}.{type(value).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenThresholdOp, self).__init__(result_type, self_, threshold, value, loc=loc, ip=ip)
+
+
+class AtenThreshold_Op:
+ def __init__(self, self_: Value, threshold: "Number", value: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(threshold):
+ threshold = torch_dialect.ConstantNumberOp(threshold)
+ else:
+ threshold = get_op_result_or_value(threshold)
+ assert str(threshold.type) in {'!torch.float', '!torch.int'}, f'`threshold` should be a !torch.number but is {type(threshold).__module__}.{type(threshold).__name__}'
+
+ if not is_mlir_value(value):
+ value = torch_dialect.ConstantNumberOp(value)
+ else:
+ value = get_op_result_or_value(value)
+ assert str(value.type) in {'!torch.float', '!torch.int'}, f'`value` should be a !torch.number but is {type(value).__module__}.{type(value).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenThreshold_Op, self).__init__(result_type, self_, threshold, value, loc=loc, ip=ip)
+
+
+class AtenSquareOp:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenSquareOp, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenSquare_Op:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenSquare_Op, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenUnsqueezeOp:
+ def __init__(self, self_: Value, dim: int, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(dim):
+ dim = torch_dialect.ConstantIntOp(dim)
+ else:
+ dim = get_op_result_or_value(dim)
+ assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenUnsqueezeOp, self).__init__(result_type, self_, dim, loc=loc, ip=ip)
+
+
+class AtenUnsqueeze_Op:
+ def __init__(self, self_: Value, dim: int, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(dim):
+ dim = torch_dialect.ConstantIntOp(dim)
+ else:
+ dim = get_op_result_or_value(dim)
+ assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenUnsqueeze_Op, self).__init__(result_type, self_, dim, loc=loc, ip=ip)
+
+
+class AtenZeroOp:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenZeroOp, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenZero_Op:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenZero_Op, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenFillScalarOp:
+ def __init__(self, self_: Value, value: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(value):
+ value = torch_dialect.ConstantNumberOp(value)
+ else:
+ value = get_op_result_or_value(value)
+ assert str(value.type) in {'!torch.float', '!torch.int'}, f'`value` should be a !torch.number but is {type(value).__module__}.{type(value).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenFillScalarOp, self).__init__(result_type, self_, value, loc=loc, ip=ip)
+
+
+class AtenFill_ScalarOp:
+ def __init__(self, self_: Value, value: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(value):
+ value = torch_dialect.ConstantNumberOp(value)
+ else:
+ value = get_op_result_or_value(value)
+ assert str(value.type) in {'!torch.float', '!torch.int'}, f'`value` should be a !torch.number but is {type(value).__module__}.{type(value).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenFill_ScalarOp, self).__init__(result_type, self_, value, loc=loc, ip=ip)
+
+
+class AtenFillTensorOp:
+ def __init__(self, self_: Value, value: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(value):
+ assert is_mlir_value(value), f'`value` should be a Value but is {type(value).__module__}.{type(value).__name__}'
+ else:
+ value = get_op_result_or_value(value)
+ assert str(value.type).startswith("!torch.vtensor"), f'`value` should be a torch.vtensor but is {type(value).__module__}.{type(value).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenFillTensorOp, self).__init__(result_type, self_, value, loc=loc, ip=ip)
+
+
+class AtenFill_TensorOp:
+ def __init__(self, self_: Value, value: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(value):
+ assert is_mlir_value(value), f'`value` should be a Value but is {type(value).__module__}.{type(value).__name__}'
+ else:
+ value = get_op_result_or_value(value)
+ assert str(value.type).startswith("!torch.vtensor"), f'`value` should be a torch.vtensor but is {type(value).__module__}.{type(value).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenFill_TensorOp, self).__init__(result_type, self_, value, loc=loc, ip=ip)
+
+
+class AtenDivTensorModeOp:
+ def __init__(self, self_: Value, other: Value, rounding_mode: Optional[str], *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}'
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}'
+
+ if not is_mlir_value(rounding_mode):
+ if rounding_mode is not None:
+ rounding_mode = torch_dialect.ConstantStrOp(rounding_mode)
+ else:
+ rounding_mode = torch_dialect.ConstantNoneOp()
+ else:
+ rounding_mode = get_op_result_or_value(rounding_mode)
+ assert str(rounding_mode.type) == '!torch.str', f'`rounding_mode` should be a !torch.str but is {type(rounding_mode).__module__}.{type(rounding_mode).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenDivTensorModeOp, self).__init__(result_type, self_, other, rounding_mode, loc=loc, ip=ip)
+
+
+class AtenDiv_TensorModeOp:
+ def __init__(self, self_: Value, other: Value, rounding_mode: Optional[str], *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}'
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}'
+
+ if not is_mlir_value(rounding_mode):
+ if rounding_mode is not None:
+ rounding_mode = torch_dialect.ConstantStrOp(rounding_mode)
+ else:
+ rounding_mode = torch_dialect.ConstantNoneOp()
+ else:
+ rounding_mode = get_op_result_or_value(rounding_mode)
+ assert str(rounding_mode.type) == '!torch.str', f'`rounding_mode` should be a !torch.str but is {type(rounding_mode).__module__}.{type(rounding_mode).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenDiv_TensorModeOp, self).__init__(result_type, self_, other, rounding_mode, loc=loc, ip=ip)
+
+
+class AtenMulTensorOp:
+ def __init__(self, self_: Value, other: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}'
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenMulTensorOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenMul_TensorOp:
+ def __init__(self, self_: Value, other: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}'
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenMul_TensorOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenAddTensorOp:
+ def __init__(self, self_: Value, other: Value, alpha: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}'
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}'
+
+ if not is_mlir_value(alpha):
+ alpha = torch_dialect.ConstantNumberOp(alpha)
+ else:
+ alpha = get_op_result_or_value(alpha)
+ assert str(alpha.type) in {'!torch.float', '!torch.int'}, f'`alpha` should be a !torch.number but is {type(alpha).__module__}.{type(alpha).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenAddTensorOp, self).__init__(result_type, self_, other, alpha, loc=loc, ip=ip)
+
+
+class AtenAdd_TensorOp:
+ def __init__(self, self_: Value, other: Value, alpha: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}'
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}'
+
+ if not is_mlir_value(alpha):
+ alpha = torch_dialect.ConstantNumberOp(alpha)
+ else:
+ alpha = get_op_result_or_value(alpha)
+ assert str(alpha.type) in {'!torch.float', '!torch.int'}, f'`alpha` should be a !torch.number but is {type(alpha).__module__}.{type(alpha).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenAdd_TensorOp, self).__init__(result_type, self_, other, alpha, loc=loc, ip=ip)
+
+
+class AtenSubTensorOp:
+ def __init__(self, self_: Value, other: Value, alpha: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}'
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}'
+
+ if not is_mlir_value(alpha):
+ alpha = torch_dialect.ConstantNumberOp(alpha)
+ else:
+ alpha = get_op_result_or_value(alpha)
+ assert str(alpha.type) in {'!torch.float', '!torch.int'}, f'`alpha` should be a !torch.number but is {type(alpha).__module__}.{type(alpha).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenSubTensorOp, self).__init__(result_type, self_, other, alpha, loc=loc, ip=ip)
+
+
+class AtenSub_TensorOp:
+ def __init__(self, self_: Value, other: Value, alpha: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}'
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}'
+
+ if not is_mlir_value(alpha):
+ alpha = torch_dialect.ConstantNumberOp(alpha)
+ else:
+ alpha = get_op_result_or_value(alpha)
+ assert str(alpha.type) in {'!torch.float', '!torch.int'}, f'`alpha` should be a !torch.number but is {type(alpha).__module__}.{type(alpha).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenSub_TensorOp, self).__init__(result_type, self_, other, alpha, loc=loc, ip=ip)
+
+
+class AtenAddScalarOp:
+ def __init__(self, self_: Value, other: "Number", alpha: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ other = torch_dialect.ConstantNumberOp(other)
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}'
+
+ if not is_mlir_value(alpha):
+ alpha = torch_dialect.ConstantNumberOp(alpha)
+ else:
+ alpha = get_op_result_or_value(alpha)
+ assert str(alpha.type) in {'!torch.float', '!torch.int'}, f'`alpha` should be a !torch.number but is {type(alpha).__module__}.{type(alpha).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenAddScalarOp, self).__init__(result_type, self_, other, alpha, loc=loc, ip=ip)
+
+
+class AtenAdd_ScalarOp:
+ def __init__(self, self_: Value, other: "Number", alpha: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ other = torch_dialect.ConstantNumberOp(other)
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}'
+
+ if not is_mlir_value(alpha):
+ alpha = torch_dialect.ConstantNumberOp(alpha)
+ else:
+ alpha = get_op_result_or_value(alpha)
+ assert str(alpha.type) in {'!torch.float', '!torch.int'}, f'`alpha` should be a !torch.number but is {type(alpha).__module__}.{type(alpha).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenAdd_ScalarOp, self).__init__(result_type, self_, other, alpha, loc=loc, ip=ip)
+
+
+class AtenSubScalarOp:
+ def __init__(self, self_: Value, other: "Number", alpha: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ other = torch_dialect.ConstantNumberOp(other)
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}'
+
+ if not is_mlir_value(alpha):
+ alpha = torch_dialect.ConstantNumberOp(alpha)
+ else:
+ alpha = get_op_result_or_value(alpha)
+ assert str(alpha.type) in {'!torch.float', '!torch.int'}, f'`alpha` should be a !torch.number but is {type(alpha).__module__}.{type(alpha).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenSubScalarOp, self).__init__(result_type, self_, other, alpha, loc=loc, ip=ip)
+
+
+class AtenSub_ScalarOp:
+ def __init__(self, self_: Value, other: "Number", alpha: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ other = torch_dialect.ConstantNumberOp(other)
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}'
+
+ if not is_mlir_value(alpha):
+ alpha = torch_dialect.ConstantNumberOp(alpha)
+ else:
+ alpha = get_op_result_or_value(alpha)
+ assert str(alpha.type) in {'!torch.float', '!torch.int'}, f'`alpha` should be a !torch.number but is {type(alpha).__module__}.{type(alpha).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenSub_ScalarOp, self).__init__(result_type, self_, other, alpha, loc=loc, ip=ip)
+
+
+class AtenMulScalarOp:
+ def __init__(self, self_: Value, other: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ other = torch_dialect.ConstantNumberOp(other)
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenMulScalarOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenMul_ScalarOp:
+ def __init__(self, self_: Value, other: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ other = torch_dialect.ConstantNumberOp(other)
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenMul_ScalarOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenAddcmulOp:
+ def __init__(self, self_: Value, tensor1: Value, tensor2: Value, value: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(tensor1):
+ assert is_mlir_value(tensor1), f'`tensor1` should be a Value but is {type(tensor1).__module__}.{type(tensor1).__name__}'
+ else:
+ tensor1 = get_op_result_or_value(tensor1)
+ assert str(tensor1.type).startswith("!torch.vtensor"), f'`tensor1` should be a torch.vtensor but is {type(tensor1).__module__}.{type(tensor1).__name__}'
+
+ if not is_mlir_value(tensor2):
+ assert is_mlir_value(tensor2), f'`tensor2` should be a Value but is {type(tensor2).__module__}.{type(tensor2).__name__}'
+ else:
+ tensor2 = get_op_result_or_value(tensor2)
+ assert str(tensor2.type).startswith("!torch.vtensor"), f'`tensor2` should be a torch.vtensor but is {type(tensor2).__module__}.{type(tensor2).__name__}'
+
+ if not is_mlir_value(value):
+ value = torch_dialect.ConstantNumberOp(value)
+ else:
+ value = get_op_result_or_value(value)
+ assert str(value.type) in {'!torch.float', '!torch.int'}, f'`value` should be a !torch.number but is {type(value).__module__}.{type(value).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenAddcmulOp, self).__init__(result_type, self_, tensor1, tensor2, value, loc=loc, ip=ip)
+
+
+class AtenAddcmul_Op:
+ def __init__(self, self_: Value, tensor1: Value, tensor2: Value, value: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(tensor1):
+ assert is_mlir_value(tensor1), f'`tensor1` should be a Value but is {type(tensor1).__module__}.{type(tensor1).__name__}'
+ else:
+ tensor1 = get_op_result_or_value(tensor1)
+ assert str(tensor1.type).startswith("!torch.vtensor"), f'`tensor1` should be a torch.vtensor but is {type(tensor1).__module__}.{type(tensor1).__name__}'
+
+ if not is_mlir_value(tensor2):
+ assert is_mlir_value(tensor2), f'`tensor2` should be a Value but is {type(tensor2).__module__}.{type(tensor2).__name__}'
+ else:
+ tensor2 = get_op_result_or_value(tensor2)
+ assert str(tensor2.type).startswith("!torch.vtensor"), f'`tensor2` should be a torch.vtensor but is {type(tensor2).__module__}.{type(tensor2).__name__}'
+
+ if not is_mlir_value(value):
+ value = torch_dialect.ConstantNumberOp(value)
+ else:
+ value = get_op_result_or_value(value)
+ assert str(value.type) in {'!torch.float', '!torch.int'}, f'`value` should be a !torch.number but is {type(value).__module__}.{type(value).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenAddcmul_Op, self).__init__(result_type, self_, tensor1, tensor2, value, loc=loc, ip=ip)
+
+
+class AtenAddcdivOp:
+ def __init__(self, self_: Value, tensor1: Value, tensor2: Value, value: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(tensor1):
+ assert is_mlir_value(tensor1), f'`tensor1` should be a Value but is {type(tensor1).__module__}.{type(tensor1).__name__}'
+ else:
+ tensor1 = get_op_result_or_value(tensor1)
+ assert str(tensor1.type).startswith("!torch.vtensor"), f'`tensor1` should be a torch.vtensor but is {type(tensor1).__module__}.{type(tensor1).__name__}'
+
+ if not is_mlir_value(tensor2):
+ assert is_mlir_value(tensor2), f'`tensor2` should be a Value but is {type(tensor2).__module__}.{type(tensor2).__name__}'
+ else:
+ tensor2 = get_op_result_or_value(tensor2)
+ assert str(tensor2.type).startswith("!torch.vtensor"), f'`tensor2` should be a torch.vtensor but is {type(tensor2).__module__}.{type(tensor2).__name__}'
+
+ if not is_mlir_value(value):
+ value = torch_dialect.ConstantNumberOp(value)
+ else:
+ value = get_op_result_or_value(value)
+ assert str(value.type) in {'!torch.float', '!torch.int'}, f'`value` should be a !torch.number but is {type(value).__module__}.{type(value).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenAddcdivOp, self).__init__(result_type, self_, tensor1, tensor2, value, loc=loc, ip=ip)
+
+
+class AtenAddcdiv_Op:
+ def __init__(self, self_: Value, tensor1: Value, tensor2: Value, value: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(tensor1):
+ assert is_mlir_value(tensor1), f'`tensor1` should be a Value but is {type(tensor1).__module__}.{type(tensor1).__name__}'
+ else:
+ tensor1 = get_op_result_or_value(tensor1)
+ assert str(tensor1.type).startswith("!torch.vtensor"), f'`tensor1` should be a torch.vtensor but is {type(tensor1).__module__}.{type(tensor1).__name__}'
+
+ if not is_mlir_value(tensor2):
+ assert is_mlir_value(tensor2), f'`tensor2` should be a Value but is {type(tensor2).__module__}.{type(tensor2).__name__}'
+ else:
+ tensor2 = get_op_result_or_value(tensor2)
+ assert str(tensor2.type).startswith("!torch.vtensor"), f'`tensor2` should be a torch.vtensor but is {type(tensor2).__module__}.{type(tensor2).__name__}'
+
+ if not is_mlir_value(value):
+ value = torch_dialect.ConstantNumberOp(value)
+ else:
+ value = get_op_result_or_value(value)
+ assert str(value.type) in {'!torch.float', '!torch.int'}, f'`value` should be a !torch.number but is {type(value).__module__}.{type(value).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenAddcdiv_Op, self).__init__(result_type, self_, tensor1, tensor2, value, loc=loc, ip=ip)
+
+
+class AtenMaximumOp:
+ def __init__(self, self_: Value, other: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}'
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenMaximumOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenMinimumOp:
+ def __init__(self, self_: Value, other: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}'
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenMinimumOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenMishOp:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenMishOp, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenRsubScalarOp:
+ def __init__(self, self_: Value, other: "Number", alpha: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ other = torch_dialect.ConstantNumberOp(other)
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}'
+
+ if not is_mlir_value(alpha):
+ alpha = torch_dialect.ConstantNumberOp(alpha)
+ else:
+ alpha = get_op_result_or_value(alpha)
+ assert str(alpha.type) in {'!torch.float', '!torch.int'}, f'`alpha` should be a !torch.number but is {type(alpha).__module__}.{type(alpha).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenRsubScalarOp, self).__init__(result_type, self_, other, alpha, loc=loc, ip=ip)
+
+
+class AtenGeluOp:
+ def __init__(self, self_: Value, approximate: str, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(approximate):
+ approximate = torch_dialect.ConstantStrOp(approximate)
+ else:
+ approximate = get_op_result_or_value(approximate)
+ assert str(approximate.type) == '!torch.str', f'`approximate` should be a !torch.str but is {type(approximate).__module__}.{type(approximate).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenGeluOp, self).__init__(result_type, self_, approximate, loc=loc, ip=ip)
+
+
+class AtenPowTensorScalarOp:
+ def __init__(self, self_: Value, exponent: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(exponent):
+ exponent = torch_dialect.ConstantNumberOp(exponent)
+ else:
+ exponent = get_op_result_or_value(exponent)
+ assert str(exponent.type) in {'!torch.float', '!torch.int'}, f'`exponent` should be a !torch.number but is {type(exponent).__module__}.{type(exponent).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenPowTensorScalarOp, self).__init__(result_type, self_, exponent, loc=loc, ip=ip)
+
+
+class AtenPowTensorTensorOp:
+ def __init__(self, self_: Value, exponent: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(exponent):
+ assert is_mlir_value(exponent), f'`exponent` should be a Value but is {type(exponent).__module__}.{type(exponent).__name__}'
+ else:
+ exponent = get_op_result_or_value(exponent)
+ assert str(exponent.type).startswith("!torch.vtensor"), f'`exponent` should be a torch.vtensor but is {type(exponent).__module__}.{type(exponent).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenPowTensorTensorOp, self).__init__(result_type, self_, exponent, loc=loc, ip=ip)
+
+
+class AtenThresholdBackwardOp:
+ def __init__(self, grad_output: Value, self_: Value, threshold: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(grad_output):
+ assert is_mlir_value(grad_output), f'`grad_output` should be a Value but is {type(grad_output).__module__}.{type(grad_output).__name__}'
+ else:
+ grad_output = get_op_result_or_value(grad_output)
+ assert str(grad_output.type).startswith("!torch.vtensor"), f'`grad_output` should be a torch.vtensor but is {type(grad_output).__module__}.{type(grad_output).__name__}'
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(threshold):
+ threshold = torch_dialect.ConstantNumberOp(threshold)
+ else:
+ threshold = get_op_result_or_value(threshold)
+ assert str(threshold.type) in {'!torch.float', '!torch.int'}, f'`threshold` should be a !torch.number but is {type(threshold).__module__}.{type(threshold).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenThresholdBackwardOp, self).__init__(result_type, grad_output, self_, threshold, loc=loc, ip=ip)
+
+
+class AtenFloorDivideOp:
+ def __init__(self, self_: Value, other: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}'
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenFloorDivideOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenSoftplusOp:
+ def __init__(self, self_: Value, beta: "Number", threshold: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(beta):
+ beta = torch_dialect.ConstantNumberOp(beta)
+ else:
+ beta = get_op_result_or_value(beta)
+ assert str(beta.type) in {'!torch.float', '!torch.int'}, f'`beta` should be a !torch.number but is {type(beta).__module__}.{type(beta).__name__}'
+
+ if not is_mlir_value(threshold):
+ threshold = torch_dialect.ConstantNumberOp(threshold)
+ else:
+ threshold = get_op_result_or_value(threshold)
+ assert str(threshold.type) in {'!torch.float', '!torch.int'}, f'`threshold` should be a !torch.number but is {type(threshold).__module__}.{type(threshold).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenSoftplusOp, self).__init__(result_type, self_, beta, threshold, loc=loc, ip=ip)
+
+
+class AtenPreluOp:
+ def __init__(self, self_: Value, weight: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(weight):
+ assert is_mlir_value(weight), f'`weight` should be a Value but is {type(weight).__module__}.{type(weight).__name__}'
+ else:
+ weight = get_op_result_or_value(weight)
+ assert str(weight.type).startswith("!torch.vtensor"), f'`weight` should be a torch.vtensor but is {type(weight).__module__}.{type(weight).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenPreluOp, self).__init__(result_type, self_, weight, loc=loc, ip=ip)
+
+
+class AtenTriuOp:
+ def __init__(self, self_: Value, diagonal: int, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(diagonal):
+ diagonal = torch_dialect.ConstantIntOp(diagonal)
+ else:
+ diagonal = get_op_result_or_value(diagonal)
+ assert str(diagonal.type) == '!torch.int', f'`diagonal` should be a !torch.int but is {type(diagonal).__module__}.{type(diagonal).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenTriuOp, self).__init__(result_type, self_, diagonal, loc=loc, ip=ip)
+
+
+class AtenTriu_Op:
+ def __init__(self, self_: Value, diagonal: int, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(diagonal):
+ diagonal = torch_dialect.ConstantIntOp(diagonal)
+ else:
+ diagonal = get_op_result_or_value(diagonal)
+ assert str(diagonal.type) == '!torch.int', f'`diagonal` should be a !torch.int but is {type(diagonal).__module__}.{type(diagonal).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenTriu_Op, self).__init__(result_type, self_, diagonal, loc=loc, ip=ip)
+
+
+class AtenRoundOp:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenRoundOp, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenRound_Op:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenRound_Op, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenIndexPutHackedTwinOp:
+ def __init__(self, self_: Value, indices: List[Value], values: Value, accumulate: bool, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(indices):
+ indices = torch_dialect.PrimListConstructOp(indices)
+ else:
+ indices = get_op_result_or_value(indices)
+ assert str(indices.type) == '!torch.list', f'`indices` should be a !torch.list but is {type(indices).__module__}.{type(indices).__name__}'
+
+ if not is_mlir_value(values):
+ assert is_mlir_value(values), f'`values` should be a Value but is {type(values).__module__}.{type(values).__name__}'
+ else:
+ values = get_op_result_or_value(values)
+ assert str(values.type).startswith("!torch.vtensor"), f'`values` should be a torch.vtensor but is {type(values).__module__}.{type(values).__name__}'
+
+ if not is_mlir_value(accumulate):
+ accumulate = torch_dialect.ConstantBoolOp(accumulate)
+ else:
+ accumulate = get_op_result_or_value(accumulate)
+ assert str(accumulate.type) == '!torch.bool', f'`accumulate` should be a !torch.bool but is {type(accumulate).__module__}.{type(accumulate).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenIndexPutHackedTwinOp, self).__init__(result_type, self_, indices, values, accumulate, loc=loc, ip=ip)
+
+
+class AtenIndexPut_HackedTwinOp:
+ def __init__(self, self_: Value, indices: List[Value], values: Value, accumulate: bool, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(indices):
+ indices = torch_dialect.PrimListConstructOp(indices)
+ else:
+ indices = get_op_result_or_value(indices)
+ assert str(indices.type) == '!torch.list', f'`indices` should be a !torch.list but is {type(indices).__module__}.{type(indices).__name__}'
+
+ if not is_mlir_value(values):
+ assert is_mlir_value(values), f'`values` should be a Value but is {type(values).__module__}.{type(values).__name__}'
+ else:
+ values = get_op_result_or_value(values)
+ assert str(values.type).startswith("!torch.vtensor"), f'`values` should be a torch.vtensor but is {type(values).__module__}.{type(values).__name__}'
+
+ if not is_mlir_value(accumulate):
+ accumulate = torch_dialect.ConstantBoolOp(accumulate)
+ else:
+ accumulate = get_op_result_or_value(accumulate)
+ assert str(accumulate.type) == '!torch.bool', f'`accumulate` should be a !torch.bool but is {type(accumulate).__module__}.{type(accumulate).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenIndexPut_HackedTwinOp, self).__init__(result_type, self_, indices, values, accumulate, loc=loc, ip=ip)
+
+
+class AtenLinearOp:
+ def __init__(self, input: Value, weight: Value, bias: Optional[Value], *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(input):
+ assert is_mlir_value(input), f'`input` should be a Value but is {type(input).__module__}.{type(input).__name__}'
+ else:
+ input = get_op_result_or_value(input)
+ assert str(input.type).startswith("!torch.vtensor"), f'`input` should be a torch.vtensor but is {type(input).__module__}.{type(input).__name__}'
+
+ if not is_mlir_value(weight):
+ assert is_mlir_value(weight), f'`weight` should be a Value but is {type(weight).__module__}.{type(weight).__name__}'
+ else:
+ weight = get_op_result_or_value(weight)
+ assert str(weight.type).startswith("!torch.vtensor"), f'`weight` should be a torch.vtensor but is {type(weight).__module__}.{type(weight).__name__}'
+
+ if not is_mlir_value(bias):
+ if bias is not None:
+ assert is_mlir_value(bias), f'`bias` should be a Value but is {type(bias).__module__}.{type(bias).__name__}'
+ else:
+ bias = torch_dialect.ConstantNoneOp()
+ else:
+ bias = get_op_result_or_value(bias)
+ assert str(bias.type).startswith("!torch.vtensor"), f'`bias` should be a torch.vtensor but is {type(bias).__module__}.{type(bias).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenLinearOp, self).__init__(result_type, input, weight, bias, loc=loc, ip=ip)
+
+
+class AtenMmOp:
+ def __init__(self, self_: Value, mat2: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(mat2):
+ assert is_mlir_value(mat2), f'`mat2` should be a Value but is {type(mat2).__module__}.{type(mat2).__name__}'
+ else:
+ mat2 = get_op_result_or_value(mat2)
+ assert str(mat2.type).startswith("!torch.vtensor"), f'`mat2` should be a torch.vtensor but is {type(mat2).__module__}.{type(mat2).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenMmOp, self).__init__(result_type, self_, mat2, loc=loc, ip=ip)
+
+
+class AtenAddmmOp:
+ def __init__(self, self_: Value, mat1: Value, mat2: Value, beta: "Number", alpha: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(mat1):
+ assert is_mlir_value(mat1), f'`mat1` should be a Value but is {type(mat1).__module__}.{type(mat1).__name__}'
+ else:
+ mat1 = get_op_result_or_value(mat1)
+ assert str(mat1.type).startswith("!torch.vtensor"), f'`mat1` should be a torch.vtensor but is {type(mat1).__module__}.{type(mat1).__name__}'
+
+ if not is_mlir_value(mat2):
+ assert is_mlir_value(mat2), f'`mat2` should be a Value but is {type(mat2).__module__}.{type(mat2).__name__}'
+ else:
+ mat2 = get_op_result_or_value(mat2)
+ assert str(mat2.type).startswith("!torch.vtensor"), f'`mat2` should be a torch.vtensor but is {type(mat2).__module__}.{type(mat2).__name__}'
+
+ if not is_mlir_value(beta):
+ beta = torch_dialect.ConstantNumberOp(beta)
+ else:
+ beta = get_op_result_or_value(beta)
+ assert str(beta.type) in {'!torch.float', '!torch.int'}, f'`beta` should be a !torch.number but is {type(beta).__module__}.{type(beta).__name__}'
+
+ if not is_mlir_value(alpha):
+ alpha = torch_dialect.ConstantNumberOp(alpha)
+ else:
+ alpha = get_op_result_or_value(alpha)
+ assert str(alpha.type) in {'!torch.float', '!torch.int'}, f'`alpha` should be a !torch.number but is {type(alpha).__module__}.{type(alpha).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenAddmmOp, self).__init__(result_type, self_, mat1, mat2, beta, alpha, loc=loc, ip=ip)
+
+
+class AtenMatmulOp:
+ def __init__(self, self_: Value, other: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}'
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenMatmulOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenMvOp:
+ def __init__(self, self_: Value, vec: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(vec):
+ assert is_mlir_value(vec), f'`vec` should be a Value but is {type(vec).__module__}.{type(vec).__name__}'
+ else:
+ vec = get_op_result_or_value(vec)
+ assert str(vec.type).startswith("!torch.vtensor"), f'`vec` should be a torch.vtensor but is {type(vec).__module__}.{type(vec).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenMvOp, self).__init__(result_type, self_, vec, loc=loc, ip=ip)
+
+
+class AtenConv2dOp:
+ def __init__(self, input: Value, weight: Value, bias: Optional[Value], stride: List[int], padding: List[int], dilation: List[int], groups: int, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(input):
+ assert is_mlir_value(input), f'`input` should be a Value but is {type(input).__module__}.{type(input).__name__}'
+ else:
+ input = get_op_result_or_value(input)
+ assert str(input.type).startswith("!torch.vtensor"), f'`input` should be a torch.vtensor but is {type(input).__module__}.{type(input).__name__}'
+
+ if not is_mlir_value(weight):
+ assert is_mlir_value(weight), f'`weight` should be a Value but is {type(weight).__module__}.{type(weight).__name__}'
+ else:
+ weight = get_op_result_or_value(weight)
+ assert str(weight.type).startswith("!torch.vtensor"), f'`weight` should be a torch.vtensor but is {type(weight).__module__}.{type(weight).__name__}'
+
+ if not is_mlir_value(bias):
+ if bias is not None:
+ assert is_mlir_value(bias), f'`bias` should be a Value but is {type(bias).__module__}.{type(bias).__name__}'
+ else:
+ bias = torch_dialect.ConstantNoneOp()
+ else:
+ bias = get_op_result_or_value(bias)
+ assert str(bias.type).startswith("!torch.vtensor"), f'`bias` should be a torch.vtensor but is {type(bias).__module__}.{type(bias).__name__}'
+
+ if not is_mlir_value(stride):
+ stride = list(map(torch_dialect.ConstantIntOp, stride))
+ stride = torch_dialect.PrimListConstructOp(stride)
+ else:
+ stride = get_op_result_or_value(stride)
+ assert str(stride.type) == '!torch.list', f'`stride` should be a !torch.list but is {type(stride).__module__}.{type(stride).__name__}'
+
+ if not is_mlir_value(padding):
+ padding = list(map(torch_dialect.ConstantIntOp, padding))
+ padding = torch_dialect.PrimListConstructOp(padding)
+ else:
+ padding = get_op_result_or_value(padding)
+ assert str(padding.type) == '!torch.list', f'`padding` should be a !torch.list but is {type(padding).__module__}.{type(padding).__name__}'
+
+ if not is_mlir_value(dilation):
+ dilation = list(map(torch_dialect.ConstantIntOp, dilation))
+ dilation = torch_dialect.PrimListConstructOp(dilation)
+ else:
+ dilation = get_op_result_or_value(dilation)
+ assert str(dilation.type) == '!torch.list', f'`dilation` should be a !torch.list but is {type(dilation).__module__}.{type(dilation).__name__}'
+
+ if not is_mlir_value(groups):
+ groups = torch_dialect.ConstantIntOp(groups)
+ else:
+ groups = get_op_result_or_value(groups)
+ assert str(groups.type) == '!torch.int', f'`groups` should be a !torch.int but is {type(groups).__module__}.{type(groups).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenConv2dOp, self).__init__(result_type, input, weight, bias, stride, padding, dilation, groups, loc=loc, ip=ip)
+
+
+class AtenConvTranspose1dOp:
+ def __init__(self, input: Value, weight: Value, bias: Optional[Value], stride: List[int], padding: List[int], output_padding: List[int], groups: int, dilation: List[int], *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(input):
+ assert is_mlir_value(input), f'`input` should be a Value but is {type(input).__module__}.{type(input).__name__}'
+ else:
+ input = get_op_result_or_value(input)
+ assert str(input.type).startswith("!torch.vtensor"), f'`input` should be a torch.vtensor but is {type(input).__module__}.{type(input).__name__}'
+
+ if not is_mlir_value(weight):
+ assert is_mlir_value(weight), f'`weight` should be a Value but is {type(weight).__module__}.{type(weight).__name__}'
+ else:
+ weight = get_op_result_or_value(weight)
+ assert str(weight.type).startswith("!torch.vtensor"), f'`weight` should be a torch.vtensor but is {type(weight).__module__}.{type(weight).__name__}'
+
+ if not is_mlir_value(bias):
+ if bias is not None:
+ assert is_mlir_value(bias), f'`bias` should be a Value but is {type(bias).__module__}.{type(bias).__name__}'
+ else:
+ bias = torch_dialect.ConstantNoneOp()
+ else:
+ bias = get_op_result_or_value(bias)
+ assert str(bias.type).startswith("!torch.vtensor"), f'`bias` should be a torch.vtensor but is {type(bias).__module__}.{type(bias).__name__}'
+
+ if not is_mlir_value(stride):
+ stride = list(map(torch_dialect.ConstantIntOp, stride))
+ stride = torch_dialect.PrimListConstructOp(stride)
+ else:
+ stride = get_op_result_or_value(stride)
+ assert str(stride.type) == '!torch.list', f'`stride` should be a !torch.list but is {type(stride).__module__}.{type(stride).__name__}'
+
+ if not is_mlir_value(padding):
+ padding = list(map(torch_dialect.ConstantIntOp, padding))
+ padding = torch_dialect.PrimListConstructOp(padding)
+ else:
+ padding = get_op_result_or_value(padding)
+ assert str(padding.type) == '!torch.list', f'`padding` should be a !torch.list but is {type(padding).__module__}.{type(padding).__name__}'
+
+ if not is_mlir_value(output_padding):
+ output_padding = list(map(torch_dialect.ConstantIntOp, output_padding))
+ output_padding = torch_dialect.PrimListConstructOp(output_padding)
+ else:
+ output_padding = get_op_result_or_value(output_padding)
+ assert str(output_padding.type) == '!torch.list', f'`output_padding` should be a !torch.list but is {type(output_padding).__module__}.{type(output_padding).__name__}'
+
+ if not is_mlir_value(groups):
+ groups = torch_dialect.ConstantIntOp(groups)
+ else:
+ groups = get_op_result_or_value(groups)
+ assert str(groups.type) == '!torch.int', f'`groups` should be a !torch.int but is {type(groups).__module__}.{type(groups).__name__}'
+
+ if not is_mlir_value(dilation):
+ dilation = list(map(torch_dialect.ConstantIntOp, dilation))
+ dilation = torch_dialect.PrimListConstructOp(dilation)
+ else:
+ dilation = get_op_result_or_value(dilation)
+ assert str(dilation.type) == '!torch.list', f'`dilation` should be a !torch.list but is {type(dilation).__module__}.{type(dilation).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenConvTranspose1dOp, self).__init__(result_type, input, weight, bias, stride, padding, output_padding, groups, dilation, loc=loc, ip=ip)
+
+
+class AtenConvTranspose2dInputOp:
+ def __init__(self, input: Value, weight: Value, bias: Optional[Value], stride: List[int], padding: List[int], output_padding: List[int], groups: int, dilation: List[int], *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(input):
+ assert is_mlir_value(input), f'`input` should be a Value but is {type(input).__module__}.{type(input).__name__}'
+ else:
+ input = get_op_result_or_value(input)
+ assert str(input.type).startswith("!torch.vtensor"), f'`input` should be a torch.vtensor but is {type(input).__module__}.{type(input).__name__}'
+
+ if not is_mlir_value(weight):
+ assert is_mlir_value(weight), f'`weight` should be a Value but is {type(weight).__module__}.{type(weight).__name__}'
+ else:
+ weight = get_op_result_or_value(weight)
+ assert str(weight.type).startswith("!torch.vtensor"), f'`weight` should be a torch.vtensor but is {type(weight).__module__}.{type(weight).__name__}'
+
+ if not is_mlir_value(bias):
+ if bias is not None:
+ assert is_mlir_value(bias), f'`bias` should be a Value but is {type(bias).__module__}.{type(bias).__name__}'
+ else:
+ bias = torch_dialect.ConstantNoneOp()
+ else:
+ bias = get_op_result_or_value(bias)
+ assert str(bias.type).startswith("!torch.vtensor"), f'`bias` should be a torch.vtensor but is {type(bias).__module__}.{type(bias).__name__}'
+
+ if not is_mlir_value(stride):
+ stride = list(map(torch_dialect.ConstantIntOp, stride))
+ stride = torch_dialect.PrimListConstructOp(stride)
+ else:
+ stride = get_op_result_or_value(stride)
+ assert str(stride.type) == '!torch.list', f'`stride` should be a !torch.list but is {type(stride).__module__}.{type(stride).__name__}'
+
+ if not is_mlir_value(padding):
+ padding = list(map(torch_dialect.ConstantIntOp, padding))
+ padding = torch_dialect.PrimListConstructOp(padding)
+ else:
+ padding = get_op_result_or_value(padding)
+ assert str(padding.type) == '!torch.list', f'`padding` should be a !torch.list but is {type(padding).__module__}.{type(padding).__name__}'
+
+ if not is_mlir_value(output_padding):
+ output_padding = list(map(torch_dialect.ConstantIntOp, output_padding))
+ output_padding = torch_dialect.PrimListConstructOp(output_padding)
+ else:
+ output_padding = get_op_result_or_value(output_padding)
+ assert str(output_padding.type) == '!torch.list', f'`output_padding` should be a !torch.list but is {type(output_padding).__module__}.{type(output_padding).__name__}'
+
+ if not is_mlir_value(groups):
+ groups = torch_dialect.ConstantIntOp(groups)
+ else:
+ groups = get_op_result_or_value(groups)
+ assert str(groups.type) == '!torch.int', f'`groups` should be a !torch.int but is {type(groups).__module__}.{type(groups).__name__}'
+
+ if not is_mlir_value(dilation):
+ dilation = list(map(torch_dialect.ConstantIntOp, dilation))
+ dilation = torch_dialect.PrimListConstructOp(dilation)
+ else:
+ dilation = get_op_result_or_value(dilation)
+ assert str(dilation.type) == '!torch.list', f'`dilation` should be a !torch.list but is {type(dilation).__module__}.{type(dilation).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenConvTranspose2dInputOp, self).__init__(result_type, input, weight, bias, stride, padding, output_padding, groups, dilation, loc=loc, ip=ip)
+
+
+class AtenConvTranspose3dInputOp:
+ def __init__(self, input: Value, weight: Value, bias: Optional[Value], stride: List[int], padding: List[int], output_padding: List[int], groups: int, dilation: List[int], *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(input):
+ assert is_mlir_value(input), f'`input` should be a Value but is {type(input).__module__}.{type(input).__name__}'
+ else:
+ input = get_op_result_or_value(input)
+ assert str(input.type).startswith("!torch.vtensor"), f'`input` should be a torch.vtensor but is {type(input).__module__}.{type(input).__name__}'
+
+ if not is_mlir_value(weight):
+ assert is_mlir_value(weight), f'`weight` should be a Value but is {type(weight).__module__}.{type(weight).__name__}'
+ else:
+ weight = get_op_result_or_value(weight)
+ assert str(weight.type).startswith("!torch.vtensor"), f'`weight` should be a torch.vtensor but is {type(weight).__module__}.{type(weight).__name__}'
+
+ if not is_mlir_value(bias):
+ if bias is not None:
+ assert is_mlir_value(bias), f'`bias` should be a Value but is {type(bias).__module__}.{type(bias).__name__}'
+ else:
+ bias = torch_dialect.ConstantNoneOp()
+ else:
+ bias = get_op_result_or_value(bias)
+ assert str(bias.type).startswith("!torch.vtensor"), f'`bias` should be a torch.vtensor but is {type(bias).__module__}.{type(bias).__name__}'
+
+ if not is_mlir_value(stride):
+ stride = list(map(torch_dialect.ConstantIntOp, stride))
+ stride = torch_dialect.PrimListConstructOp(stride)
+ else:
+ stride = get_op_result_or_value(stride)
+ assert str(stride.type) == '!torch.list', f'`stride` should be a !torch.list but is {type(stride).__module__}.{type(stride).__name__}'
+
+ if not is_mlir_value(padding):
+ padding = list(map(torch_dialect.ConstantIntOp, padding))
+ padding = torch_dialect.PrimListConstructOp(padding)
+ else:
+ padding = get_op_result_or_value(padding)
+ assert str(padding.type) == '!torch.list', f'`padding` should be a !torch.list but is {type(padding).__module__}.{type(padding).__name__}'
+
+ if not is_mlir_value(output_padding):
+ output_padding = list(map(torch_dialect.ConstantIntOp, output_padding))
+ output_padding = torch_dialect.PrimListConstructOp(output_padding)
+ else:
+ output_padding = get_op_result_or_value(output_padding)
+ assert str(output_padding.type) == '!torch.list', f'`output_padding` should be a !torch.list but is {type(output_padding).__module__}.{type(output_padding).__name__}'
+
+ if not is_mlir_value(groups):
+ groups = torch_dialect.ConstantIntOp(groups)
+ else:
+ groups = get_op_result_or_value(groups)
+ assert str(groups.type) == '!torch.int', f'`groups` should be a !torch.int but is {type(groups).__module__}.{type(groups).__name__}'
+
+ if not is_mlir_value(dilation):
+ dilation = list(map(torch_dialect.ConstantIntOp, dilation))
+ dilation = torch_dialect.PrimListConstructOp(dilation)
+ else:
+ dilation = get_op_result_or_value(dilation)
+ assert str(dilation.type) == '!torch.list', f'`dilation` should be a !torch.list but is {type(dilation).__module__}.{type(dilation).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenConvTranspose3dInputOp, self).__init__(result_type, input, weight, bias, stride, padding, output_padding, groups, dilation, loc=loc, ip=ip)
+
+
+class AtenConvolutionOp:
+ def __init__(self, input: Value, weight: Value, bias: Optional[Value], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(input):
+ assert is_mlir_value(input), f'`input` should be a Value but is {type(input).__module__}.{type(input).__name__}'
+ else:
+ input = get_op_result_or_value(input)
+ assert str(input.type).startswith("!torch.vtensor"), f'`input` should be a torch.vtensor but is {type(input).__module__}.{type(input).__name__}'
+
+ if not is_mlir_value(weight):
+ assert is_mlir_value(weight), f'`weight` should be a Value but is {type(weight).__module__}.{type(weight).__name__}'
+ else:
+ weight = get_op_result_or_value(weight)
+ assert str(weight.type).startswith("!torch.vtensor"), f'`weight` should be a torch.vtensor but is {type(weight).__module__}.{type(weight).__name__}'
+
+ if not is_mlir_value(bias):
+ if bias is not None:
+ assert is_mlir_value(bias), f'`bias` should be a Value but is {type(bias).__module__}.{type(bias).__name__}'
+ else:
+ bias = torch_dialect.ConstantNoneOp()
+ else:
+ bias = get_op_result_or_value(bias)
+ assert str(bias.type).startswith("!torch.vtensor"), f'`bias` should be a torch.vtensor but is {type(bias).__module__}.{type(bias).__name__}'
+
+ if not is_mlir_value(stride):
+ stride = list(map(torch_dialect.ConstantIntOp, stride))
+ stride = torch_dialect.PrimListConstructOp(stride)
+ else:
+ stride = get_op_result_or_value(stride)
+ assert str(stride.type) == '!torch.list', f'`stride` should be a !torch.list but is {type(stride).__module__}.{type(stride).__name__}'
+
+ if not is_mlir_value(padding):
+ padding = list(map(torch_dialect.ConstantIntOp, padding))
+ padding = torch_dialect.PrimListConstructOp(padding)
+ else:
+ padding = get_op_result_or_value(padding)
+ assert str(padding.type) == '!torch.list', f'`padding` should be a !torch.list but is {type(padding).__module__}.{type(padding).__name__}'
+
+ if not is_mlir_value(dilation):
+ dilation = list(map(torch_dialect.ConstantIntOp, dilation))
+ dilation = torch_dialect.PrimListConstructOp(dilation)
+ else:
+ dilation = get_op_result_or_value(dilation)
+ assert str(dilation.type) == '!torch.list', f'`dilation` should be a !torch.list but is {type(dilation).__module__}.{type(dilation).__name__}'
+
+ if not is_mlir_value(transposed):
+ transposed = torch_dialect.ConstantBoolOp(transposed)
+ else:
+ transposed = get_op_result_or_value(transposed)
+ assert str(transposed.type) == '!torch.bool', f'`transposed` should be a !torch.bool but is {type(transposed).__module__}.{type(transposed).__name__}'
+
+ if not is_mlir_value(output_padding):
+ output_padding = list(map(torch_dialect.ConstantIntOp, output_padding))
+ output_padding = torch_dialect.PrimListConstructOp(output_padding)
+ else:
+ output_padding = get_op_result_or_value(output_padding)
+ assert str(output_padding.type) == '!torch.list', f'`output_padding` should be a !torch.list but is {type(output_padding).__module__}.{type(output_padding).__name__}'
+
+ if not is_mlir_value(groups):
+ groups = torch_dialect.ConstantIntOp(groups)
+ else:
+ groups = get_op_result_or_value(groups)
+ assert str(groups.type) == '!torch.int', f'`groups` should be a !torch.int but is {type(groups).__module__}.{type(groups).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenConvolutionOp, self).__init__(result_type, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, loc=loc, ip=ip)
+
+
+class AtenConvolutionOverrideableOp:
+ def __init__(self, input: Value, weight: Value, bias: Optional[Value], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(input):
+ assert is_mlir_value(input), f'`input` should be a Value but is {type(input).__module__}.{type(input).__name__}'
+ else:
+ input = get_op_result_or_value(input)
+ assert str(input.type).startswith("!torch.vtensor"), f'`input` should be a torch.vtensor but is {type(input).__module__}.{type(input).__name__}'
+
+ if not is_mlir_value(weight):
+ assert is_mlir_value(weight), f'`weight` should be a Value but is {type(weight).__module__}.{type(weight).__name__}'
+ else:
+ weight = get_op_result_or_value(weight)
+ assert str(weight.type).startswith("!torch.vtensor"), f'`weight` should be a torch.vtensor but is {type(weight).__module__}.{type(weight).__name__}'
+
+ if not is_mlir_value(bias):
+ if bias is not None:
+ assert is_mlir_value(bias), f'`bias` should be a Value but is {type(bias).__module__}.{type(bias).__name__}'
+ else:
+ bias = torch_dialect.ConstantNoneOp()
+ else:
+ bias = get_op_result_or_value(bias)
+ assert str(bias.type).startswith("!torch.vtensor"), f'`bias` should be a torch.vtensor but is {type(bias).__module__}.{type(bias).__name__}'
+
+ if not is_mlir_value(stride):
+ stride = list(map(torch_dialect.ConstantIntOp, stride))
+ stride = torch_dialect.PrimListConstructOp(stride)
+ else:
+ stride = get_op_result_or_value(stride)
+ assert str(stride.type) == '!torch.list', f'`stride` should be a !torch.list but is {type(stride).__module__}.{type(stride).__name__}'
+
+ if not is_mlir_value(padding):
+ padding = list(map(torch_dialect.ConstantIntOp, padding))
+ padding = torch_dialect.PrimListConstructOp(padding)
+ else:
+ padding = get_op_result_or_value(padding)
+ assert str(padding.type) == '!torch.list', f'`padding` should be a !torch.list but is {type(padding).__module__}.{type(padding).__name__}'
+
+ if not is_mlir_value(dilation):
+ dilation = list(map(torch_dialect.ConstantIntOp, dilation))
+ dilation = torch_dialect.PrimListConstructOp(dilation)
+ else:
+ dilation = get_op_result_or_value(dilation)
+ assert str(dilation.type) == '!torch.list', f'`dilation` should be a !torch.list but is {type(dilation).__module__}.{type(dilation).__name__}'
+
+ if not is_mlir_value(transposed):
+ transposed = torch_dialect.ConstantBoolOp(transposed)
+ else:
+ transposed = get_op_result_or_value(transposed)
+ assert str(transposed.type) == '!torch.bool', f'`transposed` should be a !torch.bool but is {type(transposed).__module__}.{type(transposed).__name__}'
+
+ if not is_mlir_value(output_padding):
+ output_padding = list(map(torch_dialect.ConstantIntOp, output_padding))
+ output_padding = torch_dialect.PrimListConstructOp(output_padding)
+ else:
+ output_padding = get_op_result_or_value(output_padding)
+ assert str(output_padding.type) == '!torch.list', f'`output_padding` should be a !torch.list but is {type(output_padding).__module__}.{type(output_padding).__name__}'
+
+ if not is_mlir_value(groups):
+ groups = torch_dialect.ConstantIntOp(groups)
+ else:
+ groups = get_op_result_or_value(groups)
+ assert str(groups.type) == '!torch.int', f'`groups` should be a !torch.int but is {type(groups).__module__}.{type(groups).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenConvolutionOverrideableOp, self).__init__(result_type, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, loc=loc, ip=ip)
+
+
+class Aten_ConvolutionOp:
+ def __init__(self, input: Value, weight: Value, bias: Optional[Value], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool, allow_tf32: bool, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(input):
+ assert is_mlir_value(input), f'`input` should be a Value but is {type(input).__module__}.{type(input).__name__}'
+ else:
+ input = get_op_result_or_value(input)
+ assert str(input.type).startswith("!torch.vtensor"), f'`input` should be a torch.vtensor but is {type(input).__module__}.{type(input).__name__}'
+
+ if not is_mlir_value(weight):
+ assert is_mlir_value(weight), f'`weight` should be a Value but is {type(weight).__module__}.{type(weight).__name__}'
+ else:
+ weight = get_op_result_or_value(weight)
+ assert str(weight.type).startswith("!torch.vtensor"), f'`weight` should be a torch.vtensor but is {type(weight).__module__}.{type(weight).__name__}'
+
+ if not is_mlir_value(bias):
+ if bias is not None:
+ assert is_mlir_value(bias), f'`bias` should be a Value but is {type(bias).__module__}.{type(bias).__name__}'
+ else:
+ bias = torch_dialect.ConstantNoneOp()
+ else:
+ bias = get_op_result_or_value(bias)
+ assert str(bias.type).startswith("!torch.vtensor"), f'`bias` should be a torch.vtensor but is {type(bias).__module__}.{type(bias).__name__}'
+
+ if not is_mlir_value(stride):
+ stride = list(map(torch_dialect.ConstantIntOp, stride))
+ stride = torch_dialect.PrimListConstructOp(stride)
+ else:
+ stride = get_op_result_or_value(stride)
+ assert str(stride.type) == '!torch.list', f'`stride` should be a !torch.list but is {type(stride).__module__}.{type(stride).__name__}'
+
+ if not is_mlir_value(padding):
+ padding = list(map(torch_dialect.ConstantIntOp, padding))
+ padding = torch_dialect.PrimListConstructOp(padding)
+ else:
+ padding = get_op_result_or_value(padding)
+ assert str(padding.type) == '!torch.list', f'`padding` should be a !torch.list but is {type(padding).__module__}.{type(padding).__name__}'
+
+ if not is_mlir_value(dilation):
+ dilation = list(map(torch_dialect.ConstantIntOp, dilation))
+ dilation = torch_dialect.PrimListConstructOp(dilation)
+ else:
+ dilation = get_op_result_or_value(dilation)
+ assert str(dilation.type) == '!torch.list', f'`dilation` should be a !torch.list but is {type(dilation).__module__}.{type(dilation).__name__}'
+
+ if not is_mlir_value(transposed):
+ transposed = torch_dialect.ConstantBoolOp(transposed)
+ else:
+ transposed = get_op_result_or_value(transposed)
+ assert str(transposed.type) == '!torch.bool', f'`transposed` should be a !torch.bool but is {type(transposed).__module__}.{type(transposed).__name__}'
+
+ if not is_mlir_value(output_padding):
+ output_padding = list(map(torch_dialect.ConstantIntOp, output_padding))
+ output_padding = torch_dialect.PrimListConstructOp(output_padding)
+ else:
+ output_padding = get_op_result_or_value(output_padding)
+ assert str(output_padding.type) == '!torch.list', f'`output_padding` should be a !torch.list but is {type(output_padding).__module__}.{type(output_padding).__name__}'
+
+ if not is_mlir_value(groups):
+ groups = torch_dialect.ConstantIntOp(groups)
+ else:
+ groups = get_op_result_or_value(groups)
+ assert str(groups.type) == '!torch.int', f'`groups` should be a !torch.int but is {type(groups).__module__}.{type(groups).__name__}'
+
+ if not is_mlir_value(benchmark):
+ benchmark = torch_dialect.ConstantBoolOp(benchmark)
+ else:
+ benchmark = get_op_result_or_value(benchmark)
+ assert str(benchmark.type) == '!torch.bool', f'`benchmark` should be a !torch.bool but is {type(benchmark).__module__}.{type(benchmark).__name__}'
+
+ if not is_mlir_value(deterministic):
+ deterministic = torch_dialect.ConstantBoolOp(deterministic)
+ else:
+ deterministic = get_op_result_or_value(deterministic)
+ assert str(deterministic.type) == '!torch.bool', f'`deterministic` should be a !torch.bool but is {type(deterministic).__module__}.{type(deterministic).__name__}'
+
+ if not is_mlir_value(cudnn_enabled):
+ cudnn_enabled = torch_dialect.ConstantBoolOp(cudnn_enabled)
+ else:
+ cudnn_enabled = get_op_result_or_value(cudnn_enabled)
+ assert str(cudnn_enabled.type) == '!torch.bool', f'`cudnn_enabled` should be a !torch.bool but is {type(cudnn_enabled).__module__}.{type(cudnn_enabled).__name__}'
+
+ if not is_mlir_value(allow_tf32):
+ allow_tf32 = torch_dialect.ConstantBoolOp(allow_tf32)
+ else:
+ allow_tf32 = get_op_result_or_value(allow_tf32)
+ assert str(allow_tf32.type) == '!torch.bool', f'`allow_tf32` should be a !torch.bool but is {type(allow_tf32).__module__}.{type(allow_tf32).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(Aten_ConvolutionOp, self).__init__(result_type, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled, allow_tf32, loc=loc, ip=ip)
+
+
+class Aten_ConvolutionDeprecatedOp:
+ def __init__(self, input: Value, weight: Value, bias: Optional[Value], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(input):
+ assert is_mlir_value(input), f'`input` should be a Value but is {type(input).__module__}.{type(input).__name__}'
+ else:
+ input = get_op_result_or_value(input)
+ assert str(input.type).startswith("!torch.vtensor"), f'`input` should be a torch.vtensor but is {type(input).__module__}.{type(input).__name__}'
+
+ if not is_mlir_value(weight):
+ assert is_mlir_value(weight), f'`weight` should be a Value but is {type(weight).__module__}.{type(weight).__name__}'
+ else:
+ weight = get_op_result_or_value(weight)
+ assert str(weight.type).startswith("!torch.vtensor"), f'`weight` should be a torch.vtensor but is {type(weight).__module__}.{type(weight).__name__}'
+
+ if not is_mlir_value(bias):
+ if bias is not None:
+ assert is_mlir_value(bias), f'`bias` should be a Value but is {type(bias).__module__}.{type(bias).__name__}'
+ else:
+ bias = torch_dialect.ConstantNoneOp()
+ else:
+ bias = get_op_result_or_value(bias)
+ assert str(bias.type).startswith("!torch.vtensor"), f'`bias` should be a torch.vtensor but is {type(bias).__module__}.{type(bias).__name__}'
+
+ if not is_mlir_value(stride):
+ stride = list(map(torch_dialect.ConstantIntOp, stride))
+ stride = torch_dialect.PrimListConstructOp(stride)
+ else:
+ stride = get_op_result_or_value(stride)
+ assert str(stride.type) == '!torch.list', f'`stride` should be a !torch.list but is {type(stride).__module__}.{type(stride).__name__}'
+
+ if not is_mlir_value(padding):
+ padding = list(map(torch_dialect.ConstantIntOp, padding))
+ padding = torch_dialect.PrimListConstructOp(padding)
+ else:
+ padding = get_op_result_or_value(padding)
+ assert str(padding.type) == '!torch.list', f'`padding` should be a !torch.list but is {type(padding).__module__}.{type(padding).__name__}'
+
+ if not is_mlir_value(dilation):
+ dilation = list(map(torch_dialect.ConstantIntOp, dilation))
+ dilation = torch_dialect.PrimListConstructOp(dilation)
+ else:
+ dilation = get_op_result_or_value(dilation)
+ assert str(dilation.type) == '!torch.list', f'`dilation` should be a !torch.list but is {type(dilation).__module__}.{type(dilation).__name__}'
+
+ if not is_mlir_value(transposed):
+ transposed = torch_dialect.ConstantBoolOp(transposed)
+ else:
+ transposed = get_op_result_or_value(transposed)
+ assert str(transposed.type) == '!torch.bool', f'`transposed` should be a !torch.bool but is {type(transposed).__module__}.{type(transposed).__name__}'
+
+ if not is_mlir_value(output_padding):
+ output_padding = list(map(torch_dialect.ConstantIntOp, output_padding))
+ output_padding = torch_dialect.PrimListConstructOp(output_padding)
+ else:
+ output_padding = get_op_result_or_value(output_padding)
+ assert str(output_padding.type) == '!torch.list', f'`output_padding` should be a !torch.list but is {type(output_padding).__module__}.{type(output_padding).__name__}'
+
+ if not is_mlir_value(groups):
+ groups = torch_dialect.ConstantIntOp(groups)
+ else:
+ groups = get_op_result_or_value(groups)
+ assert str(groups.type) == '!torch.int', f'`groups` should be a !torch.int but is {type(groups).__module__}.{type(groups).__name__}'
+
+ if not is_mlir_value(benchmark):
+ benchmark = torch_dialect.ConstantBoolOp(benchmark)
+ else:
+ benchmark = get_op_result_or_value(benchmark)
+ assert str(benchmark.type) == '!torch.bool', f'`benchmark` should be a !torch.bool but is {type(benchmark).__module__}.{type(benchmark).__name__}'
+
+ if not is_mlir_value(deterministic):
+ deterministic = torch_dialect.ConstantBoolOp(deterministic)
+ else:
+ deterministic = get_op_result_or_value(deterministic)
+ assert str(deterministic.type) == '!torch.bool', f'`deterministic` should be a !torch.bool but is {type(deterministic).__module__}.{type(deterministic).__name__}'
+
+ if not is_mlir_value(cudnn_enabled):
+ cudnn_enabled = torch_dialect.ConstantBoolOp(cudnn_enabled)
+ else:
+ cudnn_enabled = get_op_result_or_value(cudnn_enabled)
+ assert str(cudnn_enabled.type) == '!torch.bool', f'`cudnn_enabled` should be a !torch.bool but is {type(cudnn_enabled).__module__}.{type(cudnn_enabled).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(Aten_ConvolutionDeprecatedOp, self).__init__(result_type, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled, loc=loc, ip=ip)
+
+
+class AtenRollOp:
+ def __init__(self, self_: Value, shifts: List[int], dims: List[int], *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(shifts):
+ shifts = list(map(torch_dialect.ConstantIntOp, shifts))
+ shifts = torch_dialect.PrimListConstructOp(shifts)
+ else:
+ shifts = get_op_result_or_value(shifts)
+ assert str(shifts.type) == '!torch.list', f'`shifts` should be a !torch.list but is {type(shifts).__module__}.{type(shifts).__name__}'
+
+ if not is_mlir_value(dims):
+ dims = list(map(torch_dialect.ConstantIntOp, dims))
+ dims = torch_dialect.PrimListConstructOp(dims)
+ else:
+ dims = get_op_result_or_value(dims)
+ assert str(dims.type) == '!torch.list', f'`dims` should be a !torch.list but is {type(dims).__module__}.{type(dims).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenRollOp, self).__init__(result_type, self_, shifts, dims, loc=loc, ip=ip)
+
+
+class AtenConvolutionBackwardOverrideableOp:
+ def __init__(self, grad_output: Value, input: Value, weight: Value, stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, output_mask: List[bool], *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(grad_output):
+ assert is_mlir_value(grad_output), f'`grad_output` should be a Value but is {type(grad_output).__module__}.{type(grad_output).__name__}'
+ else:
+ grad_output = get_op_result_or_value(grad_output)
+ assert str(grad_output.type).startswith("!torch.vtensor"), f'`grad_output` should be a torch.vtensor but is {type(grad_output).__module__}.{type(grad_output).__name__}'
+
+ if not is_mlir_value(input):
+ assert is_mlir_value(input), f'`input` should be a Value but is {type(input).__module__}.{type(input).__name__}'
+ else:
+ input = get_op_result_or_value(input)
+ assert str(input.type).startswith("!torch.vtensor"), f'`input` should be a torch.vtensor but is {type(input).__module__}.{type(input).__name__}'
+
+ if not is_mlir_value(weight):
+ assert is_mlir_value(weight), f'`weight` should be a Value but is {type(weight).__module__}.{type(weight).__name__}'
+ else:
+ weight = get_op_result_or_value(weight)
+ assert str(weight.type).startswith("!torch.vtensor"), f'`weight` should be a torch.vtensor but is {type(weight).__module__}.{type(weight).__name__}'
+
+ if not is_mlir_value(stride):
+ stride = list(map(torch_dialect.ConstantIntOp, stride))
+ stride = torch_dialect.PrimListConstructOp(stride)
+ else:
+ stride = get_op_result_or_value(stride)
+ assert str(stride.type) == '!torch.list', f'`stride` should be a !torch.list but is {type(stride).__module__}.{type(stride).__name__}'
+
+ if not is_mlir_value(padding):
+ padding = list(map(torch_dialect.ConstantIntOp, padding))
+ padding = torch_dialect.PrimListConstructOp(padding)
+ else:
+ padding = get_op_result_or_value(padding)
+ assert str(padding.type) == '!torch.list', f'`padding` should be a !torch.list but is {type(padding).__module__}.{type(padding).__name__}'
+
+ if not is_mlir_value(dilation):
+ dilation = list(map(torch_dialect.ConstantIntOp, dilation))
+ dilation = torch_dialect.PrimListConstructOp(dilation)
+ else:
+ dilation = get_op_result_or_value(dilation)
+ assert str(dilation.type) == '!torch.list', f'`dilation` should be a !torch.list but is {type(dilation).__module__}.{type(dilation).__name__}'
+
+ if not is_mlir_value(transposed):
+ transposed = torch_dialect.ConstantBoolOp(transposed)
+ else:
+ transposed = get_op_result_or_value(transposed)
+ assert str(transposed.type) == '!torch.bool', f'`transposed` should be a !torch.bool but is {type(transposed).__module__}.{type(transposed).__name__}'
+
+ if not is_mlir_value(output_padding):
+ output_padding = list(map(torch_dialect.ConstantIntOp, output_padding))
+ output_padding = torch_dialect.PrimListConstructOp(output_padding)
+ else:
+ output_padding = get_op_result_or_value(output_padding)
+ assert str(output_padding.type) == '!torch.list', f'`output_padding` should be a !torch.list but is {type(output_padding).__module__}.{type(output_padding).__name__}'
+
+ if not is_mlir_value(groups):
+ groups = torch_dialect.ConstantIntOp(groups)
+ else:
+ groups = get_op_result_or_value(groups)
+ assert str(groups.type) == '!torch.int', f'`groups` should be a !torch.int but is {type(groups).__module__}.{type(groups).__name__}'
+
+ if not is_mlir_value(output_mask):
+ output_mask = list(map(torch_dialect.ConstantBoolOp, output_mask))
+ output_mask = torch_dialect.PrimListConstructOp(output_mask)
+ else:
+ output_mask = get_op_result_or_value(output_mask)
+ # should be bool[]
+ pass
+
+ grad_input_type = Type.parse("!torch.vtensor")
+ grad_weight_type = Type.parse("!torch.vtensor")
+ grad_bias_type = Type.parse("!torch.vtensor")
+ super(AtenConvolutionBackwardOverrideableOp, self).__init__(grad_input_type, grad_weight_type, grad_bias_type, grad_output, input, weight, stride, padding, dilation, transposed, output_padding, groups, output_mask, loc=loc, ip=ip)
+
+
+class AtenFlipOp:
+ def __init__(self, self_: Value, dims: List[int], *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(dims):
+ dims = list(map(torch_dialect.ConstantIntOp, dims))
+ dims = torch_dialect.PrimListConstructOp(dims)
+ else:
+ dims = get_op_result_or_value(dims)
+ assert str(dims.type) == '!torch.list', f'`dims` should be a !torch.list but is {type(dims).__module__}.{type(dims).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenFlipOp, self).__init__(result_type, self_, dims, loc=loc, ip=ip)
+
+
+class AtenNativeBatchNormOp:
+ def __init__(self, input: Value, weight: Optional[Value], bias: Optional[Value], running_mean: Optional[Value], running_var: Optional[Value], training: bool, momentum: float, eps: float, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(input):
+ assert is_mlir_value(input), f'`input` should be a Value but is {type(input).__module__}.{type(input).__name__}'
+ else:
+ input = get_op_result_or_value(input)
+ assert str(input.type).startswith("!torch.vtensor"), f'`input` should be a torch.vtensor but is {type(input).__module__}.{type(input).__name__}'
+
+ if not is_mlir_value(weight):
+ if weight is not None:
+ assert is_mlir_value(weight), f'`weight` should be a Value but is {type(weight).__module__}.{type(weight).__name__}'
+ else:
+ weight = torch_dialect.ConstantNoneOp()
+ else:
+ weight = get_op_result_or_value(weight)
+ assert str(weight.type).startswith("!torch.vtensor"), f'`weight` should be a torch.vtensor but is {type(weight).__module__}.{type(weight).__name__}'
+
+ if not is_mlir_value(bias):
+ if bias is not None:
+ assert is_mlir_value(bias), f'`bias` should be a Value but is {type(bias).__module__}.{type(bias).__name__}'
+ else:
+ bias = torch_dialect.ConstantNoneOp()
+ else:
+ bias = get_op_result_or_value(bias)
+ assert str(bias.type).startswith("!torch.vtensor"), f'`bias` should be a torch.vtensor but is {type(bias).__module__}.{type(bias).__name__}'
+
+ if not is_mlir_value(running_mean):
+ if running_mean is not None:
+ assert is_mlir_value(running_mean), f'`running_mean` should be a Value but is {type(running_mean).__module__}.{type(running_mean).__name__}'
+ else:
+ running_mean = torch_dialect.ConstantNoneOp()
+ else:
+ running_mean = get_op_result_or_value(running_mean)
+ assert str(running_mean.type).startswith("!torch.vtensor"), f'`running_mean` should be a torch.vtensor but is {type(running_mean).__module__}.{type(running_mean).__name__}'
+
+ if not is_mlir_value(running_var):
+ if running_var is not None:
+ assert is_mlir_value(running_var), f'`running_var` should be a Value but is {type(running_var).__module__}.{type(running_var).__name__}'
+ else:
+ running_var = torch_dialect.ConstantNoneOp()
+ else:
+ running_var = get_op_result_or_value(running_var)
+ assert str(running_var.type).startswith("!torch.vtensor"), f'`running_var` should be a torch.vtensor but is {type(running_var).__module__}.{type(running_var).__name__}'
+
+ if not is_mlir_value(training):
+ training = torch_dialect.ConstantBoolOp(training)
+ else:
+ training = get_op_result_or_value(training)
+ assert str(training.type) == '!torch.bool', f'`training` should be a !torch.bool but is {type(training).__module__}.{type(training).__name__}'
+
+ if not is_mlir_value(momentum):
+ momentum = torch_dialect.ConstantFloatOp(momentum)
+ else:
+ momentum = get_op_result_or_value(momentum)
+ assert str(momentum.type) == '!torch.float', f'`momentum` should be a !torch.float but is {type(momentum).__module__}.{type(momentum).__name__}'
+
+ if not is_mlir_value(eps):
+ eps = torch_dialect.ConstantFloatOp(eps)
+ else:
+ eps = get_op_result_or_value(eps)
+ assert str(eps.type) == '!torch.float', f'`eps` should be a !torch.float but is {type(eps).__module__}.{type(eps).__name__}'
+
+ result0_type = Type.parse("!torch.vtensor")
+ result1_type = Type.parse("!torch.vtensor")
+ result2_type = Type.parse("!torch.vtensor")
+ super(AtenNativeBatchNormOp, self).__init__(result0_type, result1_type, result2_type, input, weight, bias, running_mean, running_var, training, momentum, eps, loc=loc, ip=ip)
+
+
+class AtenBatchNormOp:
+ def __init__(self, input: Value, weight: Optional[Value], bias: Optional[Value], running_mean: Optional[Value], running_var: Optional[Value], training: bool, momentum: float, eps: float, cudnn_enabled: bool, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(input):
+ assert is_mlir_value(input), f'`input` should be a Value but is {type(input).__module__}.{type(input).__name__}'
+ else:
+ input = get_op_result_or_value(input)
+ assert str(input.type).startswith("!torch.vtensor"), f'`input` should be a torch.vtensor but is {type(input).__module__}.{type(input).__name__}'
+
+ if not is_mlir_value(weight):
+ if weight is not None:
+ assert is_mlir_value(weight), f'`weight` should be a Value but is {type(weight).__module__}.{type(weight).__name__}'
+ else:
+ weight = torch_dialect.ConstantNoneOp()
+ else:
+ weight = get_op_result_or_value(weight)
+ assert str(weight.type).startswith("!torch.vtensor"), f'`weight` should be a torch.vtensor but is {type(weight).__module__}.{type(weight).__name__}'
+
+ if not is_mlir_value(bias):
+ if bias is not None:
+ assert is_mlir_value(bias), f'`bias` should be a Value but is {type(bias).__module__}.{type(bias).__name__}'
+ else:
+ bias = torch_dialect.ConstantNoneOp()
+ else:
+ bias = get_op_result_or_value(bias)
+ assert str(bias.type).startswith("!torch.vtensor"), f'`bias` should be a torch.vtensor but is {type(bias).__module__}.{type(bias).__name__}'
+
+ if not is_mlir_value(running_mean):
+ if running_mean is not None:
+ assert is_mlir_value(running_mean), f'`running_mean` should be a Value but is {type(running_mean).__module__}.{type(running_mean).__name__}'
+ else:
+ running_mean = torch_dialect.ConstantNoneOp()
+ else:
+ running_mean = get_op_result_or_value(running_mean)
+ assert str(running_mean.type).startswith("!torch.vtensor"), f'`running_mean` should be a torch.vtensor but is {type(running_mean).__module__}.{type(running_mean).__name__}'
+
+ if not is_mlir_value(running_var):
+ if running_var is not None:
+ assert is_mlir_value(running_var), f'`running_var` should be a Value but is {type(running_var).__module__}.{type(running_var).__name__}'
+ else:
+ running_var = torch_dialect.ConstantNoneOp()
+ else:
+ running_var = get_op_result_or_value(running_var)
+ assert str(running_var.type).startswith("!torch.vtensor"), f'`running_var` should be a torch.vtensor but is {type(running_var).__module__}.{type(running_var).__name__}'
+
+ if not is_mlir_value(training):
+ training = torch_dialect.ConstantBoolOp(training)
+ else:
+ training = get_op_result_or_value(training)
+ assert str(training.type) == '!torch.bool', f'`training` should be a !torch.bool but is {type(training).__module__}.{type(training).__name__}'
+
+ if not is_mlir_value(momentum):
+ momentum = torch_dialect.ConstantFloatOp(momentum)
+ else:
+ momentum = get_op_result_or_value(momentum)
+ assert str(momentum.type) == '!torch.float', f'`momentum` should be a !torch.float but is {type(momentum).__module__}.{type(momentum).__name__}'
+
+ if not is_mlir_value(eps):
+ eps = torch_dialect.ConstantFloatOp(eps)
+ else:
+ eps = get_op_result_or_value(eps)
+ assert str(eps.type) == '!torch.float', f'`eps` should be a !torch.float but is {type(eps).__module__}.{type(eps).__name__}'
+
+ if not is_mlir_value(cudnn_enabled):
+ cudnn_enabled = torch_dialect.ConstantBoolOp(cudnn_enabled)
+ else:
+ cudnn_enabled = get_op_result_or_value(cudnn_enabled)
+ assert str(cudnn_enabled.type) == '!torch.bool', f'`cudnn_enabled` should be a !torch.bool but is {type(cudnn_enabled).__module__}.{type(cudnn_enabled).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenBatchNormOp, self).__init__(result_type, input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled, loc=loc, ip=ip)
+
+
+class AtenLayerNormOp:
+ def __init__(self, input: Value, normalized_shape: List[int], weight: Optional[Value], bias: Optional[Value], eps: float, cudnn_enable: bool, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(input):
+ assert is_mlir_value(input), f'`input` should be a Value but is {type(input).__module__}.{type(input).__name__}'
+ else:
+ input = get_op_result_or_value(input)
+ assert str(input.type).startswith("!torch.vtensor"), f'`input` should be a torch.vtensor but is {type(input).__module__}.{type(input).__name__}'
+
+ if not is_mlir_value(normalized_shape):
+ normalized_shape = list(map(torch_dialect.ConstantIntOp, normalized_shape))
+ normalized_shape = torch_dialect.PrimListConstructOp(normalized_shape)
+ else:
+ normalized_shape = get_op_result_or_value(normalized_shape)
+ assert str(normalized_shape.type) == '!torch.list', f'`normalized_shape` should be a !torch.list but is {type(normalized_shape).__module__}.{type(normalized_shape).__name__}'
+
+ if not is_mlir_value(weight):
+ if weight is not None:
+ assert is_mlir_value(weight), f'`weight` should be a Value but is {type(weight).__module__}.{type(weight).__name__}'
+ else:
+ weight = torch_dialect.ConstantNoneOp()
+ else:
+ weight = get_op_result_or_value(weight)
+ assert str(weight.type).startswith("!torch.vtensor"), f'`weight` should be a torch.vtensor but is {type(weight).__module__}.{type(weight).__name__}'
+
+ if not is_mlir_value(bias):
+ if bias is not None:
+ assert is_mlir_value(bias), f'`bias` should be a Value but is {type(bias).__module__}.{type(bias).__name__}'
+ else:
+ bias = torch_dialect.ConstantNoneOp()
+ else:
+ bias = get_op_result_or_value(bias)
+ assert str(bias.type).startswith("!torch.vtensor"), f'`bias` should be a torch.vtensor but is {type(bias).__module__}.{type(bias).__name__}'
+
+ if not is_mlir_value(eps):
+ eps = torch_dialect.ConstantFloatOp(eps)
+ else:
+ eps = get_op_result_or_value(eps)
+ assert str(eps.type) == '!torch.float', f'`eps` should be a !torch.float but is {type(eps).__module__}.{type(eps).__name__}'
+
+ if not is_mlir_value(cudnn_enable):
+ cudnn_enable = torch_dialect.ConstantBoolOp(cudnn_enable)
+ else:
+ cudnn_enable = get_op_result_or_value(cudnn_enable)
+ assert str(cudnn_enable.type) == '!torch.bool', f'`cudnn_enable` should be a !torch.bool but is {type(cudnn_enable).__module__}.{type(cudnn_enable).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenLayerNormOp, self).__init__(result_type, input, normalized_shape, weight, bias, eps, cudnn_enable, loc=loc, ip=ip)
+
+
+class AtenNativeLayerNormOp:
+ def __init__(self, input: Value, normalized_shape: List[int], weight: Optional[Value], bias: Optional[Value], eps: float, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(input):
+ assert is_mlir_value(input), f'`input` should be a Value but is {type(input).__module__}.{type(input).__name__}'
+ else:
+ input = get_op_result_or_value(input)
+ assert str(input.type).startswith("!torch.vtensor"), f'`input` should be a torch.vtensor but is {type(input).__module__}.{type(input).__name__}'
+
+ if not is_mlir_value(normalized_shape):
+ normalized_shape = list(map(torch_dialect.ConstantIntOp, normalized_shape))
+ normalized_shape = torch_dialect.PrimListConstructOp(normalized_shape)
+ else:
+ normalized_shape = get_op_result_or_value(normalized_shape)
+ assert str(normalized_shape.type) == '!torch.list', f'`normalized_shape` should be a !torch.list but is {type(normalized_shape).__module__}.{type(normalized_shape).__name__}'
+
+ if not is_mlir_value(weight):
+ if weight is not None:
+ assert is_mlir_value(weight), f'`weight` should be a Value but is {type(weight).__module__}.{type(weight).__name__}'
+ else:
+ weight = torch_dialect.ConstantNoneOp()
+ else:
+ weight = get_op_result_or_value(weight)
+ assert str(weight.type).startswith("!torch.vtensor"), f'`weight` should be a torch.vtensor but is {type(weight).__module__}.{type(weight).__name__}'
+
+ if not is_mlir_value(bias):
+ if bias is not None:
+ assert is_mlir_value(bias), f'`bias` should be a Value but is {type(bias).__module__}.{type(bias).__name__}'
+ else:
+ bias = torch_dialect.ConstantNoneOp()
+ else:
+ bias = get_op_result_or_value(bias)
+ assert str(bias.type).startswith("!torch.vtensor"), f'`bias` should be a torch.vtensor but is {type(bias).__module__}.{type(bias).__name__}'
+
+ if not is_mlir_value(eps):
+ eps = torch_dialect.ConstantFloatOp(eps)
+ else:
+ eps = get_op_result_or_value(eps)
+ assert str(eps.type) == '!torch.float', f'`eps` should be a !torch.float but is {type(eps).__module__}.{type(eps).__name__}'
+
+ result0_type = Type.parse("!torch.vtensor")
+ result1_type = Type.parse("!torch.vtensor")
+ result2_type = Type.parse("!torch.vtensor")
+ super(AtenNativeLayerNormOp, self).__init__(result0_type, result1_type, result2_type, input, normalized_shape, weight, bias, eps, loc=loc, ip=ip)
+
+
+class AtenMaxPool2dOp:
+ def __init__(self, self_: Value, kernel_size: List[int], stride: List[int], padding: List[int], dilation: List[int], ceil_mode: bool, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(kernel_size):
+ kernel_size = list(map(torch_dialect.ConstantIntOp, kernel_size))
+ kernel_size = torch_dialect.PrimListConstructOp(kernel_size)
+ else:
+ kernel_size = get_op_result_or_value(kernel_size)
+ assert str(kernel_size.type) == '!torch.list', f'`kernel_size` should be a !torch.list but is {type(kernel_size).__module__}.{type(kernel_size).__name__}'
+
+ if not is_mlir_value(stride):
+ stride = list(map(torch_dialect.ConstantIntOp, stride))
+ stride = torch_dialect.PrimListConstructOp(stride)
+ else:
+ stride = get_op_result_or_value(stride)
+ assert str(stride.type) == '!torch.list', f'`stride` should be a !torch.list but is {type(stride).__module__}.{type(stride).__name__}'
+
+ if not is_mlir_value(padding):
+ padding = list(map(torch_dialect.ConstantIntOp, padding))
+ padding = torch_dialect.PrimListConstructOp(padding)
+ else:
+ padding = get_op_result_or_value(padding)
+ assert str(padding.type) == '!torch.list', f'`padding` should be a !torch.list but is {type(padding).__module__}.{type(padding).__name__}'
+
+ if not is_mlir_value(dilation):
+ dilation = list(map(torch_dialect.ConstantIntOp, dilation))
+ dilation = torch_dialect.PrimListConstructOp(dilation)
+ else:
+ dilation = get_op_result_or_value(dilation)
+ assert str(dilation.type) == '!torch.list', f'`dilation` should be a !torch.list but is {type(dilation).__module__}.{type(dilation).__name__}'
+
+ if not is_mlir_value(ceil_mode):
+ ceil_mode = torch_dialect.ConstantBoolOp(ceil_mode)
+ else:
+ ceil_mode = get_op_result_or_value(ceil_mode)
+ assert str(ceil_mode.type) == '!torch.bool', f'`ceil_mode` should be a !torch.bool but is {type(ceil_mode).__module__}.{type(ceil_mode).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenMaxPool2dOp, self).__init__(result_type, self_, kernel_size, stride, padding, dilation, ceil_mode, loc=loc, ip=ip)
+
+
+class AtenMaxPool2dWithIndicesOp:
+ def __init__(self, self_: Value, kernel_size: List[int], stride: List[int], padding: List[int], dilation: List[int], ceil_mode: bool, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(kernel_size):
+ kernel_size = list(map(torch_dialect.ConstantIntOp, kernel_size))
+ kernel_size = torch_dialect.PrimListConstructOp(kernel_size)
+ else:
+ kernel_size = get_op_result_or_value(kernel_size)
+ assert str(kernel_size.type) == '!torch.list', f'`kernel_size` should be a !torch.list but is {type(kernel_size).__module__}.{type(kernel_size).__name__}'
+
+ if not is_mlir_value(stride):
+ stride = list(map(torch_dialect.ConstantIntOp, stride))
+ stride = torch_dialect.PrimListConstructOp(stride)
+ else:
+ stride = get_op_result_or_value(stride)
+ assert str(stride.type) == '!torch.list', f'`stride` should be a !torch.list but is {type(stride).__module__}.{type(stride).__name__}'
+
+ if not is_mlir_value(padding):
+ padding = list(map(torch_dialect.ConstantIntOp, padding))
+ padding = torch_dialect.PrimListConstructOp(padding)
+ else:
+ padding = get_op_result_or_value(padding)
+ assert str(padding.type) == '!torch.list', f'`padding` should be a !torch.list but is {type(padding).__module__}.{type(padding).__name__}'
+
+ if not is_mlir_value(dilation):
+ dilation = list(map(torch_dialect.ConstantIntOp, dilation))
+ dilation = torch_dialect.PrimListConstructOp(dilation)
+ else:
+ dilation = get_op_result_or_value(dilation)
+ assert str(dilation.type) == '!torch.list', f'`dilation` should be a !torch.list but is {type(dilation).__module__}.{type(dilation).__name__}'
+
+ if not is_mlir_value(ceil_mode):
+ ceil_mode = torch_dialect.ConstantBoolOp(ceil_mode)
+ else:
+ ceil_mode = get_op_result_or_value(ceil_mode)
+ assert str(ceil_mode.type) == '!torch.bool', f'`ceil_mode` should be a !torch.bool but is {type(ceil_mode).__module__}.{type(ceil_mode).__name__}'
+
+ result0_type = Type.parse("!torch.vtensor")
+ result1_type = Type.parse("!torch.vtensor")
+ super(AtenMaxPool2dWithIndicesOp, self).__init__(result0_type, result1_type, self_, kernel_size, stride, padding, dilation, ceil_mode, loc=loc, ip=ip)
+
+
+class AtenMaxPool2dWithIndicesBackwardOp:
+ def __init__(self, grad_output: Value, self_: Value, kernel_size: List[int], stride: List[int], padding: List[int], dilation: List[int], ceil_mode: bool, indices: Value, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(grad_output):
+ assert is_mlir_value(grad_output), f'`grad_output` should be a Value but is {type(grad_output).__module__}.{type(grad_output).__name__}'
+ else:
+ grad_output = get_op_result_or_value(grad_output)
+ assert str(grad_output.type).startswith("!torch.vtensor"), f'`grad_output` should be a torch.vtensor but is {type(grad_output).__module__}.{type(grad_output).__name__}'
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(kernel_size):
+ kernel_size = list(map(torch_dialect.ConstantIntOp, kernel_size))
+ kernel_size = torch_dialect.PrimListConstructOp(kernel_size)
+ else:
+ kernel_size = get_op_result_or_value(kernel_size)
+ assert str(kernel_size.type) == '!torch.list', f'`kernel_size` should be a !torch.list but is {type(kernel_size).__module__}.{type(kernel_size).__name__}'
+
+ if not is_mlir_value(stride):
+ stride = list(map(torch_dialect.ConstantIntOp, stride))
+ stride = torch_dialect.PrimListConstructOp(stride)
+ else:
+ stride = get_op_result_or_value(stride)
+ assert str(stride.type) == '!torch.list', f'`stride` should be a !torch.list but is {type(stride).__module__}.{type(stride).__name__}'
+
+ if not is_mlir_value(padding):
+ padding = list(map(torch_dialect.ConstantIntOp, padding))
+ padding = torch_dialect.PrimListConstructOp(padding)
+ else:
+ padding = get_op_result_or_value(padding)
+ assert str(padding.type) == '!torch.list', f'`padding` should be a !torch.list but is {type(padding).__module__}.{type(padding).__name__}'
+
+ if not is_mlir_value(dilation):
+ dilation = list(map(torch_dialect.ConstantIntOp, dilation))
+ dilation = torch_dialect.PrimListConstructOp(dilation)
+ else:
+ dilation = get_op_result_or_value(dilation)
+ assert str(dilation.type) == '!torch.list', f'`dilation` should be a !torch.list but is {type(dilation).__module__}.{type(dilation).__name__}'
+
+ if not is_mlir_value(ceil_mode):
+ ceil_mode = torch_dialect.ConstantBoolOp(ceil_mode)
+ else:
+ ceil_mode = get_op_result_or_value(ceil_mode)
+ assert str(ceil_mode.type) == '!torch.bool', f'`ceil_mode` should be a !torch.bool but is {type(ceil_mode).__module__}.{type(ceil_mode).__name__}'
+
+ if not is_mlir_value(indices):
+ assert is_mlir_value(indices), f'`indices` should be a Value but is {type(indices).__module__}.{type(indices).__name__}'
+ else:
+ indices = get_op_result_or_value(indices)
+ assert str(indices.type).startswith("!torch.vtensor"), f'`indices` should be a torch.vtensor but is {type(indices).__module__}.{type(indices).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenMaxPool2dWithIndicesBackwardOp, self).__init__(result_type, grad_output, self_, kernel_size, stride, padding, dilation, ceil_mode, indices, loc=loc, ip=ip)
+
+
+class AtenAvgPool2dOp:
+ def __init__(self, self_: Value, kernel_size: List[int], stride: List[int], padding: List[int], ceil_mode: bool, count_include_pad: bool, divisor_override: Optional[int], *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(kernel_size):
+ kernel_size = list(map(torch_dialect.ConstantIntOp, kernel_size))
+ kernel_size = torch_dialect.PrimListConstructOp(kernel_size)
+ else:
+ kernel_size = get_op_result_or_value(kernel_size)
+ assert str(kernel_size.type) == '!torch.list', f'`kernel_size` should be a !torch.list but is {type(kernel_size).__module__}.{type(kernel_size).__name__}'
+
+ if not is_mlir_value(stride):
+ stride = list(map(torch_dialect.ConstantIntOp, stride))
+ stride = torch_dialect.PrimListConstructOp(stride)
+ else:
+ stride = get_op_result_or_value(stride)
+ assert str(stride.type) == '!torch.list', f'`stride` should be a !torch.list but is {type(stride).__module__}.{type(stride).__name__}'
+
+ if not is_mlir_value(padding):
+ padding = list(map(torch_dialect.ConstantIntOp, padding))
+ padding = torch_dialect.PrimListConstructOp(padding)
+ else:
+ padding = get_op_result_or_value(padding)
+ assert str(padding.type) == '!torch.list', f'`padding` should be a !torch.list but is {type(padding).__module__}.{type(padding).__name__}'
+
+ if not is_mlir_value(ceil_mode):
+ ceil_mode = torch_dialect.ConstantBoolOp(ceil_mode)
+ else:
+ ceil_mode = get_op_result_or_value(ceil_mode)
+ assert str(ceil_mode.type) == '!torch.bool', f'`ceil_mode` should be a !torch.bool but is {type(ceil_mode).__module__}.{type(ceil_mode).__name__}'
+
+ if not is_mlir_value(count_include_pad):
+ count_include_pad = torch_dialect.ConstantBoolOp(count_include_pad)
+ else:
+ count_include_pad = get_op_result_or_value(count_include_pad)
+ assert str(count_include_pad.type) == '!torch.bool', f'`count_include_pad` should be a !torch.bool but is {type(count_include_pad).__module__}.{type(count_include_pad).__name__}'
+
+ if not is_mlir_value(divisor_override):
+ if divisor_override is not None:
+ divisor_override = torch_dialect.ConstantIntOp(divisor_override)
+ else:
+ divisor_override = torch_dialect.ConstantNoneOp()
+ else:
+ divisor_override = get_op_result_or_value(divisor_override)
+ assert str(divisor_override.type) == '!torch.int', f'`divisor_override` should be a !torch.int but is {type(divisor_override).__module__}.{type(divisor_override).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenAvgPool2dOp, self).__init__(result_type, self_, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override, loc=loc, ip=ip)
+
+
+class AtenSoftmaxIntOp:
+ def __init__(self, self_: Value, dim: int, dtype: Optional[int], *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(dim):
+ dim = torch_dialect.ConstantIntOp(dim)
+ else:
+ dim = get_op_result_or_value(dim)
+ assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}'
+
+ if not is_mlir_value(dtype):
+ if dtype is not None:
+ dtype = torch_dialect.ConstantIntOp(dtype)
+ else:
+ dtype = torch_dialect.ConstantNoneOp()
+ else:
+ dtype = get_op_result_or_value(dtype)
+ assert str(dtype.type) == '!torch.int', f'`dtype` should be a !torch.int but is {type(dtype).__module__}.{type(dtype).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenSoftmaxIntOp, self).__init__(result_type, self_, dim, dtype, loc=loc, ip=ip)
+
+
+class AtenLogSoftmaxIntOp:
+ def __init__(self, self_: Value, dim: int, dtype: Optional[int], *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(dim):
+ dim = torch_dialect.ConstantIntOp(dim)
+ else:
+ dim = get_op_result_or_value(dim)
+ assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}'
+
+ if not is_mlir_value(dtype):
+ if dtype is not None:
+ dtype = torch_dialect.ConstantIntOp(dtype)
+ else:
+ dtype = torch_dialect.ConstantNoneOp()
+ else:
+ dtype = get_op_result_or_value(dtype)
+ assert str(dtype.type) == '!torch.int', f'`dtype` should be a !torch.int but is {type(dtype).__module__}.{type(dtype).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenLogSoftmaxIntOp, self).__init__(result_type, self_, dim, dtype, loc=loc, ip=ip)
+
+
+class Aten_LogSoftmaxOp:
+ def __init__(self, self_: Value, dim: int, half_to_float: bool, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(dim):
+ dim = torch_dialect.ConstantIntOp(dim)
+ else:
+ dim = get_op_result_or_value(dim)
+ assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}'
+
+ if not is_mlir_value(half_to_float):
+ half_to_float = torch_dialect.ConstantBoolOp(half_to_float)
+ else:
+ half_to_float = get_op_result_or_value(half_to_float)
+ assert str(half_to_float.type) == '!torch.bool', f'`half_to_float` should be a !torch.bool but is {type(half_to_float).__module__}.{type(half_to_float).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(Aten_LogSoftmaxOp, self).__init__(result_type, self_, dim, half_to_float, loc=loc, ip=ip)
+
+
+class AtenAdaptiveAvgPool2dOp:
+ def __init__(self, self_: Value, output_size: List[int], *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(output_size):
+ output_size = list(map(torch_dialect.ConstantIntOp, output_size))
+ output_size = torch_dialect.PrimListConstructOp(output_size)
+ else:
+ output_size = get_op_result_or_value(output_size)
+ assert str(output_size.type) == '!torch.list', f'`output_size` should be a !torch.list but is {type(output_size).__module__}.{type(output_size).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenAdaptiveAvgPool2dOp, self).__init__(result_type, self_, output_size, loc=loc, ip=ip)
+
+
+class AtenTopkOp:
+ def __init__(self, self_: Value, k: int, dim: int, largest: bool, sorted: bool, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(k):
+ k = torch_dialect.ConstantIntOp(k)
+ else:
+ k = get_op_result_or_value(k)
+ assert str(k.type) == '!torch.int', f'`k` should be a !torch.int but is {type(k).__module__}.{type(k).__name__}'
+
+ if not is_mlir_value(dim):
+ dim = torch_dialect.ConstantIntOp(dim)
+ else:
+ dim = get_op_result_or_value(dim)
+ assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}'
+
+ if not is_mlir_value(largest):
+ largest = torch_dialect.ConstantBoolOp(largest)
+ else:
+ largest = get_op_result_or_value(largest)
+ assert str(largest.type) == '!torch.bool', f'`largest` should be a !torch.bool but is {type(largest).__module__}.{type(largest).__name__}'
+
+ if not is_mlir_value(sorted):
+ sorted = torch_dialect.ConstantBoolOp(sorted)
+ else:
+ sorted = get_op_result_or_value(sorted)
+ assert str(sorted.type) == '!torch.bool', f'`sorted` should be a !torch.bool but is {type(sorted).__module__}.{type(sorted).__name__}'
+
+ values_type = Type.parse("!torch.vtensor")
+ indices_type = Type.parse("!torch.vtensor")
+ super(AtenTopkOp, self).__init__(values_type, indices_type, self_, k, dim, largest, sorted, loc=loc, ip=ip)
+
+
+class AtenTransposeIntOp:
+ def __init__(self, self_: Value, dim0: int, dim1: int, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(dim0):
+ dim0 = torch_dialect.ConstantIntOp(dim0)
+ else:
+ dim0 = get_op_result_or_value(dim0)
+ assert str(dim0.type) == '!torch.int', f'`dim0` should be a !torch.int but is {type(dim0).__module__}.{type(dim0).__name__}'
+
+ if not is_mlir_value(dim1):
+ dim1 = torch_dialect.ConstantIntOp(dim1)
+ else:
+ dim1 = get_op_result_or_value(dim1)
+ assert str(dim1.type) == '!torch.int', f'`dim1` should be a !torch.int but is {type(dim1).__module__}.{type(dim1).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenTransposeIntOp, self).__init__(result_type, self_, dim0, dim1, loc=loc, ip=ip)
+
+
+class AtenPermuteOp:
+ def __init__(self, self_: Value, dims: List[int], *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(dims):
+ dims = list(map(torch_dialect.ConstantIntOp, dims))
+ dims = torch_dialect.PrimListConstructOp(dims)
+ else:
+ dims = get_op_result_or_value(dims)
+ assert str(dims.type) == '!torch.list', f'`dims` should be a !torch.list but is {type(dims).__module__}.{type(dims).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenPermuteOp, self).__init__(result_type, self_, dims, loc=loc, ip=ip)
+
+
+class AtenBmmOp:
+ def __init__(self, self_: Value, mat2: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(mat2):
+ assert is_mlir_value(mat2), f'`mat2` should be a Value but is {type(mat2).__module__}.{type(mat2).__name__}'
+ else:
+ mat2 = get_op_result_or_value(mat2)
+ assert str(mat2.type).startswith("!torch.vtensor"), f'`mat2` should be a torch.vtensor but is {type(mat2).__module__}.{type(mat2).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenBmmOp, self).__init__(result_type, self_, mat2, loc=loc, ip=ip)
+
+
+class AtenCumsumOp:
+ def __init__(self, self_: Value, dim: int, dtype: Optional[int], *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(dim):
+ dim = torch_dialect.ConstantIntOp(dim)
+ else:
+ dim = get_op_result_or_value(dim)
+ assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}'
+
+ if not is_mlir_value(dtype):
+ if dtype is not None:
+ dtype = torch_dialect.ConstantIntOp(dtype)
+ else:
+ dtype = torch_dialect.ConstantNoneOp()
+ else:
+ dtype = get_op_result_or_value(dtype)
+ assert str(dtype.type) == '!torch.int', f'`dtype` should be a !torch.int but is {type(dtype).__module__}.{type(dtype).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenCumsumOp, self).__init__(result_type, self_, dim, dtype, loc=loc, ip=ip)
+
+
+class AtenFloorDivideScalarOp:
+ def __init__(self, self_: Value, other: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ other = torch_dialect.ConstantNumberOp(other)
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenFloorDivideScalarOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenLogsumexpOp:
+ def __init__(self, self_: Value, dim: List[int], keepdim: bool, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(dim):
+ dim = list(map(torch_dialect.ConstantIntOp, dim))
+ dim = torch_dialect.PrimListConstructOp(dim)
+ else:
+ dim = get_op_result_or_value(dim)
+ assert str(dim.type) == '!torch.list', f'`dim` should be a !torch.list but is {type(dim).__module__}.{type(dim).__name__}'
+
+ if not is_mlir_value(keepdim):
+ keepdim = torch_dialect.ConstantBoolOp(keepdim)
+ else:
+ keepdim = get_op_result_or_value(keepdim)
+ assert str(keepdim.type) == '!torch.bool', f'`keepdim` should be a !torch.bool but is {type(keepdim).__module__}.{type(keepdim).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenLogsumexpOp, self).__init__(result_type, self_, dim, keepdim, loc=loc, ip=ip)
+
+
+class Aten__And__TensorOp:
+ def __init__(self, self_: Value, other: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}'
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(Aten__And__TensorOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class Aten_SoftmaxOp:
+ def __init__(self, self_: Value, dim: int, half_to_float: bool, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(dim):
+ dim = torch_dialect.ConstantIntOp(dim)
+ else:
+ dim = get_op_result_or_value(dim)
+ assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}'
+
+ if not is_mlir_value(half_to_float):
+ half_to_float = torch_dialect.ConstantBoolOp(half_to_float)
+ else:
+ half_to_float = get_op_result_or_value(half_to_float)
+ assert str(half_to_float.type) == '!torch.bool', f'`half_to_float` should be a !torch.bool but is {type(half_to_float).__module__}.{type(half_to_float).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(Aten_SoftmaxOp, self).__init__(result_type, self_, dim, half_to_float, loc=loc, ip=ip)
+
+
+class AtenMeanOp:
+ def __init__(self, self_: Value, dtype: Optional[int], *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(dtype):
+ if dtype is not None:
+ dtype = torch_dialect.ConstantIntOp(dtype)
+ else:
+ dtype = torch_dialect.ConstantNoneOp()
+ else:
+ dtype = get_op_result_or_value(dtype)
+ assert str(dtype.type) == '!torch.int', f'`dtype` should be a !torch.int but is {type(dtype).__module__}.{type(dtype).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenMeanOp, self).__init__(result_type, self_, dtype, loc=loc, ip=ip)
+
+
+class AtenStdOp:
+ def __init__(self, self_: Value, unbiased: bool, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(unbiased):
+ unbiased = torch_dialect.ConstantBoolOp(unbiased)
+ else:
+ unbiased = get_op_result_or_value(unbiased)
+ assert str(unbiased.type) == '!torch.bool', f'`unbiased` should be a !torch.bool but is {type(unbiased).__module__}.{type(unbiased).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenStdOp, self).__init__(result_type, self_, unbiased, loc=loc, ip=ip)
+
+
+class AtenVarOp:
+ def __init__(self, self_: Value, unbiased: bool, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(unbiased):
+ unbiased = torch_dialect.ConstantBoolOp(unbiased)
+ else:
+ unbiased = get_op_result_or_value(unbiased)
+ assert str(unbiased.type) == '!torch.bool', f'`unbiased` should be a !torch.bool but is {type(unbiased).__module__}.{type(unbiased).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenVarOp, self).__init__(result_type, self_, unbiased, loc=loc, ip=ip)
+
+
+class AtenVarMeanOp:
+ def __init__(self, self_: Value, unbiased: bool, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(unbiased):
+ unbiased = torch_dialect.ConstantBoolOp(unbiased)
+ else:
+ unbiased = get_op_result_or_value(unbiased)
+ assert str(unbiased.type) == '!torch.bool', f'`unbiased` should be a !torch.bool but is {type(unbiased).__module__}.{type(unbiased).__name__}'
+
+ result0_type = Type.parse("!torch.vtensor")
+ result1_type = Type.parse("!torch.vtensor")
+ super(AtenVarMeanOp, self).__init__(result0_type, result1_type, self_, unbiased, loc=loc, ip=ip)
+
+
+class AtenNllLossForwardOp:
+ def __init__(self, self_: Value, target: Value, weight: Optional[Value], reduction: int, ignore_index: int, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(target):
+ assert is_mlir_value(target), f'`target` should be a Value but is {type(target).__module__}.{type(target).__name__}'
+ else:
+ target = get_op_result_or_value(target)
+ assert str(target.type).startswith("!torch.vtensor"), f'`target` should be a torch.vtensor but is {type(target).__module__}.{type(target).__name__}'
+
+ if not is_mlir_value(weight):
+ if weight is not None:
+ assert is_mlir_value(weight), f'`weight` should be a Value but is {type(weight).__module__}.{type(weight).__name__}'
+ else:
+ weight = torch_dialect.ConstantNoneOp()
+ else:
+ weight = get_op_result_or_value(weight)
+ assert str(weight.type).startswith("!torch.vtensor"), f'`weight` should be a torch.vtensor but is {type(weight).__module__}.{type(weight).__name__}'
+
+ if not is_mlir_value(reduction):
+ reduction = torch_dialect.ConstantIntOp(reduction)
+ else:
+ reduction = get_op_result_or_value(reduction)
+ assert str(reduction.type) == '!torch.int', f'`reduction` should be a !torch.int but is {type(reduction).__module__}.{type(reduction).__name__}'
+
+ if not is_mlir_value(ignore_index):
+ ignore_index = torch_dialect.ConstantIntOp(ignore_index)
+ else:
+ ignore_index = get_op_result_or_value(ignore_index)
+ assert str(ignore_index.type) == '!torch.int', f'`ignore_index` should be a !torch.int but is {type(ignore_index).__module__}.{type(ignore_index).__name__}'
+
+ output_type = Type.parse("!torch.vtensor")
+ total_weight_type = Type.parse("!torch.vtensor")
+ super(AtenNllLossForwardOp, self).__init__(output_type, total_weight_type, self_, target, weight, reduction, ignore_index, loc=loc, ip=ip)
+
+
+class AtenNllLossBackwardOp:
+ def __init__(self, grad_output: Value, self_: Value, target: Value, weight: Optional[Value], reduction: int, ignore_index: int, total_weight: Value, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(grad_output):
+ assert is_mlir_value(grad_output), f'`grad_output` should be a Value but is {type(grad_output).__module__}.{type(grad_output).__name__}'
+ else:
+ grad_output = get_op_result_or_value(grad_output)
+ assert str(grad_output.type).startswith("!torch.vtensor"), f'`grad_output` should be a torch.vtensor but is {type(grad_output).__module__}.{type(grad_output).__name__}'
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(target):
+ assert is_mlir_value(target), f'`target` should be a Value but is {type(target).__module__}.{type(target).__name__}'
+ else:
+ target = get_op_result_or_value(target)
+ assert str(target.type).startswith("!torch.vtensor"), f'`target` should be a torch.vtensor but is {type(target).__module__}.{type(target).__name__}'
+
+ if not is_mlir_value(weight):
+ if weight is not None:
+ assert is_mlir_value(weight), f'`weight` should be a Value but is {type(weight).__module__}.{type(weight).__name__}'
+ else:
+ weight = torch_dialect.ConstantNoneOp()
+ else:
+ weight = get_op_result_or_value(weight)
+ assert str(weight.type).startswith("!torch.vtensor"), f'`weight` should be a torch.vtensor but is {type(weight).__module__}.{type(weight).__name__}'
+
+ if not is_mlir_value(reduction):
+ reduction = torch_dialect.ConstantIntOp(reduction)
+ else:
+ reduction = get_op_result_or_value(reduction)
+ assert str(reduction.type) == '!torch.int', f'`reduction` should be a !torch.int but is {type(reduction).__module__}.{type(reduction).__name__}'
+
+ if not is_mlir_value(ignore_index):
+ ignore_index = torch_dialect.ConstantIntOp(ignore_index)
+ else:
+ ignore_index = get_op_result_or_value(ignore_index)
+ assert str(ignore_index.type) == '!torch.int', f'`ignore_index` should be a !torch.int but is {type(ignore_index).__module__}.{type(ignore_index).__name__}'
+
+ if not is_mlir_value(total_weight):
+ assert is_mlir_value(total_weight), f'`total_weight` should be a Value but is {type(total_weight).__module__}.{type(total_weight).__name__}'
+ else:
+ total_weight = get_op_result_or_value(total_weight)
+ assert str(total_weight.type).startswith("!torch.vtensor"), f'`total_weight` should be a torch.vtensor but is {type(total_weight).__module__}.{type(total_weight).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenNllLossBackwardOp, self).__init__(result_type, grad_output, self_, target, weight, reduction, ignore_index, total_weight, loc=loc, ip=ip)
+
+
+class AtenBincountOp:
+ def __init__(self, self_: Value, weights: Optional[Value], minlength: int, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(weights):
+ if weights is not None:
+ assert is_mlir_value(weights), f'`weights` should be a Value but is {type(weights).__module__}.{type(weights).__name__}'
+ else:
+ weights = torch_dialect.ConstantNoneOp()
+ else:
+ weights = get_op_result_or_value(weights)
+ assert str(weights.type).startswith("!torch.vtensor"), f'`weights` should be a torch.vtensor but is {type(weights).__module__}.{type(weights).__name__}'
+
+ if not is_mlir_value(minlength):
+ minlength = torch_dialect.ConstantIntOp(minlength)
+ else:
+ minlength = get_op_result_or_value(minlength)
+ assert str(minlength.type) == '!torch.int', f'`minlength` should be a !torch.int but is {type(minlength).__module__}.{type(minlength).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenBincountOp, self).__init__(result_type, self_, weights, minlength, loc=loc, ip=ip)
+
+
+class AtenFrobeniusNormDimOp:
+ def __init__(self, self_: Value, dim: List[int], keepdim: bool, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(dim):
+ dim = list(map(torch_dialect.ConstantIntOp, dim))
+ dim = torch_dialect.PrimListConstructOp(dim)
+ else:
+ dim = get_op_result_or_value(dim)
+ assert str(dim.type) == '!torch.list', f'`dim` should be a !torch.list but is {type(dim).__module__}.{type(dim).__name__}'
+
+ if not is_mlir_value(keepdim):
+ keepdim = torch_dialect.ConstantBoolOp(keepdim)
+ else:
+ keepdim = get_op_result_or_value(keepdim)
+ assert str(keepdim.type) == '!torch.bool', f'`keepdim` should be a !torch.bool but is {type(keepdim).__module__}.{type(keepdim).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenFrobeniusNormDimOp, self).__init__(result_type, self_, dim, keepdim, loc=loc, ip=ip)
+
+
+class AtenMseLossOp:
+ def __init__(self, self_: Value, target: Value, reduction: int, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(target):
+ assert is_mlir_value(target), f'`target` should be a Value but is {type(target).__module__}.{type(target).__name__}'
+ else:
+ target = get_op_result_or_value(target)
+ assert str(target.type).startswith("!torch.vtensor"), f'`target` should be a torch.vtensor but is {type(target).__module__}.{type(target).__name__}'
+
+ if not is_mlir_value(reduction):
+ reduction = torch_dialect.ConstantIntOp(reduction)
+ else:
+ reduction = get_op_result_or_value(reduction)
+ assert str(reduction.type) == '!torch.int', f'`reduction` should be a !torch.int but is {type(reduction).__module__}.{type(reduction).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenMseLossOp, self).__init__(result_type, self_, target, reduction, loc=loc, ip=ip)
+
+
+class AtenUpsampleNearest2dBackwardOp:
+ def __init__(self, grad_output: Value, output_size: List[int], input_size: List[int], scales_h: Optional[float], scales_w: Optional[float], *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(grad_output):
+ assert is_mlir_value(grad_output), f'`grad_output` should be a Value but is {type(grad_output).__module__}.{type(grad_output).__name__}'
+ else:
+ grad_output = get_op_result_or_value(grad_output)
+ assert str(grad_output.type).startswith("!torch.vtensor"), f'`grad_output` should be a torch.vtensor but is {type(grad_output).__module__}.{type(grad_output).__name__}'
+
+ if not is_mlir_value(output_size):
+ output_size = list(map(torch_dialect.ConstantIntOp, output_size))
+ output_size = torch_dialect.PrimListConstructOp(output_size)
+ else:
+ output_size = get_op_result_or_value(output_size)
+ assert str(output_size.type) == '!torch.list', f'`output_size` should be a !torch.list but is {type(output_size).__module__}.{type(output_size).__name__}'
+
+ if not is_mlir_value(input_size):
+ input_size = list(map(torch_dialect.ConstantIntOp, input_size))
+ input_size = torch_dialect.PrimListConstructOp(input_size)
+ else:
+ input_size = get_op_result_or_value(input_size)
+ assert str(input_size.type) == '!torch.list', f'`input_size` should be a !torch.list but is {type(input_size).__module__}.{type(input_size).__name__}'
+
+ if not is_mlir_value(scales_h):
+ if scales_h is not None:
+ scales_h = torch_dialect.ConstantFloatOp(scales_h)
+ else:
+ scales_h = torch_dialect.ConstantNoneOp()
+ else:
+ scales_h = get_op_result_or_value(scales_h)
+ assert str(scales_h.type) == '!torch.float', f'`scales_h` should be a !torch.float but is {type(scales_h).__module__}.{type(scales_h).__name__}'
+
+ if not is_mlir_value(scales_w):
+ if scales_w is not None:
+ scales_w = torch_dialect.ConstantFloatOp(scales_w)
+ else:
+ scales_w = torch_dialect.ConstantNoneOp()
+ else:
+ scales_w = get_op_result_or_value(scales_w)
+ assert str(scales_w.type) == '!torch.float', f'`scales_w` should be a !torch.float but is {type(scales_w).__module__}.{type(scales_w).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenUpsampleNearest2dBackwardOp, self).__init__(result_type, grad_output, output_size, input_size, scales_h, scales_w, loc=loc, ip=ip)
+
+
+class AtenConstantPadNdOp:
+ def __init__(self, self_: Value, pad: List[int], value: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(pad):
+ pad = list(map(torch_dialect.ConstantIntOp, pad))
+ pad = torch_dialect.PrimListConstructOp(pad)
+ else:
+ pad = get_op_result_or_value(pad)
+ assert str(pad.type) == '!torch.list', f'`pad` should be a !torch.list but is {type(pad).__module__}.{type(pad).__name__}'
+
+ if not is_mlir_value(value):
+ value = torch_dialect.ConstantNumberOp(value)
+ else:
+ value = get_op_result_or_value(value)
+ assert str(value.type) in {'!torch.float', '!torch.int'}, f'`value` should be a !torch.number but is {type(value).__module__}.{type(value).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenConstantPadNdOp, self).__init__(result_type, self_, pad, value, loc=loc, ip=ip)
+
+
+class AtenPadOp:
+ def __init__(self, self_: Value, pad: List[int], mode: str, value: Optional[float], *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(pad):
+ pad = list(map(torch_dialect.ConstantIntOp, pad))
+ pad = torch_dialect.PrimListConstructOp(pad)
+ else:
+ pad = get_op_result_or_value(pad)
+ assert str(pad.type) == '!torch.list', f'`pad` should be a !torch.list but is {type(pad).__module__}.{type(pad).__name__}'
+
+ if not is_mlir_value(mode):
+ mode = torch_dialect.ConstantStrOp(mode)
+ else:
+ mode = get_op_result_or_value(mode)
+ assert str(mode.type) == '!torch.str', f'`mode` should be a !torch.str but is {type(mode).__module__}.{type(mode).__name__}'
+
+ if not is_mlir_value(value):
+ if value is not None:
+ value = torch_dialect.ConstantFloatOp(value)
+ else:
+ value = torch_dialect.ConstantNoneOp()
+ else:
+ value = get_op_result_or_value(value)
+ assert str(value.type) == '!torch.float', f'`value` should be a !torch.float but is {type(value).__module__}.{type(value).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenPadOp, self).__init__(result_type, self_, pad, mode, value, loc=loc, ip=ip)
+
+
+class AtenSqueezeDimOp:
+ def __init__(self, self_: Value, dim: int, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(dim):
+ dim = torch_dialect.ConstantIntOp(dim)
+ else:
+ dim = get_op_result_or_value(dim)
+ assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenSqueezeDimOp, self).__init__(result_type, self_, dim, loc=loc, ip=ip)
+
+
+class AtenSqueezeOp:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenSqueezeOp, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenFlattenUsingIntsOp:
+ def __init__(self, self_: Value, start_dim: int, end_dim: int, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(start_dim):
+ start_dim = torch_dialect.ConstantIntOp(start_dim)
+ else:
+ start_dim = get_op_result_or_value(start_dim)
+ assert str(start_dim.type) == '!torch.int', f'`start_dim` should be a !torch.int but is {type(start_dim).__module__}.{type(start_dim).__name__}'
+
+ if not is_mlir_value(end_dim):
+ end_dim = torch_dialect.ConstantIntOp(end_dim)
+ else:
+ end_dim = get_op_result_or_value(end_dim)
+ assert str(end_dim.type) == '!torch.int', f'`end_dim` should be a !torch.int but is {type(end_dim).__module__}.{type(end_dim).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenFlattenUsingIntsOp, self).__init__(result_type, self_, start_dim, end_dim, loc=loc, ip=ip)
+
+
+class AtenDimOp:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ super(AtenDimOp, self).__init__(self_, loc=loc, ip=ip)
+
+
+class AtenSizeOp:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.list")
+ super(AtenSizeOp, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenBoolTensorOp:
+ def __init__(self, a: Value, *, loc=None, ip=None):
+ if not is_mlir_value(a):
+ assert is_mlir_value(a), f'`a` should be a Value but is {type(a).__module__}.{type(a).__name__}'
+ else:
+ a = get_op_result_or_value(a)
+ assert str(a.type).startswith("!torch.vtensor"), f'`a` should be a torch.vtensor but is {type(a).__module__}.{type(a).__name__}'
+
+ super(AtenBoolTensorOp, self).__init__(a, loc=loc, ip=ip)
+
+
+class AtenIsFloatingPointOp:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ super(AtenIsFloatingPointOp, self).__init__(self_, loc=loc, ip=ip)
+
+
+class Aten_ShapeAsTensorOp:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(Aten_ShapeAsTensorOp, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenAllOp:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenAllOp, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenAllBoolOp:
+ def __init__(self, self_: List[bool], *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ self_ = list(map(torch_dialect.ConstantBoolOp, self_))
+ self_ = torch_dialect.PrimListConstructOp(self_)
+ else:
+ self_ = get_op_result_or_value(self_)
+ # should be bool[]
+ pass
+
+ super(AtenAllBoolOp, self).__init__(self_, loc=loc, ip=ip)
+
+
+class AtenAnyOp:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenAnyOp, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenAnyDimOp:
+ def __init__(self, self_: Value, dim: int, keepdim: bool, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(dim):
+ dim = torch_dialect.ConstantIntOp(dim)
+ else:
+ dim = get_op_result_or_value(dim)
+ assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}'
+
+ if not is_mlir_value(keepdim):
+ keepdim = torch_dialect.ConstantBoolOp(keepdim)
+ else:
+ keepdim = get_op_result_or_value(keepdim)
+ assert str(keepdim.type) == '!torch.bool', f'`keepdim` should be a !torch.bool but is {type(keepdim).__module__}.{type(keepdim).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenAnyDimOp, self).__init__(result_type, self_, dim, keepdim, loc=loc, ip=ip)
+
+
+class AtenArangeStartOutOp:
+ def __init__(self, start: "Number", end: "Number", step: "Number", out: Value, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(start):
+ start = torch_dialect.ConstantNumberOp(start)
+ else:
+ start = get_op_result_or_value(start)
+ assert str(start.type) in {'!torch.float', '!torch.int'}, f'`start` should be a !torch.number but is {type(start).__module__}.{type(start).__name__}'
+
+ if not is_mlir_value(end):
+ end = torch_dialect.ConstantNumberOp(end)
+ else:
+ end = get_op_result_or_value(end)
+ assert str(end.type) in {'!torch.float', '!torch.int'}, f'`end` should be a !torch.number but is {type(end).__module__}.{type(end).__name__}'
+
+ if not is_mlir_value(step):
+ step = torch_dialect.ConstantNumberOp(step)
+ else:
+ step = get_op_result_or_value(step)
+ assert str(step.type) in {'!torch.float', '!torch.int'}, f'`step` should be a !torch.number but is {type(step).__module__}.{type(step).__name__}'
+
+ if not is_mlir_value(out):
+ assert is_mlir_value(out), f'`out` should be a Value but is {type(out).__module__}.{type(out).__name__}'
+ else:
+ out = get_op_result_or_value(out)
+ assert str(out.type).startswith("!torch.vtensor"), f'`out` should be a torch.vtensor but is {type(out).__module__}.{type(out).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenArangeStartOutOp, self).__init__(result_type, start, end, step, out, loc=loc, ip=ip)
+
+
+class AtenArgmaxOp:
+ def __init__(self, self_: Value, dim: Optional[int], keepdim: bool, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(dim):
+ if dim is not None:
+ dim = torch_dialect.ConstantIntOp(dim)
+ else:
+ dim = torch_dialect.ConstantNoneOp()
+ else:
+ dim = get_op_result_or_value(dim)
+ assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}'
+
+ if not is_mlir_value(keepdim):
+ keepdim = torch_dialect.ConstantBoolOp(keepdim)
+ else:
+ keepdim = get_op_result_or_value(keepdim)
+ assert str(keepdim.type) == '!torch.bool', f'`keepdim` should be a !torch.bool but is {type(keepdim).__module__}.{type(keepdim).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenArgmaxOp, self).__init__(result_type, self_, dim, keepdim, loc=loc, ip=ip)
+
+
+class AtenBucketizeTensorOp:
+ def __init__(self, self_: Value, boundaries: Value, out_int32: bool, right: bool, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(boundaries):
+ assert is_mlir_value(boundaries), f'`boundaries` should be a Value but is {type(boundaries).__module__}.{type(boundaries).__name__}'
+ else:
+ boundaries = get_op_result_or_value(boundaries)
+ assert str(boundaries.type).startswith("!torch.vtensor"), f'`boundaries` should be a torch.vtensor but is {type(boundaries).__module__}.{type(boundaries).__name__}'
+
+ if not is_mlir_value(out_int32):
+ out_int32 = torch_dialect.ConstantBoolOp(out_int32)
+ else:
+ out_int32 = get_op_result_or_value(out_int32)
+ assert str(out_int32.type) == '!torch.bool', f'`out_int32` should be a !torch.bool but is {type(out_int32).__module__}.{type(out_int32).__name__}'
+
+ if not is_mlir_value(right):
+ right = torch_dialect.ConstantBoolOp(right)
+ else:
+ right = get_op_result_or_value(right)
+ assert str(right.type) == '!torch.bool', f'`right` should be a !torch.bool but is {type(right).__module__}.{type(right).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenBucketizeTensorOp, self).__init__(result_type, self_, boundaries, out_int32, right, loc=loc, ip=ip)
+
+
+class AtenCloneOp:
+ def __init__(self, self_: Value, memory_format: Optional[int], *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(memory_format):
+ if memory_format is not None:
+ memory_format = torch_dialect.ConstantIntOp(memory_format)
+ else:
+ memory_format = torch_dialect.ConstantNoneOp()
+ else:
+ memory_format = get_op_result_or_value(memory_format)
+ assert str(memory_format.type) == '!torch.int', f'`memory_format` should be a !torch.int but is {type(memory_format).__module__}.{type(memory_format).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenCloneOp, self).__init__(result_type, self_, memory_format, loc=loc, ip=ip)
+
+
+class AtenLiftFreshCopyOp:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenLiftFreshCopyOp, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenContiguousOp:
+ def __init__(self, self_: Value, memory_format: int, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(memory_format):
+ memory_format = torch_dialect.ConstantIntOp(memory_format)
+ else:
+ memory_format = get_op_result_or_value(memory_format)
+ assert str(memory_format.type) == '!torch.int', f'`memory_format` should be a !torch.int but is {type(memory_format).__module__}.{type(memory_format).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenContiguousOp, self).__init__(result_type, self_, memory_format, loc=loc, ip=ip)
+
+
+class AtenCopyOp:
+ def __init__(self, self_: Value, src: Value, non_blocking: bool, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(src):
+ assert is_mlir_value(src), f'`src` should be a Value but is {type(src).__module__}.{type(src).__name__}'
+ else:
+ src = get_op_result_or_value(src)
+ assert str(src.type).startswith("!torch.vtensor"), f'`src` should be a torch.vtensor but is {type(src).__module__}.{type(src).__name__}'
+
+ if not is_mlir_value(non_blocking):
+ non_blocking = torch_dialect.ConstantBoolOp(non_blocking)
+ else:
+ non_blocking = get_op_result_or_value(non_blocking)
+ assert str(non_blocking.type) == '!torch.bool', f'`non_blocking` should be a !torch.bool but is {type(non_blocking).__module__}.{type(non_blocking).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenCopyOp, self).__init__(result_type, self_, src, non_blocking, loc=loc, ip=ip)
+
+
+class AtenCopy_Op:
+ def __init__(self, self_: Value, src: Value, non_blocking: bool, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(src):
+ assert is_mlir_value(src), f'`src` should be a Value but is {type(src).__module__}.{type(src).__name__}'
+ else:
+ src = get_op_result_or_value(src)
+ assert str(src.type).startswith("!torch.vtensor"), f'`src` should be a torch.vtensor but is {type(src).__module__}.{type(src).__name__}'
+
+ if not is_mlir_value(non_blocking):
+ non_blocking = torch_dialect.ConstantBoolOp(non_blocking)
+ else:
+ non_blocking = get_op_result_or_value(non_blocking)
+ assert str(non_blocking.type) == '!torch.bool', f'`non_blocking` should be a !torch.bool but is {type(non_blocking).__module__}.{type(non_blocking).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenCopy_Op, self).__init__(result_type, self_, src, non_blocking, loc=loc, ip=ip)
+
+
+class AtenDetachOp:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenDetachOp, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenEmbeddingOp:
+ def __init__(self, weight: Value, indices: Value, padding_idx: int, scale_grad_by_freq: bool, sparse: bool, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(weight):
+ assert is_mlir_value(weight), f'`weight` should be a Value but is {type(weight).__module__}.{type(weight).__name__}'
+ else:
+ weight = get_op_result_or_value(weight)
+ assert str(weight.type).startswith("!torch.vtensor"), f'`weight` should be a torch.vtensor but is {type(weight).__module__}.{type(weight).__name__}'
+
+ if not is_mlir_value(indices):
+ assert is_mlir_value(indices), f'`indices` should be a Value but is {type(indices).__module__}.{type(indices).__name__}'
+ else:
+ indices = get_op_result_or_value(indices)
+ assert str(indices.type).startswith("!torch.vtensor"), f'`indices` should be a torch.vtensor but is {type(indices).__module__}.{type(indices).__name__}'
+
+ if not is_mlir_value(padding_idx):
+ padding_idx = torch_dialect.ConstantIntOp(padding_idx)
+ else:
+ padding_idx = get_op_result_or_value(padding_idx)
+ assert str(padding_idx.type) == '!torch.int', f'`padding_idx` should be a !torch.int but is {type(padding_idx).__module__}.{type(padding_idx).__name__}'
+
+ if not is_mlir_value(scale_grad_by_freq):
+ scale_grad_by_freq = torch_dialect.ConstantBoolOp(scale_grad_by_freq)
+ else:
+ scale_grad_by_freq = get_op_result_or_value(scale_grad_by_freq)
+ assert str(scale_grad_by_freq.type) == '!torch.bool', f'`scale_grad_by_freq` should be a !torch.bool but is {type(scale_grad_by_freq).__module__}.{type(scale_grad_by_freq).__name__}'
+
+ if not is_mlir_value(sparse):
+ sparse = torch_dialect.ConstantBoolOp(sparse)
+ else:
+ sparse = get_op_result_or_value(sparse)
+ assert str(sparse.type) == '!torch.bool', f'`sparse` should be a !torch.bool but is {type(sparse).__module__}.{type(sparse).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenEmbeddingOp, self).__init__(result_type, weight, indices, padding_idx, scale_grad_by_freq, sparse, loc=loc, ip=ip)
+
+
+class AtenEmbeddingBagPaddingIdxOp:
+ def __init__(self, weight: Value, indices: Value, offsets: Value, scale_grad_by_freq: bool, mode: int, sparse: bool, per_sample_weights: Optional[Value], include_last_offset: bool, padding_idx: Optional[int], *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(weight):
+ assert is_mlir_value(weight), f'`weight` should be a Value but is {type(weight).__module__}.{type(weight).__name__}'
+ else:
+ weight = get_op_result_or_value(weight)
+ assert str(weight.type).startswith("!torch.vtensor"), f'`weight` should be a torch.vtensor but is {type(weight).__module__}.{type(weight).__name__}'
+
+ if not is_mlir_value(indices):
+ assert is_mlir_value(indices), f'`indices` should be a Value but is {type(indices).__module__}.{type(indices).__name__}'
+ else:
+ indices = get_op_result_or_value(indices)
+ assert str(indices.type).startswith("!torch.vtensor"), f'`indices` should be a torch.vtensor but is {type(indices).__module__}.{type(indices).__name__}'
+
+ if not is_mlir_value(offsets):
+ assert is_mlir_value(offsets), f'`offsets` should be a Value but is {type(offsets).__module__}.{type(offsets).__name__}'
+ else:
+ offsets = get_op_result_or_value(offsets)
+ assert str(offsets.type).startswith("!torch.vtensor"), f'`offsets` should be a torch.vtensor but is {type(offsets).__module__}.{type(offsets).__name__}'
+
+ if not is_mlir_value(scale_grad_by_freq):
+ scale_grad_by_freq = torch_dialect.ConstantBoolOp(scale_grad_by_freq)
+ else:
+ scale_grad_by_freq = get_op_result_or_value(scale_grad_by_freq)
+ assert str(scale_grad_by_freq.type) == '!torch.bool', f'`scale_grad_by_freq` should be a !torch.bool but is {type(scale_grad_by_freq).__module__}.{type(scale_grad_by_freq).__name__}'
+
+ if not is_mlir_value(mode):
+ mode = torch_dialect.ConstantIntOp(mode)
+ else:
+ mode = get_op_result_or_value(mode)
+ assert str(mode.type) == '!torch.int', f'`mode` should be a !torch.int but is {type(mode).__module__}.{type(mode).__name__}'
+
+ if not is_mlir_value(sparse):
+ sparse = torch_dialect.ConstantBoolOp(sparse)
+ else:
+ sparse = get_op_result_or_value(sparse)
+ assert str(sparse.type) == '!torch.bool', f'`sparse` should be a !torch.bool but is {type(sparse).__module__}.{type(sparse).__name__}'
+
+ if not is_mlir_value(per_sample_weights):
+ if per_sample_weights is not None:
+ assert is_mlir_value(per_sample_weights), f'`per_sample_weights` should be a Value but is {type(per_sample_weights).__module__}.{type(per_sample_weights).__name__}'
+ else:
+ per_sample_weights = torch_dialect.ConstantNoneOp()
+ else:
+ per_sample_weights = get_op_result_or_value(per_sample_weights)
+ assert str(per_sample_weights.type).startswith("!torch.vtensor"), f'`per_sample_weights` should be a torch.vtensor but is {type(per_sample_weights).__module__}.{type(per_sample_weights).__name__}'
+
+ if not is_mlir_value(include_last_offset):
+ include_last_offset = torch_dialect.ConstantBoolOp(include_last_offset)
+ else:
+ include_last_offset = get_op_result_or_value(include_last_offset)
+ assert str(include_last_offset.type) == '!torch.bool', f'`include_last_offset` should be a !torch.bool but is {type(include_last_offset).__module__}.{type(include_last_offset).__name__}'
+
+ if not is_mlir_value(padding_idx):
+ if padding_idx is not None:
+ padding_idx = torch_dialect.ConstantIntOp(padding_idx)
+ else:
+ padding_idx = torch_dialect.ConstantNoneOp()
+ else:
+ padding_idx = get_op_result_or_value(padding_idx)
+ assert str(padding_idx.type) == '!torch.int', f'`padding_idx` should be a !torch.int but is {type(padding_idx).__module__}.{type(padding_idx).__name__}'
+
+ result0_type = Type.parse("!torch.vtensor")
+ result1_type = Type.parse("!torch.vtensor")
+ result2_type = Type.parse("!torch.vtensor")
+ result3_type = Type.parse("!torch.vtensor")
+ super(AtenEmbeddingBagPaddingIdxOp, self).__init__(result0_type, result1_type, result2_type, result3_type, weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx, loc=loc, ip=ip)
+
+
+class Aten_EmbeddingBagOp:
+ def __init__(self, weight: Value, indices: Value, offsets: Value, scale_grad_by_freq: bool, mode: int, sparse: bool, per_sample_weights: Optional[Value], include_last_offset: bool, padding_idx: int, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(weight):
+ assert is_mlir_value(weight), f'`weight` should be a Value but is {type(weight).__module__}.{type(weight).__name__}'
+ else:
+ weight = get_op_result_or_value(weight)
+ assert str(weight.type).startswith("!torch.vtensor"), f'`weight` should be a torch.vtensor but is {type(weight).__module__}.{type(weight).__name__}'
+
+ if not is_mlir_value(indices):
+ assert is_mlir_value(indices), f'`indices` should be a Value but is {type(indices).__module__}.{type(indices).__name__}'
+ else:
+ indices = get_op_result_or_value(indices)
+ assert str(indices.type).startswith("!torch.vtensor"), f'`indices` should be a torch.vtensor but is {type(indices).__module__}.{type(indices).__name__}'
+
+ if not is_mlir_value(offsets):
+ assert is_mlir_value(offsets), f'`offsets` should be a Value but is {type(offsets).__module__}.{type(offsets).__name__}'
+ else:
+ offsets = get_op_result_or_value(offsets)
+ assert str(offsets.type).startswith("!torch.vtensor"), f'`offsets` should be a torch.vtensor but is {type(offsets).__module__}.{type(offsets).__name__}'
+
+ if not is_mlir_value(scale_grad_by_freq):
+ scale_grad_by_freq = torch_dialect.ConstantBoolOp(scale_grad_by_freq)
+ else:
+ scale_grad_by_freq = get_op_result_or_value(scale_grad_by_freq)
+ assert str(scale_grad_by_freq.type) == '!torch.bool', f'`scale_grad_by_freq` should be a !torch.bool but is {type(scale_grad_by_freq).__module__}.{type(scale_grad_by_freq).__name__}'
+
+ if not is_mlir_value(mode):
+ mode = torch_dialect.ConstantIntOp(mode)
+ else:
+ mode = get_op_result_or_value(mode)
+ assert str(mode.type) == '!torch.int', f'`mode` should be a !torch.int but is {type(mode).__module__}.{type(mode).__name__}'
+
+ if not is_mlir_value(sparse):
+ sparse = torch_dialect.ConstantBoolOp(sparse)
+ else:
+ sparse = get_op_result_or_value(sparse)
+ assert str(sparse.type) == '!torch.bool', f'`sparse` should be a !torch.bool but is {type(sparse).__module__}.{type(sparse).__name__}'
+
+ if not is_mlir_value(per_sample_weights):
+ if per_sample_weights is not None:
+ assert is_mlir_value(per_sample_weights), f'`per_sample_weights` should be a Value but is {type(per_sample_weights).__module__}.{type(per_sample_weights).__name__}'
+ else:
+ per_sample_weights = torch_dialect.ConstantNoneOp()
+ else:
+ per_sample_weights = get_op_result_or_value(per_sample_weights)
+ assert str(per_sample_weights.type).startswith("!torch.vtensor"), f'`per_sample_weights` should be a torch.vtensor but is {type(per_sample_weights).__module__}.{type(per_sample_weights).__name__}'
+
+ if not is_mlir_value(include_last_offset):
+ include_last_offset = torch_dialect.ConstantBoolOp(include_last_offset)
+ else:
+ include_last_offset = get_op_result_or_value(include_last_offset)
+ assert str(include_last_offset.type) == '!torch.bool', f'`include_last_offset` should be a !torch.bool but is {type(include_last_offset).__module__}.{type(include_last_offset).__name__}'
+
+ if not is_mlir_value(padding_idx):
+ padding_idx = torch_dialect.ConstantIntOp(padding_idx)
+ else:
+ padding_idx = get_op_result_or_value(padding_idx)
+ assert str(padding_idx.type) == '!torch.int', f'`padding_idx` should be a !torch.int but is {type(padding_idx).__module__}.{type(padding_idx).__name__}'
+
+ result0_type = Type.parse("!torch.vtensor")
+ result1_type = Type.parse("!torch.vtensor")
+ result2_type = Type.parse("!torch.vtensor")
+ result3_type = Type.parse("!torch.vtensor")
+ super(Aten_EmbeddingBagOp, self).__init__(result0_type, result1_type, result2_type, result3_type, weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx, loc=loc, ip=ip)
+
+
+class AtenExpandOp:
+ def __init__(self, self_: Value, size: List[int], implicit: bool, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(size):
+ size = list(map(torch_dialect.ConstantIntOp, size))
+ size = torch_dialect.PrimListConstructOp(size)
+ else:
+ size = get_op_result_or_value(size)
+ assert str(size.type) == '!torch.list', f'`size` should be a !torch.list but is {type(size).__module__}.{type(size).__name__}'
+
+ if not is_mlir_value(implicit):
+ implicit = torch_dialect.ConstantBoolOp(implicit)
+ else:
+ implicit = get_op_result_or_value(implicit)
+ assert str(implicit.type) == '!torch.bool', f'`implicit` should be a !torch.bool but is {type(implicit).__module__}.{type(implicit).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenExpandOp, self).__init__(result_type, self_, size, implicit, loc=loc, ip=ip)
+
+
+class AtenExpandAsOp:
+ def __init__(self, self_: Value, other: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}'
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenExpandAsOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenBroadcastToOp:
+ def __init__(self, self_: Value, size: List[int], *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(size):
+ size = list(map(torch_dialect.ConstantIntOp, size))
+ size = torch_dialect.PrimListConstructOp(size)
+ else:
+ size = get_op_result_or_value(size)
+ assert str(size.type) == '!torch.list', f'`size` should be a !torch.list but is {type(size).__module__}.{type(size).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenBroadcastToOp, self).__init__(result_type, self_, size, loc=loc, ip=ip)
+
+
+class AtenIndexTensorHackedTwinOp:
+ def __init__(self, self_: Value, indices: List[Value], *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(indices):
+ indices = torch_dialect.PrimListConstructOp(indices)
+ else:
+ indices = get_op_result_or_value(indices)
+ assert str(indices.type) == '!torch.list', f'`indices` should be a !torch.list but is {type(indices).__module__}.{type(indices).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenIndexTensorHackedTwinOp, self).__init__(result_type, self_, indices, loc=loc, ip=ip)
+
+
+class AtenIndexSelectOp:
+ def __init__(self, self_: Value, dim: int, index: Value, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(dim):
+ dim = torch_dialect.ConstantIntOp(dim)
+ else:
+ dim = get_op_result_or_value(dim)
+ assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}'
+
+ if not is_mlir_value(index):
+ assert is_mlir_value(index), f'`index` should be a Value but is {type(index).__module__}.{type(index).__name__}'
+ else:
+ index = get_op_result_or_value(index)
+ assert str(index.type).startswith("!torch.vtensor"), f'`index` should be a torch.vtensor but is {type(index).__module__}.{type(index).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenIndexSelectOp, self).__init__(result_type, self_, dim, index, loc=loc, ip=ip)
+
+
+class AtenItemOp:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ super(AtenItemOp, self).__init__(self_, loc=loc, ip=ip)
+
+
+class AtenMaskedSelectOp:
+ def __init__(self, self_: Value, mask: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(mask):
+ assert is_mlir_value(mask), f'`mask` should be a Value but is {type(mask).__module__}.{type(mask).__name__}'
+ else:
+ mask = get_op_result_or_value(mask)
+ assert str(mask.type).startswith("!torch.vtensor"), f'`mask` should be a torch.vtensor but is {type(mask).__module__}.{type(mask).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenMaskedSelectOp, self).__init__(result_type, self_, mask, loc=loc, ip=ip)
+
+
+class AtenNumelOp:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ super(AtenNumelOp, self).__init__(self_, loc=loc, ip=ip)
+
+
+class AtenRepeatOp:
+ def __init__(self, self_: Value, repeats: List[int], *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(repeats):
+ repeats = list(map(torch_dialect.ConstantIntOp, repeats))
+ repeats = torch_dialect.PrimListConstructOp(repeats)
+ else:
+ repeats = get_op_result_or_value(repeats)
+ assert str(repeats.type) == '!torch.list', f'`repeats` should be a !torch.list but is {type(repeats).__module__}.{type(repeats).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenRepeatOp, self).__init__(result_type, self_, repeats, loc=loc, ip=ip)
+
+
+class AtenReshapeOp:
+ def __init__(self, self_: Value, shape: List[int], *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(shape):
+ shape = list(map(torch_dialect.ConstantIntOp, shape))
+ shape = torch_dialect.PrimListConstructOp(shape)
+ else:
+ shape = get_op_result_or_value(shape)
+ assert str(shape.type) == '!torch.list', f'`shape` should be a !torch.list but is {type(shape).__module__}.{type(shape).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenReshapeOp, self).__init__(result_type, self_, shape, loc=loc, ip=ip)
+
+
+class Aten_ReshapeAliasOp:
+ def __init__(self, self_: Value, size: List[int], stride: List[int], *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(size):
+ size = list(map(torch_dialect.ConstantIntOp, size))
+ size = torch_dialect.PrimListConstructOp(size)
+ else:
+ size = get_op_result_or_value(size)
+ assert str(size.type) == '!torch.list', f'`size` should be a !torch.list but is {type(size).__module__}.{type(size).__name__}'
+
+ if not is_mlir_value(stride):
+ stride = list(map(torch_dialect.ConstantIntOp, stride))
+ stride = torch_dialect.PrimListConstructOp(stride)
+ else:
+ stride = get_op_result_or_value(stride)
+ assert str(stride.type) == '!torch.list', f'`stride` should be a !torch.list but is {type(stride).__module__}.{type(stride).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(Aten_ReshapeAliasOp, self).__init__(result_type, self_, size, stride, loc=loc, ip=ip)
+
+
+class AtenResize_Op:
+ def __init__(self, self_: Value, size: List[int], memory_format: Optional[int], *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(size):
+ size = list(map(torch_dialect.ConstantIntOp, size))
+ size = torch_dialect.PrimListConstructOp(size)
+ else:
+ size = get_op_result_or_value(size)
+ assert str(size.type) == '!torch.list', f'`size` should be a !torch.list but is {type(size).__module__}.{type(size).__name__}'
+
+ if not is_mlir_value(memory_format):
+ if memory_format is not None:
+ memory_format = torch_dialect.ConstantIntOp(memory_format)
+ else:
+ memory_format = torch_dialect.ConstantNoneOp()
+ else:
+ memory_format = get_op_result_or_value(memory_format)
+ assert str(memory_format.type) == '!torch.int', f'`memory_format` should be a !torch.int but is {type(memory_format).__module__}.{type(memory_format).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenResize_Op, self).__init__(result_type, self_, size, memory_format, loc=loc, ip=ip)
+
+
+class AtenSelectIntOp:
+ def __init__(self, self_: Value, dim: int, index: int, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(dim):
+ dim = torch_dialect.ConstantIntOp(dim)
+ else:
+ dim = get_op_result_or_value(dim)
+ assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}'
+
+ if not is_mlir_value(index):
+ index = torch_dialect.ConstantIntOp(index)
+ else:
+ index = get_op_result_or_value(index)
+ assert str(index.type) == '!torch.int', f'`index` should be a !torch.int but is {type(index).__module__}.{type(index).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenSelectIntOp, self).__init__(result_type, self_, dim, index, loc=loc, ip=ip)
+
+
+class AtenSizeIntOp:
+ def __init__(self, self_: Value, dim: int, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(dim):
+ dim = torch_dialect.ConstantIntOp(dim)
+ else:
+ dim = get_op_result_or_value(dim)
+ assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}'
+
+ super(AtenSizeIntOp, self).__init__(self_, dim, loc=loc, ip=ip)
+
+
+class AtenStackOp:
+ def __init__(self, tensors: List[Value], dim: int, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(tensors):
+ tensors = torch_dialect.PrimListConstructOp(tensors)
+ else:
+ tensors = get_op_result_or_value(tensors)
+ assert str(tensors.type) == '!torch.list', f'`tensors` should be a !torch.list but is {type(tensors).__module__}.{type(tensors).__name__}'
+
+ if not is_mlir_value(dim):
+ dim = torch_dialect.ConstantIntOp(dim)
+ else:
+ dim = get_op_result_or_value(dim)
+ assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenStackOp, self).__init__(result_type, tensors, dim, loc=loc, ip=ip)
+
+
+class AtenSumOp:
+ def __init__(self, self_: Value, dtype: Optional[int], *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(dtype):
+ if dtype is not None:
+ dtype = torch_dialect.ConstantIntOp(dtype)
+ else:
+ dtype = torch_dialect.ConstantNoneOp()
+ else:
+ dtype = get_op_result_or_value(dtype)
+ assert str(dtype.type) == '!torch.int', f'`dtype` should be a !torch.int but is {type(dtype).__module__}.{type(dtype).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenSumOp, self).__init__(result_type, self_, dtype, loc=loc, ip=ip)
+
+
+class AtenMaxOp:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenMaxOp, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenMaxDimOp:
+ def __init__(self, self_: Value, dim: int, keepdim: bool, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(dim):
+ dim = torch_dialect.ConstantIntOp(dim)
+ else:
+ dim = get_op_result_or_value(dim)
+ assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}'
+
+ if not is_mlir_value(keepdim):
+ keepdim = torch_dialect.ConstantBoolOp(keepdim)
+ else:
+ keepdim = get_op_result_or_value(keepdim)
+ assert str(keepdim.type) == '!torch.bool', f'`keepdim` should be a !torch.bool but is {type(keepdim).__module__}.{type(keepdim).__name__}'
+
+ values_type = Type.parse("!torch.vtensor")
+ indices_type = Type.parse("!torch.vtensor")
+ super(AtenMaxDimOp, self).__init__(values_type, indices_type, self_, dim, keepdim, loc=loc, ip=ip)
+
+
+class AtenAmaxOp:
+ def __init__(self, self_: Value, dim: List[int], keepdim: bool, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(dim):
+ dim = list(map(torch_dialect.ConstantIntOp, dim))
+ dim = torch_dialect.PrimListConstructOp(dim)
+ else:
+ dim = get_op_result_or_value(dim)
+ assert str(dim.type) == '!torch.list', f'`dim` should be a !torch.list but is {type(dim).__module__}.{type(dim).__name__}'
+
+ if not is_mlir_value(keepdim):
+ keepdim = torch_dialect.ConstantBoolOp(keepdim)
+ else:
+ keepdim = get_op_result_or_value(keepdim)
+ assert str(keepdim.type) == '!torch.bool', f'`keepdim` should be a !torch.bool but is {type(keepdim).__module__}.{type(keepdim).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenAmaxOp, self).__init__(result_type, self_, dim, keepdim, loc=loc, ip=ip)
+
+
+class AtenToDtypeOp:
+ def __init__(self, self_: Value, dtype: int, non_blocking: bool, copy: bool, memory_format: Optional[int], *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(dtype):
+ dtype = torch_dialect.ConstantIntOp(dtype)
+ else:
+ dtype = get_op_result_or_value(dtype)
+ assert str(dtype.type) == '!torch.int', f'`dtype` should be a !torch.int but is {type(dtype).__module__}.{type(dtype).__name__}'
+
+ if not is_mlir_value(non_blocking):
+ non_blocking = torch_dialect.ConstantBoolOp(non_blocking)
+ else:
+ non_blocking = get_op_result_or_value(non_blocking)
+ assert str(non_blocking.type) == '!torch.bool', f'`non_blocking` should be a !torch.bool but is {type(non_blocking).__module__}.{type(non_blocking).__name__}'
+
+ if not is_mlir_value(copy):
+ copy = torch_dialect.ConstantBoolOp(copy)
+ else:
+ copy = get_op_result_or_value(copy)
+ assert str(copy.type) == '!torch.bool', f'`copy` should be a !torch.bool but is {type(copy).__module__}.{type(copy).__name__}'
+
+ if not is_mlir_value(memory_format):
+ if memory_format is not None:
+ memory_format = torch_dialect.ConstantIntOp(memory_format)
+ else:
+ memory_format = torch_dialect.ConstantNoneOp()
+ else:
+ memory_format = get_op_result_or_value(memory_format)
+ assert str(memory_format.type) == '!torch.int', f'`memory_format` should be a !torch.int but is {type(memory_format).__module__}.{type(memory_format).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenToDtypeOp, self).__init__(result_type, self_, dtype, non_blocking, copy, memory_format, loc=loc, ip=ip)
+
+
+class AtenToOtherOp:
+ def __init__(self, self_: Value, other: Value, non_blocking: bool, copy: bool, memory_format: Optional[int], *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}'
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}'
+
+ if not is_mlir_value(non_blocking):
+ non_blocking = torch_dialect.ConstantBoolOp(non_blocking)
+ else:
+ non_blocking = get_op_result_or_value(non_blocking)
+ assert str(non_blocking.type) == '!torch.bool', f'`non_blocking` should be a !torch.bool but is {type(non_blocking).__module__}.{type(non_blocking).__name__}'
+
+ if not is_mlir_value(copy):
+ copy = torch_dialect.ConstantBoolOp(copy)
+ else:
+ copy = get_op_result_or_value(copy)
+ assert str(copy.type) == '!torch.bool', f'`copy` should be a !torch.bool but is {type(copy).__module__}.{type(copy).__name__}'
+
+ if not is_mlir_value(memory_format):
+ if memory_format is not None:
+ memory_format = torch_dialect.ConstantIntOp(memory_format)
+ else:
+ memory_format = torch_dialect.ConstantNoneOp()
+ else:
+ memory_format = get_op_result_or_value(memory_format)
+ assert str(memory_format.type) == '!torch.int', f'`memory_format` should be a !torch.int but is {type(memory_format).__module__}.{type(memory_format).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenToOtherOp, self).__init__(result_type, self_, other, non_blocking, copy, memory_format, loc=loc, ip=ip)
+
+
+class AtenTypeAsOp:
+ def __init__(self, self_: Value, other: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}'
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenTypeAsOp, self).__init__(result_type, self_, other, loc=loc, ip=ip)
+
+
+class AtenViewOp:
+ def __init__(self, self_: Value, size: List[int], *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(size):
+ size = list(map(torch_dialect.ConstantIntOp, size))
+ size = torch_dialect.PrimListConstructOp(size)
+ else:
+ size = get_op_result_or_value(size)
+ assert str(size.type) == '!torch.list', f'`size` should be a !torch.list but is {type(size).__module__}.{type(size).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenViewOp, self).__init__(result_type, self_, size, loc=loc, ip=ip)
+
+
+class Aten_UnsafeViewOp:
+ def __init__(self, self_: Value, size: List[int], *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(size):
+ size = list(map(torch_dialect.ConstantIntOp, size))
+ size = torch_dialect.PrimListConstructOp(size)
+ else:
+ size = get_op_result_or_value(size)
+ assert str(size.type) == '!torch.list', f'`size` should be a !torch.list but is {type(size).__module__}.{type(size).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(Aten_UnsafeViewOp, self).__init__(result_type, self_, size, loc=loc, ip=ip)
+
+
+class AtenWhereSelfOp:
+ def __init__(self, condition: Value, self_: Value, other: Value, *, loc=None, ip=None):
+ if not is_mlir_value(condition):
+ assert is_mlir_value(condition), f'`condition` should be a Value but is {type(condition).__module__}.{type(condition).__name__}'
+ else:
+ condition = get_op_result_or_value(condition)
+ assert str(condition.type).startswith("!torch.vtensor"), f'`condition` should be a torch.vtensor but is {type(condition).__module__}.{type(condition).__name__}'
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}'
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenWhereSelfOp, self).__init__(result_type, condition, self_, other, loc=loc, ip=ip)
+
+
+class AtenWhereScalarOp:
+ def __init__(self, condition: Value, self_: "Number", other: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(condition):
+ assert is_mlir_value(condition), f'`condition` should be a Value but is {type(condition).__module__}.{type(condition).__name__}'
+ else:
+ condition = get_op_result_or_value(condition)
+ assert str(condition.type).startswith("!torch.vtensor"), f'`condition` should be a torch.vtensor but is {type(condition).__module__}.{type(condition).__name__}'
+
+ if not is_mlir_value(self_):
+ self_ = torch_dialect.ConstantNumberOp(self_)
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type) in {'!torch.float', '!torch.int'}, f'`self_` should be a !torch.number but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ other = torch_dialect.ConstantNumberOp(other)
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenWhereScalarOp, self).__init__(result_type, condition, self_, other, loc=loc, ip=ip)
+
+
+class AtenWhereScalarOtherOp:
+ def __init__(self, condition: Value, self_: Value, other: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(condition):
+ assert is_mlir_value(condition), f'`condition` should be a Value but is {type(condition).__module__}.{type(condition).__name__}'
+ else:
+ condition = get_op_result_or_value(condition)
+ assert str(condition.type).startswith("!torch.vtensor"), f'`condition` should be a torch.vtensor but is {type(condition).__module__}.{type(condition).__name__}'
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ other = torch_dialect.ConstantNumberOp(other)
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type) in {'!torch.float', '!torch.int'}, f'`other` should be a !torch.number but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenWhereScalarOtherOp, self).__init__(result_type, condition, self_, other, loc=loc, ip=ip)
+
+
+class AtenWhereScalarSelfOp:
+ def __init__(self, condition: Value, self_: "Number", other: Value, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(condition):
+ assert is_mlir_value(condition), f'`condition` should be a Value but is {type(condition).__module__}.{type(condition).__name__}'
+ else:
+ condition = get_op_result_or_value(condition)
+ assert str(condition.type).startswith("!torch.vtensor"), f'`condition` should be a torch.vtensor but is {type(condition).__module__}.{type(condition).__name__}'
+
+ if not is_mlir_value(self_):
+ self_ = torch_dialect.ConstantNumberOp(self_)
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type) in {'!torch.float', '!torch.int'}, f'`self_` should be a !torch.number but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(other):
+ assert is_mlir_value(other), f'`other` should be a Value but is {type(other).__module__}.{type(other).__name__}'
+ else:
+ other = get_op_result_or_value(other)
+ assert str(other.type).startswith("!torch.vtensor"), f'`other` should be a torch.vtensor but is {type(other).__module__}.{type(other).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenWhereScalarSelfOp, self).__init__(result_type, condition, self_, other, loc=loc, ip=ip)
+
+
+class AtenSliceTensorOp:
+ def __init__(self, self_: Value, dim: int, start: Optional[int], end: Optional[int], step: int, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(dim):
+ dim = torch_dialect.ConstantIntOp(dim)
+ else:
+ dim = get_op_result_or_value(dim)
+ assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}'
+
+ if not is_mlir_value(start):
+ if start is not None:
+ start = torch_dialect.ConstantIntOp(start)
+ else:
+ start = torch_dialect.ConstantNoneOp()
+ else:
+ start = get_op_result_or_value(start)
+ assert str(start.type) == '!torch.int', f'`start` should be a !torch.int but is {type(start).__module__}.{type(start).__name__}'
+
+ if not is_mlir_value(end):
+ if end is not None:
+ end = torch_dialect.ConstantIntOp(end)
+ else:
+ end = torch_dialect.ConstantNoneOp()
+ else:
+ end = get_op_result_or_value(end)
+ assert str(end.type) == '!torch.int', f'`end` should be a !torch.int but is {type(end).__module__}.{type(end).__name__}'
+
+ if not is_mlir_value(step):
+ step = torch_dialect.ConstantIntOp(step)
+ else:
+ step = get_op_result_or_value(step)
+ assert str(step.type) == '!torch.int', f'`step` should be a !torch.int but is {type(step).__module__}.{type(step).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenSliceTensorOp, self).__init__(result_type, self_, dim, start, end, step, loc=loc, ip=ip)
+
+
+class AtenLenTensorOp:
+ def __init__(self, t: Value, *, loc=None, ip=None):
+ if not is_mlir_value(t):
+ assert is_mlir_value(t), f'`t` should be a Value but is {type(t).__module__}.{type(t).__name__}'
+ else:
+ t = get_op_result_or_value(t)
+ assert str(t.type).startswith("!torch.vtensor"), f'`t` should be a torch.vtensor but is {type(t).__module__}.{type(t).__name__}'
+
+ super(AtenLenTensorOp, self).__init__(t, loc=loc, ip=ip)
+
+
+class AtenCpuOp:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenCpuOp, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenGatherOp:
+ def __init__(self, self_: Value, dim: int, index: Value, sparse_grad: bool, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(dim):
+ dim = torch_dialect.ConstantIntOp(dim)
+ else:
+ dim = get_op_result_or_value(dim)
+ assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}'
+
+ if not is_mlir_value(index):
+ assert is_mlir_value(index), f'`index` should be a Value but is {type(index).__module__}.{type(index).__name__}'
+ else:
+ index = get_op_result_or_value(index)
+ assert str(index.type).startswith("!torch.vtensor"), f'`index` should be a torch.vtensor but is {type(index).__module__}.{type(index).__name__}'
+
+ if not is_mlir_value(sparse_grad):
+ sparse_grad = torch_dialect.ConstantBoolOp(sparse_grad)
+ else:
+ sparse_grad = get_op_result_or_value(sparse_grad)
+ assert str(sparse_grad.type) == '!torch.bool', f'`sparse_grad` should be a !torch.bool but is {type(sparse_grad).__module__}.{type(sparse_grad).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenGatherOp, self).__init__(result_type, self_, dim, index, sparse_grad, loc=loc, ip=ip)
+
+
+class AtenScatterAddOp:
+ def __init__(self, self_: Value, dim: int, index: Value, src: Value, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(dim):
+ dim = torch_dialect.ConstantIntOp(dim)
+ else:
+ dim = get_op_result_or_value(dim)
+ assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}'
+
+ if not is_mlir_value(index):
+ assert is_mlir_value(index), f'`index` should be a Value but is {type(index).__module__}.{type(index).__name__}'
+ else:
+ index = get_op_result_or_value(index)
+ assert str(index.type).startswith("!torch.vtensor"), f'`index` should be a torch.vtensor but is {type(index).__module__}.{type(index).__name__}'
+
+ if not is_mlir_value(src):
+ assert is_mlir_value(src), f'`src` should be a Value but is {type(src).__module__}.{type(src).__name__}'
+ else:
+ src = get_op_result_or_value(src)
+ assert str(src.type).startswith("!torch.vtensor"), f'`src` should be a torch.vtensor but is {type(src).__module__}.{type(src).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenScatterAddOp, self).__init__(result_type, self_, dim, index, src, loc=loc, ip=ip)
+
+
+class AtenIntImplicitOp:
+ def __init__(self, a: Value, *, loc=None, ip=None):
+ if not is_mlir_value(a):
+ assert is_mlir_value(a), f'`a` should be a Value but is {type(a).__module__}.{type(a).__name__}'
+ else:
+ a = get_op_result_or_value(a)
+ assert str(a.type).startswith("!torch.vtensor"), f'`a` should be a torch.vtensor but is {type(a).__module__}.{type(a).__name__}'
+
+ super(AtenIntImplicitOp, self).__init__(a, loc=loc, ip=ip)
+
+
+class AtenFloatImplicitOp:
+ def __init__(self, a: Value, *, loc=None, ip=None):
+ if not is_mlir_value(a):
+ assert is_mlir_value(a), f'`a` should be a Value but is {type(a).__module__}.{type(a).__name__}'
+ else:
+ a = get_op_result_or_value(a)
+ assert str(a.type).startswith("!torch.vtensor"), f'`a` should be a torch.vtensor but is {type(a).__module__}.{type(a).__name__}'
+
+ super(AtenFloatImplicitOp, self).__init__(a, loc=loc, ip=ip)
+
+
+class AtenIntTensorOp:
+ def __init__(self, a: Value, *, loc=None, ip=None):
+ if not is_mlir_value(a):
+ assert is_mlir_value(a), f'`a` should be a Value but is {type(a).__module__}.{type(a).__name__}'
+ else:
+ a = get_op_result_or_value(a)
+ assert str(a.type).startswith("!torch.vtensor"), f'`a` should be a torch.vtensor but is {type(a).__module__}.{type(a).__name__}'
+
+ super(AtenIntTensorOp, self).__init__(a, loc=loc, ip=ip)
+
+
+class AtenFloatTensorOp:
+ def __init__(self, a: Value, *, loc=None, ip=None):
+ if not is_mlir_value(a):
+ assert is_mlir_value(a), f'`a` should be a Value but is {type(a).__module__}.{type(a).__name__}'
+ else:
+ a = get_op_result_or_value(a)
+ assert str(a.type).startswith("!torch.vtensor"), f'`a` should be a torch.vtensor but is {type(a).__module__}.{type(a).__name__}'
+
+ super(AtenFloatTensorOp, self).__init__(a, loc=loc, ip=ip)
+
+
+class AtenDropoutOp:
+ def __init__(self, input: Value, p: float, train: bool, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(input):
+ assert is_mlir_value(input), f'`input` should be a Value but is {type(input).__module__}.{type(input).__name__}'
+ else:
+ input = get_op_result_or_value(input)
+ assert str(input.type).startswith("!torch.vtensor"), f'`input` should be a torch.vtensor but is {type(input).__module__}.{type(input).__name__}'
+
+ if not is_mlir_value(p):
+ p = torch_dialect.ConstantFloatOp(p)
+ else:
+ p = get_op_result_or_value(p)
+ assert str(p.type) == '!torch.float', f'`p` should be a !torch.float but is {type(p).__module__}.{type(p).__name__}'
+
+ if not is_mlir_value(train):
+ train = torch_dialect.ConstantBoolOp(train)
+ else:
+ train = get_op_result_or_value(train)
+ assert str(train.type) == '!torch.bool', f'`train` should be a !torch.bool but is {type(train).__module__}.{type(train).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenDropoutOp, self).__init__(result_type, input, p, train, loc=loc, ip=ip)
+
+
+class AtenDropout_Op:
+ def __init__(self, self_: Value, p: float, train: bool, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(p):
+ p = torch_dialect.ConstantFloatOp(p)
+ else:
+ p = get_op_result_or_value(p)
+ assert str(p.type) == '!torch.float', f'`p` should be a !torch.float but is {type(p).__module__}.{type(p).__name__}'
+
+ if not is_mlir_value(train):
+ train = torch_dialect.ConstantBoolOp(train)
+ else:
+ train = get_op_result_or_value(train)
+ assert str(train.type) == '!torch.bool', f'`train` should be a !torch.bool but is {type(train).__module__}.{type(train).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenDropout_Op, self).__init__(result_type, self_, p, train, loc=loc, ip=ip)
+
+
+class AtenNativeDropoutOp:
+ def __init__(self, input: Value, p: float, train: Optional[bool], *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(input):
+ assert is_mlir_value(input), f'`input` should be a Value but is {type(input).__module__}.{type(input).__name__}'
+ else:
+ input = get_op_result_or_value(input)
+ assert str(input.type).startswith("!torch.vtensor"), f'`input` should be a torch.vtensor but is {type(input).__module__}.{type(input).__name__}'
+
+ if not is_mlir_value(p):
+ p = torch_dialect.ConstantFloatOp(p)
+ else:
+ p = get_op_result_or_value(p)
+ assert str(p.type) == '!torch.float', f'`p` should be a !torch.float but is {type(p).__module__}.{type(p).__name__}'
+
+ if not is_mlir_value(train):
+ if train is not None:
+ train = torch_dialect.ConstantBoolOp(train)
+ else:
+ train = torch_dialect.ConstantNoneOp()
+ else:
+ train = get_op_result_or_value(train)
+ assert str(train.type) == '!torch.bool', f'`train` should be a !torch.bool but is {type(train).__module__}.{type(train).__name__}'
+
+ result0_type = Type.parse("!torch.vtensor")
+ result1_type = Type.parse("!torch.vtensor")
+ super(AtenNativeDropoutOp, self).__init__(result0_type, result1_type, input, p, train, loc=loc, ip=ip)
+
+
+class AtenTOp:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenTOp, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenNumpyTOp:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenNumpyTOp, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenBaddbmmOp:
+ def __init__(self, self_: Value, batch1: Value, batch2: Value, beta: "Number", alpha: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(batch1):
+ assert is_mlir_value(batch1), f'`batch1` should be a Value but is {type(batch1).__module__}.{type(batch1).__name__}'
+ else:
+ batch1 = get_op_result_or_value(batch1)
+ assert str(batch1.type).startswith("!torch.vtensor"), f'`batch1` should be a torch.vtensor but is {type(batch1).__module__}.{type(batch1).__name__}'
+
+ if not is_mlir_value(batch2):
+ assert is_mlir_value(batch2), f'`batch2` should be a Value but is {type(batch2).__module__}.{type(batch2).__name__}'
+ else:
+ batch2 = get_op_result_or_value(batch2)
+ assert str(batch2.type).startswith("!torch.vtensor"), f'`batch2` should be a torch.vtensor but is {type(batch2).__module__}.{type(batch2).__name__}'
+
+ if not is_mlir_value(beta):
+ beta = torch_dialect.ConstantNumberOp(beta)
+ else:
+ beta = get_op_result_or_value(beta)
+ assert str(beta.type) in {'!torch.float', '!torch.int'}, f'`beta` should be a !torch.number but is {type(beta).__module__}.{type(beta).__name__}'
+
+ if not is_mlir_value(alpha):
+ alpha = torch_dialect.ConstantNumberOp(alpha)
+ else:
+ alpha = get_op_result_or_value(alpha)
+ assert str(alpha.type) in {'!torch.float', '!torch.int'}, f'`alpha` should be a !torch.number but is {type(alpha).__module__}.{type(alpha).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenBaddbmmOp, self).__init__(result_type, self_, batch1, batch2, beta, alpha, loc=loc, ip=ip)
+
+
+class AtenBaddbmm_Op:
+ def __init__(self, self_: Value, batch1: Value, batch2: Value, beta: "Number", alpha: "Number", *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(batch1):
+ assert is_mlir_value(batch1), f'`batch1` should be a Value but is {type(batch1).__module__}.{type(batch1).__name__}'
+ else:
+ batch1 = get_op_result_or_value(batch1)
+ assert str(batch1.type).startswith("!torch.vtensor"), f'`batch1` should be a torch.vtensor but is {type(batch1).__module__}.{type(batch1).__name__}'
+
+ if not is_mlir_value(batch2):
+ assert is_mlir_value(batch2), f'`batch2` should be a Value but is {type(batch2).__module__}.{type(batch2).__name__}'
+ else:
+ batch2 = get_op_result_or_value(batch2)
+ assert str(batch2.type).startswith("!torch.vtensor"), f'`batch2` should be a torch.vtensor but is {type(batch2).__module__}.{type(batch2).__name__}'
+
+ if not is_mlir_value(beta):
+ beta = torch_dialect.ConstantNumberOp(beta)
+ else:
+ beta = get_op_result_or_value(beta)
+ assert str(beta.type) in {'!torch.float', '!torch.int'}, f'`beta` should be a !torch.number but is {type(beta).__module__}.{type(beta).__name__}'
+
+ if not is_mlir_value(alpha):
+ alpha = torch_dialect.ConstantNumberOp(alpha)
+ else:
+ alpha = get_op_result_or_value(alpha)
+ assert str(alpha.type) in {'!torch.float', '!torch.int'}, f'`alpha` should be a !torch.number but is {type(alpha).__module__}.{type(alpha).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenBaddbmm_Op, self).__init__(result_type, self_, batch1, batch2, beta, alpha, loc=loc, ip=ip)
+
+
+class AtenFftFftOp:
+ def __init__(self, self_: Value, n: Optional[int], dim: int, norm: Optional[str], *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(n):
+ if n is not None:
+ n = torch_dialect.ConstantIntOp(n)
+ else:
+ n = torch_dialect.ConstantNoneOp()
+ else:
+ n = get_op_result_or_value(n)
+ assert str(n.type) == '!torch.int', f'`n` should be a !torch.int but is {type(n).__module__}.{type(n).__name__}'
+
+ if not is_mlir_value(dim):
+ dim = torch_dialect.ConstantIntOp(dim)
+ else:
+ dim = get_op_result_or_value(dim)
+ assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}'
+
+ if not is_mlir_value(norm):
+ if norm is not None:
+ norm = torch_dialect.ConstantStrOp(norm)
+ else:
+ norm = torch_dialect.ConstantNoneOp()
+ else:
+ norm = get_op_result_or_value(norm)
+ assert str(norm.type) == '!torch.str', f'`norm` should be a !torch.str but is {type(norm).__module__}.{type(norm).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenFftFftOp, self).__init__(result_type, self_, n, dim, norm, loc=loc, ip=ip)
+
+
+class AtenAliasCopyOp:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenAliasCopyOp, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenAsStridedCopyOp:
+ def __init__(self, self_: Value, size: List[int], stride: List[int], storage_offset: Optional[int], *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(size):
+ size = list(map(torch_dialect.ConstantIntOp, size))
+ size = torch_dialect.PrimListConstructOp(size)
+ else:
+ size = get_op_result_or_value(size)
+ assert str(size.type) == '!torch.list', f'`size` should be a !torch.list but is {type(size).__module__}.{type(size).__name__}'
+
+ if not is_mlir_value(stride):
+ stride = list(map(torch_dialect.ConstantIntOp, stride))
+ stride = torch_dialect.PrimListConstructOp(stride)
+ else:
+ stride = get_op_result_or_value(stride)
+ assert str(stride.type) == '!torch.list', f'`stride` should be a !torch.list but is {type(stride).__module__}.{type(stride).__name__}'
+
+ if not is_mlir_value(storage_offset):
+ if storage_offset is not None:
+ storage_offset = torch_dialect.ConstantIntOp(storage_offset)
+ else:
+ storage_offset = torch_dialect.ConstantNoneOp()
+ else:
+ storage_offset = get_op_result_or_value(storage_offset)
+ assert str(storage_offset.type) == '!torch.int', f'`storage_offset` should be a !torch.int but is {type(storage_offset).__module__}.{type(storage_offset).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenAsStridedCopyOp, self).__init__(result_type, self_, size, stride, storage_offset, loc=loc, ip=ip)
+
+
+class AtenDiagonalCopyOp:
+ def __init__(self, self_: Value, offset: int, dim1: int, dim2: int, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(offset):
+ offset = torch_dialect.ConstantIntOp(offset)
+ else:
+ offset = get_op_result_or_value(offset)
+ assert str(offset.type) == '!torch.int', f'`offset` should be a !torch.int but is {type(offset).__module__}.{type(offset).__name__}'
+
+ if not is_mlir_value(dim1):
+ dim1 = torch_dialect.ConstantIntOp(dim1)
+ else:
+ dim1 = get_op_result_or_value(dim1)
+ assert str(dim1.type) == '!torch.int', f'`dim1` should be a !torch.int but is {type(dim1).__module__}.{type(dim1).__name__}'
+
+ if not is_mlir_value(dim2):
+ dim2 = torch_dialect.ConstantIntOp(dim2)
+ else:
+ dim2 = get_op_result_or_value(dim2)
+ assert str(dim2.type) == '!torch.int', f'`dim2` should be a !torch.int but is {type(dim2).__module__}.{type(dim2).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenDiagonalCopyOp, self).__init__(result_type, self_, offset, dim1, dim2, loc=loc, ip=ip)
+
+
+class AtenExpandCopyOp:
+ def __init__(self, self_: Value, size: List[int], implicit: bool, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(size):
+ size = list(map(torch_dialect.ConstantIntOp, size))
+ size = torch_dialect.PrimListConstructOp(size)
+ else:
+ size = get_op_result_or_value(size)
+ assert str(size.type) == '!torch.list', f'`size` should be a !torch.list but is {type(size).__module__}.{type(size).__name__}'
+
+ if not is_mlir_value(implicit):
+ implicit = torch_dialect.ConstantBoolOp(implicit)
+ else:
+ implicit = get_op_result_or_value(implicit)
+ assert str(implicit.type) == '!torch.bool', f'`implicit` should be a !torch.bool but is {type(implicit).__module__}.{type(implicit).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenExpandCopyOp, self).__init__(result_type, self_, size, implicit, loc=loc, ip=ip)
+
+
+class AtenPermuteCopyOp:
+ def __init__(self, self_: Value, dims: List[int], *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(dims):
+ dims = list(map(torch_dialect.ConstantIntOp, dims))
+ dims = torch_dialect.PrimListConstructOp(dims)
+ else:
+ dims = get_op_result_or_value(dims)
+ assert str(dims.type) == '!torch.list', f'`dims` should be a !torch.list but is {type(dims).__module__}.{type(dims).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenPermuteCopyOp, self).__init__(result_type, self_, dims, loc=loc, ip=ip)
+
+
+class Aten_ReshapeAliasCopyOp:
+ def __init__(self, self_: Value, size: List[int], stride: List[int], *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(size):
+ size = list(map(torch_dialect.ConstantIntOp, size))
+ size = torch_dialect.PrimListConstructOp(size)
+ else:
+ size = get_op_result_or_value(size)
+ assert str(size.type) == '!torch.list', f'`size` should be a !torch.list but is {type(size).__module__}.{type(size).__name__}'
+
+ if not is_mlir_value(stride):
+ stride = list(map(torch_dialect.ConstantIntOp, stride))
+ stride = torch_dialect.PrimListConstructOp(stride)
+ else:
+ stride = get_op_result_or_value(stride)
+ assert str(stride.type) == '!torch.list', f'`stride` should be a !torch.list but is {type(stride).__module__}.{type(stride).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(Aten_ReshapeAliasCopyOp, self).__init__(result_type, self_, size, stride, loc=loc, ip=ip)
+
+
+class AtenSelectCopyIntOp:
+ def __init__(self, self_: Value, dim: int, index: int, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(dim):
+ dim = torch_dialect.ConstantIntOp(dim)
+ else:
+ dim = get_op_result_or_value(dim)
+ assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}'
+
+ if not is_mlir_value(index):
+ index = torch_dialect.ConstantIntOp(index)
+ else:
+ index = get_op_result_or_value(index)
+ assert str(index.type) == '!torch.int', f'`index` should be a !torch.int but is {type(index).__module__}.{type(index).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenSelectCopyIntOp, self).__init__(result_type, self_, dim, index, loc=loc, ip=ip)
+
+
+class AtenDetachCopyOp:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenDetachCopyOp, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenSliceCopyTensorOp:
+ def __init__(self, self_: Value, dim: int, start: Optional[int], end: Optional[int], step: int, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(dim):
+ dim = torch_dialect.ConstantIntOp(dim)
+ else:
+ dim = get_op_result_or_value(dim)
+ assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}'
+
+ if not is_mlir_value(start):
+ if start is not None:
+ start = torch_dialect.ConstantIntOp(start)
+ else:
+ start = torch_dialect.ConstantNoneOp()
+ else:
+ start = get_op_result_or_value(start)
+ assert str(start.type) == '!torch.int', f'`start` should be a !torch.int but is {type(start).__module__}.{type(start).__name__}'
+
+ if not is_mlir_value(end):
+ if end is not None:
+ end = torch_dialect.ConstantIntOp(end)
+ else:
+ end = torch_dialect.ConstantNoneOp()
+ else:
+ end = get_op_result_or_value(end)
+ assert str(end.type) == '!torch.int', f'`end` should be a !torch.int but is {type(end).__module__}.{type(end).__name__}'
+
+ if not is_mlir_value(step):
+ step = torch_dialect.ConstantIntOp(step)
+ else:
+ step = get_op_result_or_value(step)
+ assert str(step.type) == '!torch.int', f'`step` should be a !torch.int but is {type(step).__module__}.{type(step).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenSliceCopyTensorOp, self).__init__(result_type, self_, dim, start, end, step, loc=loc, ip=ip)
+
+
+class AtenSqueezeCopyOp:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenSqueezeCopyOp, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenSqueezeCopyDimOp:
+ def __init__(self, self_: Value, dim: int, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(dim):
+ dim = torch_dialect.ConstantIntOp(dim)
+ else:
+ dim = get_op_result_or_value(dim)
+ assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenSqueezeCopyDimOp, self).__init__(result_type, self_, dim, loc=loc, ip=ip)
+
+
+class AtenTCopyOp:
+ def __init__(self, self_: Value, *, loc=None, ip=None):
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenTCopyOp, self).__init__(result_type, self_, loc=loc, ip=ip)
+
+
+class AtenTransposeCopyIntOp:
+ def __init__(self, self_: Value, dim0: int, dim1: int, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(dim0):
+ dim0 = torch_dialect.ConstantIntOp(dim0)
+ else:
+ dim0 = get_op_result_or_value(dim0)
+ assert str(dim0.type) == '!torch.int', f'`dim0` should be a !torch.int but is {type(dim0).__module__}.{type(dim0).__name__}'
+
+ if not is_mlir_value(dim1):
+ dim1 = torch_dialect.ConstantIntOp(dim1)
+ else:
+ dim1 = get_op_result_or_value(dim1)
+ assert str(dim1.type) == '!torch.int', f'`dim1` should be a !torch.int but is {type(dim1).__module__}.{type(dim1).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenTransposeCopyIntOp, self).__init__(result_type, self_, dim0, dim1, loc=loc, ip=ip)
+
+
+class AtenUnsqueezeCopyOp:
+ def __init__(self, self_: Value, dim: int, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(dim):
+ dim = torch_dialect.ConstantIntOp(dim)
+ else:
+ dim = get_op_result_or_value(dim)
+ assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenUnsqueezeCopyOp, self).__init__(result_type, self_, dim, loc=loc, ip=ip)
+
+
+class AtenViewCopyOp:
+ def __init__(self, self_: Value, size: List[int], *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(size):
+ size = list(map(torch_dialect.ConstantIntOp, size))
+ size = torch_dialect.PrimListConstructOp(size)
+ else:
+ size = get_op_result_or_value(size)
+ assert str(size.type) == '!torch.list', f'`size` should be a !torch.list but is {type(size).__module__}.{type(size).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenViewCopyOp, self).__init__(result_type, self_, size, loc=loc, ip=ip)
+
+
+class AtenViewCopyDtypeOp:
+ def __init__(self, self_: Value, dtype: int, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(dtype):
+ dtype = torch_dialect.ConstantIntOp(dtype)
+ else:
+ dtype = get_op_result_or_value(dtype)
+ assert str(dtype.type) == '!torch.int', f'`dtype` should be a !torch.int but is {type(dtype).__module__}.{type(dtype).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenViewCopyDtypeOp, self).__init__(result_type, self_, dtype, loc=loc, ip=ip)
+
+
+class AtenUnfoldCopyOp:
+ def __init__(self, self_: Value, dimension: int, size: int, step: int, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(dimension):
+ dimension = torch_dialect.ConstantIntOp(dimension)
+ else:
+ dimension = get_op_result_or_value(dimension)
+ assert str(dimension.type) == '!torch.int', f'`dimension` should be a !torch.int but is {type(dimension).__module__}.{type(dimension).__name__}'
+
+ if not is_mlir_value(size):
+ size = torch_dialect.ConstantIntOp(size)
+ else:
+ size = get_op_result_or_value(size)
+ assert str(size.type) == '!torch.int', f'`size` should be a !torch.int but is {type(size).__module__}.{type(size).__name__}'
+
+ if not is_mlir_value(step):
+ step = torch_dialect.ConstantIntOp(step)
+ else:
+ step = get_op_result_or_value(step)
+ assert str(step.type) == '!torch.int', f'`step` should be a !torch.int but is {type(step).__module__}.{type(step).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenUnfoldCopyOp, self).__init__(result_type, self_, dimension, size, step, loc=loc, ip=ip)
+
+
+class AtenSelectScatterOp:
+ def __init__(self, self_: Value, src: Value, dim: int, index: int, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(src):
+ assert is_mlir_value(src), f'`src` should be a Value but is {type(src).__module__}.{type(src).__name__}'
+ else:
+ src = get_op_result_or_value(src)
+ assert str(src.type).startswith("!torch.vtensor"), f'`src` should be a torch.vtensor but is {type(src).__module__}.{type(src).__name__}'
+
+ if not is_mlir_value(dim):
+ dim = torch_dialect.ConstantIntOp(dim)
+ else:
+ dim = get_op_result_or_value(dim)
+ assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}'
+
+ if not is_mlir_value(index):
+ index = torch_dialect.ConstantIntOp(index)
+ else:
+ index = get_op_result_or_value(index)
+ assert str(index.type) == '!torch.int', f'`index` should be a !torch.int but is {type(index).__module__}.{type(index).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenSelectScatterOp, self).__init__(result_type, self_, src, dim, index, loc=loc, ip=ip)
+
+
+class AtenSliceScatterOp:
+ def __init__(self, self_: Value, src: Value, dim: int, start: Optional[int], end: Optional[int], step: int, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(src):
+ assert is_mlir_value(src), f'`src` should be a Value but is {type(src).__module__}.{type(src).__name__}'
+ else:
+ src = get_op_result_or_value(src)
+ assert str(src.type).startswith("!torch.vtensor"), f'`src` should be a torch.vtensor but is {type(src).__module__}.{type(src).__name__}'
+
+ if not is_mlir_value(dim):
+ dim = torch_dialect.ConstantIntOp(dim)
+ else:
+ dim = get_op_result_or_value(dim)
+ assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}'
+
+ if not is_mlir_value(start):
+ if start is not None:
+ start = torch_dialect.ConstantIntOp(start)
+ else:
+ start = torch_dialect.ConstantNoneOp()
+ else:
+ start = get_op_result_or_value(start)
+ assert str(start.type) == '!torch.int', f'`start` should be a !torch.int but is {type(start).__module__}.{type(start).__name__}'
+
+ if not is_mlir_value(end):
+ if end is not None:
+ end = torch_dialect.ConstantIntOp(end)
+ else:
+ end = torch_dialect.ConstantNoneOp()
+ else:
+ end = get_op_result_or_value(end)
+ assert str(end.type) == '!torch.int', f'`end` should be a !torch.int but is {type(end).__module__}.{type(end).__name__}'
+
+ if not is_mlir_value(step):
+ step = torch_dialect.ConstantIntOp(step)
+ else:
+ step = get_op_result_or_value(step)
+ assert str(step.type) == '!torch.int', f'`step` should be a !torch.int but is {type(step).__module__}.{type(step).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenSliceScatterOp, self).__init__(result_type, self_, src, dim, start, end, step, loc=loc, ip=ip)
+
+
+class AtenDiagonalScatterOp:
+ def __init__(self, self_: Value, src: Value, offset: int, dim1: int, dim2: int, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(src):
+ assert is_mlir_value(src), f'`src` should be a Value but is {type(src).__module__}.{type(src).__name__}'
+ else:
+ src = get_op_result_or_value(src)
+ assert str(src.type).startswith("!torch.vtensor"), f'`src` should be a torch.vtensor but is {type(src).__module__}.{type(src).__name__}'
+
+ if not is_mlir_value(offset):
+ offset = torch_dialect.ConstantIntOp(offset)
+ else:
+ offset = get_op_result_or_value(offset)
+ assert str(offset.type) == '!torch.int', f'`offset` should be a !torch.int but is {type(offset).__module__}.{type(offset).__name__}'
+
+ if not is_mlir_value(dim1):
+ dim1 = torch_dialect.ConstantIntOp(dim1)
+ else:
+ dim1 = get_op_result_or_value(dim1)
+ assert str(dim1.type) == '!torch.int', f'`dim1` should be a !torch.int but is {type(dim1).__module__}.{type(dim1).__name__}'
+
+ if not is_mlir_value(dim2):
+ dim2 = torch_dialect.ConstantIntOp(dim2)
+ else:
+ dim2 = get_op_result_or_value(dim2)
+ assert str(dim2.type) == '!torch.int', f'`dim2` should be a !torch.int but is {type(dim2).__module__}.{type(dim2).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenDiagonalScatterOp, self).__init__(result_type, self_, src, offset, dim1, dim2, loc=loc, ip=ip)
+
+
+class AtenAsStridedScatterOp:
+ def __init__(self, self_: Value, src: Value, size: List[int], stride: List[int], storage_offset: Optional[int], *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(src):
+ assert is_mlir_value(src), f'`src` should be a Value but is {type(src).__module__}.{type(src).__name__}'
+ else:
+ src = get_op_result_or_value(src)
+ assert str(src.type).startswith("!torch.vtensor"), f'`src` should be a torch.vtensor but is {type(src).__module__}.{type(src).__name__}'
+
+ if not is_mlir_value(size):
+ size = list(map(torch_dialect.ConstantIntOp, size))
+ size = torch_dialect.PrimListConstructOp(size)
+ else:
+ size = get_op_result_or_value(size)
+ assert str(size.type) == '!torch.list', f'`size` should be a !torch.list but is {type(size).__module__}.{type(size).__name__}'
+
+ if not is_mlir_value(stride):
+ stride = list(map(torch_dialect.ConstantIntOp, stride))
+ stride = torch_dialect.PrimListConstructOp(stride)
+ else:
+ stride = get_op_result_or_value(stride)
+ assert str(stride.type) == '!torch.list', f'`stride` should be a !torch.list but is {type(stride).__module__}.{type(stride).__name__}'
+
+ if not is_mlir_value(storage_offset):
+ if storage_offset is not None:
+ storage_offset = torch_dialect.ConstantIntOp(storage_offset)
+ else:
+ storage_offset = torch_dialect.ConstantNoneOp()
+ else:
+ storage_offset = get_op_result_or_value(storage_offset)
+ assert str(storage_offset.type) == '!torch.int', f'`storage_offset` should be a !torch.int but is {type(storage_offset).__module__}.{type(storage_offset).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenAsStridedScatterOp, self).__init__(result_type, self_, src, size, stride, storage_offset, loc=loc, ip=ip)
+
+
+class AtenUpsampleNearest2dOp:
+ def __init__(self, self_: Value, output_size: List[int], scales_h: Optional[float], scales_w: Optional[float], *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ assert is_mlir_value(self_), f'`self_` should be a Value but is {type(self_).__module__}.{type(self_).__name__}'
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type).startswith("!torch.vtensor"), f'`self_` should be a torch.vtensor but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(output_size):
+ output_size = list(map(torch_dialect.ConstantIntOp, output_size))
+ output_size = torch_dialect.PrimListConstructOp(output_size)
+ else:
+ output_size = get_op_result_or_value(output_size)
+ assert str(output_size.type) == '!torch.list', f'`output_size` should be a !torch.list but is {type(output_size).__module__}.{type(output_size).__name__}'
+
+ if not is_mlir_value(scales_h):
+ if scales_h is not None:
+ scales_h = torch_dialect.ConstantFloatOp(scales_h)
+ else:
+ scales_h = torch_dialect.ConstantNoneOp()
+ else:
+ scales_h = get_op_result_or_value(scales_h)
+ assert str(scales_h.type) == '!torch.float', f'`scales_h` should be a !torch.float but is {type(scales_h).__module__}.{type(scales_h).__name__}'
+
+ if not is_mlir_value(scales_w):
+ if scales_w is not None:
+ scales_w = torch_dialect.ConstantFloatOp(scales_w)
+ else:
+ scales_w = torch_dialect.ConstantNoneOp()
+ else:
+ scales_w = get_op_result_or_value(scales_w)
+ assert str(scales_w.type) == '!torch.float', f'`scales_w` should be a !torch.float but is {type(scales_w).__module__}.{type(scales_w).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenUpsampleNearest2dOp, self).__init__(result_type, self_, output_size, scales_h, scales_w, loc=loc, ip=ip)
+
+
+class Aten__Contains__IntListOp:
+ def __init__(self, l: List[int], item: int, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(l):
+ l = list(map(torch_dialect.ConstantIntOp, l))
+ l = torch_dialect.PrimListConstructOp(l)
+ else:
+ l = get_op_result_or_value(l)
+ assert str(l.type) == '!torch.list', f'`l` should be a !torch.list but is {type(l).__module__}.{type(l).__name__}'
+
+ if not is_mlir_value(item):
+ item = torch_dialect.ConstantIntOp(item)
+ else:
+ item = get_op_result_or_value(item)
+ assert str(item.type) == '!torch.int', f'`item` should be a !torch.int but is {type(item).__module__}.{type(item).__name__}'
+
+ super(Aten__Contains__IntListOp, self).__init__(l, item, loc=loc, ip=ip)
+
+
+class AtenCatOp:
+ def __init__(self, tensors: List[Value], dim: int, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(tensors):
+ tensors = torch_dialect.PrimListConstructOp(tensors)
+ else:
+ tensors = get_op_result_or_value(tensors)
+ assert str(tensors.type) == '!torch.list', f'`tensors` should be a !torch.list but is {type(tensors).__module__}.{type(tensors).__name__}'
+
+ if not is_mlir_value(dim):
+ dim = torch_dialect.ConstantIntOp(dim)
+ else:
+ dim = get_op_result_or_value(dim)
+ assert str(dim.type) == '!torch.int', f'`dim` should be a !torch.int but is {type(dim).__module__}.{type(dim).__name__}'
+
+ result_type = Type.parse("!torch.vtensor")
+ super(AtenCatOp, self).__init__(result_type, tensors, dim, loc=loc, ip=ip)
+
+
+class AtenAppendTOp:
+ def __init__(self, self_: List[Value], el: Value, *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(self_):
+ self_ = torch_dialect.PrimListConstructOp(self_)
+ else:
+ self_ = get_op_result_or_value(self_)
+ assert str(self_.type) == '!torch.list', f'`self_` should be a !torch.list but is {type(self_).__module__}.{type(self_).__name__}'
+
+ if not is_mlir_value(el):
+ assert is_mlir_value(el), f'`el` should be a Value but is {type(el).__module__}.{type(el).__name__}'
+ else:
+ el = get_op_result_or_value(el)
+ assert str(el.type).startswith("!torch.vtensor"), f'`el` should be a torch.vtensor but is {type(el).__module__}.{type(el).__name__}'
+
+ result_type = Type.parse("!torch.list")
+ super(AtenAppendTOp, self).__init__(result_type, self_, el, loc=loc, ip=ip)
+
+
+class AtenAddTOp:
+ def __init__(self, a: List[Value], b: List[Value], *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(a):
+ a = torch_dialect.PrimListConstructOp(a)
+ else:
+ a = get_op_result_or_value(a)
+ assert str(a.type) == '!torch.list', f'`a` should be a !torch.list but is {type(a).__module__}.{type(a).__name__}'
+
+ if not is_mlir_value(b):
+ b = torch_dialect.PrimListConstructOp(b)
+ else:
+ b = get_op_result_or_value(b)
+ assert str(b.type) == '!torch.list', f'`b` should be a !torch.list but is {type(b).__module__}.{type(b).__name__}'
+
+ result_type = Type.parse("!torch.list")
+ super(AtenAddTOp, self).__init__(result_type, a, b, loc=loc, ip=ip)
+
+
+class AtenEqIntListOp:
+ def __init__(self, a: List[int], b: List[int], *, loc=None, ip=None):
+ from torch_mlir.dialects import torch as torch_dialect
+
+ if not is_mlir_value(a):
+ a = list(map(torch_dialect.ConstantIntOp, a))
+ a = torch_dialect.PrimListConstructOp(a)
+ else:
+ a = get_op_result_or_value(a)
+ assert str(a.type) == '!torch.list', f'`a` should be a !torch.list but is {type(a).__module__}.{type(a).__name__}'
+
+ if not is_mlir_value(b):
+ b = list(map(torch_dialect.ConstantIntOp, b))
+ b = torch_dialect.PrimListConstructOp(b)
+ else:
+ b = get_op_result_or_value(b)
+ assert str(b.type) == '!torch.list