Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tunneling sessions report #1226

Merged
merged 1 commit into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 99 additions & 62 deletions keepercommander/commands/discoveryrotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
get_router_url
from .record_edit import RecordEditMixin
from .tunnel.port_forward.endpoint import establish_symmetric_key, WebRTCConnection, TunnelEntrance, READ_TIMEOUT, \
find_open_port
find_open_port, CloseConnectionReasons
from .. import api, utils, vault_extensions, vault, record_management, attachment, record_facades
from ..display import bcolors
from ..error import CommandError
Expand Down Expand Up @@ -1627,7 +1627,14 @@ def gather_tabel_row_data(thread):


def clean_up_tunnel(params, convo_id):
tunnel_data = params.tunnel_threads.get(convo_id)
tunnel_data = None
index = None
for i, co in enumerate(params.tunnel_threads):
tmp_entrance = params.tunnel_threads[co].get('entrance', {})
if tmp_entrance and tmp_entrance.pc.endpoint_name == convo_id:
tunnel_data = params.tunnel_threads[co]
index = i
break
if tunnel_data:
kill_server_event = tunnel_data.get("kill_server_event")
if kill_server_event:
Expand All @@ -1638,18 +1645,18 @@ def clean_up_tunnel(params, convo_id):
p = tunnel_data.get("process", None)
if p and p.is_alive():
p.join()
if params.tunnel_threads.get(convo_id):
del params.tunnel_threads[convo_id]
if params.tunnel_threads_queue.get(convo_id):
del params.tunnel_threads_queue[convo_id]
if params.tunnel_threads.get(index):
del params.tunnel_threads[index]
if params.tunnel_threads_queue.get(index):
del params.tunnel_threads_queue[index]
else:
if params.debug:
print(f"{bcolors.WARNING}No tunnel data found to remove for {convo_id}{bcolors.ENDC}")


class PAMTunnelStopCommand(Command):
pam_cmd_parser = argparse.ArgumentParser(prog='pam tunnel stop')
pam_cmd_parser.add_argument('uid', type=str, action='store', help='The Tunnel UID')
pam_cmd_parser.add_argument('uid', type=str, action='store', help='The Tunnel UID or Record UID')

def get_parser(self):
return PAMTunnelStopCommand.pam_cmd_parser
Expand All @@ -1658,10 +1665,17 @@ def execute(self, params, **kwargs):
convo_id = kwargs.get('uid')
if not convo_id:
raise CommandError('tunnel stop', '"uid" argument is required')

tunnel_data = params.tunnel_threads.get(convo_id, None)
tunnel_data = []
for co in params.tunnel_threads:
tmp_entrance = params.tunnel_threads[co].get('entrance', {})
if tmp_entrance and tmp_entrance.pc.endpoint_name == convo_id:
tunnel_data.append(tmp_entrance)
elif tmp_entrance and tmp_entrance.pc.record_uid == convo_id:
tunnel_data.append(tmp_entrance)
if not tunnel_data:
raise CommandError('tunnel stop', f"No tunnel data to remove found for {convo_id}")
for co in tunnel_data:
clean_up_tunnel(params, co.pc.endpoint_name)
clean_up_tunnel(params, convo_id)

return
Expand All @@ -1678,17 +1692,25 @@ def execute(self, params, **kwargs):
convo_id = kwargs.get('uid')
if not convo_id:
raise CommandError('tunnel tail', '"uid" argument is required')
if convo_id not in params.tunnel_threads:
tunnel_data = None
index = None
for i, co in enumerate(params.tunnel_threads):
tmp_entrance = params.tunnel_threads[co].get('entrance', {})
if tmp_entrance and tmp_entrance.pc.endpoint_name == convo_id:
tunnel_data = tmp_entrance
index = i
break
if not tunnel_data:
raise CommandError('tunnel tail', f"Tunnel UID {convo_id} not found")

log_queue = params.tunnel_threads_queue.get(convo_id)
log_queue = params.tunnel_threads_queue.get(index)

logger_level = logging.getLogger().getEffectiveLevel()
aio_log_level = logging.getLogger('aiortc').getEffectiveLevel()

logging.getLogger('aiortc').setLevel(logging.DEBUG)
logging.getLogger('aioice').setLevel(logging.DEBUG)
logging.getLogger(convo_id).setLevel(logging.DEBUG)
logging.getLogger(tunnel_data.pc.endpoint_name).setLevel(logging.DEBUG)

if log_queue:
try:
Expand All @@ -1705,7 +1727,7 @@ def execute(self, params, **kwargs):
finally:
logging.getLogger('aiortc').setLevel(aio_log_level)
logging.getLogger('aioice').setLevel(aio_log_level)
logging.getLogger(convo_id).setLevel(logger_level)
logging.getLogger(tunnel_data.pc.endpoint_name).setLevel(logger_level)
else:
print(f' {bcolors.FAIL}Invalid conversation ID{bcolors.ENDC}')
return
Expand Down Expand Up @@ -1904,11 +1926,11 @@ def setup_logging(self, convo_id, log_queue, logging_level):
logger.debug("Logging setup complete.")
return logger

async def connect(self, params, record_uid, convo_id, gateway_uid, host, port,
async def connect(self, params, record_uid, convo_num, gateway_uid, host, port,
log_queue, gateway_public_key_bytes, client_private_key):

# Setup custom logging to put logs into log_queue
logger = self.setup_logging(convo_id, log_queue, logging.getLogger().getEffectiveLevel())
logger = self.setup_logging(str(convo_num), log_queue, logging.getLogger().getEffectiveLevel())

print(f"{bcolors.HIGHINTENSITYWHITE}Establishing tunnel between Commander and Gateway. Please wait...{bcolors.ENDC}")
# get the keys
Expand Down Expand Up @@ -1940,7 +1962,7 @@ async def connect(self, params, record_uid, convo_id, gateway_uid, host, port,
# Set up the pc
print_ready_event = asyncio.Event()
kill_server_event = asyncio.Event()
pc = WebRTCConnection(endpoint_name=convo_id, params=params, record_uid=record_uid, gateway_uid=gateway_uid,
pc = WebRTCConnection(params=params, record_uid=record_uid, gateway_uid=gateway_uid,
symmetric_key=symmetric_key, print_ready_event=print_ready_event,
kill_server_event=kill_server_event, logger=logger, server=params.server)

Expand All @@ -1951,13 +1973,12 @@ async def connect(self, params, record_uid, convo_id, gateway_uid, host, port,

logger.debug("starting private tunnel")

private_tunnel = TunnelEntrance(host=host, port=port, endpoint_name=convo_id, pc=pc,
print_ready_event=print_ready_event, logger=logger,
connect_task=params.tunnel_threads[convo_id].get("connect_task", None),
private_tunnel = TunnelEntrance(host=host, port=port, pc=pc, print_ready_event=print_ready_event, logger=logger,
connect_task=params.tunnel_threads[convo_num].get("connect_task", None),
kill_server_event=kill_server_event)

t1 = asyncio.create_task(private_tunnel.start_server())
params.tunnel_threads[convo_id].update({"server": t1, "entrance": private_tunnel,
params.tunnel_threads[convo_num].update({"server": t1, "entrance": private_tunnel,
"kill_server_event": kill_server_event})

logger.debug("--> START LISTENING FOR MESSAGES FROM GATEWAY --------")
Expand All @@ -1968,9 +1989,9 @@ async def connect(self, params, record_uid, convo_id, gateway_uid, host, port,
finally:
logger.debug("--> STOP LISTENING FOR MESSAGES FROM GATEWAY --------")

def pre_connect(self, params, record_uid, convo_id, gateway_uid, host, port,
def pre_connect(self, params, record_uid, convo_num, gateway_uid, host, port,
gateway_public_key_bytes, client_private_key):

tunnel_name = f"{convo_num}"
def custom_exception_handler(loop, context):
# Check if the exception is present in the context
if "exception" in context:
Expand All @@ -1986,17 +2007,17 @@ def custom_exception_handler(loop, context):
try:
# Create a new asyncio event loop and set the custom exception handler
loop = asyncio.new_event_loop()
params.tunnel_threads[convo_id].update({"loop": loop})
params.tunnel_threads[convo_num].update({"loop": loop})
asyncio.set_event_loop(loop)
loop.set_exception_handler(custom_exception_handler)
output_queue = queue.Queue(maxsize=500)
params.tunnel_threads_queue[convo_id] = output_queue
params.tunnel_threads_queue[convo_num] = output_queue
# Create a Task from the coroutine
connect_task = loop.create_task(
self.connect(
params=params,
record_uid=record_uid,
convo_id=convo_id,
convo_num=convo_num,
gateway_uid=gateway_uid,
host=host,
port=port,
Expand All @@ -2005,51 +2026,54 @@ def custom_exception_handler(loop, context):
client_private_key=client_private_key
)
)
params.tunnel_threads[convo_id].update({"connect_task": connect_task})
params.tunnel_threads[convo_num].update({"connect_task": connect_task})
try:
# Run the task until it is complete
loop.run_until_complete(connect_task)
except asyncio.CancelledError:
pass
except SocketNotConnectedException as es:
print(f"{bcolors.FAIL}Socket not connected exception in connection {convo_id}: {es}{bcolors.ENDC}")
print(f"{bcolors.FAIL}Socket not connected exception in connection {tunnel_name}: {es}{bcolors.ENDC}")
except KeyboardInterrupt:
print(f"{bcolors.OKBLUE}Exiting: {convo_id}{bcolors.ENDC}")
print(f"{bcolors.OKBLUE}Exiting: connection {tunnel_name}{bcolors.ENDC}")
except CommandError as ce:
print(f"{bcolors.FAIL}{ce}{bcolors.ENDC}")
except Exception as e:
print(f"{bcolors.FAIL}An exception occurred in connection {convo_id}: {e}{bcolors.ENDC}")
print(f"{bcolors.FAIL}An exception occurred in connection {tunnel_name}: {e}{bcolors.ENDC}")
finally:
if loop:
try:
tunnel_data = params.tunnel_threads.get(convo_id, None)
tunnel_data = params.tunnel_threads.get(convo_num, None)
co_entrance = tunnel_data.get('entrance')
if co_entrance:
tunnel_name = co_entrance.pc.endpoint_name
if not tunnel_data:
logging.debug(f"{bcolors.WARNING}No tunnel data found for {convo_id}{bcolors.ENDC}")
logging.debug(f"{bcolors.WARNING}No tunnel data found for {tunnel_name}{bcolors.ENDC}")
return

if convo_id in params.tunnel_threads_queue:
del params.tunnel_threads_queue[convo_id]
if convo_num in params.tunnel_threads_queue:
del params.tunnel_threads_queue[convo_num]

entrance = tunnel_data.get("entrance", None)
if entrance:
loop.run_until_complete(entrance.stop_server())
loop.run_until_complete(entrance.stop_server(CloseConnectionReasons.ConnectionFailed))

del params.tunnel_threads[convo_id]
logging.debug(f"Cleaned up data for {convo_id}")
del params.tunnel_threads[convo_num]
logging.debug(f"Cleaned up data for {tunnel_name}")

try:
for task in asyncio.all_tasks(loop):
task.cancel()
loop.stop()
loop.close()
logging.debug(f"{convo_id} Loop cleaned up")
logging.debug(f"{tunnel_name} Loop cleaned up")
except Exception as e:
logging.debug(f"{bcolors.WARNING}Exception while stopping event loop: {e}{bcolors.ENDC}")
except Exception as e:
print(f"{bcolors.FAIL}An exception occurred in pre_connect for connection {convo_id}: {e}{bcolors.ENDC}")
print(f"{bcolors.FAIL}An exception occurred in pre_connect for connection {tunnel_name}: {e}{bcolors.ENDC}")
finally:
clean_up_tunnel(params, convo_id)
print(f"{bcolors.OKBLUE}Tunnel {convo_id} closed.{bcolors.ENDC}")
clean_up_tunnel(params, convo_num)
print(f"{bcolors.OKBLUE}Tunnel {tunnel_name} closed.{bcolors.ENDC}")

def execute(self, params, **kwargs):
# https://pypi.org/project/aiortc/
Expand All @@ -2065,8 +2089,8 @@ def execute(self, params, **kwargs):
return

record_uid = kwargs.get('uid')
convo_id = GatewayAction.generate_conversation_id()
params.tunnel_threads[convo_id] = {}
convo_num = len(params.tunnel_threads)
params.tunnel_threads[convo_num] = {}
host = kwargs.get('host')
port = kwargs.get('port')
if port is not None and port > 0:
Expand Down Expand Up @@ -2145,42 +2169,50 @@ def execute(self, params, **kwargs):
print(f"{bcolors.FAIL}Could not retrieve public key for gateway {gateway_uid}{bcolors.ENDC}")
return

t = threading.Thread(target=self.pre_connect, args=(params, record_uid, convo_id, gateway_uid, host, port,
t = threading.Thread(target=self.pre_connect, args=(params, record_uid, convo_num, gateway_uid, host, port,
gateway_public_key_bytes, client_private_key_value)
)

# Setting the thread as a daemon thread
t.daemon = True
t.start()

if not params.tunnel_threads.get(convo_id):
params.tunnel_threads[convo_id] = {"convo_id": convo_id, "thread": t, "host": host, "port": port,
"started": datetime.now(), "record_uid": record_uid}
if not params.tunnel_threads.get(convo_num):
params.tunnel_threads[convo_num] = {"thread": t, "host": host, "port": port,
"started": datetime.now(), "record_uid": record_uid}
else:
params.tunnel_threads[convo_id].update({"convo_id": convo_id, "thread": t, "host": host, "port": port,
entrance = params.tunnel_threads[convo_num].get("entrance", None)
if entrance is not None:
endpoint_name = entrance.pc.endpoint_name
params.tunnel_threads[convo_num].update({"thread": t, "host": host, "port": port,
"started": datetime.now(), "record_uid": record_uid})
count = 0
wait_time = 120
entrance = None
while count < wait_time:
if params.tunnel_threads.get(convo_id):
entrance = params.tunnel_threads[convo_id].get("entrance", None)
if params.tunnel_threads.get(convo_num):
entrance = params.tunnel_threads[convo_num].get("entrance", None)
if entrance:
break
else:
break
count += .1
time.sleep(.1)

def print_fail():
fail_dynamic_length = len("| Endpoint ") + len(convo_id) + len(" failed to start..")
def print_fail(con_num):
con_name = ''
con_entrance = params.tunnel_threads[con_num].get("entrance", None)
fail_dynamic_length = len("| Endpoint ") + len(" failed to start..")
if con_entrance:
con_name = con_entrance.pc.endpoint_name
fail_dynamic_length = len("| Endpoint ") + len(con_name) + len(" failed to start..")

clean_up_tunnel(params, convo_id)
time.sleep(.5)
clean_up_tunnel(params, con_entrance.pc.endpoint_name)
time.sleep(.5)
# Dashed line adjusted to the length of the middle line
fail_dashed_line = '+' + '-' * fail_dynamic_length + '+'
print(f'\n{bcolors.FAIL}{fail_dashed_line}{bcolors.ENDC}')
print(f'{bcolors.FAIL}| Endpoint {bcolors.ENDC}{convo_id}{bcolors.FAIL} failed to start..{bcolors.ENDC}')
print(f'{bcolors.FAIL}| Endpoint {bcolors.ENDC}{con_name}{bcolors.FAIL} failed to start..{bcolors.ENDC}')
print(f'{bcolors.FAIL}{fail_dashed_line}{bcolors.ENDC}\n')

if entrance is not None:
Expand All @@ -2196,30 +2228,35 @@ def print_fail():
host = host + ":" if host else ''
# Total length of the dynamic parts (endpoint name, host, and port)
dynamic_length = \
(len("| Endpoint : Listening on: ") + len(convo_id) + len(host) + len(str(entrance.port)))
(len("| Endpoint : Listening on: ") +
len(entrance.pc.endpoint_name) +
len(host) +
len(str(entrance.port)))

# Dashed line adjusted to the length of the middle line
dashed_line = '+' + '-' * dynamic_length + '+'

endpoint_name = entrance.pc.endpoint_name

# Print statements
print(f'\n{bcolors.OKGREEN}{dashed_line}{bcolors.ENDC}')
print(
f'{bcolors.OKGREEN}| Endpoint {bcolors.ENDC}{bcolors.OKBLUE}{convo_id}{bcolors.ENDC}'
f'{bcolors.OKGREEN}| Endpoint {bcolors.ENDC}{bcolors.OKBLUE}{endpoint_name}{bcolors.ENDC}'
f'{bcolors.OKGREEN}: Listening on: {bcolors.ENDC}'
f'{bcolors.BOLD}{bcolors.OKBLUE}{host}{entrance.port}{bcolors.ENDC}{bcolors.OKGREEN} |{bcolors.ENDC}')
print(f'{bcolors.OKGREEN}{dashed_line}{bcolors.ENDC}')
print(
f'{bcolors.OKGREEN}View all open tunnels : {bcolors.ENDC}{bcolors.OKBLUE}pam tunnel list{bcolors.ENDC}')
print(f'{bcolors.OKGREEN}Tail logs on open tunnel: {bcolors.ENDC}'
f'{bcolors.OKBLUE}pam tunnel tail ' +
(f'-- ' if convo_id[0] == '-' else '') +
f'{convo_id}{bcolors.ENDC}')
(f'-- ' if endpoint_name[0] == '-' else '') +
f'{endpoint_name}{bcolors.ENDC}')
print(f'{bcolors.OKGREEN}Stop a tunnel : {bcolors.ENDC}'
f'{bcolors.OKBLUE}pam tunnel stop ' +
(f'-- ' if convo_id[0] == '-' else '') +
f'{convo_id}{bcolors.ENDC}\n')
(f'-- ' if endpoint_name[0] == '-' else '') +
f'{endpoint_name}{bcolors.ENDC}\n')
else:
print_fail()
print_fail(convo_num)
else:
print_fail()
print_fail(convo_num)

Loading
Loading