From 6029c560ea2c287c9a381bf6926ad4beb8426181 Mon Sep 17 00:00:00 2001 From: Abe Coull <85974725+math411@users.noreply.github.com> Date: Thu, 14 Mar 2024 12:06:50 -0700 Subject: [PATCH] fix: store account id if already accessed (#908) --- src/braket/aws/aws_session.py | 11 +++++++++-- test/unit_tests/braket/aws/test_aws_session.py | 7 ++++++- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/braket/aws/aws_session.py b/src/braket/aws/aws_session.py index a1ae0c7bb..d2f2099f4 100644 --- a/src/braket/aws/aws_session.py +++ b/src/braket/aws/aws_session.py @@ -85,7 +85,6 @@ def __init__( self.braket_client = self.boto_session.client( "braket", config=self._config, endpoint_url=os.environ.get("BRAKET_ENDPOINT") ) - self._update_user_agent() self._custom_default_bucket = bool(default_bucket) self._default_bucket = default_bucket or os.environ.get("AMZN_BRAKET_OUT_S3_BUCKET") @@ -101,6 +100,7 @@ def __init__( self._sts = None self._logs = None self._ecr = None + self._account_id = None @property def region(self) -> str: @@ -108,7 +108,14 @@ def region(self) -> str: @property def account_id(self) -> str: - return self.sts_client.get_caller_identity()["Account"] + """Gets the caller's account number. + + Returns: + str: The account number of the caller. + """ + if not self._account_id: + self._account_id = self.sts_client.get_caller_identity()["Account"] + return self._account_id @property def iam_client(self) -> client: diff --git a/test/unit_tests/braket/aws/test_aws_session.py b/test/unit_tests/braket/aws/test_aws_session.py index c61d22606..56d23b2e9 100644 --- a/test/unit_tests/braket/aws/test_aws_session.py +++ b/test/unit_tests/braket/aws/test_aws_session.py @@ -58,7 +58,6 @@ def aws_session(boto_session, braket_client, account_id): _aws_session._sts.get_caller_identity.return_value = { "Account": account_id, } - _aws_session._s3 = Mock() return _aws_session @@ -998,6 +997,12 @@ def test_upload_to_s3(aws_session): aws_session._s3.upload_file.assert_called_with(filename, bucket, key) +def test_account_id_idempotency(aws_session, account_id): + acc_id = aws_session.account_id + assert acc_id == aws_session.account_id + assert acc_id == account_id + + def test_upload_local_data(aws_session): with tempfile.TemporaryDirectory() as temp_dir: os.chdir(temp_dir)