From cf0b38786c322de28e59f21e66cc08388db1fb3c Mon Sep 17 00:00:00 2001 From: reza-j <23619106+reza-j@users.noreply.github.com> Date: Wed, 3 Apr 2024 09:45:28 -0500 Subject: [PATCH] performance improvements for quir to pulse (#314) improves the performance of quir to pulse --- include/Conversion/QUIRToPulse/QUIRToPulse.h | 21 ++- lib/Conversion/QUIRToPulse/QUIRToPulse.cpp | 130 +++++------------- .../QUIRToPulse/convert-quir-to-pulse.mlir | 26 ++-- 3 files changed, 52 insertions(+), 125 deletions(-) diff --git a/include/Conversion/QUIRToPulse/QUIRToPulse.h b/include/Conversion/QUIRToPulse/QUIRToPulse.h index bd5aff79a..8eb26ffb8 100644 --- a/include/Conversion/QUIRToPulse/QUIRToPulse.h +++ b/include/Conversion/QUIRToPulse/QUIRToPulse.h @@ -30,6 +30,7 @@ #include "mlir/Pass/Pass.h" #include +#include namespace mlir::pulse { @@ -70,9 +71,9 @@ struct QUIRToPulsePass // will be reset every time convertCircuitToSequence is called and will be // used by several functions that are called within that function uint convertedSequenceOpArgIndex; - std::map circuitArgToConvertedSequenceArgMap; + std::unordered_map circuitArgToConvertedSequenceArgMap; SmallVector convertedPulseSequenceOpArgs; - std::vector convertedPulseCallSequenceOpOperandNames; + std::unordered_map operandNameToIndexMap; // process the args of the circuit op, and add corresponding args to the // converted pulse sequence op @@ -126,14 +127,15 @@ struct QUIRToPulsePass mlir::func::FuncOp &mainFunc); // map of the hashed location of quir angle/duration ops to their converted // pulse ops - std::map classicalQUIROpLocToConvertedPulseOpMap; + std::unordered_map + classicalQUIROpLocToConvertedPulseOpMap; // port name to Port_CreateOp map - std::map openedPorts; + std::unordered_map openedPorts; // mixframe name to MixFrameOp map - std::map openedMixFrames; + std::unordered_map openedMixFrames; // waveform name to Waveform_CreateOp map - std::map openedWfrs; + std::unordered_map openedWfrs; // add a port to IR if it's not already added and return the Port_CreateOp mlir::pulse::Port_CreateOp addPortOpToIR(std::string const &portName, mlir::func::FuncOp &mainFunc, @@ -149,14 +151,9 @@ struct QUIRToPulsePass mlir::func::FuncOp &mainFunc, mlir::OpBuilder &builder); - void addCircuitToEraseList(mlir::Operation *op); - void addCircuitOperandToEraseList(mlir::Operation *op); - std::vector quirCircuitEraseList; - std::vector quirCircuitOperandEraseList; - // parse the waveform containers and add them to pulseNameToWaveformMap void parsePulseWaveformContainerOps(std::string &waveformContainerPath); - std::map pulseNameToWaveformMap; + std::unordered_map pulseNameToWaveformMap; qssc::utils::SymbolCacheAnalysis *symbolCache{nullptr}; }; diff --git a/lib/Conversion/QUIRToPulse/QUIRToPulse.cpp b/lib/Conversion/QUIRToPulse/QUIRToPulse.cpp index c2be7e8ab..6bfd4e7b3 100644 --- a/lib/Conversion/QUIRToPulse/QUIRToPulse.cpp +++ b/lib/Conversion/QUIRToPulse/QUIRToPulse.cpp @@ -58,10 +58,8 @@ #include "llvm/Support/SMLoc.h" #include "llvm/Support/SourceMgr.h" -#include #include #include -#include #include #include #include @@ -109,26 +107,17 @@ void QUIRToPulsePass::runOnOperation() { callCircOp->erase(); }); - // erase the quir circuits - LLVM_DEBUG(llvm::dbgs() << "\nErasing quir circuits:\n"); - for (auto *op : quirCircuitEraseList) { - LLVM_DEBUG(op->dump()); - op->erase(); - } - - // erase quir barriers before erasing the operands - moduleOp->walk([&](mlir::quir::BarrierOp barrierOp) { barrierOp->erase(); }); - - // erase the quir circuit operands - LLVM_DEBUG(llvm::dbgs() << "\nErasing quir circuit operands:\n"); - for (auto *op : quirCircuitOperandEraseList) { - LLVM_DEBUG(op->dump()); - op->erase(); - } + // erase circuit ops + moduleOp->walk([&](CircuitOp circOp) { circOp->erase(); }); - // erase the rest of quir.declare_qubits (unused in the input program) - moduleOp->walk([&](mlir::quir::DeclareQubitOp declareQubitOp) { - declareQubitOp->erase(); + // erase qubit ops and constant angle ops + moduleOp->walk([&](Operation *op) { + if (isa(op)) + op->erase(); + else if (auto castOp = dyn_cast(op)) { + if (castOp.getType().isa<::mlir::quir::AngleType>()) + op->erase(); + } }); } @@ -144,7 +133,6 @@ QUIRToPulsePass::convertCircuitToSequence(CallCircuitOp &callCircuitOp, LLVM_DEBUG(llvm::dbgs() << "\nConverting QUIR circuit " << circName << ":\n"); assert(callCircuitOp && "callCircuit op is null"); assert(circuitOp && "circuit op is null"); - addCircuitToEraseList(circuitOp); // build an empty pulse sequence SmallVector arguments; @@ -162,7 +150,7 @@ QUIRToPulsePass::convertCircuitToSequence(CallCircuitOp &callCircuitOp, convertedSequenceOpArgIndex = 0; circuitArgToConvertedSequenceArgMap.clear(); convertedPulseSequenceOpArgs.clear(); - convertedPulseCallSequenceOpOperandNames.clear(); + operandNameToIndexMap.clear(); // convert quir circuit args if not already converted, and add the converted // args to the the converted pulse sequence @@ -200,9 +188,6 @@ QUIRToPulsePass::convertCircuitToSequence(CallCircuitOp &callCircuitOp, attr.getValue()); } - pulseCalCallSequenceOp->setAttr( - "pulse.operands", - pulseCalSequenceOp->getAttrOfType("pulse.args")); for (auto type : pulseCalCallSequenceOp.getResultTypes()) convertedPulseSequenceOpReturnTypes.push_back(type); for (auto val : pulseCalCallSequenceOp.getRes()) @@ -254,12 +239,6 @@ QUIRToPulsePass::convertCircuitToSequence(CallCircuitOp &callCircuitOp, convertedPulseSequenceOp, convertedPulseSequenceOpArgs); convertedPulseCallSequenceOp->moveAfter(callCircuitOp); - convertedPulseSequenceOp->setAttr( - "pulse.args", - builder.getArrayAttr(convertedPulseCallSequenceOpOperandNames)); - convertedPulseCallSequenceOp->setAttr( - "pulse.operands", - builder.getArrayAttr(convertedPulseCallSequenceOpOperandNames)); return convertedPulseCallSequenceOp; } @@ -276,35 +255,26 @@ void QUIRToPulsePass::processCircuitArgs( auto *angleOp = callCircuitOp.getOperand(cnt).getDefiningOp(); LLVM_DEBUG(llvm::dbgs() << "angle argument "); LLVM_DEBUG(angleOp->dump()); - convertedPulseSequenceOp.insertArgument(convertedSequenceOpArgIndex, - builder.getF64Type(), dictArg, - arg.getLoc()); + convertedPulseSequenceOp.getBody().addArgument(builder.getF64Type(), + arg.getLoc()); circuitArgToConvertedSequenceArgMap[cnt] = convertedSequenceOpArgIndex; auto convertedAngleToF64 = convertAngleToF64(angleOp, builder); convertedSequenceOpArgIndex += 1; - convertedPulseCallSequenceOpOperandNames.push_back( - builder.getStringAttr("angle")); convertedPulseSequenceOpArgs.push_back(convertedAngleToF64); } else if (argumentType.isa()) { auto *durationOp = callCircuitOp.getOperand(cnt).getDefiningOp(); LLVM_DEBUG(llvm::dbgs() << "duration argument "); LLVM_DEBUG(durationOp->dump()); - convertedPulseSequenceOp.insertArgument(convertedSequenceOpArgIndex, - builder.getI64Type(), dictArg, - arg.getLoc()); + convertedPulseSequenceOp.getBody().addArgument(builder.getI64Type(), + arg.getLoc()); circuitArgToConvertedSequenceArgMap[cnt] = convertedSequenceOpArgIndex; auto convertedDurationToI64 = convertDurationToI64( callCircuitOp, durationOp, cnt, builder, mainFunc); convertedSequenceOpArgIndex += 1; - convertedPulseCallSequenceOpOperandNames.push_back( - builder.getStringAttr("duration")); convertedPulseSequenceOpArgs.push_back(convertedDurationToI64); } else if (argumentType.isa()) { auto *qubitOp = callCircuitOp.getOperand(cnt).getDefiningOp(); - addCircuitOperandToEraseList(qubitOp); - } - - else + } else llvm_unreachable("unkown circuit argument."); } } @@ -413,23 +383,16 @@ void QUIRToPulsePass::processMixFrameOpArg( mlir::func::FuncOp &mainFunc, mlir::OpBuilder &builder) { auto mixedFrameOp = addMixFrameOpToIR(mixFrameName, portName, mainFunc, builder); - auto it = std::find(convertedPulseCallSequenceOpOperandNames.begin(), - convertedPulseCallSequenceOpOperandNames.end(), - builder.getStringAttr(mixFrameName)); - if (it == convertedPulseCallSequenceOpOperandNames.end()) { - convertedPulseCallSequenceOpOperandNames.push_back( - builder.getStringAttr(mixFrameName)); + if (operandNameToIndexMap.find(mixFrameName) == operandNameToIndexMap.end()) { + operandNameToIndexMap[mixFrameName] = convertedSequenceOpArgIndex; convertedPulseSequenceOpArgs.push_back(mixedFrameOp); - convertedPulseSequenceOp.insertArgument( - convertedSequenceOpArgIndex, - builder.getType(), DictionaryAttr{}, - argumentValue.getLoc()); + convertedPulseSequenceOp.getBody().addArgument( + builder.getType(), argumentValue.getLoc()); pulseCalSequenceArgs.push_back( convertedPulseSequenceOp.getArguments()[convertedSequenceOpArgIndex]); convertedSequenceOpArgIndex += 1; } else { - uint const mixFrameOperandIndex = - std::distance(convertedPulseCallSequenceOpOperandNames.begin(), it); + uint const mixFrameOperandIndex = operandNameToIndexMap[mixFrameName]; pulseCalSequenceArgs.push_back( convertedPulseSequenceOp.getArguments()[mixFrameOperandIndex]); } @@ -442,22 +405,16 @@ void QUIRToPulsePass::processPortOpArg(std::string const &portName, mlir::func::FuncOp &mainFunc, mlir::OpBuilder &builder) { auto portOp = addPortOpToIR(portName, mainFunc, builder); - auto it = std::find(convertedPulseCallSequenceOpOperandNames.begin(), - convertedPulseCallSequenceOpOperandNames.end(), - builder.getStringAttr(portName)); - if (it == convertedPulseCallSequenceOpOperandNames.end()) { - convertedPulseCallSequenceOpOperandNames.push_back( - builder.getStringAttr(portName)); + if (operandNameToIndexMap.find(portName) == operandNameToIndexMap.end()) { + operandNameToIndexMap[portName] = convertedSequenceOpArgIndex; convertedPulseSequenceOpArgs.push_back(portOp); - convertedPulseSequenceOp.insertArgument( - convertedSequenceOpArgIndex, builder.getType(), - DictionaryAttr{}, argumentValue.getLoc()); + convertedPulseSequenceOp.getBody().addArgument( + builder.getType(), argumentValue.getLoc()); pulseCalSequenceArgs.push_back( convertedPulseSequenceOp.getArguments()[convertedSequenceOpArgIndex]); convertedSequenceOpArgIndex += 1; } else { - uint const portOperandIndex = - std::distance(convertedPulseCallSequenceOpOperandNames.begin(), it); + uint const portOperandIndex = operandNameToIndexMap[portName]; pulseCalSequenceArgs.push_back( convertedPulseSequenceOp.getArguments()[portOperandIndex]); } @@ -470,23 +427,16 @@ void QUIRToPulsePass::processWfrOpArg(std::string const &wfrName, mlir::func::FuncOp &mainFunc, mlir::OpBuilder &builder) { auto wfrOp = addWfrOpToIR(wfrName, mainFunc, builder); - auto it = std::find(convertedPulseCallSequenceOpOperandNames.begin(), - convertedPulseCallSequenceOpOperandNames.end(), - builder.getStringAttr(wfrName)); - if (it == convertedPulseCallSequenceOpOperandNames.end()) { - convertedPulseCallSequenceOpOperandNames.push_back( - builder.getStringAttr(wfrName)); + if (operandNameToIndexMap.find(wfrName) == operandNameToIndexMap.end()) { + operandNameToIndexMap[wfrName] = convertedSequenceOpArgIndex; convertedPulseSequenceOpArgs.push_back(wfrOp); - convertedPulseSequenceOp.insertArgument( - convertedSequenceOpArgIndex, - builder.getType(), DictionaryAttr{}, - argumentValue.getLoc()); + convertedPulseSequenceOp.getBody().addArgument( + builder.getType(), argumentValue.getLoc()); pulseCalSequenceArgs.push_back( convertedPulseSequenceOp.getArguments()[convertedSequenceOpArgIndex]); convertedSequenceOpArgIndex += 1; } else { - uint const wfrOperandIndex = - std::distance(convertedPulseCallSequenceOpOperandNames.begin(), it); + uint const wfrOperandIndex = operandNameToIndexMap[wfrName]; pulseCalSequenceArgs.push_back( convertedPulseSequenceOp.getArguments()[wfrOperandIndex]); } @@ -561,7 +511,6 @@ mlir::Value QUIRToPulsePass::convertAngleToF64(Operation *angleOp, if (classicalQUIROpLocToConvertedPulseOpMap.find(angleLocHash) == classicalQUIROpLocToConvertedPulseOpMap.end()) { if (auto castOp = dyn_cast(angleOp)) { - addCircuitOperandToEraseList(angleOp); double const angleVal = castOp.getAngleValueFromConstant().convertToDouble(); auto f64Angle = builder.create( @@ -575,7 +524,6 @@ mlir::Value QUIRToPulsePass::convertAngleToF64(Operation *angleOp, angleCastedOp->moveAfter(castOp); classicalQUIROpLocToConvertedPulseOpMap[angleLocHash] = angleCastedOp; } else if (auto castOp = dyn_cast(angleOp)) { - addCircuitOperandToEraseList(angleOp); auto castOpArg = castOp.getArg(); if (auto paramCastOp = dyn_cast(castOpArg.getDefiningOp())) { @@ -600,7 +548,6 @@ mlir::Value QUIRToPulsePass::convertDurationToI64( if (classicalQUIROpLocToConvertedPulseOpMap.find(durLocHash) == classicalQUIROpLocToConvertedPulseOpMap.end()) { if (auto castOp = dyn_cast(durationOp)) { - addCircuitOperandToEraseList(durationOp); auto durVal = quir::getDuration(castOp).get().getDuration().convertToDouble(); assert(castOp.getType().dyn_cast().getUnits() == @@ -658,21 +605,6 @@ QUIRToPulsePass::addWfrOpToIR(std::string const &wfrName, return openedWfrs[wfrName]; } -void QUIRToPulsePass::addCircuitToEraseList(mlir::Operation *op) { - assert(op && "caller requested adding a null op to erase list"); - if (std::find(quirCircuitEraseList.begin(), quirCircuitEraseList.end(), op) == - quirCircuitEraseList.end()) - quirCircuitEraseList.push_back(op); -} - -void QUIRToPulsePass::addCircuitOperandToEraseList(mlir::Operation *op) { - assert(op && "caller requested adding a null op to erase list"); - if (std::find(quirCircuitOperandEraseList.begin(), - quirCircuitOperandEraseList.end(), - op) == quirCircuitOperandEraseList.end()) - quirCircuitOperandEraseList.push_back(op); -} - void QUIRToPulsePass::parsePulseWaveformContainerOps( std::string &waveformContainerPath) { std::string errorMessage; diff --git a/test/Conversion/QUIRToPulse/convert-quir-to-pulse.mlir b/test/Conversion/QUIRToPulse/convert-quir-to-pulse.mlir index 687e4b973..3cb76b4f7 100644 --- a/test/Conversion/QUIRToPulse/convert-quir-to-pulse.mlir +++ b/test/Conversion/QUIRToPulse/convert-quir-to-pulse.mlir @@ -71,19 +71,19 @@ module { %false = arith.constant false pulse.return %false : i1 } - // CHECK: pulse.sequence @circuit_0_q5_q3_circuit_1_q5_sequence(%arg0: !pulse.mixed_frame, %arg1: !pulse.mixed_frame, %arg2: !pulse.mixed_frame, %arg3: !pulse.mixed_frame, %arg4: !pulse.mixed_frame, %arg5: !pulse.mixed_frame) -> (i1, i1, i1, i1) attributes {pulse.args = ["q3-drive-mixframe", "q5-drive-mixframe", "q3-readout-mixframe", "q3-capture-mixframe", "q5-readout-mixframe", "q5-capture-mixframe"]} { - // CHECK: %0 = pulse.call_sequence @x_3(%arg0) {{{.*}} : (!pulse.mixed_frame) -> i1 - // CHECK: %1 = pulse.call_sequence @sx_5(%arg1) {{{.*}} : (!pulse.mixed_frame) -> i1 - // CHECK: %2:2 = pulse.call_sequence @measure_3_5(%arg2, %arg3, %arg4, %arg5) {{{.*}} : (!pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame) -> (i1, i1) + // CHECK: pulse.sequence @circuit_0_q5_q3_circuit_1_q5_sequence(%arg0: !pulse.mixed_frame, %arg1: !pulse.mixed_frame, %arg2: !pulse.mixed_frame, %arg3: !pulse.mixed_frame, %arg4: !pulse.mixed_frame, %arg5: !pulse.mixed_frame) -> (i1, i1, i1, i1) { + // CHECK: %0 = pulse.call_sequence @x_3(%arg0) : (!pulse.mixed_frame) -> i1 + // CHECK: %1 = pulse.call_sequence @sx_5(%arg1) : (!pulse.mixed_frame) -> i1 + // CHECK: %2:2 = pulse.call_sequence @measure_3_5(%arg2, %arg3, %arg4, %arg5) : (!pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame) -> (i1, i1) // CHECK: pulse.return %0, %1, %2#0, %2#1 : i1, i1, i1, i1 - // CHECK: pulse.sequence @circuit_2_q5_q3_circuit_3_q5_sequence(%arg0: !pulse.mixed_frame, %arg1: !pulse.mixed_frame, %arg2: !pulse.mixed_frame, %arg3: !pulse.mixed_frame, %arg4: !pulse.mixed_frame, %arg5: !pulse.mixed_frame, %arg6: !pulse.mixed_frame) -> (i1, i1, i1, i1, i1, i1) attributes {pulse.args = ["q5-drive-mixframe", "q3-5-cx-mixframe", "q5-3-cx-mixframe", "q3-readout-mixframe", "q3-capture-mixframe", "q5-readout-mixframe", "q5-capture-mixframe"]} { + // CHECK: pulse.sequence @circuit_2_q5_q3_circuit_3_q5_sequence(%arg0: !pulse.mixed_frame, %arg1: !pulse.mixed_frame, %arg2: !pulse.mixed_frame, %arg3: !pulse.mixed_frame, %arg4: !pulse.mixed_frame, %arg5: !pulse.mixed_frame, %arg6: !pulse.mixed_frame) -> (i1, i1, i1, i1, i1, i1) { // CHECK: %cst = arith.constant 1.5707963267948966 : f64 - // CHECK: %0 = pulse.call_sequence @rz_5(%cst, %arg0) {{{.*}} : (f64, !pulse.mixed_frame) -> i1 - // CHECK: %1 = pulse.call_sequence @sx_5(%arg0) {{{.*}} : (!pulse.mixed_frame) -> i1 - // CHECK: %2 = pulse.call_sequence @rz_5(%cst, %arg0) {{{.*}} : (f64, !pulse.mixed_frame) -> i1 - // CHECK: %3 = pulse.call_sequence @cx_5_3(%arg1, %arg2) {{{.*}} : (!pulse.mixed_frame, !pulse.mixed_frame) -> i1 - // CHECK: %4:2 = pulse.call_sequence @measure_3_5(%arg3, %arg4, %arg5, %arg6) {{{.*}} : (!pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame) -> (i1, i1) + // CHECK: %0 = pulse.call_sequence @rz_5(%cst, %arg0) : (f64, !pulse.mixed_frame) -> i1 + // CHECK: %1 = pulse.call_sequence @sx_5(%arg0) : (!pulse.mixed_frame) -> i1 + // CHECK: %2 = pulse.call_sequence @rz_5(%cst, %arg0) : (f64, !pulse.mixed_frame) -> i1 + // CHECK: %3 = pulse.call_sequence @cx_5_3(%arg1, %arg2) : (!pulse.mixed_frame, !pulse.mixed_frame) -> i1 + // CHECK: %4:2 = pulse.call_sequence @measure_3_5(%arg3, %arg4, %arg5, %arg6) : (!pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame) -> (i1, i1) // CHECK: pulse.return %0, %1, %2, %3, %4#0, %4#1 : i1, i1, i1, i1, i1, i1 func.func @main() -> i32 attributes {quir.classicalOnly = false} { @@ -121,12 +121,10 @@ module { // CHECK-NOT: %5 = quir.declare_qubit {id = 5 : i32} : !quir.qubit<1> %7:2 = quir.call_circuit @circuit_0_q5_q3_circuit_1_q5(%5, %3) : (!quir.qubit<1>, !quir.qubit<1>) -> (i1, i1) // CHECK-NOT: %7:2 = quir.call_circuit @circuit_0_q5_q3_circuit_1_q5(%5, %3) : (!quir.qubit<1>, !quir.qubit<1>) -> (i1, i1) - // CHECK: %14:4 = pulse.call_sequence @circuit_0_q5_q3_circuit_1_q5_sequence(%1, %3, %5, %7, %9, %11) {{{.*}} : (!pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame) -> (i1, i1, i1, i1) - quir.barrier %3, %5 : (!quir.qubit<1>, !quir.qubit<1>) -> () - // CHECK-NOT: %quir.barrier %3, %5 : (!quir.qubit<1>, !quir.qubit<1>) -> () + // CHECK: %14:4 = pulse.call_sequence @circuit_0_q5_q3_circuit_1_q5_sequence(%1, %3, %5, %7, %9, %11) : (!pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame) -> (i1, i1, i1, i1) %8:2 = quir.call_circuit @circuit_2_q5_q3_circuit_3_q5(%5, %3) : (!quir.qubit<1>, !quir.qubit<1>) -> (i1, i1) // CHECK-NOT: %8:2 = quir.call_circuit @circuit_2_q5_q3_circuit_3_q5(%5, %3) : (!quir.qubit<1>, !quir.qubit<1>) -> (i1, i1) - // CHECK: %15:6 = pulse.call_sequence @circuit_2_q5_q3_circuit_3_q5_sequence(%3, %12, %13, %5, %7, %9, %11) {{{.*}} : (!pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame) -> (i1, i1, i1, i1, i1, i1) + // CHECK: %15:6 = pulse.call_sequence @circuit_2_q5_q3_circuit_3_q5_sequence(%3, %12, %13, %5, %7, %9, %11) : (!pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame) -> (i1, i1, i1, i1, i1, i1) } {qcs.shot_loop} return %c0_i32 : i32 }