diff --git a/src/braket/experimental/autoqasm/operators/assignments.py b/src/braket/experimental/autoqasm/operators/assignments.py index 0419399cb..f85ddecd2 100644 --- a/src/braket/experimental/autoqasm/operators/assignments.py +++ b/src/braket/experimental/autoqasm/operators/assignments.py @@ -144,19 +144,20 @@ def assign_stmt(target_name: str, value: Any) -> Any: value = types.wrap_value(value) - if not isinstance(value, oqpy.base.Var): - return value - - if is_target_name_used: + if is_target_name_used and isinstance(value, (oqpy.base.Var, oqpy.base.OQPyExpression)): target = _get_oqpy_program_variable(target_name) _validate_assignment_types(target, value) - else: + elif isinstance(value, oqpy.base.Var): target = copy.copy(value) target.init_expression = None target.name = target_name + else: + return value oqpy_program = program_conversion_context.get_oqpy_program() - if is_value_name_used or value.init_expression is None: + + value_init_expression = value.init_expression if isinstance(value, oqpy.base.Var) else None + if is_value_name_used or value_init_expression is None: # Directly assign the value to the target. # For example: # a = b; @@ -170,17 +171,17 @@ def assign_stmt(target_name: str, value: Any) -> Any: # For example: # int[32] a = 10; # where `a` is at the root scope of the function (not inside any if/for/while block). - target.init_expression = value.init_expression - oqpy_program.declare(target) + target.init_expression = value_init_expression + oqpy_program._add_var(target) else: - # Set to `value.init_expression` to avoid declaring an unnecessary variable. + # Set to `value_init_expression` to avoid declaring an unnecessary variable. # The variable will be set in the current scope and auto-declared at the root scope. # For example, the `a = 1` and `a = 0` statements in the following: # int[32] a; # if (b == True) { a = 1; } # else { a = 0; } # where `b` is previously declared. - oqpy_program.set(target, value.init_expression) + oqpy_program.set(target, value_init_expression) return target @@ -211,12 +212,14 @@ def _validate_assignment_types(var1: oqpy.base.Var, var2: oqpy.base.Var) -> None "Variables in assignment statements must have the same type" ) - if isinstance(var1, oqpy.ArrayVar): + if isinstance(var1, oqpy.ArrayVar) and isinstance(var2, oqpy.ArrayVar): if var1.dimensions != var2.dimensions: raise errors.InvalidAssignmentStatement( "Arrays in assignment statements must have the same dimensions" ) - elif isinstance(var1, oqpy.classical_types._SizedVar): + elif isinstance(var1, oqpy.classical_types._SizedVar) and isinstance( + var2, oqpy.classical_types._SizedVar + ): var1_size = var1.size or 1 var2_size = var2.size or 1 if var1_size != var2_size: diff --git a/src/braket/experimental/autoqasm/program/program.py b/src/braket/experimental/autoqasm/program/program.py index 63da3b306..f19612d07 100644 --- a/src/braket/experimental/autoqasm/program/program.py +++ b/src/braket/experimental/autoqasm/program/program.py @@ -25,6 +25,7 @@ import oqpy.base import pygments +from openpulse import ast from openqasm_pygments import OpenQASM3Lexer from pygments.formatters.terminal import TerminalFormatter from sympy import Symbol @@ -509,10 +510,26 @@ def add_io_declarations(self) -> None: root_oqpy_program.undeclared_vars[parameter.name]._needs_declaration = True else: root_oqpy_program._add_var(parameter) + for parameter_name, parameter in self._output_parameters.items(): # Before adding the output variable to the program, remove any existing reference - root_oqpy_program.undeclared_vars.pop(parameter_name, None) - root_oqpy_program.declared_vars.pop(parameter_name, None) + popped_undeclared = root_oqpy_program.undeclared_vars.pop(parameter_name, None) + popped_declared = root_oqpy_program.declared_vars.pop(parameter_name, None) + + # Verify that we didn't find it in both lists + assert popped_undeclared is None or popped_declared is None + + popped = popped_undeclared if popped_undeclared is not None else popped_declared + if popped is not None and popped.init_expression is not None: + # Add an assignment statement to the beginning of the program to initialize + # the output parameter to the desired value. + # TODO: This logic uses oqpy internals - should it be moved into oqpy? + init_stmt = ast.ClassicalAssignment( + ast.Identifier(name=parameter_name), + ast.AssignmentOperator["="], + oqpy.base.to_ast(root_oqpy_program, popped.init_expression), + ) + root_oqpy_program._state.body.insert(0, init_stmt) parameter.name = parameter_name root_oqpy_program._add_var(parameter) diff --git a/test/unit_tests/braket/experimental/autoqasm/test_api.py b/test/unit_tests/braket/experimental/autoqasm/test_api.py index 6c7e7b561..8cc8d6c16 100644 --- a/test/unit_tests/braket/experimental/autoqasm/test_api.py +++ b/test/unit_tests/braket/experimental/autoqasm/test_api.py @@ -251,8 +251,8 @@ def bell_measurement_declared() -> None: def test_bell_measurement_declared(bell_measurement_declared) -> None: expected = """OPENQASM 3.0; -qubit[2] __qubits__; bit[2] c = "00"; +qubit[2] __qubits__; h __qubits__[0]; cnot __qubits__[0], __qubits__[1]; bit[2] __bit_1__ = "00"; @@ -863,16 +863,13 @@ def classical_variables_types() -> None: def test_classical_variables_types(classical_variables_types): expected = """OPENQASM 3.0; -bit a = 0; -a = 1; +bit a = 1; int[32] i = 1; bit[2] a_array = "00"; +int[32] b = 15; +float[64] c = 3.4; a_array[0] = 0; -a_array[i] = 1; -int[32] b = 10; -b = 15; -float[64] c = 1.2; -c = 3.4;""" +a_array[i] = 1;""" assert classical_variables_types.build().to_ir() == expected @@ -889,9 +886,8 @@ def prog() -> None: a = b # declared target, declared value # noqa: F841 expected = """OPENQASM 3.0; +int[32] a = 2; int[32] b; -int[32] a = 1; -a = 2; b = a; a = b;""" assert prog.build().to_ir() == expected diff --git a/test/unit_tests/braket/experimental/autoqasm/test_converters.py b/test/unit_tests/braket/experimental/autoqasm/test_converters.py index a916f3224..534f44e97 100644 --- a/test/unit_tests/braket/experimental/autoqasm/test_converters.py +++ b/test/unit_tests/braket/experimental/autoqasm/test_converters.py @@ -54,13 +54,12 @@ def fn() -> None: qasm = program_conversion_context.make_program().to_ir() expected_qasm = """OPENQASM 3.0; -int[32] e; -int[32] a = 5; +int[32] a = 1; float[64] b = 1.2; -a = 1; -e = a; +int[32] e; bool f = false; bool g = true; +e = a; g = f;""" assert qasm == expected_qasm diff --git a/test/unit_tests/braket/experimental/autoqasm/test_operators.py b/test/unit_tests/braket/experimental/autoqasm/test_operators.py index 78d4d095c..5280c0320 100644 --- a/test/unit_tests/braket/experimental/autoqasm/test_operators.py +++ b/test/unit_tests/braket/experimental/autoqasm/test_operators.py @@ -23,7 +23,7 @@ import braket.experimental.autoqasm as aq from braket.experimental.autoqasm import errors from braket.experimental.autoqasm.errors import UnsupportedConditionalExpressionError -from braket.experimental.autoqasm.instructions import cnot, h, measure, x +from braket.experimental.autoqasm.instructions import cnot, h, measure, rx, x @pytest.fixture @@ -162,8 +162,8 @@ def branch_assignment_declared(): a = aq.IntVar(7) # noqa: F841 expected = """OPENQASM 3.0; -bool __bool_1__ = true; int[32] a = 5; +bool __bool_1__ = true; if (__bool_1__) { a = 6; } else { @@ -173,6 +173,52 @@ def branch_assignment_declared(): assert branch_assignment_declared.build().to_ir() == expected +def test_iterative_assignment() -> None: + """Tests a for loop where a variable is updated on each iteration.""" + + @aq.main(num_qubits=3) + def iterative_assignment(): + val = aq.FloatVar(0.5) + for q in aq.qubits: + val = val + measure(q) + rx(0, val) + + expected = """OPENQASM 3.0; +float[64] val = 0.5; +qubit[3] __qubits__; +for int q in [0:3 - 1] { + bit __bit_1__; + __bit_1__ = measure __qubits__[q]; + val = val + __bit_1__; + rx(val) __qubits__[0]; +}""" + + assert iterative_assignment.build().to_ir() == expected + + +def test_iterative_output_assignment() -> None: + """Tests a for loop where an output variable is updated on each iteration.""" + + @aq.main(num_qubits=3) + def iterative_output_assignment(): + val = aq.FloatVar(0.5) + for q in aq.range(3): + val = val + measure(q) + return val + + expected = """OPENQASM 3.0; +output float[64] val; +val = 0.5; +qubit[3] __qubits__; +for int q in [0:3 - 1] { + bit __bit_1__; + __bit_1__ = measure __qubits__[q]; + val = val + __bit_1__; +}""" + + assert iterative_output_assignment.build().to_ir() == expected + + def for_body(i: aq.Qubit) -> None: h(i) @@ -655,9 +701,9 @@ def measure_to_slice(): b0[3] = c expected = """OPENQASM 3.0; +bit[10] b0 = "0000000000"; bit c; qubit[1] __qubits__; -bit[10] b0 = "0000000000"; bit __bit_1__; __bit_1__ = measure __qubits__[0]; c = __bit_1__; diff --git a/test/unit_tests/braket/experimental/autoqasm/test_parameters.py b/test/unit_tests/braket/experimental/autoqasm/test_parameters.py index 3ad96774b..bf91b00b7 100644 --- a/test/unit_tests/braket/experimental/autoqasm/test_parameters.py +++ b/test/unit_tests/braket/experimental/autoqasm/test_parameters.py @@ -529,13 +529,20 @@ def parametric_explicit(): with pytest.raises(RuntimeError, match="conflicting variables with name alpha"): parametric_explicit.build() + +def test_assignment_to_input_variable_name(): + """Test assigning to overwrite an input variable within the program.""" + @aq.main def parametric(alpha): - alpha = aq.FloatVar(1.2) # noqa: F841 + alpha = aq.FloatVar(1.2) rx(0, alpha) - with pytest.raises(RuntimeError, match="conflicting variables with name alpha"): - parametric.build() + expected = """OPENQASM 3.0; +float[64] alpha = 1.2; +qubit[1] __qubits__; +rx(alpha) __qubits__[0];""" + assert parametric.build().to_ir() == expected def test_binding_variable_fails(): diff --git a/test/unit_tests/braket/experimental/autoqasm/test_types.py b/test/unit_tests/braket/experimental/autoqasm/test_types.py index 03083d452..f3bad422c 100644 --- a/test/unit_tests/braket/experimental/autoqasm/test_types.py +++ b/test/unit_tests/braket/experimental/autoqasm/test_types.py @@ -159,9 +159,9 @@ def ret_test() -> int: def add(int[32] a, int[32] b) -> int[32] { return a + b; } -output int[32] return_value; int[32] a = 5; int[32] b = 6; +output int[32] return_value; int[32] __int_2__; __int_2__ = add(a, b); return_value = __int_2__;""" @@ -194,8 +194,8 @@ def declare_array(): expected = """OPENQASM 3.0; array[int[32], 3] a = {1, 2, 3}; -a[0] = 11; array[int[32], 3] b = {4, 5, 6}; +a[0] = 11; b[2] = 14; b = a;""" @@ -517,9 +517,9 @@ def main(): expected_qasm = """OPENQASM 3.0; def retval_recursive() -> int[32] { + int[32] retval_ = 1; int[32] __int_1__; __int_1__ = retval_recursive(); - int[32] retval_ = 1; return retval_; } int[32] __int_3__; @@ -543,10 +543,10 @@ def main(): expected_qasm = """OPENQASM 3.0; def retval_recursive() -> int[32] { int[32] a; + int[32] retval_ = 1; int[32] __int_1__; __int_1__ = retval_recursive(); a = __int_1__; - int[32] retval_ = 1; return retval_; } int[32] __int_3__;