From 035409addb9a3f7bc15626b99dec2e22519bbfac Mon Sep 17 00:00:00 2001 From: Abe Coull <85974725+math411@users.noreply.github.com> Date: Wed, 22 May 2024 09:11:37 -0700 Subject: [PATCH] feat: add support for the ARN region (#977) Co-authored-by: Coull Co-authored-by: Tim (Yi-Ting) --- src/braket/aws/aws_device.py | 2 +- src/braket/devices/devices.py | 4 ++ src/braket/jobs/image_uri_config/base.json | 3 +- .../jobs/image_uri_config/pl_pytorch.json | 3 +- .../jobs/image_uri_config/pl_tensorflow.json | 3 +- test/integ_tests/conftest.py | 55 ++++++++++++------- test/integ_tests/test_cost_tracking.py | 37 +++++++------ test/unit_tests/braket/aws/test_aws_device.py | 24 +++++++- tox.ini | 1 + 9 files changed, 89 insertions(+), 43 deletions(-) diff --git a/src/braket/aws/aws_device.py b/src/braket/aws/aws_device.py index 70fc3c078..041098f5a 100644 --- a/src/braket/aws/aws_device.py +++ b/src/braket/aws/aws_device.py @@ -62,7 +62,7 @@ class AwsDevice(Device): device. """ - REGIONS = ("us-east-1", "us-west-1", "us-west-2", "eu-west-2") + REGIONS = ("us-east-1", "us-west-1", "us-west-2", "eu-west-2", "eu-north-1") DEFAULT_SHOTS_QPU = 1000 DEFAULT_SHOTS_SIMULATOR = 0 diff --git a/src/braket/devices/devices.py b/src/braket/devices/devices.py index f4fe8ff8d..fa2c6d025 100644 --- a/src/braket/devices/devices.py +++ b/src/braket/devices/devices.py @@ -27,6 +27,9 @@ class _DWave(str, Enum): _Advantage6 = "arn:aws:braket:us-west-2::device/qpu/d-wave/Advantage_system6" _DW2000Q6 = "arn:aws:braket:::device/qpu/d-wave/DW_2000Q_6" + class _IQM(str, Enum): + Garnet = "arn:aws:braket:eu-north-1::device/qpu/iqm/Garnet" + class _IonQ(str, Enum): Harmony = "arn:aws:braket:us-east-1::device/qpu/ionq/Harmony" Aria1 = "arn:aws:braket:us-east-1::device/qpu/ionq/Aria-1" @@ -54,6 +57,7 @@ class _Xanadu(str, Enum): Amazon = _Amazon # DWave = _DWave IonQ = _IonQ + IQM = _IQM OQC = _OQC QuEra = _QuEra Rigetti = _Rigetti diff --git a/src/braket/jobs/image_uri_config/base.json b/src/braket/jobs/image_uri_config/base.json index eb71e60fd..c7aef2be2 100644 --- a/src/braket/jobs/image_uri_config/base.json +++ b/src/braket/jobs/image_uri_config/base.json @@ -5,6 +5,7 @@ "us-east-1", "us-west-1", "us-west-2", - "eu-west-2" + "eu-west-2", + "eu-north-1" ] } diff --git a/src/braket/jobs/image_uri_config/pl_pytorch.json b/src/braket/jobs/image_uri_config/pl_pytorch.json index c7e28fbde..0a00e8537 100644 --- a/src/braket/jobs/image_uri_config/pl_pytorch.json +++ b/src/braket/jobs/image_uri_config/pl_pytorch.json @@ -5,6 +5,7 @@ "us-east-1", "us-west-1", "us-west-2", - "eu-west-2" + "eu-west-2", + "eu-north-1" ] } diff --git a/src/braket/jobs/image_uri_config/pl_tensorflow.json b/src/braket/jobs/image_uri_config/pl_tensorflow.json index 3278a8712..c43792e8a 100644 --- a/src/braket/jobs/image_uri_config/pl_tensorflow.json +++ b/src/braket/jobs/image_uri_config/pl_tensorflow.json @@ -5,6 +5,7 @@ "us-east-1", "us-west-1", "us-west-2", - "eu-west-2" + "eu-west-2", + "eu-north-1" ] } diff --git a/test/integ_tests/conftest.py b/test/integ_tests/conftest.py index b8d0d43df..80a8f5c8f 100644 --- a/test/integ_tests/conftest.py +++ b/test/integ_tests/conftest.py @@ -38,6 +38,7 @@ def pytest_configure_node(node): node.workerinput["JOB_FAILED_NAME"] = job_fail_name if endpoint := os.getenv("BRAKET_ENDPOINT"): node.workerinput["BRAKET_ENDPOINT"] = endpoint + node.workerinput["AWS_REGION"] = os.getenv("AWS_REGION") def pytest_xdist_node_collection_finished(ids): @@ -48,8 +49,11 @@ def pytest_xdist_node_collection_finished(ids): """ run_jobs = any("job" in test for test in ids) profile_name = os.environ["AWS_PROFILE"] - aws_session = AwsSession(boto3.session.Session(profile_name=profile_name)) - if run_jobs and os.getenv("JOBS_STARTED") is None: + region_name = os.getenv("AWS_REGION") + aws_session = AwsSession( + boto3.session.Session(profile_name=profile_name, region_name=region_name) + ) + if run_jobs and os.getenv("JOBS_STARTED") is None and region_name != "eu-north-1": AwsQuantumJob.create( "arn:aws:braket:::device/quantum-simulator/amazon/sv1", job_name=job_fail_name, @@ -72,9 +76,10 @@ def pytest_xdist_node_collection_finished(ids): @pytest.fixture(scope="session") -def boto_session(): +def boto_session(request): profile_name = os.environ["AWS_PROFILE"] - return boto3.session.Session(profile_name=profile_name) + region_name = request.config.workerinput["AWS_REGION"] + return boto3.session.Session(profile_name=profile_name, region_name=region_name) @pytest.fixture(scope="session") @@ -137,9 +142,11 @@ def s3_destination_folder(s3_bucket, s3_prefix): @pytest.fixture(scope="session") def braket_simulators(aws_session): - return { - simulator_arn: AwsDevice(simulator_arn, aws_session) for simulator_arn in SIMULATOR_ARNS - } + return ( + {simulator_arn: AwsDevice(simulator_arn, aws_session) for simulator_arn in SIMULATOR_ARNS} + if aws_session.region != "eu-north-1" + else None + ) @pytest.fixture(scope="session") @@ -164,21 +171,29 @@ def job_failed_name(request): @pytest.fixture(scope="session", autouse=True) def completed_quantum_job(job_completed_name): - job_arn = [ - job["jobArn"] - for job in boto3.client("braket").search_jobs(filters=[])["jobs"] - if job["jobName"] == job_completed_name - ][0] + job_arn = ( + [ + job["jobArn"] + for job in boto3.client("braket").search_jobs(filters=[])["jobs"] + if job["jobName"] == job_completed_name + ][0] + if os.getenv("JOBS_STARTED") + else None + ) - return AwsQuantumJob(arn=job_arn) + return AwsQuantumJob(arn=job_arn) if os.getenv("JOBS_STARTED") else None @pytest.fixture(scope="session", autouse=True) def failed_quantum_job(job_failed_name): - job_arn = [ - job["jobArn"] - for job in boto3.client("braket").search_jobs(filters=[])["jobs"] - if job["jobName"] == job_failed_name - ][0] - - return AwsQuantumJob(arn=job_arn) + job_arn = ( + [ + job["jobArn"] + for job in boto3.client("braket").search_jobs(filters=[])["jobs"] + if job["jobName"] == job_failed_name + ][0] + if os.getenv("JOBS_STARTED") + else None + ) + + return AwsQuantumJob(arn=job_arn) if os.getenv("JOBS_STARTED") else None diff --git a/test/integ_tests/test_cost_tracking.py b/test/integ_tests/test_cost_tracking.py index d638f80b0..0c06f297a 100644 --- a/test/integ_tests/test_cost_tracking.py +++ b/test/integ_tests/test_cost_tracking.py @@ -18,7 +18,7 @@ import pytest from botocore.exceptions import ClientError -from braket.aws import AwsDevice, AwsSession +from braket.aws import AwsDevice, AwsDeviceType, AwsSession from braket.circuits import Circuit from braket.tracking import Tracker from braket.tracking.tracker import MIN_SIMULATOR_DURATION @@ -93,23 +93,26 @@ def test_all_devices_price_search(): s = AwsSession(boto3.Session(region_name=region)) # Skip devices with empty execution windows for device in [device for device in devices if device.properties.service.executionWindows]: - try: - s.get_device(device.arn) - - # If we are here, device can create tasks in region - details = { - "shots": 100, - "device": device.arn, - "billed_duration": MIN_SIMULATOR_DURATION, - "job_task": False, - "status": "COMPLETED", - } - tasks[f"task:for:{device.name}:{region}"] = details.copy() - details["job_task"] = True - tasks[f"jobtask:for:{device.name}:{region}"] = details - except s.braket_client.exceptions.ResourceNotFoundException: - # device does not exist in region, so nothing to test + if region == "eu-north-1" and device.type == AwsDeviceType.SIMULATOR: pass + else: + try: + s.get_device(device.arn) + + # If we are here, device can create tasks in region + details = { + "shots": 100, + "device": device.arn, + "billed_duration": MIN_SIMULATOR_DURATION, + "job_task": False, + "status": "COMPLETED", + } + tasks[f"task:for:{device.name}:{region}"] = details.copy() + details["job_task"] = True + tasks[f"jobtask:for:{device.name}:{region}"] = details + except s.braket_client.exceptions.ResourceNotFoundException: + # device does not exist in region, so nothing to test + pass t = Tracker() t._resources = tasks diff --git a/test/unit_tests/braket/aws/test_aws_device.py b/test/unit_tests/braket/aws/test_aws_device.py index 7f9e6179d..a778dfb5f 100644 --- a/test/unit_tests/braket/aws/test_aws_device.py +++ b/test/unit_tests/braket/aws/test_aws_device.py @@ -1753,6 +1753,16 @@ def test_get_devices(mock_copy_session, aws_session): "providerName": "OQC", } ], + # eu-north-1 + [ + { + "deviceArn": SV1_ARN, + "deviceName": "SV1", + "deviceType": "SIMULATOR", + "deviceStatus": "ONLINE", + "providerName": "Amazon Braket", + }, + ], # Only two regions to search outside of current ValueError("should not be reachable"), ] @@ -1763,7 +1773,7 @@ def test_get_devices(mock_copy_session, aws_session): ValueError("should not be reachable"), ] mock_copy_session.return_value = session_for_region - # Search order: us-east-1, us-west-1, us-west-2, eu-west-2 + # Search order: us-east-1, us-west-1, us-west-2, eu-west-2, eu-north-1 results = AwsDevice.get_devices( arns=[SV1_ARN, DWAVE_ARN, IONQ_ARN, OQC_ARN], provider_names=["Amazon Braket", "D-Wave", "IonQ", "OQC"], @@ -1858,6 +1868,16 @@ def test_get_devices_with_error_in_region(mock_copy_session, aws_session): "providerName": "OQC", } ], + # eu-north-1 + [ + { + "deviceArn": SV1_ARN, + "deviceName": "SV1", + "deviceType": "SIMULATOR", + "deviceStatus": "ONLINE", + "providerName": "Amazon Braket", + }, + ], # Only two regions to search outside of current ValueError("should not be reachable"), ] @@ -1867,7 +1887,7 @@ def test_get_devices_with_error_in_region(mock_copy_session, aws_session): ValueError("should not be reachable"), ] mock_copy_session.return_value = session_for_region - # Search order: us-east-1, us-west-1, us-west-2, eu-west-2 + # Search order: us-east-1, us-west-1, us-west-2, eu-west-2, eu-north-1 results = AwsDevice.get_devices( statuses=["ONLINE"], aws_session=aws_session, diff --git a/tox.ini b/tox.ini index d4fae64c5..95c862106 100644 --- a/tox.ini +++ b/tox.ini @@ -29,6 +29,7 @@ deps = {[test-deps]deps} passenv = AWS_PROFILE + AWS_REGION BRAKET_ENDPOINT commands = pytest test/integ_tests {posargs}