Skip to content

Commit

Permalink
Tunneling API message change, adding time tracking, using the correct…
Browse files Browse the repository at this point in the history
… id for lables
  • Loading branch information
miroberts authored and aaunario-keeper committed Apr 5, 2024
1 parent 4655557 commit 97abbaf
Show file tree
Hide file tree
Showing 3 changed files with 457 additions and 234 deletions.
161 changes: 99 additions & 62 deletions keepercommander/commands/discoveryrotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
get_router_url
from .record_edit import RecordEditMixin
from .tunnel.port_forward.endpoint import establish_symmetric_key, WebRTCConnection, TunnelEntrance, READ_TIMEOUT, \
find_open_port
find_open_port, CloseConnectionReasons
from .. import api, utils, vault_extensions, vault, record_management, attachment, record_facades
from ..display import bcolors
from ..error import CommandError
Expand Down Expand Up @@ -1627,7 +1627,14 @@ def gather_tabel_row_data(thread):


def clean_up_tunnel(params, convo_id):
tunnel_data = params.tunnel_threads.get(convo_id)
tunnel_data = None
index = None
for i, co in enumerate(params.tunnel_threads):
tmp_entrance = params.tunnel_threads[co].get('entrance', {})
if tmp_entrance and tmp_entrance.pc.endpoint_name == convo_id:
tunnel_data = params.tunnel_threads[co]
index = i
break
if tunnel_data:
kill_server_event = tunnel_data.get("kill_server_event")
if kill_server_event:
Expand All @@ -1638,18 +1645,18 @@ def clean_up_tunnel(params, convo_id):
p = tunnel_data.get("process", None)
if p and p.is_alive():
p.join()
if params.tunnel_threads.get(convo_id):
del params.tunnel_threads[convo_id]
if params.tunnel_threads_queue.get(convo_id):
del params.tunnel_threads_queue[convo_id]
if params.tunnel_threads.get(index):
del params.tunnel_threads[index]
if params.tunnel_threads_queue.get(index):
del params.tunnel_threads_queue[index]
else:
if params.debug:
print(f"{bcolors.WARNING}No tunnel data found to remove for {convo_id}{bcolors.ENDC}")


class PAMTunnelStopCommand(Command):
pam_cmd_parser = argparse.ArgumentParser(prog='pam tunnel stop')
pam_cmd_parser.add_argument('uid', type=str, action='store', help='The Tunnel UID')
pam_cmd_parser.add_argument('uid', type=str, action='store', help='The Tunnel UID or Record UID')

def get_parser(self):
return PAMTunnelStopCommand.pam_cmd_parser
Expand All @@ -1658,10 +1665,17 @@ def execute(self, params, **kwargs):
convo_id = kwargs.get('uid')
if not convo_id:
raise CommandError('tunnel stop', '"uid" argument is required')

tunnel_data = params.tunnel_threads.get(convo_id, None)
tunnel_data = []
for co in params.tunnel_threads:
tmp_entrance = params.tunnel_threads[co].get('entrance', {})
if tmp_entrance and tmp_entrance.pc.endpoint_name == convo_id:
tunnel_data.append(tmp_entrance)
elif tmp_entrance and tmp_entrance.pc.record_uid == convo_id:
tunnel_data.append(tmp_entrance)
if not tunnel_data:
raise CommandError('tunnel stop', f"No tunnel data to remove found for {convo_id}")
for co in tunnel_data:
clean_up_tunnel(params, co.pc.endpoint_name)
clean_up_tunnel(params, convo_id)

return
Expand All @@ -1678,17 +1692,25 @@ def execute(self, params, **kwargs):
convo_id = kwargs.get('uid')
if not convo_id:
raise CommandError('tunnel tail', '"uid" argument is required')
if convo_id not in params.tunnel_threads:
tunnel_data = None
index = None
for i, co in enumerate(params.tunnel_threads):
tmp_entrance = params.tunnel_threads[co].get('entrance', {})
if tmp_entrance and tmp_entrance.pc.endpoint_name == convo_id:
tunnel_data = tmp_entrance
index = i
break
if not tunnel_data:
raise CommandError('tunnel tail', f"Tunnel UID {convo_id} not found")

log_queue = params.tunnel_threads_queue.get(convo_id)
log_queue = params.tunnel_threads_queue.get(index)

logger_level = logging.getLogger().getEffectiveLevel()
aio_log_level = logging.getLogger('aiortc').getEffectiveLevel()

logging.getLogger('aiortc').setLevel(logging.DEBUG)
logging.getLogger('aioice').setLevel(logging.DEBUG)
logging.getLogger(convo_id).setLevel(logging.DEBUG)
logging.getLogger(tunnel_data.pc.endpoint_name).setLevel(logging.DEBUG)

if log_queue:
try:
Expand All @@ -1705,7 +1727,7 @@ def execute(self, params, **kwargs):
finally:
logging.getLogger('aiortc').setLevel(aio_log_level)
logging.getLogger('aioice').setLevel(aio_log_level)
logging.getLogger(convo_id).setLevel(logger_level)
logging.getLogger(tunnel_data.pc.endpoint_name).setLevel(logger_level)
else:
print(f' {bcolors.FAIL}Invalid conversation ID{bcolors.ENDC}')
return
Expand Down Expand Up @@ -1904,11 +1926,11 @@ def setup_logging(self, convo_id, log_queue, logging_level):
logger.debug("Logging setup complete.")
return logger

async def connect(self, params, record_uid, convo_id, gateway_uid, host, port,
async def connect(self, params, record_uid, convo_num, gateway_uid, host, port,
log_queue, gateway_public_key_bytes, client_private_key):

# Setup custom logging to put logs into log_queue
logger = self.setup_logging(convo_id, log_queue, logging.getLogger().getEffectiveLevel())
logger = self.setup_logging(str(convo_num), log_queue, logging.getLogger().getEffectiveLevel())

print(f"{bcolors.HIGHINTENSITYWHITE}Establishing tunnel between Commander and Gateway. Please wait...{bcolors.ENDC}")
# get the keys
Expand Down Expand Up @@ -1940,7 +1962,7 @@ async def connect(self, params, record_uid, convo_id, gateway_uid, host, port,
# Set up the pc
print_ready_event = asyncio.Event()
kill_server_event = asyncio.Event()
pc = WebRTCConnection(endpoint_name=convo_id, params=params, record_uid=record_uid, gateway_uid=gateway_uid,
pc = WebRTCConnection(params=params, record_uid=record_uid, gateway_uid=gateway_uid,
symmetric_key=symmetric_key, print_ready_event=print_ready_event,
kill_server_event=kill_server_event, logger=logger, server=params.server)

Expand All @@ -1951,13 +1973,12 @@ async def connect(self, params, record_uid, convo_id, gateway_uid, host, port,

logger.debug("starting private tunnel")

private_tunnel = TunnelEntrance(host=host, port=port, endpoint_name=convo_id, pc=pc,
print_ready_event=print_ready_event, logger=logger,
connect_task=params.tunnel_threads[convo_id].get("connect_task", None),
private_tunnel = TunnelEntrance(host=host, port=port, pc=pc, print_ready_event=print_ready_event, logger=logger,
connect_task=params.tunnel_threads[convo_num].get("connect_task", None),
kill_server_event=kill_server_event)

t1 = asyncio.create_task(private_tunnel.start_server())
params.tunnel_threads[convo_id].update({"server": t1, "entrance": private_tunnel,
params.tunnel_threads[convo_num].update({"server": t1, "entrance": private_tunnel,
"kill_server_event": kill_server_event})

logger.debug("--> START LISTENING FOR MESSAGES FROM GATEWAY --------")
Expand All @@ -1968,9 +1989,9 @@ async def connect(self, params, record_uid, convo_id, gateway_uid, host, port,
finally:
logger.debug("--> STOP LISTENING FOR MESSAGES FROM GATEWAY --------")

def pre_connect(self, params, record_uid, convo_id, gateway_uid, host, port,
def pre_connect(self, params, record_uid, convo_num, gateway_uid, host, port,
gateway_public_key_bytes, client_private_key):

tunnel_name = f"{convo_num}"
def custom_exception_handler(loop, context):
# Check if the exception is present in the context
if "exception" in context:
Expand All @@ -1986,17 +2007,17 @@ def custom_exception_handler(loop, context):
try:
# Create a new asyncio event loop and set the custom exception handler
loop = asyncio.new_event_loop()
params.tunnel_threads[convo_id].update({"loop": loop})
params.tunnel_threads[convo_num].update({"loop": loop})
asyncio.set_event_loop(loop)
loop.set_exception_handler(custom_exception_handler)
output_queue = queue.Queue(maxsize=500)
params.tunnel_threads_queue[convo_id] = output_queue
params.tunnel_threads_queue[convo_num] = output_queue
# Create a Task from the coroutine
connect_task = loop.create_task(
self.connect(
params=params,
record_uid=record_uid,
convo_id=convo_id,
convo_num=convo_num,
gateway_uid=gateway_uid,
host=host,
port=port,
Expand All @@ -2005,51 +2026,54 @@ def custom_exception_handler(loop, context):
client_private_key=client_private_key
)
)
params.tunnel_threads[convo_id].update({"connect_task": connect_task})
params.tunnel_threads[convo_num].update({"connect_task": connect_task})
try:
# Run the task until it is complete
loop.run_until_complete(connect_task)
except asyncio.CancelledError:
pass
except SocketNotConnectedException as es:
print(f"{bcolors.FAIL}Socket not connected exception in connection {convo_id}: {es}{bcolors.ENDC}")
print(f"{bcolors.FAIL}Socket not connected exception in connection {tunnel_name}: {es}{bcolors.ENDC}")
except KeyboardInterrupt:
print(f"{bcolors.OKBLUE}Exiting: {convo_id}{bcolors.ENDC}")
print(f"{bcolors.OKBLUE}Exiting: connection {tunnel_name}{bcolors.ENDC}")
except CommandError as ce:
print(f"{bcolors.FAIL}{ce}{bcolors.ENDC}")
except Exception as e:
print(f"{bcolors.FAIL}An exception occurred in connection {convo_id}: {e}{bcolors.ENDC}")
print(f"{bcolors.FAIL}An exception occurred in connection {tunnel_name}: {e}{bcolors.ENDC}")
finally:
if loop:
try:
tunnel_data = params.tunnel_threads.get(convo_id, None)
tunnel_data = params.tunnel_threads.get(convo_num, None)
co_entrance = tunnel_data.get('entrance')
if co_entrance:
tunnel_name = co_entrance.pc.endpoint_name
if not tunnel_data:
logging.debug(f"{bcolors.WARNING}No tunnel data found for {convo_id}{bcolors.ENDC}")
logging.debug(f"{bcolors.WARNING}No tunnel data found for {tunnel_name}{bcolors.ENDC}")
return

if convo_id in params.tunnel_threads_queue:
del params.tunnel_threads_queue[convo_id]
if convo_num in params.tunnel_threads_queue:
del params.tunnel_threads_queue[convo_num]

entrance = tunnel_data.get("entrance", None)
if entrance:
loop.run_until_complete(entrance.stop_server())
loop.run_until_complete(entrance.stop_server(CloseConnectionReasons.ConnectionFailed))

del params.tunnel_threads[convo_id]
logging.debug(f"Cleaned up data for {convo_id}")
del params.tunnel_threads[convo_num]
logging.debug(f"Cleaned up data for {tunnel_name}")

try:
for task in asyncio.all_tasks(loop):
task.cancel()
loop.stop()
loop.close()
logging.debug(f"{convo_id} Loop cleaned up")
logging.debug(f"{tunnel_name} Loop cleaned up")
except Exception as e:
logging.debug(f"{bcolors.WARNING}Exception while stopping event loop: {e}{bcolors.ENDC}")
except Exception as e:
print(f"{bcolors.FAIL}An exception occurred in pre_connect for connection {convo_id}: {e}{bcolors.ENDC}")
print(f"{bcolors.FAIL}An exception occurred in pre_connect for connection {tunnel_name}: {e}{bcolors.ENDC}")
finally:
clean_up_tunnel(params, convo_id)
print(f"{bcolors.OKBLUE}Tunnel {convo_id} closed.{bcolors.ENDC}")
clean_up_tunnel(params, convo_num)
print(f"{bcolors.OKBLUE}Tunnel {tunnel_name} closed.{bcolors.ENDC}")

def execute(self, params, **kwargs):
# https://pypi.org/project/aiortc/
Expand All @@ -2065,8 +2089,8 @@ def execute(self, params, **kwargs):
return

record_uid = kwargs.get('uid')
convo_id = GatewayAction.generate_conversation_id()
params.tunnel_threads[convo_id] = {}
convo_num = len(params.tunnel_threads)
params.tunnel_threads[convo_num] = {}
host = kwargs.get('host')
port = kwargs.get('port')
if port is not None and port > 0:
Expand Down Expand Up @@ -2145,42 +2169,50 @@ def execute(self, params, **kwargs):
print(f"{bcolors.FAIL}Could not retrieve public key for gateway {gateway_uid}{bcolors.ENDC}")
return

t = threading.Thread(target=self.pre_connect, args=(params, record_uid, convo_id, gateway_uid, host, port,
t = threading.Thread(target=self.pre_connect, args=(params, record_uid, convo_num, gateway_uid, host, port,
gateway_public_key_bytes, client_private_key_value)
)

# Setting the thread as a daemon thread
t.daemon = True
t.start()

if not params.tunnel_threads.get(convo_id):
params.tunnel_threads[convo_id] = {"convo_id": convo_id, "thread": t, "host": host, "port": port,
"started": datetime.now(), "record_uid": record_uid}
if not params.tunnel_threads.get(convo_num):
params.tunnel_threads[convo_num] = {"thread": t, "host": host, "port": port,
"started": datetime.now(), "record_uid": record_uid}
else:
params.tunnel_threads[convo_id].update({"convo_id": convo_id, "thread": t, "host": host, "port": port,
entrance = params.tunnel_threads[convo_num].get("entrance", None)
if entrance is not None:
endpoint_name = entrance.pc.endpoint_name
params.tunnel_threads[convo_num].update({"thread": t, "host": host, "port": port,
"started": datetime.now(), "record_uid": record_uid})
count = 0
wait_time = 120
entrance = None
while count < wait_time:
if params.tunnel_threads.get(convo_id):
entrance = params.tunnel_threads[convo_id].get("entrance", None)
if params.tunnel_threads.get(convo_num):
entrance = params.tunnel_threads[convo_num].get("entrance", None)
if entrance:
break
else:
break
count += .1
time.sleep(.1)

def print_fail():
fail_dynamic_length = len("| Endpoint ") + len(convo_id) + len(" failed to start..")
def print_fail(con_num):
con_name = ''
con_entrance = params.tunnel_threads[con_num].get("entrance", None)
fail_dynamic_length = len("| Endpoint ") + len(" failed to start..")
if con_entrance:
con_name = con_entrance.pc.endpoint_name
fail_dynamic_length = len("| Endpoint ") + len(con_name) + len(" failed to start..")

clean_up_tunnel(params, convo_id)
time.sleep(.5)
clean_up_tunnel(params, con_entrance.pc.endpoint_name)
time.sleep(.5)
# Dashed line adjusted to the length of the middle line
fail_dashed_line = '+' + '-' * fail_dynamic_length + '+'
print(f'\n{bcolors.FAIL}{fail_dashed_line}{bcolors.ENDC}')
print(f'{bcolors.FAIL}| Endpoint {bcolors.ENDC}{convo_id}{bcolors.FAIL} failed to start..{bcolors.ENDC}')
print(f'{bcolors.FAIL}| Endpoint {bcolors.ENDC}{con_name}{bcolors.FAIL} failed to start..{bcolors.ENDC}')
print(f'{bcolors.FAIL}{fail_dashed_line}{bcolors.ENDC}\n')

if entrance is not None:
Expand All @@ -2196,30 +2228,35 @@ def print_fail():
host = host + ":" if host else ''
# Total length of the dynamic parts (endpoint name, host, and port)
dynamic_length = \
(len("| Endpoint : Listening on: ") + len(convo_id) + len(host) + len(str(entrance.port)))
(len("| Endpoint : Listening on: ") +
len(entrance.pc.endpoint_name) +
len(host) +
len(str(entrance.port)))

# Dashed line adjusted to the length of the middle line
dashed_line = '+' + '-' * dynamic_length + '+'

endpoint_name = entrance.pc.endpoint_name

# Print statements
print(f'\n{bcolors.OKGREEN}{dashed_line}{bcolors.ENDC}')
print(
f'{bcolors.OKGREEN}| Endpoint {bcolors.ENDC}{bcolors.OKBLUE}{convo_id}{bcolors.ENDC}'
f'{bcolors.OKGREEN}| Endpoint {bcolors.ENDC}{bcolors.OKBLUE}{endpoint_name}{bcolors.ENDC}'
f'{bcolors.OKGREEN}: Listening on: {bcolors.ENDC}'
f'{bcolors.BOLD}{bcolors.OKBLUE}{host}{entrance.port}{bcolors.ENDC}{bcolors.OKGREEN} |{bcolors.ENDC}')
print(f'{bcolors.OKGREEN}{dashed_line}{bcolors.ENDC}')
print(
f'{bcolors.OKGREEN}View all open tunnels : {bcolors.ENDC}{bcolors.OKBLUE}pam tunnel list{bcolors.ENDC}')
print(f'{bcolors.OKGREEN}Tail logs on open tunnel: {bcolors.ENDC}'
f'{bcolors.OKBLUE}pam tunnel tail ' +
(f'-- ' if convo_id[0] == '-' else '') +
f'{convo_id}{bcolors.ENDC}')
(f'-- ' if endpoint_name[0] == '-' else '') +
f'{endpoint_name}{bcolors.ENDC}')
print(f'{bcolors.OKGREEN}Stop a tunnel : {bcolors.ENDC}'
f'{bcolors.OKBLUE}pam tunnel stop ' +
(f'-- ' if convo_id[0] == '-' else '') +
f'{convo_id}{bcolors.ENDC}\n')
(f'-- ' if endpoint_name[0] == '-' else '') +
f'{endpoint_name}{bcolors.ENDC}\n')
else:
print_fail()
print_fail(convo_num)
else:
print_fail()
print_fail(convo_num)

Loading

0 comments on commit 97abbaf

Please sign in to comment.