diff --git a/nixops_aws/ec2_utils.py b/nixops_aws/ec2_utils.py index 294f4133..02dd189e 100644 --- a/nixops_aws/ec2_utils.py +++ b/nixops_aws/ec2_utils.py @@ -18,10 +18,14 @@ import mypy_boto3_rds -def fetch_aws_secret_key(access_key_id) -> Tuple[str, str]: +def fetch_aws_secret_key(access_key_id) -> Tuple[str, str, str]: """ Fetch the secret access key corresponding to the given access key ID from ~/.ec2-keys, or from ~/.aws/credentials, or from the environment (in that priority). + + If fetching from the environment, any session token which might be present due to 2FA + will also be returned. Using session tokens are not supported when fetching from + ~/.ec2-keys or ~/.aws/credentials. """ def parse_ec2_keys(): @@ -35,9 +39,9 @@ def parse_ec2_keys(): if len(w) < 2 or len(w) > 3: continue if len(w) == 3 and w[2] == access_key_id: - return (w[0], w[1]) + return (w[0], w[1], None) if w[0] == access_key_id: - return (access_key_id, w[1]) + return (access_key_id, w[1], None) return None def parse_aws_credentials(): @@ -48,16 +52,20 @@ def parse_aws_credentials(): conf = Config(os.path.expanduser(path)) if access_key_id == conf.get("default", "aws_access_key_id"): - return (access_key_id, conf.get("default", "aws_secret_access_key")) + return (access_key_id, conf.get("default", "aws_secret_access_key"), None) return ( conf.get(access_key_id, "aws_access_key_id"), conf.get(access_key_id, "aws_secret_access_key"), + None, ) def ec2_keys_from_env(): return ( access_key_id, os.environ.get("EC2_SECRET_KEY") or os.environ.get("AWS_SECRET_ACCESS_KEY"), + # Get first from AWS_SESSION_TOKEN as that is the new default but fall back to + # AWS_SECURITY_TOKEN as that was previously the standard. + os.environ.get("AWS_SESSION_TOKEN") or os.environ.get("AWS_SECURITY_TOKEN"), ) sources = ( @@ -84,11 +92,14 @@ def ec2_keys_from_env(): def connect(region, access_key_id): """Connect to the specified EC2 region using the given access key.""" assert region - (access_key_id, secret_access_key) = fetch_aws_secret_key(access_key_id) + (access_key_id, secret_access_key, session_token) = fetch_aws_secret_key( + access_key_id + ) conn = boto.ec2.connect_to_region( region_name=region, aws_access_key_id=access_key_id, aws_secret_access_key=secret_access_key, + aws_session_token=session_token, ) if not conn: raise Exception("invalid EC2 region ‘{0}’".format(region)) @@ -97,12 +108,15 @@ def connect(region, access_key_id): def connect_ec2_boto3(region, access_key_id): assert region - (access_key_id, secret_access_key) = fetch_aws_secret_key(access_key_id) + (access_key_id, secret_access_key, session_token) = fetch_aws_secret_key( + access_key_id + ) client = boto3.session.Session().client( "ec2", region_name=region, aws_access_key_id=access_key_id, aws_secret_access_key=secret_access_key, + aws_session_token=session_token, ) return client @@ -110,11 +124,14 @@ def connect_ec2_boto3(region, access_key_id): def connect_vpc(region, access_key_id): """Connect to the specified VPC region using the given access key.""" assert region - (access_key_id, secret_access_key) = fetch_aws_secret_key(access_key_id) + (access_key_id, secret_access_key, session_token) = fetch_aws_secret_key( + access_key_id + ) conn = boto.vpc.connect_to_region( region_name=region, aws_access_key_id=access_key_id, aws_secret_access_key=secret_access_key, + aws_session_token=session_token, ) if not conn: raise Exception("invalid VPC region ‘{0}’".format(region)) @@ -123,12 +140,15 @@ def connect_vpc(region, access_key_id): def connect_rds_boto3(region, access_key_id) -> "mypy_boto3_rds.RDSClient": assert region - (access_key_id, secret_access_key) = fetch_aws_secret_key(access_key_id) + (access_key_id, secret_access_key, session_token) = fetch_aws_secret_key( + access_key_id + ) client = boto3.session.Session().client( "rds", region_name=region, aws_access_key_id=access_key_id, aws_secret_access_key=secret_access_key, + aws_session_token=session_token, ) return client diff --git a/tests/unit/test_ec2_utils.py b/tests/unit/test_ec2_utils.py new file mode 100644 index 00000000..afc9180a --- /dev/null +++ b/tests/unit/test_ec2_utils.py @@ -0,0 +1,18 @@ +import os +import unittest + +from nixops.util import fetch_aws_secret_key + + +class TestEc2Utils(unittest.TestCase): + def test_session_token(self): + session_token = "DUMMY_SESSION_TOKEN" + security_token = "DUMMY_SECURITY_TOKEN" + self.assertIsNone(fetch_aws_secret_key("DUMMY_ACCESS_KEY")[2]) + + os.environ["AWS_SECURITY_TOKEN"] = security_token + self.assertIsEqual(fetch_aws_secret_key("DUMMY_ACCESS_KEY")[2], security_token) + + # SESSION_TOKEN should take priority if it's set + os.environ["AWS_SESSION_TOKEN"] = session_token + self.assertIsEqual(fetch_aws_secret_key("DUMMY_ACCES_KEY")[2], session_token)