Skip to content

Commit

Permalink
Merge branch 'main' into math411-patch-6
Browse files Browse the repository at this point in the history
  • Loading branch information
AbeCoull authored May 21, 2024
2 parents f6068da + 7ba54ce commit dc56f9e
Show file tree
Hide file tree
Showing 22 changed files with 551 additions and 44 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/check-format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ jobs:
pip install tox
- name: Run code format checks
run: |
tox -e linters_check
tox -e linters_check -p auto
16 changes: 16 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,21 @@
# Changelog

## v1.79.1 (2024-05-08)

### Bug Fixes and Other Changes

* check the qubit set length against observables

## v1.79.0 (2024-05-06)

### Features

* Direct Reservation context manager

### Documentation Changes

* correct the example in the measure docstring

## v1.78.0 (2024-04-18)

### Features
Expand Down
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,12 @@ To run linters and doc generators and unit tests:
tox
```

or if your machine can handle multithreaded workloads, run them in parallel with:

```bash
tox -p auto
```

### Integration Tests

First, configure a profile to use your account to interact with AWS. To learn more, see [Configure AWS CLI](https://docs.aws.amazon.com/cli/latest/userguide/cli-chap-configure.html).
Expand Down
17 changes: 14 additions & 3 deletions examples/reservation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,25 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

from braket.aws import AwsDevice
from braket.aws import AwsDevice, DirectReservation
from braket.circuits import Circuit
from braket.devices import Devices

bell = Circuit().h(0).cnot(0, 1)
device = AwsDevice(Devices.IonQ.Aria1)

# To run a task in a device reservation, change the device to the one you reserved
# and fill in your reservation ARN
task = device.run(bell, shots=100, reservation_arn="reservation ARN")
# and fill in your reservation ARN.
with DirectReservation(device, reservation_arn="<my_reservation_arn>"):
task = device.run(bell, shots=100)
print(task.result().measurement_counts)

# Alternatively, you may start the reservation globally
reservation = DirectReservation(device, reservation_arn="<my_reservation_arn>").start()
task = device.run(bell, shots=100)
print(task.result().measurement_counts)
reservation.stop() # stop creating tasks in the reservation

# Lastly, you may pass the reservation ARN directly to a quantum task
task = device.run(bell, shots=100, reservation_arn="<my_reservation_arn>")
print(task.result().measurement_counts)
Binary file removed model.tar.gz
Binary file not shown.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ test=pytest
xfail_strict = true
# https://pytest-xdist.readthedocs.io/en/latest/known-limitations.html
addopts =
--verbose -n logical --durations=0 --durations-min=1
--verbose -n logical --durations=0 --durations-min=1 --dist worksteal
testpaths = test/unit_tests
filterwarnings=
# Issue #557 in `pytest-cov` (currently v4.x) has not moved for a while now,
Expand Down
2 changes: 1 addition & 1 deletion src/braket/_sdk/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@
Version number (major.minor.patch[-label])
"""

__version__ = "1.78.1.dev0"
__version__ = "1.79.2.dev0"
1 change: 1 addition & 0 deletions src/braket/aws/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@
from braket.aws.aws_quantum_task import AwsQuantumTask # noqa: F401
from braket.aws.aws_quantum_task_batch import AwsQuantumTaskBatch # noqa: F401
from braket.aws.aws_session import AwsSession # noqa: F401
from braket.aws.direct_reservations import DirectReservation # noqa: F401
29 changes: 29 additions & 0 deletions src/braket/aws/aws_quantum_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
IRType,
OpenQASMSerializationProperties,
QubitReferenceType,
SerializableProgram,
)
from braket.device_schema import GateModelParameters
from braket.device_schema.dwave import (
Expand Down Expand Up @@ -623,6 +624,34 @@ def _(
return AwsQuantumTask(task_arn, aws_session, *args, **kwargs)


@_create_internal.register
def _(
serializable_program: SerializableProgram,
aws_session: AwsSession,
create_task_kwargs: dict[str, Any],
device_arn: str,
device_parameters: Union[dict, BraketSchemaBase],
_disable_qubit_rewiring: bool,
inputs: dict[str, float],
gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]],
*args,
**kwargs,
) -> AwsQuantumTask:
openqasm_program = OpenQASMProgram(source=serializable_program.to_ir(ir_type=IRType.OPENQASM))
return _create_internal(
openqasm_program,
aws_session,
create_task_kwargs,
device_arn,
device_parameters,
_disable_qubit_rewiring,
inputs,
gate_definitions,
*args,
**kwargs,
)


@_create_internal.register
def _(
blackbird_program: BlackbirdProgram,
Expand Down
27 changes: 27 additions & 0 deletions src/braket/aws/aws_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import os
import os.path
import re
import warnings
from functools import cache
from pathlib import Path
from typing import Any, NamedTuple, Optional
Expand Down Expand Up @@ -235,6 +236,32 @@ def create_quantum_task(self, **boto3_kwargs) -> str:
Returns:
str: The ARN of the quantum task.
"""
# Add reservation arn if available and device is correct.
context_device_arn = os.getenv("AMZN_BRAKET_RESERVATION_DEVICE_ARN")
context_reservation_arn = os.getenv("AMZN_BRAKET_RESERVATION_TIME_WINDOW_ARN")

# if the task has a reservation_arn and also context does, raise a warning
# Raise warning if reservation ARN is found in both context and task parameters
task_has_reservation = any(
item.get("type") == "RESERVATION_TIME_WINDOW_ARN"
for item in boto3_kwargs.get("associations", [])
)
if task_has_reservation and context_reservation_arn:
warnings.warn(
"A reservation ARN was passed to 'CreateQuantumTask', but it is being overridden "
"by a 'DirectReservation' context. If this was not intended, please review your "
"reservation ARN settings or the context in which 'CreateQuantumTask' is called."
)

# Ensure reservation only applies to specific device
if context_device_arn == boto3_kwargs["deviceArn"] and context_reservation_arn:
boto3_kwargs["associations"] = [
{
"arn": context_reservation_arn,
"type": "RESERVATION_TIME_WINDOW_ARN",
}
]

# Add job token to request, if available.
job_token = os.getenv("AMZN_BRAKET_JOB_TOKEN")
if job_token:
Expand Down
98 changes: 98 additions & 0 deletions src/braket/aws/direct_reservations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

from __future__ import annotations

import os
import warnings
from contextlib import AbstractContextManager

from braket.aws.aws_device import AwsDevice
from braket.devices import Device


class DirectReservation(AbstractContextManager):
"""
Context manager that modifies AwsQuantumTasks created within the context to use a reservation
ARN for all tasks targeting the specified device. Note: this context manager only allows for
one reservation at a time.
Reservations are AWS account and device specific. Only the AWS account that created the
reservation can use your reservation ARN. Additionally, the reservation ARN is only valid on the
reserved device at the chosen start and end times.
Args:
device (Device | str | None): The Braket device for which you have a reservation ARN, or
optionally the device ARN.
reservation_arn (str | None): The Braket Direct reservation ARN to be applied to all
quantum tasks run within the context.
Examples:
As a context manager
>>> with DirectReservation(device_arn, reservation_arn="<my_reservation_arn>"):
... task1 = device.run(circuit, shots)
... task2 = device.run(circuit, shots)
or start the reservation
>>> DirectReservation(device_arn, reservation_arn="<my_reservation_arn>").start()
... task1 = device.run(circuit, shots)
... task2 = device.run(circuit, shots)
References:
[1] https://docs.aws.amazon.com/braket/latest/developerguide/braket-reservations.html
"""

_is_active = False # Class variable to track active reservation context

def __init__(self, device: Device | str | None, reservation_arn: str | None):
if isinstance(device, AwsDevice):
self.device_arn = device.arn
elif isinstance(device, str):
self.device_arn = AwsDevice(device).arn # validate ARN early
elif isinstance(device, Device) or device is None: # LocalSimulator
warnings.warn(
"Using a local simulator with the reservation. For a reservation on a QPU, please "
"ensure the device matches the reserved Braket device."
)
self.device_arn = "" # instead of None, use empty string
else:
raise TypeError("Device must be an AwsDevice or its ARN, or a local simulator device.")

self.reservation_arn = reservation_arn

def __enter__(self):
self.start()
return self

def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self.stop()

def start(self) -> None:
"""Start the reservation context."""
if DirectReservation._is_active:
raise RuntimeError("Another reservation is already active.")

os.environ["AMZN_BRAKET_RESERVATION_DEVICE_ARN"] = self.device_arn
if self.reservation_arn:
os.environ["AMZN_BRAKET_RESERVATION_TIME_WINDOW_ARN"] = self.reservation_arn
DirectReservation._is_active = True

def stop(self) -> None:
"""Stop the reservation context."""
if not DirectReservation._is_active:
warnings.warn("Reservation context is not active.")
return
os.environ.pop("AMZN_BRAKET_RESERVATION_DEVICE_ARN", None)
os.environ.pop("AMZN_BRAKET_RESERVATION_TIME_WINDOW_ARN", None)
DirectReservation._is_active = False
1 change: 0 additions & 1 deletion src/braket/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,6 @@ def measure(self, target_qubits: QubitSetInput) -> Circuit:
[Instruction('operator': H('qubit_count': 1), 'target': QubitSet([Qubit(0)]),
Instruction('operator': CNot('qubit_count': 2), 'target': QubitSet([Qubit(0),
Qubit(1)]),
Instruction('operator': H('qubit_count': 1), 'target': QubitSet([Qubit(2)]),
Instruction('operator': Measure, 'target': QubitSet([Qubit(0)])]
"""
if not isinstance(target_qubits, Iterable):
Expand Down
2 changes: 1 addition & 1 deletion src/braket/circuits/result_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def __init__(
"target length is equal to the observable term's qubits count."
)
self._target = [QubitSet(term_target) for term_target in target]
for term_target, obs in zip(target, observable.summands):
for term_target, obs in zip(self._target, observable.summands):
if obs.qubit_count != len(term_target):
raise ValueError(
"Sum observable's target shape must be a nested list where each term's "
Expand Down
21 changes: 21 additions & 0 deletions src/braket/circuits/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum

Expand All @@ -32,6 +33,26 @@ class QubitReferenceType(str, Enum):
PHYSICAL = "PHYSICAL"


class SerializableProgram(ABC):
@abstractmethod
def to_ir(
self,
ir_type: IRType = IRType.OPENQASM,
) -> str:
"""Serializes the program into an intermediate representation.
Args:
ir_type (IRType): The IRType to use for converting the program to its
IR representation. Defaults to IRType.OPENQASM.
Raises:
ValueError: Raised if the supplied `ir_type` is not supported.
Returns:
str: A representation of the program in the `ir_type` format.
"""


@dataclass
class OpenQASMSerializationProperties:
"""Properties for serializing a circuit to OpenQASM.
Expand Down
Loading

0 comments on commit dc56f9e

Please sign in to comment.