Skip to content

Commit

Permalink
Merge branch 'main' into jcjaskula-aws/add_in_place_modifications
Browse files Browse the repository at this point in the history
  • Loading branch information
jcjaskula-aws committed Mar 18, 2024
1 parent fb1db5d commit 09572f4
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 85 deletions.
22 changes: 11 additions & 11 deletions src/braket/aws/aws_quantum_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def create(
disable_qubit_rewiring: bool = False,
tags: dict[str, str] | None = None,
inputs: dict[str, float] | None = None,
gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]] | None = None,
gate_definitions: dict[tuple[Gate, QubitSet], PulseSequence] | None = None,
quiet: bool = False,
reservation_arn: str | None = None,
*args,
Expand Down Expand Up @@ -148,10 +148,9 @@ def create(
IR. If the IR supports inputs, the inputs will be updated with this value.
Default: {}.
gate_definitions (Optional[dict[tuple[Gate, QubitSet], PulseSequence]] | None):
A `Dict` for user defined gate calibration. The calibration is defined for
for a particular `Gate` on a particular `QubitSet` and is represented by
a `PulseSequence`.
gate_definitions (dict[tuple[Gate, QubitSet], PulseSequence] | None): A `dict`
of user defined gate calibrations. Each calibration is defined for a particular
`Gate` on a particular `QubitSet` and is represented by a `PulseSequence`.
Default: None.
quiet (bool): Sets the verbosity of the logger to low and does not report queue
Expand Down Expand Up @@ -190,6 +189,7 @@ def create(
if tags is not None:
create_task_kwargs.update({"tags": tags})
inputs = inputs or {}
gate_definitions = gate_definitions or {}

if reservation_arn:
create_task_kwargs.update(
Expand Down Expand Up @@ -561,7 +561,7 @@ def _create_internal(
device_parameters: Union[dict, BraketSchemaBase],
disable_qubit_rewiring: bool,
inputs: dict[str, float],
gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]],
gate_definitions: dict[tuple[Gate, QubitSet], PulseSequence],
*args,
**kwargs,
) -> AwsQuantumTask:
Expand All @@ -577,7 +577,7 @@ def _(
_device_parameters: Union[dict, BraketSchemaBase], # Not currently used for OpenQasmProgram
_disable_qubit_rewiring: bool,
inputs: dict[str, float],
gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]],
gate_definitions: dict[tuple[Gate, QubitSet], PulseSequence],
*args,
**kwargs,
) -> AwsQuantumTask:
Expand All @@ -600,7 +600,7 @@ def _(
device_parameters: Union[dict, BraketSchemaBase],
_disable_qubit_rewiring: bool,
inputs: dict[str, float],
gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]],
gate_definitions: dict[tuple[Gate, QubitSet], PulseSequence],
*args,
**kwargs,
) -> AwsQuantumTask:
Expand Down Expand Up @@ -639,7 +639,7 @@ def _(
_device_parameters: Union[dict, BraketSchemaBase],
_disable_qubit_rewiring: bool,
inputs: dict[str, float],
gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]],
gate_definitions: dict[tuple[Gate, QubitSet], PulseSequence],
*args,
**kwargs,
) -> AwsQuantumTask:
Expand All @@ -657,7 +657,7 @@ def _(
device_parameters: Union[dict, BraketSchemaBase],
disable_qubit_rewiring: bool,
inputs: dict[str, float],
gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]],
gate_definitions: dict[tuple[Gate, QubitSet], PulseSequence],
*args,
**kwargs,
) -> AwsQuantumTask:
Expand All @@ -678,7 +678,7 @@ def _(
if (
disable_qubit_rewiring
or Instruction(StartVerbatimBox()) in circuit.instructions
or gate_definitions is not None
or gate_definitions
or any(isinstance(instruction.operator, PulseGate) for instruction in circuit.instructions)
):
qubit_reference_type = QubitReferenceType.PHYSICAL
Expand Down
96 changes: 66 additions & 30 deletions src/braket/aws/aws_quantum_task_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@
from braket.aws.aws_quantum_task import AwsQuantumTask
from braket.aws.aws_session import AwsSession
from braket.circuits import Circuit
from braket.circuits.gate import Gate
from braket.ir.blackbird import Program as BlackbirdProgram
from braket.ir.openqasm import Program as OpenQasmProgram
from braket.pulse.pulse_sequence import PulseSequence
from braket.registers.qubit_set import QubitSet
from braket.tasks.quantum_task_batch import QuantumTaskBatch


Expand Down Expand Up @@ -61,6 +64,13 @@ def __init__(
poll_timeout_seconds: float = AwsQuantumTask.DEFAULT_RESULTS_POLL_TIMEOUT,
poll_interval_seconds: float = AwsQuantumTask.DEFAULT_RESULTS_POLL_INTERVAL,
inputs: Union[dict[str, float], list[dict[str, float]]] | None = None,
gate_definitions: (
Union[
dict[tuple[Gate, QubitSet], PulseSequence],
list[dict[tuple[Gate, QubitSet], PulseSequence]],
]
| None
) = None,
reservation_arn: str | None = None,
*aws_quantum_task_args: Any,
**aws_quantum_task_kwargs: Any,
Expand Down Expand Up @@ -92,6 +102,9 @@ def __init__(
inputs (Union[dict[str, float], list[dict[str, float]]] | None): Inputs to be passed
along with the IR. If the IR supports inputs, the inputs will be updated
with this value. Default: {}.
gate_definitions (Union[dict[tuple[Gate, QubitSet], PulseSequence], list[dict[tuple[Gate, QubitSet], PulseSequence]]] | None): # noqa: E501
User-defined gate calibration. The calibration is defined for a particular `Gate` on a
particular `QubitSet` and is represented by a `PulseSequence`. Default: None.
reservation_arn (str | None): The reservation ARN provided by Braket Direct
to reserve exclusive usage for the device to run the quantum task on.
Note: If you are creating tasks in a job that itself was created reservation ARN,
Expand All @@ -111,6 +124,7 @@ def __init__(
poll_timeout_seconds,
poll_interval_seconds,
inputs,
gate_definitions,
reservation_arn,
*aws_quantum_task_args,
**aws_quantum_task_kwargs,
Expand All @@ -134,7 +148,7 @@ def __init__(
self._aws_quantum_task_kwargs = aws_quantum_task_kwargs

@staticmethod
def _tasks_and_inputs(
def _tasks_inputs_gatedefs(
task_specifications: Union[
Union[Circuit, Problem, OpenQasmProgram, BlackbirdProgram, AnalogHamiltonianSimulation],
list[
Expand All @@ -144,45 +158,55 @@ def _tasks_and_inputs(
],
],
inputs: Union[dict[str, float], list[dict[str, float]]] = None,
gate_definitions: Union[
dict[tuple[Gate, QubitSet], PulseSequence],
list[dict[tuple[Gate, QubitSet], PulseSequence]],
] = None,
) -> list[
tuple[
Union[Circuit, Problem, OpenQasmProgram, BlackbirdProgram, AnalogHamiltonianSimulation],
dict[str, float],
dict[tuple[Gate, QubitSet], PulseSequence],
]
]:
inputs = inputs or {}

max_inputs_tasks = 1
single_task = isinstance(
task_specifications,
(Circuit, Problem, OpenQasmProgram, BlackbirdProgram, AnalogHamiltonianSimulation),
)
single_input = isinstance(inputs, dict)

max_inputs_tasks = (
max(max_inputs_tasks, len(task_specifications)) if not single_task else max_inputs_tasks
)
max_inputs_tasks = (
max(max_inputs_tasks, len(inputs)) if not single_input else max_inputs_tasks
gate_definitions = gate_definitions or {}

single_task_type = (
Circuit,
Problem,
OpenQasmProgram,
BlackbirdProgram,
AnalogHamiltonianSimulation,
)
single_input_type = dict
single_gate_definitions_type = dict

if not single_task and not single_input:
if len(task_specifications) != len(inputs):
raise ValueError("Multiple inputs and task specifications must be equal in number.")
if single_task:
task_specifications = repeat(task_specifications, times=max_inputs_tasks)
args = [task_specifications, inputs, gate_definitions]
single_arg_types = [single_task_type, single_input_type, single_gate_definitions_type]

if single_input:
inputs = repeat(inputs, times=max_inputs_tasks)
batch_length = 1
arg_lengths = []
for arg, single_arg_type in zip(args, single_arg_types):
arg_length = 1 if isinstance(arg, single_arg_type) else len(arg)
arg_lengths.append(arg_length)

tasks_and_inputs = zip(task_specifications, inputs)
if arg_length != 1:
if batch_length != 1 and arg_length != batch_length:
raise ValueError(
"Multiple inputs, task specifications and gate definitions must "
"be equal in length."
)
else:
batch_length = arg_length

if single_task and single_input:
tasks_and_inputs = list(tasks_and_inputs)
for i, arg_length in enumerate(arg_lengths):
if arg_length == 1:
args[i] = repeat(args[i], batch_length)

tasks_and_inputs = list(tasks_and_inputs)
tasks_inputs_definitions = list(zip(*args))

for task_specification, input_map in tasks_and_inputs:
for task_specification, input_map, _gate_definitions in tasks_inputs_definitions:
if isinstance(task_specification, Circuit):
param_names = {param.name for param in task_specification.parameters}
unbounded_parameters = param_names - set(input_map.keys())
Expand All @@ -192,7 +216,7 @@ def _tasks_and_inputs(
f"{unbounded_parameters}"
)

return tasks_and_inputs
return tasks_inputs_definitions

@staticmethod
def _execute(
Expand All @@ -213,13 +237,22 @@ def _execute(
poll_timeout_seconds: float = AwsQuantumTask.DEFAULT_RESULTS_POLL_TIMEOUT,
poll_interval_seconds: float = AwsQuantumTask.DEFAULT_RESULTS_POLL_INTERVAL,
inputs: Union[dict[str, float], list[dict[str, float]]] = None,
gate_definitions: (
Union[
dict[tuple[Gate, QubitSet], PulseSequence],
list[dict[tuple[Gate, QubitSet], PulseSequence]],
]
| None
) = None,
reservation_arn: str | None = None,
*args,
**kwargs,
) -> list[AwsQuantumTask]:
tasks_and_inputs = AwsQuantumTaskBatch._tasks_and_inputs(task_specifications, inputs)
tasks_inputs_gatedefs = AwsQuantumTaskBatch._tasks_inputs_gatedefs(
task_specifications, inputs, gate_definitions
)
max_threads = min(max_parallel, max_workers)
remaining = [0 for _ in tasks_and_inputs]
remaining = [0 for _ in tasks_inputs_gatedefs]
try:
with ThreadPoolExecutor(max_workers=max_threads) as executor:
task_futures = [
Expand All @@ -234,11 +267,12 @@ def _execute(
poll_timeout_seconds=poll_timeout_seconds,
poll_interval_seconds=poll_interval_seconds,
inputs=input_map,
gate_definitions=gatedefs,
reservation_arn=reservation_arn,
*args,
**kwargs,
)
for task, input_map in tasks_and_inputs
for task, input_map, gatedefs in tasks_inputs_gatedefs
]
except KeyboardInterrupt:
# If an exception is thrown before the thread pool has finished,
Expand Down Expand Up @@ -266,6 +300,7 @@ def _create_task(
shots: int,
poll_interval_seconds: float = AwsQuantumTask.DEFAULT_RESULTS_POLL_INTERVAL,
inputs: dict[str, float] = None,
gate_definitions: dict[tuple[Gate, QubitSet], PulseSequence] | None = None,
reservation_arn: str | None = None,
*args,
**kwargs,
Expand All @@ -278,6 +313,7 @@ def _create_task(
shots,
poll_interval_seconds=poll_interval_seconds,
inputs=inputs,
gate_definitions=gate_definitions,
reservation_arn=reservation_arn,
*args,
**kwargs,
Expand Down
Loading

0 comments on commit 09572f4

Please sign in to comment.