Skip to content

Commit

Permalink
fix: store account id if already accessed (#908)
Browse files Browse the repository at this point in the history
  • Loading branch information
AbeCoull authored Mar 14, 2024
1 parent 5e5b2f3 commit 6029c56
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
11 changes: 9 additions & 2 deletions src/braket/aws/aws_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def __init__(
self.braket_client = self.boto_session.client(
"braket", config=self._config, endpoint_url=os.environ.get("BRAKET_ENDPOINT")
)

self._update_user_agent()
self._custom_default_bucket = bool(default_bucket)
self._default_bucket = default_bucket or os.environ.get("AMZN_BRAKET_OUT_S3_BUCKET")
Expand All @@ -101,14 +100,22 @@ def __init__(
self._sts = None
self._logs = None
self._ecr = None
self._account_id = None

@property
def region(self) -> str:
return self.boto_session.region_name

@property
def account_id(self) -> str:
return self.sts_client.get_caller_identity()["Account"]
"""Gets the caller's account number.
Returns:
str: The account number of the caller.
"""
if not self._account_id:
self._account_id = self.sts_client.get_caller_identity()["Account"]
return self._account_id

@property
def iam_client(self) -> client:
Expand Down
7 changes: 6 additions & 1 deletion test/unit_tests/braket/aws/test_aws_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def aws_session(boto_session, braket_client, account_id):
_aws_session._sts.get_caller_identity.return_value = {
"Account": account_id,
}

_aws_session._s3 = Mock()
return _aws_session

Expand Down Expand Up @@ -998,6 +997,12 @@ def test_upload_to_s3(aws_session):
aws_session._s3.upload_file.assert_called_with(filename, bucket, key)


def test_account_id_idempotency(aws_session, account_id):
acc_id = aws_session.account_id
assert acc_id == aws_session.account_id
assert acc_id == account_id


def test_upload_local_data(aws_session):
with tempfile.TemporaryDirectory() as temp_dir:
os.chdir(temp_dir)
Expand Down

0 comments on commit 6029c56

Please sign in to comment.