Skip to content

Commit

Permalink
fix merge
Browse files Browse the repository at this point in the history
  • Loading branch information
jcjaskula-aws committed Mar 18, 2024
1 parent 5b43536 commit fb1db5d
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 17 deletions.
13 changes: 0 additions & 13 deletions src/braket/parametric/free_parameter_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,6 @@ def __init__(self, expression: Union[FreeParameterExpression, Number, sympy.Expr
Args:
expression (Union[FreeParameterExpression, Number, Expr, str]): The expression to use.
_type (Optional[ClassicalType]): The OpenQASM3 type associated with the expression.
Subtypes of openqasm3.ast.ClassicalType are used to specify how to express the
expression in the OpenQASM3 IR. Any type other than DurationType is considered
as FloatType.
Raises:
NotImplementedError: Raised if the expression is not of type
Expand All @@ -61,7 +57,6 @@ def __init__(self, expression: Union[FreeParameterExpression, Number, sympy.Expr
ast.Pow: self.__pow__,
ast.USub: self.__neg__,
}
self._type = _type if _type is not None else FloatType()
if isinstance(expression, FreeParameterExpression):
self._expression = expression.expression
elif isinstance(expression, (Number, sympy.Expr)):
Expand All @@ -70,7 +65,6 @@ def __init__(self, expression: Union[FreeParameterExpression, Number, sympy.Expr
self._expression = self._parse_string_expression(expression).expression
else:
raise NotImplementedError
self._validate_type()

@property
def expression(self) -> Union[Number, sympy.Expr]:
Expand Down Expand Up @@ -109,13 +103,6 @@ def subs(
else:
return FreeParameterExpression(subbed_expr)

def _validate_type(self) -> None:
if not isinstance(self._type, (FloatType, DurationType)):
raise TypeError(
"FreeParameterExpression must be of type openqasm3.ast.FloatType "
"or openqasm3.ast.DurationType"
)

def _parse_string_expression(self, expression: str) -> FreeParameterExpression:
return self._eval_operation(ast.parse(expression, mode="eval").body)

Expand Down
2 changes: 1 addition & 1 deletion src/braket/pulse/ast/free_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def visit_Identifier(
using the given parameter values.
Args:
identifier (_FreeParameterExpressionIdentifier): The identifier.
identifier (Identifier): The identifier.
Returns:
Union[Identifier, FloatLiteral]: The transformed identifier.
Expand Down
22 changes: 19 additions & 3 deletions src/braket/pulse/waveforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,16 @@

import numpy as np
import openpulse.ast as ast
from oqpy import WaveformVar, bool_, complex128, declare_waveform_generator, duration, float64
from oqpy.base import OQPyExpression
from oqpy import (
WaveformVar,
bool_,
complex128,
convert_float_to_duration,
declare_waveform_generator,
duration,
float64,
)
from oqpy.base import OQPyExpression, to_ast

from braket.parametric.free_parameter import FreeParameter
from braket.parametric.free_parameter_expression import (
Expand Down Expand Up @@ -73,8 +81,16 @@ def _modify_oqpy_waveform_var(
self, key: str, value: Any, type_: ast.ClassicalType = float64
) -> None:
if self._pulse_sequence is not None:
self._pulse_sequence._register_free_parameters(value)
self._pulse_sequence._program.undeclared_vars[self.id].init_expression.args[key] = (
self._pulse_sequence._format_parameter_ast(value, type_)
to_ast(
self._pulse_sequence._program,
(
convert_float_to_duration(value)
if isinstance(type_, ast.DurationType)
else value
),
)
)

@abstractmethod
Expand Down

0 comments on commit fb1db5d

Please sign in to comment.