diff --git a/setup.py b/setup.py index 36eaeff..5f0676f 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name = "awsenv", - version = "1.1.0", + version = "1.2.0", packages = find_packages('src'), package_dir = { '': 'src'}, author = "Naftuli Kay", diff --git a/src/awsenv/__init__.py b/src/awsenv/__init__.py index e504124..c9bd966 100644 --- a/src/awsenv/__init__.py +++ b/src/awsenv/__init__.py @@ -122,9 +122,11 @@ def load(cls): for name in result.keys(): profile = result[name] - key_id, secret_key = profile.get('aws_access_key_id'), profile.get('aws_secret_access_key') + key_id, secret_key, session_token = profile.get('aws_access_key_id'), profile.get('aws_secret_access_key'), profile.get('aws_session_token') - if len(key_id or '') > 0 and len(secret_key or '') > 0: + if len(key_id or '') > 0 and len(secret_key or '') > 0 and len(session_token or '') > 0: + profile_map[name] = AWSProfile(name=name, key_id=key_id, secret_key=secret_key, session_token=session_token) + elif len(key_id or '') > 0 and len(secret_key or ''): profile_map[name] = AWSProfile(name=name, key_id=key_id, secret_key=secret_key) return AWSCredentials(**profile_map) @@ -148,18 +150,25 @@ def ls(self): class AWSProfile(object): - def __init__(self, name, key_id, secret_key): + def __init__(self, name, key_id, secret_key, session_token=None): self.name = name self.key_id = key_id self.secret_key = secret_key + self.session_token = session_token def format(self, export=True): """Formats the AWS credentials for the shell.""" - return "\n".join([ - "{}AWS_ACCESS_KEY_ID={}".format("export " if export else "", self.aws_access_key_id), - "{}AWS_SECRET_ACCESS_KEY={}".format("export " if export else "", self.aws_secret_access_key) - ]) - + if self.aws_session_token: + return "\n".join([ + "{}AWS_ACCESS_KEY_ID={}".format("export " if export else "", self.aws_access_key_id), + "{}AWS_SECRET_ACCESS_KEY={}".format("export " if export else "", self.aws_secret_access_key), + "{}AWS_SESSION_TOKEN={}".format("export " if export else "", self.aws_session_token) + ]) + else: + return "\n".join([ + "{}AWS_ACCESS_KEY_ID={}".format("export " if export else "", self.aws_access_key_id), + "{}AWS_SECRET_ACCESS_KEY={}".format("export " if export else "", self.aws_secret_access_key) + ]) @property def aws_access_key_id(self): return self.key_id @@ -168,6 +177,10 @@ def aws_access_key_id(self): def aws_secret_access_key(self): return self.secret_key + @property + def aws_session_token(self): + return self.session_token + def main(): diff --git a/src/awsenv/tests.py b/src/awsenv/tests.py index 609b57e..5a69f0c 100644 --- a/src/awsenv/tests.py +++ b/src/awsenv/tests.py @@ -62,6 +62,11 @@ def setUp(self): 'aws_access_key_id': 'another', 'aws_secret_access_key': 'thing', }, + 'three': { + 'aws_access_key_id': 'another', + 'aws_secret_access_key': 'thing', + 'aws_session_token': 'here', + }, 'blank_id': { 'aws_secret_access_key': 'value' }, @@ -134,38 +139,43 @@ def test_add(self): def test_get(self): - result = AWSCredentials(one=AWSProfile('one', 'key one', 'key two')) + result = AWSCredentials(one=AWSProfile('one', 'key one', 'key two', 'key three')) test = result.get('one') self.assertIsNotNone(test) self.assertTrue(isinstance(test, AWSProfile)) self.assertEqual('key one', test.aws_access_key_id) self.assertEqual('key two', test.aws_secret_access_key) + self.assertEqual('key three', test.aws_session_token) def test_ls(self): - result = AWSCredentials(one=AWSProfile('one', 'a', 'b'), two=AWSProfile('two', 'a', 'b')) - self.assertEqual(set(['one', 'two']), set(result.ls())) + result = AWSCredentials(one=AWSProfile('one', 'a', 'b'), two=AWSProfile('two', 'a', 'b'), three=AWSProfile('three', 'a', 'b', 'c')) + self.assertEqual(set(['one', 'two', 'three']), set(result.ls())) class AWSProfileTestCase(unittest.TestCase): def test_constructor(self): - fixture = AWSProfile('profile one', 'access key id', 'secret access key') + fixture = AWSProfile('profile one', 'access key id', 'secret access key', 'session token') self.assertEqual('profile one', fixture.name) self.assertEqual('access key id', fixture.key_id) self.assertEqual('secret access key', fixture.secret_key) + self.assertEqual('session token', fixture.session_token) def test_format(self): - fixture = AWSProfile(None, 'a', 'b') - result_export = "export AWS_ACCESS_KEY_ID=a\nexport AWS_SECRET_ACCESS_KEY=b" - result_no_export = "AWS_ACCESS_KEY_ID=a\nAWS_SECRET_ACCESS_KEY=b" + fixture = AWSProfile(None, 'a', 'b', 'c') + result_export = "export AWS_ACCESS_KEY_ID=a\nexport AWS_SECRET_ACCESS_KEY=b\nexport AWS_SESSION_TOKEN=c" + result_no_export = "AWS_ACCESS_KEY_ID=a\nAWS_SECRET_ACCESS_KEY=b\nAWS_SESSION_TOKEN=c" self.assertEqual(result_export, fixture.format()) self.assertEqual(result_no_export, fixture.format(export=False)) def test_access_key_id(self): - self.assertEqual('access key id', AWSProfile(None, 'access key id', None).aws_access_key_id) + self.assertEqual('access key id', AWSProfile(None, 'access key id', None, None).aws_access_key_id) def test_secret_access_key(self): - self.assertEqual('secret access key', AWSProfile(None, None, 'secret access key').aws_secret_access_key) + self.assertEqual('secret access key', AWSProfile(None, None, 'secret access key', None).aws_secret_access_key) + + def test_session_token(self): + self.assertEqual('session token', AWSProfile(None, None, None, 'session token').aws_session_token)