Skip to content

Commit

Permalink
Merge pull request #158 from trashguy/improve-account-specific
Browse files Browse the repository at this point in the history
Enhancement ability to filter for an account on the CLI with flag -A
  • Loading branch information
stevemac007 authored Apr 13, 2020
2 parents b4bbe98 + afa8032 commit a937d37
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 8 deletions.
14 changes: 13 additions & 1 deletion aws_google_auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def parse_args(args):
parser.add_argument('-R', '--region', help='AWS region endpoint ($AWS_DEFAULT_REGION)')
parser.add_argument('-d', '--duration', type=int, help='Credential duration ($DURATION)')
parser.add_argument('-p', '--profile', help='AWS profile (defaults to value of $AWS_PROFILE, then falls back to \'sts\')')
parser.add_argument('-A', '--account', help='Filter for specific AWS account.')
parser.add_argument('-D', '--disable-u2f', action='store_true', help='Disable U2F functionality.')
parser.add_argument('-q', '--quiet', action='store_true', help='Quiet output')
parser.add_argument('--bg-response', help='Override default bgresponse challenge token.')
Expand Down Expand Up @@ -155,6 +156,12 @@ def resolve_config(args):
os.getenv('GOOGLE_USERNAME'),
config.username)

# Account (Option priority = ARGS, ENV_VAR, DEFAULT)
config.account = coalesce(
args.account,
os.getenv('AWS_ACCOUNT'),
config.account)

config.keyring = coalesce(
args.keyring,
config.keyring)
Expand Down Expand Up @@ -247,7 +254,12 @@ def process_auth(args, config):
if config.role_arn in roles and not config.ask_role:
config.provider = roles[config.role_arn]
else:
if config.resolve_aliases:
if config.account and config.resolve_aliases:
aliases = amazon_client.resolve_aws_aliases(roles)
config.role_arn, config.provider = util.Util.pick_a_role(roles, aliases, config.account)
elif config.account:
config.role_arn, config.provider = util.Util.pick_a_role(roles, account=config.account)
elif config.resolve_aliases:
aliases = amazon_client.resolve_aws_aliases(roles)
config.role_arn, config.provider = util.Util.pick_a_role(roles, aliases)
else:
Expand Down
8 changes: 8 additions & 0 deletions aws_google_auth/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(self, **kwargs):
self.print_creds = False
self.quiet = False
self.bg_response = None
self.account = ""

# For the "~/.aws/config" file, we use the format "[profile testing]"
# for the 'testing' profile. The credential file will just be "[testing]"
Expand Down Expand Up @@ -138,6 +139,9 @@ def raise_if_invalid(self):
# quiet
assert (self.quiet.__class__ is bool), "Expected quiet to be a boolean. Got {}.".format(self.quiet.__class__)

# account
assert (self.account.__class__ is str), "Expected account to be string. Got {}".format(self.account.__class__)

# Write the configuration (and credentials) out to disk. This allows for
# regular AWS tooling (aws cli and boto) to use the credentials in the
# profile the user specified.
Expand Down Expand Up @@ -259,6 +263,10 @@ def read(self, profile):
read_bg_response = unicode_to_string(config_parser[profile_string].get('google_config.bg_response', None))
self.bg_response = coalesce(read_bg_response, self.bg_response)

# Account
read_account = unicode_to_string(config_parser[profile_string].get('account', None))
self.account = coalesce(read_account, self.account)

# SAML Cache
try:
with open(self.saml_cache_file, 'r') as f:
Expand Down
4 changes: 4 additions & 0 deletions aws_google_auth/tests/test_args_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def test_no_arguments(self):
self.assertEqual(parser.username, None)
self.assertEqual(parser.quiet, False)
self.assertEqual(parser.bg_response, None)
self.assertEqual(parser.account, None)

self.assertFalse(parser.save_failure_html)

Expand All @@ -54,6 +55,7 @@ def test_username(self):
self.assertEqual(parser.region, None)
self.assertEqual(parser.role_arn, None)
self.assertEqual(parser.username, 'username@gmail.com')
self.assertEqual(parser.account, None)

def test_nocache(self):

Expand All @@ -70,6 +72,7 @@ def test_nocache(self):
self.assertEqual(parser.region, None)
self.assertEqual(parser.role_arn, None)
self.assertEqual(parser.username, None)
self.assertEqual(parser.account, None)

def test_resolvealiases(self):

Expand All @@ -86,6 +89,7 @@ def test_resolvealiases(self):
self.assertEqual(parser.region, None)
self.assertEqual(parser.role_arn, None)
self.assertEqual(parser.username, None)
self.assertEqual(parser.account, None)

def test_ask_and_supply_role(self):

Expand Down
1 change: 1 addition & 0 deletions aws_google_auth/tests/test_backwards_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def setUp(self):
self.c.sp_id = "sample_sp_id"
self.c.u2f_disabled = False
self.c.username = "sample_username"
self.c.account = "123456789012"
self.c.raise_if_invalid()
self.c.write(None)

Expand Down
24 changes: 24 additions & 0 deletions aws_google_auth/tests/test_config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,3 +249,27 @@ def test_with_environment(self):
args = parse_args([])
config = resolve_config(args)
self.assertEqual(config.bg_response, 'foo')


class TestAccountProcessing(unittest.TestCase):

@nottest
def test_default(self):
args = parse_args([])
config = resolve_config(args)
self.assertEqual(None, config.account)

def test_cli_param_supplied(self):
args = parse_args(['--account', "123456789012"])
config = resolve_config(args)
self.assertEqual("123456789012", config.account)

@mock.patch.dict(os.environ, {'AWS_ACCOUNT': '123456789012'})
def test_with_environment(self):
args = parse_args([])
config = resolve_config(args)
self.assertEqual("123456789012", config.account)

args = parse_args(['--region', "123456789012"])
config = resolve_config(args)
self.assertEqual("123456789012", config.account)
1 change: 1 addition & 0 deletions aws_google_auth/tests/test_configuration_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def setUp(self):
self.c.bg_response = "foo"
self.c.raise_if_invalid()
self.c.write(None)
self.c.account = "123456789012"

self.config_parser = configparser.RawConfigParser()
self.config_parser.read(self.c.config_file)
Expand Down
12 changes: 10 additions & 2 deletions aws_google_auth/tests/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def test_main_method_chaining(self, process_auth, resolve_config, exit_if_unsupp
print_creds=False,
username=None,
quiet=False,
bg_response=None))
bg_response=None,
account=None))
],
resolve_config.mock_calls)

Expand All @@ -82,7 +83,8 @@ def test_main_method_chaining(self, process_auth, resolve_config, exit_if_unsupp
print_creds=False,
username=None,
quiet=False,
bg_response=None),
bg_response=None,
account=None),
mock_config)
],
process_auth.mock_calls)
Expand All @@ -100,6 +102,7 @@ def test_process_auth_standard(self, mock_google, mock_amazon, mock_util):
mock_config.idp_id = None
mock_config.sp_id = None
mock_config.return_value = None
mock_config.account = None
mock_config.region = None

mock_amazon_client = Mock()
Expand Down Expand Up @@ -174,6 +177,7 @@ def test_process_auth_print_creds(self, mock_google, mock_amazon, mock_util):
mock_config.sp_id = None
mock_config.return_value = None
mock_config.print_creds = True
mock_config.account = None

mock_amazon_client = Mock()
mock_google_client = Mock()
Expand Down Expand Up @@ -321,6 +325,7 @@ def test_process_auth_dont_resolve_alias(self, mock_google, mock_amazon, mock_ut
mock_config.sp_id = None
mock_config.return_value = None
mock_config.keyring = False
mock_config.account = None

mock_amazon_client = Mock()
mock_google_client = Mock()
Expand Down Expand Up @@ -354,6 +359,7 @@ def test_process_auth_dont_resolve_alias(self, mock_google, mock_amazon, mock_ut
self.assertEqual(mock_config.password, "pass")
self.assertEqual(mock_config.provider, "da_provider")
self.assertEqual(mock_config.role_arn, "da_role")
self.assertEqual(mock_config.account, None)

# Assert calls occur
self.assertEqual([call.Util.get_input('Google username: '),
Expand Down Expand Up @@ -392,6 +398,7 @@ def test_process_auth_with_profile(self, mock_google, mock_amazon, mock_util):
mock_config.profile = "blart"
mock_config.return_value = None
mock_config.role_arn = 'arn:aws:iam::123456789012:role/admin'
mock_config.account = None

mock_amazon_client = Mock()
mock_google_client = Mock()
Expand Down Expand Up @@ -464,6 +471,7 @@ def test_process_auth_with_saml_cache(self, mock_google, mock_amazon, mock_util)
mock_config.password = None
mock_config.return_value = None
mock_config.role_arn = 'arn:aws:iam::123456789012:role/admin'
mock_config.account = None

mock_amazon_client = Mock()
mock_google_client = Mock()
Expand Down
15 changes: 10 additions & 5 deletions aws_google_auth/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,15 @@ def get_input(prompt):
return input(prompt)

@staticmethod
def pick_a_role(roles, aliases=None):
def pick_a_role(roles, aliases=None, account=None):
if account:
filtered_roles = {role: principal for role, principal in roles.items() if(account in role)}
else:
filtered_roles = roles

if aliases:
enriched_roles = {}
for role, principal in roles.items():
for role, principal in filtered_roles.items():
enriched_roles[role] = [
aliases[role.split(':')[4]],
role.split('role/')[1],
Expand All @@ -48,14 +53,14 @@ def pick_a_role(roles, aliases=None):
print("Invalid choice, try again.")
else:
while True:
for i, role in enumerate(roles):
for i, role in enumerate(filtered_roles):
print("[{:>3d}] {}".format(i + 1, role))

prompt = 'Type the number (1 - {:d}) of the role to assume: '.format(len(roles))
prompt = 'Type the number (1 - {:d}) of the role to assume: '.format(len(filtered_roles))
choice = Util.get_input(prompt)

try:
return list(roles.items())[int(choice) - 1]
return list(filtered_roles.items())[int(choice) - 1]
except (IndexError, ValueError):
print("Invalid choice, try again.")

Expand Down

0 comments on commit a937d37

Please sign in to comment.