diff --git a/CMakeLists.txt b/CMakeLists.txt index 75e013d98f..d26f513a30 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -57,6 +57,7 @@ set(SOURCE_FILES lib/basic_types.c lib/error.h lib/error.c + lib/snowflake_util.h lib/client_int.h lib/chunk_downloader.h lib/chunk_downloader.c @@ -97,6 +98,7 @@ set (SOURCE_FILES_PUT_GET cpp/util/Proxy.cpp cpp/util/ThreadPool.hpp cpp/util/SnowflakeCommon.hpp + cpp/util/entities.cpp cpp/crypto/CryptoTypes.hpp cpp/crypto/Cryptor.hpp cpp/crypto/CipherContext.hpp @@ -129,9 +131,9 @@ set (SOURCE_FILES_PUT_GET include/snowflake/ITransferResult.hpp include/snowflake/PutGetParseResponse.hpp include/snowflake/SnowflakeTransferException.hpp - include/snowflake/IJwt.hpp include/snowflake/IBase64.hpp include/snowflake/Proxy.hpp + include/snowflake/entities.hpp ) set(SOURCE_FILES_CPP_WRAPPER @@ -146,6 +148,9 @@ set(SOURCE_FILES_CPP_WRAPPER include/snowflake/SFURL.hpp include/snowflake/CurlDesc.hpp include/snowflake/CurlDescPool.hpp + include/snowflake/IJwt.hpp + include/snowflake/IAuth.hpp + cpp/lib/SnowflakeUtil.cpp cpp/lib/Exceptions.cpp cpp/lib/Connection.cpp cpp/lib/Statement.cpp @@ -167,7 +172,8 @@ set(SOURCE_FILES_CPP_WRAPPER cpp/lib/ResultSetJson.cpp cpp/lib/ResultSetJson.hpp cpp/lib/Authenticator.cpp - cpp/lib/Authenticator.hpp + cpp/lib/Authenticator.hpp + cpp/lib/IAuth.cpp cpp/jwt/jwtWrapper.cpp cpp/util/SnowflakeCommon.cpp cpp/util/SFURL.cpp @@ -337,6 +343,7 @@ if (LINUX) deps-build/${PLATFORM}/${CMAKE_BUILD_TYPE}/azure/include deps-build/${PLATFORM}/${CMAKE_BUILD_TYPE}/cmocka/include deps-build/${PLATFORM}/${CMAKE_BUILD_TYPE}/uuid/include + deps-build/${PLATFORM}/${CMAKE_BUILD_TYPE}/picojson/include include lib) endif() @@ -352,6 +359,7 @@ if (APPLE) deps-build/${PLATFORM}/${CMAKE_BUILD_TYPE}/aws/include deps-build/${PLATFORM}/${CMAKE_BUILD_TYPE}/azure/include deps-build/${PLATFORM}/${CMAKE_BUILD_TYPE}/cmocka/include + deps-build/${PLATFORM}/${CMAKE_BUILD_TYPE}/picojson/include include lib) endif() @@ -366,6 +374,7 @@ if (WIN32) deps-build/${PLATFORM}/${VSDIR}/${CMAKE_BUILD_TYPE}/aws/include deps-build/${PLATFORM}/${VSDIR}/${CMAKE_BUILD_TYPE}/azure/include deps-build/${PLATFORM}/${VSDIR}/${CMAKE_BUILD_TYPE}/cmocka/include + deps-build/${PLATFORM}/${VSDIR}/${CMAKE_BUILD_TYPE}/picojson/include include lib) if (CMAKE_SIZEOF_VOID_P EQUAL 8) diff --git a/cpp/lib/Authenticator.cpp b/cpp/lib/Authenticator.cpp index d12bc194a3..9a6146f647 100644 --- a/cpp/lib/Authenticator.cpp +++ b/cpp/lib/Authenticator.cpp @@ -12,10 +12,17 @@ #include "Authenticator.hpp" #include "../logger/SFLogger.hpp" #include "error.h" +#include "../include/snowflake/entities.hpp" #include #include #include +#include +#include "curl_desc_pool.h" +#include "snowflake/Exceptions.hpp" +#include "cJSON.h" +#include "memory.h" +#include "../cpp/jwt/Util.hpp" #include @@ -31,10 +38,10 @@ #define strcasecmp _stricmp #endif -#define JWT_THROW(msg) \ -{ \ - throw Snowflake::Client::Jwt::JwtException(msg); \ -} +#define RETRY_THROW(elapsedSeconds, retriedCount) \ +{ \ + throw RenewTimeoutException(elapsedSeconds, retriedCount, false);\ +} // wrapper functions for C extern "C" { @@ -55,10 +62,10 @@ extern "C" { return AUTH_OAUTH; } - return AUTH_UNSUPPORTED; + return AUTH_OKTA; } - SF_STATUS STDCALL auth_initialize(SF_CONNECT * conn) + SF_STATUS STDCALL auth_initialize(SF_CONNECT *conn) { if (!conn) { @@ -73,6 +80,11 @@ extern "C" { conn->auth_object = static_cast( new Snowflake::Client::AuthenticatorJWT(conn)); } + else if (AUTH_OKTA == auth_type) + { + conn->auth_object = static_cast( + new Snowflake::Client::AuthenticatorOKTA(conn)); + } } catch (...) { @@ -85,7 +97,7 @@ extern "C" { return SF_STATUS_SUCCESS; } - int64 auth_get_renew_timeout(SF_CONNECT * conn) + int64 auth_get_renew_timeout(SF_CONNECT *conn) { if (!conn || !conn->auth_object) { @@ -103,7 +115,7 @@ extern "C" { } } - SF_STATUS STDCALL auth_authenticate(SF_CONNECT * conn) + SF_STATUS STDCALL auth_authenticate(SF_CONNECT *conn) { if (!conn || !conn->auth_object) { @@ -116,32 +128,24 @@ extern "C" { } catch (...) { - SET_SNOWFLAKE_ERROR(&conn->error, SF_STATUS_ERROR_GENERAL, - "authentication failed", - SF_SQLSTATE_GENERAL_ERROR); return SF_STATUS_ERROR_GENERAL; } return SF_STATUS_SUCCESS; } - void auth_update_json_body(SF_CONNECT * conn, cJSON* body) + void auth_update_json_body(SF_CONNECT *conn, cJSON* body) { - if (AUTH_OAUTH == getAuthenticatorType(conn->authenticator)) + cJSON* data = snowflake_cJSON_GetObjectItem(body, "data"); + if (!data) { - cJSON* data = snowflake_cJSON_GetObjectItem(body, "data"); - if (!data) - { - data = snowflake_cJSON_CreateObject(); - snowflake_cJSON_AddItemToObject(body, "data", data); - } - - snowflake_cJSON_DeleteItemFromObject(data, "AUTHENTICATOR"); - snowflake_cJSON_AddStringToObject(data, "AUTHENTICATOR", SF_AUTHENTICATOR_OAUTH); - snowflake_cJSON_DeleteItemFromObject(data, "TOKEN"); - snowflake_cJSON_AddStringToObject(data, "TOKEN", conn->oauth_token); + data = snowflake_cJSON_CreateObject(); + snowflake_cJSON_AddItemToObject(body, "data", data); } + snowflake_cJSON_DeleteItemFromObject(data, "AUTHENTICATOR"); + snowflake_cJSON_DeleteItemFromObject(data, "TOKEN"); + if (!conn || !conn->auth_object) { return; @@ -149,8 +153,11 @@ extern "C" { try { - static_cast(conn->auth_object)-> - updateDataMap(body); + jsonObject_t picoBody; + cJSONtoPicoJson(data, picoBody); + static_cast(conn->auth_object)-> + updateDataMap(picoBody); + picoJsonTocJson(picoBody, &body); } catch (...) { @@ -160,7 +167,7 @@ extern "C" { return; } - void auth_renew_json_body(SF_CONNECT * conn, cJSON* body) + void auth_renew_json_body(SF_CONNECT *conn, cJSON* body) { if (!conn || !conn->auth_object) { @@ -169,8 +176,12 @@ extern "C" { try { - static_cast(conn->auth_object)-> - renewDataMap(body); + jsonObject_t picoBody; + cJSON* data = snowflake_cJSON_GetObjectItem(body, "data"); + cJSONtoPicoJson(data, picoBody); + static_cast(conn->auth_object)-> + renewDataMap(picoBody); + picoJsonTocJson(picoBody, &body); } catch (...) { @@ -180,7 +191,7 @@ extern "C" { return; } - void STDCALL auth_terminate(SF_CONNECT * conn) + void STDCALL auth_terminate(SF_CONNECT *conn) { if (!conn || !conn->auth_object) { @@ -190,10 +201,7 @@ extern "C" { AuthenticatorType auth_type = getAuthenticatorType(conn->authenticator); try { - if (AUTH_JWT == auth_type) - { - delete static_cast(conn->auth_object); - } + delete static_cast(conn->auth_object); } catch (...) { @@ -202,19 +210,13 @@ extern "C" { return; } - } // extern "C" namespace Snowflake { namespace Client { - - void IAuthenticator::renewDataMap(cJSON *dataMap) - { - authenticate(); - updateDataMap(dataMap); - } + using namespace picojson; void AuthenticatorJWT::loadPrivateKey(const std::string &privateKeyFile, const std::string &passcode) @@ -223,7 +225,7 @@ namespace Client if (sf_fopen(&file, privateKeyFile.c_str(), "r") == nullptr) { CXX_LOG_ERROR("Failed to open private key file. Errno: %d", errno); - JWT_THROW("Failed to open private key file"); + AUTH_THROW("Failed to open private key file"); } m_privKey = PEM_read_PrivateKey(file, nullptr, nullptr, (void *)passcode.c_str()); @@ -231,7 +233,7 @@ namespace Client if (m_privKey == nullptr) { CXX_LOG_ERROR("Loading private key from %s failed", privateKeyFile.c_str()); - JWT_THROW("Marshaling private key failed"); + AUTH_THROW("Marshaling private key failed"); } } @@ -248,31 +250,37 @@ namespace Client { privKeyFilePwd = conn->priv_key_file_pwd; } - loadPrivateKey(privKeyFile, privKeyFilePwd); - m_timeOut = conn->jwt_timeout; - m_renewTimeout = conn->jwt_cnxn_wait_time; + try { + loadPrivateKey(privKeyFile, privKeyFilePwd); + m_timeOut = conn->jwt_timeout; + m_renewTimeout = conn->jwt_cnxn_wait_time; - // Prepare header - std::shared_ptr
header {Header::buildHeader()}; - header->setAlgorithm(Snowflake::Client::Jwt::AlgorithmType::RS256); - m_jwt->setHeader(std::move(header)); + // Prepare header + std::shared_ptr
header{ Header::buildHeader() }; + header->setAlgorithm(Snowflake::Client::Jwt::AlgorithmType::RS256); + m_jwt->setHeader(std::move(header)); - // Prepare claim set - std::shared_ptr claimSet {ClaimSet::buildClaimSet()}; + // Prepare claim set + std::shared_ptr claimSet{ ClaimSet::buildClaimSet() }; - std::string account = conn->account; - std::string user = conn->user; - for (char &c : account) c = std::toupper(c); - for (char &c : user) c = std::toupper(c); + std::string account = conn->account; + std::string user = conn->user; + for (char& c : account) c = std::toupper(c); + for (char& c : user) c = std::toupper(c); - // issuer - std::string subject = account + '.'; - subject += user; - claimSet->addClaim("sub", subject); + // issuer + std::string subject = account + '.'; + subject += user; + claimSet->addClaim("sub", subject); - std::string issuer = subject + ".SHA256:" + extractPublicKey(m_privKey); - claimSet->addClaim("iss", issuer); - m_jwt->setClaimSet(std::move(claimSet)); + std::string issuer = subject + ".SHA256:" + extractPublicKey(m_privKey); + claimSet->addClaim("iss", issuer); + m_jwt->setClaimSet(std::move(claimSet)); + } + catch (AuthException& e) { + SET_SNOWFLAKE_ERROR(&conn->error, SF_STATUS_ERROR_GENERAL, e.message_.c_str(), SF_SQLSTATE_GENERAL_ERROR); + AUTH_THROW("JWT Authentication failed"); + } } AuthenticatorJWT::~AuthenticatorJWT() @@ -294,18 +302,10 @@ namespace Client claimSet->addClaim("exp", (long)seconds.count() + m_timeOut); } - void AuthenticatorJWT::updateDataMap(cJSON* dataMap) + void AuthenticatorJWT::updateDataMap(jsonObject_t& dataMap) { - cJSON* data = snowflake_cJSON_GetObjectItem(dataMap, "data"); - if (!data) - { - data = snowflake_cJSON_CreateObject(); - snowflake_cJSON_AddItemToObject(dataMap, "data", data); - } - snowflake_cJSON_DeleteItemFromObject(data, "AUTHENTICATOR"); - snowflake_cJSON_DeleteItemFromObject(data, "TOKEN"); - snowflake_cJSON_AddStringToObject(data, "AUTHENTICATOR", SF_AUTHENTICATOR_JWT); - snowflake_cJSON_AddStringToObject(data, "TOKEN", m_jwt->serialize(m_privKey).c_str()); + dataMap["AUTHENTICATOR"] = picojson::value(SF_AUTHENTICATOR_JWT); + dataMap["TOKEN"] = picojson::value(m_jwt->serialize(m_privKey)); } std::string AuthenticatorJWT::extractPublicKey(EVP_PKEY *privKey) @@ -316,7 +316,7 @@ namespace Client if (size < 0) { CXX_LOG_ERROR("Fail to extract public key"); - JWT_THROW("Public Key extract failed"); + AUTH_THROW("Public Key extract failed"); } std::vector pubKeyBytes(out, out + size); OPENSSL_free(out); @@ -335,19 +335,19 @@ namespace Client if (mdctx == nullptr) { CXX_LOG_ERROR("EVP context create failed."); - JWT_THROW("EVP context create failed"); + AUTH_THROW("EVP context create failed"); } if (1 != EVP_DigestInit_ex(mdctx.get(), EVP_sha256(), nullptr)) { CXX_LOG_ERROR("Digest Init failed."); - JWT_THROW("Digest Init failed"); + AUTH_THROW("Digest Init failed"); } if (1 != EVP_DigestUpdate(mdctx.get(), message.data(), message.size())) { CXX_LOG_ERROR("Digest update failed."); - JWT_THROW("Digest update failed"); + AUTH_THROW("Digest update failed"); } std::vector coded(EVP_MD_size(EVP_sha256())); @@ -356,12 +356,206 @@ namespace Client if (1 != EVP_DigestFinal_ex(mdctx.get(), (unsigned char *)coded.data(), &code_size)) { CXX_LOG_ERROR("Digest final failed."); - JWT_THROW("Digest final failed"); + AUTH_THROW("Digest final failed"); } coded.resize(code_size); return coded; } + + void AuthenticatorOKTA::curl_post_call(SFURL& url, const jsonObject_t& obj, jsonObject_t& resp) + { + std::string destination = url.toString(); + void* curl_desc; + CURL* curl; + curl_desc = get_curl_desc_from_pool(destination.c_str(), m_connection->proxy, m_connection->no_proxy); + curl = get_curl_from_desc(curl_desc); + SF_ERROR_STRUCT* err = &m_connection->error; + + int64 elapsedTime = 0; + int8 maxRetryCount = get_login_retry_count(m_connection); + int64 renewTimeout = auth_get_renew_timeout(m_connection); + + // add headers for account and authentication + SF_HEADER* httpExtraHeaders = sf_header_create(); + std::string s_body = value(obj).serialize(); + cJSON* resp_data = NULL; + try { + httpExtraHeaders->use_application_json_accept_type = SF_BOOLEAN_TRUE; + if (!create_header(m_connection, httpExtraHeaders, &m_connection->error)) { + log_trace("sf", "IDPAuthenticator", + "post_curl_call", + "Failed to create the header for the request to get the token URL and the SSO URL"); + SET_SNOWFLAKE_ERROR(err, SF_STATUS_ERROR_GENERAL, "OktaConnectionFailed: failed to create the header", SF_SQLSTATE_GENERAL_ERROR); + AUTH_THROW(err); + } + + if (!::curl_post_call(m_connection, curl, (char*)destination.c_str(), httpExtraHeaders, (char*)s_body.c_str(), + &resp_data, err, renewTimeout, maxRetryCount, m_retryTimeout, &elapsedTime, + &m_retriedCount, NULL, SF_BOOLEAN_TRUE)) + { + log_info("sf", "IDPAuthenticator", "post_curl_call", + "Fail to get authenticator info, response body=%s\n", + snowflake_cJSON_Print(snowflake_cJSON_GetObjectItem(resp_data, "data"))); + SET_SNOWFLAKE_ERROR(err, SF_STATUS_ERROR_GENERAL, "SFConnectionFailed", SF_SQLSTATE_GENERAL_ERROR); + AUTH_THROW(err); + } + + if (elapsedTime >= m_retryTimeout) + { + CXX_LOG_WARN("sf", "IDPAuthenticator", "post_curl_call", + "timeout reached: %d, elapsed time: %d", + m_retryTimeout, elapsedTime); + SET_SNOWFLAKE_ERROR(err, SF_STATUS_ERROR_REQUEST_TIMEOUT, "OktaConnectionFailed: timeout reached", SF_SQLSTATE_GENERAL_ERROR); + AUTH_THROW(err); + } + cJSONtoPicoJson(resp_data, resp); + } + catch (AuthException& e) { + // just to escape from the try block + } + + //Clean up resources + m_retryTimeout -= elapsedTime; + sf_header_destroy(httpExtraHeaders); + free_curl_desc(curl_desc); + snowflake_cJSON_Delete(resp_data); + if (err->error_code != SF_STATUS_SUCCESS) { + AUTH_THROW(err); + } + } + + + AuthenticatorOKTA::AuthenticatorOKTA( + SF_CONNECT* connection) : m_connection(connection) + { + m_account = m_connection->account; + m_authenticator = m_connection->authenticator; + m_user = m_connection->user; + m_password = m_connection->password; + m_port = m_connection->port; + m_host = m_connection->host; + m_protocol = m_connection->protocol; + m_disableSamlUrlCheck = m_connection->disable_saml_url_check; + m_retriedCount = get_login_retry_count(m_connection); + m_retryTimeout = get_retry_timeout(m_connection); + + //m_appID = m_connection->application_name; + //m_appVersion = m_connection->application_version; + m_appID = "ODBC"; + m_appVersion = "3.4.1"; + } + + AuthenticatorOKTA::~AuthenticatorOKTA() + { + // nop + } + + void AuthenticatorOKTA::curl_get_call(SFURL& url, jsonObject_t& resp, bool parseJSON, std::string& rawData) + { + bool isRetry = false; + int64 maxRetryCount = get_login_retry_count(m_connection); + int64 elapsedTime = 0; + int64 renewTimeout = auth_get_renew_timeout(m_connection); + int64 curlTimeout = m_connection->network_timeout; + + std::string destination = url.toString(); + void* curl_desc; + CURL* curl; + curl_desc = get_curl_desc_from_pool(destination.c_str(), m_connection->proxy, m_connection->no_proxy); + curl = get_curl_from_desc(curl_desc); + + SF_ERROR_STRUCT *err = &m_connection->error; + + NON_JSON_RESP* raw_resp = (NON_JSON_RESP*) SF_MALLOC(sizeof(NON_JSON_RESP)); + raw_resp->write_callback = non_json_resp_write_callback; + RAW_JSON_BUFFER buf = { NULL,0 }; + raw_resp->buffer = (void*)&buf; + + // add headers for account and authentication + SF_HEADER* httpExtraHeaders = sf_header_create(); + httpExtraHeaders->use_application_json_accept_type = SF_BOOLEAN_TRUE; + + try { + if (!create_header(m_connection, httpExtraHeaders, &m_connection->error)) { + log_trace("sf", "AuthenticatorOKTA", + "get_curl_call", + "Failed to create the header for the request to get onetime token"); + SET_SNOWFLAKE_ERROR(err, SF_STATUS_ERROR_GENERAL, "OktaConnectionFailed: failed to create the header", SF_SQLSTATE_GENERAL_ERROR); + AUTH_THROW(err); + } + + if (!http_perform(curl, GET_REQUEST_TYPE, (char*)destination.c_str(), httpExtraHeaders, NULL, NULL, raw_resp, + curlTimeout, SF_BOOLEAN_FALSE, err, + m_connection->insecure_mode, m_connection->ocsp_fail_open, + m_connection->retry_on_curle_couldnt_connect_count, + renewTimeout, maxRetryCount, &elapsedTime, &m_retriedCount, NULL, SF_BOOLEAN_FALSE, + m_connection->proxy, m_connection->no_proxy, SF_BOOLEAN_FALSE, SF_BOOLEAN_FALSE)) + { + //Fail to get the saml response. Retry. + isRetry = true; + AUTH_THROW("retry"); + } + + if (elapsedTime >= m_retryTimeout) + { + CXX_LOG_WARN("sf", "AuthenticatorOKTA", "get_curl_call", + "Fail to get SAML response, timeout reached: %d, elapsed time: %d", + m_retryTimeout, elapsedTime); + + SET_SNOWFLAKE_ERROR(err, SF_STATUS_ERROR_REQUEST_TIMEOUT, "OktaConnectionFailed: timeout reached", SF_SQLSTATE_GENERAL_ERROR); + AUTH_THROW(err); + } + rawData = buf.buffer; + } + catch (AuthException& e) { + // just to escape from the try block + } + + m_retryTimeout -= elapsedTime; + sf_header_destroy(httpExtraHeaders); + free_curl_desc(curl_desc); + SF_FREE(raw_resp); + if (isRetry) { + RETRY_THROW(elapsedTime, m_retriedCount); + } + if (err->error_code != SF_STATUS_SUCCESS) { + AUTH_THROW(err); + } + } + + void AuthenticatorOKTA::authenticate() + { + try + { + IAuthenticatorOKTA::authenticate(); + } + catch (AuthException& e) + { + SF_ERROR_STRUCT* err = &m_connection->error; + if (!err) + { + std::string timeoutErr = "timeout"; + if (timeoutErr.compare(e.what())) + { + SET_SNOWFLAKE_ERROR(err, SF_STATUS_ERROR_REQUEST_TIMEOUT, "OktaConnectionFailed: timeout reached", SF_SQLSTATE_GENERAL_ERROR); + } + else + { + SET_SNOWFLAKE_ERROR(err, SF_STATUS_ERROR_BAD_REQUEST, e.what(), SF_SQLSTATE_GENERAL_ERROR); + } + } + AUTH_THROW(e.what()); + } + } + + void AuthenticatorOKTA::updateDataMap(jsonObject_t& dataMap) + { + dataMap.erase("LOGIN_NAME"); + dataMap.erase("PASSWORD"); + dataMap.erase("EXT_AUTHN_DUO_METHOD"); + IAuthenticatorOKTA::updateDataMap(dataMap); + } } // namespace Client } // namespace Snowflake diff --git a/cpp/lib/Authenticator.hpp b/cpp/lib/Authenticator.hpp index 4c9dde6690..7527d6e91c 100644 --- a/cpp/lib/Authenticator.hpp +++ b/cpp/lib/Authenticator.hpp @@ -15,46 +15,16 @@ #include "snowflake/IJwt.hpp" #include "snowflake/IBase64.hpp" #include "authenticator.h" -#include "cJSON.h" +#include "picojson.h" +#include "snowflake/SFURL.hpp" +#include "../../lib/snowflake_util.h" +#include "../include/snowflake/IAuth.hpp" namespace Snowflake { namespace Client { - /** - * Authenticator - */ - class IAuthenticator - { - public: - - IAuthenticator() : m_renewTimeout(0) - {} - - virtual ~IAuthenticator() - {} - - virtual void authenticate()=0; - - virtual void updateDataMap(cJSON * dataMap)=0; - - // Retrieve authenticator renew timeout, return 0 if not available. - // When the authenticator renew timeout is available, the connection should - // renew the authentication (call renewDataMap) for each time the - // authenticator specific timeout exceeded within the entire login timeout. - int64 getAuthRenewTimeout() - { - return m_renewTimeout; - } - - // Renew the autentication and update datamap. - // The default behavior is to call authenticate() and updateDataMap(). - virtual void renewDataMap(cJSON * dataMap); - - protected: - int64 m_renewTimeout; - }; - + using namespace Snowflake::Client::IAuth; /** * JWT Authenticator */ @@ -67,7 +37,7 @@ namespace Client void authenticate(); - void updateDataMap(cJSON* dataMap); + void updateDataMap(jsonObject_t& dataMap); private: void loadPrivateKey(const std::string &privateKeyFile, const std::string &passcode); @@ -87,6 +57,21 @@ namespace Client static std::vector SHA256(const std::vector &message); }; + class AuthenticatorOKTA : public IAuthenticatorOKTA + { + public: + AuthenticatorOKTA(SF_CONNECT *conn); + + ~AuthenticatorOKTA(); + + void authenticate(); + void updateDataMap(jsonObject_t& dataMap); + void curl_post_call(SFURL& url, const jsonObject_t& body, jsonObject_t& resp); + void curl_get_call(SFURL& url, jsonObject_t& resp, bool parseJSON, std::string& raw_data); + + private: + SF_CONNECT* m_connection; + }; } // namespace Client } // namespace Snowflake #endif //PROJECT_AUTHENTICATOR_HPP diff --git a/cpp/lib/IAuth.cpp b/cpp/lib/IAuth.cpp new file mode 100644 index 0000000000..c8e48e1f4e --- /dev/null +++ b/cpp/lib/IAuth.cpp @@ -0,0 +1,167 @@ +/* + * Copyright (c) 2024 Snowflake Computing, Inc. All rights reserved. + */ + +#include +#include "snowflake/Exceptions.hpp" +#include "cJSON.h" +#include "../include/snowflake/entities.hpp" +#include "../logger/SFLogger.hpp" +#include "snowflake/IAuth.hpp" + +namespace Snowflake +{ +namespace Client +{ + namespace IAuth { + using namespace picojson; + + void IAuthenticator::renewDataMap(jsonObject_t& dataMap) + { + authenticate(); + updateDataMap(dataMap); + } + + void IDPAuthenticator::getIDPInfo() + { + jsonObject_t dataMap; + SFURL connectURL = getServerURLSync().path("/session/authenticator-request"); + dataMap["ACCOUNT_NAME"] = value(m_account); + dataMap["AUTHENTICATOR"] = value(m_authenticator); + dataMap["LOGIN_NAME"] = value(m_user); + dataMap["PORT"] = value(m_port); + dataMap["PROTOCOL"] = value(m_protocol); + dataMap["CLIENT_APP_ID"] = value(m_appID); + dataMap["CLIENT_APP_VERSION"] = value(m_appVersion);; + + jsonObject_t authnData, respData; + authnData["data"] = value(dataMap); + + curl_post_call(connectURL, authnData, respData); + jsonObject_t& data = respData["data"].get(); + tokenURLStr = data["tokenUrl"].get(); + ssoURLStr = data["ssoUrl"].get(); + } + + SFURL IDPAuthenticator::getServerURLSync() + { + SFURL url = SFURL().scheme(m_protocol) + .host(m_host) + .port(m_port); + + return url; + } + + void IAuthenticatorOKTA::authenticate() + { + // 1. get authenticator info + getIDPInfo(); + + // 2. verify ssoUrl and tokenUrl contains same prefix + if (!urlHasSamePrefix(tokenURLStr, m_authenticator)) + { + CXX_LOG_ERROR("sf", "AuthenticatorOKTA", "authenticate", + "The specified authenticator is not supported, " + "authenticator=%s, token url=%s, sso url=%s", + m_authenticator.c_str(), tokenURLStr.c_str(), ssoURLStr.c_str()); + AUTH_THROW("SFAuthenticatorVerificationFailed: the token URL does not have the same prefix with the authenticator"); + } + + // 3. get one time token from okta + while (true) + { + SFURL tokenURL = SFURL::parse(tokenURLStr); + + jsonObject_t dataMap, respData; + dataMap["username"] = picojson::value(m_user); + dataMap["password"] = picojson::value(m_password); + + try { + curl_post_call(tokenURL, dataMap, respData); + } + catch (...) + { + CXX_LOG_WARN("sf", "AuthenticatorOKTA", "getOneTimeToken", + "Fail to get one time token response, response body=%s", + picojson::value(respData).serialize().c_str()); + AUTH_THROW("Failed to get the one time token from Okta authentication.") + } + + oneTimeToken = respData["sessionToken"].get(); + if (oneTimeToken.empty()) { + oneTimeToken = respData["cookieToken"].get(); + } + // 4. get SAML response + try { + + jsonObject_t resp; + SFURL sso_url = SFURL::parse(ssoURLStr); + sso_url.addQueryParam("onetimetoken", oneTimeToken); + curl_get_call(sso_url, resp, false, m_samlResponse); + break; + } + catch (RenewTimeoutException& e) + { + int64 elapsedSeconds = e.getElapsedSeconds(); + + if (elapsedSeconds >= m_retryTimeout) + { + CXX_LOG_WARN("sf", "AuthenticatorOKTA", "getSamlResponse", + "Fail to get SAML response, timeout reached: %d, elapsed time: %d", + m_retryTimeout, elapsedSeconds); + + AUTH_THROW("timeout"); + } + + m_retriedCount = e.getRetriedCount(); + m_retryTimeout -= elapsedSeconds; + CXX_LOG_TRACE("sf", "Connection", "Connect", + "Retry on getting SAML response with one time token renewed for %d times " + "with updated retryTimeout = %d", + m_retriedCount, m_retryTimeout); + } + } + + // 5. Validate post_back_url matches Snowflake URL + std::string post_back_url = extractPostBackUrlFromSamlResponse(m_samlResponse); + std::string server_url = getServerURLSync().toString(); + if ((!m_disableSamlUrlCheck) && + (!urlHasSamePrefix(post_back_url, server_url))) + { + CXX_LOG_ERROR("sf", "AuthenticatorOKTA", "authenticate", + "The specified authenticator and destination URL in " + "Saml Assertion did not " + "match, expected=%s, post back=%s", + server_url.c_str(), + post_back_url.c_str()); + AUTH_THROW("SFSamlResponseVerificationFailed"); + } + } + + void IAuthenticatorOKTA::updateDataMap(jsonObject_t& dataMap) + { + dataMap["RAW_SAML_RESPONSE"] = picojson::value(m_samlResponse); + } + + std::string IAuthenticatorOKTA::extractPostBackUrlFromSamlResponse(std::string html) + { + std::size_t form_start = html.find("(); + } + + void picoJsonTocJson(jsonObject_t& picojson, cJSON** cjson) + { + std::string body_str = picojson::value(picojson).serialize(); + cJSON* new_body = snowflake_cJSON_Parse(body_str.c_str()); + snowflake_cJSON_DeleteItemFromObject(*cjson, "data"); + snowflake_cJSON_AddItemToObject(*cjson, "data", new_body); + } + + void strToPicoJson(jsonObject_t& picojson, std::string& str) + { + jsonValue_t v; + picojson::parse(v, str); + picojson = v.get(); + } + + bool urlHasSamePrefix(std::string url1, std::string url2) + { + SFURL parsed_url1 = SFURL::parse(url1); + SFURL parsed_url2 = SFURL::parse(url2); + + if (parsed_url1.port() == "" && parsed_url1.scheme() == "https") + { + parsed_url1.port("443"); + } + + if (parsed_url2.port() == "" && parsed_url2.scheme() == "https") + { + parsed_url2.port("443"); + } + + return parsed_url1.scheme() == parsed_url2.scheme() && + parsed_url1.host() == parsed_url2.host() && + parsed_url1.port() == parsed_url2.port(); + } + } +} +} \ No newline at end of file diff --git a/cpp/util/entities.cpp b/cpp/util/entities.cpp new file mode 100644 index 0000000000..0373e54e89 --- /dev/null +++ b/cpp/util/entities.cpp @@ -0,0 +1,415 @@ +/** + * Copyright 2012, 2016 Christoph Gärtner + * Distributed under the Boost Software License, Version 1.0 + */ + +#include "../include/snowflake/entities.hpp" +#include +#include +#include +#include "snowflake/SF_CRTFunctionSafe.h" + +#define UNICODE_MAX 0x10FFFFul + +namespace Snowflake +{ + namespace Client + { + + static const char* const NAMED_ENTITIES[][2] = { + {"AElig;", "Æ"}, + {"Aacute;", "Á"}, + {"Acirc;", "Â"}, + {"Agrave;", "À"}, + {"Alpha;", "Α"}, + {"Aring;", "Å"}, + {"Atilde;", "Ã"}, + {"Auml;", "Ä"}, + {"Beta;", "Β"}, + {"Ccedil;", "Ç"}, + {"Chi;", "Χ"}, + {"Dagger;", "‡"}, + {"Delta;", "Δ"}, + {"ETH;", "Ð"}, + {"Eacute;", "É"}, + {"Ecirc;", "Ê"}, + {"Egrave;", "È"}, + {"Epsilon;", "Ε"}, + {"Eta;", "Η"}, + {"Euml;", "Ë"}, + {"Gamma;", "Γ"}, + {"Iacute;", "Í"}, + {"Icirc;", "Î"}, + {"Igrave;", "Ì"}, + {"Iota;", "Ι"}, + {"Iuml;", "Ï"}, + {"Kappa;", "Κ"}, + {"Lambda;", "Λ"}, + {"Mu;", "Μ"}, + {"Ntilde;", "Ñ"}, + {"Nu;", "Ν"}, + {"OElig;", "Œ"}, + {"Oacute;", "Ó"}, + {"Ocirc;", "Ô"}, + {"Ograve;", "Ò"}, + {"Omega;", "Ω"}, + {"Omicron;", "Ο"}, + {"Oslash;", "Ø"}, + {"Otilde;", "Õ"}, + {"Ouml;", "Ö"}, + {"Phi;", "Φ"}, + {"Pi;", "Π"}, + {"Prime;", "″"}, + {"Psi;", "Ψ"}, + {"Rho;", "Ρ"}, + {"Scaron;", "Š"}, + {"Sigma;", "Σ"}, + {"THORN;", "Þ"}, + {"Tau;", "Τ"}, + {"Theta;", "Θ"}, + {"Uacute;", "Ú"}, + {"Ucirc;", "Û"}, + {"Ugrave;", "Ù"}, + {"Upsilon;", "Υ"}, + {"Uuml;", "Ü"}, + {"Xi;", "Ξ"}, + {"Yacute;", "Ý"}, + {"Yuml;", "Ÿ"}, + {"Zeta;", "Ζ"}, + {"aacute;", "á"}, + {"acirc;", "â"}, + {"acute;", "´"}, + {"aelig;", "æ"}, + {"agrave;", "à"}, + {"alefsym;", "ℵ"}, + {"alpha;", "α"}, + {"amp;", "&"}, + {"and;", "∧"}, + {"ang;", "∠"}, + {"apos;", "'"}, + {"aring;", "å"}, + {"asymp;", "≈"}, + {"atilde;", "ã"}, + {"auml;", "ä"}, + {"bdquo;", "„"}, + {"beta;", "β"}, + {"brvbar;", "¦"}, + {"bull;", "•"}, + {"cap;", "∩"}, + {"ccedil;", "ç"}, + {"cedil;", "¸"}, + {"cent;", "¢"}, + {"chi;", "χ"}, + {"circ;", "ˆ"}, + {"clubs;", "♣"}, + {"cong;", "≅"}, + {"copy;", "©"}, + {"crarr;", "↵"}, + {"cup;", "∪"}, + {"curren;", "¤"}, + {"dArr;", "⇓"}, + {"dagger;", "†"}, + {"darr;", "↓"}, + {"deg;", "°"}, + {"delta;", "δ"}, + {"diams;", "♦"}, + {"divide;", "÷"}, + {"eacute;", "é"}, + {"ecirc;", "ê"}, + {"egrave;", "è"}, + {"empty;", "∅"}, + {"emsp;", "\xE2\x80\x83"}, + {"ensp;", "\xE2\x80\x82"}, + {"epsilon;", "ε"}, + {"equiv;", "≡"}, + {"eta;", "η"}, + {"eth;", "ð"}, + {"euml;", "ë"}, + {"euro;", "€"}, + {"exist;", "∃"}, + {"fnof;", "ƒ"}, + {"forall;", "∀"}, + {"frac12;", "½"}, + {"frac14;", "¼"}, + {"frac34;", "¾"}, + {"frasl;", "⁄"}, + {"gamma;", "γ"}, + {"ge;", "≥"}, + {"gt;", ">"}, + {"hArr;", "⇔"}, + {"harr;", "↔"}, + {"hearts;", "♥"}, + {"hellip;", "…"}, + {"iacute;", "í"}, + {"icirc;", "î"}, + {"iexcl;", "¡"}, + {"igrave;", "ì"}, + {"image;", "ℑ"}, + {"infin;", "∞"}, + {"int;", "∫"}, + {"iota;", "ι"}, + {"iquest;", "¿"}, + {"isin;", "∈"}, + {"iuml;", "ï"}, + {"kappa;", "κ"}, + {"lArr;", "⇐"}, + {"lambda;", "λ"}, + {"lang;", "〈"}, + {"laquo;", "«"}, + {"larr;", "←"}, + {"lceil;", "⌈"}, + {"ldquo;", "“"}, + {"le;", "≤"}, + {"lfloor;", "⌊"}, + {"lowast;", "∗"}, + {"loz;", "◊"}, + {"lrm;", "\xE2\x80\x8E"}, + {"lsaquo;", "‹"}, + {"lsquo;", "‘"}, + {"lt;", "<"}, + {"macr;", "¯"}, + {"mdash;", "—"}, + {"micro;", "µ"}, + {"middot;", "·"}, + {"minus;", "−"}, + {"mu;", "μ"}, + {"nabla;", "∇"}, + {"nbsp;", "\xC2\xA0"}, + {"ndash;", "–"}, + {"ne;", "≠"}, + {"ni;", "∋"}, + {"not;", "¬"}, + {"notin;", "∉"}, + {"nsub;", "⊄"}, + {"ntilde;", "ñ"}, + {"nu;", "ν"}, + {"oacute;", "ó"}, + {"ocirc;", "ô"}, + {"oelig;", "œ"}, + {"ograve;", "ò"}, + {"oline;", "‾"}, + {"omega;", "ω"}, + {"omicron;", "ο"}, + {"oplus;", "⊕"}, + {"or;", "∨"}, + {"ordf;", "ª"}, + {"ordm;", "º"}, + {"oslash;", "ø"}, + {"otilde;", "õ"}, + {"otimes;", "⊗"}, + {"ouml;", "ö"}, + {"para;", "¶"}, + {"part;", "∂"}, + {"permil;", "‰"}, + {"perp;", "⊥"}, + {"phi;", "φ"}, + {"pi;", "π"}, + {"piv;", "ϖ"}, + {"plusmn;", "±"}, + {"pound;", "£"}, + {"prime;", "′"}, + {"prod;", "∏"}, + {"prop;", "∝"}, + {"psi;", "ψ"}, + {"quot;", "\""}, + {"rArr;", "⇒"}, + {"radic;", "√"}, + {"rang;", "〉"}, + {"raquo;", "»"}, + {"rarr;", "→"}, + {"rceil;", "⌉"}, + {"rdquo;", "”"}, + {"real;", "ℜ"}, + {"reg;", "®"}, + {"rfloor;", "⌋"}, + {"rho;", "ρ"}, + {"rlm;", "\xE2\x80\x8F"}, + {"rsaquo;", "›"}, + {"rsquo;", "’"}, + {"sbquo;", "‚"}, + {"scaron;", "š"}, + {"sdot;", "⋅"}, + {"sect;", "§"}, + {"shy;", "\xC2\xAD"}, + {"sigma;", "σ"}, + {"sigmaf;", "ς"}, + {"sim;", "∼"}, + {"spades;", "♠"}, + {"sub;", "⊂"}, + {"sube;", "⊆"}, + {"sum;", "∑"}, + {"sup1;", "¹"}, + {"sup2;", "²"}, + {"sup3;", "³"}, + {"sup;", "⊃"}, + {"supe;", "⊇"}, + {"szlig;", "ß"}, + {"tau;", "τ"}, + {"there4;", "∴"}, + {"theta;", "θ"}, + {"thetasym;", "ϑ"}, + {"thinsp;", "\xE2\x80\x89"}, + {"thorn;", "þ"}, + {"tilde;", "˜"}, + {"times;", "×"}, + {"trade;", "™"}, + {"uArr;", "⇑"}, + {"uacute;", "ú"}, + {"uarr;", "↑"}, + {"ucirc;", "û"}, + {"ugrave;", "ù"}, + {"uml;", "¨"}, + {"upsih;", "ϒ"}, + {"upsilon;", "υ"}, + {"uuml;", "ü"}, + {"weierp;", "℘"}, + {"xi;", "ξ"}, + {"yacute;", "ý"}, + {"yen;", "¥"}, + {"yuml;", "ÿ"}, + {"zeta;", "ζ"}, + {"zwj;", "\xE2\x80\x8D"}, + {"zwnj;", "\xE2\x80\x8C"} + }; + + static int cmp(const void* key, const void* value) + { + return strncmp((const char*)key, *(const char* const*)value, + strlen(*(const char* const*)value)); + } + +#if defined(_WIN32) || defined(_WIN64) + static int cmp_s(void* pvlocale, const void* key, const void* value) + { + //UNUSED(pvlocale); + return cmp(key, value); + } +#endif + + static const char* get_named_entity(const char* name) + { +#if defined(_WIN32) || defined(_WIN64) + const char* const* entity = (const char* const*)bsearch_s(name, + NAMED_ENTITIES, + sizeof NAMED_ENTITIES / sizeof * NAMED_ENTITIES, + sizeof * NAMED_ENTITIES, cmp_s, + NULL); +#else + const char* const* entity = (const char* const*)bsearch(name, + NAMED_ENTITIES, + sizeof NAMED_ENTITIES / sizeof * NAMED_ENTITIES, + sizeof * NAMED_ENTITIES, cmp); +#endif + + return entity ? entity[1] : NULL; + } + + static size_t putc_utf8(unsigned long cp, char* buffer) + { + unsigned char* bytes = (unsigned char*)buffer; + + if (cp <= 0x007Ful) + { + bytes[0] = (unsigned char)cp; + return 1; + } + + if (cp <= 0x07FFul) + { + bytes[1] = (unsigned char)((2 << 6) | (cp & 0x3F)); + bytes[0] = (unsigned char)((6 << 5) | (cp >> 6)); + return 2; + } + + if (cp <= 0xFFFFul) + { + bytes[2] = (unsigned char)((2 << 6) | (cp & 0x3F)); + bytes[1] = (unsigned char)((2 << 6) | ((cp >> 6) & 0x3F)); + bytes[0] = (unsigned char)((14 << 4) | (cp >> 12)); + return 3; + } + + if (cp <= 0x10FFFFul) + { + bytes[3] = (unsigned char)((2 << 6) | (cp & 0x3F)); + bytes[2] = (unsigned char)((2 << 6) | ((cp >> 6) & 0x3F)); + bytes[1] = (unsigned char)((2 << 6) | ((cp >> 12) & 0x3F)); + bytes[0] = (unsigned char)((30 << 3) | (cp >> 18)); + return 4; + } + + return 0; + } + + static bool parse_entity( + const char* current, char** to, const char** from) + { + const char* end = strchr(current, ';'); + if (!end) return 0; + + if (current[1] == '#') + { + char* tail = NULL; + int errno_save = errno; + bool hex = current[2] == 'x' || current[2] == 'X'; + + errno = 0; + unsigned long cp = strtoul( + current + (hex ? 3 : 2), &tail, hex ? 16 : 10); + + bool fail = errno || tail != end || cp > UNICODE_MAX; + errno = errno_save; + if (fail) return 0; + + *to += putc_utf8(cp, *to); + *from = end + 1; + + return 1; + } + else + { + const char* entity = get_named_entity(¤t[1]); + if (!entity) return 0; + + size_t len = strlen(entity); + sf_memcpy(*to, len, entity, len); + + *to += len; + *from = end + 1; + + return 1; + } + } + + + + size_t decode_html_entities_utf8(char* dest, const char* src) + { + if (!src) src = dest; + + char* to = dest; + const char* from = src; + + for (const char* current; (current = strchr(from, '&'));) + { + memmove(to, from, (size_t)(current - from)); + to += current - from; + + if (parse_entity(current, &to, &from)) + continue; + + from = current; + *to++ = *from++; + } + + size_t remaining = strlen(from); + + memmove(to, from, remaining); + to += remaining; + *to = 0; + + return (size_t)(to - dest); + } + } +} + diff --git a/include/snowflake/Exceptions.hpp b/include/snowflake/Exceptions.hpp index ea32f0a6b0..4663b7e487 100644 --- a/include/snowflake/Exceptions.hpp +++ b/include/snowflake/Exceptions.hpp @@ -35,4 +35,55 @@ class GeneralException: public SnowflakeException { GeneralException(SF_ERROR_STRUCT *error) : SnowflakeException(error) {}; }; +class RenewTimeoutException : public std::exception +{ +public: + RenewTimeoutException(int64 elapsedSeconds, + int8 retriedCount, + bool isCurlTimeoutNoBackoff) : + m_elapsedSeconds(elapsedSeconds), + m_retriedCount(retriedCount), + m_isCurlTimeoutNoBackoff(isCurlTimeoutNoBackoff) + {} + + int64 getElapsedSeconds() + { + return m_elapsedSeconds; + } + + int8 getRetriedCount() + { + return m_retriedCount; + } + + bool isCurlTimeoutNoBackoff() + { + return m_isCurlTimeoutNoBackoff; + } + + virtual const char* what() const noexcept + { + return "internal renew timeout exception"; + } + +private: + int64 m_elapsedSeconds; + int8 m_retriedCount; + // The flag indicate if the renew exception is thrown for renew the request + // within curl timeout and no backoff made + bool m_isCurlTimeoutNoBackoff; +}; + +struct AuthException : public std::exception +{ + AuthException(SF_ERROR_STRUCT* error) : message_(error->msg) {} + AuthException(const std::string& message) : message_(message) {} + + const char* what() const noexcept + { + return message_.c_str(); + } + + std::string message_; +}; #endif //SNOWFLAKECLIENT_EXCEPTIONS_HPP diff --git a/include/snowflake/IAuth.hpp b/include/snowflake/IAuth.hpp new file mode 100644 index 0000000000..5fdfcfe235 --- /dev/null +++ b/include/snowflake/IAuth.hpp @@ -0,0 +1,125 @@ +/* + * Copyright (c) 2018-2019 Snowflake Computing, Inc. All rights reserved. + */ + +#ifndef SNOWFLAKECLIENT_IAUTH_HPP +#define SNOWFLAKECLIENT_IAUTH_HPP + +#include +#include +#include "../../lib/snowflake_util.h" +#include "./Exceptions.hpp" + +#define AUTH_THROW(msg) \ +{ \ + throw AuthException(msg); \ +} + +namespace Snowflake +{ +namespace Client +{ +namespace IAuth +{ + /** + * Authenticator + */ + class IAuthenticator + { + public: + + IAuthenticator() : m_renewTimeout(0) + {} + + virtual ~IAuthenticator() + {} + + virtual void authenticate() = 0; + + virtual void updateDataMap(jsonObject_t& dataMap) = 0; + + // Retrieve authenticator renew timeout, return 0 if not available. + // When the authenticator renew timeout is available, the connection should + // renew the authentication (call renewDataMap) for each time the + // authenticator specific timeout exceeded within the entire login timeout. + int64 getAuthRenewTimeout() + { + return m_renewTimeout; + } + + // Renew the autentication and update datamap. + // The default behavior is to call authenticate() and updateDataMap(). + virtual void renewDataMap(jsonObject_t& dataMap); + + protected: + int64 m_renewTimeout; + }; + + + class IDPAuthenticator + { + public: + IDPAuthenticator() + {}; + + virtual ~IDPAuthenticator() + {}; + + void getIDPInfo(); + + virtual SFURL getServerURLSync(); + /* + * Get IdpInfo for OKTA and SAML 2.0 application + */ + virtual void curl_post_call(SFURL& url, const jsonObject_t& body, jsonObject_t& resp) = 0; + virtual void curl_get_call(SFURL& url, jsonObject_t& resp, bool parseJSON, std::string& raw_data) = 0; + + protected: + std::string tokenURLStr; + std::string ssoURLStr; + //For EXTERNALBROSER in the future + std::string proofKeyStr; + + //These fields should be definied in the child class. + std::string m_authenticator; + std::string m_account; + std::string m_appID; + std::string m_appVersion; + std::string m_user; + std::string m_port; + std::string m_host; + std::string m_protocol; + }; + + class IAuthenticatorOKTA : public IDPAuthenticator, public IAuthenticator + { + public: + IAuthenticatorOKTA() {}; + + virtual ~IAuthenticatorOKTA() {}; + + virtual void authenticate() = 0; + + virtual void updateDataMap(jsonObject_t& dataMap); + + /** + * Extract post back url from samel response. Input is in HTML format. + */ + std::string extractPostBackUrlFromSamlResponse(std::string html); + + protected: + //These fields should be definied in the child class. + std::string m_password; + bool m_disableSamlUrlCheck; + int8 m_retriedCount; + int64 m_retryTimeout; + + private: + std::string oneTimeToken; + std::string m_samlResponse; + }; +} // namespace Auth +} // namespace Client +} // namespace Snowflake + +#endif //SNOWFLAKECLIENT_IIDP_AUTH_HPP diff --git a/include/snowflake/client.h b/include/snowflake/client.h index 1dd5f53b30..8d31948d2c 100644 --- a/include/snowflake/client.h +++ b/include/snowflake/client.h @@ -267,13 +267,14 @@ typedef enum SF_ATTRIBUTE { SF_CON_MAX_VARCHAR_SIZE, SF_CON_MAX_BINARY_SIZE, SF_CON_MAX_VARIANT_SIZE, + SF_CON_DISABLE_SAML_URL_CHECK, SF_CON_OCSP_FAIL_OPEN, SF_DIR_QUERY_URL, SF_DIR_QUERY_URL_PARAM, SF_DIR_QUERY_TOKEN, SF_RETRY_ON_CURLE_COULDNT_CONNECT_COUNT, SF_QUERY_RESULT_TYPE, - SF_CON_OAUTH_TOKEN + SF_CON_OAUTH_TOKEN, } SF_ATTRIBUTE; /** @@ -402,6 +403,7 @@ typedef struct SF_CONNECT { uint64 max_binary_size; uint64 max_variant_size; + sf_bool disable_saml_url_check; //token for OAuth authentication char *oauth_token; } SF_CONNECT; diff --git a/include/snowflake/entities.hpp b/include/snowflake/entities.hpp new file mode 100644 index 0000000000..123601ceec --- /dev/null +++ b/include/snowflake/entities.hpp @@ -0,0 +1,29 @@ +/** + * Copyright 2012 Christoph Gärtner + * Distributed under the Boost Software License, Version 1.0 + */ + +#ifndef DECODE_HTML_ENTITIES_UTF8_ +#define DECODE_HTML_ENTITIES_UTF8_ + +#include +#include + +namespace Snowflake +{ + namespace Client + { + /* Takes input from and decodes into , which should be a buffer + large enough to hold characters. + + If is , input will be taken from , decoding + the entities in-place. + + The function returns the length of the decoded string. + */ + size_t decode_html_entities_utf8(char* dest, const char* src); + + } +} + +#endif diff --git a/lib/client.c b/lib/client.c index 6b818a4fb6..306f45c796 100644 --- a/lib/client.c +++ b/lib/client.c @@ -1164,6 +1164,9 @@ SF_STATUS STDCALL snowflake_set_attribute( case SF_CON_INCLUDE_RETRY_REASON: sf->include_retry_reason = value ? *((sf_bool *)value) : SF_BOOLEAN_TRUE; break; + case SF_CON_DISABLE_SAML_URL_CHECK: + sf->disable_saml_url_check = value ? *((sf_bool*)value) : SF_BOOLEAN_FALSE; + break; default: SET_SNOWFLAKE_ERROR(&sf->error, SF_STATUS_ERROR_BAD_ATTRIBUTE_TYPE, "Invalid attribute type", @@ -1313,6 +1316,9 @@ SF_STATUS STDCALL snowflake_get_attribute( case SF_CON_MAX_VARIANT_SIZE: *value = &sf->max_variant_size; break; + case SF_CON_DISABLE_SAML_URL_CHECK: + *value = &sf->disable_saml_url_check; + break; default: SET_SNOWFLAKE_ERROR(&sf->error, SF_STATUS_ERROR_BAD_ATTRIBUTE_TYPE, "Invalid attribute type", diff --git a/lib/connection.c b/lib/connection.c index 94661ac6d9..cecc8285ac 100644 --- a/lib/connection.c +++ b/lib/connection.c @@ -187,6 +187,7 @@ cJSON *STDCALL create_auth_json_body(SF_CONNECT *sf, snowflake_cJSON_AddStringToObject(data, "EXT_AUTHN_DUO_METHOD", "push"); } } + snowflake_cJSON_AddItemToObject(data, "CLIENT_ENVIRONMENT", client_env); snowflake_cJSON_AddItemToObject(data, "SESSION_PARAMETERS", session_parameters); @@ -198,6 +199,12 @@ cJSON *STDCALL create_auth_json_body(SF_CONNECT *sf, // update authentication information to body auth_update_json_body(sf, body); + if (AUTH_OAUTH == getAuthenticatorType(sf->authenticator)) + { + snowflake_cJSON_AddStringToObject(data, "AUTHENTICATOR", SF_AUTHENTICATOR_OAUTH); + snowflake_cJSON_AddStringToObject(data, "TOKEN", sf->oauth_token); + } + return body; } @@ -470,7 +477,12 @@ sf_bool STDCALL curl_get_call(SF_CONNECT *sf, char *url, SF_HEADER *header, cJSON **json, - SF_ERROR_STRUCT *error) { + SF_ERROR_STRUCT *error, + int64 renew_timeout, + int8 retry_max_count, + int64 retry_timeout, + int64* elapsed_time, + int8* retried_count){ SF_JSON_ERROR json_error; const char *error_msg; char query_code[QUERYCODE_LEN]; @@ -486,7 +498,7 @@ sf_bool STDCALL curl_get_call(SF_CONNECT *sf, get_retry_timeout(sf), SF_BOOLEAN_FALSE, error, sf->insecure_mode, sf->ocsp_fail_open, sf->retry_on_curle_couldnt_connect_count, - 0, sf->retry_count, NULL, NULL, NULL, SF_BOOLEAN_FALSE, + renew_timeout, retry_max_count, elapsed_time, retried_count, NULL, SF_BOOLEAN_FALSE, sf->proxy, sf->no_proxy, SF_BOOLEAN_FALSE, SF_BOOLEAN_FALSE) || !*json) { // Error is set in the perform function @@ -518,7 +530,7 @@ sf_bool STDCALL curl_get_call(SF_CONNECT *sf, if (!create_header(sf, new_header, error)) { break; } - if (!curl_get_call(sf, curl, url, new_header, json, error)) { + if (!curl_get_call(sf, curl, url, new_header, json, error, retry_max_count, renew_timeout, retry_timeout, elapsed_time, retried_count)) { // Error is set in curl call break; } @@ -953,7 +965,8 @@ sf_bool STDCALL request(SF_CONNECT *sf, elapsed_time, retried_count, is_renew, renew_injection); } else if (request_type == GET_REQUEST_TYPE) { - ret = curl_get_call(sf, curl, encoded_url, my_header, json, error); + ret = curl_get_call(sf, curl, encoded_url, my_header, json, error, + renew_timeout, retry_max_count, retry_timeout, elapsed_time, retried_count); } else { SET_SNOWFLAKE_ERROR(error, SF_STATUS_ERROR_BAD_REQUEST, "An unknown request type was passed to the request function", @@ -1344,4 +1357,14 @@ int64 get_retry_timeout(SF_CONNECT *sf) int8 get_login_retry_count(SF_CONNECT *sf) { return (int8)get_less_one(sf->retry_on_connect_count, sf->retry_count); +} + +sf_bool is_one_time_token_request(cJSON* resp) +{ + return snowflake_cJSON_HasObjectItem(resp, "cookieToken") || snowflake_cJSON_HasObjectItem(resp, "sessionToken"); +} + +size_t non_json_resp_write_callback(char* ptr, size_t size, size_t nmemb, void* userdata) +{ + return json_resp_cb(ptr, size, nmemb, userdata); } \ No newline at end of file diff --git a/lib/connection.h b/lib/connection.h index fe605526d0..f05733065d 100644 --- a/lib/connection.h +++ b/lib/connection.h @@ -233,6 +233,9 @@ sf_bool STDCALL create_header(SF_CONNECT *sf, SF_HEADER *header, SF_ERROR_STRUCT * reached and the caller can renew the credentials and * then go back to the retry by calling curl_post_call() again. * 0 means no renew timeout needed. + * For Okta Authentication, whenever the authentication failed, the connector + * should update the onetime token. In this case, the renew timeout < 0, which means + * the request should be renewed for each request. * @param retry_max_count The max number of retry attempts. 0 means no limit. * @param retry_timeout The timeout for retry. Will stop retry when it's exceeded. 0 means no limit. * @param elapsed_time The in/out paramter to record the elapsed time before @@ -261,10 +264,26 @@ sf_bool STDCALL curl_post_call(SF_CONNECT *sf, CURL *curl, char *url, SF_HEADER * @param header Header passed to cURL for use in the request * @param json Reference to a cJSON pointer that is used to store the JSON response upon a successful request * @param error Reference to the Snowflake Error object to set an error if one occurs + * @param renew_timeout For key pair authentication. Credentials could expire + * during the connection retry. Set renew timeout in such + * case so http_perform will return when renew_timeout is + * reached and the caller can renew the credentials and + * then go back to the retry by calling curl_post_call() again. + * 0 means no renew timeout needed. + * For Okta Authentication, whenever the authentication failed, the connector + * should update the onetime token. In this case, the renew timeout < 0, which means + * the request should be renewed for each request. + * @param retry_max_count The max number of retry attempts. 0 means no limit. + * @param retry_timeout The timeout for retry. Will stop retry when it's exceeded. 0 means no limit. + * @param elapsed_time The in/out paramter to record the elapsed time before + * curl_post_call() returned due to renew timeout last time + * @param retried_count The in/out paramter to record the number of retry attempts + * has been done before http_perform() returned due to renew + * timeout last time. * @return Success/failure status of get call. 1 = Success; 0 = Failure */ sf_bool STDCALL curl_get_call(SF_CONNECT *sf, CURL *curl, char *url, SF_HEADER *header, cJSON **json, - SF_ERROR_STRUCT *error); + SF_ERROR_STRUCT *error, int64 renew_timeout, int8 retry_max_count, int64 retry_timeout, int64* elapsed_time, int8* retried_count); /** * Used to determine the sleep time during the next backoff caused by request failure. @@ -624,6 +643,15 @@ int64 get_retry_timeout(SF_CONNECT *sf); */ uint64 sf_get_current_time_millis(); +/* +* A function to check that this request is whether the one time token request. +*/ +sf_bool is_one_time_token_request(cJSON *resp); + +/* +* A write callback function to use to write the response text received from the cURL response with non_json_resp +*/ +size_t non_json_resp_write_callback(char* ptr, size_t size, size_t nmemb, void* userdata); #ifdef __cplusplus } #endif diff --git a/lib/http_perform.c b/lib/http_perform.c index b7291e2034..8190a261f3 100644 --- a/lib/http_perform.c +++ b/lib/http_perform.c @@ -460,8 +460,9 @@ sf_bool STDCALL http_perform(CURL *curl, // When renew timeout is reached, stop retry and return to the caller // to renew request - if ((retry) && (renew_timeout > 0) && - ((time(NULL) - elapsedRetryTime) >= renew_timeout)) { + sf_bool renew_timeout_reached = retry && (renew_timeout > 0) && ((time(NULL) - elapsedRetryTime) >= renew_timeout); + sf_bool renew_timeout_disabled = retry && renew_timeout < 0; + if (renew_timeout_reached || renew_timeout_disabled) { retry = SF_BOOLEAN_FALSE; if (elapsed_time) { *elapsed_time += (time(NULL) - elapsedRetryTime); @@ -489,8 +490,12 @@ sf_bool STDCALL http_perform(CURL *curl, snowflake_cJSON_Delete(*json); *json = NULL; *json = snowflake_cJSON_Parse(buffer.buffer); + if (*json) { ret = SF_BOOLEAN_TRUE; + if (is_one_time_token_request(*json)) { + snowflake_cJSON_AddNullToObject(*json, "code"); + } } else { SET_SNOWFLAKE_ERROR(error, SF_STATUS_ERROR_BAD_JSON, "Unable to parse JSON text response.", diff --git a/lib/snowflake_util.h b/lib/snowflake_util.h new file mode 100644 index 0000000000..c6c9025dc7 --- /dev/null +++ b/lib/snowflake_util.h @@ -0,0 +1,43 @@ +#ifndef SNOWFLAKE_UTIL_H +#define SNOWFLAKE_UTIL_H + +#include "picojson.h" +#include "cJSON.h" + +#ifdef __cplusplus +extern "C" { +#endif + typedef picojson::value jsonValue_t; + typedef std::map jsonObject_t; + typedef std::vector jsonArray_t; + + /* + * Convert the cJSON to picoJSON + */ + void cJSONtoPicoJson(cJSON* cjson, jsonObject_t& picojson); + + /* + * Convert the picojson to cJSON + */ + void picoJsonTocJson(jsonObject_t &picojson, cJSON** cjson); + + /* + * Stringfy the picojson data. + */ + void strToPicoJson(jsonObject_t& picojson, std::string& str); + + + /** + * Verify that if two urls has same prefix (protocl + host + port) + * @param url1 + * @param url2 + * + * @return true if same prefix otherwise false + */ + bool urlHasSamePrefix(std::string url1, std::string url2); + +#ifdef __cplusplus +} +#endif + +#endif //SNOWFLAKE_UTIL_H \ No newline at end of file diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 7eab19d7ea..b41ee49612 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -48,8 +48,8 @@ SET(TESTS_C # will enable lob test when the change on server side will be published # test_lob # test_stats -# This test needs the Oauth token, so it should be run manually - test_manual_connect +# MFA, Oauth, and Okta connections are only able to run testing manually. + test_manual_connect ) SET(TESTS_CXX diff --git a/tests/test_manual_connect.c b/tests/test_manual_connect.c index 65ebf5a6c2..c34f872828 100644 --- a/tests/test_manual_connect.c +++ b/tests/test_manual_connect.c @@ -160,6 +160,36 @@ void test_mfa_connect_with_duo_passcodeInPassword(void** unused) snowflake_term(sf); } +void test_okta_connect(void** unused) { + SF_CONNECT* sf = snowflake_init(); + snowflake_set_attribute(sf, SF_CON_ACCOUNT, + getenv("SNOWFLAKE_TEST_ACCOUNT")); + snowflake_set_attribute(sf, SF_CON_USER, getenv("SNOWFLAKE_TEST_OKTA_USERNAME")); + snowflake_set_attribute(sf, SF_CON_PASSWORD, + getenv("SNOWFLAKE_TEST_OKTA_PASSWORD")); + snowflake_set_attribute(sf, SF_CON_AUTHENTICATOR, + getenv("SNOWFLAKE_TEST_AUTHENTICATOR")); + char* host, * port, * protocol; + host = getenv("SNOWFLAKE_TEST_HOST"); + if (host) { + snowflake_set_attribute(sf, SF_CON_HOST, host); + } + port = getenv("SNOWFLAKE_TEST_PORT"); + if (port) { + snowflake_set_attribute(sf, SF_CON_PORT, port); + } + protocol = getenv("SNOWFLAKE_TEST_PROTOCOL"); + if (protocol) { + snowflake_set_attribute(sf, SF_CON_PROTOCOL, protocol); + } + SF_STATUS status = snowflake_connect(sf); + if (status != SF_STATUS_SUCCESS) { + dump_error(&(sf->error)); + } + assert_int_equal(status, SF_STATUS_SUCCESS); + snowflake_term(sf); +} + void test_none(void** unused) {} @@ -187,6 +217,10 @@ int main(void) tests[0].name = "test_mfa_connect_with_duo_passcodeInPassword"; tests[0].test_func = test_mfa_connect_with_duo_passcodeInPassword; } + else if (strcmp(manual_test, "test_okta_connect") == 0) { + tests[0].name = "test_okta_connect"; + tests[0].test_func = test_okta_connect; + } else { printf("No matching test found for: %s\n", manual_test); }