From e603bf0eeda6bca8c05d855368b38a68006e23dc Mon Sep 17 00:00:00 2001 From: Phillip Markert Date: Thu, 2 Jan 2020 15:27:34 +0100 Subject: [PATCH 1/3] Remove default region for #156 --- aws_google_auth/__init__.py | 4 +++ aws_google_auth/configuration.py | 2 +- aws_google_auth/tests/test_configuration.py | 40 +++++++++++++++++++-- 3 files changed, 42 insertions(+), 4 deletions(-) diff --git a/aws_google_auth/__init__.py b/aws_google_auth/__init__.py index 71a1dfb..c3c7e9f 100644 --- a/aws_google_auth/__init__.py +++ b/aws_google_auth/__init__.py @@ -178,6 +178,10 @@ def process_auth(args, config): # Set up logging logging.getLogger().setLevel(getattr(logging, args.log_level.upper(), None)) + if config.region is None: + config.region = util.Util.get_input("AWS Region: ") + logging.debug('%s: region is: %s', __name__, config.region) + # If there is a valid cache and the user opted to use it, use that instead # of prompting the user for input (it will also ignroe any set variables # such as username or sp_id and idp_id, as those are built into the SAML diff --git a/aws_google_auth/configuration.py b/aws_google_auth/configuration.py index f1643b2..be30e8d 100644 --- a/aws_google_auth/configuration.py +++ b/aws_google_auth/configuration.py @@ -26,7 +26,7 @@ def __init__(self, **kwargs): self.idp_id = None self.password = None self.profile = "sts" - self.region = "ap-southeast-2" + self.region = None self.role_arn = None self.__saml_cache = None self.sp_id = None diff --git a/aws_google_auth/tests/test_configuration.py b/aws_google_auth/tests/test_configuration.py index 459f91f..de60b8a 100644 --- a/aws_google_auth/tests/test_configuration.py +++ b/aws_google_auth/tests/test_configuration.py @@ -17,17 +17,20 @@ def test_config_profile(self): def test_duration_invalid_values(self): # Duration must be an integer c = configuration.Configuration() + c.region = "sample_region" c.idp_id = "sample_idp_id" c.password = "hunter2" c.sp_id = "sample_sp_id" c.username = "sample_username" c.duration = "bad_type" + c.region = "sample_region" with self.assertRaises(AssertionError) as e: c.raise_if_invalid() self.assertIn("Expected duration to be an integer.", str(e.exception)) # Duration can not be negative c = configuration.Configuration() + c.region = "sample_region" c.idp_id = "sample_idp_id" c.sp_id = "sample_sp_id" c.password = "hunter2" @@ -45,6 +48,7 @@ def test_duration_invalid_values(self): valid.username = "sample_username" valid.duration = 100 c = configuration.Configuration() + c.region = "sample_region" c.idp_id = "sample_idp_id" c.sp_id = "sample_sp_id" c.password = "hunter2" @@ -56,6 +60,7 @@ def test_duration_invalid_values(self): def test_duration_valid_values(self): c = configuration.Configuration() + c.region = "sample_region" c.idp_id = "sample_idp_id" c.sp_id = "sample_sp_id" c.password = "hunter2" @@ -72,6 +77,7 @@ def test_duration_valid_values(self): def test_duration_defaults_to_max_duration(self): c = configuration.Configuration() + c.region = "sample_region" c.idp_id = "sample_idp_id" c.sp_id = "sample_sp_id" c.password = "hunter2" @@ -82,6 +88,7 @@ def test_duration_defaults_to_max_duration(self): def test_ask_role_invalid_values(self): # ask_role must be a boolean c = configuration.Configuration() + c.region = "sample_region" c.idp_id = "sample_idp_id" c.sp_id = "sample_sp_id" c.password = "hunter2" @@ -93,6 +100,7 @@ def test_ask_role_invalid_values(self): def test_ask_role_valid_values(self): c = configuration.Configuration() + c.region = "sample_region" c.idp_id = "sample_idp_id" c.sp_id = "sample_sp_id" c.password = "hunter2" @@ -101,6 +109,7 @@ def test_ask_role_valid_values(self): self.assertTrue(c.ask_role) c.raise_if_invalid() c = configuration.Configuration() + c.region = "sample_region" c.idp_id = "sample_idp_id" c.password = "hunter2" c.sp_id = "sample_sp_id" @@ -111,6 +120,7 @@ def test_ask_role_valid_values(self): def test_ask_role_optional(self): c = configuration.Configuration() + c.region = "sample_region" c.idp_id = "sample_idp_id" c.sp_id = "sample_sp_id" c.password = "hunter2" @@ -121,6 +131,7 @@ def test_ask_role_optional(self): def test_idp_id_invalid_values(self): # idp_id must not be None c = configuration.Configuration() + c.region = "sample_region" c.sp_id = "sample_sp_id" c.password = "hunter2" c.username = "sample_username" @@ -130,6 +141,7 @@ def test_idp_id_invalid_values(self): def test_idp_id_valid_values(self): c = configuration.Configuration() + c.region = "sample_region" c.idp_id = "sample_idp_id" c.sp_id = "sample_sp_id" c.password = "hunter2" @@ -143,6 +155,7 @@ def test_idp_id_valid_values(self): def test_sp_id_invalid_values(self): # sp_id must not be None c = configuration.Configuration() + c.region = "sample_region" c.idp_id = "sample_idp_id" c.password = "hunter2" c.username = "sample_username" @@ -152,6 +165,7 @@ def test_sp_id_invalid_values(self): def test_username_valid_values(self): c = configuration.Configuration() + c.region = "sample_region" c.password = "hunter2" c.idp_id = "sample_idp_id" c.sp_id = "sample_sp_id" @@ -165,6 +179,7 @@ def test_username_valid_values(self): def test_username_invalid_values(self): # username must be set c = configuration.Configuration() + c.region = "sample_region" c.idp_id = "sample_idp_id" c.password = "hunter2" c.sp_id = "sample_sp_id" @@ -173,6 +188,7 @@ def test_username_invalid_values(self): self.assertIn("Expected username to be a string.", str(e.exception)) # username must be be string c = configuration.Configuration() + c.region = "sample_region" c.idp_id = "sample_idp_id" c.sp_id = "sample_sp_id" c.password = "hunter2" @@ -183,6 +199,7 @@ def test_username_invalid_values(self): def test_password_valid_values(self): c = configuration.Configuration() + c.region = "sample_region" c.password = "hunter2" c.idp_id = "sample_idp_id" c.sp_id = "sample_sp_id" @@ -196,6 +213,7 @@ def test_password_valid_values(self): def test_password_invalid_values(self): # password must be set c = configuration.Configuration() + c.region = "sample_region" c.idp_id = "sample_idp_id" c.username = "sample_username" c.sp_id = "sample_sp_id" @@ -204,6 +222,7 @@ def test_password_invalid_values(self): self.assertIn("Expected password to be a string.", str(e.exception)) # password must be be string c = configuration.Configuration() + c.region = "sample_region" c.idp_id = "sample_idp_id" c.sp_id = "sample_sp_id" c.password = 123456 @@ -214,6 +233,7 @@ def test_password_invalid_values(self): def test_sp_id_valid_values(self): c = configuration.Configuration() + c.region = "sample_region" c.idp_id = "sample_idp_id" c.sp_id = "sample_sp_id" c.username = "sample_username" @@ -226,6 +246,7 @@ def test_sp_id_valid_values(self): def test_profile_defaults_to_sts(self): c = configuration.Configuration() + c.region = "sample_region" c.idp_id = "sample_idp_id" c.password = "hunter2" c.sp_id = "sample_sp_id" @@ -236,6 +257,7 @@ def test_profile_defaults_to_sts(self): def test_profile_invalid_values(self): # profile must be a string c = configuration.Configuration() + c.region = "sample_region" c.idp_id = "sample_idp_id" c.sp_id = "sample_sp_id" c.password = "hunter2" @@ -247,6 +269,7 @@ def test_profile_invalid_values(self): def test_profile_valid_values(self): c = configuration.Configuration() + c.region = "sample_region" c.idp_id = "sample_idp_id" c.password = "hunter2" c.sp_id = "sample_sp_id" @@ -260,6 +283,7 @@ def test_profile_valid_values(self): def test_profile_defaults(self): c = configuration.Configuration() + c.region = "sample_region" c.idp_id = "sample_idp_id" c.password = "hunter2" c.sp_id = "sample_sp_id" @@ -292,18 +316,20 @@ def test_region_valid_values(self): self.assertEqual(c.region, "us-west-2") c.raise_if_invalid() - def test_region_defaults_to_ap_southeast_2(self): + def test_region_defaults_to_none(self): c = configuration.Configuration() c.idp_id = "sample_idp_id" c.sp_id = "sample_sp_id" c.username = "sample_username" c.password = "hunter2" - self.assertEqual(c.region, "ap-southeast-2") - c.raise_if_invalid() + self.assertEqual(c.region, None) + with self.assertRaises(AssertionError) as e: + c.raise_if_invalid() def test_role_arn_invalid_values(self): # role_arn must be a string c = configuration.Configuration() + c.region = "sample_region" c.idp_id = "sample_idp_id" c.sp_id = "sample_sp_id" c.password = "hunter2" @@ -315,6 +341,7 @@ def test_role_arn_invalid_values(self): # role_arn be a arn-looking string c = configuration.Configuration() + c.region = "sample_region" c.idp_id = "sample_idp_id" c.sp_id = "sample_sp_id" c.password = "hunter2" @@ -326,6 +353,7 @@ def test_role_arn_invalid_values(self): def test_role_arn_is_optional(self): c = configuration.Configuration() + c.region = "sample_region" c.idp_id = "sample_idp_id" c.sp_id = "sample_sp_id" c.password = "hunter2" @@ -335,6 +363,7 @@ def test_role_arn_is_optional(self): def test_role_arn_valid_values(self): c = configuration.Configuration() + c.region = "sample_region" c.idp_id = "sample_idp_id" c.sp_id = "sample_sp_id" c.username = "sample_username" @@ -349,6 +378,7 @@ def test_role_arn_valid_values(self): def test_u2f_disabled_invalid_values(self): # u2f_disabled must be a boolean c = configuration.Configuration() + c.region = "sample_region" c.idp_id = "sample_idp_id" c.sp_id = "sample_sp_id" c.username = "sample_username" @@ -360,6 +390,7 @@ def test_u2f_disabled_invalid_values(self): def test_u2f_disabled_valid_values(self): c = configuration.Configuration() + c.region = "sample_region" c.password = "hunter2" c.idp_id = "sample_idp_id" c.sp_id = "sample_sp_id" @@ -368,6 +399,7 @@ def test_u2f_disabled_valid_values(self): self.assertTrue(c.u2f_disabled) c.raise_if_invalid() c = configuration.Configuration() + c.region = "sample_region" c.password = "hunter2" c.idp_id = "sample_idp_id" c.sp_id = "sample_sp_id" @@ -378,6 +410,7 @@ def test_u2f_disabled_valid_values(self): def test_u2f_disabled_is_optional(self): c = configuration.Configuration() + c.region = "sample_region" c.password = "hunter2" c.idp_id = "sample_idp_id" c.sp_id = "sample_sp_id" @@ -387,6 +420,7 @@ def test_u2f_disabled_is_optional(self): def test_unicode_password(self): c = configuration.Configuration() + c.region = "sample_region" c.password = u"hunter2" c.idp_id = "sample_idp_id" c.sp_id = "sample_sp_id" From 9460c71244b4148c1ede8cfa48da7c2ba0627665 Mon Sep 17 00:00:00 2001 From: Phillip Markert Date: Fri, 3 Jan 2020 10:02:02 +0100 Subject: [PATCH 2/3] Unit test for region default --- aws_google_auth/tests/test_configuration.py | 1 + 1 file changed, 1 insertion(+) diff --git a/aws_google_auth/tests/test_configuration.py b/aws_google_auth/tests/test_configuration.py index de60b8a..cf75c26 100644 --- a/aws_google_auth/tests/test_configuration.py +++ b/aws_google_auth/tests/test_configuration.py @@ -325,6 +325,7 @@ def test_region_defaults_to_none(self): self.assertEqual(c.region, None) with self.assertRaises(AssertionError) as e: c.raise_if_invalid() + self.assertIn("Expected region to be a string.", str(e.exception)) def test_role_arn_invalid_values(self): # role_arn must be a string From 029153992b6568972e1595d0e25d02d05cd4ef05 Mon Sep 17 00:00:00 2001 From: Phillip Markert Date: Mon, 6 Jan 2020 11:33:23 +0100 Subject: [PATCH 3/3] Update to unit-test for coverage Mock wasn't returning None by default. --- aws_google_auth/tests/test_init.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/aws_google_auth/tests/test_init.py b/aws_google_auth/tests/test_init.py index 3450e4e..b8f0949 100644 --- a/aws_google_auth/tests/test_init.py +++ b/aws_google_auth/tests/test_init.py @@ -98,6 +98,7 @@ def test_process_auth_standard(self, mock_google, mock_amazon, mock_util): mock_config.idp_id = None mock_config.sp_id = None mock_config.return_value = None + mock_config.region = None mock_amazon_client = Mock() mock_google_client = Mock() @@ -109,7 +110,7 @@ def test_process_auth_standard(self, mock_google, mock_amazon, mock_util): mock_util_obj = MagicMock() mock_util_obj.pick_a_role = MagicMock(return_value=("da_role", "da_provider")) - mock_util_obj.get_input = MagicMock(side_effect=["input", "input2", "input3"]) + mock_util_obj.get_input = MagicMock(side_effect=["region_input", "input", "input2", "input3"]) mock_util_obj.get_password = MagicMock(return_value="pass") mock_util.Util = mock_util_obj @@ -125,6 +126,7 @@ def test_process_auth_standard(self, mock_google, mock_amazon, mock_util): aws_google_auth.process_auth(args, mock_config) # Assert values collected + self.assertEqual(mock_config.region, "region_input") self.assertEqual(mock_config.username, "input") self.assertEqual(mock_config.idp_id, "input2") self.assertEqual(mock_config.sp_id, "input3") @@ -133,7 +135,8 @@ def test_process_auth_standard(self, mock_google, mock_amazon, mock_util): self.assertEqual(mock_config.role_arn, "da_role") # Assert calls occur - self.assertEqual([call.Util.get_input('Google username: '), + self.assertEqual([call.Util.get_input('AWS Region: '), + call.Util.get_input('Google username: '), call.Util.get_input('Google IDP ID: '), call.Util.get_input('Google SP ID: '), call.Util.get_password('Google Password: '),