From 0f98a14e87744fbacc50e854459a482d3664bcf3 Mon Sep 17 00:00:00 2001 From: Steve Kim <86316075+sbSteveK@users.noreply.github.com> Date: Wed, 19 Apr 2023 14:50:11 -0700 Subject: [PATCH] Secure tunnel WebSocket Protocol v3 Support (#84) * add support for Secure Tunnel WebSocket Protocol V3: https://github.com/aws-samples/aws-iot-securetunneling-localproxy/blob/main/V3WebSocketProtocolGuide.md --------- Co-authored-by: Michael Graeb --- .github/workflows/ci.yml | 10 +- include/aws/iotdevice/iotdevice.h | 5 +- .../iotdevice/private/secure_tunneling_impl.h | 53 +- .../private/secure_tunneling_operations.h | 37 +- include/aws/iotdevice/private/serializer.h | 5 +- include/aws/iotdevice/secure_tunneling.h | 85 +- source/iotdevice.c | 11 +- source/secure_tunneling.c | 1096 ++++++++++++-- source/secure_tunneling_operations.c | 413 ++++- source/serializer.c | 65 +- tests/CMakeLists.txt | 16 + tests/secure_tunnel_tests.c | 1341 +++++++++++++---- 12 files changed, 2584 insertions(+), 553 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 50c92ffc..227f465e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -6,7 +6,7 @@ on: - 'main' env: - BUILDER_VERSION: v0.9.26 + BUILDER_VERSION: v0.9.40 BUILDER_SOURCE: releases BUILDER_HOST: https://d19elf31gohf1l.cloudfront.net PACKAGE_NAME: aws-c-iot @@ -105,12 +105,12 @@ jobs: runs-on: windows-2019 # windows-2019 is last env with Visual Studio 2015 (v14.0) toolset strategy: matrix: - arch: [Win32, x64] + arch: [x86, x64] steps: - name: Build ${{ env.PACKAGE_NAME }} + consumers run: | python -c "from urllib.request import urlretrieve; urlretrieve('${{ env.BUILDER_HOST }}/${{ env.BUILDER_SOURCE }}/${{ env.BUILDER_VERSION }}/builder.pyz?run=${{ env.RUN }}', 'builder.pyz')" - python builder.pyz build -p ${{ env.PACKAGE_NAME }} --cmake-extra=-Tv140 --cmake-extra=-A${{ matrix.arch }} + python builder.pyz build -p ${{ env.PACKAGE_NAME }} --target windows-${{ matrix.arch }} --compiler msvc-14 windows-shared-libs: runs-on: windows-2022 # latest @@ -128,8 +128,8 @@ jobs: python3 -c "from urllib.request import urlretrieve; urlretrieve('${{ env.BUILDER_HOST }}/${{ env.BUILDER_SOURCE }}/${{ env.BUILDER_VERSION }}/builder.pyz?run=${{ env.RUN }}', 'builder')" chmod a+x builder ./builder build -p ${{ env.PACKAGE_NAME }} - - + + # Test downstream repos. # This should not be required because we can run into a chicken and egg problem if there is a change that needs some fix in a downstream repo. downstream: diff --git a/include/aws/iotdevice/iotdevice.h b/include/aws/iotdevice/iotdevice.h index 1f18507c..230947d2 100644 --- a/include/aws/iotdevice/iotdevice.h +++ b/include/aws/iotdevice/iotdevice.h @@ -21,7 +21,9 @@ enum aws_iotdevice_error { AWS_ERROR_IOTDEVICE_DEFENDER_PUBLISH_FAILURE, AWS_ERROR_IOTDEVICE_DEFENDER_UNKNOWN_TASK_STATUS, - AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_INVALID_STREAM, + AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_INVALID_STREAM_ID, + AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_INVALID_CONNECTION_ID, + AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_INVALID_SERVICE_ID, AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_INCORRECT_MODE, AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_BAD_SERVICE_ID, AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_DATA_OPTIONS_VALIDATION, @@ -34,6 +36,7 @@ enum aws_iotdevice_error { AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_OPERATION_FAILED_DUE_TO_OFFLINE_QUEUE_POLICY, AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_UNEXPECTED_HANGUP, AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_USER_REQUESTED_STOP, + AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_PROTOCOL_VERSION_MISSMATCH, AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_TERMINATED, AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_DECODE_FAILURE, diff --git a/include/aws/iotdevice/private/secure_tunneling_impl.h b/include/aws/iotdevice/private/secure_tunneling_impl.h index 4cbfd8fb..d416848b 100644 --- a/include/aws/iotdevice/private/secure_tunneling_impl.h +++ b/include/aws/iotdevice/private/secure_tunneling_impl.h @@ -13,6 +13,7 @@ #include #include #include +#include #include #include @@ -99,10 +100,21 @@ struct data_tunnel_pair { struct aws_allocator *allocator; struct aws_byte_buf buf; struct aws_byte_cursor cur; + enum aws_secure_tunnel_message_type type; const struct aws_secure_tunnel *secure_tunnel; bool length_prefix_written; }; +struct aws_secure_tunnel_message_storage { + struct aws_allocator *allocator; + struct aws_secure_tunnel_message_view storage_view; + + struct aws_byte_cursor service_id; + struct aws_byte_cursor payload; + + struct aws_byte_buf storage; +}; + /* * Secure tunnel configuration */ @@ -120,21 +132,18 @@ struct aws_secure_tunnel_options_storage { struct aws_string *endpoint_host; - /* Stream related info */ - int32_t stream_id; - - struct aws_hash_table service_ids; - /* Callbacks */ aws_secure_tunnel_message_received_fn *on_message_received; aws_secure_tunneling_on_connection_complete_fn *on_connection_complete; aws_secure_tunneling_on_connection_shutdown_fn *on_connection_shutdown; aws_secure_tunneling_on_stream_start_fn *on_stream_start; aws_secure_tunneling_on_stream_reset_fn *on_stream_reset; + aws_secure_tunneling_on_connection_start_fn *on_connection_start; + aws_secure_tunneling_on_connection_reset_fn *on_connection_reset; aws_secure_tunneling_on_session_reset_fn *on_session_reset; aws_secure_tunneling_on_stopped_fn *on_stopped; + aws_secure_tunneling_on_send_message_complete_fn *on_send_message_complete; - aws_secure_tunneling_on_send_data_complete_fn *on_send_data_complete; aws_secure_tunneling_on_termination_complete_fn *on_termination_complete; void *secure_tunnel_on_termination_user_data; @@ -142,6 +151,23 @@ struct aws_secure_tunnel_options_storage { enum aws_secure_tunneling_local_proxy_mode local_proxy_mode; }; +struct aws_secure_tunnel_connections { + struct aws_allocator *allocator; + + uint8_t protocol_version; + + /* Used for streams not using multiplexing (service ids) */ + int32_t stream_id; + struct aws_hash_table connection_ids; + + /* Table containing streams using multiplexing (service ids) */ + struct aws_hash_table service_ids; + + /* Message used for initializing a stream upon a reconnect due to a protocol version missmatch */ + struct aws_secure_tunnel_message_storage *restore_stream_message_view; + struct aws_secure_tunnel_message_storage restore_stream_message; +}; + struct aws_secure_tunnel_vtable { /* aws_high_res_clock_get_ticks */ uint64_t (*get_current_time_fn)(void); @@ -175,9 +201,16 @@ struct aws_secure_tunnel { */ struct aws_secure_tunnel_options_storage *config; + /* + * Stores connection related information + */ + struct aws_secure_tunnel_connections *connections; + struct aws_tls_ctx *tls_ctx; struct aws_tls_connection_options tls_con_opt; + struct aws_host_resolution_config host_resolution_config; + /* * The recurrent task that runs all secure tunnel logic outside of external event callbacks. Bound to the secure * tunnel's event loop. @@ -267,6 +300,14 @@ AWS_IOTDEVICE_API void aws_secure_tunnel_set_vtable( */ AWS_IOTDEVICE_API const struct aws_secure_tunnel_vtable *aws_secure_tunnel_get_default_vtable(void); +/* + * For testing purposes. This message type should only be sent due to internal logic. + */ +AWS_IOTDEVICE_API +int aws_secure_tunnel_connection_reset( + struct aws_secure_tunnel *secure_tunnel, + const struct aws_secure_tunnel_message_view *message_options); + AWS_EXTERN_C_END #endif /* AWS_IOTDEVICE_SECURE_TUNNELING_IMPL_H */ diff --git a/include/aws/iotdevice/private/secure_tunneling_operations.h b/include/aws/iotdevice/private/secure_tunneling_operations.h index 6196351d..37c17d29 100644 --- a/include/aws/iotdevice/private/secure_tunneling_operations.h +++ b/include/aws/iotdevice/private/secure_tunneling_operations.h @@ -23,7 +23,9 @@ enum aws_secure_tunnel_operation_type { AWS_STOT_PING, AWS_STOT_MESSAGE, AWS_STOT_STREAM_RESET, - AWS_STOT_STREAM_START + AWS_STOT_STREAM_START, + AWS_STOT_CONNECTION_START, + AWS_STOT_CONNECTION_RESET, }; struct aws_service_id_element { @@ -31,18 +33,12 @@ struct aws_service_id_element { struct aws_byte_cursor service_id_cur; struct aws_string *service_id_string; int32_t stream_id; + struct aws_hash_table connection_ids; }; -struct aws_secure_tunnel_message_storage { +struct aws_connection_id_element { struct aws_allocator *allocator; - struct aws_secure_tunnel_message_view storage_view; - - bool ignorable; - int32_t stream_id; - struct aws_byte_cursor service_id; - struct aws_byte_cursor payload; - - struct aws_byte_buf storage; + uint32_t connection_id; }; /* Basic vtable for all secure tunnel operations. Implementations are per-message type */ @@ -57,10 +53,15 @@ struct aws_secure_tunnel_operation_vtable { struct aws_secure_tunnel_operation *operation, struct aws_secure_tunnel *secure_tunnel); - /* Set the stream id of outgoing st_msg to +1 of the currently set stream id */ + /* Set the stream id of outgoing STREAM START message to +1 of the currently set stream id */ int (*aws_secure_tunnel_operation_set_next_stream_id_fn)( struct aws_secure_tunnel_operation *operation, struct aws_secure_tunnel *secure_tunnel); + + /* Set the connection id of outbound CONNECTION START as active for the Source device */ + int (*aws_secure_tunnel_operation_set_connection_start_id)( + struct aws_secure_tunnel_operation *operation, + struct aws_secure_tunnel *secure_tunnel); }; /** @@ -172,6 +173,12 @@ struct aws_secure_tunnel_options_storage *aws_secure_tunnel_options_storage_new( struct aws_allocator *allocator, const struct aws_secure_tunnel_options *options); +AWS_IOTDEVICE_API +void aws_secure_tunnel_connections_destroy(struct aws_secure_tunnel_connections *storage); + +AWS_IOTDEVICE_API +struct aws_secure_tunnel_connections *aws_secure_tunnel_connections_new(struct aws_allocator *allocator); + AWS_IOTDEVICE_API void aws_secure_tunnel_options_storage_log( const struct aws_secure_tunnel_options_storage *options_storage, @@ -197,6 +204,14 @@ struct aws_service_id_element *aws_service_id_element_new( const struct aws_byte_cursor *service_id, int32_t stream_id); +AWS_IOTDEVICE_API +void aws_connection_id_destroy(void *data); + +AWS_IOTDEVICE_API +struct aws_connection_id_element *aws_connection_id_element_new( + struct aws_allocator *allocator, + uint32_t connection_id); + AWS_EXTERN_C_END #endif /* AWS_IOTDEVICE_SECURE_TUNNELING_OPERATION_H */ diff --git a/include/aws/iotdevice/private/serializer.h b/include/aws/iotdevice/private/serializer.h index 306f800c..0d1a3339 100644 --- a/include/aws/iotdevice/private/serializer.h +++ b/include/aws/iotdevice/private/serializer.h @@ -17,7 +17,7 @@ #define AWS_IOT_ST_MAXIMUM_1_BYTE_VARINT_VALUE 128 #define AWS_IOT_ST_MAXIMUM_2_BYTE_VARINT_VALUE 16384 #define AWS_IOT_ST_MAXIMUM_3_BYTE_VARINT_VALUE 2097152 -#define AWS_IOT_ST_MAX_MESSAGE_SIZE (64 * 1024) +#define AWS_IOT_ST_MAX_PAYLOAD_SIZE (63 * 1024) enum aws_secure_tunnel_field_number { AWS_SECURE_TUNNEL_FN_TYPE = 1, @@ -56,9 +56,6 @@ int aws_secure_tunnel_deserialize_message_from_cursor( struct aws_byte_cursor *cursor, aws_secure_tunnel_on_message_received_fn *on_message_received); -AWS_IOTDEVICE_API -const char *aws_secure_tunnel_message_type_to_c_string(enum aws_secure_tunnel_message_type message_type); - AWS_EXTERN_C_END #endif diff --git a/include/aws/iotdevice/secure_tunneling.h b/include/aws/iotdevice/secure_tunneling.h index d6e8f932..a55ecbba 100644 --- a/include/aws/iotdevice/secure_tunneling.h +++ b/include/aws/iotdevice/secure_tunneling.h @@ -67,7 +67,7 @@ enum aws_secure_tunnel_message_type { /** * ConnectionReset messages convey that the connection has ended, either in error, or closed intentionally for the - * tunnel peer. + * tunnel peer. These should not be manually sent from either Destination or Source clients. */ AWS_SECURE_TUNNEL_MT_CONNECTION_RESET = 7 }; @@ -86,6 +86,7 @@ struct aws_secure_tunnel_message_view { bool ignorable; int32_t stream_id; + uint32_t connection_id; /** * Secure tunnel multiplexing identifier @@ -113,23 +114,63 @@ struct aws_secure_tunnel_connection_view { */ typedef void( aws_secure_tunnel_message_received_fn)(const struct aws_secure_tunnel_message_view *message, void *user_data); - +/** + * Signature of callback to invoke on fully established connection to Secure Tunnel Service + */ typedef void(aws_secure_tunneling_on_connection_complete_fn)( const struct aws_secure_tunnel_connection_view *connection_view, int error_code, void *user_data); +/** + * Signature of callback to invoke on shutdown of connection to Secure Tunnel Service + */ typedef void(aws_secure_tunneling_on_connection_shutdown_fn)(int error_code, void *user_data); -typedef void(aws_secure_tunneling_on_send_data_complete_fn)(int error_code, void *user_data); +/** + * Signature of callback to invoke on completion of an outbound message + */ +typedef void(aws_secure_tunneling_on_send_message_complete_fn( + enum aws_secure_tunnel_message_type type, + int error_code, + void *user_data)); +/** + * Signature of callback to invoke on the start of a stream + */ typedef void(aws_secure_tunneling_on_stream_start_fn)( const struct aws_secure_tunnel_message_view *message, int error_code, void *user_data); +/** + * Signature of callback to invoke on a stream being reset + */ typedef void(aws_secure_tunneling_on_stream_reset_fn)( const struct aws_secure_tunnel_message_view *message, int error_code, void *user_data); +/** + * Signature of callback to invoke on start of a connection id stream + */ +typedef void(aws_secure_tunneling_on_connection_start_fn)( + const struct aws_secure_tunnel_message_view *message, + int error_code, + void *user_data); +/** + * Signature of callback to invoke on a connection id stream being reset + */ +typedef void(aws_secure_tunneling_on_connection_reset_fn)( + const struct aws_secure_tunnel_message_view *message, + int error_code, + void *user_data); +/** + * Signature of callback to invoke on session reset recieved from the Secure Tunnel Service + */ typedef void(aws_secure_tunneling_on_session_reset_fn)(void *user_data); +/** + * Signature of callback to invoke on Secure Tunnel reaching a STOPPED state + */ typedef void(aws_secure_tunneling_on_stopped_fn)(void *user_data); +/** + * Signature of callback to invoke on termination completion of the Native Secure Tunnel Client + */ typedef void(aws_secure_tunneling_on_termination_complete_fn)(void *user_data); /** @@ -153,6 +194,11 @@ struct aws_secure_tunnel_options { */ const struct aws_socket_options *socket_options; + /** + * (Optional) Tls options to use whenever this Secure Tunnel Client establishes a connection + */ + const struct aws_tls_connection_options *tls_options; + /** * (Optional) Http proxy options to use whenever this Secure Tunnel establishes a connection */ @@ -179,9 +225,11 @@ struct aws_secure_tunnel_options { aws_secure_tunneling_on_connection_complete_fn *on_connection_complete; aws_secure_tunneling_on_connection_shutdown_fn *on_connection_shutdown; - aws_secure_tunneling_on_send_data_complete_fn *on_send_data_complete; + aws_secure_tunneling_on_send_message_complete_fn *on_send_message_complete; aws_secure_tunneling_on_stream_start_fn *on_stream_start; aws_secure_tunneling_on_stream_reset_fn *on_stream_reset; + aws_secure_tunneling_on_connection_start_fn *on_connection_start; + aws_secure_tunneling_on_connection_reset_fn *on_connection_reset; aws_secure_tunneling_on_session_reset_fn *on_session_reset; aws_secure_tunneling_on_stopped_fn *on_stopped; @@ -192,19 +240,6 @@ struct aws_secure_tunnel_options { void *secure_tunnel_on_termination_user_data; }; -/** - * Signature of callback to invoke when secure tunnel enters a fully disconnected state - */ -typedef void(aws_secure_tunnel_disconnect_completion_fn)(int error_code, void *complete_ctx); - -/** - * Public completion callback options for the DISCONNECT operation - */ -struct aws_secure_tunnel_disconnect_completion_options { - aws_secure_tunnel_disconnect_completion_fn *completion_callback; - void *completion_user_data; -}; - AWS_EXTERN_C_BEGIN /** @@ -269,6 +304,15 @@ int aws_secure_tunnel_send_message( struct aws_secure_tunnel *secure_tunnel, const struct aws_secure_tunnel_message_view *message_options); +/** + * Get the const char description of a message type + * + * @param message_type message type used by a secure tunnel message + * @return const char translation of the message type + */ +AWS_IOTDEVICE_API +const char *aws_secure_tunnel_message_type_to_c_string(enum aws_secure_tunnel_message_type message_type); + //*********************************************************************************************************************** /* THIS API SHOULD ONLY BE USED FROM SOURCE MODE */ //*********************************************************************************************************************** @@ -277,8 +321,13 @@ int aws_secure_tunnel_stream_start( struct aws_secure_tunnel *secure_tunnel, const struct aws_secure_tunnel_message_view *message_options); +AWS_IOTDEVICE_API +int aws_secure_tunnel_connection_start( + struct aws_secure_tunnel *secure_tunnel, + const struct aws_secure_tunnel_message_view *message_options); + //*********************************************************************************************************************** -/* THIS API SHOULD NOT BE USED BY THE CUSTOMER AND SHOULD BE DEPRECATED */ +/* THIS API SHOULD NOT BE USED BY THE CUSTOMER AND IS DEPRECATED */ //*********************************************************************************************************************** AWS_IOTDEVICE_API int aws_secure_tunnel_stream_reset( diff --git a/source/iotdevice.c b/source/iotdevice.c index 28e5475d..3a317275 100644 --- a/source/iotdevice.c +++ b/source/iotdevice.c @@ -40,8 +40,14 @@ static struct aws_error_info s_errors[] = { "Device defender task was invoked with an unknown task status."), AWS_DEFINE_ERROR_INFO_IOTDEVICE( - AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_INVALID_STREAM, + AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_INVALID_STREAM_ID, "Secure Tunnel invalid stream id."), + AWS_DEFINE_ERROR_INFO_IOTDEVICE( + AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_INVALID_CONNECTION_ID, + "Secure Tunnel invalid connection id."), + AWS_DEFINE_ERROR_INFO_IOTDEVICE( + AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_INVALID_SERVICE_ID, + "Secure Tunnel invalid service id."), AWS_DEFINE_ERROR_INFO_IOTDEVICE( AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_INCORRECT_MODE, "Secure Tunnel stream cannot be started while in Destination Mode."), @@ -78,6 +84,9 @@ static struct aws_error_info s_errors[] = { AWS_DEFINE_ERROR_INFO_IOTDEVICE( AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_USER_REQUESTED_STOP, "Secure Tunnel connection interrupted by user request."), + AWS_DEFINE_ERROR_INFO_IOTDEVICE( + AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_PROTOCOL_VERSION_MISSMATCH, + "Secure Tunnel connection interrupted due to a protocol version missmatch."), AWS_DEFINE_ERROR_INFO_IOTDEVICE( AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_TERMINATED, "Secure Tunnel terminated by user request."), diff --git a/source/secure_tunneling.c b/source/secure_tunneling.c index 508ebbfb..c46733eb 100644 --- a/source/secure_tunneling.c +++ b/source/secure_tunneling.c @@ -7,6 +7,7 @@ #include #include +#include #include #include #include @@ -32,7 +33,7 @@ #define WEBSOCKET_HEADER_NAME_ACCESS_TOKEN "access-token" #define WEBSOCKET_HEADER_NAME_CLIENT_TOKEN "client-token" #define WEBSOCKET_HEADER_NAME_PROTOCOL "Sec-WebSocket-Protocol" -#define WEBSOCKET_HEADER_PROTOCOL_VALUE "aws.iot.securetunneling-2.0" +#define WEBSOCKET_HEADER_PROTOCOL_VALUE "aws.iot.securetunneling-3.0" static void s_change_current_state(struct aws_secure_tunnel *secure_tunnel, enum aws_secure_tunnel_state next_state); void aws_secure_tunnel_operational_state_clean_up(struct aws_secure_tunnel *secure_tunnel); @@ -43,12 +44,14 @@ static void s_complete_operation_list( struct aws_secure_tunnel *secure_tunnel, struct aws_linked_list *operation_list, int error_code); - -static int s_secure_tunneling_send( +static void s_reevaluate_service_task(struct aws_secure_tunnel *secure_tunnel); +static void s_aws_secure_tunnel_connected_on_message_received( + struct aws_secure_tunnel *secure_tunnel, + struct aws_secure_tunnel_message_view *message_view); +static int s_aws_secure_tunnel_remove_connection_id( struct aws_secure_tunnel *secure_tunnel, const struct aws_secure_tunnel_message_view *message_view); - -static void s_reevaluate_service_task(struct aws_secure_tunnel *secure_tunnel); +int reset_secure_tunnel_connection(struct aws_secure_tunnel *secure_tunnel); const char *aws_secure_tunnel_state_to_c_string(enum aws_secure_tunnel_state state) { switch (state) { @@ -78,6 +81,37 @@ const char *aws_secure_tunnel_state_to_c_string(enum aws_secure_tunnel_state sta } } +const char *aws_secure_tunnel_message_type_to_c_string(enum aws_secure_tunnel_message_type message_type) { + switch (message_type) { + case AWS_SECURE_TUNNEL_MT_UNKNOWN: + return "ST_MT_UNKNOWN"; + + case AWS_SECURE_TUNNEL_MT_DATA: + return "DATA"; + + case AWS_SECURE_TUNNEL_MT_STREAM_START: + return "STREAM START"; + + case AWS_SECURE_TUNNEL_MT_STREAM_RESET: + return "STREAM RESET"; + + case AWS_SECURE_TUNNEL_MT_SESSION_RESET: + return "SESSION RESET"; + + case AWS_SECURE_TUNNEL_MT_SERVICE_IDS: + return "SERVICE IDS"; + + case AWS_SECURE_TUNNEL_MT_CONNECTION_START: + return "CONNECTION START"; + + case AWS_SECURE_TUNNEL_MT_CONNECTION_RESET: + return "CONNECTION RESET"; + + default: + return "UNKNOWN"; + } +} + static const char *s_get_proxy_mode_string(enum aws_secure_tunneling_local_proxy_mode local_proxy_mode) { if (local_proxy_mode == AWS_SECURE_TUNNELING_SOURCE_MODE) { return "source"; @@ -85,13 +119,6 @@ static const char *s_get_proxy_mode_string(enum aws_secure_tunneling_local_proxy return "destination"; } -static int s_reset_service_id(void *context, struct aws_hash_element *p_element) { - (void)context; - struct aws_service_id_element *service_id_elem = p_element->value; - service_id_elem->stream_id = INVALID_STREAM_ID; - return AWS_COMMON_HASH_TABLE_ITER_CONTINUE; -} - /********************************************************************************************************************* * Secure Tunnel Clean Up ********************************************************************************************************************/ @@ -114,6 +141,7 @@ static void s_secure_tunnel_final_destroy(struct aws_secure_tunnel *secure_tunne aws_secure_tunnel_operational_state_clean_up(secure_tunnel); /* Clean up all memory */ + aws_secure_tunnel_connections_destroy(secure_tunnel->connections); aws_secure_tunnel_options_storage_destroy(secure_tunnel->config); aws_http_message_release(secure_tunnel->handshake_request); aws_byte_buf_clean_up(&secure_tunnel->received_data); @@ -132,37 +160,80 @@ static void s_on_secure_tunnel_zero_ref_count(void *user_data) { } /***************************************************************************************************************** - * RECEIVE MESSAGE HANDLING + * STREAM HANDLING *****************************************************************************************************************/ +static void s_set_absent_connection_id_to_one(struct aws_secure_tunnel_message_view *message, uint32_t *connection_id) { + if (message->connection_id == 0) { + *connection_id = 1; + } +} + +static int s_reset_service_id(void *context, struct aws_hash_element *p_element) { + (void)context; + struct aws_service_id_element *service_id_elem = p_element->value; + service_id_elem->stream_id = INVALID_STREAM_ID; + aws_hash_table_clear(&service_id_elem->connection_ids); + return AWS_COMMON_HASH_TABLE_ITER_CONTINUE; +} + /* * Close and reset all stream ids */ -static void s_reset_secure_tunnel(struct aws_secure_tunnel *secure_tunnel) { +static void s_reset_secure_tunnel_streams(struct aws_secure_tunnel *secure_tunnel) { AWS_LOGF_INFO(AWS_LS_IOTDEVICE_SECURE_TUNNELING, "id=%p: Secure tunnel session reset.", (void *)secure_tunnel); - secure_tunnel->config->stream_id = INVALID_STREAM_ID; - aws_hash_table_foreach(&secure_tunnel->config->service_ids, s_reset_service_id, NULL); + secure_tunnel->connections->protocol_version = 0; + secure_tunnel->connections->stream_id = INVALID_STREAM_ID; + aws_hash_table_clear(&secure_tunnel->connections->connection_ids); + aws_hash_table_foreach(&secure_tunnel->connections->service_ids, s_reset_service_id, NULL); secure_tunnel->received_data.len = 0; /* Drop any incomplete secure tunnel frame */ } -static bool s_aws_secure_tunnel_stream_id_check_match( +static uint8_t s_aws_secure_tunnel_message_min_protocol_check(const struct aws_secure_tunnel_message_view *message) { + uint8_t version = 1; + + if (message->service_id != NULL && message->service_id->len > 0) { + version = 2; + } + + if (message->connection_id > 0) { + version = 3; + } + + return version; +} + +static bool s_aws_secure_tunnel_protocol_version_match_check( + struct aws_secure_tunnel *secure_tunnel, + const struct aws_secure_tunnel_message_view *message) { + uint8_t message_protocol_version = s_aws_secure_tunnel_message_min_protocol_check(message); + if (secure_tunnel->connections->protocol_version != message_protocol_version) { + AWS_LOGF_WARN( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: Secure Tunnel is currently using Protocol V%d and received a message using Protocol V%d", + (void *)secure_tunnel, + (int)secure_tunnel->connections->protocol_version, + message_protocol_version); + return false; + } + return true; +} + +static bool s_aws_secure_tunnel_stream_id_match_check( struct aws_secure_tunnel *secure_tunnel, const struct aws_byte_cursor *service_id, int32_t stream_id) { - /* No service id means V1 protocol is being used */ - if (service_id->len == 0) { - return (secure_tunnel->config->stream_id == stream_id); + /* + * No service id means either V1 protocol is being used or V3 protocol is being used on a tunnel without service ids + */ + if (service_id == NULL || service_id->len == 0) { + return (secure_tunnel->connections->stream_id == stream_id); } struct aws_hash_element *elem = NULL; - aws_hash_table_find(&secure_tunnel->config->service_ids, service_id, &elem); + aws_hash_table_find(&secure_tunnel->connections->service_ids, service_id, &elem); if (elem == NULL) { - AWS_LOGF_WARN( - AWS_LS_IOTDEVICE_SECURE_TUNNELING, - "id=%p: Secure tunnel stream id check request for unsupported service_id: " PRInSTR, - (void *)secure_tunnel, - AWS_BYTE_CURSOR_PRI(*service_id)); return false; } @@ -170,27 +241,97 @@ static bool s_aws_secure_tunnel_stream_id_check_match( return (stream_id == service_id_elem->stream_id); } -static int s_aws_secure_tunnel_set_stream_id( +static bool s_aws_secure_tunnel_active_stream_check( + struct aws_secure_tunnel *secure_tunnel, + const struct aws_secure_tunnel_message_view *message_view) { + /* + * No service id means either V1 protocol is being used or V3 protocol is being used on a tunnel without service ids + */ + if (message_view->service_id == NULL || message_view->service_id->len == 0) { + if (secure_tunnel->connections->stream_id != message_view->stream_id) { + return false; + } + + uint32_t connection_id = message_view->connection_id; + if (connection_id == 0) { + connection_id = 1; + } + + /* + * V1 and V2 connection id has been stored as 1. V3 can be any number > 0. Either way, connection id will be + * checked against stored connection_ids to confirm the stream is active. + */ + struct aws_hash_element *connection_id_elem = NULL; + aws_hash_table_find(&secure_tunnel->connections->connection_ids, &connection_id, &connection_id_elem); + if (connection_id_elem == NULL) { + aws_raise_error(AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_INVALID_CONNECTION_ID); + return false; + } + return true; + } + + /* Check if service id is being used by the secure tunnel */ + struct aws_hash_element *elem = NULL; + aws_hash_table_find(&secure_tunnel->connections->service_ids, message_view->service_id, &elem); + if (elem == NULL) { + aws_raise_error(AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_INVALID_SERVICE_ID); + return false; + } + + /* Check if the stream id is the currently active one */ + struct aws_service_id_element *service_id_elem = elem->value; + if (message_view->stream_id != service_id_elem->stream_id) { + aws_raise_error(AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_INVALID_STREAM_ID); + return false; + } + + /* V1 and V2 will be considered active at this point with a matching stream id but V3 streams will need to have + * their connection id checked */ + if (secure_tunnel->connections->protocol_version == 3) { + struct aws_hash_element *connection_id_elem = NULL; + aws_hash_table_find(&service_id_elem->connection_ids, &message_view->connection_id, &connection_id_elem); + if (connection_id_elem == NULL) { + aws_raise_error(AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_INVALID_CONNECTION_ID); + return false; + } + } + + return true; +} + +static int s_aws_secure_tunnel_set_stream( struct aws_secure_tunnel *secure_tunnel, const struct aws_byte_cursor *service_id, - int32_t stream_id) { + int32_t stream_id, + uint32_t connection_id) { /* No service id means V1 protocol is being used */ if (service_id == NULL || service_id->len == 0) { - secure_tunnel->config->stream_id = stream_id; + secure_tunnel->connections->stream_id = stream_id; + aws_hash_table_clear(&secure_tunnel->connections->connection_ids); + if (connection_id > 0) { + struct aws_connection_id_element *connection_id_elem = + aws_connection_id_element_new(secure_tunnel->allocator, connection_id); + aws_hash_table_put( + &secure_tunnel->connections->connection_ids, + &connection_id_elem->connection_id, + connection_id_elem, + NULL); + } AWS_LOGF_INFO( AWS_LS_IOTDEVICE_SECURE_TUNNELING, - "id=%p: Secure tunnel stream_id set to %d", + "id=%p: Secure tunnel set to stream id (%d) with active connection id(%d)", (void *)secure_tunnel, - stream_id); + stream_id, + connection_id); return AWS_OP_SUCCESS; } struct aws_hash_element *elem = NULL; - aws_hash_table_find(&secure_tunnel->config->service_ids, service_id, &elem); + aws_hash_table_find(&secure_tunnel->connections->service_ids, service_id, &elem); if (elem == NULL) { AWS_LOGF_WARN( AWS_LS_IOTDEVICE_SECURE_TUNNELING, - "id=%p: Secure tunnel request for unsupported service_id: " PRInSTR, + "id=%p: Incomming stream set request for unsupported service_id: " PRInSTR, (void *)secure_tunnel, AWS_BYTE_CURSOR_PRI(*service_id)); return AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_BAD_SERVICE_ID; @@ -199,21 +340,268 @@ static int s_aws_secure_tunnel_set_stream_id( struct aws_service_id_element *replacement_elem = aws_service_id_element_new(secure_tunnel->allocator, service_id, stream_id); - aws_hash_table_put(&secure_tunnel->config->service_ids, &replacement_elem->service_id_cur, replacement_elem, NULL); + if (connection_id > 0) { + struct aws_connection_id_element *connection_id_elem = + aws_connection_id_element_new(secure_tunnel->allocator, connection_id); + aws_hash_table_put( + &replacement_elem->connection_ids, &connection_id_elem->connection_id, connection_id_elem, NULL); + } + aws_hash_table_put( + &secure_tunnel->connections->service_ids, &replacement_elem->service_id_cur, replacement_elem, NULL); + AWS_LOGF_INFO( AWS_LS_IOTDEVICE_SECURE_TUNNELING, - "id=%p: Secure tunnel service_id '" PRInSTR "' stream_id set to %d", + "id=%p: Secure Tunnel service id '" PRInSTR "' set to stream id (%d) with active connection_id (%d)", (void *)secure_tunnel, AWS_BYTE_CURSOR_PRI(*service_id), - stream_id); + stream_id, + connection_id); + + return AWS_OP_SUCCESS; +} + +static int s_aws_secure_tunnel_set_connection_id( + struct aws_secure_tunnel *secure_tunnel, + struct aws_byte_cursor *service_id, + uint32_t connection_id) { + struct aws_hash_table *table_to_put_in = NULL; + if (service_id == NULL || service_id->len == 0) { + table_to_put_in = &secure_tunnel->connections->connection_ids; + } else { + struct aws_hash_element *elem = NULL; + aws_hash_table_find(&secure_tunnel->connections->service_ids, service_id, &elem); + if (elem == NULL) { + AWS_LOGF_ERROR( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: invalid service_id:'" PRInSTR + "' attempted to be used to start a stream using a connection id (%d)", + (void *)secure_tunnel, + AWS_BYTE_CURSOR_PRI(*service_id), + connection_id); + return aws_raise_error(AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_INVALID_SERVICE_ID); + } else { + struct aws_service_id_element *service_id_elem = elem->value; + table_to_put_in = &service_id_elem->connection_ids; + } + } + + if (connection_id != 0) { + struct aws_connection_id_element *connection_id_elem = NULL; + connection_id_elem = aws_connection_id_element_new(secure_tunnel->allocator, connection_id); + struct aws_hash_element *preexisting_connection_id_elem = NULL; + + aws_hash_table_find(table_to_put_in, &connection_id_elem->connection_id, &preexisting_connection_id_elem); + + if (preexisting_connection_id_elem == NULL) { + aws_hash_table_put(table_to_put_in, &connection_id_elem->connection_id, connection_id_elem, NULL); + + if (service_id == NULL || service_id->len == 0) { + AWS_LOGF_INFO( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: Stream started using connection id (%d)", + (void *)secure_tunnel, + connection_id); + } else { + AWS_LOGF_INFO( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: Stream started on service_id:'" PRInSTR "' using connection id (%d)", + (void *)secure_tunnel, + AWS_BYTE_CURSOR_PRI(*service_id), + connection_id); + } + + } else { + /* + * If the connection id is already stored something is wrong and this connection id must be removed and a + * connection reset must be sent for this connection id + */ + aws_connection_id_destroy(connection_id_elem); + if (service_id == NULL || service_id->len == 0) { + AWS_LOGF_INFO( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: Connection Start on existing connection id (%d) received. Closing active stream and " + "sending CONNECTION RESET.", + (void *)secure_tunnel, + connection_id); + } else { + AWS_LOGF_INFO( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: Connection Start on service_id:'" PRInSTR + "' on existing connection id (%d) received. Closing active stream and sending CONNECTION RESET.", + (void *)secure_tunnel, + AWS_BYTE_CURSOR_PRI(*service_id), + connection_id); + } + struct aws_secure_tunnel_message_view reset_message = { + .type = AWS_SECURE_TUNNEL_MT_CONNECTION_RESET, + .service_id = service_id, + .connection_id = connection_id, + }; + + s_aws_secure_tunnel_remove_connection_id(secure_tunnel, &reset_message); + if (secure_tunnel->config->on_connection_reset) { + secure_tunnel->config->on_connection_reset( + &reset_message, + AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_INVALID_CONNECTION_ID, + secure_tunnel->config->user_data); + } + + aws_secure_tunnel_connection_reset(secure_tunnel, &reset_message); + + return aws_raise_error(AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_INVALID_CONNECTION_ID); + } + } else { + AWS_LOGF_ERROR( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: Connection Id can not be set to 0 on a CONNECTION START", + (void *)secure_tunnel); + return aws_raise_error(AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_INVALID_CONNECTION_ID); + } return AWS_OP_SUCCESS; } +static int s_aws_secure_tunnel_remove_connection_id( + struct aws_secure_tunnel *secure_tunnel, + const struct aws_secure_tunnel_message_view *message_view) { + + if (s_aws_secure_tunnel_active_stream_check(secure_tunnel, message_view)) { + struct aws_hash_table *table_to_remove_from = NULL; + + if (message_view->service_id == NULL || message_view->service_id->len == 0) { + table_to_remove_from = &secure_tunnel->connections->connection_ids; + } else { + struct aws_hash_element *elem = NULL; + aws_hash_table_find(&secure_tunnel->connections->service_ids, message_view->service_id, &elem); + struct aws_service_id_element *service_id_elem = elem->value; + table_to_remove_from = &service_id_elem->connection_ids; + } + + aws_hash_table_remove(table_to_remove_from, &message_view->connection_id, NULL, NULL); + + if (message_view->service_id == NULL || message_view->service_id->len == 0) { + AWS_LOGF_INFO( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: Stream using connection id (%d) closed", + (void *)secure_tunnel, + message_view->connection_id); + } else { + AWS_LOGF_INFO( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: Stream on service_id:'" PRInSTR "' using connection id (%d) closed", + (void *)secure_tunnel, + AWS_BYTE_CURSOR_PRI(*message_view->service_id), + message_view->connection_id); + } + } else { + return aws_last_error(); + } + + return AWS_OP_SUCCESS; +} + +/***************************************************************************************************************** + * RECEIVE MESSAGE HANDLING + *****************************************************************************************************************/ + +static void s_aws_secure_tunnel_on_data_received( + struct aws_secure_tunnel *secure_tunnel, + struct aws_secure_tunnel_message_view *message_view) { + + if (!s_aws_secure_tunnel_protocol_version_match_check(secure_tunnel, message_view)) { + /* + * Protocol missmatch results in a full disconnect/reconnect to the Secure Tunnel Service followed by + * initializing the stream that caused the missmatch + */ + AWS_LOGF_INFO( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: Secure Tunnel will be reset due to Protocol Version missmatch between previously established " + "Protocol Version and Protocol Version used by incoming STREAM START message.", + (void *)secure_tunnel); + reset_secure_tunnel_connection(secure_tunnel); + return; + } + + /* + * An absent connection ID in DESTINATION MODE will result in connection id being set to 1. + */ + if (secure_tunnel->config->local_proxy_mode == AWS_SECURE_TUNNELING_DESTINATION_MODE) { + s_set_absent_connection_id_to_one(message_view, &message_view->connection_id); + } + + if (s_aws_secure_tunnel_active_stream_check(secure_tunnel, message_view)) { + if (secure_tunnel->config->on_message_received) { + secure_tunnel->config->on_message_received(message_view, secure_tunnel->config->user_data); + } + } else { + if (message_view->service_id->len > 0) { + AWS_LOGF_INFO( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: Incomming DATA message on inactive stream with service id '" PRInSTR + "' stream id (%d) connection id (%d) ignored", + (void *)secure_tunnel, + AWS_BYTE_CURSOR_PRI(*message_view->service_id), + message_view->stream_id, + message_view->connection_id); + } else { + AWS_LOGF_INFO( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: Incomming DATA message on inactive stream with stream id (%d) connection id (%d) ignored", + (void *)secure_tunnel, + message_view->stream_id, + message_view->connection_id); + } + } +} + static void s_aws_secure_tunnel_on_stream_start_received( struct aws_secure_tunnel *secure_tunnel, struct aws_secure_tunnel_message_view *message_view) { - int result = s_aws_secure_tunnel_set_stream_id(secure_tunnel, message_view->service_id, message_view->stream_id); + /* + * If a protocol version hasn't been established yet, the first STREAM START determines the protocol version + * being used this session + */ + if (secure_tunnel->connections->protocol_version == 0) { + + uint8_t message_protocol_version = s_aws_secure_tunnel_message_min_protocol_check(message_view); + secure_tunnel->connections->protocol_version = message_protocol_version; + AWS_LOGF_INFO( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: Secure tunnel client Protocol set to V%d based on received STREAM START", + (void *)secure_tunnel, + secure_tunnel->connections->protocol_version); + } else if (!s_aws_secure_tunnel_protocol_version_match_check(secure_tunnel, message_view)) { + /* + * Protocol missmatch results in a full disconnect/reconnect to the Secure Tunnel Service followed by + * initializing the stream that caused the missmatch + */ + AWS_LOGF_INFO( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: Secure Tunnel will be reset due to Protocol Version missmatch between previously established " + "Protocol Version and Protocol Version used by incoming STREAM START message.", + (void *)secure_tunnel); + reset_secure_tunnel_connection(secure_tunnel); + aws_secure_tunnel_message_storage_init( + &secure_tunnel->connections->restore_stream_message, + secure_tunnel->allocator, + message_view, + AWS_STOT_STREAM_START); + secure_tunnel->connections->restore_stream_message_view = &secure_tunnel->connections->restore_stream_message; + return; + } + + uint32_t connection_id = message_view->connection_id; + + /* + * An absent connection ID will result in connection id being set to 1. The connection is considered a V1 + * connection at this point and the future existance of an unexpected connection ID will result in a full reset + * of the client as mixed protocol versions is not supported. + */ + s_set_absent_connection_id_to_one(message_view, &connection_id); + + int result = + s_aws_secure_tunnel_set_stream(secure_tunnel, message_view->service_id, message_view->stream_id, connection_id); + if (secure_tunnel->config->on_stream_start) { secure_tunnel->config->on_stream_start(message_view, result, secure_tunnel->config->user_data); } @@ -222,17 +610,44 @@ static void s_aws_secure_tunnel_on_stream_start_received( static void s_aws_secure_tunnel_on_stream_reset_received( struct aws_secure_tunnel *secure_tunnel, struct aws_secure_tunnel_message_view *message_view) { - int result = AWS_OP_SUCCESS; - if (s_aws_secure_tunnel_stream_id_check_match(secure_tunnel, message_view->service_id, message_view->stream_id)) { - result = s_aws_secure_tunnel_set_stream_id(secure_tunnel, message_view->service_id, INVALID_STREAM_ID); + + if (secure_tunnel->connections->protocol_version != 0 && + !s_aws_secure_tunnel_protocol_version_match_check(secure_tunnel, message_view)) { + AWS_LOGF_INFO( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: Secure Tunnel will be reset due to Protocol Version missmatch between previously established " + "Protocol Version and Protocol Version used by incoming STREAM RESET message.", + (void *)secure_tunnel); + reset_secure_tunnel_connection(secure_tunnel); + return; } - if (secure_tunnel->config->on_stream_reset) { - secure_tunnel->config->on_stream_reset(message_view, result, secure_tunnel->config->user_data); + + int result = AWS_OP_SUCCESS; + if (s_aws_secure_tunnel_stream_id_match_check(secure_tunnel, message_view->service_id, message_view->stream_id)) { + result = s_aws_secure_tunnel_set_stream(secure_tunnel, message_view->service_id, INVALID_STREAM_ID, 0); + if (secure_tunnel->config->on_stream_reset) { + secure_tunnel->config->on_stream_reset(message_view, result, secure_tunnel->config->user_data); + } + } else { + if (message_view->service_id->len > 0) { + AWS_LOGF_INFO( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: Incomming STREAM RESET on inactive stream with service id '" PRInSTR "' stream id (%d) ignored", + (void *)secure_tunnel, + AWS_BYTE_CURSOR_PRI(*message_view->service_id), + message_view->stream_id); + } else { + AWS_LOGF_INFO( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: Incomming STREAM RESET on inactive stream with stream id (%d) ignored", + (void *)secure_tunnel, + message_view->stream_id); + } } } static void s_aws_secure_tunnel_on_session_reset_received(struct aws_secure_tunnel *secure_tunnel) { - s_reset_secure_tunnel(secure_tunnel); + s_reset_secure_tunnel_streams(secure_tunnel); if (secure_tunnel->config->on_session_reset) { secure_tunnel->config->on_session_reset(secure_tunnel->config->user_data); } @@ -242,13 +657,13 @@ static void s_aws_secure_tunnel_on_service_ids_received( struct aws_secure_tunnel *secure_tunnel, struct aws_secure_tunnel_message_view *message_view) { - aws_hash_table_clear(&secure_tunnel->config->service_ids); + aws_hash_table_clear(&secure_tunnel->connections->service_ids); if (message_view->service_id != NULL) { struct aws_service_id_element *service_id_1_elem = aws_service_id_element_new(secure_tunnel->allocator, message_view->service_id, INVALID_STREAM_ID); aws_hash_table_put( - &secure_tunnel->config->service_ids, &service_id_1_elem->service_id_cur, service_id_1_elem, NULL); + &secure_tunnel->connections->service_ids, &service_id_1_elem->service_id_cur, service_id_1_elem, NULL); AWS_LOGF_INFO( AWS_LS_IOTDEVICE_SECURE_TUNNELING, "id=%p: secure tunnel service id 1 set to: " PRInSTR, @@ -258,7 +673,7 @@ static void s_aws_secure_tunnel_on_service_ids_received( struct aws_service_id_element *service_id_2_elem = aws_service_id_element_new(secure_tunnel->allocator, message_view->service_id_2, INVALID_STREAM_ID); aws_hash_table_put( - &secure_tunnel->config->service_ids, &service_id_2_elem->service_id_cur, service_id_2_elem, NULL); + &secure_tunnel->connections->service_ids, &service_id_2_elem->service_id_cur, service_id_2_elem, NULL); AWS_LOGF_INFO( AWS_LS_IOTDEVICE_SECURE_TUNNELING, "id=%p: secure tunnel service id 2 set to: " PRInSTR, @@ -268,7 +683,10 @@ static void s_aws_secure_tunnel_on_service_ids_received( struct aws_service_id_element *service_id_3_elem = aws_service_id_element_new(secure_tunnel->allocator, message_view->service_id_3, INVALID_STREAM_ID); aws_hash_table_put( - &secure_tunnel->config->service_ids, &service_id_3_elem->service_id_cur, service_id_3_elem, NULL); + &secure_tunnel->connections->service_ids, + &service_id_3_elem->service_id_cur, + service_id_3_elem, + NULL); AWS_LOGF_INFO( AWS_LS_IOTDEVICE_SECURE_TUNNELING, "id=%p: secure tunnel service id 3 set to: " PRInSTR, @@ -284,37 +702,128 @@ static void s_aws_secure_tunnel_on_service_ids_received( connection_view.service_id_2 = message_view->service_id_2; connection_view.service_id_3 = message_view->service_id_3; - /* A connection can only be used once available service ids are established with the secure tunnel. */ - if (secure_tunnel->config->on_connection_complete) { - secure_tunnel->config->on_connection_complete( - &connection_view, AWS_ERROR_SUCCESS, secure_tunnel->config->user_data); + /* A connection can only be used once available service ids are established with the secure tunnel. */ + if (secure_tunnel->config->on_connection_complete) { + secure_tunnel->config->on_connection_complete( + &connection_view, AWS_ERROR_SUCCESS, secure_tunnel->config->user_data); + } + + /* Initialize stream if one is set to be started upon a reconnect */ + if (secure_tunnel->connections->restore_stream_message_view != NULL) { + AWS_LOGF_INFO( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: Secure Tunnel will process the STREAM START that caused reconnection due to changing protocol by " + "Source Device.", + (void *)secure_tunnel); + s_aws_secure_tunnel_connected_on_message_received( + secure_tunnel, &secure_tunnel->connections->restore_stream_message_view->storage_view); + aws_secure_tunnel_message_storage_clean_up(&secure_tunnel->connections->restore_stream_message); + secure_tunnel->connections->restore_stream_message_view = NULL; + } +} + +static void s_aws_secure_tunnel_on_connection_start_received( + struct aws_secure_tunnel *secure_tunnel, + struct aws_secure_tunnel_message_view *message_view) { + if (secure_tunnel->connections->protocol_version != 3) { + AWS_LOGF_INFO( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: Secure Tunnel will be reset due to Protocol Version missmatch between previously established " + "Protocol Version and Protocol Version used by incoming CONNECTION START message.", + (void *)secure_tunnel); + reset_secure_tunnel_connection(secure_tunnel); + return; + } + + /* + * An absent connection ID will result in connection id being set to 1. + */ + s_set_absent_connection_id_to_one(message_view, &message_view->connection_id); + + if (s_aws_secure_tunnel_stream_id_match_check(secure_tunnel, message_view->service_id, message_view->stream_id)) { + int result = + s_aws_secure_tunnel_set_connection_id(secure_tunnel, message_view->service_id, message_view->connection_id); + if (secure_tunnel->config->on_connection_start) { + secure_tunnel->config->on_connection_start(message_view, result, secure_tunnel->config->user_data); + } + } else { + if (message_view->service_id->len > 0) { + AWS_LOGF_INFO( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: Incomming CONNECTION START on inactive stream with service id '" PRInSTR + "' stream id (%d) ignored", + (void *)secure_tunnel, + AWS_BYTE_CURSOR_PRI(*message_view->service_id), + message_view->stream_id); + } else { + AWS_LOGF_INFO( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: Incomming CONNECTION START on inactive stream with stream id (%d) ignored", + (void *)secure_tunnel, + message_view->stream_id); + } + } +} + +static void s_aws_secure_tunnel_on_connection_reset_received( + struct aws_secure_tunnel *secure_tunnel, + struct aws_secure_tunnel_message_view *message_view) { + if (secure_tunnel->connections->protocol_version != 3) { + AWS_LOGF_INFO( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: Secure Tunnel will be reset due to Protocol Version missmatch between previously established " + "Protocol Version and Protocol Version used by incoming CONNECTION RESET message.", + (void *)secure_tunnel); + reset_secure_tunnel_connection(secure_tunnel); + return; + } + + /* + * An absent connection ID will result in connection id being set to 1. + */ + s_set_absent_connection_id_to_one(message_view, &message_view->connection_id); + + int result = s_aws_secure_tunnel_remove_connection_id(secure_tunnel, message_view); + + if (secure_tunnel->config->on_connection_reset) { + secure_tunnel->config->on_connection_reset(message_view, result, secure_tunnel->config->user_data); } } static void s_aws_secure_tunnel_connected_on_message_received( struct aws_secure_tunnel *secure_tunnel, struct aws_secure_tunnel_message_view *message_view) { + aws_secure_tunnel_message_view_log(message_view, AWS_LL_DEBUG); switch (message_view->type) { case AWS_SECURE_TUNNEL_MT_DATA: - if (secure_tunnel->config->on_message_received) { - secure_tunnel->config->on_message_received(message_view, secure_tunnel->config->user_data); - } + s_aws_secure_tunnel_on_data_received(secure_tunnel, message_view); break; + case AWS_SECURE_TUNNEL_MT_STREAM_START: s_aws_secure_tunnel_on_stream_start_received(secure_tunnel, message_view); break; + case AWS_SECURE_TUNNEL_MT_STREAM_RESET: s_aws_secure_tunnel_on_stream_reset_received(secure_tunnel, message_view); break; + case AWS_SECURE_TUNNEL_MT_SESSION_RESET: s_aws_secure_tunnel_on_session_reset_received(secure_tunnel); break; + case AWS_SECURE_TUNNEL_MT_SERVICE_IDS: s_aws_secure_tunnel_on_service_ids_received(secure_tunnel, message_view); break; + case AWS_SECURE_TUNNEL_MT_CONNECTION_START: + s_aws_secure_tunnel_on_connection_start_received(secure_tunnel, message_view); + break; + case AWS_SECURE_TUNNEL_MT_CONNECTION_RESET: + s_aws_secure_tunnel_on_connection_reset_received(secure_tunnel, message_view); + break; + case AWS_SECURE_TUNNEL_MT_UNKNOWN: default: if (!message_view->ignorable) { @@ -328,12 +837,13 @@ static void s_aws_secure_tunnel_connected_on_message_received( } static int s_process_received_data(struct aws_secure_tunnel *secure_tunnel) { + struct aws_byte_buf *received_data = &secure_tunnel->received_data; struct aws_byte_cursor cursor = aws_byte_cursor_from_buf(received_data); uint16_t data_length = 0; /* - * If there are at least two bytes for the data_length, but not enough data for a complete secure tunnel frame, we - * don't want to move `cursor`. + * If there are at least two bytes for the data_length, but not enough data for a complete secure tunnel frame, + * we don't want to move `cursor`. */ struct aws_byte_cursor tmp_cursor = cursor; while (aws_byte_cursor_read_be16(&tmp_cursor, &data_length) && tmp_cursor.len >= data_length) { @@ -376,8 +886,8 @@ static void s_secure_tunneling_websocket_on_send_data_complete_callback( (void)websocket; struct data_tunnel_pair *pair = user_data; struct aws_secure_tunnel *secure_tunnel = (struct aws_secure_tunnel *)pair->secure_tunnel; - if (secure_tunnel->config->on_send_data_complete) { - secure_tunnel->config->on_send_data_complete(error_code, pair->secure_tunnel->config->user_data); + if (secure_tunnel->config->on_send_message_complete) { + secure_tunnel->config->on_send_message_complete(pair->type, error_code, secure_tunnel->config->user_data); } aws_secure_tunnel_data_tunnel_pair_destroy(pair); secure_tunnel->pending_write_completion = false; @@ -539,6 +1049,13 @@ static void s_secure_tunnel_shutdown_websocket(struct aws_secure_tunnel *secure_ return; } + AWS_LOGF_DEBUG( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: secure tunnel websocket shutdown invoked with error code %d(%s)", + (void *)secure_tunnel, + error_code, + aws_error_debug_str(error_code)); + s_change_current_state(secure_tunnel, AWS_STS_WEBSOCKET_SHUTDOWN); } @@ -636,6 +1153,8 @@ void s_websocket_transform_complete_task_fn(struct aws_task *task, void *arg, en .on_incoming_frame_begin = s_on_websocket_incoming_frame_begin, .on_incoming_frame_payload = s_on_websocket_incoming_frame_payload, .on_incoming_frame_complete = s_on_websocket_incoming_frame_complete, + + .host_resolution_config = &secure_tunnel->host_resolution_config, }; if (secure_tunnel->config->http_proxy_config != NULL) { @@ -682,10 +1201,9 @@ static int s_handshake_add_header( } AWS_LOGF_TRACE( AWS_LS_IOTDEVICE_SECURE_TUNNELING, - "id=%p: Added header " PRInSTR " " PRInSTR " to websocket request", + "id=%p: Added header " PRInSTR " to websocket request", (void *)secure_tunnel, - AWS_BYTE_CURSOR_PRI(header.name), - AWS_BYTE_CURSOR_PRI(header.value)); + AWS_BYTE_CURSOR_PRI(header.name)); return AWS_OP_SUCCESS; } @@ -799,7 +1317,7 @@ static void s_change_current_state_to_stopped(struct aws_secure_tunnel *secure_t secure_tunnel, AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_USER_REQUESTED_STOP); /* Stop works as a complete session wipe, and so the next time we connect, we want it to be clean */ - s_reset_secure_tunnel(secure_tunnel); + s_reset_secure_tunnel_streams(secure_tunnel); if (secure_tunnel->config->on_stopped) { secure_tunnel->config->on_stopped(secure_tunnel->config->user_data); @@ -862,23 +1380,53 @@ static void s_change_current_state_to_websocket_shutdown(struct aws_secure_tunne } } -static void s_update_reconnect_delay_for_pending_reconnect(struct aws_secure_tunnel *secure_tunnel) { +static uint64_t s_aws_secure_tunnel_compute_reconnect_backoff_no_jitter(struct aws_secure_tunnel *secure_tunnel) { + uint64_t retry_count = aws_min_u64(secure_tunnel->reconnect_count, 63); + return aws_mul_u64_saturating((uint64_t)1 << retry_count, MIN_RECONNECT_DELAY_MS); +} + +uint64_t aws_secure_tunnel_random_in_range(uint64_t from, uint64_t to) { + uint64_t max = aws_max_u64(from, to); + uint64_t min = aws_min_u64(from, to); + + /* Note: this contains several changes to the corresponding function in aws-c-io. Don't throw them away. + * + * 1. random range is now inclusive/closed: [from, to] rather than half-open [from, to) + * 2. as a corollary, diff == 0 => return min, not 0 + */ + uint64_t diff = max - min; + if (!diff) { + return min; + } + + uint64_t random_value = 0; + if (aws_device_random_u64(&random_value)) { + return min; + } + + if (diff == UINT64_MAX) { + return random_value; + } - uint64_t delay_ms = MIN_RECONNECT_DELAY_MS; - delay_ms = delay_ms << (int)secure_tunnel->reconnect_count; + return min + random_value % (diff + 1); /* + 1 is safe due to previous check */ +} + +static uint64_t s_aws_secure_tunnel_compute_reconnect_backoff_full_jitter(struct aws_secure_tunnel *secure_tunnel) { + uint64_t non_jittered = s_aws_secure_tunnel_compute_reconnect_backoff_no_jitter(secure_tunnel); + return aws_secure_tunnel_random_in_range(0, non_jittered); +} +static void s_update_reconnect_delay_for_pending_reconnect(struct aws_secure_tunnel *secure_tunnel) { + uint64_t delay_ms = s_aws_secure_tunnel_compute_reconnect_backoff_full_jitter(secure_tunnel); delay_ms = aws_min_u64(delay_ms, MAX_RECONNECT_DELAY_MS); uint64_t now = (*secure_tunnel->vtable->get_current_time_fn)(); - secure_tunnel->next_reconnect_time_ns = aws_add_u64_saturating(now, aws_timestamp_convert(delay_ms, AWS_TIMESTAMP_MILLIS, AWS_TIMESTAMP_NANOS, NULL)); - AWS_LOGF_DEBUG( AWS_LS_IOTDEVICE_SECURE_TUNNELING, "id=%p: next connection attempt in %" PRIu64 " milliseconds", (void *)secure_tunnel, delay_ms); - secure_tunnel->reconnect_count++; } @@ -962,6 +1510,11 @@ static void s_change_state_task_fn(struct aws_task *task, void *arg, enum aws_ta goto done; } + if (desired_state == AWS_STS_CLEAN_DISCONNECT) { + s_change_current_state(secure_tunnel, AWS_STS_CLEAN_DISCONNECT); + goto done; + } + if (secure_tunnel->desired_state != desired_state) { AWS_LOGF_INFO( AWS_LS_IOTDEVICE_SECURE_TUNNELING, @@ -1037,6 +1590,27 @@ static int s_aws_secure_tunnel_change_desired_state( return AWS_OP_SUCCESS; } +/* + * Disconnect the Secure Tunnel from the Secure Tunnel service and reset all stream ids + */ +int reset_secure_tunnel_connection(struct aws_secure_tunnel *secure_tunnel) { + + struct aws_secure_tunnel_change_desired_state_task *task = s_aws_secure_tunnel_change_desired_state_task_new( + secure_tunnel->allocator, secure_tunnel, AWS_STS_CLEAN_DISCONNECT); + + if (task == NULL) { + AWS_LOGF_ERROR( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: failed to create change desired state task", + (void *)secure_tunnel); + return AWS_OP_ERR; + } + + aws_event_loop_schedule_task_now(secure_tunnel->loop, &task->task); + + return AWS_OP_SUCCESS; +} + /********************************************************************************************************************* * vtable functions ********************************************************************************************************************/ @@ -1110,16 +1684,8 @@ static bool s_aws_secure_tunnel_has_pending_operational_work(const struct aws_se return false; } - struct aws_linked_list_node *next_operation_node = aws_linked_list_front(&secure_tunnel->queued_operations); - struct aws_secure_tunnel_operation *next_operation = - AWS_CONTAINER_OF(next_operation_node, struct aws_secure_tunnel_operation, node); - switch (secure_tunnel->current_state) { case AWS_STS_CLEAN_DISCONNECT: - /* Except for finishing the current operation, only allowed to send STREAM RESET messages in this state - */ - return next_operation->operation_type == AWS_STOT_STREAM_RESET; - case AWS_STS_CONNECTED: return true; @@ -1212,8 +1778,8 @@ int aws_secure_tunnel_service_operational_state(struct aws_secure_tunnel *secure switch (current_operation->operation_type) { case AWS_STOT_PING:; /* - * TODO Currently, pings are sent to keep the websocket alive but we do not receive responses from the - * secure tunnel service until a src is also connected. This is a known bug that is in their + * TODO Currently, pings are sent to keep the websocket alive but we do not receive responses from + * the secure tunnel service until a src is also connected. This is a known bug that is in their * backlog. Once it is fixed, we should implement ping timeout checks to determine whether we are * still connected to the secure tunnel through WebSocket. */ @@ -1222,44 +1788,57 @@ int aws_secure_tunnel_service_operational_state(struct aws_secure_tunnel *secure frame_options.opcode = AWS_WEBSOCKET_OPCODE_PING; frame_options.fin = true; secure_tunnel->vtable->aws_websocket_send_frame_fn(secure_tunnel->websocket, &frame_options); - break; + case AWS_STOT_MESSAGE: /* If a data message attempts to be sent on an unopen stream, discard it. */ if ((*current_operation->vtable->aws_secure_tunnel_operation_assign_stream_id_fn)( current_operation, secure_tunnel)) { - error_code = aws_last_error(); - - if (current_operation->message_view->service_id) { - AWS_LOGF_DEBUG( - AWS_LS_IOTDEVICE_SECURE_TUNNELING, - "id=%p: failed to assign service id '" PRInSTR - "' DATA message a stream id with error %d(%s)", - (void *)secure_tunnel, - AWS_BYTE_CURSOR_PRI(*current_operation->message_view->service_id), - error_code, - aws_error_debug_str(error_code)); - } else { - AWS_LOGF_DEBUG( - AWS_LS_IOTDEVICE_SECURE_TUNNELING, - "id=%p: failed to assign V1 DATA message a stream id with error %d(%s)", - (void *)secure_tunnel, - error_code, - aws_error_debug_str(error_code)); - } } else { - /* Send the Data message through the WebSocket */ - if (s_secure_tunneling_send(secure_tunnel, current_operation->message_view)) { + if (s_aws_secure_tunnel_active_stream_check(secure_tunnel, current_operation->message_view)) { + /* Send the Data message through the WebSocket */ + if (s_secure_tunneling_send(secure_tunnel, current_operation->message_view)) { + error_code = aws_last_error(); + AWS_LOGF_ERROR( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: failed to send DATA message with error %d(%s)", + (void *)secure_tunnel, + error_code, + aws_error_debug_str(error_code)); + } + aws_secure_tunnel_message_view_log(current_operation->message_view, AWS_LL_DEBUG); + } else { error_code = aws_last_error(); - AWS_LOGF_ERROR( - AWS_LS_IOTDEVICE_SECURE_TUNNELING, - "id=%p: failed to send DATA message with error %d(%s)", - (void *)secure_tunnel, - error_code, - aws_error_debug_str(error_code)); + if (current_operation->message_view->service_id && + current_operation->message_view->service_id->len > 0) { + AWS_LOGF_DEBUG( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: failed to send DATA message with service id '" PRInSTR + "' stream id (%d) and connection id (%d) with error %d(%s)", + (void *)secure_tunnel, + AWS_BYTE_CURSOR_PRI(*current_operation->message_view->service_id), + current_operation->message_view->stream_id, + current_operation->message_view->connection_id, + error_code, + aws_error_debug_str(error_code)); + } else { + AWS_LOGF_DEBUG( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: failed to send DATA message with stream id (%d) and connection id (%d) " + "with " + "error %d(%s)", + (void *)secure_tunnel, + current_operation->message_view->stream_id, + current_operation->message_view->connection_id, + error_code, + aws_error_debug_str(error_code)); + } } - aws_secure_tunnel_message_view_log(current_operation->message_view, AWS_LL_DEBUG); + } + if (error_code && secure_tunnel->config->on_send_message_complete) { + secure_tunnel->config->on_send_message_complete( + AWS_SECURE_TUNNEL_MT_DATA, error_code, secure_tunnel->config->user_data); } break; @@ -1275,36 +1854,102 @@ int aws_secure_tunnel_service_operational_state(struct aws_secure_tunnel *secure error_code, aws_error_debug_str(error_code)); } else { - /* Send the Stream Start message through the WebSocket */ if (s_secure_tunneling_send(secure_tunnel, current_operation->message_view)) { error_code = aws_last_error(); } aws_secure_tunnel_message_view_log(current_operation->message_view, AWS_LL_DEBUG); } + + if (error_code && secure_tunnel->config->on_send_message_complete) { + secure_tunnel->config->on_send_message_complete( + AWS_SECURE_TUNNEL_MT_STREAM_START, error_code, secure_tunnel->config->user_data); + } break; case AWS_STOT_STREAM_RESET: + if ((*current_operation->vtable->aws_secure_tunnel_operation_assign_stream_id_fn)( + current_operation, secure_tunnel) == AWS_OP_SUCCESS) { + if (current_operation->message_view->connection_id == 0) { + /* Send the Stream Reset message through the WebSocket */ + if (s_secure_tunneling_send(secure_tunnel, current_operation->message_view)) { + error_code = aws_last_error(); + } else { + s_aws_secure_tunnel_set_stream( + secure_tunnel, + current_operation->message_view->service_id, + INVALID_STREAM_ID, + current_operation->message_view->connection_id); + } + aws_secure_tunnel_message_view_log(current_operation->message_view, AWS_LL_DEBUG); + } else { + AWS_LOGF_WARN( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: failed to send STREAM RESET message must not have a connection id", + (void *)secure_tunnel); + } + } + + break; + + case AWS_STOT_CONNECTION_START: + /* If a connection start attempts to be sent on an unopen stream, discard it. */ if ((*current_operation->vtable->aws_secure_tunnel_operation_assign_stream_id_fn)( current_operation, secure_tunnel)) { error_code = aws_last_error(); - AWS_LOGF_DEBUG( - AWS_LS_IOTDEVICE_SECURE_TUNNELING, - "id=%p: failed to send STREAM RESET message with error %d(%s)", - (void *)secure_tunnel, - error_code, - aws_error_debug_str(error_code)); + } else if ((*current_operation->vtable->aws_secure_tunnel_operation_set_connection_start_id)( + current_operation, secure_tunnel)) { + error_code = aws_last_error(); } else { - /* Send the Stream Reset message through the WebSocket */ if (s_secure_tunneling_send(secure_tunnel, current_operation->message_view)) { error_code = aws_last_error(); - } else { - s_aws_secure_tunnel_set_stream_id( - secure_tunnel, current_operation->message_view->service_id, INVALID_STREAM_ID); + AWS_LOGF_ERROR( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: failed to send CONNECTION START message with error %d(%s)", + (void *)secure_tunnel, + error_code, + aws_error_debug_str(error_code)); } aws_secure_tunnel_message_view_log(current_operation->message_view, AWS_LL_DEBUG); } + if (error_code && secure_tunnel->config->on_send_message_complete) { + secure_tunnel->config->on_send_message_complete( + AWS_SECURE_TUNNEL_MT_CONNECTION_START, error_code, secure_tunnel->config->user_data); + } + break; + + case AWS_STOT_CONNECTION_RESET: + if ((*current_operation->vtable->aws_secure_tunnel_operation_assign_stream_id_fn)( + current_operation, secure_tunnel)) { + error_code = aws_last_error(); + } else { + error_code = + s_aws_secure_tunnel_remove_connection_id(secure_tunnel, current_operation->message_view); + + /* + * If we have a stream id, we should send the CONNECTION RESET message even if we do not have a + * currently active stream + */ + if (s_secure_tunneling_send(secure_tunnel, current_operation->message_view)) { + AWS_LOGF_ERROR( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: failed to send DATA message with error %d(%s)", + (void *)secure_tunnel, + aws_last_error(), + aws_error_debug_str(aws_last_error())); + } + } + + if (error_code) { + AWS_LOGF_ERROR( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: failed to send CONNECTION RESET message with error %d(%s)", + (void *)secure_tunnel, + error_code, + aws_error_debug_str(error_code)); + } + break; case AWS_STOT_NONE: @@ -1486,7 +2131,8 @@ static uint64_t s_compute_next_service_time_secure_tunnel_connected( static uint64_t s_compute_next_service_time_secure_tunnel_clean_disconnect( struct aws_secure_tunnel *secure_tunnel, uint64_t now) { - return s_aws_secure_tunnel_compute_operational_state_service_time(secure_tunnel, now); + (void)secure_tunnel; + return now; } static uint64_t s_compute_next_service_time_secure_tunnel_websocket_shutdown( @@ -1644,6 +2290,24 @@ static void s_service_state_connected(struct aws_secure_tunnel *secure_tunnel, u static void s_service_state_clean_disconnect(struct aws_secure_tunnel *secure_tunnel, uint64_t now) { (void)now; + + enum aws_secure_tunnel_state desired_state = secure_tunnel->desired_state; + if (desired_state != AWS_STS_CONNECTED) { + AWS_LOGF_INFO( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: channel shutdown due to user Stop request", + (void *)secure_tunnel); + s_secure_tunnel_shutdown_websocket(secure_tunnel, AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_USER_REQUESTED_STOP); + return; + } + + if (aws_linked_list_empty(&secure_tunnel->queued_operations)) { + s_reset_secure_tunnel_streams(secure_tunnel); + s_secure_tunnel_shutdown_websocket( + secure_tunnel, AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_PROTOCOL_VERSION_MISSMATCH); + return; + } + if (aws_secure_tunnel_service_operational_state(secure_tunnel)) { int error_code = aws_last_error(); AWS_LOGF_ERROR( @@ -1745,44 +2409,76 @@ struct aws_secure_tunnel *aws_secure_tunnel_new( goto error; } + secure_tunnel->connections = aws_secure_tunnel_connections_new(allocator); + if (secure_tunnel->connections == NULL) { + goto error; + } + /* all secure tunnel activity will take place on this event loop */ secure_tunnel->loop = aws_event_loop_group_get_next_loop(secure_tunnel->config->bootstrap->event_loop_group); if (secure_tunnel->loop == NULL) { goto error; } + secure_tunnel->host_resolution_config = aws_host_resolver_init_default_resolution_config(); + secure_tunnel->host_resolution_config.resolve_frequency_ns = + aws_timestamp_convert(MAX_RECONNECT_DELAY_MS, AWS_TIMESTAMP_MILLIS, AWS_TIMESTAMP_NANOS, NULL); + secure_tunnel->desired_state = AWS_STS_STOPPED; secure_tunnel->current_state = AWS_STS_STOPPED; /* tls setup */ - struct aws_tls_ctx_options tls_ctx_opt; - AWS_ZERO_STRUCT(tls_ctx_opt); - aws_tls_ctx_options_init_default_client(&tls_ctx_opt, secure_tunnel->allocator); + if (options->tls_options) { + if (aws_tls_connection_options_copy(&secure_tunnel->tls_con_opt, options->tls_options)) { + goto error; + } + } else { + struct aws_tls_ctx_options tls_ctx_opt; + AWS_ZERO_STRUCT(tls_ctx_opt); + + aws_tls_ctx_options_init_default_client(&tls_ctx_opt, secure_tunnel->allocator); + + if (options->root_ca != NULL) { + if (aws_tls_ctx_options_override_default_trust_store_from_path(&tls_ctx_opt, NULL, options->root_ca)) { + AWS_LOGF_ERROR( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "Failed to load %s with error %s", + options->root_ca, + aws_error_debug_str(aws_last_error())); + aws_tls_ctx_options_clean_up(&tls_ctx_opt); + goto error; + } + } - if (options->root_ca != NULL) { - if (aws_tls_ctx_options_override_default_trust_store_from_path(&tls_ctx_opt, NULL, options->root_ca)) { + secure_tunnel->tls_ctx = aws_tls_client_ctx_new(allocator, &tls_ctx_opt); + if (secure_tunnel->tls_ctx == NULL) { + AWS_LOGF_ERROR( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "Failed to initialize TLS context with error %s.", + aws_error_debug_str(aws_last_error())); + aws_tls_ctx_options_clean_up(&tls_ctx_opt); goto error; } - } - secure_tunnel->tls_ctx = aws_tls_client_ctx_new(allocator, &tls_ctx_opt); - if (secure_tunnel->tls_ctx == NULL) { - goto error; + aws_tls_connection_options_init_from_ctx(&secure_tunnel->tls_con_opt, secure_tunnel->tls_ctx); + aws_tls_ctx_options_clean_up(&tls_ctx_opt); } - /* tls_connection_options */ - aws_tls_connection_options_init_from_ctx(&secure_tunnel->tls_con_opt, secure_tunnel->tls_ctx); - if (aws_tls_connection_options_set_server_name( - &secure_tunnel->tls_con_opt, allocator, (struct aws_byte_cursor *)&options->endpoint_host)) { - goto error; + if (!secure_tunnel->tls_con_opt.server_name) { + if (aws_tls_connection_options_set_server_name( + &secure_tunnel->tls_con_opt, secure_tunnel->allocator, &options->endpoint_host)) { + AWS_LOGF_ERROR( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "Failed to set endpoint host name with error %s.", + aws_error_debug_str(aws_last_error())); + goto error; + } } - aws_tls_ctx_options_clean_up(&tls_ctx_opt); - /* Connection reset */ - secure_tunnel->config->stream_id = INVALID_STREAM_ID; + secure_tunnel->connections->stream_id = INVALID_STREAM_ID; - aws_hash_table_foreach(&secure_tunnel->config->service_ids, s_reset_service_id, NULL); + aws_hash_table_foreach(&secure_tunnel->connections->service_ids, s_reset_service_id, NULL); secure_tunnel->handshake_request = NULL; secure_tunnel->websocket = NULL; @@ -1794,7 +2490,6 @@ struct aws_secure_tunnel *aws_secure_tunnel_new( return secure_tunnel; error: - aws_tls_ctx_options_clean_up(&tls_ctx_opt); aws_secure_tunnel_release(secure_tunnel); return NULL; } @@ -1834,7 +2529,16 @@ int aws_secure_tunnel_send_message( secure_tunnel->allocator, secure_tunnel, message_options, AWS_STOT_MESSAGE); if (message_op == NULL) { - return AWS_OP_ERR; + return aws_last_error(); + } + + /* + * If message is being sent from DESTINATION MODE, it might be expected that a V2 or V1 connection has established a + * default connection id of 1. This default connection id must be stripped before sending a V1 or V2 message out. + */ + if (secure_tunnel->config->local_proxy_mode == AWS_SECURE_TUNNELING_DESTINATION_MODE && + secure_tunnel->connections->protocol_version < 3 && message_options->connection_id == 1) { + message_op->options_storage.storage_view.connection_id = 0; } AWS_LOGF_DEBUG( @@ -1851,7 +2555,7 @@ int aws_secure_tunnel_send_message( error: aws_secure_tunnel_operation_release(&message_op->base); - return AWS_OP_ERR; + return aws_last_error(); } int aws_secure_tunnel_stream_start( @@ -1861,10 +2565,36 @@ int aws_secure_tunnel_stream_start( AWS_PRECONDITION(message_options != NULL); if (secure_tunnel->config->local_proxy_mode == AWS_SECURE_TUNNELING_DESTINATION_MODE) { - AWS_LOGF_ERROR(AWS_LS_IOTDEVICE_SECURE_TUNNELING, "Stream Start can only be sent from source mode"); + AWS_LOGF_ERROR(AWS_LS_IOTDEVICE_SECURE_TUNNELING, "Stream Start can only be sent from Source Mode"); return AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_INCORRECT_MODE; } + uint8_t message_protocol_version = s_aws_secure_tunnel_message_min_protocol_check(message_options); + if (secure_tunnel->connections->protocol_version != 0 && + message_protocol_version != secure_tunnel->connections->protocol_version) { + /* + * Protocol missmatch results in a full disconnect/reconnect to the Secure Tunnel Service followed by + * sending the STREAM START request that caused the missmatch. + */ + AWS_LOGF_INFO( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: Secure Tunnel will be reset due to Protocol Version missmatch between previously established " + "Protocol Version (%d) and Protocol Version used by outbound STREAM START message (%d).", + (void *)secure_tunnel, + (int)secure_tunnel->connections->protocol_version, + (int)message_protocol_version); + reset_secure_tunnel_connection(secure_tunnel); + } + + if (secure_tunnel->connections->protocol_version == 0) { + secure_tunnel->connections->protocol_version = message_protocol_version; + AWS_LOGF_INFO( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: Secure tunnel client Protocol set to V%d based on outbound STREAM START", + (void *)secure_tunnel, + (int)secure_tunnel->connections->protocol_version); + } + struct aws_secure_tunnel_operation_message *message_op = aws_secure_tunnel_operation_message_new( secure_tunnel->allocator, secure_tunnel, message_options, AWS_STOT_STREAM_START); @@ -1889,6 +2619,60 @@ int aws_secure_tunnel_stream_start( return AWS_OP_ERR; } +int aws_secure_tunnel_connection_start( + struct aws_secure_tunnel *secure_tunnel, + const struct aws_secure_tunnel_message_view *message_options) { + AWS_PRECONDITION(secure_tunnel != NULL); + AWS_PRECONDITION(message_options != NULL); + + if (secure_tunnel->config->local_proxy_mode == AWS_SECURE_TUNNELING_DESTINATION_MODE) { + AWS_LOGF_ERROR(AWS_LS_IOTDEVICE_SECURE_TUNNELING, "Connection Start can only be sent from Source Mode"); + return AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_INCORRECT_MODE; + } + + if (secure_tunnel->connections->protocol_version != 3) { + AWS_LOGF_WARN( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, "Connection Start may only be used with a Protocol V3 stream."); + return AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_PROTOCOL_VERSION_MISSMATCH; + } + + if (message_options->connection_id == 0) { + AWS_LOGF_WARN(AWS_LS_IOTDEVICE_SECURE_TUNNELING, "Connection Start must include a connection id."); + return AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_INVALID_CONNECTION_ID; + } + + struct aws_secure_tunnel_operation_message *message_op = aws_secure_tunnel_operation_message_new( + secure_tunnel->allocator, secure_tunnel, message_options, AWS_STOT_CONNECTION_START); + + if (message_op == NULL) { + return AWS_OP_ERR; + } + + AWS_LOGF_DEBUG( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: Submitting CONNECTION START operation (%p)", + (void *)secure_tunnel, + (void *)message_op); + + if (s_submit_operation(secure_tunnel, &message_op->base)) { + goto error; + } + + return AWS_OP_SUCCESS; + +error: + aws_secure_tunnel_operation_release(&message_op->base); + return AWS_OP_ERR; +} + +/********************************************************************************************************************* + * Internal Operation Calls + ********************************************************************************************************************/ + +/* + * This is currently exposed by the initial implementation of Secure Tunnel and has been marked as deprecated. + * Should this be called, it will be honored but it should be made private when possible. + */ int aws_secure_tunnel_stream_reset( struct aws_secure_tunnel *secure_tunnel, const struct aws_secure_tunnel_message_view *message_options) { @@ -1918,3 +2702,33 @@ int aws_secure_tunnel_stream_reset( aws_secure_tunnel_operation_release(&message_op->base); return AWS_OP_ERR; } + +int aws_secure_tunnel_connection_reset( + struct aws_secure_tunnel *secure_tunnel, + const struct aws_secure_tunnel_message_view *message_options) { + AWS_PRECONDITION(secure_tunnel != NULL); + AWS_PRECONDITION(message_options != NULL); + + struct aws_secure_tunnel_operation_message *message_op = aws_secure_tunnel_operation_message_new( + secure_tunnel->allocator, secure_tunnel, message_options, AWS_STOT_CONNECTION_RESET); + + if (message_op == NULL) { + return AWS_OP_ERR; + } + + AWS_LOGF_DEBUG( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: Submitting CONNECTION RESET operation (%p)", + (void *)secure_tunnel, + (void *)message_op); + + if (s_submit_operation(secure_tunnel, &message_op->base)) { + goto error; + } + + return AWS_OP_SUCCESS; + +error: + aws_secure_tunnel_operation_release(&message_op->base); + return AWS_OP_ERR; +} diff --git a/source/secure_tunneling_operations.c b/source/secure_tunneling_operations.c index 9931805d..38f65602 100644 --- a/source/secure_tunneling_operations.c +++ b/source/secure_tunneling_operations.c @@ -15,9 +15,48 @@ #define INVALID_STREAM_ID 0 +/********************************************************************************************************************* + * SERVICE AND CONNECTION ID HASH TABLE + ********************************************************************************************************************/ + +static const uint32_t s_bit_scrambling_magic = 0x45d9f3bU; +static const uint32_t s_bit_shift_magic = 16U; + +/* this is a repurposed hash function based on the technique in splitmix64. The magic number was a result of numerical + * analysis on maximum bit entropy. */ +uint64_t aws_secure_tunnel_hash_connection_id(const void *to_hash) { + uint32_t int_to_hash = *(const uint32_t *)to_hash; + uint32_t hash = ((int_to_hash >> s_bit_shift_magic) ^ int_to_hash) * s_bit_scrambling_magic; + hash = ((hash >> s_bit_shift_magic) ^ hash) * s_bit_scrambling_magic; + hash = (hash >> s_bit_shift_magic) ^ hash; + return (uint64_t)hash; +} + +bool aws_secure_tunnel_connection_id_eq(const void *a, const void *b) { + return *(const uint32_t *)a == *(const uint32_t *)b; +} + +void aws_connection_id_destroy(void *data) { + struct aws_connection_id_element *elem = data; + aws_mem_release(elem->allocator, elem); +} + +struct aws_connection_id_element *aws_connection_id_element_new( + struct aws_allocator *allocator, + uint32_t connection_id) { + AWS_PRECONDITION(allocator != NULL); + AWS_PRECONDITION(connection_id > 0); + struct aws_connection_id_element *elem = aws_mem_calloc(allocator, 1, sizeof(struct aws_service_id_element)); + elem->allocator = allocator; + elem->connection_id = connection_id; + + return elem; +} + /* for the hash table, to destroy elements */ static void s_destroy_service_id(void *data) { struct aws_service_id_element *elem = data; + aws_hash_table_clean_up(&elem->connection_ids); aws_string_destroy(elem->service_id_string); aws_mem_release(elem->allocator, elem); } @@ -38,6 +77,17 @@ struct aws_service_id_element *aws_service_id_element_new( elem->service_id_cur = aws_byte_cursor_from_string(elem->service_id_string); elem->stream_id = stream_id; + if (aws_hash_table_init( + &elem->connection_ids, + allocator, + 1, + aws_secure_tunnel_hash_connection_id, + aws_secure_tunnel_connection_id_eq, + NULL, + aws_connection_id_destroy)) { + goto error; + } + return elem; error: @@ -46,7 +96,7 @@ struct aws_service_id_element *aws_service_id_element_new( } /********************************************************************************************************************* - * Operation base + * OPERATION BASE ********************************************************************************************************************/ struct aws_secure_tunnel_operation *aws_secure_tunnel_operation_acquire(struct aws_secure_tunnel_operation *operation) { @@ -91,10 +141,11 @@ static struct aws_secure_tunnel_operation_vtable s_empty_operation_vtable = { .aws_secure_tunnel_operation_completion_fn = NULL, .aws_secure_tunnel_operation_assign_stream_id_fn = NULL, .aws_secure_tunnel_operation_set_next_stream_id_fn = NULL, + .aws_secure_tunnel_operation_set_connection_start_id = NULL, }; /********************************************************************************************************************* - * Message + * MESSAGE ********************************************************************************************************************/ int aws_secure_tunnel_message_view_validate(const struct aws_secure_tunnel_message_view *message_view) { @@ -106,15 +157,15 @@ int aws_secure_tunnel_message_view_validate(const struct aws_secure_tunnel_messa if (message_view->type == AWS_SECURE_TUNNEL_MT_DATA && message_view->stream_id != 0) { AWS_LOGF_ERROR( AWS_LS_IOTDEVICE_SECURE_TUNNELING, - "id=%p: aws_secure_tunnel_message_view stream id for DATA MESSAGES must be 0", + "id=%p: aws_secure_tunnel_message_view - stream id for DATA MESSAGES must be 0", (void *)message_view); return aws_raise_error(AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_DATA_OPTIONS_VALIDATION); } - if (message_view->payload != NULL && message_view->payload->len > AWS_IOT_ST_MAX_MESSAGE_SIZE) { + if (message_view->payload != NULL && message_view->payload->len > AWS_IOT_ST_MAX_PAYLOAD_SIZE) { AWS_LOGF_ERROR( AWS_LS_IOTDEVICE_SECURE_TUNNELING, - "id=%p: aws_secure_tunnel_message_view - payload too long", + "id=%p: aws_secure_tunnel_message_view - payload too large", (void *)message_view); return aws_raise_error(AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_DATA_OPTIONS_VALIDATION); } @@ -138,30 +189,83 @@ void aws_secure_tunnel_message_view_log( (void *)message_view, aws_secure_tunnel_message_type_to_c_string(message_view->type)); - if (message_view->service_id != NULL) { - AWS_LOGUF( - log_handle, - level, - AWS_LS_IOTDEVICE_SECURE_TUNNELING, - "id=%p: aws_secure_tunnel_message_view service_id set to '" PRInSTR "'", - (void *)message_view, - AWS_BYTE_CURSOR_PRI(*message_view->service_id)); - } else { - AWS_LOGUF( - log_handle, - level, - AWS_LS_IOTDEVICE_SECURE_TUNNELING, - "id=%p: aws_secure_tunnel_message_view service_id not set", - (void *)message_view); - } + switch (message_view->type) { + case AWS_SECURE_TUNNEL_MT_DATA: + case AWS_SECURE_TUNNEL_MT_STREAM_START: + case AWS_SECURE_TUNNEL_MT_STREAM_RESET: + case AWS_SECURE_TUNNEL_MT_CONNECTION_START: + case AWS_SECURE_TUNNEL_MT_CONNECTION_RESET: + if (message_view->service_id != NULL) { + AWS_LOGUF( + log_handle, + level, + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: aws_secure_tunnel_message_view service_id set to '" PRInSTR "'", + (void *)message_view, + AWS_BYTE_CURSOR_PRI(*message_view->service_id)); + } else { + AWS_LOGUF( + log_handle, + level, + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: aws_secure_tunnel_message_view service_id not set", + (void *)message_view); + } - AWS_LOGUF( - log_handle, - level, - AWS_LS_IOTDEVICE_SECURE_TUNNELING, - "id=%p: aws_secure_tunnel_message_view stream_id set to %d", - (void *)message_view, - (int)message_view->stream_id); + AWS_LOGUF( + log_handle, + level, + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: aws_secure_tunnel_message_view stream_id set to %d", + (void *)message_view, + (int)message_view->stream_id); + + if (message_view->connection_id != 0) { + AWS_LOGUF( + log_handle, + level, + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: aws_secure_tunnel_message_view connection_id set to %d", + (void *)message_view, + (int)message_view->connection_id); + } + + break; + + case AWS_SECURE_TUNNEL_MT_SERVICE_IDS: + if (message_view->service_id != NULL) { + AWS_LOGUF( + log_handle, + level, + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: aws_secure_tunnel_message_view service_id 1 set to '" PRInSTR "'", + (void *)message_view, + AWS_BYTE_CURSOR_PRI(*message_view->service_id)); + } + if (message_view->service_id_2 != NULL) { + AWS_LOGUF( + log_handle, + level, + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: aws_secure_tunnel_message_view service_id 2 set to '" PRInSTR "'", + (void *)message_view, + AWS_BYTE_CURSOR_PRI(*message_view->service_id_2)); + } + if (message_view->service_id_3 != NULL) { + AWS_LOGUF( + log_handle, + level, + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: aws_secure_tunnel_message_view service_id 3 set to '" PRInSTR "'", + (void *)message_view, + AWS_BYTE_CURSOR_PRI(*message_view->service_id_3)); + } + break; + case AWS_SECURE_TUNNEL_MT_SESSION_RESET: + case AWS_SECURE_TUNNEL_MT_UNKNOWN: + default: + break; + } if (message_view->payload != NULL) { AWS_LOGUF( @@ -199,6 +303,7 @@ int aws_secure_tunnel_message_storage_init( storage_view->type = message_options->type; storage_view->ignorable = message_options->ignorable; storage_view->stream_id = message_options->stream_id; + storage_view->connection_id = message_options->connection_id; switch (type) { case AWS_STOT_MESSAGE: @@ -210,6 +315,12 @@ int aws_secure_tunnel_message_storage_init( case AWS_STOT_STREAM_RESET: storage_view->type = AWS_SECURE_TUNNEL_MT_STREAM_RESET; break; + case AWS_STOT_CONNECTION_START: + storage_view->type = AWS_SECURE_TUNNEL_MT_CONNECTION_START; + break; + case AWS_STOT_CONNECTION_RESET: + storage_view->type = AWS_SECURE_TUNNEL_MT_CONNECTION_RESET; + break; default: storage_view->type = AWS_SECURE_TUNNEL_MT_UNKNOWN; break; @@ -238,7 +349,9 @@ void aws_secure_tunnel_message_storage_clean_up(struct aws_secure_tunnel_message aws_byte_buf_clean_up(&message_storage->storage); } -/* Sets the stream id on outbound message based on the service id (or lack of for V1) to the current one being used. */ +/* + * Retreives and assigns the stream id on an outbound message based on the service id (or lack of one for V1). + */ static int s_aws_secure_tunnel_operation_message_assign_stream_id( struct aws_secure_tunnel_operation *operation, struct aws_secure_tunnel *secure_tunnel) { @@ -248,30 +361,47 @@ static int s_aws_secure_tunnel_operation_message_assign_stream_id( struct aws_secure_tunnel_message_view *message_view = &message_op->options_storage.storage_view; - if (message_view->service_id != NULL) { + if (message_view->service_id == NULL || message_view->service_id->len == 0) { + stream_id = secure_tunnel->connections->stream_id; + } else { struct aws_hash_element *elem = NULL; - aws_hash_table_find(&secure_tunnel->config->service_ids, message_view->service_id, &elem); + aws_hash_table_find(&secure_tunnel->connections->service_ids, message_view->service_id, &elem); if (elem == NULL) { AWS_LOGF_WARN( AWS_LS_IOTDEVICE_SECURE_TUNNELING, - "id=%p: invalid service_id:'" PRInSTR "' attempted to be used with an outbound message", + "id=%p: invalid service id '" PRInSTR "' attempted to be assigned a stream id on an outbound message", (void *)message_view, AWS_BYTE_CURSOR_PRI(*message_view->service_id)); - stream_id = INVALID_STREAM_ID; - } else { - struct aws_service_id_element *service_id_elem = elem->value; - stream_id = service_id_elem->stream_id; + goto error; } - } else { - stream_id = secure_tunnel->config->stream_id; + struct aws_service_id_element *service_id_elem = elem->value; + stream_id = service_id_elem->stream_id; } if (stream_id == INVALID_STREAM_ID) { - return aws_raise_error(AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_INVALID_STREAM); + goto error; } message_op->options_storage.storage_view.stream_id = stream_id; return AWS_OP_SUCCESS; + +error: + if (message_view->service_id == NULL || message_view->service_id->len == 0) { + AWS_LOGF_DEBUG( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: No active stream to assign outbound %s message a stream id", + (void *)secure_tunnel, + aws_secure_tunnel_message_type_to_c_string(message_view->type)); + } else { + AWS_LOGF_DEBUG( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: No active stream with service id '" PRInSTR "' to assign outbound %s message a stream id", + (void *)secure_tunnel, + AWS_BYTE_CURSOR_PRI(*message_view->service_id), + aws_secure_tunnel_message_type_to_c_string(message_view->type)); + } + + return aws_raise_error(AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_INVALID_STREAM_ID); } /* @@ -287,9 +417,30 @@ static int s_aws_secure_tunnel_operation_message_set_next_stream_id( struct aws_secure_tunnel_message_view *message_view = &message_op->options_storage.storage_view; - if (message_view->service_id != NULL && message_view->service_id->len > 0) { + if (message_view->service_id == NULL || message_view->service_id->len == 0) { + stream_id = secure_tunnel->connections->stream_id + 1; + secure_tunnel->connections->stream_id = stream_id; + + aws_hash_table_clear(&secure_tunnel->connections->connection_ids); + struct aws_connection_id_element *connection_id_elem = NULL; + if (message_view->connection_id > 0) { + connection_id_elem = aws_connection_id_element_new(secure_tunnel->allocator, message_view->connection_id); + } else { + connection_id_elem = aws_connection_id_element_new(secure_tunnel->allocator, 1); + } + + aws_hash_table_put( + &secure_tunnel->connections->connection_ids, &connection_id_elem->connection_id, connection_id_elem, NULL); + + AWS_LOGF_INFO( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: Secure tunnel set to stream id (%d) with connection id (%d)", + (void *)secure_tunnel, + stream_id, + connection_id_elem->connection_id); + } else { struct aws_hash_element *elem = NULL; - aws_hash_table_find(&secure_tunnel->config->service_ids, message_view->service_id, &elem); + aws_hash_table_find(&secure_tunnel->connections->service_ids, message_view->service_id, &elem); if (elem == NULL) { AWS_LOGF_WARN( AWS_LS_IOTDEVICE_SECURE_TUNNELING, @@ -304,26 +455,108 @@ static int s_aws_secure_tunnel_operation_message_set_next_stream_id( struct aws_service_id_element *replacement_elem = aws_service_id_element_new(secure_tunnel->allocator, message_view->service_id, stream_id); + + struct aws_connection_id_element *connection_id_elem = NULL; + if (message_view->connection_id > 0) { + connection_id_elem = + aws_connection_id_element_new(secure_tunnel->allocator, message_view->connection_id); + } else { + connection_id_elem = aws_connection_id_element_new(secure_tunnel->allocator, 1); + } + + aws_hash_table_put( + &replacement_elem->connection_ids, &connection_id_elem->connection_id, connection_id_elem, NULL); aws_hash_table_put( - &secure_tunnel->config->service_ids, &replacement_elem->service_id_cur, replacement_elem, NULL); + &secure_tunnel->connections->service_ids, &replacement_elem->service_id_cur, replacement_elem, NULL); + + AWS_LOGF_INFO( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: Secure tunnel service id '" PRInSTR "' set to stream id (%d) with connection id (%d)", + (void *)secure_tunnel, + AWS_BYTE_CURSOR_PRI(*message_view->service_id), + stream_id, + connection_id_elem->connection_id); } - } else { - stream_id = secure_tunnel->config->stream_id + 1; - secure_tunnel->config->stream_id = stream_id; } if (stream_id == INVALID_STREAM_ID) { - return aws_raise_error(AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_INVALID_STREAM); + return aws_raise_error(AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_INVALID_STREAM_ID); } message_op->options_storage.storage_view.stream_id = stream_id; - AWS_LOGF_INFO( - AWS_LS_IOTDEVICE_SECURE_TUNNELING, - "id=%p: Secure tunnel service_id '" PRInSTR "' stream_id set to %d", - (void *)secure_tunnel, - AWS_BYTE_CURSOR_PRI(*message_view->service_id), - stream_id); + return AWS_OP_SUCCESS; +} + +static int s_aws_secure_tunnel_operation_set_connection_start_id( + struct aws_secure_tunnel_operation *operation, + struct aws_secure_tunnel *secure_tunnel) { + + struct aws_secure_tunnel_operation_message *message_op = operation->impl; + struct aws_secure_tunnel_message_view *message_view = &message_op->options_storage.storage_view; + + /* + * Get the appropriate connection id hash table to add the new connection id to + */ + struct aws_hash_table *table_to_put_in = NULL; + if (message_view->service_id == NULL || message_view->service_id->len == 0) { + table_to_put_in = &secure_tunnel->connections->connection_ids; + } else { + struct aws_hash_element *elem = NULL; + aws_hash_table_find(&secure_tunnel->connections->service_ids, message_view->service_id, &elem); + if (elem == NULL) { + AWS_LOGF_ERROR( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: invalid service_id:'" PRInSTR + "' attempted to be used to start a stream using a connection id (%d)", + (void *)message_view, + AWS_BYTE_CURSOR_PRI(*message_view->service_id), + message_view->connection_id); + aws_raise_error(AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_INVALID_SERVICE_ID); + } else { + struct aws_service_id_element *service_id_elem = elem->value; + table_to_put_in = &service_id_elem->connection_ids; + } + } + + if (message_view->connection_id != 0) { + struct aws_connection_id_element *connection_id_elem = NULL; + connection_id_elem = aws_connection_id_element_new(secure_tunnel->allocator, message_view->connection_id); + struct aws_hash_element *connection_elem = NULL; + + aws_hash_table_find(table_to_put_in, &connection_id_elem->connection_id, &connection_elem); + /* + * If the connection id is already stored, it does not need to be put into the hash table. The CONNECTION START + * will still be sent but if there is already an active stream on this connection id on the Destination, they + * will send a CONNECTION RESET to close it. + */ + if (connection_elem == NULL) { + aws_hash_table_put(table_to_put_in, &connection_id_elem->connection_id, connection_id_elem, NULL); + } else { + aws_connection_id_destroy(connection_id_elem); + } + + if (message_view->service_id == NULL || message_view->service_id->len == 0) { + AWS_LOGF_INFO( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: Stream started using connection id (%d)", + (void *)message_view, + message_view->connection_id); + } else { + AWS_LOGF_INFO( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: Stream started on service_id:'" PRInSTR "' using connection id (%d)", + (void *)message_view, + AWS_BYTE_CURSOR_PRI(*message_view->service_id), + message_view->connection_id); + } + } else { + AWS_LOGF_ERROR( + AWS_LS_IOTDEVICE_SECURE_TUNNELING, + "id=%p: Connection Id can not be set to 0 on a CONNECTION START", + (void *)message_view); + aws_raise_error(AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_INVALID_CONNECTION_ID); + } return AWS_OP_SUCCESS; } @@ -331,6 +564,7 @@ static int s_aws_secure_tunnel_operation_message_set_next_stream_id( static struct aws_secure_tunnel_operation_vtable s_message_operation_vtable = { .aws_secure_tunnel_operation_assign_stream_id_fn = s_aws_secure_tunnel_operation_message_assign_stream_id, .aws_secure_tunnel_operation_set_next_stream_id_fn = s_aws_secure_tunnel_operation_message_set_next_stream_id, + .aws_secure_tunnel_operation_set_connection_start_id = s_aws_secure_tunnel_operation_set_connection_start_id, }; static void s_destroy_operation_message(void *object) { @@ -547,7 +781,6 @@ void aws_secure_tunnel_options_storage_destroy(struct aws_secure_tunnel_options_ aws_string_destroy(storage->endpoint_host); aws_string_destroy(storage->access_token); aws_string_destroy(storage->client_token); - aws_hash_table_clean_up(&storage->service_ids); aws_mem_release(storage->allocator, storage); } @@ -626,26 +859,17 @@ struct aws_secure_tunnel_options_storage *aws_secure_tunnel_options_storage_new( aws_http_proxy_options_init_from_config(&storage->http_proxy_options, storage->http_proxy_config); } - if (aws_hash_table_init( - &storage->service_ids, - allocator, - 3, - aws_hash_byte_cursor_ptr, - (aws_hash_callback_eq_fn *)aws_byte_cursor_eq, - NULL, - s_destroy_service_id)) { - goto error; - } - storage->on_message_received = options->on_message_received; storage->user_data = options->user_data; storage->local_proxy_mode = options->local_proxy_mode; storage->on_connection_complete = options->on_connection_complete; storage->on_connection_shutdown = options->on_connection_shutdown; - storage->on_send_data_complete = options->on_send_data_complete; + storage->on_send_message_complete = options->on_send_message_complete; storage->on_stream_start = options->on_stream_start; storage->on_stream_reset = options->on_stream_reset; + storage->on_connection_start = options->on_connection_start; + storage->on_connection_reset = options->on_connection_reset; storage->on_session_reset = options->on_session_reset; storage->on_stopped = options->on_stopped; storage->on_termination_complete = options->on_termination_complete; @@ -658,6 +882,56 @@ struct aws_secure_tunnel_options_storage *aws_secure_tunnel_options_storage_new( return NULL; } +void aws_secure_tunnel_connections_destroy(struct aws_secure_tunnel_connections *connections) { + if (connections == NULL) { + return; + } + + if (connections->restore_stream_message_view != NULL) { + aws_secure_tunnel_message_storage_clean_up(&connections->restore_stream_message); + connections->restore_stream_message_view = NULL; + } + aws_hash_table_clean_up(&connections->service_ids); + aws_hash_table_clean_up(&connections->connection_ids); + + aws_mem_release(connections->allocator, connections); +} + +struct aws_secure_tunnel_connections *aws_secure_tunnel_connections_new(struct aws_allocator *allocator) { + AWS_PRECONDITION(allocator != NULL); + + struct aws_secure_tunnel_connections *connections = + aws_mem_calloc(allocator, 1, sizeof(struct aws_secure_tunnel_connections)); + + connections->allocator = allocator; + + if (aws_hash_table_init( + &connections->service_ids, + allocator, + 3, + aws_hash_byte_cursor_ptr, + (aws_hash_callback_eq_fn *)aws_byte_cursor_eq, + NULL, + s_destroy_service_id)) { + goto error; + } + if (aws_hash_table_init( + &connections->connection_ids, + allocator, + 1, + aws_secure_tunnel_hash_connection_id, + aws_secure_tunnel_connection_id_eq, + NULL, + aws_connection_id_destroy)) { + goto error; + } + return connections; + +error: + aws_secure_tunnel_connections_destroy(connections); + return NULL; +} + /********************************************************************************************************************* * Data Tunnel Pair ********************************************************************************************************************/ @@ -684,15 +958,12 @@ struct data_tunnel_pair *aws_secure_tunnel_data_tunnel_pair_new( struct data_tunnel_pair *pair = aws_mem_calloc(allocator, 1, sizeof(struct data_tunnel_pair)); pair->allocator = allocator; pair->secure_tunnel = secure_tunnel; + pair->type = message_view->type; pair->length_prefix_written = false; if (aws_iot_st_msg_serialize_from_view(&pair->buf, allocator, message_view)) { AWS_LOGF_ERROR(AWS_LS_IOTDEVICE_SECURE_TUNNELING, "Failure serializing message"); goto error; } - if (pair->buf.len > AWS_IOT_ST_MAX_MESSAGE_SIZE) { - AWS_LOGF_ERROR(AWS_LS_IOTDEVICE_SECURE_TUNNELING, "Message size greater than AWS_IOT_ST_MAX_MESSAGE_SIZE"); - goto error; - } pair->cur = aws_byte_cursor_from_buf(&pair->buf); @@ -716,6 +987,10 @@ const char *aws_secure_tunnel_operation_type_to_c_string(enum aws_secure_tunnel_ return "STREAM RESET"; case AWS_STOT_STREAM_START: return "STREAM START"; + case AWS_STOT_CONNECTION_START: + return "CONNECTION START"; + case AWS_STOT_CONNECTION_RESET: + return "CONNECTION RESET"; default: return "UNKNOWN"; } diff --git a/source/serializer.c b/source/serializer.c index 565178a3..15bcfba7 100644 --- a/source/serializer.c +++ b/source/serializer.c @@ -92,6 +92,10 @@ static int s_iot_st_encode_stream_id(int32_t data, struct aws_byte_buf *buffer) return s_iot_st_encode_varint(AWS_SECURE_TUNNEL_FN_STREAM_ID, AWS_SECURE_TUNNEL_PBWT_VARINT, data, buffer); } +static int s_iot_st_encode_connection_id(uint32_t data, struct aws_byte_buf *buffer) { + return s_iot_st_encode_varint(AWS_SECURE_TUNNEL_FN_CONNECTION_ID, AWS_SECURE_TUNNEL_PBWT_VARINT, data, buffer); +} + static int s_iot_st_encode_ignorable(int32_t data, struct aws_byte_buf *buffer) { return s_iot_st_encode_varint(AWS_SECURE_TUNNEL_FN_IGNORABLE, AWS_SECURE_TUNNEL_PBWT_VARINT, data, buffer); } @@ -158,6 +162,21 @@ static int s_iot_st_compute_message_length( local_length += (1 + stream_id_length); } + if (message->connection_id != 0) { + /* + * 1 byte connection_id key + * 1-4 byte connection_id varint + */ + + size_t connection_id_length = 0; + + if (s_iot_st_get_varint_size(message->connection_id, &connection_id_length)) { + return AWS_OP_ERR; + } + + local_length += (1 + connection_id_length); + } + if (message->ignorable != 0) { /* * 1 byte ignorable key @@ -256,6 +275,12 @@ int aws_iot_st_msg_serialize_from_view( } } + if (message_view->connection_id != 0) { + if (s_iot_st_encode_connection_id(message_view->connection_id, buffer)) { + goto cleanup; + } + } + if (message_view->ignorable != 0) { if (s_iot_st_encode_ignorable(message_view->ignorable, buffer)) { goto cleanup; @@ -290,11 +315,6 @@ int aws_iot_st_msg_serialize_from_view( } } - if (buffer->capacity > AWS_IOT_ST_MAX_MESSAGE_SIZE) { - aws_raise_error(AWS_ERROR_INVALID_BUFFER_SIZE); - goto cleanup; - } - return AWS_OP_SUCCESS; cleanup: @@ -349,6 +369,9 @@ int aws_secure_tunnel_deserialize_varint_from_cursor_to_message( case AWS_SECURE_TUNNEL_FN_IGNORABLE: message->ignorable = result; break; + case AWS_SECURE_TUNNEL_FN_CONNECTION_ID: + message->connection_id = result; + break; default: AWS_LOGF_WARN( AWS_LS_IOTDEVICE_SECURE_TUNNELING, @@ -372,7 +395,6 @@ int aws_secure_tunnel_deserialize_message_from_cursor( (void *)secure_tunnel, cursor->len); - AWS_RETURN_ERROR_IF2(cursor->len < AWS_IOT_ST_MAX_MESSAGE_SIZE, AWS_ERROR_INVALID_BUFFER_SIZE); uint8_t wire_type; uint8_t field_number; struct aws_secure_tunnel_message_view message_view; @@ -468,34 +490,3 @@ int aws_secure_tunnel_deserialize_message_from_cursor( error: return AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_DECODE_FAILURE; } - -const char *aws_secure_tunnel_message_type_to_c_string(enum aws_secure_tunnel_message_type message_type) { - switch (message_type) { - case AWS_SECURE_TUNNEL_MT_UNKNOWN: - return "ST_MT_UNKNOWN"; - - case AWS_SECURE_TUNNEL_MT_DATA: - return "DATA"; - - case AWS_SECURE_TUNNEL_MT_STREAM_START: - return "STREAM START"; - - case AWS_SECURE_TUNNEL_MT_STREAM_RESET: - return "STREAM RESET"; - - case AWS_SECURE_TUNNEL_MT_SESSION_RESET: - return "SESSION RESET"; - - case AWS_SECURE_TUNNEL_MT_SERVICE_IDS: - return "SERVICE IDS"; - - case AWS_SECURE_TUNNEL_MT_CONNECTION_START: - return "CONNECTION START"; - - case AWS_SECURE_TUNNEL_MT_CONNECTION_RESET: - return "CONNECTION RESET"; - - default: - return "UNKNOWN"; - } -} diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index cee0e152..36db970e 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -19,6 +19,7 @@ if (UNIX AND NOT APPLE) add_test_case(devicedefender_publish_failure_callback_invoked) endif() +# Secure Tunnel Tets add_net_test_case(secure_tunneling_functionality_connect_test) add_net_test_case(secure_tunneling_functionality_client_token_test) add_net_test_case(secure_tunneling_fail_and_retry_connection_test) @@ -26,8 +27,23 @@ add_net_test_case(secure_tunneling_store_service_ids_test) add_net_test_case(secure_tunneling_receive_stream_start_test) add_net_test_case(secure_tunneling_rejected_service_id_stream_start_test) add_net_test_case(secure_tunneling_close_stream_on_stream_reset_test) +add_net_test_case(secure_tunneling_ignore_stream_reset_for_inactive_stream_test) add_net_test_case(secure_tunneling_session_reset_test) add_net_test_case(secure_tunneling_serializer_data_message_test) +add_net_test_case(secure_tunneling_max_payload_test) +add_net_test_case(secure_tunneling_max_payload_exceed_test) +add_net_test_case(secure_tunneling_receive_connection_start_test) +add_net_test_case(secure_tunneling_ignore_inactive_stream_message_test) +add_net_test_case(secure_tunneling_ignore_inactive_connection_id_message_test) +add_net_test_case(secure_tunneling_v1_to_v2_stream_start_test) +add_net_test_case(secure_tunneling_v1_to_v3_stream_start_test) +add_net_test_case(secure_tunneling_v2_to_v1_stream_start_test) +add_net_test_case(secure_tunneling_v3_to_v1_stream_start_test) +add_net_test_case(secure_tunneling_v1_stream_start_v3_message_reset_test) +add_net_test_case(secure_tunneling_v2_stream_start_connection_start_reset_test) +add_net_test_case(secure_tunneling_ignore_outbound_inactive_connection_id_message_sending_test) +add_net_test_case(secure_tunneling_close_stream_on_connection_reset_test) +add_net_test_case(secure_tunneling_existing_connection_start_send_reset_test) generate_test_driver(${PROJECT_NAME}-tests) diff --git a/tests/secure_tunnel_tests.c b/tests/secure_tunnel_tests.c index 795c0bf8..18776a7e 100644 --- a/tests/secure_tunnel_tests.c +++ b/tests/secure_tunnel_tests.c @@ -38,6 +38,18 @@ AWS_STATIC_STRING_FROM_LITERAL(s_payload_text, "IAmABunchOfPayloadText"); # define LOCAL_SOCK_TEST_PATTERN "testsock%llu.sock" #endif +static uint8_t s_too_long_for_uint16[UINT16_MAX + 1]; + +static struct aws_byte_cursor s_payload_cursor_max_size_exceeded = { + .ptr = s_too_long_for_uint16, + .len = AWS_IOT_ST_MAX_PAYLOAD_SIZE + 1, +}; + +static struct aws_byte_cursor s_payload_cursor_max_size = { + .ptr = s_too_long_for_uint16, + .len = AWS_IOT_ST_MAX_PAYLOAD_SIZE, +}; + struct aws_secure_tunnel_mock_websocket_vtable { aws_websocket_on_connection_setup_fn *on_connection_setup_fn; aws_websocket_on_connection_shutdown_fn *on_connection_shutdown_fn; @@ -91,6 +103,7 @@ struct aws_secure_tunnel_mock_test_fixture { struct aws_mutex lock; struct aws_condition_variable signal; bool listener_destroyed; + bool secure_tunnel_connected; bool secure_tunnel_terminated; bool secure_tunnel_connected_succesfully; bool secure_tunnel_connection_shutdown; @@ -98,16 +111,77 @@ struct aws_secure_tunnel_mock_test_fixture { bool secure_tunnel_stream_started; bool secure_tunnel_bad_stream_request; bool secure_tunnel_stream_reset_received; + bool secure_tunnel_connection_started; + bool secure_tunnel_bad_connection_request; + bool secure_tunnel_connection_reset_received; bool secure_tunnel_session_reset_received; struct aws_byte_buf last_message_payload_buf; int secure_tunnel_message_received_count; + int secure_tunnel_message_sent_count; int secure_tunnel_stream_started_count; int secure_tunnel_stream_started_count_target; - int secure_tunnel_message_count_target; + int secure_tunnel_connection_started_count; + int secure_tunnel_connection_started_count_target; + int secure_tunnel_message_received_count_target; + int secure_tunnel_message_sent_count_target; + int secure_tunnel_message_sent_connection_reset_count; + int secure_tunnel_message_sent_data_count; }; +static bool s_secure_tunnel_check_active_stream_id( + struct aws_secure_tunnel *secure_tunnel, + struct aws_byte_cursor *service_id, + int32_t stream_id) { + if (service_id == NULL) { + return secure_tunnel->connections->stream_id == stream_id; + } + + struct aws_hash_element *elem = NULL; + aws_hash_table_find(&secure_tunnel->connections->service_ids, service_id, &elem); + if (elem == NULL) { + return false; + } + + struct aws_service_id_element *service_id_elem = elem->value; + if (service_id_elem->stream_id != stream_id) { + return false; + } + + return true; +} + +static bool s_secure_tunnel_check_active_connection_id( + struct aws_secure_tunnel *secure_tunnel, + struct aws_byte_cursor *service_id, + int32_t stream_id, + uint32_t connection_id) { + struct aws_hash_table *table_to_check = NULL; + if (service_id) { + struct aws_hash_element *elem = NULL; + aws_hash_table_find(&secure_tunnel->connections->service_ids, service_id, &elem); + if (elem == NULL) { + return false; + } + struct aws_service_id_element *service_id_elem = elem->value; + table_to_check = &service_id_elem->connection_ids; + } else { + if (secure_tunnel->connections->stream_id != stream_id) { + return false; + } + table_to_check = &secure_tunnel->connections->connection_ids; + } + + struct aws_hash_element *connection_elem = NULL; + aws_hash_table_find(table_to_check, &connection_id, &connection_elem); + if (connection_elem == NULL) { + return false; + } + + return true; +} + /***************************************************************************************************************** * SECURE TUNNEL CALLBACKS *****************************************************************************************************************/ @@ -120,13 +194,15 @@ static void s_on_test_secure_tunnel_connection_complete( struct aws_secure_tunnel_mock_test_fixture *test_fixture = user_data; aws_mutex_lock(&test_fixture->lock); - if (error_code == 0) { + if (error_code == 0 && test_fixture->secure_tunnel_connected == false) { + test_fixture->secure_tunnel_connection_shutdown = false; test_fixture->secure_tunnel_connected_succesfully = true; + test_fixture->secure_tunnel_connected = true; } else { test_fixture->secure_tunnel_connection_failed = true; } - aws_mutex_unlock(&test_fixture->lock); aws_condition_variable_notify_all(&test_fixture->signal); + aws_mutex_unlock(&test_fixture->lock); } static void s_on_test_secure_tunnel_connection_shutdown(int error_code, void *user_data) { @@ -135,8 +211,11 @@ static void s_on_test_secure_tunnel_connection_shutdown(int error_code, void *us aws_mutex_lock(&test_fixture->lock); test_fixture->secure_tunnel_connection_shutdown = true; - aws_mutex_unlock(&test_fixture->lock); + test_fixture->secure_tunnel_connected = false; + test_fixture->secure_tunnel_connected_succesfully = false; + test_fixture->secure_tunnel_stream_started = false; aws_condition_variable_notify_all(&test_fixture->signal); + aws_mutex_unlock(&test_fixture->lock); } static void s_on_test_secure_tunnel_message_received( @@ -152,11 +231,15 @@ static void s_on_test_secure_tunnel_message_received( .len = message->payload->len, }; aws_byte_buf_write_from_whole_cursor(&test_fixture->last_message_payload_buf, payload_cur); - aws_mutex_unlock(&test_fixture->lock); aws_condition_variable_notify_all(&test_fixture->signal); + aws_mutex_unlock(&test_fixture->lock); } -static void s_on_test_secure_tunnel_send_data_complete(int error_code, void *user_data) { +static void s_on_test_secure_tunnel_send_message_complete( + enum aws_secure_tunnel_message_type type, + int error_code, + void *user_data) { + (void)type; (void)error_code; (void)user_data; } @@ -166,8 +249,8 @@ static void s_on_test_secure_tunnel_on_session_reset(void *user_data) { aws_mutex_lock(&test_fixture->lock); test_fixture->secure_tunnel_session_reset_received = true; - aws_mutex_unlock(&test_fixture->lock); aws_condition_variable_notify_all(&test_fixture->signal); + aws_mutex_unlock(&test_fixture->lock); } static void s_on_test_secure_tunnel_on_stopped(void *user_data) { @@ -179,8 +262,8 @@ static void s_on_test_secure_tunnel_termination(void *user_data) { aws_mutex_lock(&test_fixture->lock); test_fixture->secure_tunnel_terminated = true; - aws_mutex_unlock(&test_fixture->lock); aws_condition_variable_notify_all(&test_fixture->signal); + aws_mutex_unlock(&test_fixture->lock); } static void s_on_test_secure_tunnel_on_stream_reset( @@ -194,8 +277,8 @@ static void s_on_test_secure_tunnel_on_stream_reset( aws_mutex_lock(&test_fixture->lock); test_fixture->secure_tunnel_stream_reset_received = true; - aws_mutex_unlock(&test_fixture->lock); aws_condition_variable_notify_all(&test_fixture->signal); + aws_mutex_unlock(&test_fixture->lock); } static void s_on_test_secure_tunnel_on_stream_start( @@ -213,8 +296,42 @@ static void s_on_test_secure_tunnel_on_stream_start( } else { test_fixture->secure_tunnel_bad_stream_request = true; } + aws_condition_variable_notify_all(&test_fixture->signal); + aws_mutex_unlock(&test_fixture->lock); +} + +static void s_on_test_secure_tunnel_on_connection_start( + const struct aws_secure_tunnel_message_view *message, + int error_code, + void *user_data) { + (void)message; + + struct aws_secure_tunnel_mock_test_fixture *test_fixture = user_data; + + aws_mutex_lock(&test_fixture->lock); + if (error_code == AWS_OP_SUCCESS) { + test_fixture->secure_tunnel_connection_started = true; + test_fixture->secure_tunnel_connection_started_count++; + } else { + test_fixture->secure_tunnel_bad_connection_request = true; + } + aws_condition_variable_notify_all(&test_fixture->signal); aws_mutex_unlock(&test_fixture->lock); +} + +static void s_on_test_secure_tunnel_on_connection_reset( + const struct aws_secure_tunnel_message_view *message, + int error_code, + void *user_data) { + (void)message; + (void)error_code; + + struct aws_secure_tunnel_mock_test_fixture *test_fixture = user_data; + + aws_mutex_lock(&test_fixture->lock); + test_fixture->secure_tunnel_connection_reset_received = true; aws_condition_variable_notify_all(&test_fixture->signal); + aws_mutex_unlock(&test_fixture->lock); } /***************************************************************************************************************** @@ -269,6 +386,54 @@ static void s_wait_for_stream_started(struct aws_secure_tunnel_mock_test_fixture aws_mutex_unlock(&test_fixture->lock); } +static bool s_has_secure_tunnel_connection_started(void *arg) { + struct aws_secure_tunnel_mock_test_fixture *test_fixture = arg; + return test_fixture->secure_tunnel_connection_started; +} + +static void s_wait_for_connection_started(struct aws_secure_tunnel_mock_test_fixture *test_fixture) { + aws_mutex_lock(&test_fixture->lock); + aws_condition_variable_wait_pred( + &test_fixture->signal, &test_fixture->lock, s_has_secure_tunnel_connection_started, test_fixture); + aws_mutex_unlock(&test_fixture->lock); +} + +static bool s_has_secure_tunnel_bad_connection_started(void *arg) { + struct aws_secure_tunnel_mock_test_fixture *test_fixture = arg; + return test_fixture->secure_tunnel_bad_connection_request; +} + +static void s_wait_for_bad_connection_started(struct aws_secure_tunnel_mock_test_fixture *test_fixture) { + aws_mutex_lock(&test_fixture->lock); + aws_condition_variable_wait_pred( + &test_fixture->signal, &test_fixture->lock, s_has_secure_tunnel_bad_connection_started, test_fixture); + aws_mutex_unlock(&test_fixture->lock); +} + +static bool s_has_secure_tunnel_connection_reset_message_sent(void *arg) { + struct aws_secure_tunnel_mock_test_fixture *test_fixture = arg; + return test_fixture->secure_tunnel_message_sent_connection_reset_count > 0; +} + +static void s_wait_for_connection_reset_message_sent(struct aws_secure_tunnel_mock_test_fixture *test_fixture) { + aws_mutex_lock(&test_fixture->lock); + aws_condition_variable_wait_pred( + &test_fixture->signal, &test_fixture->lock, s_has_secure_tunnel_connection_reset_message_sent, test_fixture); + aws_mutex_unlock(&test_fixture->lock); +} + +static bool s_has_secure_tunnel_connection_reset_received(void *arg) { + struct aws_secure_tunnel_mock_test_fixture *test_fixture = arg; + return test_fixture->secure_tunnel_connection_reset_received; +} + +static void s_wait_for_connection_reset_received(struct aws_secure_tunnel_mock_test_fixture *test_fixture) { + aws_mutex_lock(&test_fixture->lock); + aws_condition_variable_wait_pred( + &test_fixture->signal, &test_fixture->lock, s_has_secure_tunnel_connection_reset_received, test_fixture); + aws_mutex_unlock(&test_fixture->lock); +} + static bool s_has_secure_tunnel_bad_stream_request(void *arg) { struct aws_secure_tunnel_mock_test_fixture *test_fixture = arg; return test_fixture->secure_tunnel_bad_stream_request; @@ -319,7 +484,8 @@ static void s_wait_for_session_reset_received(struct aws_secure_tunnel_mock_test static bool s_has_secure_tunnel_n_messages_received(void *arg) { struct aws_secure_tunnel_mock_test_fixture *test_fixture = arg; - return test_fixture->secure_tunnel_stream_started_count == test_fixture->secure_tunnel_message_count_target; + return test_fixture->secure_tunnel_message_received_count == + test_fixture->secure_tunnel_message_received_count_target; } static void s_wait_for_n_messages_received(struct aws_secure_tunnel_mock_test_fixture *test_fixture) { @@ -337,7 +503,10 @@ static void s_wait_for_n_messages_received(struct aws_secure_tunnel_mock_test_fi void aws_secure_tunnel_send_mock_message( struct aws_secure_tunnel_mock_test_fixture *test_fixture, const struct aws_secure_tunnel_message_view *message_view) { - + /* The actual WebSocket is assigned the same event loop as the secure tunnel but the mock websocket for tests + * requires a short sleep to insure there aren't race conditions related to the incoming websocket data being + * processed. */ + aws_thread_current_sleep(aws_timestamp_convert(350, AWS_TIMESTAMP_MILLIS, AWS_TIMESTAMP_NANOS, NULL)); struct aws_byte_buf data_buf; struct aws_byte_cursor data_cur; struct aws_byte_buf out_buf; @@ -352,6 +521,10 @@ void aws_secure_tunnel_send_mock_message( aws_byte_buf_clean_up(&out_buf); aws_byte_buf_clean_up(&data_buf); + /* The actual WebSocket is assigned the same event loop as the secure tunnel but the mock websocket for tests + * requires a short sleep to insure there aren't race conditions related to the incoming websocket data being + * processed. */ + aws_thread_current_sleep(aws_timestamp_convert(350, AWS_TIMESTAMP_MILLIS, AWS_TIMESTAMP_NANOS, NULL)); } int aws_websocket_client_connect_mock_fn(const struct aws_websocket_client_connection_options *options) { @@ -381,6 +554,7 @@ int aws_websocket_client_connect_mock_fn(const struct aws_websocket_client_conne .websocket = pointer}; (test_fixture->websocket_function_table->on_connection_setup_fn)(&websocket_setup, secure_tunnel); + secure_tunnel->websocket = pointer; struct aws_byte_cursor service_1 = aws_byte_cursor_from_string(s_service_id_1); struct aws_byte_cursor service_2 = aws_byte_cursor_from_string(s_service_id_2); @@ -398,11 +572,45 @@ int aws_websocket_client_connect_mock_fn(const struct aws_websocket_client_conne return AWS_OP_SUCCESS; } +void aws_secure_tunnel_test_on_message_received( + struct aws_secure_tunnel *secure_tunnel, + struct aws_secure_tunnel_message_view *message_view) { + (void)message_view; + struct aws_secure_tunnel_mock_test_fixture *test_fixture = secure_tunnel->config->user_data; + + aws_mutex_lock(&test_fixture->lock); + test_fixture->secure_tunnel_message_sent_count++; + switch (message_view->type) { + case AWS_SECURE_TUNNEL_MT_DATA: + test_fixture->secure_tunnel_message_sent_data_count++; + break; + case AWS_SECURE_TUNNEL_MT_CONNECTION_RESET: + test_fixture->secure_tunnel_message_sent_connection_reset_count++; + break; + default: + break; + } + aws_condition_variable_notify_all(&test_fixture->signal); + aws_mutex_unlock(&test_fixture->lock); +} + int aws_websocket_send_frame_mock_fn( struct aws_websocket *websocket, const struct aws_websocket_send_frame_options *options) { - (void)websocket; - (void)options; + + if (options->opcode == AWS_WEBSOCKET_OPCODE_PING) { + return AWS_OP_SUCCESS; + } + + void *pointer = websocket; + struct aws_secure_tunnel_mock_test_fixture *test_fixture = pointer; + + struct data_tunnel_pair *pair = options->user_data; + aws_secure_tunnel_deserialize_message_from_cursor( + test_fixture->secure_tunnel, &pair->cur, &aws_secure_tunnel_test_on_message_received); + + options->on_complete(websocket, AWS_OP_SUCCESS, options->user_data); + return AWS_OP_SUCCESS; } @@ -480,11 +688,13 @@ int aws_secure_tunnel_mock_test_fixture_init( options->secure_tunnel_options->on_connection_complete = s_on_test_secure_tunnel_connection_complete; options->secure_tunnel_options->on_connection_shutdown = s_on_test_secure_tunnel_connection_shutdown; options->secure_tunnel_options->on_message_received = s_on_test_secure_tunnel_message_received; - options->secure_tunnel_options->on_send_data_complete = s_on_test_secure_tunnel_send_data_complete; + options->secure_tunnel_options->on_send_message_complete = s_on_test_secure_tunnel_send_message_complete; options->secure_tunnel_options->on_session_reset = s_on_test_secure_tunnel_on_session_reset; options->secure_tunnel_options->on_stopped = s_on_test_secure_tunnel_on_stopped; options->secure_tunnel_options->on_stream_reset = s_on_test_secure_tunnel_on_stream_reset; options->secure_tunnel_options->on_stream_start = s_on_test_secure_tunnel_on_stream_start; + options->secure_tunnel_options->on_connection_start = s_on_test_secure_tunnel_on_connection_start; + options->secure_tunnel_options->on_connection_reset = s_on_test_secure_tunnel_on_connection_reset; options->secure_tunnel_options->on_termination_complete = s_on_test_secure_tunnel_termination; options->secure_tunnel_options->secure_tunnel_on_termination_user_data = test_fixture; @@ -504,8 +714,30 @@ int aws_secure_tunnel_mock_test_fixture_init( return AWS_OP_SUCCESS; } -void aws_secure_tunnel_mock_test_fixture_clean_up(struct aws_secure_tunnel_mock_test_fixture *test_fixture) { +void aws_secure_tunnel_mock_test_init( + struct aws_allocator *allocator, + struct secure_tunnel_test_options *test_options, + struct aws_secure_tunnel_mock_test_fixture *test_fixture) { + + aws_http_library_init(allocator); + aws_iotdevice_library_init(allocator); + + s_secure_tunnel_test_init_default_options(test_options); + + test_options->secure_tunnel_options.client_token = aws_byte_cursor_from_string(s_client_token); + + struct aws_secure_tunnel_mock_test_fixture_options test_fixture_options = { + .secure_tunnel_options = &test_options->secure_tunnel_options, + .websocket_function_table = &test_options->websocket_function_table, + }; + + aws_secure_tunnel_mock_test_fixture_init(test_fixture, allocator, &test_fixture_options); +} + +void aws_secure_tunnel_mock_test_clean_up(struct aws_secure_tunnel_mock_test_fixture *test_fixture) { + aws_secure_tunnel_release(test_fixture->secure_tunnel); s_wait_for_secure_tunnel_terminated(test_fixture); + aws_client_bootstrap_release(test_fixture->secure_tunnel_bootstrap); aws_host_resolver_release(test_fixture->host_resolver); @@ -514,13 +746,16 @@ void aws_secure_tunnel_mock_test_fixture_clean_up(struct aws_secure_tunnel_mock_ aws_byte_buf_clean_up(&test_fixture->last_message_payload_buf); aws_mutex_clean_up(&test_fixture->lock); aws_condition_variable_clean_up(&test_fixture->signal); + + aws_iotdevice_library_clean_up(); + aws_http_library_clean_up(); + aws_iotdevice_library_clean_up(); } /********************************************************************************************************************* * TESTS ********************************************************************************************************************/ -/* [Func-UC1] */ int secure_tunneling_access_token_check(const struct aws_http_headers *request_headers, void *user_data) { (void)user_data; struct aws_byte_cursor access_token_cur; @@ -536,44 +771,27 @@ int secure_tunneling_access_token_check(const struct aws_http_headers *request_h static int s_secure_tunneling_functionality_connect_test_fn(struct aws_allocator *allocator, void *ctx) { (void)ctx; - aws_http_library_init(allocator); - aws_iotdevice_library_init(allocator); struct secure_tunnel_test_options test_options; - s_secure_tunnel_test_init_default_options(&test_options); - - struct aws_secure_tunnel_mock_test_fixture_options test_fixture_options = { - .secure_tunnel_options = &test_options.secure_tunnel_options, - .websocket_function_table = &test_options.websocket_function_table, - }; - struct aws_secure_tunnel_mock_test_fixture test_fixture; - ASSERT_SUCCESS(aws_secure_tunnel_mock_test_fixture_init(&test_fixture, allocator, &test_fixture_options)); + aws_secure_tunnel_mock_test_init(allocator, &test_options, &test_fixture); + struct aws_secure_tunnel *secure_tunnel = test_fixture.secure_tunnel; test_fixture.header_check = secure_tunneling_access_token_check; - struct aws_secure_tunnel *secure_tunnel = test_fixture.secure_tunnel; - ASSERT_SUCCESS(aws_secure_tunnel_start(secure_tunnel)); s_wait_for_connected_successfully(&test_fixture); ASSERT_SUCCESS(aws_secure_tunnel_stop(secure_tunnel)); s_wait_for_connection_shutdown(&test_fixture); - aws_secure_tunnel_release(secure_tunnel); - s_wait_for_secure_tunnel_terminated(&test_fixture); - - aws_secure_tunnel_mock_test_fixture_clean_up(&test_fixture); - aws_iotdevice_library_clean_up(); - aws_http_library_clean_up(); - aws_iotdevice_library_clean_up(); + aws_secure_tunnel_mock_test_clean_up(&test_fixture); return AWS_OP_SUCCESS; } AWS_TEST_CASE(secure_tunneling_functionality_connect_test, s_secure_tunneling_functionality_connect_test_fn) -/* [Func-UC2] */ int secure_tunneling_client_token_check(const struct aws_http_headers *request_headers, void *user_data) { (void)user_data; struct aws_byte_cursor client_token_cur; @@ -589,46 +807,26 @@ int secure_tunneling_client_token_check(const struct aws_http_headers *request_h static int s_secure_tunneling_functionality_client_token_test_fn(struct aws_allocator *allocator, void *ctx) { (void)ctx; - aws_http_library_init(allocator); - aws_iotdevice_library_init(allocator); - struct secure_tunnel_test_options test_options; - s_secure_tunnel_test_init_default_options(&test_options); - test_options.secure_tunnel_options.client_token = aws_byte_cursor_from_string(s_client_token); - - struct aws_secure_tunnel_mock_test_fixture_options test_fixture_options = { - .secure_tunnel_options = &test_options.secure_tunnel_options, - .websocket_function_table = &test_options.websocket_function_table, - }; - struct aws_secure_tunnel_mock_test_fixture test_fixture; - ASSERT_SUCCESS(aws_secure_tunnel_mock_test_fixture_init(&test_fixture, allocator, &test_fixture_options)); + aws_secure_tunnel_mock_test_init(allocator, &test_options, &test_fixture); + struct aws_secure_tunnel *secure_tunnel = test_fixture.secure_tunnel; test_fixture.header_check = secure_tunneling_client_token_check; - struct aws_secure_tunnel *secure_tunnel = test_fixture.secure_tunnel; - ASSERT_SUCCESS(aws_secure_tunnel_start(secure_tunnel)); s_wait_for_connected_successfully(&test_fixture); ASSERT_SUCCESS(aws_secure_tunnel_stop(secure_tunnel)); s_wait_for_connection_shutdown(&test_fixture); - aws_secure_tunnel_release(secure_tunnel); - s_wait_for_secure_tunnel_terminated(&test_fixture); - - aws_secure_tunnel_mock_test_fixture_clean_up(&test_fixture); - aws_iotdevice_library_clean_up(); - aws_http_library_clean_up(); - aws_iotdevice_library_clean_up(); + aws_secure_tunnel_mock_test_clean_up(&test_fixture); return AWS_OP_SUCCESS; } AWS_TEST_CASE(secure_tunneling_functionality_client_token_test, s_secure_tunneling_functionality_client_token_test_fn) -/* [Func-UC3] */ - int aws_websocket_client_connect_fail_once_fn(const struct aws_websocket_client_connection_options *options) { struct aws_secure_tunnel *secure_tunnel = options->user_data; struct aws_secure_tunnel_mock_test_fixture *test_fixture = secure_tunnel->config->user_data; @@ -684,19 +882,10 @@ int aws_websocket_client_connect_fail_once_fn(const struct aws_websocket_client_ static int s_secure_tunneling_fail_and_retry_connection_test_fn(struct aws_allocator *allocator, void *ctx) { (void)ctx; - aws_http_library_init(allocator); - aws_iotdevice_library_init(allocator); - struct secure_tunnel_test_options test_options; - s_secure_tunnel_test_init_default_options(&test_options); - - struct aws_secure_tunnel_mock_test_fixture_options test_fixture_options = { - .secure_tunnel_options = &test_options.secure_tunnel_options, - .websocket_function_table = &test_options.websocket_function_table, - }; - struct aws_secure_tunnel_mock_test_fixture test_fixture; - ASSERT_SUCCESS(aws_secure_tunnel_mock_test_fixture_init(&test_fixture, allocator, &test_fixture_options)); + aws_secure_tunnel_mock_test_init(allocator, &test_options, &test_fixture); + struct aws_secure_tunnel *secure_tunnel = test_fixture.secure_tunnel; test_fixture.secure_tunnel_vtable = *aws_secure_tunnel_get_default_vtable(); test_fixture.secure_tunnel_vtable.aws_websocket_client_connect_fn = aws_websocket_client_connect_fail_once_fn; @@ -705,45 +894,24 @@ static int s_secure_tunneling_fail_and_retry_connection_test_fn(struct aws_alloc test_fixture.secure_tunnel_vtable.aws_websocket_close_fn = aws_websocket_close_mock_fn; test_fixture.secure_tunnel_vtable.vtable_user_data = &test_fixture; - struct aws_secure_tunnel *secure_tunnel = test_fixture.secure_tunnel; - ASSERT_SUCCESS(aws_secure_tunnel_start(secure_tunnel)); s_wait_for_connected_successfully(&test_fixture); ASSERT_SUCCESS(aws_secure_tunnel_stop(secure_tunnel)); s_wait_for_connection_shutdown(&test_fixture); - aws_secure_tunnel_release(secure_tunnel); - s_wait_for_secure_tunnel_terminated(&test_fixture); - - aws_secure_tunnel_mock_test_fixture_clean_up(&test_fixture); - aws_iotdevice_library_clean_up(); - aws_http_library_clean_up(); - aws_iotdevice_library_clean_up(); + aws_secure_tunnel_mock_test_clean_up(&test_fixture); return AWS_OP_SUCCESS; } AWS_TEST_CASE(secure_tunneling_fail_and_retry_connection_test, s_secure_tunneling_fail_and_retry_connection_test_fn) -/* [Func-UC4] */ - static int s_secure_tunneling_store_service_ids_test_fn(struct aws_allocator *allocator, void *ctx) { (void)ctx; - aws_http_library_init(allocator); - aws_iotdevice_library_init(allocator); - struct secure_tunnel_test_options test_options; - s_secure_tunnel_test_init_default_options(&test_options); - - struct aws_secure_tunnel_mock_test_fixture_options test_fixture_options = { - .secure_tunnel_options = &test_options.secure_tunnel_options, - .websocket_function_table = &test_options.websocket_function_table, - }; - struct aws_secure_tunnel_mock_test_fixture test_fixture; - ASSERT_SUCCESS(aws_secure_tunnel_mock_test_fixture_init(&test_fixture, allocator, &test_fixture_options)); - + aws_secure_tunnel_mock_test_init(allocator, &test_options, &test_fixture); struct aws_secure_tunnel *secure_tunnel = test_fixture.secure_tunnel; ASSERT_SUCCESS(aws_secure_tunnel_start(secure_tunnel)); @@ -752,51 +920,32 @@ static int s_secure_tunneling_store_service_ids_test_fn(struct aws_allocator *al /* check that service ids have been stored */ struct aws_hash_element *elem = NULL; struct aws_byte_cursor service_id_1_cur = aws_byte_cursor_from_string(s_service_id_1); - aws_hash_table_find(&secure_tunnel->config->service_ids, &service_id_1_cur, &elem); + aws_hash_table_find(&secure_tunnel->connections->service_ids, &service_id_1_cur, &elem); ASSERT_NOT_NULL(elem); elem = NULL; struct aws_byte_cursor service_id_2_cur = aws_byte_cursor_from_string(s_service_id_2); - aws_hash_table_find(&secure_tunnel->config->service_ids, &service_id_2_cur, &elem); + aws_hash_table_find(&secure_tunnel->connections->service_ids, &service_id_2_cur, &elem); ASSERT_NOT_NULL(elem); elem = NULL; struct aws_byte_cursor service_id_3_cur = aws_byte_cursor_from_string(s_service_id_3); - aws_hash_table_find(&secure_tunnel->config->service_ids, &service_id_3_cur, &elem); + aws_hash_table_find(&secure_tunnel->connections->service_ids, &service_id_3_cur, &elem); ASSERT_NOT_NULL(elem); ASSERT_SUCCESS(aws_secure_tunnel_stop(secure_tunnel)); s_wait_for_connection_shutdown(&test_fixture); - aws_secure_tunnel_release(secure_tunnel); - s_wait_for_secure_tunnel_terminated(&test_fixture); - - aws_secure_tunnel_mock_test_fixture_clean_up(&test_fixture); - aws_iotdevice_library_clean_up(); - aws_http_library_clean_up(); - aws_iotdevice_library_clean_up(); + aws_secure_tunnel_mock_test_clean_up(&test_fixture); return AWS_OP_SUCCESS; } AWS_TEST_CASE(secure_tunneling_store_service_ids_test, s_secure_tunneling_store_service_ids_test_fn) -/* [Func-UC5] */ - static int s_secure_tunneling_receive_stream_start_test_fn(struct aws_allocator *allocator, void *ctx) { (void)ctx; - aws_http_library_init(allocator); - aws_iotdevice_library_init(allocator); - struct secure_tunnel_test_options test_options; - s_secure_tunnel_test_init_default_options(&test_options); - - struct aws_secure_tunnel_mock_test_fixture_options test_fixture_options = { - .secure_tunnel_options = &test_options.secure_tunnel_options, - .websocket_function_table = &test_options.websocket_function_table, - }; - struct aws_secure_tunnel_mock_test_fixture test_fixture; - ASSERT_SUCCESS(aws_secure_tunnel_mock_test_fixture_init(&test_fixture, allocator, &test_fixture_options)); - + aws_secure_tunnel_mock_test_init(allocator, &test_options, &test_fixture); struct aws_secure_tunnel *secure_tunnel = test_fixture.secure_tunnel; ASSERT_SUCCESS(aws_secure_tunnel_start(secure_tunnel)); @@ -815,46 +964,24 @@ static int s_secure_tunneling_receive_stream_start_test_fn(struct aws_allocator s_wait_for_stream_started(&test_fixture); /* check that service id stream has been set properly */ - struct aws_hash_element *elem = NULL; - aws_hash_table_find(&secure_tunnel->config->service_ids, stream_start_message_view.service_id, &elem); - ASSERT_NOT_NULL(elem); - struct aws_service_id_element *service_id_elem = elem->value; - ASSERT_TRUE(service_id_elem->stream_id == stream_start_message_view.stream_id); + ASSERT_TRUE(s_secure_tunnel_check_active_stream_id( + secure_tunnel, stream_start_message_view.service_id, stream_start_message_view.stream_id)); ASSERT_SUCCESS(aws_secure_tunnel_stop(secure_tunnel)); s_wait_for_connection_shutdown(&test_fixture); - aws_secure_tunnel_release(secure_tunnel); - s_wait_for_secure_tunnel_terminated(&test_fixture); - - aws_secure_tunnel_mock_test_fixture_clean_up(&test_fixture); - aws_iotdevice_library_clean_up(); - aws_http_library_clean_up(); - aws_iotdevice_library_clean_up(); + aws_secure_tunnel_mock_test_clean_up(&test_fixture); return AWS_OP_SUCCESS; } AWS_TEST_CASE(secure_tunneling_receive_stream_start_test, s_secure_tunneling_receive_stream_start_test_fn) -/* [Func-UC6] */ - static int s_secure_tunneling_rejected_service_id_stream_start_test_fn(struct aws_allocator *allocator, void *ctx) { (void)ctx; - aws_http_library_init(allocator); - aws_iotdevice_library_init(allocator); - struct secure_tunnel_test_options test_options; - s_secure_tunnel_test_init_default_options(&test_options); - - struct aws_secure_tunnel_mock_test_fixture_options test_fixture_options = { - .secure_tunnel_options = &test_options.secure_tunnel_options, - .websocket_function_table = &test_options.websocket_function_table, - }; - struct aws_secure_tunnel_mock_test_fixture test_fixture; - ASSERT_SUCCESS(aws_secure_tunnel_mock_test_fixture_init(&test_fixture, allocator, &test_fixture_options)); - + aws_secure_tunnel_mock_test_init(allocator, &test_options, &test_fixture); struct aws_secure_tunnel *secure_tunnel = test_fixture.secure_tunnel; ASSERT_SUCCESS(aws_secure_tunnel_start(secure_tunnel)); @@ -875,13 +1002,7 @@ static int s_secure_tunneling_rejected_service_id_stream_start_test_fn(struct aw ASSERT_SUCCESS(aws_secure_tunnel_stop(secure_tunnel)); s_wait_for_connection_shutdown(&test_fixture); - aws_secure_tunnel_release(secure_tunnel); - s_wait_for_secure_tunnel_terminated(&test_fixture); - - aws_secure_tunnel_mock_test_fixture_clean_up(&test_fixture); - aws_iotdevice_library_clean_up(); - aws_http_library_clean_up(); - aws_iotdevice_library_clean_up(); + aws_secure_tunnel_mock_test_clean_up(&test_fixture); return AWS_OP_SUCCESS; } @@ -890,24 +1011,11 @@ AWS_TEST_CASE( secure_tunneling_rejected_service_id_stream_start_test, s_secure_tunneling_rejected_service_id_stream_start_test_fn) -/* [Func-UC7] */ - static int s_secure_tunneling_close_stream_on_stream_reset_test_fn(struct aws_allocator *allocator, void *ctx) { (void)ctx; - aws_http_library_init(allocator); - aws_iotdevice_library_init(allocator); - struct secure_tunnel_test_options test_options; - s_secure_tunnel_test_init_default_options(&test_options); - - struct aws_secure_tunnel_mock_test_fixture_options test_fixture_options = { - .secure_tunnel_options = &test_options.secure_tunnel_options, - .websocket_function_table = &test_options.websocket_function_table, - }; - struct aws_secure_tunnel_mock_test_fixture test_fixture; - ASSERT_SUCCESS(aws_secure_tunnel_mock_test_fixture_init(&test_fixture, allocator, &test_fixture_options)); - + aws_secure_tunnel_mock_test_init(allocator, &test_options, &test_fixture); struct aws_secure_tunnel *secure_tunnel = test_fixture.secure_tunnel; ASSERT_SUCCESS(aws_secure_tunnel_start(secure_tunnel)); @@ -925,6 +1033,9 @@ static int s_secure_tunneling_close_stream_on_stream_reset_test_fn(struct aws_al /* Wait and confirm that a stream has been started */ s_wait_for_stream_started(&test_fixture); + /* Check that stream is active */ + ASSERT_TRUE(s_secure_tunnel_check_active_stream_id(secure_tunnel, &service_1, 1)); + /* Send a stream reset message from the server to the destination client */ stream_start_message_view.type = AWS_SECURE_TUNNEL_MT_STREAM_RESET; @@ -933,23 +1044,13 @@ static int s_secure_tunneling_close_stream_on_stream_reset_test_fn(struct aws_al /* Wait for a stream reset to have been received */ s_wait_for_stream_reset_received(&test_fixture); - /* check that service id stream has been reset */ - struct aws_hash_element *elem = NULL; - aws_hash_table_find(&secure_tunnel->config->service_ids, stream_start_message_view.service_id, &elem); - ASSERT_NOT_NULL(elem); - struct aws_service_id_element *service_id_elem = elem->value; - ASSERT_TRUE(service_id_elem->stream_id == 0); + /* Check that stream id has been reset */ + ASSERT_TRUE(s_secure_tunnel_check_active_stream_id(secure_tunnel, &service_1, 0)); ASSERT_SUCCESS(aws_secure_tunnel_stop(secure_tunnel)); s_wait_for_connection_shutdown(&test_fixture); - aws_secure_tunnel_release(secure_tunnel); - s_wait_for_secure_tunnel_terminated(&test_fixture); - - aws_secure_tunnel_mock_test_fixture_clean_up(&test_fixture); - aws_iotdevice_library_clean_up(); - aws_http_library_clean_up(); - aws_iotdevice_library_clean_up(); + aws_secure_tunnel_mock_test_clean_up(&test_fixture); return AWS_OP_SUCCESS; } @@ -958,23 +1059,13 @@ AWS_TEST_CASE( secure_tunneling_close_stream_on_stream_reset_test, s_secure_tunneling_close_stream_on_stream_reset_test_fn) -/* [Func-UC8] */ -static int s_secure_tunneling_session_reset_test_fn(struct aws_allocator *allocator, void *ctx) { +static int s_secure_tunneling_ignore_stream_reset_for_inactive_stream_test_fn( + struct aws_allocator *allocator, + void *ctx) { (void)ctx; - aws_http_library_init(allocator); - aws_iotdevice_library_init(allocator); - struct secure_tunnel_test_options test_options; - s_secure_tunnel_test_init_default_options(&test_options); - - struct aws_secure_tunnel_mock_test_fixture_options test_fixture_options = { - .secure_tunnel_options = &test_options.secure_tunnel_options, - .websocket_function_table = &test_options.websocket_function_table, - }; - struct aws_secure_tunnel_mock_test_fixture test_fixture; - ASSERT_SUCCESS(aws_secure_tunnel_mock_test_fixture_init(&test_fixture, allocator, &test_fixture_options)); - + aws_secure_tunnel_mock_test_init(allocator, &test_options, &test_fixture); struct aws_secure_tunnel *secure_tunnel = test_fixture.secure_tunnel; ASSERT_SUCCESS(aws_secure_tunnel_start(secure_tunnel)); @@ -982,41 +1073,80 @@ static int s_secure_tunneling_session_reset_test_fn(struct aws_allocator *alloca /* Create and send a stream start message from the server to the destination client */ struct aws_byte_cursor service_1 = aws_byte_cursor_from_string(s_service_id_1); - struct aws_byte_cursor service_2 = aws_byte_cursor_from_string(s_service_id_2); - struct aws_byte_cursor service_3 = aws_byte_cursor_from_string(s_service_id_3); struct aws_secure_tunnel_message_view stream_start_message_view = { .type = AWS_SECURE_TUNNEL_MT_STREAM_START, .service_id = &service_1, .stream_id = 1, }; aws_secure_tunnel_send_mock_message(&test_fixture, &stream_start_message_view); - stream_start_message_view.service_id = &service_2; - aws_secure_tunnel_send_mock_message(&test_fixture, &stream_start_message_view); - stream_start_message_view.service_id = &service_3; - aws_secure_tunnel_send_mock_message(&test_fixture, &stream_start_message_view); - test_fixture.secure_tunnel_stream_started_count_target = 3; - s_wait_for_n_stream_started(&test_fixture); + /* Wait and confirm that a stream has been started */ + s_wait_for_stream_started(&test_fixture); - /* check that stream ids have been set */ - struct aws_hash_element *elem = NULL; - struct aws_byte_cursor service_id_1_cur = aws_byte_cursor_from_string(s_service_id_1); - aws_hash_table_find(&secure_tunnel->config->service_ids, &service_id_1_cur, &elem); - ASSERT_NOT_NULL(elem); - struct aws_service_id_element *service_id_elem = elem->value; - ASSERT_TRUE(service_id_elem->stream_id == stream_start_message_view.stream_id); - elem = NULL; - struct aws_byte_cursor service_id_2_cur = aws_byte_cursor_from_string(s_service_id_2); - aws_hash_table_find(&secure_tunnel->config->service_ids, &service_id_2_cur, &elem); - ASSERT_NOT_NULL(elem); - service_id_elem = elem->value; - ASSERT_TRUE(service_id_elem->stream_id == stream_start_message_view.stream_id); - elem = NULL; - struct aws_byte_cursor service_id_3_cur = aws_byte_cursor_from_string(s_service_id_3); - aws_hash_table_find(&secure_tunnel->config->service_ids, &service_id_3_cur, &elem); - ASSERT_NOT_NULL(elem); - service_id_elem = elem->value; - ASSERT_TRUE(service_id_elem->stream_id == stream_start_message_view.stream_id); + /* Check that stream is active */ + ASSERT_TRUE(s_secure_tunnel_check_active_stream_id(secure_tunnel, &service_1, 1)); + + /* Send a stream reset message for a different stream id from the server to the destination client */ + struct aws_secure_tunnel_message_view stream_reset_message_view = { + .type = AWS_SECURE_TUNNEL_MT_STREAM_RESET, + .service_id = &service_1, + .stream_id = 2, + }; + + aws_secure_tunnel_send_mock_message(&test_fixture, &stream_reset_message_view); + + /* Stream reset is ignored by client on an inactive stream id. Wait for client to process the message that should be + * ignored. */ + aws_thread_current_sleep(aws_timestamp_convert(1, AWS_TIMESTAMP_SECS, AWS_TIMESTAMP_NANOS, NULL)); + + /* Check that stream is still active */ + ASSERT_TRUE(s_secure_tunnel_check_active_stream_id(secure_tunnel, &service_1, 1)); + + ASSERT_SUCCESS(aws_secure_tunnel_stop(secure_tunnel)); + s_wait_for_connection_shutdown(&test_fixture); + + aws_secure_tunnel_mock_test_clean_up(&test_fixture); + + return AWS_OP_SUCCESS; +} + +AWS_TEST_CASE( + secure_tunneling_ignore_stream_reset_for_inactive_stream_test, + s_secure_tunneling_ignore_stream_reset_for_inactive_stream_test_fn) + +static int s_secure_tunneling_session_reset_test_fn(struct aws_allocator *allocator, void *ctx) { + (void)ctx; + struct secure_tunnel_test_options test_options; + struct aws_secure_tunnel_mock_test_fixture test_fixture; + aws_secure_tunnel_mock_test_init(allocator, &test_options, &test_fixture); + struct aws_secure_tunnel *secure_tunnel = test_fixture.secure_tunnel; + + ASSERT_SUCCESS(aws_secure_tunnel_start(secure_tunnel)); + s_wait_for_connected_successfully(&test_fixture); + + /* Create and send a stream start message from the server to the destination client */ + struct aws_byte_cursor service_1 = aws_byte_cursor_from_string(s_service_id_1); + struct aws_byte_cursor service_2 = aws_byte_cursor_from_string(s_service_id_2); + struct aws_byte_cursor service_3 = aws_byte_cursor_from_string(s_service_id_3); + struct aws_secure_tunnel_message_view stream_start_message_view = { + .type = AWS_SECURE_TUNNEL_MT_STREAM_START, + .service_id = &service_1, + .stream_id = 1, + }; + + aws_secure_tunnel_send_mock_message(&test_fixture, &stream_start_message_view); + stream_start_message_view.service_id = &service_2; + aws_secure_tunnel_send_mock_message(&test_fixture, &stream_start_message_view); + stream_start_message_view.service_id = &service_3; + aws_secure_tunnel_send_mock_message(&test_fixture, &stream_start_message_view); + + test_fixture.secure_tunnel_stream_started_count_target = 3; + s_wait_for_n_stream_started(&test_fixture); + + /* check that stream ids have been set */ + ASSERT_TRUE(s_secure_tunnel_check_active_stream_id(secure_tunnel, &service_1, 1)); + ASSERT_TRUE(s_secure_tunnel_check_active_stream_id(secure_tunnel, &service_2, 1)); + ASSERT_TRUE(s_secure_tunnel_check_active_stream_id(secure_tunnel, &service_3, 1)); /* Create and send a session reset message from the server to the destination client */ struct aws_secure_tunnel_message_view reset_message_view = { @@ -1027,32 +1157,14 @@ static int s_secure_tunneling_session_reset_test_fn(struct aws_allocator *alloca s_wait_for_session_reset_received(&test_fixture); /* Check that stream ids have been reset */ - elem = NULL; - aws_hash_table_find(&secure_tunnel->config->service_ids, &service_id_1_cur, &elem); - ASSERT_NOT_NULL(elem); - service_id_elem = elem->value; - ASSERT_TRUE(service_id_elem->stream_id == 0); - elem = NULL; - aws_hash_table_find(&secure_tunnel->config->service_ids, &service_id_2_cur, &elem); - ASSERT_NOT_NULL(elem); - service_id_elem = elem->value; - ASSERT_TRUE(service_id_elem->stream_id == 0); - elem = NULL; - aws_hash_table_find(&secure_tunnel->config->service_ids, &service_id_3_cur, &elem); - ASSERT_NOT_NULL(elem); - service_id_elem = elem->value; - ASSERT_TRUE(service_id_elem->stream_id == 0); + ASSERT_TRUE(s_secure_tunnel_check_active_stream_id(secure_tunnel, &service_1, 0)); + ASSERT_TRUE(s_secure_tunnel_check_active_stream_id(secure_tunnel, &service_2, 0)); + ASSERT_TRUE(s_secure_tunnel_check_active_stream_id(secure_tunnel, &service_3, 0)); ASSERT_SUCCESS(aws_secure_tunnel_stop(secure_tunnel)); s_wait_for_connection_shutdown(&test_fixture); - aws_secure_tunnel_release(secure_tunnel); - s_wait_for_secure_tunnel_terminated(&test_fixture); - - aws_secure_tunnel_mock_test_fixture_clean_up(&test_fixture); - aws_iotdevice_library_clean_up(); - aws_http_library_clean_up(); - aws_iotdevice_library_clean_up(); + aws_secure_tunnel_mock_test_clean_up(&test_fixture); return AWS_OP_SUCCESS; } @@ -1061,20 +1173,9 @@ AWS_TEST_CASE(secure_tunneling_session_reset_test, s_secure_tunneling_session_re static int s_secure_tunneling_serializer_data_message_test_fn(struct aws_allocator *allocator, void *ctx) { (void)ctx; - aws_http_library_init(allocator); - aws_iotdevice_library_init(allocator); - struct secure_tunnel_test_options test_options; - s_secure_tunnel_test_init_default_options(&test_options); - - struct aws_secure_tunnel_mock_test_fixture_options test_fixture_options = { - .secure_tunnel_options = &test_options.secure_tunnel_options, - .websocket_function_table = &test_options.websocket_function_table, - }; - struct aws_secure_tunnel_mock_test_fixture test_fixture; - ASSERT_SUCCESS(aws_secure_tunnel_mock_test_fixture_init(&test_fixture, allocator, &test_fixture_options)); - + aws_secure_tunnel_mock_test_init(allocator, &test_options, &test_fixture); struct aws_secure_tunnel *secure_tunnel = test_fixture.secure_tunnel; ASSERT_SUCCESS(aws_secure_tunnel_start(secure_tunnel)); @@ -1089,6 +1190,10 @@ static int s_secure_tunneling_serializer_data_message_test_fn(struct aws_allocat }; aws_secure_tunnel_send_mock_message(&test_fixture, &stream_start_message_view); + /* Wait and confirm that a stream has been started */ + s_wait_for_stream_started(&test_fixture); + ASSERT_TRUE(s_secure_tunnel_check_active_stream_id(secure_tunnel, &service_1, 1)); + /* Create and send a data message from the server to the destination client */ struct aws_byte_cursor payload_cur = aws_byte_cursor_from_string(s_payload_text); struct aws_secure_tunnel_message_view data_message_view = { @@ -1099,7 +1204,7 @@ static int s_secure_tunneling_serializer_data_message_test_fn(struct aws_allocat }; aws_secure_tunnel_send_mock_message(&test_fixture, &data_message_view); - test_fixture.secure_tunnel_message_count_target = 1; + test_fixture.secure_tunnel_message_received_count_target = 1; s_wait_for_n_messages_received(&test_fixture); struct aws_byte_cursor payload_comp_cur = { @@ -1108,21 +1213,737 @@ static int s_secure_tunneling_serializer_data_message_test_fn(struct aws_allocat }; ASSERT_CURSOR_VALUE_STRING_EQUALS(payload_comp_cur, s_payload_text); + ASSERT_SUCCESS(aws_secure_tunnel_stop(secure_tunnel)); + s_wait_for_connection_shutdown(&test_fixture); + + aws_secure_tunnel_mock_test_clean_up(&test_fixture); + + return AWS_OP_SUCCESS; +} + +AWS_TEST_CASE(secure_tunneling_serializer_data_message_test, s_secure_tunneling_serializer_data_message_test_fn) + +static int s_secure_tunneling_max_payload_test_fn(struct aws_allocator *allocator, void *ctx) { + (void)ctx; + struct secure_tunnel_test_options test_options; + struct aws_secure_tunnel_mock_test_fixture test_fixture; + aws_secure_tunnel_mock_test_init(allocator, &test_options, &test_fixture); + struct aws_secure_tunnel *secure_tunnel = test_fixture.secure_tunnel; + + ASSERT_SUCCESS(aws_secure_tunnel_start(secure_tunnel)); + s_wait_for_connected_successfully(&test_fixture); + + /* Create and send a stream start message from the server to the destination client */ + struct aws_byte_cursor service_1 = aws_byte_cursor_from_string(s_service_id_1); + struct aws_secure_tunnel_message_view stream_start_message_view = { + .type = AWS_SECURE_TUNNEL_MT_STREAM_START, + .service_id = &service_1, + .stream_id = 1, + }; + aws_secure_tunnel_send_mock_message(&test_fixture, &stream_start_message_view); + /* Wait and confirm that a stream has been started */ s_wait_for_stream_started(&test_fixture); + ASSERT_TRUE(s_secure_tunnel_check_active_stream_id(secure_tunnel, &service_1, 1)); + + struct aws_secure_tunnel_message_view data_message_view = { + .type = AWS_SECURE_TUNNEL_MT_DATA, + .stream_id = 0, + .service_id = &service_1, + .payload = &s_payload_cursor_max_size, + }; + + aws_secure_tunnel_send_message(secure_tunnel, &data_message_view); ASSERT_SUCCESS(aws_secure_tunnel_stop(secure_tunnel)); s_wait_for_connection_shutdown(&test_fixture); - aws_secure_tunnel_release(secure_tunnel); - s_wait_for_secure_tunnel_terminated(&test_fixture); + aws_secure_tunnel_mock_test_clean_up(&test_fixture); - aws_secure_tunnel_mock_test_fixture_clean_up(&test_fixture); - aws_iotdevice_library_clean_up(); - aws_http_library_clean_up(); - aws_iotdevice_library_clean_up(); + return AWS_OP_SUCCESS; +} + +AWS_TEST_CASE(secure_tunneling_max_payload_test, s_secure_tunneling_max_payload_test_fn) + +static int s_secure_tunneling_max_payload_exceed_test_fn(struct aws_allocator *allocator, void *ctx) { + (void)ctx; + struct secure_tunnel_test_options test_options; + struct aws_secure_tunnel_mock_test_fixture test_fixture; + aws_secure_tunnel_mock_test_init(allocator, &test_options, &test_fixture); + struct aws_secure_tunnel *secure_tunnel = test_fixture.secure_tunnel; + + ASSERT_SUCCESS(aws_secure_tunnel_start(secure_tunnel)); + s_wait_for_connected_successfully(&test_fixture); + + /* Create and send a stream start message from the server to the destination client */ + struct aws_byte_cursor service_1 = aws_byte_cursor_from_string(s_service_id_1); + struct aws_secure_tunnel_message_view stream_start_message_view = { + .type = AWS_SECURE_TUNNEL_MT_STREAM_START, + .service_id = &service_1, + .stream_id = 1, + }; + aws_secure_tunnel_send_mock_message(&test_fixture, &stream_start_message_view); + + /* Wait and confirm that a stream has been started */ + s_wait_for_stream_started(&test_fixture); + ASSERT_TRUE(s_secure_tunnel_check_active_stream_id(secure_tunnel, &service_1, 1)); + + struct aws_secure_tunnel_message_view data_message_view = { + .type = AWS_SECURE_TUNNEL_MT_DATA, + .stream_id = 0, + .service_id = &service_1, + .connection_id = 1, + .payload = &s_payload_cursor_max_size_exceeded, + }; + + int result = aws_secure_tunnel_send_message(secure_tunnel, &data_message_view); + + ASSERT_INT_EQUALS(result, AWS_ERROR_IOTDEVICE_SECURE_TUNNELING_DATA_OPTIONS_VALIDATION); + + ASSERT_SUCCESS(aws_secure_tunnel_stop(secure_tunnel)); + s_wait_for_connection_shutdown(&test_fixture); + + aws_secure_tunnel_mock_test_clean_up(&test_fixture); return AWS_OP_SUCCESS; } -AWS_TEST_CASE(secure_tunneling_serializer_data_message_test, s_secure_tunneling_serializer_data_message_test_fn) +AWS_TEST_CASE(secure_tunneling_max_payload_exceed_test, s_secure_tunneling_max_payload_exceed_test_fn) + +static int s_secure_tunneling_receive_connection_start_test_fn(struct aws_allocator *allocator, void *ctx) { + (void)ctx; + struct secure_tunnel_test_options test_options; + struct aws_secure_tunnel_mock_test_fixture test_fixture; + aws_secure_tunnel_mock_test_init(allocator, &test_options, &test_fixture); + struct aws_secure_tunnel *secure_tunnel = test_fixture.secure_tunnel; + + ASSERT_SUCCESS(aws_secure_tunnel_start(secure_tunnel)); + s_wait_for_connected_successfully(&test_fixture); + + /* Create and send a stream start message from the server to the destination client */ + struct aws_byte_cursor service_1 = aws_byte_cursor_from_string(s_service_id_1); + struct aws_secure_tunnel_message_view stream_start_message_view = { + .type = AWS_SECURE_TUNNEL_MT_STREAM_START, + .service_id = &service_1, + .stream_id = 1, + .connection_id = 1, + }; + aws_secure_tunnel_send_mock_message(&test_fixture, &stream_start_message_view); + + /* Wait and confirm that a stream has been started */ + s_wait_for_stream_started(&test_fixture); + ASSERT_TRUE(s_secure_tunnel_check_active_connection_id(secure_tunnel, &service_1, 1, 1)); + + struct aws_secure_tunnel_message_view connection_start_message_view = { + .type = AWS_SECURE_TUNNEL_MT_CONNECTION_START, + .service_id = &service_1, + .stream_id = 1, + .connection_id = 2, + }; + aws_secure_tunnel_send_mock_message(&test_fixture, &connection_start_message_view); + + /* Wait and confirm that a connection has been started */ + s_wait_for_connection_started(&test_fixture); + ASSERT_TRUE(s_secure_tunnel_check_active_connection_id(secure_tunnel, &service_1, 1, 2)); + + ASSERT_SUCCESS(aws_secure_tunnel_stop(secure_tunnel)); + s_wait_for_connection_shutdown(&test_fixture); + + aws_secure_tunnel_mock_test_clean_up(&test_fixture); + + return AWS_OP_SUCCESS; +} + +AWS_TEST_CASE(secure_tunneling_receive_connection_start_test, s_secure_tunneling_receive_connection_start_test_fn) + +static int s_secure_tunneling_ignore_inactive_stream_message_test_fn(struct aws_allocator *allocator, void *ctx) { + (void)ctx; + struct secure_tunnel_test_options test_options; + struct aws_secure_tunnel_mock_test_fixture test_fixture; + aws_secure_tunnel_mock_test_init(allocator, &test_options, &test_fixture); + struct aws_secure_tunnel *secure_tunnel = test_fixture.secure_tunnel; + + ASSERT_SUCCESS(aws_secure_tunnel_start(secure_tunnel)); + s_wait_for_connected_successfully(&test_fixture); + + /* Create and send a stream start message from the server to the destination client */ + struct aws_byte_cursor service_1 = aws_byte_cursor_from_string(s_service_id_1); + struct aws_secure_tunnel_message_view stream_start_message_view = { + .type = AWS_SECURE_TUNNEL_MT_STREAM_START, + .service_id = &service_1, + .stream_id = 1, + }; + aws_secure_tunnel_send_mock_message(&test_fixture, &stream_start_message_view); + + /* Wait and confirm that a stream has been started */ + s_wait_for_stream_started(&test_fixture); + s_secure_tunnel_check_active_stream_id(secure_tunnel, &service_1, 1); + + /* Create and send a data message on a different stream id from the server to the destination client */ + struct aws_byte_cursor payload_cur = aws_byte_cursor_from_string(s_payload_text); + struct aws_secure_tunnel_message_view data_message_view = { + .type = AWS_SECURE_TUNNEL_MT_DATA, + .service_id = &service_1, + .stream_id = 2, + .payload = &payload_cur, + }; + + aws_secure_tunnel_send_mock_message(&test_fixture, &data_message_view); + + /* Messages on inactive streams are ignored and no callback is emitted. Wait for client to process and ignore + * message */ + aws_thread_current_sleep(aws_timestamp_convert(1, AWS_TIMESTAMP_SECS, AWS_TIMESTAMP_NANOS, NULL)); + ASSERT_INT_EQUALS(test_fixture.secure_tunnel_message_received_count, 0); + + ASSERT_SUCCESS(aws_secure_tunnel_stop(secure_tunnel)); + s_wait_for_connection_shutdown(&test_fixture); + + aws_secure_tunnel_mock_test_clean_up(&test_fixture); + + return AWS_OP_SUCCESS; +} + +AWS_TEST_CASE( + secure_tunneling_ignore_inactive_stream_message_test, + s_secure_tunneling_ignore_inactive_stream_message_test_fn) + +static int s_secure_tunneling_ignore_inactive_connection_id_message_test_fn( + struct aws_allocator *allocator, + void *ctx) { + (void)ctx; + struct secure_tunnel_test_options test_options; + struct aws_secure_tunnel_mock_test_fixture test_fixture; + aws_secure_tunnel_mock_test_init(allocator, &test_options, &test_fixture); + struct aws_secure_tunnel *secure_tunnel = test_fixture.secure_tunnel; + + ASSERT_SUCCESS(aws_secure_tunnel_start(secure_tunnel)); + s_wait_for_connected_successfully(&test_fixture); + + /* Create and send a stream start message from the server to the destination client */ + struct aws_byte_cursor service_1 = aws_byte_cursor_from_string(s_service_id_1); + struct aws_secure_tunnel_message_view stream_start_message_view = { + .type = AWS_SECURE_TUNNEL_MT_STREAM_START, + .service_id = &service_1, + .stream_id = 1, + .connection_id = 2, + }; + aws_secure_tunnel_send_mock_message(&test_fixture, &stream_start_message_view); + + /* Wait and confirm that a stream has been started */ + s_wait_for_stream_started(&test_fixture); + ASSERT_TRUE(s_secure_tunnel_check_active_connection_id(secure_tunnel, &service_1, 1, 2)); + + /* Create and send a data message on a different stream id from the server to the destination client */ + struct aws_byte_cursor payload_cur = aws_byte_cursor_from_string(s_payload_text); + struct aws_secure_tunnel_message_view data_message_view = { + .type = AWS_SECURE_TUNNEL_MT_DATA, + .service_id = &service_1, + .stream_id = 2, + .connection_id = 4, + .payload = &payload_cur, + }; + + aws_secure_tunnel_send_mock_message(&test_fixture, &data_message_view); + + /* Messages on inactive streams are ignored and no callback is emitted. Wait for client to process and ignore + * message */ + aws_thread_current_sleep(aws_timestamp_convert(1, AWS_TIMESTAMP_SECS, AWS_TIMESTAMP_NANOS, NULL)); + ASSERT_INT_EQUALS(test_fixture.secure_tunnel_message_received_count, 0); + + ASSERT_SUCCESS(aws_secure_tunnel_stop(secure_tunnel)); + s_wait_for_connection_shutdown(&test_fixture); + + aws_secure_tunnel_mock_test_clean_up(&test_fixture); + + return AWS_OP_SUCCESS; +} + +AWS_TEST_CASE( + secure_tunneling_ignore_inactive_connection_id_message_test, + s_secure_tunneling_ignore_inactive_connection_id_message_test_fn) + +static int s_secure_tunneling_v1_to_v2_stream_start_test_fn(struct aws_allocator *allocator, void *ctx) { + (void)ctx; + struct secure_tunnel_test_options test_options; + struct aws_secure_tunnel_mock_test_fixture test_fixture; + aws_secure_tunnel_mock_test_init(allocator, &test_options, &test_fixture); + struct aws_secure_tunnel *secure_tunnel = test_fixture.secure_tunnel; + + ASSERT_SUCCESS(aws_secure_tunnel_start(secure_tunnel)); + s_wait_for_connected_successfully(&test_fixture); + + /* Create and send a stream start message from the server to the destination client */ + struct aws_byte_cursor service_2 = aws_byte_cursor_from_string(s_service_id_2); + struct aws_secure_tunnel_message_view stream_start_message_view = { + .type = AWS_SECURE_TUNNEL_MT_STREAM_START, + .stream_id = 1, + }; + aws_secure_tunnel_send_mock_message(&test_fixture, &stream_start_message_view); + + /* Wait and confirm that a stream has been started */ + s_wait_for_stream_started(&test_fixture); + ASSERT_TRUE(s_secure_tunnel_check_active_stream_id(secure_tunnel, NULL, 1)); + + struct aws_secure_tunnel_message_view stream_start_message_view_2 = { + .type = AWS_SECURE_TUNNEL_MT_STREAM_START, + .service_id = &service_2, + .stream_id = 1, + }; + + aws_secure_tunnel_send_mock_message(&test_fixture, &stream_start_message_view_2); + + /* Client should disconnect, clear previous V1 connection and stream, reconnect, and start a V2 stream */ + + s_wait_for_connection_shutdown(&test_fixture); + s_wait_for_connected_successfully(&test_fixture); + + /* Check that the established stream is cleared */ + ASSERT_TRUE(s_secure_tunnel_check_active_stream_id(secure_tunnel, NULL, 0)); + + s_wait_for_stream_started(&test_fixture); + ASSERT_TRUE(s_secure_tunnel_check_active_stream_id(secure_tunnel, &service_2, 1)); + + ASSERT_SUCCESS(aws_secure_tunnel_stop(secure_tunnel)); + s_wait_for_connection_shutdown(&test_fixture); + + aws_secure_tunnel_mock_test_clean_up(&test_fixture); + + return AWS_OP_SUCCESS; +} + +AWS_TEST_CASE(secure_tunneling_v1_to_v2_stream_start_test, s_secure_tunneling_v1_to_v2_stream_start_test_fn) + +static int s_secure_tunneling_v1_to_v3_stream_start_test_fn(struct aws_allocator *allocator, void *ctx) { + (void)ctx; + struct secure_tunnel_test_options test_options; + struct aws_secure_tunnel_mock_test_fixture test_fixture; + aws_secure_tunnel_mock_test_init(allocator, &test_options, &test_fixture); + struct aws_secure_tunnel *secure_tunnel = test_fixture.secure_tunnel; + + ASSERT_SUCCESS(aws_secure_tunnel_start(secure_tunnel)); + s_wait_for_connected_successfully(&test_fixture); + + /* Create and send a stream start message from the server to the destination client */ + struct aws_byte_cursor service_1 = aws_byte_cursor_from_string(s_service_id_1); + struct aws_secure_tunnel_message_view stream_start_message_view = { + .type = AWS_SECURE_TUNNEL_MT_STREAM_START, + .stream_id = 1, + }; + aws_secure_tunnel_send_mock_message(&test_fixture, &stream_start_message_view); + + /* Wait and confirm that a stream has been started */ + s_wait_for_stream_started(&test_fixture); + ASSERT_TRUE(s_secure_tunnel_check_active_stream_id(secure_tunnel, NULL, 1)); + + struct aws_secure_tunnel_message_view stream_start_message_view_2 = { + .type = AWS_SECURE_TUNNEL_MT_STREAM_START, + .service_id = &service_1, + .stream_id = 1, + .connection_id = 3, + }; + + aws_secure_tunnel_send_mock_message(&test_fixture, &stream_start_message_view_2); + + /* Client should disconnect, clear previous V1 connection and stream, reconnect, and start a V3 stream */ + + s_wait_for_connection_shutdown(&test_fixture); + s_wait_for_connected_successfully(&test_fixture); + + /* Check that the established stream is cleared */ + ASSERT_TRUE(s_secure_tunnel_check_active_stream_id(secure_tunnel, NULL, 0)); + + s_wait_for_stream_started(&test_fixture); + ASSERT_TRUE(s_secure_tunnel_check_active_connection_id(secure_tunnel, &service_1, 1, 3)); + + ASSERT_SUCCESS(aws_secure_tunnel_stop(secure_tunnel)); + s_wait_for_connection_shutdown(&test_fixture); + + aws_secure_tunnel_mock_test_clean_up(&test_fixture); + + return AWS_OP_SUCCESS; +} + +AWS_TEST_CASE(secure_tunneling_v1_to_v3_stream_start_test, s_secure_tunneling_v1_to_v3_stream_start_test_fn) + +static int s_secure_tunneling_v2_to_v1_stream_start_test_fn(struct aws_allocator *allocator, void *ctx) { + (void)ctx; + struct secure_tunnel_test_options test_options; + struct aws_secure_tunnel_mock_test_fixture test_fixture; + aws_secure_tunnel_mock_test_init(allocator, &test_options, &test_fixture); + struct aws_secure_tunnel *secure_tunnel = test_fixture.secure_tunnel; + + ASSERT_SUCCESS(aws_secure_tunnel_start(secure_tunnel)); + s_wait_for_connected_successfully(&test_fixture); + + /* Create and send a v2 stream start message from the server to the destination client */ + struct aws_byte_cursor service_1 = aws_byte_cursor_from_string(s_service_id_1); + struct aws_secure_tunnel_message_view stream_start_message_view = { + .type = AWS_SECURE_TUNNEL_MT_STREAM_START, + .service_id = &service_1, + .stream_id = 1, + }; + aws_secure_tunnel_send_mock_message(&test_fixture, &stream_start_message_view); + + /* Wait and confirm that a stream has been started */ + s_wait_for_stream_started(&test_fixture); + ASSERT_TRUE(s_secure_tunnel_check_active_stream_id(secure_tunnel, &service_1, 1)); + + struct aws_secure_tunnel_message_view stream_start_message_view_2 = { + .type = AWS_SECURE_TUNNEL_MT_STREAM_START, + .stream_id = 2, + }; + aws_secure_tunnel_send_mock_message(&test_fixture, &stream_start_message_view_2); + + /* Client should disconnect, clear previous V2 connection and stream, reconnect, and start a V1 stream */ + + s_wait_for_connection_shutdown(&test_fixture); + s_wait_for_connected_successfully(&test_fixture); + + /* Confirm that previous stream has been closed */ + ASSERT_TRUE(s_secure_tunnel_check_active_stream_id(secure_tunnel, &service_1, 0)); + + /* Wait and confirm that a stream has been started */ + s_wait_for_stream_started(&test_fixture); + ASSERT_TRUE(s_secure_tunnel_check_active_stream_id(secure_tunnel, NULL, 2)); + + ASSERT_SUCCESS(aws_secure_tunnel_stop(secure_tunnel)); + s_wait_for_connection_shutdown(&test_fixture); + + aws_secure_tunnel_mock_test_clean_up(&test_fixture); + + return AWS_OP_SUCCESS; +} + +AWS_TEST_CASE(secure_tunneling_v2_to_v1_stream_start_test, s_secure_tunneling_v2_to_v1_stream_start_test_fn) + +static int s_secure_tunneling_v3_to_v1_stream_start_test_fn(struct aws_allocator *allocator, void *ctx) { + (void)ctx; + struct secure_tunnel_test_options test_options; + struct aws_secure_tunnel_mock_test_fixture test_fixture; + aws_secure_tunnel_mock_test_init(allocator, &test_options, &test_fixture); + struct aws_secure_tunnel *secure_tunnel = test_fixture.secure_tunnel; + + ASSERT_SUCCESS(aws_secure_tunnel_start(secure_tunnel)); + s_wait_for_connected_successfully(&test_fixture); + + /* Create and send a v2 stream start message from the server to the destination client */ + struct aws_byte_cursor service_1 = aws_byte_cursor_from_string(s_service_id_1); + struct aws_secure_tunnel_message_view stream_start_message_view = { + .type = AWS_SECURE_TUNNEL_MT_STREAM_START, + .service_id = &service_1, + .stream_id = 1, + .connection_id = 2, + }; + aws_secure_tunnel_send_mock_message(&test_fixture, &stream_start_message_view); + + /* Wait and confirm that a stream has been started */ + s_wait_for_stream_started(&test_fixture); + ASSERT_TRUE(s_secure_tunnel_check_active_connection_id(secure_tunnel, &service_1, 1, 2)); + + struct aws_secure_tunnel_message_view stream_start_message_view_2 = { + .type = AWS_SECURE_TUNNEL_MT_STREAM_START, + .stream_id = 2, + }; + aws_secure_tunnel_send_mock_message(&test_fixture, &stream_start_message_view_2); + + /* Client should disconnect, clear previous V3 connection and stream, reconnect, and start a V1 stream */ + + s_wait_for_connection_shutdown(&test_fixture); + s_wait_for_connected_successfully(&test_fixture); + + /* Check that the established stream is cleared */ + ASSERT_FALSE(s_secure_tunnel_check_active_connection_id(secure_tunnel, &service_1, 1, 2)); + + /* Wait and confirm that a stream has been started */ + s_wait_for_stream_started(&test_fixture); + /* Check that V1 Stream is established */ + ASSERT_TRUE(s_secure_tunnel_check_active_stream_id(secure_tunnel, NULL, 2)); + + ASSERT_SUCCESS(aws_secure_tunnel_stop(secure_tunnel)); + s_wait_for_connection_shutdown(&test_fixture); + + aws_secure_tunnel_mock_test_clean_up(&test_fixture); + + return AWS_OP_SUCCESS; +} + +AWS_TEST_CASE(secure_tunneling_v3_to_v1_stream_start_test, s_secure_tunneling_v3_to_v1_stream_start_test_fn) + +static int s_secure_tunneling_v1_stream_start_v3_message_reset_test_fn(struct aws_allocator *allocator, void *ctx) { + (void)ctx; + struct secure_tunnel_test_options test_options; + struct aws_secure_tunnel_mock_test_fixture test_fixture; + aws_secure_tunnel_mock_test_init(allocator, &test_options, &test_fixture); + struct aws_secure_tunnel *secure_tunnel = test_fixture.secure_tunnel; + + ASSERT_SUCCESS(aws_secure_tunnel_start(secure_tunnel)); + s_wait_for_connected_successfully(&test_fixture); + + /* Create and send a stream start message from the server to the destination client */ + struct aws_byte_cursor service_1 = aws_byte_cursor_from_string(s_service_id_1); + struct aws_secure_tunnel_message_view stream_start_message_view = { + .type = AWS_SECURE_TUNNEL_MT_STREAM_START, + .stream_id = 1, + }; + aws_secure_tunnel_send_mock_message(&test_fixture, &stream_start_message_view); + + /* Wait and confirm that a stream has been started */ + s_wait_for_stream_started(&test_fixture); + ASSERT_TRUE(s_secure_tunnel_check_active_stream_id(secure_tunnel, NULL, 1)); + + struct aws_byte_cursor payload_cur = aws_byte_cursor_from_string(s_payload_text); + struct aws_secure_tunnel_message_view data_message_view = { + .type = AWS_SECURE_TUNNEL_MT_DATA, + .service_id = &service_1, + .stream_id = 1, + .payload = &payload_cur, + .connection_id = 3, + }; + aws_secure_tunnel_send_mock_message(&test_fixture, &data_message_view); + + /* On receipt of an unexpected protocol version message, Client should disconnect/reconnect and clear all streams */ + + s_wait_for_connection_shutdown(&test_fixture); + s_wait_for_connected_successfully(&test_fixture); + + /* Check that the established stream is cleared */ + ASSERT_TRUE(s_secure_tunnel_check_active_stream_id(secure_tunnel, NULL, 0)); + + ASSERT_SUCCESS(aws_secure_tunnel_stop(secure_tunnel)); + s_wait_for_connection_shutdown(&test_fixture); + + aws_secure_tunnel_mock_test_clean_up(&test_fixture); + + return AWS_OP_SUCCESS; +} + +AWS_TEST_CASE( + secure_tunneling_v1_stream_start_v3_message_reset_test, + s_secure_tunneling_v1_stream_start_v3_message_reset_test_fn) + +static int s_secure_tunneling_v2_stream_start_connection_start_reset_test_fn( + struct aws_allocator *allocator, + void *ctx) { + (void)ctx; + struct secure_tunnel_test_options test_options; + struct aws_secure_tunnel_mock_test_fixture test_fixture; + aws_secure_tunnel_mock_test_init(allocator, &test_options, &test_fixture); + struct aws_secure_tunnel *secure_tunnel = test_fixture.secure_tunnel; + + ASSERT_SUCCESS(aws_secure_tunnel_start(secure_tunnel)); + s_wait_for_connected_successfully(&test_fixture); + + /* Create and send a stream start message from the server to the destination client */ + struct aws_byte_cursor service_1 = aws_byte_cursor_from_string(s_service_id_1); + struct aws_secure_tunnel_message_view stream_start_message_view = { + .type = AWS_SECURE_TUNNEL_MT_STREAM_START, + .service_id = &service_1, + .stream_id = 1, + }; + aws_secure_tunnel_send_mock_message(&test_fixture, &stream_start_message_view); + + /* Wait and confirm that a stream has been started */ + s_wait_for_stream_started(&test_fixture); + ASSERT_TRUE(s_secure_tunnel_check_active_stream_id(secure_tunnel, &service_1, 1)); + + struct aws_secure_tunnel_message_view connection_start_message_view = { + .type = AWS_SECURE_TUNNEL_MT_CONNECTION_START, + .service_id = &service_1, + .stream_id = 1, + .connection_id = 3, + }; + aws_secure_tunnel_send_mock_message(&test_fixture, &connection_start_message_view); + + /* Client should disconnect and reconnect with no active streams on receiving a wrong version connection start */ + + s_wait_for_connection_shutdown(&test_fixture); + s_wait_for_connected_successfully(&test_fixture); + + /* pause to process a new connection */ + aws_thread_current_sleep(aws_timestamp_convert(1, AWS_TIMESTAMP_SECS, AWS_TIMESTAMP_NANOS, NULL)); + + /* Check that the established stream is cleared */ + ASSERT_TRUE(s_secure_tunnel_check_active_stream_id(secure_tunnel, &service_1, 0)); + + ASSERT_SUCCESS(aws_secure_tunnel_stop(secure_tunnel)); + s_wait_for_connection_shutdown(&test_fixture); + + aws_secure_tunnel_mock_test_clean_up(&test_fixture); + + return AWS_OP_SUCCESS; +} + +AWS_TEST_CASE( + secure_tunneling_v2_stream_start_connection_start_reset_test, + s_secure_tunneling_v2_stream_start_connection_start_reset_test_fn) + +static int s_secure_tunneling_ignore_outbound_inactive_connection_id_message_sending_test_fn( + struct aws_allocator *allocator, + void *ctx) { + (void)ctx; + struct secure_tunnel_test_options test_options; + struct aws_secure_tunnel_mock_test_fixture test_fixture; + aws_secure_tunnel_mock_test_init(allocator, &test_options, &test_fixture); + struct aws_secure_tunnel *secure_tunnel = test_fixture.secure_tunnel; + + ASSERT_SUCCESS(aws_secure_tunnel_start(secure_tunnel)); + s_wait_for_connected_successfully(&test_fixture); + + /* Create and send a stream start message from the server to the destination client */ + struct aws_byte_cursor service_1 = aws_byte_cursor_from_string(s_service_id_1); + struct aws_secure_tunnel_message_view stream_start_message_view = { + .type = AWS_SECURE_TUNNEL_MT_STREAM_START, + .service_id = &service_1, + .stream_id = 1, + .connection_id = 2, + }; + aws_secure_tunnel_send_mock_message(&test_fixture, &stream_start_message_view); + + /* Wait and confirm that a stream has been started */ + s_wait_for_stream_started(&test_fixture); + ASSERT_TRUE(s_secure_tunnel_check_active_connection_id(secure_tunnel, &service_1, 1, 2)); + + /* Create and send a data message from the server to the destination client to an inactive connection id */ + struct aws_byte_cursor payload_cur = aws_byte_cursor_from_string(s_payload_text); + struct aws_secure_tunnel_message_view data_message_view = { + .type = AWS_SECURE_TUNNEL_MT_DATA, + .service_id = &service_1, + .payload = &payload_cur, + .connection_id = 3, + }; + aws_secure_tunnel_send_message(secure_tunnel, &data_message_view); + + /* Confirm that no messages have gone out from the client */ + aws_thread_current_sleep(aws_timestamp_convert(1, AWS_TIMESTAMP_SECS, AWS_TIMESTAMP_NANOS, NULL)); + ASSERT_INT_EQUALS(test_fixture.secure_tunnel_message_sent_count, 0); + + ASSERT_SUCCESS(aws_secure_tunnel_stop(secure_tunnel)); + s_wait_for_connection_shutdown(&test_fixture); + + aws_secure_tunnel_mock_test_clean_up(&test_fixture); + + return AWS_OP_SUCCESS; +} + +AWS_TEST_CASE( + secure_tunneling_ignore_outbound_inactive_connection_id_message_sending_test, + s_secure_tunneling_ignore_outbound_inactive_connection_id_message_sending_test_fn) + +static int s_secure_tunneling_close_stream_on_connection_reset_test_fn(struct aws_allocator *allocator, void *ctx) { + (void)ctx; + struct secure_tunnel_test_options test_options; + struct aws_secure_tunnel_mock_test_fixture test_fixture; + aws_secure_tunnel_mock_test_init(allocator, &test_options, &test_fixture); + struct aws_secure_tunnel *secure_tunnel = test_fixture.secure_tunnel; + + ASSERT_SUCCESS(aws_secure_tunnel_start(secure_tunnel)); + s_wait_for_connected_successfully(&test_fixture); + + /* Create and send a v3 stream start message from the server to the destination client */ + struct aws_byte_cursor service_1 = aws_byte_cursor_from_string(s_service_id_1); + struct aws_secure_tunnel_message_view stream_start_message_view = { + .type = AWS_SECURE_TUNNEL_MT_STREAM_START, + .service_id = &service_1, + .stream_id = 1, + .connection_id = 2, + }; + aws_secure_tunnel_send_mock_message(&test_fixture, &stream_start_message_view); + + /* Wait and confirm that a stream has been started */ + s_wait_for_stream_started(&test_fixture); + ASSERT_TRUE(s_secure_tunnel_check_active_connection_id(secure_tunnel, &service_1, 1, 2)); + + /* Send a connection start */ + struct aws_secure_tunnel_message_view connection_start_message_view = { + .type = AWS_SECURE_TUNNEL_MT_CONNECTION_START, + .service_id = &service_1, + .stream_id = 1, + .connection_id = 3, + }; + aws_secure_tunnel_send_mock_message(&test_fixture, &connection_start_message_view); + + s_wait_for_connection_started(&test_fixture); + /* Check that connections has been started */ + ASSERT_TRUE(s_secure_tunnel_check_active_connection_id(secure_tunnel, &service_1, 1, 3)); + + /* Send a connection reset */ + struct aws_secure_tunnel_message_view connection_reset_message_view = { + .type = AWS_SECURE_TUNNEL_MT_CONNECTION_RESET, + .service_id = &service_1, + .stream_id = 1, + .connection_id = 3, + }; + aws_secure_tunnel_send_mock_message(&test_fixture, &connection_reset_message_view); + + s_wait_for_connection_reset_received(&test_fixture); + + /* Check that connection has been closed */ + ASSERT_FALSE(s_secure_tunnel_check_active_connection_id(secure_tunnel, &service_1, 1, 3)); + + ASSERT_SUCCESS(aws_secure_tunnel_stop(secure_tunnel)); + s_wait_for_connection_shutdown(&test_fixture); + + aws_secure_tunnel_mock_test_clean_up(&test_fixture); + + return AWS_OP_SUCCESS; +} + +AWS_TEST_CASE( + secure_tunneling_close_stream_on_connection_reset_test, + s_secure_tunneling_close_stream_on_connection_reset_test_fn) + +static int s_secure_tunneling_existing_connection_start_send_reset_test_fn(struct aws_allocator *allocator, void *ctx) { + (void)ctx; + struct secure_tunnel_test_options test_options; + struct aws_secure_tunnel_mock_test_fixture test_fixture; + aws_secure_tunnel_mock_test_init(allocator, &test_options, &test_fixture); + struct aws_secure_tunnel *secure_tunnel = test_fixture.secure_tunnel; + + ASSERT_SUCCESS(aws_secure_tunnel_start(secure_tunnel)); + s_wait_for_connected_successfully(&test_fixture); + + /* Create and send a v3 stream start message from the server to the destination client */ + struct aws_byte_cursor service_1 = aws_byte_cursor_from_string(s_service_id_1); + struct aws_secure_tunnel_message_view stream_start_message_view = { + .type = AWS_SECURE_TUNNEL_MT_STREAM_START, + .service_id = &service_1, + .stream_id = 1, + .connection_id = 2, + }; + aws_secure_tunnel_send_mock_message(&test_fixture, &stream_start_message_view); + + /* Wait and confirm that a stream has been started */ + s_wait_for_stream_started(&test_fixture); + + /* Send a CONNECTION START on existing connection id */ + struct aws_secure_tunnel_message_view connection_start_message_view = { + .type = AWS_SECURE_TUNNEL_MT_CONNECTION_START, + .service_id = &service_1, + .stream_id = 1, + .connection_id = 2, + }; + aws_secure_tunnel_send_mock_message(&test_fixture, &connection_start_message_view); + + /* Wait and confirm that a bad connection request was received */ + s_wait_for_bad_connection_started(&test_fixture); + + s_wait_for_connection_reset_message_sent(&test_fixture); + + /* check that stream with connection id has been closed properly */ + struct aws_hash_element *elem = NULL; + aws_hash_table_find(&secure_tunnel->connections->service_ids, stream_start_message_view.service_id, &elem); + ASSERT_NOT_NULL(elem); + struct aws_service_id_element *service_id_elem = elem->value; + ASSERT_INT_EQUALS((int)aws_hash_table_get_entry_count(&service_id_elem->connection_ids), 0); + + ASSERT_SUCCESS(aws_secure_tunnel_stop(secure_tunnel)); + s_wait_for_connection_shutdown(&test_fixture); + + aws_secure_tunnel_mock_test_clean_up(&test_fixture); + + return AWS_OP_SUCCESS; +} + +AWS_TEST_CASE( + secure_tunneling_existing_connection_start_send_reset_test, + s_secure_tunneling_existing_connection_start_send_reset_test_fn)