Skip to content

Commit

Permalink
Parameters Compiling Performance Improvements (#330)
Browse files Browse the repository at this point in the history
Improves performance of compilation when using parameters by dropping
the use of the DeclareParameterOp and resulting symbol lookup. The
initial value has been moved to an attribute of the ParameterLoadOp.
Removes the use of ParameterInitialValueAnalysis.
  • Loading branch information
bcdonovan authored May 24, 2024
1 parent 1f506a7 commit 7d71243
Show file tree
Hide file tree
Showing 15 changed files with 76 additions and 238 deletions.
8 changes: 2 additions & 6 deletions include/Dialect/QCS/IR/QCSOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def QCS_DeclareParameterOp : QCS_Op<"declare_parameter", [Symbol]> {
}

def QCS_ParameterLoadOp : QCS_Op<"parameter_load",
[DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
[]> {
let summary = "Use the current value of a parameter";
let description = [{
The operation `qcs.parameter_load` returns the current value of the
Expand All @@ -281,16 +281,12 @@ def QCS_ParameterLoadOp : QCS_Op<"parameter_load",
}];

let arguments = (ins
FlatSymbolRefAttr:$parameter_name
StrAttr:$parameter_name
);

let results = (outs AnyClassical:$res);

let extraClassDeclaration = [{
// Return the initial value - using ParameterInitialValueAnalysis
ParameterType getInitialValue(llvm::StringMap<ParameterType> &parameterNames);

// Return the initial value - slower SymbolTable version
ParameterType getInitialValue();
}];

Expand Down
4 changes: 2 additions & 2 deletions include/Dialect/QCS/Utils/ParameterInitialValueAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ using InitialValueType = llvm::StringMap<ParameterType>;

class ParameterInitialValueAnalysis {
private:
InitialValueType initial_values_;
InitialValueType initialValues_;
bool invalid_{true};

public:
ParameterInitialValueAnalysis(mlir::Operation *op);
InitialValueType &getNames() { return initial_values_; }
InitialValueType &getNames() { return initialValues_; }
void invalidate() { invalid_ = true; }
bool isInvalidated(const AnalysisManager::PreservedAnalyses &pa) {
return invalid_;
Expand Down
5 changes: 1 addition & 4 deletions include/Dialect/QUIR/Transforms/QUIRCircuitAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#define QUIR_CIRCUITS_ANALYSIS_H

#include "Dialect/Pulse/IR/PulseOps.h"
#include "Dialect/QCS/Utils/ParameterInitialValueAnalysis.h"
#include "Dialect/QUIR/IR/QUIROps.h"

#include "mlir/IR/BuiltinOps.h"
Expand Down Expand Up @@ -53,8 +52,7 @@ class QUIRCircuitAnalysis {
}

private:
double getAngleValue(mlir::Value operand,
mlir::qcs::ParameterInitialValueAnalysis *nameAnalysis);
double getAngleValue(mlir::Value operand);
llvm::StringRef getParameterName(mlir::Value operand);
quir::DurationAttr getDuration(mlir::Value operand);
};
Expand All @@ -75,7 +73,6 @@ struct QUIRCircuitAnalysisPass

llvm::Expected<double>
angleValToDouble(mlir::Value inVal,
mlir::qcs::ParameterInitialValueAnalysis *nameAnalysis,
mlir::quir::QUIRCircuitAnalysis *circuitAnalysis = nullptr);

} // namespace mlir::quir
Expand Down
2 changes: 2 additions & 0 deletions include/Frontend/OpenQASM3/QUIRGenQASM3Visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,8 @@ class QUIRGenQASM3Visitor : public BaseQASM3Visitor {
ExpressionValueType getValueFromLiteral(const QASM::ASTMPDecimalNode *);

mlir::Type getQUIRTypeFromDeclaration(const QASM::ASTDeclarationNode *);

bool enableParametersWarningEmitted = false;
};

} // namespace qssc::frontend::openqasm3
Expand Down
135 changes: 6 additions & 129 deletions lib/Dialect/QCS/IR/QCSOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,156 +21,33 @@
#include "Dialect/QCS/IR/QCSOps.h"

#include "Dialect/QCS/IR/QCSTypes.h"
#include "Dialect/QUIR/IR/QUIRAttributes.h"

#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/SymbolTable.h"
#include <mlir/IR/OpImplementation.h>
#include <mlir/IR/OperationSupport.h>
#include <mlir/Support/LogicalResult.h>

#include "llvm/ADT/StringMap.h"

#include <cassert>

using namespace mlir;
using namespace mlir::qcs;

#define GET_OP_CLASSES
// NOLINTNEXTLINE(misc-include-cleaner): Required for MLIR registrations
#include "Dialect/QCS/IR/QCSOps.cpp.inc"

namespace {
LogicalResult
verifyQCSParameterOpSymbolUses(SymbolTableCollection &symbolTable,
mlir::Operation *op,
bool operandMustMatchSymbolType = false) {
assert(op);

// Check that op has attribute variable_name
auto paramRefAttr = op->getAttrOfType<FlatSymbolRefAttr>("parameter_name");
if (!paramRefAttr)
return op->emitOpError(
"requires a symbol reference attribute 'parameter_name'");

// Check that symbol reference resolves to a parameter declaration
auto declOp =
symbolTable.lookupNearestSymbolFrom<DeclareParameterOp>(op, paramRefAttr);

// check higher level modules
if (!declOp) {
auto targetModuleOp = op->getParentOfType<mlir::ModuleOp>();
if (targetModuleOp) {
auto topLevelModuleOp = targetModuleOp->getParentOfType<mlir::ModuleOp>();
if (!declOp && topLevelModuleOp)
declOp = symbolTable.lookupNearestSymbolFrom<DeclareParameterOp>(
topLevelModuleOp, paramRefAttr);
}
}

if (!declOp)
return op->emitOpError() << "no valid reference to a parameter '"
<< paramRefAttr.getValue() << "'";

assert(op->getNumResults() <= 1 && "assume none or single result");

// Check that type of variables matches result type of this Op
if (op->getNumResults() == 1) {
if (op->getResult(0).getType() != declOp.getType())
return op->emitOpError(
"type mismatch between variable declaration and variable use");
}

if (op->getNumOperands() > 0 && operandMustMatchSymbolType) {
assert(op->getNumOperands() == 1 &&
"type check only supported for a single operand");
if (op->getOperand(0).getType() != declOp.getType())
return op->emitOpError(
"type mismatch between variable declaration and variable assignment");
}
return success();
}

} // anonymous namespace

//===----------------------------------------------------------------------===//
// ParameterLoadOp
//===----------------------------------------------------------------------===//

LogicalResult
ParameterLoadOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return verifyQCSParameterOpSymbolUses(symbolTable, getOperation(), true);
}

// Returns the float value from the initial value of this parameter
ParameterType ParameterLoadOp::getInitialValue() {
auto *op = getOperation();
auto paramRefAttr =
op->getAttrOfType<mlir::FlatSymbolRefAttr>("parameter_name");
auto declOp =
mlir::SymbolTable::lookupNearestSymbolFrom<mlir::qcs::DeclareParameterOp>(
op, paramRefAttr);

// check higher level modules

auto currentScopeOp = op->getParentOfType<mlir::ModuleOp>();
do {
declOp = mlir::SymbolTable::lookupNearestSymbolFrom<
mlir::qcs::DeclareParameterOp>(currentScopeOp, paramRefAttr);
if (declOp)
break;
currentScopeOp = currentScopeOp->getParentOfType<mlir::ModuleOp>();
assert(currentScopeOp);
} while (!declOp);

assert(declOp);

double retVal;

auto iniValue = declOp.getInitialValue();
if (iniValue.has_value()) {
auto angleAttr = iniValue.value().dyn_cast<mlir::quir::AngleAttr>();

auto floatAttr = iniValue.value().dyn_cast<FloatAttr>();

if (!(angleAttr || floatAttr)) {
op->emitError(
"Parameters are currently limited to angles or float[64] only.");
return 0.0;
}

if (angleAttr)
retVal = angleAttr.getValue().convertToDouble();

if (floatAttr)
retVal = floatAttr.getValue().convertToDouble();

return retVal;
}

op->emitError("Does not have initial value set.");
return 0.0;
}

// Returns the float value from the initial value of this parameter
// this version uses a precomputed map of parameter_name to the initial_value
// in order to avoid slow SymbolTable lookups
ParameterType ParameterLoadOp::getInitialValue(
llvm::StringMap<ParameterType> &declareParametersMap) {
auto *op = getOperation();
auto paramRefAttr =
op->getAttrOfType<mlir::FlatSymbolRefAttr>("parameter_name");

auto paramOpEntry = declareParametersMap.find(paramRefAttr.getValue());

if (paramOpEntry == declareParametersMap.end()) {
op->emitError("Could not find declare parameter op " +
paramRefAttr.getValue().str());
return 0.0;
double retVal = 0.0;
if (op->hasAttr("initialValue")) {
auto initAttr = op->getAttr("initialValue").dyn_cast<FloatAttr>();
if (initAttr)
retVal = initAttr.getValue().convertToDouble();
}

return paramOpEntry->second;
return retVal;
}

//===----------------------------------------------------------------------===//
Expand Down
34 changes: 8 additions & 26 deletions lib/Dialect/QCS/Utils/ParameterInitialValueAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include "Dialect/QCS/Utils/ParameterInitialValueAnalysis.h"

#include "Dialect/QCS/IR/QCSOps.h"
#include "Dialect/QUIR/IR/QUIRAttributes.h"

#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
Expand Down Expand Up @@ -60,30 +59,13 @@ ParameterInitialValueAnalysis::ParameterInitialValueAnalysis(
for (auto &region : moduleOp->getRegions())
for (auto &block : region.getBlocks())
for (auto &op : block.getOperations()) {
auto declareParameterOp = dyn_cast<DeclareParameterOp>(op);
if (!declareParameterOp)
auto parameterLoadOp = dyn_cast<ParameterLoadOp>(op);
if (!parameterLoadOp)
continue;

double initial_value = 0.0;
if (declareParameterOp.getInitialValue().has_value()) {
auto angleAttr = declareParameterOp.getInitialValue()
.value()
.dyn_cast<mlir::quir::AngleAttr>();
auto floatAttr = declareParameterOp.getInitialValue()
.value()
.dyn_cast<FloatAttr>();
if (!(angleAttr || floatAttr))
declareParameterOp.emitError(
"Parameters are currently limited to "
"angles or float[64] only.");

if (angleAttr)
initial_value = angleAttr.getValue().convertToDouble();

if (floatAttr)
initial_value = floatAttr.getValue().convertToDouble();
}
initial_values_[declareParameterOp.getSymName()] = initial_value;
const double initialValue =
std::get<double>(parameterLoadOp.getInitialValue());
initialValues_[parameterLoadOp.getParameterName()] = initialValue;
foundParameters = true;
}
if (!foundParameters) {
Expand All @@ -98,9 +80,9 @@ ParameterInitialValueAnalysis::ParameterInitialValueAnalysis(

// debugging / test print out
if (printAnalysisEntries) {
for (auto &initial_value : initial_values_) {
llvm::outs() << initial_value.first() << " = "
<< std::get<double>(initial_value.second) << "\n";
for (auto &initialValue : initialValues_) {
llvm::outs() << initialValue.first() << " = "
<< std::get<double>(initialValue.second) << "\n";
}
}
}
Expand Down
12 changes: 0 additions & 12 deletions lib/Dialect/QUIR/Transforms/LoadElimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
#include "Dialect/QUIR/Transforms/LoadElimination.h"

#include "Dialect/OQ3/IR/OQ3Ops.h"
#include "Dialect/QUIR/IR/QUIRAttributes.h"

#include "mlir/IR/Dominance.h"
#include "mlir/IR/SymbolTable.h"
Expand Down Expand Up @@ -81,17 +80,6 @@ void LoadEliminationPass::runOnOperation() {

auto varAssignmentOp = mlir::cast<mlir::oq3::VariableAssignOp>(assignment);

// Transfer marker for input parameters
// Note: for arith.constant operations, canonicalization will drop these
// attributes and we need to find another way (to be specific:
// canonicalization will move constants to the begin of ops like Functions
// by means of dialect->materializeConstant(...) that creates new
// constants). For now and for angle constants, this approach is good-enough
// while not satisfying.
if (decl.isInputVariable())
varAssignmentOp.getAssignedValue().getDefiningOp()->setAttr(
mlir::quir::getInputParameterAttrName(), decl.getNameAttr());

for (auto *userOp : symbolUses) {

if (!mlir::isa<mlir::oq3::VariableLoadOp>(userOp))
Expand Down
Loading

0 comments on commit 7d71243

Please sign in to comment.