Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 28 additions & 8 deletions nixops_aws/ec2_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,14 @@
import mypy_boto3_rds


def fetch_aws_secret_key(access_key_id) -> Tuple[str, str]:
def fetch_aws_secret_key(access_key_id) -> Tuple[str, str, str]:
"""
Fetch the secret access key corresponding to the given access key ID from ~/.ec2-keys,
or from ~/.aws/credentials, or from the environment (in that priority).

If fetching from the environment, any session token which might be present due to 2FA
will also be returned. Using session tokens are not supported when fetching from
~/.ec2-keys or ~/.aws/credentials.
"""

def parse_ec2_keys():
Expand All @@ -35,9 +39,9 @@ def parse_ec2_keys():
if len(w) < 2 or len(w) > 3:
continue
if len(w) == 3 and w[2] == access_key_id:
return (w[0], w[1])
return (w[0], w[1], None)
if w[0] == access_key_id:
return (access_key_id, w[1])
return (access_key_id, w[1], None)
return None

def parse_aws_credentials():
Expand All @@ -48,16 +52,20 @@ def parse_aws_credentials():
conf = Config(os.path.expanduser(path))

if access_key_id == conf.get("default", "aws_access_key_id"):
return (access_key_id, conf.get("default", "aws_secret_access_key"))
return (access_key_id, conf.get("default", "aws_secret_access_key"), None)
return (
conf.get(access_key_id, "aws_access_key_id"),
conf.get(access_key_id, "aws_secret_access_key"),
None,
)

def ec2_keys_from_env():
return (
access_key_id,
os.environ.get("EC2_SECRET_KEY") or os.environ.get("AWS_SECRET_ACCESS_KEY"),
# Get first from AWS_SESSION_TOKEN as that is the new default but fall back to
# AWS_SECURITY_TOKEN as that was previously the standard.
os.environ.get("AWS_SESSION_TOKEN") or os.environ.get("AWS_SECURITY_TOKEN"),
)

sources = (
Expand All @@ -84,11 +92,14 @@ def ec2_keys_from_env():
def connect(region, access_key_id):
"""Connect to the specified EC2 region using the given access key."""
assert region
(access_key_id, secret_access_key) = fetch_aws_secret_key(access_key_id)
(access_key_id, secret_access_key, session_token) = fetch_aws_secret_key(
access_key_id
)
conn = boto.ec2.connect_to_region(
region_name=region,
aws_access_key_id=access_key_id,
aws_secret_access_key=secret_access_key,
aws_session_token=session_token,
)
if not conn:
raise Exception("invalid EC2 region ‘{0}’".format(region))
Expand All @@ -97,24 +108,30 @@ def connect(region, access_key_id):

def connect_ec2_boto3(region, access_key_id):
assert region
(access_key_id, secret_access_key) = fetch_aws_secret_key(access_key_id)
(access_key_id, secret_access_key, session_token) = fetch_aws_secret_key(
access_key_id
)
client = boto3.session.Session().client(
"ec2",
region_name=region,
aws_access_key_id=access_key_id,
aws_secret_access_key=secret_access_key,
aws_session_token=session_token,
)
return client


def connect_vpc(region, access_key_id):
"""Connect to the specified VPC region using the given access key."""
assert region
(access_key_id, secret_access_key) = fetch_aws_secret_key(access_key_id)
(access_key_id, secret_access_key, session_token) = fetch_aws_secret_key(
access_key_id
)
conn = boto.vpc.connect_to_region(
region_name=region,
aws_access_key_id=access_key_id,
aws_secret_access_key=secret_access_key,
aws_session_token=session_token,
)
if not conn:
raise Exception("invalid VPC region ‘{0}’".format(region))
Expand All @@ -123,12 +140,15 @@ def connect_vpc(region, access_key_id):

def connect_rds_boto3(region, access_key_id) -> "mypy_boto3_rds.RDSClient":
assert region
(access_key_id, secret_access_key) = fetch_aws_secret_key(access_key_id)
(access_key_id, secret_access_key, session_token) = fetch_aws_secret_key(
access_key_id
)
client = boto3.session.Session().client(
"rds",
region_name=region,
aws_access_key_id=access_key_id,
aws_secret_access_key=secret_access_key,
aws_session_token=session_token,
)
return client

Expand Down
18 changes: 18 additions & 0 deletions tests/unit/test_ec2_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import os
import unittest

from nixops.util import fetch_aws_secret_key


class TestEc2Utils(unittest.TestCase):
def test_session_token(self):
session_token = "DUMMY_SESSION_TOKEN"
security_token = "DUMMY_SECURITY_TOKEN"
self.assertIsNone(fetch_aws_secret_key("DUMMY_ACCESS_KEY")[2])

os.environ["AWS_SECURITY_TOKEN"] = security_token
self.assertIsEqual(fetch_aws_secret_key("DUMMY_ACCESS_KEY")[2], security_token)

# SESSION_TOKEN should take priority if it's set
os.environ["AWS_SESSION_TOKEN"] = session_token
self.assertIsEqual(fetch_aws_secret_key("DUMMY_ACCES_KEY")[2], session_token)