From e43ec2f0ad7bb6fdfd1ee1a5a99f067885011234 Mon Sep 17 00:00:00 2001 From: Thomas Alexander Date: Mon, 29 Jul 2024 22:03:40 -0300 Subject: [PATCH] Update parameter/constant op handling for performance (#335) This PR updates how parameters are handled. They are now directly read from "parameter_load" ops which removes the need for QUIR variable analysis and casting which significantly improves performance of large parameter programs by reducing the total number of operations in the program. This also cuts down on the total number of arguments to each circuit/sequence by storing the parameter loads/constants directly in the sequence itself For example before for the program below ```qasm OPENQASM 3; qubit $0; qubit $1; qubit $2; gate sx q { } gate rz(phi) q { } input float[64] theta = 3.14159265358979; input angle phi; sx $0; rz(theta) $0; sx $0; rz(phi) $1; rz(3.141592) $1; rz(theta) $2; rz(phi) $2; bit b; b = measure $0; ``` The MLIR generated is now: ```mlir module { func.func @sx(%arg0: !quir.qubit<1>) attributes {quir.classicalOnly = false} { return } func.func @rz(%arg0: !quir.qubit<1>, %arg1: !quir.angle<64>) attributes {quir.classicalOnly = false} { return } quir.circuit @circuit_0(%arg0: !quir.qubit<1> {quir.physicalId = 0 : i32}, %arg1: !quir.qubit<1> {quir.physicalId = 1 : i32}, %arg2: !quir.qubit<1> {quir.physicalId = 2 : i32}) -> i1 attributes {quir.classicalOnly = false, quir.physicalIds = [0 : i32, 1 : i32, 2 : i32]} { quir.call_gate @sx(%arg0) : (!quir.qubit<1>) -> () %0 = qcs.parameter_load "theta" : !quir.angle<64> {initialValue = 3.14159265358979 : f64} quir.call_gate @rz(%arg0, %0) : (!quir.qubit<1>, !quir.angle<64>) -> () quir.call_gate @sx(%arg0) : (!quir.qubit<1>) -> () %1 = qcs.parameter_load "phi" : !quir.angle<64> quir.call_gate @rz(%arg1, %1) : (!quir.qubit<1>, !quir.angle<64>) -> () %angle = quir.constant #quir.angle<3.1415920000000002> : !quir.angle<64> quir.call_gate @rz(%arg1, %angle) : (!quir.qubit<1>, !quir.angle<64>) -> () %2 = qcs.parameter_load "theta" : !quir.angle<64> {initialValue = 3.14159265358979 : f64} quir.call_gate @rz(%arg2, %2) : (!quir.qubit<1>, !quir.angle<64>) -> () %3 = qcs.parameter_load "phi" : !quir.angle<64> quir.call_gate @rz(%arg2, %3) : (!quir.qubit<1>, !quir.angle<64>) -> () %4 = quir.measure(%arg0) {quir.noFastPathComm, quir.noJunoComm, quir.noJunoUse} : (!quir.qubit<1>) -> i1 quir.return %4 : i1 } func.func @main() -> i32 attributes {quir.classicalOnly = false} { %c0_i32 = arith.constant 0 : i32 %dur = quir.constant #quir.duration<4.500000e+06> : !quir.duration
%c1 = arith.constant 1 : index %c1000 = arith.constant 1000 : index %c0 = arith.constant 0 : index qcs.init scf.for %arg0 = %c0 to %c1000 step %c1 { quir.delay %dur, () : !quir.duration
, () -> () qcs.shot_init {qcs.num_shots = 1000 : i32} %0 = quir.declare_qubit {id = 0 : i32} : !quir.qubit<1> %1 = quir.declare_qubit {id = 1 : i32} : !quir.qubit<1> %2 = quir.declare_qubit {id = 2 : i32} : !quir.qubit<1> %3 = quir.call_circuit @circuit_0(%0, %1, %2) : (!quir.qubit<1>, !quir.qubit<1>, !quir.qubit<1>) -> i1 } {qcs.shot_loop, quir.classicalOnly = false, quir.physicalIds = [0 : i32, 1 : i32, 2 : i32]} qcs.finalize return %c0_i32 : i32 } } ``` --- conan/qasm/conandata.yml | 2 +- conan/qasm/conanfile.py | 2 +- conandata.yml | 2 +- include/Conversion/QUIRToPulse/QUIRToPulse.h | 2 +- .../Dialect/QUIR/Transforms/ExtractCircuits.h | 14 +-- .../Frontend/OpenQASM3/QUIRGenQASM3Visitor.h | 3 + .../Frontend/OpenQASM3/QUIRVariableBuilder.h | 2 +- lib/Conversion/QUIRToPulse/LoadPulseCals.cpp | 2 +- lib/Conversion/QUIRToPulse/QUIRToPulse.cpp | 99 ++++++++++++------- lib/Dialect/Pulse/IR/PulseOps.cpp | 6 +- lib/Dialect/Pulse/Transforms/Scheduling.cpp | 1 + lib/Dialect/QUIR/IR/QUIROps.cpp | 5 +- .../QUIR/Transforms/ExtractCircuits.cpp | 96 ++++++++++-------- .../QUIR/Transforms/ReorderMeasurements.cpp | 19 +++- .../QUIR/Transforms/UnusedVariable.cpp | 63 ++++-------- .../OpenQASM3/QUIRGenQASM3Visitor.cpp | 60 +++++------ .../OpenQASM3/QUIRVariableBuilder.cpp | 50 +++------- ...e-parameter-handling-cfa04a0bd7250401.yaml | 11 +++ .../QUIR/Transforms/extract-circuits.mlir | 6 +- .../QUIR/Transforms/reorder-measurements.mlir | 10 +- .../OpenQASM3/input-output-variables.qasm | 3 - .../OpenQASM3/input-parameters-if.qasm | 7 +- .../OpenQASM3/input-parameters-while.qasm | 3 +- test/Frontend/OpenQASM3/input-parameters.qasm | 22 ++--- test/unittest/quir-dialect.cpp | 18 ++-- 25 files changed, 261 insertions(+), 247 deletions(-) create mode 100644 releasenotes/notes/update-parameter-handling-cfa04a0bd7250401.yaml diff --git a/conan/qasm/conandata.yml b/conan/qasm/conandata.yml index a13e4e302..128270ff4 100644 --- a/conan/qasm/conandata.yml +++ b/conan/qasm/conandata.yml @@ -1,5 +1,5 @@ sources: - hash: "f6d695fd9f18462e65f6290d05ccb4ccb371b288" + hash: "ec7731bf645240a597cd9ebb2c395b114f155ed2" requirements: - "gmp/6.3.0" - "mpfr/4.1.0" diff --git a/conan/qasm/conanfile.py b/conan/qasm/conanfile.py index 3df4f1cc7..95532601a 100644 --- a/conan/qasm/conanfile.py +++ b/conan/qasm/conanfile.py @@ -17,7 +17,7 @@ class QasmConan(ConanFile): name = "qasm" - version = "0.3.2" + version = "0.3.3" url = "https://github.com/openqasm/qe-qasm.git" settings = "os", "compiler", "build_type", "arch" options = {"shared": [True, False], "examples": [True, False]} diff --git a/conandata.yml b/conandata.yml index 24d31df0a..4b85ea0ce 100644 --- a/conandata.yml +++ b/conandata.yml @@ -7,4 +7,4 @@ requirements: - pybind11/2.11.1 - clang-tools-extra/17.0.5-0@ - llvm/17.0.5-0@ - - qasm/0.3.2@qss/stable + - qasm/0.3.3@qss/stable diff --git a/include/Conversion/QUIRToPulse/QUIRToPulse.h b/include/Conversion/QUIRToPulse/QUIRToPulse.h index 8eb26ffb8..f80a66d1f 100644 --- a/include/Conversion/QUIRToPulse/QUIRToPulse.h +++ b/include/Conversion/QUIRToPulse/QUIRToPulse.h @@ -127,7 +127,7 @@ struct QUIRToPulsePass mlir::func::FuncOp &mainFunc); // map of the hashed location of quir angle/duration ops to their converted // pulse ops - std::unordered_map + std::unordered_map classicalQUIROpLocToConvertedPulseOpMap; // port name to Port_CreateOp map diff --git a/include/Dialect/QUIR/Transforms/ExtractCircuits.h b/include/Dialect/QUIR/Transforms/ExtractCircuits.h index d22feace0..935482a16 100644 --- a/include/Dialect/QUIR/Transforms/ExtractCircuits.h +++ b/include/Dialect/QUIR/Transforms/ExtractCircuits.h @@ -25,11 +25,11 @@ #include "Utils/SymbolCacheAnalysis.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/IRMapping.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" -#include "llvm/ADT/SmallVector.h" - +#include #include namespace mlir::quir { @@ -49,14 +49,14 @@ struct ExtractCircuitsPass OpBuilder circuitBuilder); OpBuilder startCircuit(mlir::Location location, OpBuilder topLevelBuilder); void endCircuit(mlir::Operation *firstOp, mlir::Operation *lastOp, - OpBuilder topLevelBuilder, OpBuilder circuitBuilder, - llvm::SmallVector &eraseList); - void addToCircuit(mlir::Operation *currentOp, OpBuilder circuitBuilder, - llvm::SmallVector &eraseList); + OpBuilder topLevelBuilder, OpBuilder circuitBuilder); + void addToCircuit(mlir::Operation *currentOp, OpBuilder circuitBuilder); + uint64_t circuitCount = 0; qssc::utils::SymbolCacheAnalysis *symbolCache{nullptr}; mlir::quir::CircuitOp currentCircuitOp = nullptr; + mlir::IRMapping currentCircuitMapper; mlir::quir::CallCircuitOp newCallCircuitOp; llvm::SmallVector inputTypes; @@ -68,6 +68,8 @@ struct ExtractCircuitsPass std::unordered_map circuitOperands; llvm::SmallVector originalResults; + std::set eraseConstSet; + std::set eraseOpSet; }; // struct ExtractCircuitsPass } // namespace mlir::quir diff --git a/include/Frontend/OpenQASM3/QUIRGenQASM3Visitor.h b/include/Frontend/OpenQASM3/QUIRGenQASM3Visitor.h index 61222f601..0551782ef 100644 --- a/include/Frontend/OpenQASM3/QUIRGenQASM3Visitor.h +++ b/include/Frontend/OpenQASM3/QUIRGenQASM3Visitor.h @@ -321,6 +321,9 @@ class QUIRGenQASM3Visitor : public BaseQASM3Visitor { mlir::Type getQUIRTypeFromDeclaration(const QASM::ASTDeclarationNode *); bool enableParametersWarningEmitted = false; + + /// Cached dummy value for error handling + mlir::Value voidValue; }; } // namespace qssc::frontend::openqasm3 diff --git a/include/Frontend/OpenQASM3/QUIRVariableBuilder.h b/include/Frontend/OpenQASM3/QUIRVariableBuilder.h index 49397e74c..f5085db00 100644 --- a/include/Frontend/OpenQASM3/QUIRVariableBuilder.h +++ b/include/Frontend/OpenQASM3/QUIRVariableBuilder.h @@ -68,7 +68,7 @@ class QUIRVariableBuilder { mlir::Value generateParameterLoad(mlir::Location location, llvm::StringRef variableName, - mlir::Value assignedValue); + double initialValue); /// Generate code for declaring an array (at the builder's current insertion /// point). diff --git a/lib/Conversion/QUIRToPulse/LoadPulseCals.cpp b/lib/Conversion/QUIRToPulse/LoadPulseCals.cpp index e6d7d33b4..c57889a53 100644 --- a/lib/Conversion/QUIRToPulse/LoadPulseCals.cpp +++ b/lib/Conversion/QUIRToPulse/LoadPulseCals.cpp @@ -152,7 +152,7 @@ void LoadPulseCalsPass::loadPulseCals(CallCircuitOp callCircuitOp, LLVM_DEBUG(llvm::dbgs() << "no pulse cal loading needed for " << op); assert((!op->hasTrait() and !op->hasTrait()) && - "unkown operation"); + "unknown operation"); } }); } diff --git a/lib/Conversion/QUIRToPulse/QUIRToPulse.cpp b/lib/Conversion/QUIRToPulse/QUIRToPulse.cpp index 8b65fc517..3330bdc6e 100644 --- a/lib/Conversion/QUIRToPulse/QUIRToPulse.cpp +++ b/lib/Conversion/QUIRToPulse/QUIRToPulse.cpp @@ -100,8 +100,10 @@ void QUIRToPulsePass::runOnOperation() { moduleOp->walk([&](CallCircuitOp callCircOp) { if (isa(callCircOp->getParentOp())) return; + auto convertedPulseCallSequenceOp = convertCircuitToSequence(callCircOp, mainFunc, moduleOp); + if (!callCircOp->use_empty()) callCircOp->replaceAllUsesWith(convertedPulseCallSequenceOp); callCircOp->erase(); @@ -229,8 +231,9 @@ QUIRToPulsePass::convertCircuitToSequence(CallCircuitOp &callCircuitOp, auto *newDelayCyclesOp = builder.clone(*quirOp); newDelayCyclesOp->moveAfter(callCircuitOp); } else - assert(((isa(quirOp) or isa(quirOp) or - isa(quirOp))) && + assert(((isa(quirOp) || + isa(quirOp) || + isa(quirOp) || isa(quirOp))) && "quir op is not allowed in this pass."); }); @@ -251,6 +254,7 @@ QUIRToPulsePass::convertCircuitToSequence(CallCircuitOp &callCircuitOp, convertedPulseSequenceOp, convertedPulseSequenceOpArgs); convertedPulseCallSequenceOp->moveAfter(callCircuitOp); + return convertedPulseCallSequenceOp; } @@ -286,7 +290,7 @@ void QUIRToPulsePass::processCircuitArgs( } else if (argumentType.isa()) { auto *qubitOp = callCircuitOp.getOperand(cnt).getDefiningOp(); } else - llvm_unreachable("unkown circuit argument."); + llvm_unreachable("unknown circuit argument."); } } @@ -339,7 +343,7 @@ void QUIRToPulsePass::processPulseCalArgs( } else if (argumentType.isa()) { assert(argAttr[index].dyn_cast().getValue().str() == "angle" && - "unkown argument."); + "unknown argument."); assert(angleOperands.size() && "no angle operand found."); auto nextAngle = angleOperands.front(); LLVM_DEBUG(llvm::dbgs() << "angle argument "); @@ -350,7 +354,7 @@ void QUIRToPulsePass::processPulseCalArgs( } else if (argumentType.isa()) { assert(argAttr[index].dyn_cast().getValue().str() == "duration" && - "unkown argument."); + "unknown argument."); assert(durationOperands.size() && "no duration operand found."); auto nextDuration = durationOperands.front(); LLVM_DEBUG(llvm::dbgs() << "duration argument "); @@ -359,7 +363,7 @@ void QUIRToPulsePass::processPulseCalArgs( pulseCalSequenceArgs, builder); durationOperands.pop(); } else - llvm_unreachable("unkown argument type."); + llvm_unreachable("unknown argument type."); } } @@ -379,12 +383,13 @@ void QUIRToPulsePass::getQUIROpClassicalOperands( } for (auto operand : classicalOperands) - if (operand.getType().isa()) + if (operand.getType().isa() || + operand.getType().isa()) angleOperands.push(operand); else if (operand.getType().isa()) durationOperands.push(operand); else - llvm_unreachable("unkown operand."); + llvm_unreachable("unknown operand."); } void QUIRToPulsePass::processMixFrameOpArg( @@ -463,21 +468,38 @@ void QUIRToPulsePass::processAngleArg(Value nextAngleOperand, pulseCalSequenceArgs.push_back( convertedPulseSequenceOp .getArguments()[circuitArgToConvertedSequenceArgMap[circNum]]); - } else { - auto angleOp = nextAngleOperand.getDefiningOp(); - std::string const angleLocHash = - std::to_string(mlir::hash_value(angleOp->getLoc())); - if (classicalQUIROpLocToConvertedPulseOpMap.find(angleLocHash) == + } else if (auto angleOp = + nextAngleOperand.getDefiningOp()) { + auto *op = angleOp.getOperation(); + if (classicalQUIROpLocToConvertedPulseOpMap.find(op) == classicalQUIROpLocToConvertedPulseOpMap.end()) { double const angleVal = angleOp.getAngleValueFromConstant().convertToDouble(); auto f64Angle = entryBuilder.create( angleOp.getLoc(), entryBuilder.getFloatAttr(entryBuilder.getF64Type(), llvm::APFloat(angleVal))); - classicalQUIROpLocToConvertedPulseOpMap[angleLocHash] = f64Angle; + classicalQUIROpLocToConvertedPulseOpMap[op] = f64Angle; } - pulseCalSequenceArgs.push_back( - classicalQUIROpLocToConvertedPulseOpMap[angleLocHash]); + pulseCalSequenceArgs.push_back(classicalQUIROpLocToConvertedPulseOpMap[op]); + } else if (auto paramOp = + nextAngleOperand.getDefiningOp()) { + auto *op = paramOp.getOperation(); + if (classicalQUIROpLocToConvertedPulseOpMap.find(op) == + classicalQUIROpLocToConvertedPulseOpMap.end()) { + + auto newParam = entryBuilder.create( + paramOp->getLoc(), entryBuilder.getF64Type(), + paramOp.getParameterName()); + if (paramOp->hasAttr("initialValue")) { + auto initAttr = paramOp->getAttr("initialValue").dyn_cast(); + if (initAttr) + newParam->setAttr("initialValue", initAttr); + } + + classicalQUIROpLocToConvertedPulseOpMap[op] = newParam; + } + + pulseCalSequenceArgs.push_back(classicalQUIROpLocToConvertedPulseOpMap[op]); } } @@ -501,25 +523,23 @@ void QUIRToPulsePass::processDurationArg( TimeUnits::dt && "this pass only accepts durations with dt unit"); - if (classicalQUIROpLocToConvertedPulseOpMap.find(durLocHash) == + auto *op = durationOp.getOperation(); + if (classicalQUIROpLocToConvertedPulseOpMap.find(op) == classicalQUIROpLocToConvertedPulseOpMap.end()) { auto dur64 = entryBuilder.create( durationOp.getLoc(), entryBuilder.getIntegerAttr(entryBuilder.getI64Type(), uint64_t(durVal))); - classicalQUIROpLocToConvertedPulseOpMap[durLocHash] = dur64; + classicalQUIROpLocToConvertedPulseOpMap[op] = dur64; } - pulseCalSequenceArgs.push_back( - classicalQUIROpLocToConvertedPulseOpMap[durLocHash]); + pulseCalSequenceArgs.push_back(classicalQUIROpLocToConvertedPulseOpMap[op]); } } mlir::Value QUIRToPulsePass::convertAngleToF64(Operation *angleOp, mlir::OpBuilder &builder) { assert(angleOp && "angle op is null"); - std::string const angleLocHash = - std::to_string(mlir::hash_value(angleOp->getLoc())); - if (classicalQUIROpLocToConvertedPulseOpMap.find(angleLocHash) == + if (classicalQUIROpLocToConvertedPulseOpMap.find(angleOp) == classicalQUIROpLocToConvertedPulseOpMap.end()) { if (auto castOp = dyn_cast(angleOp)) { double const angleVal = @@ -528,12 +548,19 @@ mlir::Value QUIRToPulsePass::convertAngleToF64(Operation *angleOp, castOp->getLoc(), builder.getFloatAttr(builder.getF64Type(), llvm::APFloat(angleVal))); f64Angle->moveAfter(castOp); - classicalQUIROpLocToConvertedPulseOpMap[angleLocHash] = f64Angle; + classicalQUIROpLocToConvertedPulseOpMap[angleOp] = f64Angle; } else if (auto castOp = dyn_cast(angleOp)) { - auto angleCastedOp = builder.create( - castOp->getLoc(), builder.getF64Type(), castOp.getRes()); - angleCastedOp->moveAfter(castOp); - classicalQUIROpLocToConvertedPulseOpMap[angleLocHash] = angleCastedOp; + // Just convert to an f64 directly + auto newParam = builder.create( + angleOp->getLoc(), builder.getF64Type(), castOp.getParameterName()); + if (castOp->hasAttr("initialValue")) { + auto initAttr = castOp->getAttr("initialValue").dyn_cast(); + if (initAttr) + newParam->setAttr("initialValue", initAttr); + } + newParam->moveAfter(castOp); + + classicalQUIROpLocToConvertedPulseOpMap[angleOp] = newParam; } else if (auto castOp = dyn_cast(angleOp)) { auto castOpArg = castOp.getArg(); if (auto paramCastOp = @@ -541,28 +568,26 @@ mlir::Value QUIRToPulsePass::convertAngleToF64(Operation *angleOp, auto angleCastedOp = builder.create( paramCastOp->getLoc(), builder.getF64Type(), paramCastOp.getRes()); angleCastedOp->moveAfter(paramCastOp); - classicalQUIROpLocToConvertedPulseOpMap[angleLocHash] = angleCastedOp; + classicalQUIROpLocToConvertedPulseOpMap[angleOp] = angleCastedOp; } else if (auto constOp = dyn_cast(castOpArg.getDefiningOp())) { // if cast from float64 then use directly assert(constOp.getType() == builder.getF64Type() && "expected angle type to be float 64"); - classicalQUIROpLocToConvertedPulseOpMap[angleLocHash] = constOp; + classicalQUIROpLocToConvertedPulseOpMap[angleOp] = constOp; } else llvm_unreachable("castOp arg unknown"); } else llvm_unreachable("angleOp unknown"); } - return classicalQUIROpLocToConvertedPulseOpMap[angleLocHash]; + return classicalQUIROpLocToConvertedPulseOpMap[angleOp]; } mlir::Value QUIRToPulsePass::convertDurationToI64( mlir::quir::CallCircuitOp &callCircuitOp, Operation *durationOp, uint &cnt, mlir::OpBuilder &builder, mlir::func::FuncOp &mainFunc) { assert(durationOp && "duration op is null"); - std::string const durLocHash = - std::to_string(mlir::hash_value(durationOp->getLoc())); - if (classicalQUIROpLocToConvertedPulseOpMap.find(durLocHash) == + if (classicalQUIROpLocToConvertedPulseOpMap.find(durationOp) == classicalQUIROpLocToConvertedPulseOpMap.end()) { if (auto castOp = dyn_cast(durationOp)) { auto durVal = @@ -575,11 +600,11 @@ mlir::Value QUIRToPulsePass::convertDurationToI64( castOp->getLoc(), builder.getIntegerAttr(builder.getI64Type(), uint64_t(durVal))); I64Dur->moveAfter(castOp); - classicalQUIROpLocToConvertedPulseOpMap[durLocHash] = I64Dur; + classicalQUIROpLocToConvertedPulseOpMap[durationOp] = I64Dur; } else - llvm_unreachable("unkown duration op"); + llvm_unreachable("unknown duration op"); } - return classicalQUIROpLocToConvertedPulseOpMap[durLocHash]; + return classicalQUIROpLocToConvertedPulseOpMap[durationOp]; } mlir::pulse::Port_CreateOp diff --git a/lib/Dialect/Pulse/IR/PulseOps.cpp b/lib/Dialect/Pulse/IR/PulseOps.cpp index 6b2d46c22..3c9717037 100644 --- a/lib/Dialect/Pulse/IR/PulseOps.cpp +++ b/lib/Dialect/Pulse/IR/PulseOps.cpp @@ -17,6 +17,7 @@ #include "Dialect/Pulse/IR/PulseOps.h" #include "Dialect/Pulse/IR/PulseTraits.h" +#include "Dialect/QCS/IR/QCSOps.h" #include "Dialect/QUIR/IR/QUIROps.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -356,8 +357,9 @@ LogicalResult verifyClassical_(SequenceOp op) { mlir::Operation *classicalOp = nullptr; WalkResult const result = op->walk([&](Operation *subOp) { if (isa(subOp) || isa(subOp) || - isa(subOp) || isa(subOp) || - isa(subOp) || isa(subOp) || + isa(subOp) || isa(subOp) || + isa(subOp) || isa(subOp) || + isa(subOp) || subOp->hasTrait() || subOp->hasTrait()) return WalkResult::advance(); diff --git a/lib/Dialect/Pulse/Transforms/Scheduling.cpp b/lib/Dialect/Pulse/Transforms/Scheduling.cpp index 8117ed369..547ba3524 100644 --- a/lib/Dialect/Pulse/Transforms/Scheduling.cpp +++ b/lib/Dialect/Pulse/Transforms/Scheduling.cpp @@ -112,6 +112,7 @@ void QuantumCircuitPulseSchedulingPass::scheduleAlap( opEnd = quantumCircuitSequenceOpBlock->rend(); opIt != opEnd; ++opIt) { auto &op = *opIt; + if (auto quantumGateCallSequenceOp = dyn_cast(op)) { // find quantum gate SequenceOp diff --git a/lib/Dialect/QUIR/IR/QUIROps.cpp b/lib/Dialect/QUIR/IR/QUIROps.cpp index 29bef7d33..efe3c9bfd 100644 --- a/lib/Dialect/QUIR/IR/QUIROps.cpp +++ b/lib/Dialect/QUIR/IR/QUIROps.cpp @@ -380,8 +380,9 @@ LogicalResult verifyClassical_(CircuitOp op) { mlir::Operation *classicalOp = nullptr; WalkResult const result = op->walk([&](Operation *subOp) { if (isa(subOp) || isa(subOp) || - isa(subOp) || isa(subOp) || - isa(subOp) || subOp->hasTrait() || + isa(subOp) || isa(subOp) || + isa(subOp) || isa(subOp) || + subOp->hasTrait() || subOp->hasTrait()) return WalkResult::advance(); classicalOp = subOp; diff --git a/lib/Dialect/QUIR/Transforms/ExtractCircuits.cpp b/lib/Dialect/QUIR/Transforms/ExtractCircuits.cpp index 3d9c02462..b0bf5df26 100644 --- a/lib/Dialect/QUIR/Transforms/ExtractCircuits.cpp +++ b/lib/Dialect/QUIR/Transforms/ExtractCircuits.cpp @@ -31,12 +31,12 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/IRMapping.h" #include "mlir/IR/Location.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/Operation.h" #include "mlir/IR/Region.h" #include "mlir/IR/TypeRange.h" +#include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Support/LLVM.h" @@ -93,6 +93,7 @@ OpBuilder ExtractCircuitsPass::startCircuit(Location location, topLevelBuilder.getFunctionType( /*inputs=*/ArrayRef(), /*results=*/ArrayRef())); + currentCircuitMapper = IRMapping(); currentCircuitOp.addEntryBlock(); symbolCache->addCallee(currentCircuitOp); @@ -107,35 +108,51 @@ OpBuilder ExtractCircuitsPass::startCircuit(Location location, return circuitBuilder; } -void ExtractCircuitsPass::addToCircuit( - Operation *currentOp, OpBuilder circuitBuilder, - llvm::SmallVector &eraseList) { +void ExtractCircuitsPass::addToCircuit(Operation *currentOp, + OpBuilder circuitBuilder) { - IRMapping mapper; // add operands to circuit input list for (auto operand : currentOp->getOperands()) { auto *defOp = operand.getDefiningOp(); auto search = circuitOperands.find(defOp); uint argumentIndex = 0; + mlir::Value mappedValue; if (search == circuitOperands.end()) { - argumentIndex = inputValues.size(); - inputValues.push_back(operand); - inputTypes.push_back(operand.getType()); - circuitOperands[defOp] = argumentIndex; - currentCircuitOp.getBody().addArgument(operand.getType(), - currentOp->getLoc()); - if (isa(defOp)) { - auto id = defOp->getAttrOfType("id").getInt(); - phyiscalIds.push_back(id); - argToId[argumentIndex] = id; + // Check if we should embed in the circuit + auto constantLike = (isa(defOp) || + isa(defOp)); + if (constantLike) { + // Don't clone/map if we already have + if (currentCircuitMapper.contains(operand)) + continue; + auto *newDefOp = circuitBuilder.clone(*defOp, currentCircuitMapper); + mappedValue = newDefOp->getResult(0); + // May be used multiple times so we must remove all users + // before erasing. + eraseConstSet.insert(defOp); + } else { + // Otherwise we add to the circuit signature + argumentIndex = inputValues.size(); + inputValues.push_back(operand); + inputTypes.push_back(operand.getType()); + circuitOperands[defOp] = argumentIndex; + currentCircuitOp.getBody().addArgument(operand.getType(), + currentOp->getLoc()); + + if (isa(defOp)) { + auto id = defOp->getAttrOfType("id").getInt(); + phyiscalIds.push_back(id); + argToId[argumentIndex] = id; + } + mappedValue = currentCircuitOp.getArgument(argumentIndex); } } else { argumentIndex = search->second; + mappedValue = currentCircuitOp.getArgument(argumentIndex); } - - mapper.map(operand, currentCircuitOp.getArgument(argumentIndex)); + currentCircuitMapper.map(operand, mappedValue); } - auto *newOp = circuitBuilder.clone(*currentOp, mapper); + auto *newOp = circuitBuilder.clone(*currentOp, currentCircuitMapper); outputTypes.append(newOp->getResultTypes().begin(), newOp->getResultTypes().end()); @@ -143,12 +160,12 @@ void ExtractCircuitsPass::addToCircuit( originalResults.append(currentOp->getResults().begin(), currentOp->getResults().end()); - eraseList.push_back(currentOp); + eraseOpSet.insert(currentOp); } -void ExtractCircuitsPass::endCircuit( - Operation *firstOp, Operation *lastOp, OpBuilder topLevelBuilder, - OpBuilder circuitBuilder, llvm::SmallVector &eraseList) { +void ExtractCircuitsPass::endCircuit(Operation *firstOp, Operation *lastOp, + OpBuilder topLevelBuilder, + OpBuilder circuitBuilder) { LLVM_DEBUG(llvm::dbgs() << "Ending circuit " << currentCircuitOp.getSymName() << "\n"); @@ -189,16 +206,6 @@ void ExtractCircuitsPass::endCircuit( assert(originalResults[cnt].use_empty() && "usage expected to be empty"); } - // erase operations - while (!eraseList.empty()) { - auto *op = eraseList.back(); - eraseList.pop_back(); - assert(op->use_empty() && "operation usage expected to be empty"); - LLVM_DEBUG(llvm::dbgs() << "Erasing: "); - LLVM_DEBUG(op->dump()); - op->erase(); - } - currentCircuitOp = nullptr; } @@ -212,7 +219,6 @@ void ExtractCircuitsPass::processRegion(mlir::Region ®ion, void ExtractCircuitsPass::processBlock(mlir::Block &block, OpBuilder topLevelBuilder, OpBuilder circuitBuilder) { - llvm::SmallVector eraseList; Operation *firstQuantumOp = nullptr; Operation *lastQuantumOp = nullptr; @@ -244,7 +250,7 @@ void ExtractCircuitsPass::processBlock(mlir::Block &block, circuitBuilder = startCircuit(firstQuantumOp->getLoc(), topLevelBuilder); } - addToCircuit(¤tOp, circuitBuilder, eraseList); + addToCircuit(¤tOp, circuitBuilder); continue; } if (terminatesCircuit(currentOp)) { @@ -252,7 +258,7 @@ void ExtractCircuitsPass::processBlock(mlir::Block &block, // progress there is an in progress circuit to be ended. if (currentCircuitOp) { endCircuit(firstQuantumOp, lastQuantumOp, topLevelBuilder, - circuitBuilder, eraseList); + circuitBuilder); } // handle control flow by recursively calling processBlock for control @@ -262,10 +268,8 @@ void ExtractCircuitsPass::processBlock(mlir::Block &block, } } // End of block complete the circuit - if (currentCircuitOp) { - endCircuit(firstQuantumOp, lastQuantumOp, topLevelBuilder, circuitBuilder, - eraseList); - } + if (currentCircuitOp) + endCircuit(firstQuantumOp, lastQuantumOp, topLevelBuilder, circuitBuilder); } void ExtractCircuitsPass::runOnOperation() { @@ -284,6 +288,20 @@ void ExtractCircuitsPass::runOnOperation() { auto const builder = OpBuilder(mainFunc); processRegion(mainFunc.getRegion(), builder, builder); + + // erase operations + for (auto *op : eraseOpSet) { + LLVM_DEBUG(llvm::dbgs() << "Erasing: "); + LLVM_DEBUG(op->dump()); + op->erase(); + } + for (auto *op : eraseConstSet) { + assert(op->use_empty() && "operation usage expected to be empty"); + LLVM_DEBUG(llvm::dbgs() << "Erasing: "); + LLVM_DEBUG(op->dump()); + op->erase(); + } + } // runOnOperation llvm::StringRef ExtractCircuitsPass::getArgument() const { diff --git a/lib/Dialect/QUIR/Transforms/ReorderMeasurements.cpp b/lib/Dialect/QUIR/Transforms/ReorderMeasurements.cpp index bf0126ac6..f74304456 100644 --- a/lib/Dialect/QUIR/Transforms/ReorderMeasurements.cpp +++ b/lib/Dialect/QUIR/Transforms/ReorderMeasurements.cpp @@ -21,6 +21,7 @@ #include "Dialect/QUIR/Transforms/ReorderMeasurements.h" #include "Dialect/OQ3/IR/OQ3Ops.h" +#include "Dialect/QCS/IR/QCSOps.h" #include "Dialect/QUIR/IR/QUIRInterfaces.h" #include "Dialect/QUIR/IR/QUIROps.h" #include "Dialect/QUIR/IR/QUIRTraits.h" @@ -85,11 +86,14 @@ bool mayMoveVariableLoadOp(MeasureOp measureOp, bool mayMoveCastOp(MeasureOp measureOp, oq3::CastOp castOp, MoveListVec &moveList) { bool moveCastOp = false; - auto variableLoadOp = - dyn_cast(castOp.getArg().getDefiningOp()); - if (variableLoadOp) + + auto *definingOp = castOp.getArg().getDefiningOp(); + if (auto variableLoadOp = dyn_cast(definingOp)) moveCastOp = mayMoveVariableLoadOp(measureOp, variableLoadOp, moveList); - auto castMeasureOp = dyn_cast(castOp.getArg().getDefiningOp()); + else if (isa(definingOp)) + moveCastOp = true; + + auto castMeasureOp = dyn_cast(definingOp); if (castMeasureOp) moveCastOp = ((castMeasureOp != measureOp) && (castMeasureOp->isBeforeInBlock(measureOp) || @@ -170,6 +174,13 @@ struct ReorderMeasureAndNonMeasurePat : public OpRewritePattern { mayMoveVariableLoadOp(measureOp, variableLoadOp, moveList); } + // if the defining op is a parameter load op we are are safe + // to move + if (auto parameterLoadOp = dyn_cast(defOp)) { + moveOps = true; + moveList.push_back(parameterLoadOp); + } + auto castOp = dyn_cast(defOp); if (castOp) moveOps = mayMoveCastOp(measureOp, castOp, moveList); diff --git a/lib/Dialect/QUIR/Transforms/UnusedVariable.cpp b/lib/Dialect/QUIR/Transforms/UnusedVariable.cpp index 4c8ad09d4..b19ce10f3 100644 --- a/lib/Dialect/QUIR/Transforms/UnusedVariable.cpp +++ b/lib/Dialect/QUIR/Transforms/UnusedVariable.cpp @@ -23,74 +23,45 @@ #include "Dialect/OQ3/IR/OQ3Ops.h" -#include "mlir/IR/PatternMatch.h" #include "mlir/IR/SymbolTable.h" +#include "mlir/IR/Visitors.h" #include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/StringRef.h" -#include - using namespace mlir; using namespace quir; using namespace oq3; -namespace { -/// This pattern matches on variable declarations that are not marked 'output' -/// and are not followed by a use of the same variable, and removes them -struct UnusedVariablePat : public OpRewritePattern { - UnusedVariablePat(MLIRContext *context, mlir::SymbolUserMap &symbolUses) - : OpRewritePattern(context, /*benefit=*/1), - symbolUses(symbolUses) {} - mlir::SymbolUserMap &symbolUses; - LogicalResult - matchAndRewrite(DeclareVariableOp declOp, - mlir::PatternRewriter &rewriter) const override { +/// +/// \brief Entry point for the pass. +void UnusedVariablePass::runOnOperation() { + mlir::SymbolTableCollection symbolTable; + mlir::SymbolUserMap symbolUsers(symbolTable, getOperation()); + + getOperation()->walk([&](DeclareVariableOp declOp) { if (declOp.isOutputVariable()) - return failure(); + return mlir::WalkResult::advance(); // iterate through uses - for (auto *useOp : symbolUses.getUsers(declOp)) { + for (auto *useOp : symbolUsers.getUsers(declOp)) { if (auto useVariable = dyn_cast(useOp)) { if (!useVariable || !useVariable.use_empty()) - return failure(); + return mlir::WalkResult::advance(); } } // No uses found, so now we can erase all references (just stores) and the // declaration - for (auto *useOp : symbolUses.getUsers(declOp)) - rewriter.eraseOp(useOp); - - rewriter.eraseOp(declOp); - return success(); - } // matchAndRewrite - -}; // struct UnusedVariablePat -} // anonymous namespace - -/// -/// \brief Entry point for the pass. -void UnusedVariablePass::runOnOperation() { - RewritePatternSet patterns(&getContext()); - mlir::GreedyRewriteConfig config; - mlir::SymbolTableCollection symbolTable; - mlir::SymbolUserMap symbolUsers(symbolTable, getOperation()); - - // use cheaper top-down traversal (in this case, bottom-up would not behave - // any differently) - config.useTopDownTraversal = true; - // Disable to improve performance - config.enableRegionSimplification = false; + for (auto *useOp : symbolUsers.getUsers(declOp)) + useOp->erase(); + ; - patterns.add(&getContext(), symbolUsers); + declOp->erase(); - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), - config))) - signalPassFailure(); + return mlir::WalkResult::advance(); + }); } llvm::StringRef UnusedVariablePass::getArgument() const { diff --git a/lib/Frontend/OpenQASM3/QUIRGenQASM3Visitor.cpp b/lib/Frontend/OpenQASM3/QUIRGenQASM3Visitor.cpp index 1069ccaf5..506ab7036 100644 --- a/lib/Frontend/OpenQASM3/QUIRGenQASM3Visitor.cpp +++ b/lib/Frontend/OpenQASM3/QUIRGenQASM3Visitor.cpp @@ -839,11 +839,13 @@ ExpressionValueType QUIRGenQASM3Visitor::visit_(const ASTGateNode *node) { // must be a normal angle variable use if (!assign(pos, param->GetGateParamName())) { if (const auto *const ident = param->GetValueIdentifier()) { - pos = varHandler.generateVariableUse(getLocation(node), ident); - if (pos.getType() != builder.getType(64)) { - pos = circuitParentBuilder.create( - pos.getLoc(), builder.getType(64), pos); - } + + double initialValue = 0.0; + if (param->IsNumber()) + initialValue = param->AsDouble(); + + pos = varHandler.generateParameterLoad( + getLocation(node), ident->GetName(), initialValue); ssaOtherValues.push_back(pos); } else { reportError(node, mlir::DiagnosticSeverity::Error) @@ -1143,27 +1145,15 @@ void QUIRGenQASM3Visitor::visit(const ASTDeclarationNode *node) { case ASTTypeMPDecimal: case ASTTypeMPComplex: { switchCircuit(false, getLocation(node)); - auto variableType = varHandler.resolveQUIRVariableType(node); - auto valOrError = visitAndGetExpressionValue(node->GetExpression()); - varHandler.generateVariableDeclaration( - loc, idNode->GetName(), variableType, - node->GetModifierType() == QASM::ASTTypeInputModifier, - node->GetModifierType() == QASM::ASTTypeOutputModifier); - - if (!valOrError) { - assert(hasFailed && "visitAndGetExpressionValue returned error but did " - "not set state to failed."); - return; - } - auto val = valOrError.get(); + auto variableType = varHandler.resolveQUIRVariableType(node); // generate variable assignment so that they are reinitialized on every // shot. bool genVariableWithVal = true; - // parameter support currently limited to quir::AngleType + // parameter support currently limited to quir::AngleType/Float64Type if (node->GetModifierType() == QASM::ASTTypeInputModifier) { bool genParameter = true; if (!enableParameters) { @@ -1183,17 +1173,26 @@ void QUIRGenQASM3Visitor::visit(const ASTDeclarationNode *node) { genParameter = false; } - if (genParameter) { - auto load = - varHandler.generateParameterLoad(loc, idNode->GetName(), val); - varHandler.generateVariableAssignment(loc, idNode->GetName(), load); + if (genParameter) genVariableWithVal = false; - } } - if (genVariableWithVal) + if (genVariableWithVal) { + auto valOrError = visitAndGetExpressionValue(node->GetExpression()); + if (!valOrError) { + assert(hasFailed && "visitAndGetExpressionValue returned error but did " + "not set state to failed."); + return; + } + auto val = valOrError.get(); varHandler.generateVariableAssignment(loc, idNode->GetName(), val); + varHandler.generateVariableDeclaration( + loc, idNode->GetName(), variableType, + node->GetModifierType() == QASM::ASTTypeInputModifier, + node->GetModifierType() == QASM::ASTTypeOutputModifier); + } + return; } @@ -1442,7 +1441,7 @@ QUIRGenQASM3Visitor::handleAssign(const ASTBinaryOpNode *node) { "set state to failed."); return rightRefOrError; } - Value const rightRef = rightRefOrError.get(); + const Value rightRef = rightRefOrError.get(); return handleAssign(node, rightRef); } @@ -1553,6 +1552,7 @@ QUIRGenQASM3Visitor::visitAndGetExpressionValue(const ASTExpressionNode *node) { BaseQASM3Visitor::visit(node); if (expression) ssaOtherValues.push_back((expression.get())); + return std::move(expression); } @@ -2255,8 +2255,12 @@ QUIRGenQASM3Visitor::visit_(const ASTCastExpressionNode *node) { } mlir::Value QUIRGenQASM3Visitor::createVoidValue(mlir::Location location) { - return builder.create( - location, builder.getZeroAttr(builder.getI1Type())); + // Only create void value for error propagation reasons once + // to avoid adding many unused operations to the program. + if (!voidValue) + voidValue = builder.create( + location, builder.getZeroAttr(builder.getI1Type())); + return voidValue; } mlir::Value QUIRGenQASM3Visitor::createVoidValue(QASM::ASTBase const *node) { diff --git a/lib/Frontend/OpenQASM3/QUIRVariableBuilder.cpp b/lib/Frontend/OpenQASM3/QUIRVariableBuilder.cpp index faa63bb46..b4185fbe6 100644 --- a/lib/Frontend/OpenQASM3/QUIRVariableBuilder.cpp +++ b/lib/Frontend/OpenQASM3/QUIRVariableBuilder.cpp @@ -23,7 +23,6 @@ #include "Dialect/OQ3/IR/OQ3Ops.h" #include "Dialect/QCS/IR/QCSOps.h" -#include "Dialect/QUIR/IR/QUIRAttributes.h" #include "Dialect/QUIR/IR/QUIROps.h" #include "Dialect/QUIR/IR/QUIRTypes.h" @@ -54,6 +53,11 @@ void QUIRVariableBuilder::generateVariableDeclaration( mlir::Location location, llvm::StringRef variableName, mlir::Type type, bool isInputVariable, bool isOutputVariable) { + // Input variables are not used as parameter loads replace them + // for performance reasons. + // TODO: Replace many parameters with array accesses. + if (isInputVariable) + return; // variables are symbols and thus need to be placed directly in a surrounding // Op that contains a symbol table. mlir::OpBuilder::InsertionGuard const g(builder); @@ -73,8 +77,6 @@ void QUIRVariableBuilder::generateVariableDeclaration( lastDeclaration[surroundingModuleOp] = declareOp; // save this to insert after - if (isInputVariable) - declareOp.setInputAttr(builder.getUnitAttr()); if (isOutputVariable) declareOp.setOutputAttr(builder.getUnitAttr()); variables.emplace(variableName.str(), type); @@ -120,49 +122,19 @@ void QUIRVariableBuilder::generateParameterDeclaration( mlir::Value QUIRVariableBuilder::generateParameterLoad(mlir::Location location, llvm::StringRef variableName, - mlir::Value assignedValue) { + double initialValue) { - if (auto constantOp = mlir::dyn_cast( - assignedValue.getDefiningOp())) { - auto op = getClassicalBuilder().create( - location, builder.getType(64), - variableName.str()); - - double initialValue = 0.0; - - auto constFloatAttr = constantOp.getValue().dyn_cast(); - if (constFloatAttr) { - initialValue = constFloatAttr.getValueAsDouble(); - } else { - auto constAngleAttr = - constantOp.getValue().dyn_cast(); - if (constAngleAttr) - initialValue = constAngleAttr.getValue().convertToDouble(); - } + auto op = getClassicalBuilder().create( + location, builder.getType(64), variableName.str()); + // Only store initial value if it is not zero for performance reasons. + if (initialValue != 0.0) { mlir::FloatAttr const floatAttr = getClassicalBuilder().getF64FloatAttr(initialValue); op->setAttr("initialValue", floatAttr); - return op; - } - - // if the source is a arith::ConstantOp cast to angle - if (auto constantOp = mlir::dyn_cast( - assignedValue.getDefiningOp())) { - auto loadOp = getClassicalBuilder().create( - location, constantOp.getType(), variableName.str()); - double initialValue = 0.0; - auto constAttr = constantOp.getValue().dyn_cast(); - if (constAttr) - initialValue = constAttr.getValueAsDouble(); - mlir::FloatAttr const floatAttr = - getClassicalBuilder().getF64FloatAttr(initialValue); - loadOp->setAttr("initialValue", floatAttr); - return loadOp; } - llvm_unreachable( - "Unsupported defining value operation for parameter variable"); + return op; } void QUIRVariableBuilder::generateArrayVariableDeclaration( diff --git a/releasenotes/notes/update-parameter-handling-cfa04a0bd7250401.yaml b/releasenotes/notes/update-parameter-handling-cfa04a0bd7250401.yaml new file mode 100644 index 000000000..f4f63319d --- /dev/null +++ b/releasenotes/notes/update-parameter-handling-cfa04a0bd7250401.yaml @@ -0,0 +1,11 @@ +--- +features: + - | + Handling of ``qcs.parameter_load`` operations has been modified to be more direct + with reads straight from the angle variables. This brings significant performance enhancements + as a large number of MLIR operations have been removed. The consequence is that if an OpenQASM 3 + input parameter value is written to this value will not be dynamically resolved. This could be + fixed in later versions of the compiler by using memref like semantics for parameters. +fixes: + - | + Significant performance enhancements for both constant and parameter gate angles. diff --git a/test/Dialect/QUIR/Transforms/extract-circuits.mlir b/test/Dialect/QUIR/Transforms/extract-circuits.mlir index 00f6a90bc..072f4d115 100644 --- a/test/Dialect/QUIR/Transforms/extract-circuits.mlir +++ b/test/Dialect/QUIR/Transforms/extract-circuits.mlir @@ -18,8 +18,8 @@ module { return } // CHECK: quir.circuit @circuit_0 - // CHECK: quir.delay %arg0, (%arg1) - // CHECK: %0:2 = quir.measure(%arg2, %arg3) + // CHECK: quir.delay %dur, (%arg0) + // CHECK: %0:2 = quir.measure(%arg1, %arg2) // CHECK: quir.return %0#0, %0#1 : i1, i1 // CHECK: quir.circuit @circuit_1 // CHECK: quir.call_gate @x(%arg0) @@ -47,7 +47,7 @@ module { %4:2 = quir.measure(%0, %2) : (!quir.qubit<1>, !quir.qubit<1>) -> (i1, i1) // CHECK-NOT: quir.delay %dur_0, (%1) : !quir.duration
, (!quir.qubit<1>) -> () // CHECK-NOT: %4:2 = quir.measure(%0, %2) : (!quir.qubit<1>, !quir.qubit<1>) -> (i1, i1) - // CHECK: %4:2 = quir.call_circuit @circuit_0(%dur_0, %1, %0, %2) + // CHECK: %4:2 = quir.call_circuit @circuit_0(%1, %0, %2) qcs.parallel_control_flow { // CHECK: qcs.parallel_control_flow scf.if %4#0 { diff --git a/test/Dialect/QUIR/Transforms/reorder-measurements.mlir b/test/Dialect/QUIR/Transforms/reorder-measurements.mlir index da18a489d..602a0ca88 100644 --- a/test/Dialect/QUIR/Transforms/reorder-measurements.mlir +++ b/test/Dialect/QUIR/Transforms/reorder-measurements.mlir @@ -28,10 +28,13 @@ func.func @three(%c : memref<1xi1>, %ind : index, %angle_0 : !quir.angle<64>) { quir.call_gate @rz(%q1, %angle_0) : (!quir.qubit<1>, !quir.angle<64>) -> () %res1 = quir.measure(%q1) : (!quir.qubit<1>) -> (i1) memref.store %res1, %c[%ind] : memref<1xi1> - quir.call_gate @rz(%q2, %angle_0) : (!quir.qubit<1>, !quir.angle<64>) -> () + %angle_1 = "qcs.parameter_load"() {parameter_name = "test"} : () -> !quir.angle<64> + quir.call_gate @rz(%q2, %angle_1) : (!quir.qubit<1>, !quir.angle<64>) -> () quir.call_gate @sx(%q2) : (!quir.qubit<1>) -> () - quir.call_gate @rz(%q2, %angle_0) : (!quir.qubit<1>, !quir.angle<64>) -> () + %angle_2 = quir.constant #quir.angle<3.0> : !quir.angle<64> + quir.call_gate @rz(%q2, %angle_2) : (!quir.qubit<1>, !quir.angle<64>) -> () %res2 = quir.measure(%q2) : (!quir.qubit<1>) -> (i1) +// CHECK: {{.*}} = quir.constant #quir.angle<3.000000e+00> : !quir.angle<64> // CHECK: [[Q00:%.*]] = quir.declare_qubit {id = 0 : i32} : !quir.qubit<1> // CHECK: [[Q01:%.*]] = quir.declare_qubit {id = 1 : i32} : !quir.qubit<1> // CHECK: [[Q02:%.*]] = quir.declare_qubit {id = 2 : i32} : !quir.qubit<1> @@ -41,7 +44,8 @@ func.func @three(%c : memref<1xi1>, %ind : index, %angle_0 : !quir.angle<64>) { // CHECK-NEXT: quir.call_gate @rz([[Q01]], {{%.*}}) : (!quir.qubit<1>, !quir.angle<64>) -> () // CHECK-NEXT: quir.call_gate @sx([[Q01]]) : (!quir.qubit<1>) -> () // CHECK-NEXT: quir.call_gate @rz([[Q01]], {{%.*}}) : (!quir.qubit<1>, !quir.angle<64>) -> () -// CHECK-NEXT: quir.call_gate @rz([[Q02]], {{%.*}}) : (!quir.qubit<1>, !quir.angle<64>) -> () +// CHECK-NEXT: [[ANGLE:%.*]] = qcs.parameter_load "test" : !quir.angle<64> +// CHECK-NEXT: quir.call_gate @rz([[Q02]], [[ANGLE]]) : (!quir.qubit<1>, !quir.angle<64>) -> () // CHECK-NEXT: quir.call_gate @sx([[Q02]]) : (!quir.qubit<1>) -> () // CHECK-NEXT: quir.call_gate @rz([[Q02]], {{%.*}}) : (!quir.qubit<1>, !quir.angle<64>) -> () // CHECK-NEXT: [[RES00:%.*]] = quir.measure([[Q00]]) : (!quir.qubit<1>) -> i1 diff --git a/test/Frontend/OpenQASM3/input-output-variables.qasm b/test/Frontend/OpenQASM3/input-output-variables.qasm index 6304b97a8..86f8485ad 100644 --- a/test/Frontend/OpenQASM3/input-output-variables.qasm +++ b/test/Frontend/OpenQASM3/input-output-variables.qasm @@ -1,6 +1,5 @@ OPENQASM 3.0; // RUN: qss-compiler -X=qasm --emit=ast-pretty %s | FileCheck %s --match-full-lines --check-prefix AST-PRETTY -// RUN: qss-compiler -X=qasm --emit=mlir %s --enable-parameters=false | FileCheck %s --match-full-lines --check-prefix MLIR // RUN: (! qss-compiler -X=qasm --emit=mlir --enable-parameters --enable-circuits-from-qasm %s 2>&1 ) | FileCheck %s --check-prefix CIRCUITS // @@ -28,12 +27,10 @@ input int basis; // CIRCUITS: error: Input parameter basis type error. Input parameters must be angle or float[64]. // AST-PRETTY: DeclarationNode(type=ASTTypeBitset, CBitNode(name=flags, bits=32), inputVariable) -// MLIR-DAG: oq3.declare_variable {input} @flags : !quir.cbit<32> input bit[32] flags; // CIRCUITS: error: Input parameter flags type error. Input parameters must be angle or float[64]. // AST-PRETTY: DeclarationNode(type=ASTTypeBitset, CBitNode(name=result, bits=1), outputVariable) -// MLIR-DAG: oq3.declare_variable {output} @result : !quir.cbit<1> output bit result; // TODO diff --git a/test/Frontend/OpenQASM3/input-parameters-if.qasm b/test/Frontend/OpenQASM3/input-parameters-if.qasm index 20eb19a82..18d77904e 100644 --- a/test/Frontend/OpenQASM3/input-parameters-if.qasm +++ b/test/Frontend/OpenQASM3/input-parameters-if.qasm @@ -25,7 +25,7 @@ bit result; gate x q { } gate rz(phi) q { } -input angle theta = 3.141; +input angle theta; x $2; rz(theta) $2; @@ -50,8 +50,7 @@ is_excited = measure $2; // CHECK: [[QUBIT2:%.*]] = quir.declare_qubit {id = 2 : i32} : !quir.qubit<1> // CHECK: [[QUBIT3:%.*]] = quir.declare_qubit {id = 3 : i32} : !quir.qubit<1> -// CHECK: [[PARAM:%.*]] = qcs.parameter_load "theta" : !quir.angle<64> {initialValue = 3.141000e+00 : f64} -// CHECK: oq3.variable_assign @theta : !quir.angle<64> = [[PARAM]] +// CHECK: [[PARAM:%.*]] = qcs.parameter_load "theta" : !quir.angle<64> // CHECK: [[EXCITED:%.*]] = oq3.variable_load @is_excited : !quir.cbit<1> // CHECK: [[CONST:%[0-9a-z_]+]] = arith.constant 1 : i32 @@ -68,7 +67,7 @@ if (is_excited == 1) { // CHECK: [[COND1:%.*]] = arith.cmpi eq, [[OTHERCAST]], [[CONST]] : i32 // CHECK: scf.if [[COND1]] { if (other == 1){ -// CHECK: [[THETA:%.*]] = oq3.variable_load @theta : !quir.angle<64> +// CHECK: [[THETA:%.*]] = qcs.parameter_load "theta" : !quir.angle<64> // CHECK: quir.call_circuit @circuit_2([[QUBIT2]], [[THETA]]) : (!quir.qubit<1>, !quir.angle<64>) -> () x $2; rz(theta) $2; diff --git a/test/Frontend/OpenQASM3/input-parameters-while.qasm b/test/Frontend/OpenQASM3/input-parameters-while.qasm index d3a73228f..317786f07 100644 --- a/test/Frontend/OpenQASM3/input-parameters-while.qasm +++ b/test/Frontend/OpenQASM3/input-parameters-while.qasm @@ -58,7 +58,6 @@ bit is_excited; // CHECK: func.func @main() -> i32 { // CHECK: scf.for %arg0 = %c0 to %c1000 step %c1 { -// CHECK: {{.*}} = qcs.parameter_load "theta" : !quir.angle<64> {initialValue = 3.141000e+00 : f64} // CHECK: [[QUBIT:%.*]] = quir.declare_qubit {id = 0 : i32} : !quir.qubit<1> // CHECK: scf.while : () -> () { // CHECK: [[N:%.*]] = oq3.variable_load @n : i32 @@ -76,7 +75,7 @@ while (n != 0) { // CHECK: scf.if [[COND2]] { if (is_excited) { - // CHECK: [[THETA:%.*]] = oq3.variable_load @theta : !quir.angle<64> + // CHECK: [[THETA:%.*]] = qcs.parameter_load "theta" : !quir.angle<64> {initialValue = 3.141000e+00 : f64} // CHECK: quir.call_circuit @circuit_2([[QUBIT]], [[THETA]]) : (!quir.qubit<1>, !quir.angle<64>) -> () // CHECK: } h $0; diff --git a/test/Frontend/OpenQASM3/input-parameters.qasm b/test/Frontend/OpenQASM3/input-parameters.qasm index f057a31a4..55879b660 100644 --- a/test/Frontend/OpenQASM3/input-parameters.qasm +++ b/test/Frontend/OpenQASM3/input-parameters.qasm @@ -1,5 +1,5 @@ OPENQASM 3; -// RUN: qss-compiler -X=qasm --emit=mlir --enable-parameters --enable-circuits-from-qasm %s | FileCheck %s --check-prefixes=CHECK,CHECK-XX +// RUN: qss-compiler -X=qasm --emit=mlir --enable-parameters --enable-circuits-from-qasm %s | FileCheck %s --check-prefixes=CHECK // // This code is part of Qiskit. @@ -66,16 +66,10 @@ c = measure $0; // CHECK: %1 = quir.declare_qubit {id = 2 : i32} : !quir.qubit<1> // CHECK: %2 = qcs.parameter_load "theta" : !quir.angle<64> {initialValue = 3.141000e+00 : f64} -// CHECK: oq3.variable_assign @theta : !quir.angle<64> = %2 -// CHECK: %3 = qcs.parameter_load "theta2" : f64 {initialValue = 1.560000e+00 : f64} -// CHECK: oq3.variable_assign @theta2 : f64 = %3 -// CHECK-XX: quir.reset %0 : !quir.qubit<1> -// CHECK-NOT: oq3.variable_assign @theta : !quir.angle<64> = %angle - -// CHECK: quir.call_circuit @circuit_0(%0, %4) : (!quir.qubit<1>, !quir.angle<64>) -> () -// CHECK: %6 = quir.call_circuit @circuit_1(%0) : (!quir.qubit<1>) -> i1 -// CHECK: oq3.cbit_assign_bit @b<1> [0] : i1 = %6 - -// CHECK: %7 = oq3.variable_load @theta2 : f64 -// CHECK: %8 = "oq3.cast"(%7) : (f64) -> !quir.angle<64> -// CHECK: quir.call_circuit @circuit_2(%0, %8) : (!quir.qubit<1>, !quir.angle<64>) -> () + +// CHECK: quir.call_circuit @circuit_0(%0, %2) : (!quir.qubit<1>, !quir.angle<64>) -> () +// CHECK: %4 = quir.call_circuit @circuit_1(%0) : (!quir.qubit<1>) -> i1 +// CHECK: oq3.cbit_assign_bit @b<1> [0] : i1 = %4 + +// CHECK: %5 = qcs.parameter_load "theta2" : !quir.angle<64> {initialValue = 1.560000e+00 : f64} +// CHECK: quir.call_circuit @circuit_2(%0, %5) : (!quir.qubit<1>, !quir.angle<64>) -> () diff --git a/test/unittest/quir-dialect.cpp b/test/unittest/quir-dialect.cpp index c3110bcb0..563fcd387 100644 --- a/test/unittest/quir-dialect.cpp +++ b/test/unittest/quir-dialect.cpp @@ -36,13 +36,13 @@ namespace { class QUIRDialect : public ::testing::Test { protected: mlir::MLIRContext ctx; - mlir::UnknownLoc unkownLoc; + mlir::UnknownLoc unknownLoc; mlir::ModuleOp rootModule; mlir::OpBuilder builder; QUIRDialect() - : unkownLoc(mlir::UnknownLoc::get(&ctx)), - rootModule(mlir::ModuleOp::create(unkownLoc)), builder(rootModule) { + : unknownLoc(mlir::UnknownLoc::get(&ctx)), + rootModule(mlir::ModuleOp::create(unknownLoc)), builder(rootModule) { mlir::DialectRegistry registry; registry.insert(); ctx.appendDialectRegistry(registry); @@ -55,10 +55,10 @@ class QUIRDialect : public ::testing::Test { TEST_F(QUIRDialect, CPTPOpTrait) { auto declareQubitOp = builder.create( - unkownLoc, builder.getType(1), + unknownLoc, builder.getType(1), builder.getIntegerAttr(builder.getI32Type(), 0)); auto reset = builder.create( - unkownLoc, mlir::ValueRange{declareQubitOp.getResult()}); + unknownLoc, mlir::ValueRange{declareQubitOp.getResult()}); EXPECT_FALSE(declareQubitOp->hasTrait()); EXPECT_FALSE(declareQubitOp->hasTrait()); @@ -73,10 +73,10 @@ TEST_F(QUIRDialect, CPTPOpTrait) { TEST_F(QUIRDialect, UnitaryOpTrait) { auto declareQubitOp = builder.create( - unkownLoc, builder.getType(1), + unknownLoc, builder.getType(1), builder.getIntegerAttr(builder.getI32Type(), 0)); auto barrier = builder.create( - unkownLoc, mlir::ValueRange{declareQubitOp.getResult()}); + unknownLoc, mlir::ValueRange{declareQubitOp.getResult()}); EXPECT_TRUE(barrier->hasTrait()); EXPECT_FALSE(barrier->hasTrait()); @@ -88,11 +88,11 @@ TEST_F(QUIRDialect, UnitaryOpTrait) { TEST_F(QUIRDialect, MeasureSideEffects) { auto qubitDecl = builder.create( - unkownLoc, builder.getType(1), + unknownLoc, builder.getType(1), builder.getIntegerAttr(builder.getI32Type(), 0)); auto measureOp = builder.create( - unkownLoc, builder.getI1Type(), qubitDecl.getRes()); + unknownLoc, builder.getI1Type(), qubitDecl.getRes()); EXPECT_TRUE(measureOp);