-
Notifications
You must be signed in to change notification settings - Fork 0
/
populate-nlb-tg-vpce.py
338 lines (300 loc) · 13 KB
/
populate-nlb-tg-vpce.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
import json
import tempfile
from datetime import datetime
import os
import random
import sys
from botocore.exceptions import ClientError
import boto3
import dns.resolver
'''
This function checks the DNS records for an vpc endpoint IP addresses.
It populates a Network Load Balancer's target group with VPC Endpoint IP addresses
WARNING: This function perform multiple DNS looks per each invocation. It is not guaranteed that all
Application Load Balancer IP will be detected by a single invocation. However, the result converges
when more invocations are triggered. This function perform registration aggressively
and deregistration cautiously.
Configure these environment variables in your Lambda environment (CloudFormation Inputs)
1. DNS_NAME - The full DNS name of the vpc endpoint
2. DNS_LISTENERS - The traffic listener port of the vpc endpoint listeners
3. S3_BUCKET - Bucket to track changes between Lambda invocations
4. NLB_TG_ARN - The ARN of the Network Load Balancer's target group
5. MAX_LOOKUP_PER_INVOCATION - The max times of DNS look per invocation
6. INVOCATIONS_BEFORE_DEREGISTRATION - Then number of required Invocations before a IP is deregistered
7. CW_METRIC_FLAG_IP_COUNT - The controller flag that enables CloudWatch metric of IP count
'''
DNS_NAME = os.environ['DNS_NAME']
DNS_LISTENERS = os.environ['DNS_LISTENERS']
S3_BUCKET = os.environ['S3_BUCKET']
NLB_TG_ARN = os.environ['NLB_TG_ARN']
MAX_LOOKUP_PER_INVOCATION = int(os.environ['MAX_LOOKUP_PER_INVOCATION'])
INVOCATIONS_BEFORE_DEREGISTRATION = int(os.environ['INVOCATIONS_BEFORE_DEREGISTRATION'])
CW_METRIC_FLAG_IP_COUNT = os.environ['CW_METRIC_FLAG_IP_COUNT']
ACTIVE_FILENAME = 'Active IP list of {}.json'.format(DNS_NAME)
PENDING_DEREGISTRATION_FILENAME = 'Pending deregisteration IP list of {}.json'.format(DNS_NAME)
ACTIVE_IP_LIST_KEY = "{}-active-registered-IPs/{}"\
.format(DNS_NAME, ACTIVE_FILENAME)
PENDING_IP_LIST_KEY = "{}-pending-deregisteration-IPs/{}"\
.format(DNS_NAME, PENDING_DEREGISTRATION_FILENAME)
TIME = datetime.strftime(((datetime.utcnow())), '%Y-%m-%d %H:%M:%S')
try:
s3 = boto3.resource('s3')
except Exception as e:
print "ERROR: failed to connect to S3"
sys.exit(1)
try:
cwclient = boto3.client('cloudwatch')
except ClientError as e:
print e.response['Error']['Message']
sys.exit(1)
try:
elbv2client = boto3.client('elbv2')
except ClientError as e:
print e.response['Error']['Message']
sys.exit(1)
def put_metric_data(ip_dict):
"""
Put metric -- IPCount to CloudWatch
"""
try:
cwclient.put_metric_data(
Namespace='AWS/ApplicationELB',
MetricData=[
{
'MetricName': "LoadBalancerIPCount",
'Dimensions': [
{
'Name': 'LoadBalancerName',
'Value': ip_dict['LoadBalancerName']
},
],
'Value': float(ip_dict['IPCount']),
'Unit': 'Count'
},
]
)
except ClientError as e:
print e.response['Error']['Message']
def upload_ip_list(s3_bucket, file_name, json_object, object_key):
"""
Upload a IP address list to S3
"""
temp_file = tempfile.NamedTemporaryFile()
with open(temp_file.name, 'w') as f:
json.dump(json_object, f)
try:
s3.meta.client.upload_file(temp_file.name, s3_bucket, object_key)
except Exception as e:
print e.response['Error']['Message']
def download_ip_list(s3_bucket, object_key):
"""
Download a IP address list of Load Balancer IP to S3
"""
try:
s3client = boto3.client('s3')
except Exception as e:
print "ERROR: failed to connect to S3"
print e
try:
response = s3client.get_object(Bucket=s3_bucket, Key=object_key)
except Exception as e:
print "ERROR: Failed to download IP list from S3. " \
"It is normal to see this message " \
"if it is the first time the Lambda function is triggered."
print e
return '{}'
ip_str = response['Body'].read()
old_ip_dict = json.loads(ip_str)
return old_ip_dict
def register_target(tg_arn, new_target_list):
"""
Register ALB's IP to NLB's target group
"""
print "INFO: Register new_target_list:{}".format(new_target_list)
try:
elbv2client.register_targets(
TargetGroupArn=tg_arn,
Targets=new_target_list
)
except ClientError as e:
print e.response['Error']['Message']
def deregister_target(tg_arn, new_target_list):
"""
Deregister ALB's IP from NLB's target group
"""
try:
print "INFO: Deregistering targets: {}".format(new_target_list)
elbv2client.deregister_targets(
TargetGroupArn=tg_arn,
Targets=new_target_list
)
except ClientError as e:
print e.response['Error']['Message']
def target_group_list(ip_list):
"""
Render a list of targets for registration
"""
target_list = []
alb_listeners_list = DNS_LISTENERS.split(",")
for alb_listener in alb_listeners_list:
for ip in ip_list:
target = {
'Id': ip,
'Port': int(alb_listener),
}
target_list.append(target)
return target_list
def describe_target_health(tg_arn):
"""
Get a IP address list of registered targets in the NLB's target group
"""
registered_ip_list = []
try:
response = elbv2client.describe_target_health(
TargetGroupArn=tg_arn)
registered_ip_count = len(response['TargetHealthDescriptions'])
print "INFO: Number of currently registered IP: ", registered_ip_count
for target in response['TargetHealthDescriptions']:
registered_ip = target['Target']['Id']
registered_ip_list.append(registered_ip)
except ClientError as e:
print e.response['Error']['Message']
return registered_ip_list
def dns_lookup(domainname, record_type, *dnsserver):
"""
Get dns lookup results
:param domain:
:return: list of dns lookup results
"""
lookup_result_list = []
myResolver = dns.resolver.Resolver()
if not dnsserver:
lookupAnswer = myResolver.query(domainname, record_type)
else:
myResolver.nameservers = random.choice(dnsserver)
lookupAnswer = myResolver.query(domainname, record_type)
for answer in lookupAnswer:
lookup_result_list.append(str(answer))
return lookup_result_list
def lambda_handler(event, context):
"""
Main Lambda handler
This is invoked when Lambda is called
"""
if MAX_LOOKUP_PER_INVOCATION <= 0:
print "ERROR: MAX_LOOKUP_PER_INVOCATION is negative or zero, try again"
sys.exit(1)
if INVOCATIONS_BEFORE_DEREGISTRATION <= 0:
print "ERROR: INVOCATIONS_BEFORE_DEREGISTRATION is negative or zero, try again"
sys.exit(1)
regional_name = '.'.join(DNS_NAME.split('.')[2:])
authoritative_server_ip_list = []
regular_record_set = []
registered_ip_list = describe_target_health(NLB_TG_ARN)
# Get authoritative name server IP
authoritative_server_domain_list = set(dns_lookup(regional_name, "NS"))
for nameserver_domain in authoritative_server_domain_list:
authoritative_server_ip_list += dns_lookup(nameserver_domain, "A")
print "INFO: Authoritative name server: {}".format(authoritative_server_ip_list)
i = 1
while i <= MAX_LOOKUP_PER_INVOCATION:
dns_lookup_result = dns_lookup(DNS_NAME, "A", authoritative_server_ip_list)
regular_record_set = set(dns_lookup_result) | set(regular_record_set)
if len(dns_lookup_result) < 8:
break
i+=1
print "INFO: IPs detected by DNS lookup:", regular_record_set
print "INFO: Number of IPs detected by DNS lookup: ", len(regular_record_set)
# At this point if the actual_ip_list is empty then something has gone really wrong
# An ALB should never have zero IPs in DNS; if it looks like that, bail out
if not regular_record_set:
print "ERROR: The number of IPs in DNS for the ALB is" \
" showing up as zero. This cannot be correct."
print "ERROR: Script will not proceed with " \
"making changes to the NLB target group."
sys.exit(1)
new_active_ip_dict = {"LoadBalancerName": DNS_NAME, "TimeStamp": TIME}
new_active_ip_dict["IPList"] = list(regular_record_set)
new_active_ip_dict["IPCount"] = len(regular_record_set)
active_ip_json = json.dumps(new_active_ip_dict)
if CW_METRIC_FLAG_IP_COUNT.lower() == "true":
put_metric_data(new_active_ip_dict)
#construct set of new active IPs and registered IPs
new_active_ip_set = set(new_active_ip_dict['IPList'])
registered_ip_set = set(registered_ip_list)
#down load old active IPs and old pending IPs from S3
old_active_ip_dict = json.loads(download_ip_list(S3_BUCKET, ACTIVE_IP_LIST_KEY))
old_pending_ip_dict = json.loads(download_ip_list(S3_BUCKET, PENDING_IP_LIST_KEY))
print "INFO: Active IPs from last invocation: {}".format(old_active_ip_dict)
print "INFO:Pending deregistration IP from last invocation: {}".format(old_pending_ip_dict)
print "INFO: Active IPs from the current invocation {}".format(new_active_ip_dict)
# Check for Registration
# IPs that have not been registered, and missing from the old active IP list
new_diff_ip_set_from_descibe = new_active_ip_set - registered_ip_set
if old_active_ip_dict:
old_active_ip_set = set(old_active_ip_dict['IPList'])
new_diff_ip_set_from_s3 = new_active_ip_set - old_active_ip_set
registration_ip_list = list(new_diff_ip_set_from_s3 | new_diff_ip_set_from_descibe)
# IPs that have not been registered
else:
registration_ip_list = list(new_diff_ip_set_from_descibe)
# Check for Deregistration
new_pending_ip_dict = {}
if old_active_ip_dict:
old_diff_ip_set_from_s3 = old_active_ip_set - new_active_ip_set
old_diff_ip_set_from_descibe = registered_ip_set - new_active_ip_set
deregiter_ip_diff_set = old_diff_ip_set_from_s3 | old_diff_ip_set_from_descibe
print "INFO: Pending deregistration IPs from current invocation - {}".\
format(deregiter_ip_diff_set)
if old_pending_ip_dict:
old_pending_ip_set = set(old_pending_ip_dict.keys())
print "INFO: Pending deregistration IPs from last invocation - {}" \
.format(old_pending_ip_set)
# Additional IPs are not in the old pending list
additional_ip_set = deregiter_ip_diff_set - old_pending_ip_set
print "INFO: Additional pending IPs " \
"(pending IPs in the current but not the last invocation) - {}"\
.format(additional_ip_set)
for ip in additional_ip_set:
old_pending_ip_dict[ip] = 1
# Existing IPs that already in the old pending list
existing_ip_set = deregiter_ip_diff_set & old_pending_ip_set
print "INFO: Existing pending IPs (pending " \
"IPs in both current and the last invocation) - {}"\
.format(existing_ip_set)
for ip in existing_ip_set:
old_pending_ip_dict[ip] += 1
# Missing IPs -- In old pending list but no longer in the new pending list
missing_ip_set = old_pending_ip_set - deregiter_ip_diff_set
print "INFO: Missing pending IPs (pending " \
"IPs in the last but not the current invocation) - {}".format(missing_ip_set)
for ip in missing_ip_set:
old_pending_ip_dict.pop(ip)
new_pending_ip_dict = old_pending_ip_dict
else:
for ip in deregiter_ip_diff_set:
new_pending_ip_dict[ip] = 1
print "INFO: New pending deregisration IP- {}" .format(new_pending_ip_dict)
else:
print "INFO: No active IP List from last invocation"
pending_ip_json = json.dumps(new_pending_ip_dict)
upload_ip_list(S3_BUCKET, ACTIVE_FILENAME, active_ip_json, ACTIVE_IP_LIST_KEY)
upload_ip_list(S3_BUCKET, PENDING_DEREGISTRATION_FILENAME, pending_ip_json, PENDING_IP_LIST_KEY)
print registration_ip_list
if registration_ip_list:
registerTarget_list = target_group_list(registration_ip_list)
register_target(NLB_TG_ARN, registerTarget_list)
print "INFO: Registering {}".format(registration_ip_list)
else:
print "INFO: No new target registered"
deregistration_ip_list = []
if new_pending_ip_dict:
pending_ip_list = new_pending_ip_dict.keys()
for ip in pending_ip_list:
if new_pending_ip_dict[ip] >= INVOCATIONS_BEFORE_DEREGISTRATION :
deregistration_ip_list.append(ip)
print "INFO: Deregistering IP: {}".format(ip)
deregisterTarget_list = target_group_list(deregistration_ip_list)
deregister_target(NLB_TG_ARN, deregisterTarget_list)
else:
print "INFO: No old target deregistered"