Skip to content

Commit

Permalink
fix: iterative variable assignments in AutoQASM (#930)
Browse files Browse the repository at this point in the history
  • Loading branch information
rmshaffer authored Apr 2, 2024
1 parent 61ae9dd commit 790ede5
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 38 deletions.
27 changes: 15 additions & 12 deletions src/braket/experimental/autoqasm/operators/assignments.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,19 +144,20 @@ def assign_stmt(target_name: str, value: Any) -> Any:

value = types.wrap_value(value)

if not isinstance(value, oqpy.base.Var):
return value

if is_target_name_used:
if is_target_name_used and isinstance(value, (oqpy.base.Var, oqpy.base.OQPyExpression)):
target = _get_oqpy_program_variable(target_name)
_validate_assignment_types(target, value)
else:
elif isinstance(value, oqpy.base.Var):
target = copy.copy(value)
target.init_expression = None
target.name = target_name
else:
return value

oqpy_program = program_conversion_context.get_oqpy_program()
if is_value_name_used or value.init_expression is None:

value_init_expression = value.init_expression if isinstance(value, oqpy.base.Var) else None
if is_value_name_used or value_init_expression is None:
# Directly assign the value to the target.
# For example:
# a = b;
Expand All @@ -170,17 +171,17 @@ def assign_stmt(target_name: str, value: Any) -> Any:
# For example:
# int[32] a = 10;
# where `a` is at the root scope of the function (not inside any if/for/while block).
target.init_expression = value.init_expression
oqpy_program.declare(target)
target.init_expression = value_init_expression
oqpy_program._add_var(target)
else:
# Set to `value.init_expression` to avoid declaring an unnecessary variable.
# Set to `value_init_expression` to avoid declaring an unnecessary variable.
# The variable will be set in the current scope and auto-declared at the root scope.
# For example, the `a = 1` and `a = 0` statements in the following:
# int[32] a;
# if (b == True) { a = 1; }
# else { a = 0; }
# where `b` is previously declared.
oqpy_program.set(target, value.init_expression)
oqpy_program.set(target, value_init_expression)

return target

Expand Down Expand Up @@ -211,12 +212,14 @@ def _validate_assignment_types(var1: oqpy.base.Var, var2: oqpy.base.Var) -> None
"Variables in assignment statements must have the same type"
)

if isinstance(var1, oqpy.ArrayVar):
if isinstance(var1, oqpy.ArrayVar) and isinstance(var2, oqpy.ArrayVar):
if var1.dimensions != var2.dimensions:
raise errors.InvalidAssignmentStatement(
"Arrays in assignment statements must have the same dimensions"
)
elif isinstance(var1, oqpy.classical_types._SizedVar):
elif isinstance(var1, oqpy.classical_types._SizedVar) and isinstance(
var2, oqpy.classical_types._SizedVar
):
var1_size = var1.size or 1
var2_size = var2.size or 1
if var1_size != var2_size:
Expand Down
21 changes: 19 additions & 2 deletions src/braket/experimental/autoqasm/program/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import oqpy.base
import pygments
from openpulse import ast
from openqasm_pygments import OpenQASM3Lexer
from pygments.formatters.terminal import TerminalFormatter
from sympy import Symbol
Expand Down Expand Up @@ -509,10 +510,26 @@ def add_io_declarations(self) -> None:
root_oqpy_program.undeclared_vars[parameter.name]._needs_declaration = True
else:
root_oqpy_program._add_var(parameter)

for parameter_name, parameter in self._output_parameters.items():
# Before adding the output variable to the program, remove any existing reference
root_oqpy_program.undeclared_vars.pop(parameter_name, None)
root_oqpy_program.declared_vars.pop(parameter_name, None)
popped_undeclared = root_oqpy_program.undeclared_vars.pop(parameter_name, None)
popped_declared = root_oqpy_program.declared_vars.pop(parameter_name, None)

# Verify that we didn't find it in both lists
assert popped_undeclared is None or popped_declared is None

popped = popped_undeclared if popped_undeclared is not None else popped_declared
if popped is not None and popped.init_expression is not None:
# Add an assignment statement to the beginning of the program to initialize
# the output parameter to the desired value.
# TODO: This logic uses oqpy internals - should it be moved into oqpy?
init_stmt = ast.ClassicalAssignment(
ast.Identifier(name=parameter_name),
ast.AssignmentOperator["="],
oqpy.base.to_ast(root_oqpy_program, popped.init_expression),
)
root_oqpy_program._state.body.insert(0, init_stmt)

parameter.name = parameter_name
root_oqpy_program._add_var(parameter)
Expand Down
16 changes: 6 additions & 10 deletions test/unit_tests/braket/experimental/autoqasm/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,8 @@ def bell_measurement_declared() -> None:

def test_bell_measurement_declared(bell_measurement_declared) -> None:
expected = """OPENQASM 3.0;
qubit[2] __qubits__;
bit[2] c = "00";
qubit[2] __qubits__;
h __qubits__[0];
cnot __qubits__[0], __qubits__[1];
bit[2] __bit_1__ = "00";
Expand Down Expand Up @@ -863,16 +863,13 @@ def classical_variables_types() -> None:

def test_classical_variables_types(classical_variables_types):
expected = """OPENQASM 3.0;
bit a = 0;
a = 1;
bit a = 1;
int[32] i = 1;
bit[2] a_array = "00";
int[32] b = 15;
float[64] c = 3.4;
a_array[0] = 0;
a_array[i] = 1;
int[32] b = 10;
b = 15;
float[64] c = 1.2;
c = 3.4;"""
a_array[i] = 1;"""
assert classical_variables_types.build().to_ir() == expected


Expand All @@ -889,9 +886,8 @@ def prog() -> None:
a = b # declared target, declared value # noqa: F841

expected = """OPENQASM 3.0;
int[32] a = 2;
int[32] b;
int[32] a = 1;
a = 2;
b = a;
a = b;"""
assert prog.build().to_ir() == expected
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,12 @@ def fn() -> None:

qasm = program_conversion_context.make_program().to_ir()
expected_qasm = """OPENQASM 3.0;
int[32] e;
int[32] a = 5;
int[32] a = 1;
float[64] b = 1.2;
a = 1;
e = a;
int[32] e;
bool f = false;
bool g = true;
e = a;
g = f;"""
assert qasm == expected_qasm

Expand Down
52 changes: 49 additions & 3 deletions test/unit_tests/braket/experimental/autoqasm/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import braket.experimental.autoqasm as aq
from braket.experimental.autoqasm import errors
from braket.experimental.autoqasm.errors import UnsupportedConditionalExpressionError
from braket.experimental.autoqasm.instructions import cnot, h, measure, x
from braket.experimental.autoqasm.instructions import cnot, h, measure, rx, x


@pytest.fixture
Expand Down Expand Up @@ -162,8 +162,8 @@ def branch_assignment_declared():
a = aq.IntVar(7) # noqa: F841

expected = """OPENQASM 3.0;
bool __bool_1__ = true;
int[32] a = 5;
bool __bool_1__ = true;
if (__bool_1__) {
a = 6;
} else {
Expand All @@ -173,6 +173,52 @@ def branch_assignment_declared():
assert branch_assignment_declared.build().to_ir() == expected


def test_iterative_assignment() -> None:
"""Tests a for loop where a variable is updated on each iteration."""

@aq.main(num_qubits=3)
def iterative_assignment():
val = aq.FloatVar(0.5)
for q in aq.qubits:
val = val + measure(q)
rx(0, val)

expected = """OPENQASM 3.0;
float[64] val = 0.5;
qubit[3] __qubits__;
for int q in [0:3 - 1] {
bit __bit_1__;
__bit_1__ = measure __qubits__[q];
val = val + __bit_1__;
rx(val) __qubits__[0];
}"""

assert iterative_assignment.build().to_ir() == expected


def test_iterative_output_assignment() -> None:
"""Tests a for loop where an output variable is updated on each iteration."""

@aq.main(num_qubits=3)
def iterative_output_assignment():
val = aq.FloatVar(0.5)
for q in aq.range(3):
val = val + measure(q)
return val

expected = """OPENQASM 3.0;
output float[64] val;
val = 0.5;
qubit[3] __qubits__;
for int q in [0:3 - 1] {
bit __bit_1__;
__bit_1__ = measure __qubits__[q];
val = val + __bit_1__;
}"""

assert iterative_output_assignment.build().to_ir() == expected


def for_body(i: aq.Qubit) -> None:
h(i)

Expand Down Expand Up @@ -655,9 +701,9 @@ def measure_to_slice():
b0[3] = c

expected = """OPENQASM 3.0;
bit[10] b0 = "0000000000";
bit c;
qubit[1] __qubits__;
bit[10] b0 = "0000000000";
bit __bit_1__;
__bit_1__ = measure __qubits__[0];
c = __bit_1__;
Expand Down
13 changes: 10 additions & 3 deletions test/unit_tests/braket/experimental/autoqasm/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,13 +529,20 @@ def parametric_explicit():
with pytest.raises(RuntimeError, match="conflicting variables with name alpha"):
parametric_explicit.build()


def test_assignment_to_input_variable_name():
"""Test assigning to overwrite an input variable within the program."""

@aq.main
def parametric(alpha):
alpha = aq.FloatVar(1.2) # noqa: F841
alpha = aq.FloatVar(1.2)
rx(0, alpha)

with pytest.raises(RuntimeError, match="conflicting variables with name alpha"):
parametric.build()
expected = """OPENQASM 3.0;
float[64] alpha = 1.2;
qubit[1] __qubits__;
rx(alpha) __qubits__[0];"""
assert parametric.build().to_ir() == expected


def test_binding_variable_fails():
Expand Down
8 changes: 4 additions & 4 deletions test/unit_tests/braket/experimental/autoqasm/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,9 @@ def ret_test() -> int:
def add(int[32] a, int[32] b) -> int[32] {
return a + b;
}
output int[32] return_value;
int[32] a = 5;
int[32] b = 6;
output int[32] return_value;
int[32] __int_2__;
__int_2__ = add(a, b);
return_value = __int_2__;"""
Expand Down Expand Up @@ -194,8 +194,8 @@ def declare_array():

expected = """OPENQASM 3.0;
array[int[32], 3] a = {1, 2, 3};
a[0] = 11;
array[int[32], 3] b = {4, 5, 6};
a[0] = 11;
b[2] = 14;
b = a;"""

Expand Down Expand Up @@ -517,9 +517,9 @@ def main():

expected_qasm = """OPENQASM 3.0;
def retval_recursive() -> int[32] {
int[32] retval_ = 1;
int[32] __int_1__;
__int_1__ = retval_recursive();
int[32] retval_ = 1;
return retval_;
}
int[32] __int_3__;
Expand All @@ -543,10 +543,10 @@ def main():
expected_qasm = """OPENQASM 3.0;
def retval_recursive() -> int[32] {
int[32] a;
int[32] retval_ = 1;
int[32] __int_1__;
__int_1__ = retval_recursive();
a = __int_1__;
int[32] retval_ = 1;
return retval_;
}
int[32] __int_3__;
Expand Down

0 comments on commit 790ede5

Please sign in to comment.