diff --git a/aws_google_auth/__init__.py b/aws_google_auth/__init__.py index 7fca14c..07d6a12 100644 --- a/aws_google_auth/__init__.py +++ b/aws_google_auth/__init__.py @@ -30,6 +30,7 @@ def parse_args(args): parser.add_argument('-R', '--region', help='AWS region endpoint ($AWS_DEFAULT_REGION)') parser.add_argument('-d', '--duration', type=int, help='Credential duration ($DURATION)') parser.add_argument('-p', '--profile', help='AWS profile (defaults to value of $AWS_PROFILE, then falls back to \'sts\')') + parser.add_argument('-A', '--account', help='Filter for specific AWS account.') parser.add_argument('-D', '--disable-u2f', action='store_true', help='Disable U2F functionality.') parser.add_argument('-q', '--quiet', action='store_true', help='Quiet output') parser.add_argument('--bg-response', help='Override default bgresponse challenge token.') @@ -155,6 +156,12 @@ def resolve_config(args): os.getenv('GOOGLE_USERNAME'), config.username) + # Account (Option priority = ARGS, ENV_VAR, DEFAULT) + config.account = coalesce( + args.account, + os.getenv('AWS_ACCOUNT'), + config.account) + config.keyring = coalesce( args.keyring, config.keyring) @@ -247,7 +254,12 @@ def process_auth(args, config): if config.role_arn in roles and not config.ask_role: config.provider = roles[config.role_arn] else: - if config.resolve_aliases: + if config.account and config.resolve_aliases: + aliases = amazon_client.resolve_aws_aliases(roles) + config.role_arn, config.provider = util.Util.pick_a_role(roles, aliases, config.account) + elif config.account: + config.role_arn, config.provider = util.Util.pick_a_role(roles, account=config.account) + elif config.resolve_aliases: aliases = amazon_client.resolve_aws_aliases(roles) config.role_arn, config.provider = util.Util.pick_a_role(roles, aliases) else: diff --git a/aws_google_auth/configuration.py b/aws_google_auth/configuration.py index 4319863..74756a2 100644 --- a/aws_google_auth/configuration.py +++ b/aws_google_auth/configuration.py @@ -37,6 +37,7 @@ def __init__(self, **kwargs): self.print_creds = False self.quiet = False self.bg_response = None + self.account = "" # For the "~/.aws/config" file, we use the format "[profile testing]" # for the 'testing' profile. The credential file will just be "[testing]" @@ -138,6 +139,9 @@ def raise_if_invalid(self): # quiet assert (self.quiet.__class__ is bool), "Expected quiet to be a boolean. Got {}.".format(self.quiet.__class__) + # account + assert (self.account.__class__ is str), "Expected account to be string. Got {}".format(self.account.__class__) + # Write the configuration (and credentials) out to disk. This allows for # regular AWS tooling (aws cli and boto) to use the credentials in the # profile the user specified. @@ -259,6 +263,10 @@ def read(self, profile): read_bg_response = unicode_to_string(config_parser[profile_string].get('google_config.bg_response', None)) self.bg_response = coalesce(read_bg_response, self.bg_response) + # Account + read_account = unicode_to_string(config_parser[profile_string].get('account', None)) + self.account = coalesce(read_account, self.account) + # SAML Cache try: with open(self.saml_cache_file, 'r') as f: diff --git a/aws_google_auth/tests/test_args_parser.py b/aws_google_auth/tests/test_args_parser.py index 8fbc7a0..1ede993 100644 --- a/aws_google_auth/tests/test_args_parser.py +++ b/aws_google_auth/tests/test_args_parser.py @@ -32,6 +32,7 @@ def test_no_arguments(self): self.assertEqual(parser.username, None) self.assertEqual(parser.quiet, False) self.assertEqual(parser.bg_response, None) + self.assertEqual(parser.account, None) self.assertFalse(parser.save_failure_html) @@ -54,6 +55,7 @@ def test_username(self): self.assertEqual(parser.region, None) self.assertEqual(parser.role_arn, None) self.assertEqual(parser.username, 'username@gmail.com') + self.assertEqual(parser.account, None) def test_nocache(self): @@ -70,6 +72,7 @@ def test_nocache(self): self.assertEqual(parser.region, None) self.assertEqual(parser.role_arn, None) self.assertEqual(parser.username, None) + self.assertEqual(parser.account, None) def test_resolvealiases(self): @@ -86,6 +89,7 @@ def test_resolvealiases(self): self.assertEqual(parser.region, None) self.assertEqual(parser.role_arn, None) self.assertEqual(parser.username, None) + self.assertEqual(parser.account, None) def test_ask_and_supply_role(self): diff --git a/aws_google_auth/tests/test_backwards_compatibility.py b/aws_google_auth/tests/test_backwards_compatibility.py index b3ff051..f8962f6 100644 --- a/aws_google_auth/tests/test_backwards_compatibility.py +++ b/aws_google_auth/tests/test_backwards_compatibility.py @@ -29,6 +29,7 @@ def setUp(self): self.c.sp_id = "sample_sp_id" self.c.u2f_disabled = False self.c.username = "sample_username" + self.c.account = "123456789012" self.c.raise_if_invalid() self.c.write(None) diff --git a/aws_google_auth/tests/test_config_parser.py b/aws_google_auth/tests/test_config_parser.py index 87dd9ab..beccb4a 100644 --- a/aws_google_auth/tests/test_config_parser.py +++ b/aws_google_auth/tests/test_config_parser.py @@ -249,3 +249,27 @@ def test_with_environment(self): args = parse_args([]) config = resolve_config(args) self.assertEqual(config.bg_response, 'foo') + + +class TestAccountProcessing(unittest.TestCase): + + @nottest + def test_default(self): + args = parse_args([]) + config = resolve_config(args) + self.assertEqual(None, config.account) + + def test_cli_param_supplied(self): + args = parse_args(['--account', "123456789012"]) + config = resolve_config(args) + self.assertEqual("123456789012", config.account) + + @mock.patch.dict(os.environ, {'AWS_ACCOUNT': '123456789012'}) + def test_with_environment(self): + args = parse_args([]) + config = resolve_config(args) + self.assertEqual("123456789012", config.account) + + args = parse_args(['--region', "123456789012"]) + config = resolve_config(args) + self.assertEqual("123456789012", config.account) diff --git a/aws_google_auth/tests/test_configuration_persistence.py b/aws_google_auth/tests/test_configuration_persistence.py index da07150..d05135b 100644 --- a/aws_google_auth/tests/test_configuration_persistence.py +++ b/aws_google_auth/tests/test_configuration_persistence.py @@ -33,6 +33,7 @@ def setUp(self): self.c.bg_response = "foo" self.c.raise_if_invalid() self.c.write(None) + self.c.account = "123456789012" self.config_parser = configparser.RawConfigParser() self.config_parser.read(self.c.config_file) diff --git a/aws_google_auth/tests/test_init.py b/aws_google_auth/tests/test_init.py index e97af53..7090ad6 100644 --- a/aws_google_auth/tests/test_init.py +++ b/aws_google_auth/tests/test_init.py @@ -61,7 +61,8 @@ def test_main_method_chaining(self, process_auth, resolve_config, exit_if_unsupp print_creds=False, username=None, quiet=False, - bg_response=None)) + bg_response=None, + account=None)) ], resolve_config.mock_calls) @@ -82,7 +83,8 @@ def test_main_method_chaining(self, process_auth, resolve_config, exit_if_unsupp print_creds=False, username=None, quiet=False, - bg_response=None), + bg_response=None, + account=None), mock_config) ], process_auth.mock_calls) @@ -100,6 +102,7 @@ def test_process_auth_standard(self, mock_google, mock_amazon, mock_util): mock_config.idp_id = None mock_config.sp_id = None mock_config.return_value = None + mock_config.account = None mock_config.region = None mock_amazon_client = Mock() @@ -174,6 +177,7 @@ def test_process_auth_print_creds(self, mock_google, mock_amazon, mock_util): mock_config.sp_id = None mock_config.return_value = None mock_config.print_creds = True + mock_config.account = None mock_amazon_client = Mock() mock_google_client = Mock() @@ -321,6 +325,7 @@ def test_process_auth_dont_resolve_alias(self, mock_google, mock_amazon, mock_ut mock_config.sp_id = None mock_config.return_value = None mock_config.keyring = False + mock_config.account = None mock_amazon_client = Mock() mock_google_client = Mock() @@ -354,6 +359,7 @@ def test_process_auth_dont_resolve_alias(self, mock_google, mock_amazon, mock_ut self.assertEqual(mock_config.password, "pass") self.assertEqual(mock_config.provider, "da_provider") self.assertEqual(mock_config.role_arn, "da_role") + self.assertEqual(mock_config.account, None) # Assert calls occur self.assertEqual([call.Util.get_input('Google username: '), @@ -392,6 +398,7 @@ def test_process_auth_with_profile(self, mock_google, mock_amazon, mock_util): mock_config.profile = "blart" mock_config.return_value = None mock_config.role_arn = 'arn:aws:iam::123456789012:role/admin' + mock_config.account = None mock_amazon_client = Mock() mock_google_client = Mock() @@ -464,6 +471,7 @@ def test_process_auth_with_saml_cache(self, mock_google, mock_amazon, mock_util) mock_config.password = None mock_config.return_value = None mock_config.role_arn = 'arn:aws:iam::123456789012:role/admin' + mock_config.account = None mock_amazon_client = Mock() mock_google_client = Mock() diff --git a/aws_google_auth/util.py b/aws_google_auth/util.py index 0e42ea9..4aacac6 100644 --- a/aws_google_auth/util.py +++ b/aws_google_auth/util.py @@ -18,10 +18,15 @@ def get_input(prompt): return input(prompt) @staticmethod - def pick_a_role(roles, aliases=None): + def pick_a_role(roles, aliases=None, account=None): + if account: + filtered_roles = {role: principal for role, principal in roles.items() if(account in role)} + else: + filtered_roles = roles + if aliases: enriched_roles = {} - for role, principal in roles.items(): + for role, principal in filtered_roles.items(): enriched_roles[role] = [ aliases[role.split(':')[4]], role.split('role/')[1], @@ -48,14 +53,14 @@ def pick_a_role(roles, aliases=None): print("Invalid choice, try again.") else: while True: - for i, role in enumerate(roles): + for i, role in enumerate(filtered_roles): print("[{:>3d}] {}".format(i + 1, role)) - prompt = 'Type the number (1 - {:d}) of the role to assume: '.format(len(roles)) + prompt = 'Type the number (1 - {:d}) of the role to assume: '.format(len(filtered_roles)) choice = Util.get_input(prompt) try: - return list(roles.items())[int(choice) - 1] + return list(filtered_roles.items())[int(choice) - 1] except (IndexError, ValueError): print("Invalid choice, try again.")