Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce memory usage in pulse sequencing #336

Merged
merged 30 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
4491c17
Bump qasm version to fix bug.
taalexander Jun 12, 2024
6b8a593
Try more efficient parameter loading.
taalexander Jun 14, 2024
d9c2b83
Remove initial value of parameter.
taalexander Jun 14, 2024
80b4bbf
Parameter parsing working with more efficient parsing.
taalexander Jun 17, 2024
b12b961
Fix dual use bug.
taalexander Jun 17, 2024
1d08977
Fix multi-circuit case.
taalexander Jun 17, 2024
46a0203
Fix operation removal.
taalexander Jun 17, 2024
049675e
Remove assert.
taalexander Jun 17, 2024
0c100de
Fix bugs from using locations as hashes which can be collide for oper…
taalexander Jun 17, 2024
8ecd7d1
Fix core tests.
taalexander Jun 17, 2024
2b6f999
Fixing tests.
taalexander Jun 18, 2024
4801bd3
Fixing test.
taalexander Jun 18, 2024
7fbe9c1
Fix additional tests.
taalexander Jun 18, 2024
4647c53
Tests passing
taalexander Jun 18, 2024
08b846f
Add reno.
taalexander Jun 18, 2024
fa489d5
Update qasm version.
taalexander Jun 18, 2024
2ce92ad
Tidying.
taalexander Jun 18, 2024
b5cf8b5
Fix bug in measurement ordering.
taalexander Jun 19, 2024
324b2e5
Tidying.
taalexander Jun 19, 2024
102fd9a
Tidying.
taalexander Jun 19, 2024
b649a47
Remove dummy values being created.
taalexander Jun 20, 2024
d9a89c2
Fix variable assignment.
taalexander Jun 20, 2024
2293ec1
Tidy
taalexander Jun 20, 2024
1fc1d70
Fix bug in error handling.
taalexander Jun 20, 2024
4b6e2da
Tidy
taalexander Jun 21, 2024
9d21372
Remove greedy rewrite from variable removal pass for performance.
taalexander Jun 26, 2024
16aed8d
Switch to llvm::DenseMap
taalexander Jun 24, 2024
353ba65
Add reno.
taalexander Jun 24, 2024
0e335cb
Fixup grammar.
taalexander Jun 25, 2024
3381a48
Optimize walks.
taalexander Jun 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion conan/qasm/conandata.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
sources:
hash: "f6d695fd9f18462e65f6290d05ccb4ccb371b288"
hash: "ec7731bf645240a597cd9ebb2c395b114f155ed2"
requirements:
- "gmp/6.3.0"
- "mpfr/4.1.0"
Expand Down
2 changes: 1 addition & 1 deletion conan/qasm/conanfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

class QasmConan(ConanFile):
name = "qasm"
version = "0.3.2"
version = "0.3.3"
url = "https://github.com/openqasm/qe-qasm.git"
settings = "os", "compiler", "build_type", "arch"
options = {"shared": [True, False], "examples": [True, False]}
Expand Down
2 changes: 1 addition & 1 deletion conandata.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ requirements:
- pybind11/2.11.1
- clang-tools-extra/17.0.5-0@
- llvm/17.0.5-0@
- qasm/0.3.2@qss/stable
- qasm/0.3.3@qss/stable
2 changes: 1 addition & 1 deletion include/Conversion/QUIRToPulse/QUIRToPulse.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ struct QUIRToPulsePass
mlir::func::FuncOp &mainFunc);
// map of the hashed location of quir angle/duration ops to their converted
// pulse ops
std::unordered_map<std::string, mlir::Value>
std::unordered_map<Operation *, mlir::Value>
classicalQUIROpLocToConvertedPulseOpMap;

// port name to Port_CreateOp map
Expand Down
14 changes: 8 additions & 6 deletions include/Dialect/QUIR/Transforms/ExtractCircuits.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
#include "Utils/SymbolCacheAnalysis.h"

#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"

#include "llvm/ADT/SmallVector.h"

#include <set>
#include <unordered_map>

namespace mlir::quir {
Expand All @@ -49,14 +49,14 @@ struct ExtractCircuitsPass
OpBuilder circuitBuilder);
OpBuilder startCircuit(mlir::Location location, OpBuilder topLevelBuilder);
void endCircuit(mlir::Operation *firstOp, mlir::Operation *lastOp,
OpBuilder topLevelBuilder, OpBuilder circuitBuilder,
llvm::SmallVector<Operation *> &eraseList);
void addToCircuit(mlir::Operation *currentOp, OpBuilder circuitBuilder,
llvm::SmallVector<Operation *> &eraseList);
OpBuilder topLevelBuilder, OpBuilder circuitBuilder);
void addToCircuit(mlir::Operation *currentOp, OpBuilder circuitBuilder);

uint64_t circuitCount = 0;
qssc::utils::SymbolCacheAnalysis *symbolCache{nullptr};

mlir::quir::CircuitOp currentCircuitOp = nullptr;
mlir::IRMapping currentCircuitMapper;
mlir::quir::CallCircuitOp newCallCircuitOp;

llvm::SmallVector<Type> inputTypes;
Expand All @@ -68,6 +68,8 @@ struct ExtractCircuitsPass

std::unordered_map<Operation *, uint32_t> circuitOperands;
llvm::SmallVector<OpResult> originalResults;
std::set<Operation *> eraseConstSet;
std::set<Operation *> eraseOpSet;

}; // struct ExtractCircuitsPass
} // namespace mlir::quir
Expand Down
3 changes: 3 additions & 0 deletions include/Frontend/OpenQASM3/QUIRGenQASM3Visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,9 @@ class QUIRGenQASM3Visitor : public BaseQASM3Visitor {
mlir::Type getQUIRTypeFromDeclaration(const QASM::ASTDeclarationNode *);

bool enableParametersWarningEmitted = false;

/// Cached dummy value for error handling
mlir::Value voidValue;
};

} // namespace qssc::frontend::openqasm3
Expand Down
2 changes: 1 addition & 1 deletion include/Frontend/OpenQASM3/QUIRVariableBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class QUIRVariableBuilder {

mlir::Value generateParameterLoad(mlir::Location location,
llvm::StringRef variableName,
mlir::Value assignedValue);
double initialValue);

/// Generate code for declaring an array (at the builder's current insertion
/// point).
Expand Down
8 changes: 5 additions & 3 deletions include/Utils/SymbolCacheAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@ namespace qssc::utils {
// .addToCache<SequenceOp>();
//
// This analysis is intended to be used with MLIR's getAnalysis
// framework. It has been designed to reused the chached value
// framework. It has been designed to reuse the cached value
// and will not be invalidated automatically with each pass.
// If a pass manipulates the symbols that are cached with this
// analysis then it should use the addCallee method to update the
// map or call invalidate after appying updates.
// map or call invalidate after applying updates.
// Note this analysis should always be used by reference or
// via a pointer to ensure that updates are applied to the maps
// stored by the MLIR analysis framework.
Expand Down Expand Up @@ -95,6 +95,8 @@ class SymbolCacheAnalysis {

op->walk([&](CalleeOp op) {
symbolOpsMap[op.getSymName()] = op.getOperation();
// Don't recurse symbols
return mlir::WalkResult::skip();
});
cachedTypes.insert(typeName);
invalid = false;
Expand Down Expand Up @@ -193,7 +195,7 @@ class SymbolCacheAnalysis {

private:
llvm::StringMap<mlir::Operation *> symbolOpsMap;
std::unordered_map<mlir::Operation *, mlir::Operation *> callMap;
llvm::DenseMap<mlir::Operation *, mlir::Operation *> callMap;
std::unordered_set<std::string> cachedTypes;
mlir::Operation *topOp{nullptr};
bool invalid{true};
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/QUIRToPulse/LoadPulseCals.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ void LoadPulseCalsPass::loadPulseCals(CallCircuitOp callCircuitOp,
LLVM_DEBUG(llvm::dbgs() << "no pulse cal loading needed for " << op);
assert((!op->hasTrait<mlir::quir::UnitaryOp>() and
!op->hasTrait<mlir::quir::CPTPOp>()) &&
"unkown operation");
"unknown operation");
}
});
}
Expand Down
99 changes: 62 additions & 37 deletions lib/Conversion/QUIRToPulse/QUIRToPulse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,10 @@ void QUIRToPulsePass::runOnOperation() {
moduleOp->walk([&](CallCircuitOp callCircOp) {
if (isa<CircuitOp>(callCircOp->getParentOp()))
return;

auto convertedPulseCallSequenceOp =
convertCircuitToSequence(callCircOp, mainFunc, moduleOp);

if (!callCircOp->use_empty())
callCircOp->replaceAllUsesWith(convertedPulseCallSequenceOp);
callCircOp->erase();
Expand Down Expand Up @@ -229,8 +231,9 @@ QUIRToPulsePass::convertCircuitToSequence(CallCircuitOp &callCircuitOp,
auto *newDelayCyclesOp = builder.clone(*quirOp);
newDelayCyclesOp->moveAfter(callCircuitOp);
} else
assert(((isa<quir::ConstantOp>(quirOp) or isa<quir::ReturnOp>(quirOp) or
isa<quir::CircuitOp>(quirOp))) &&
assert(((isa<quir::ConstantOp>(quirOp) ||
isa<qcs::ParameterLoadOp>(quirOp) ||
isa<quir::ReturnOp>(quirOp) || isa<quir::CircuitOp>(quirOp))) &&
"quir op is not allowed in this pass.");
});

Expand All @@ -251,6 +254,7 @@ QUIRToPulsePass::convertCircuitToSequence(CallCircuitOp &callCircuitOp,
convertedPulseSequenceOp,
convertedPulseSequenceOpArgs);
convertedPulseCallSequenceOp->moveAfter(callCircuitOp);

return convertedPulseCallSequenceOp;
}

Expand Down Expand Up @@ -286,7 +290,7 @@ void QUIRToPulsePass::processCircuitArgs(
} else if (argumentType.isa<mlir::quir::QubitType>()) {
auto *qubitOp = callCircuitOp.getOperand(cnt).getDefiningOp();
} else
llvm_unreachable("unkown circuit argument.");
llvm_unreachable("unknown circuit argument.");
}
}

Expand Down Expand Up @@ -339,7 +343,7 @@ void QUIRToPulsePass::processPulseCalArgs(
} else if (argumentType.isa<FloatType>()) {
assert(argAttr[index].dyn_cast<StringAttr>().getValue().str() ==
"angle" &&
"unkown argument.");
"unknown argument.");
assert(angleOperands.size() && "no angle operand found.");
auto nextAngle = angleOperands.front();
LLVM_DEBUG(llvm::dbgs() << "angle argument ");
Expand All @@ -350,7 +354,7 @@ void QUIRToPulsePass::processPulseCalArgs(
} else if (argumentType.isa<IntegerType>()) {
assert(argAttr[index].dyn_cast<StringAttr>().getValue().str() ==
"duration" &&
"unkown argument.");
"unknown argument.");
assert(durationOperands.size() && "no duration operand found.");
auto nextDuration = durationOperands.front();
LLVM_DEBUG(llvm::dbgs() << "duration argument ");
Expand All @@ -359,7 +363,7 @@ void QUIRToPulsePass::processPulseCalArgs(
pulseCalSequenceArgs, builder);
durationOperands.pop();
} else
llvm_unreachable("unkown argument type.");
llvm_unreachable("unknown argument type.");
}
}

Expand All @@ -379,12 +383,13 @@ void QUIRToPulsePass::getQUIROpClassicalOperands(
}

for (auto operand : classicalOperands)
if (operand.getType().isa<mlir::quir::AngleType>())
if (operand.getType().isa<mlir::quir::AngleType>() ||
operand.getType().isa<FloatType>())
angleOperands.push(operand);
else if (operand.getType().isa<mlir::quir::DurationType>())
durationOperands.push(operand);
else
llvm_unreachable("unkown operand.");
llvm_unreachable("unknown operand.");
}

void QUIRToPulsePass::processMixFrameOpArg(
Expand Down Expand Up @@ -463,21 +468,38 @@ void QUIRToPulsePass::processAngleArg(Value nextAngleOperand,
pulseCalSequenceArgs.push_back(
convertedPulseSequenceOp
.getArguments()[circuitArgToConvertedSequenceArgMap[circNum]]);
} else {
auto angleOp = nextAngleOperand.getDefiningOp<mlir::quir::ConstantOp>();
std::string const angleLocHash =
std::to_string(mlir::hash_value(angleOp->getLoc()));
if (classicalQUIROpLocToConvertedPulseOpMap.find(angleLocHash) ==
} else if (auto angleOp =
nextAngleOperand.getDefiningOp<mlir::quir::ConstantOp>()) {
auto *op = angleOp.getOperation();
if (classicalQUIROpLocToConvertedPulseOpMap.find(op) ==
classicalQUIROpLocToConvertedPulseOpMap.end()) {
double const angleVal =
angleOp.getAngleValueFromConstant().convertToDouble();
auto f64Angle = entryBuilder.create<mlir::arith::ConstantOp>(
angleOp.getLoc(), entryBuilder.getFloatAttr(entryBuilder.getF64Type(),
llvm::APFloat(angleVal)));
classicalQUIROpLocToConvertedPulseOpMap[angleLocHash] = f64Angle;
classicalQUIROpLocToConvertedPulseOpMap[op] = f64Angle;
}
pulseCalSequenceArgs.push_back(
classicalQUIROpLocToConvertedPulseOpMap[angleLocHash]);
pulseCalSequenceArgs.push_back(classicalQUIROpLocToConvertedPulseOpMap[op]);
} else if (auto paramOp =
nextAngleOperand.getDefiningOp<mlir::qcs::ParameterLoadOp>()) {
auto *op = paramOp.getOperation();
if (classicalQUIROpLocToConvertedPulseOpMap.find(op) ==
classicalQUIROpLocToConvertedPulseOpMap.end()) {

auto newParam = entryBuilder.create<qcs::ParameterLoadOp>(
paramOp->getLoc(), entryBuilder.getF64Type(),
paramOp.getParameterName());
if (paramOp->hasAttr("initialValue")) {
auto initAttr = paramOp->getAttr("initialValue").dyn_cast<FloatAttr>();
if (initAttr)
newParam->setAttr("initialValue", initAttr);
}

classicalQUIROpLocToConvertedPulseOpMap[op] = newParam;
}

pulseCalSequenceArgs.push_back(classicalQUIROpLocToConvertedPulseOpMap[op]);
}
}

Expand All @@ -501,25 +523,23 @@ void QUIRToPulsePass::processDurationArg(
TimeUnits::dt &&
"this pass only accepts durations with dt unit");

if (classicalQUIROpLocToConvertedPulseOpMap.find(durLocHash) ==
auto *op = durationOp.getOperation();
if (classicalQUIROpLocToConvertedPulseOpMap.find(op) ==
classicalQUIROpLocToConvertedPulseOpMap.end()) {
auto dur64 = entryBuilder.create<mlir::arith::ConstantOp>(
durationOp.getLoc(),
entryBuilder.getIntegerAttr(entryBuilder.getI64Type(),
uint64_t(durVal)));
classicalQUIROpLocToConvertedPulseOpMap[durLocHash] = dur64;
classicalQUIROpLocToConvertedPulseOpMap[op] = dur64;
}
pulseCalSequenceArgs.push_back(
classicalQUIROpLocToConvertedPulseOpMap[durLocHash]);
pulseCalSequenceArgs.push_back(classicalQUIROpLocToConvertedPulseOpMap[op]);
}
}

mlir::Value QUIRToPulsePass::convertAngleToF64(Operation *angleOp,
mlir::OpBuilder &builder) {
assert(angleOp && "angle op is null");
std::string const angleLocHash =
std::to_string(mlir::hash_value(angleOp->getLoc()));
if (classicalQUIROpLocToConvertedPulseOpMap.find(angleLocHash) ==
if (classicalQUIROpLocToConvertedPulseOpMap.find(angleOp) ==
classicalQUIROpLocToConvertedPulseOpMap.end()) {
if (auto castOp = dyn_cast<quir::ConstantOp>(angleOp)) {
double const angleVal =
Expand All @@ -528,41 +548,46 @@ mlir::Value QUIRToPulsePass::convertAngleToF64(Operation *angleOp,
castOp->getLoc(),
builder.getFloatAttr(builder.getF64Type(), llvm::APFloat(angleVal)));
f64Angle->moveAfter(castOp);
classicalQUIROpLocToConvertedPulseOpMap[angleLocHash] = f64Angle;
classicalQUIROpLocToConvertedPulseOpMap[angleOp] = f64Angle;
} else if (auto castOp = dyn_cast<qcs::ParameterLoadOp>(angleOp)) {
auto angleCastedOp = builder.create<oq3::CastOp>(
castOp->getLoc(), builder.getF64Type(), castOp.getRes());
angleCastedOp->moveAfter(castOp);
classicalQUIROpLocToConvertedPulseOpMap[angleLocHash] = angleCastedOp;
// Just convert to an f64 directly
auto newParam = builder.create<qcs::ParameterLoadOp>(
angleOp->getLoc(), builder.getF64Type(), castOp.getParameterName());
if (castOp->hasAttr("initialValue")) {
auto initAttr = castOp->getAttr("initialValue").dyn_cast<FloatAttr>();
if (initAttr)
newParam->setAttr("initialValue", initAttr);
}
newParam->moveAfter(castOp);

classicalQUIROpLocToConvertedPulseOpMap[angleOp] = newParam;
} else if (auto castOp = dyn_cast<oq3::CastOp>(angleOp)) {
auto castOpArg = castOp.getArg();
if (auto paramCastOp =
dyn_cast<qcs::ParameterLoadOp>(castOpArg.getDefiningOp())) {
auto angleCastedOp = builder.create<oq3::CastOp>(
paramCastOp->getLoc(), builder.getF64Type(), paramCastOp.getRes());
angleCastedOp->moveAfter(paramCastOp);
classicalQUIROpLocToConvertedPulseOpMap[angleLocHash] = angleCastedOp;
classicalQUIROpLocToConvertedPulseOpMap[angleOp] = angleCastedOp;
} else if (auto constOp =
dyn_cast<arith::ConstantOp>(castOpArg.getDefiningOp())) {
// if cast from float64 then use directly
assert(constOp.getType() == builder.getF64Type() &&
"expected angle type to be float 64");
classicalQUIROpLocToConvertedPulseOpMap[angleLocHash] = constOp;
classicalQUIROpLocToConvertedPulseOpMap[angleOp] = constOp;
} else
llvm_unreachable("castOp arg unknown");
} else
llvm_unreachable("angleOp unknown");
}
return classicalQUIROpLocToConvertedPulseOpMap[angleLocHash];
return classicalQUIROpLocToConvertedPulseOpMap[angleOp];
}

mlir::Value QUIRToPulsePass::convertDurationToI64(
mlir::quir::CallCircuitOp &callCircuitOp, Operation *durationOp, uint &cnt,
mlir::OpBuilder &builder, mlir::func::FuncOp &mainFunc) {
assert(durationOp && "duration op is null");
std::string const durLocHash =
std::to_string(mlir::hash_value(durationOp->getLoc()));
if (classicalQUIROpLocToConvertedPulseOpMap.find(durLocHash) ==
if (classicalQUIROpLocToConvertedPulseOpMap.find(durationOp) ==
classicalQUIROpLocToConvertedPulseOpMap.end()) {
if (auto castOp = dyn_cast<quir::ConstantOp>(durationOp)) {
auto durVal =
Expand All @@ -575,11 +600,11 @@ mlir::Value QUIRToPulsePass::convertDurationToI64(
castOp->getLoc(),
builder.getIntegerAttr(builder.getI64Type(), uint64_t(durVal)));
I64Dur->moveAfter(castOp);
classicalQUIROpLocToConvertedPulseOpMap[durLocHash] = I64Dur;
classicalQUIROpLocToConvertedPulseOpMap[durationOp] = I64Dur;
} else
llvm_unreachable("unkown duration op");
llvm_unreachable("unknown duration op");
}
return classicalQUIROpLocToConvertedPulseOpMap[durLocHash];
return classicalQUIROpLocToConvertedPulseOpMap[durationOp];
}

mlir::pulse::Port_CreateOp
Expand Down
6 changes: 4 additions & 2 deletions lib/Dialect/Pulse/IR/PulseOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "Dialect/Pulse/IR/PulseOps.h"

#include "Dialect/Pulse/IR/PulseTraits.h"
#include "Dialect/QCS/IR/QCSOps.h"
#include "Dialect/QUIR/IR/QUIROps.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
Expand Down Expand Up @@ -356,8 +357,9 @@ LogicalResult verifyClassical_(SequenceOp op) {
mlir::Operation *classicalOp = nullptr;
WalkResult const result = op->walk([&](Operation *subOp) {
if (isa<mlir::arith::ConstantOp>(subOp) || isa<quir::ConstantOp>(subOp) ||
isa<CallSequenceOp>(subOp) || isa<pulse::ReturnOp>(subOp) ||
isa<SequenceOp>(subOp) || isa<mlir::complex::CreateOp>(subOp) ||
isa<qcs::ParameterLoadOp>(subOp) || isa<CallSequenceOp>(subOp) ||
isa<pulse::ReturnOp>(subOp) || isa<SequenceOp>(subOp) ||
isa<mlir::complex::CreateOp>(subOp) ||
subOp->hasTrait<mlir::pulse::SequenceAllowed>() ||
subOp->hasTrait<mlir::pulse::SequenceRequired>())
return WalkResult::advance();
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Pulse/Transforms/Scheduling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ void QuantumCircuitPulseSchedulingPass::scheduleAlap(
opEnd = quantumCircuitSequenceOpBlock->rend();
opIt != opEnd; ++opIt) {
auto &op = *opIt;

if (auto quantumGateCallSequenceOp =
dyn_cast<mlir::pulse::CallSequenceOp>(op)) {
// find quantum gate SequenceOp
Expand Down
Loading
Loading