Skip to content

Commit

Permalink
Merge pull request #162 from pmarkert/region
Browse files Browse the repository at this point in the history
Remove default region for #156
  • Loading branch information
stevemac007 authored Apr 13, 2020
2 parents 2d90c20 + 0291539 commit b4bbe98
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 6 deletions.
4 changes: 4 additions & 0 deletions aws_google_auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,10 @@ def process_auth(args, config):
# Set up logging
logging.getLogger().setLevel(getattr(logging, args.log_level.upper(), None))

if config.region is None:
config.region = util.Util.get_input("AWS Region: ")
logging.debug('%s: region is: %s', __name__, config.region)

# If there is a valid cache and the user opted to use it, use that instead
# of prompting the user for input (it will also ignroe any set variables
# such as username or sp_id and idp_id, as those are built into the SAML
Expand Down
2 changes: 1 addition & 1 deletion aws_google_auth/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, **kwargs):
self.idp_id = None
self.password = None
self.profile = "sts"
self.region = "ap-southeast-2"
self.region = None
self.role_arn = None
self.__saml_cache = None
self.sp_id = None
Expand Down
41 changes: 38 additions & 3 deletions aws_google_auth/tests/test_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,20 @@ def test_config_profile(self):
def test_duration_invalid_values(self):
# Duration must be an integer
c = configuration.Configuration()
c.region = "sample_region"
c.idp_id = "sample_idp_id"
c.password = "hunter2"
c.sp_id = "sample_sp_id"
c.username = "sample_username"
c.duration = "bad_type"
c.region = "sample_region"
with self.assertRaises(AssertionError) as e:
c.raise_if_invalid()
self.assertIn("Expected duration to be an integer.", str(e.exception))

# Duration can not be negative
c = configuration.Configuration()
c.region = "sample_region"
c.idp_id = "sample_idp_id"
c.sp_id = "sample_sp_id"
c.password = "hunter2"
Expand All @@ -45,6 +48,7 @@ def test_duration_invalid_values(self):
valid.username = "sample_username"
valid.duration = 100
c = configuration.Configuration()
c.region = "sample_region"
c.idp_id = "sample_idp_id"
c.sp_id = "sample_sp_id"
c.password = "hunter2"
Expand All @@ -56,6 +60,7 @@ def test_duration_invalid_values(self):

def test_duration_valid_values(self):
c = configuration.Configuration()
c.region = "sample_region"
c.idp_id = "sample_idp_id"
c.sp_id = "sample_sp_id"
c.password = "hunter2"
Expand All @@ -72,6 +77,7 @@ def test_duration_valid_values(self):

def test_duration_defaults_to_max_duration(self):
c = configuration.Configuration()
c.region = "sample_region"
c.idp_id = "sample_idp_id"
c.sp_id = "sample_sp_id"
c.password = "hunter2"
Expand All @@ -82,6 +88,7 @@ def test_duration_defaults_to_max_duration(self):
def test_ask_role_invalid_values(self):
# ask_role must be a boolean
c = configuration.Configuration()
c.region = "sample_region"
c.idp_id = "sample_idp_id"
c.sp_id = "sample_sp_id"
c.password = "hunter2"
Expand All @@ -93,6 +100,7 @@ def test_ask_role_invalid_values(self):

def test_ask_role_valid_values(self):
c = configuration.Configuration()
c.region = "sample_region"
c.idp_id = "sample_idp_id"
c.sp_id = "sample_sp_id"
c.password = "hunter2"
Expand All @@ -101,6 +109,7 @@ def test_ask_role_valid_values(self):
self.assertTrue(c.ask_role)
c.raise_if_invalid()
c = configuration.Configuration()
c.region = "sample_region"
c.idp_id = "sample_idp_id"
c.password = "hunter2"
c.sp_id = "sample_sp_id"
Expand All @@ -111,6 +120,7 @@ def test_ask_role_valid_values(self):

def test_ask_role_optional(self):
c = configuration.Configuration()
c.region = "sample_region"
c.idp_id = "sample_idp_id"
c.sp_id = "sample_sp_id"
c.password = "hunter2"
Expand All @@ -121,6 +131,7 @@ def test_ask_role_optional(self):
def test_idp_id_invalid_values(self):
# idp_id must not be None
c = configuration.Configuration()
c.region = "sample_region"
c.sp_id = "sample_sp_id"
c.password = "hunter2"
c.username = "sample_username"
Expand All @@ -130,6 +141,7 @@ def test_idp_id_invalid_values(self):

def test_idp_id_valid_values(self):
c = configuration.Configuration()
c.region = "sample_region"
c.idp_id = "sample_idp_id"
c.sp_id = "sample_sp_id"
c.password = "hunter2"
Expand All @@ -143,6 +155,7 @@ def test_idp_id_valid_values(self):
def test_sp_id_invalid_values(self):
# sp_id must not be None
c = configuration.Configuration()
c.region = "sample_region"
c.idp_id = "sample_idp_id"
c.password = "hunter2"
c.username = "sample_username"
Expand All @@ -152,6 +165,7 @@ def test_sp_id_invalid_values(self):

def test_username_valid_values(self):
c = configuration.Configuration()
c.region = "sample_region"
c.password = "hunter2"
c.idp_id = "sample_idp_id"
c.sp_id = "sample_sp_id"
Expand All @@ -165,6 +179,7 @@ def test_username_valid_values(self):
def test_username_invalid_values(self):
# username must be set
c = configuration.Configuration()
c.region = "sample_region"
c.idp_id = "sample_idp_id"
c.password = "hunter2"
c.sp_id = "sample_sp_id"
Expand All @@ -173,6 +188,7 @@ def test_username_invalid_values(self):
self.assertIn("Expected username to be a string.", str(e.exception))
# username must be be string
c = configuration.Configuration()
c.region = "sample_region"
c.idp_id = "sample_idp_id"
c.sp_id = "sample_sp_id"
c.password = "hunter2"
Expand All @@ -183,6 +199,7 @@ def test_username_invalid_values(self):

def test_password_valid_values(self):
c = configuration.Configuration()
c.region = "sample_region"
c.password = "hunter2"
c.idp_id = "sample_idp_id"
c.sp_id = "sample_sp_id"
Expand All @@ -196,6 +213,7 @@ def test_password_valid_values(self):
def test_password_invalid_values(self):
# password must be set
c = configuration.Configuration()
c.region = "sample_region"
c.idp_id = "sample_idp_id"
c.username = "sample_username"
c.sp_id = "sample_sp_id"
Expand All @@ -204,6 +222,7 @@ def test_password_invalid_values(self):
self.assertIn("Expected password to be a string.", str(e.exception))
# password must be be string
c = configuration.Configuration()
c.region = "sample_region"
c.idp_id = "sample_idp_id"
c.sp_id = "sample_sp_id"
c.password = 123456
Expand All @@ -214,6 +233,7 @@ def test_password_invalid_values(self):

def test_sp_id_valid_values(self):
c = configuration.Configuration()
c.region = "sample_region"
c.idp_id = "sample_idp_id"
c.sp_id = "sample_sp_id"
c.username = "sample_username"
Expand All @@ -226,6 +246,7 @@ def test_sp_id_valid_values(self):

def test_profile_defaults_to_sts(self):
c = configuration.Configuration()
c.region = "sample_region"
c.idp_id = "sample_idp_id"
c.password = "hunter2"
c.sp_id = "sample_sp_id"
Expand All @@ -236,6 +257,7 @@ def test_profile_defaults_to_sts(self):
def test_profile_invalid_values(self):
# profile must be a string
c = configuration.Configuration()
c.region = "sample_region"
c.idp_id = "sample_idp_id"
c.sp_id = "sample_sp_id"
c.password = "hunter2"
Expand All @@ -247,6 +269,7 @@ def test_profile_invalid_values(self):

def test_profile_valid_values(self):
c = configuration.Configuration()
c.region = "sample_region"
c.idp_id = "sample_idp_id"
c.password = "hunter2"
c.sp_id = "sample_sp_id"
Expand All @@ -260,6 +283,7 @@ def test_profile_valid_values(self):

def test_profile_defaults(self):
c = configuration.Configuration()
c.region = "sample_region"
c.idp_id = "sample_idp_id"
c.password = "hunter2"
c.sp_id = "sample_sp_id"
Expand Down Expand Up @@ -292,18 +316,21 @@ def test_region_valid_values(self):
self.assertEqual(c.region, "us-west-2")
c.raise_if_invalid()

def test_region_defaults_to_ap_southeast_2(self):
def test_region_defaults_to_none(self):
c = configuration.Configuration()
c.idp_id = "sample_idp_id"
c.sp_id = "sample_sp_id"
c.username = "sample_username"
c.password = "hunter2"
self.assertEqual(c.region, "ap-southeast-2")
c.raise_if_invalid()
self.assertEqual(c.region, None)
with self.assertRaises(AssertionError) as e:
c.raise_if_invalid()
self.assertIn("Expected region to be a string.", str(e.exception))

def test_role_arn_invalid_values(self):
# role_arn must be a string
c = configuration.Configuration()
c.region = "sample_region"
c.idp_id = "sample_idp_id"
c.sp_id = "sample_sp_id"
c.password = "hunter2"
Expand All @@ -315,6 +342,7 @@ def test_role_arn_invalid_values(self):

# role_arn be a arn-looking string
c = configuration.Configuration()
c.region = "sample_region"
c.idp_id = "sample_idp_id"
c.sp_id = "sample_sp_id"
c.password = "hunter2"
Expand All @@ -326,6 +354,7 @@ def test_role_arn_invalid_values(self):

def test_role_arn_is_optional(self):
c = configuration.Configuration()
c.region = "sample_region"
c.idp_id = "sample_idp_id"
c.sp_id = "sample_sp_id"
c.password = "hunter2"
Expand All @@ -335,6 +364,7 @@ def test_role_arn_is_optional(self):

def test_role_arn_valid_values(self):
c = configuration.Configuration()
c.region = "sample_region"
c.idp_id = "sample_idp_id"
c.sp_id = "sample_sp_id"
c.username = "sample_username"
Expand All @@ -349,6 +379,7 @@ def test_role_arn_valid_values(self):
def test_u2f_disabled_invalid_values(self):
# u2f_disabled must be a boolean
c = configuration.Configuration()
c.region = "sample_region"
c.idp_id = "sample_idp_id"
c.sp_id = "sample_sp_id"
c.username = "sample_username"
Expand All @@ -360,6 +391,7 @@ def test_u2f_disabled_invalid_values(self):

def test_u2f_disabled_valid_values(self):
c = configuration.Configuration()
c.region = "sample_region"
c.password = "hunter2"
c.idp_id = "sample_idp_id"
c.sp_id = "sample_sp_id"
Expand All @@ -368,6 +400,7 @@ def test_u2f_disabled_valid_values(self):
self.assertTrue(c.u2f_disabled)
c.raise_if_invalid()
c = configuration.Configuration()
c.region = "sample_region"
c.password = "hunter2"
c.idp_id = "sample_idp_id"
c.sp_id = "sample_sp_id"
Expand All @@ -378,6 +411,7 @@ def test_u2f_disabled_valid_values(self):

def test_u2f_disabled_is_optional(self):
c = configuration.Configuration()
c.region = "sample_region"
c.password = "hunter2"
c.idp_id = "sample_idp_id"
c.sp_id = "sample_sp_id"
Expand All @@ -387,6 +421,7 @@ def test_u2f_disabled_is_optional(self):

def test_unicode_password(self):
c = configuration.Configuration()
c.region = "sample_region"
c.password = u"hunter2"
c.idp_id = "sample_idp_id"
c.sp_id = "sample_sp_id"
Expand Down
7 changes: 5 additions & 2 deletions aws_google_auth/tests/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def test_process_auth_standard(self, mock_google, mock_amazon, mock_util):
mock_config.idp_id = None
mock_config.sp_id = None
mock_config.return_value = None
mock_config.region = None

mock_amazon_client = Mock()
mock_google_client = Mock()
Expand All @@ -111,7 +112,7 @@ def test_process_auth_standard(self, mock_google, mock_amazon, mock_util):

mock_util_obj = MagicMock()
mock_util_obj.pick_a_role = MagicMock(return_value=("da_role", "da_provider"))
mock_util_obj.get_input = MagicMock(side_effect=["input", "input2", "input3"])
mock_util_obj.get_input = MagicMock(side_effect=["region_input", "input", "input2", "input3"])
mock_util_obj.get_password = MagicMock(return_value="pass")

mock_util.Util = mock_util_obj
Expand All @@ -127,6 +128,7 @@ def test_process_auth_standard(self, mock_google, mock_amazon, mock_util):
aws_google_auth.process_auth(args, mock_config)

# Assert values collected
self.assertEqual(mock_config.region, "region_input")
self.assertEqual(mock_config.username, "input")
self.assertEqual(mock_config.idp_id, "input2")
self.assertEqual(mock_config.sp_id, "input3")
Expand All @@ -135,7 +137,8 @@ def test_process_auth_standard(self, mock_google, mock_amazon, mock_util):
self.assertEqual(mock_config.role_arn, "da_role")

# Assert calls occur
self.assertEqual([call.Util.get_input('Google username: '),
self.assertEqual([call.Util.get_input('AWS Region: '),
call.Util.get_input('Google username: '),
call.Util.get_input('Google IDP ID: '),
call.Util.get_input('Google SP ID: '),
call.Util.get_password('Google Password: '),
Expand Down

0 comments on commit b4bbe98

Please sign in to comment.