Skip to content

Commit

Permalink
Added UTF-16 support to session, statement and standard-into-type sou…
Browse files Browse the repository at this point in the history
…rces
  • Loading branch information
bold84 committed Mar 17, 2024
1 parent 334ff8d commit 832a796
Show file tree
Hide file tree
Showing 11 changed files with 423 additions and 71 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ option(SOCI_TESTS "Enable build of collection of SOCI tests" ON)
option(SOCI_ASAN "Enable address sanitizer on GCC v4.8+/Clang v 3.1+" OFF)
option(SOCI_LTO "Enable link time optimization" OFF)
option(SOCI_VISIBILITY "Enable hiding private symbol using ELF visibility if supported by the platform" ON)
option(SOCI_ENABLE_UNICODE "Enable Unicode support for ODBC backend" OFF)

if (SOCI_LTO)
cmake_minimum_required(VERSION 3.9)
Expand Down
8 changes: 8 additions & 0 deletions cmake/SociBackend.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,10 @@ macro(soci_backend NAME)
VERSION ${${PROJECT_NAME}_VERSION}
CLEAN_DIRECT_OUTPUT 1)
endif()

if(SOCI_ENABLE_UNICODE)
target_compile_definitions(${THIS_BACKEND_TARGET} PRIVATE SOCI_ODBC_WIDE UNICODE)
endif()

# Static library target
if(SOCI_STATIC)
Expand Down Expand Up @@ -345,6 +349,10 @@ macro(soci_backend_test)
${THIS_TEST_DEPENDS_LIBRARIES}
soci_core
soci_${BACKENDL})

if(SOCI_ENABLE_UNICODE)
target_compile_definitions(${TEST_TARGET} PRIVATE SOCI_ODBC_WIDE UNICODE)
endif()

add_test(${TEST_TARGET}
${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TEST_TARGET}
Expand Down
74 changes: 69 additions & 5 deletions include/soci/odbc/soci-odbc.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,18 @@
#endif

#include <vector>
#include <memory>
#include <soci/soci-backend.h>
#include <sstream>
#include <locale>
#include <codecvt>
#if defined(_MSC_VER) || defined(__MINGW32__)
#include <windows.h>
#endif
#include <sqlext.h> // ODBC
#ifdef SOCI_ODBC_WIDE
#include <sqlucode.h>
#endif
#include <string.h> // strcpy()

namespace soci
Expand All @@ -39,10 +45,67 @@ namespace details

// This cast is only used to avoid compiler warnings when passing strings
// to ODBC functions, the returned string may *not* be really modified.

inline SQLCHAR* sqlchar_cast(std::string const& s)
{
return reinterpret_cast<SQLCHAR*>(const_cast<char*>(s.c_str()));
}

inline char* sqlchar_cast(SQLCHAR* s)
{
return reinterpret_cast<char*>(s);
}

inline const char* sqlchar_cast(const SQLCHAR* s)
{
return reinterpret_cast<const char*>(s);
}

inline SQLWCHAR* sqlchar_cast(std::wstring const& s)
{
return reinterpret_cast<SQLWCHAR*>(const_cast<wchar_t*>(s.c_str()));
}

inline std::string toUtf8(std::wstring const& s)
{
static std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter;
return converter.to_bytes(s);
}

inline std::string toUtf8(const wchar_t* s)
{
static std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter;
return converter.to_bytes(s);
}

inline std::wstring toUtf16(std::string const& s)
{
static std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter;
return converter.from_bytes(s);
}

inline std::wstring toUtf16(const char* s)
{
static std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter;
return converter.from_bytes(s);
}

// convert single wchar_t to char
inline char toUtf8(wchar_t c)
{
static std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter;
return converter.to_bytes(c)[0];
}

// convert single char to wchar_t
inline wchar_t toUtf16(char c)
{
static std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter;
return converter.from_bytes(&c, &c + 1)[0];
}



}

// Option allowing to specify the "driver completion" parameter of
Expand Down Expand Up @@ -88,7 +151,7 @@ struct odbc_standard_into_type_backend : details::standard_into_type_backend,
private odbc_standard_type_backend_base
{
odbc_standard_into_type_backend(odbc_statement_backend &st)
: odbc_standard_type_backend_base(st), buf_(0)
: odbc_standard_type_backend_base(st), buf_(nullptr)
{}

void define_by_pos(int &position,
Expand All @@ -99,8 +162,8 @@ struct odbc_standard_into_type_backend : details::standard_into_type_backend,
indicator *ind) override;

void clean_up() override;

char *buf_; // generic buffer
char* buf_; // generic buffer
void *data_;
details::exchange_type type_;
int position_;
Expand Down Expand Up @@ -156,7 +219,7 @@ struct odbc_standard_use_type_backend : details::standard_use_type_backend,
{
odbc_standard_use_type_backend(odbc_statement_backend &st)
: odbc_standard_type_backend_base(st),
position_(-1), data_(0), buf_(0), indHolder_(0) {}
position_(-1), data_(0), buf_(nullptr), indHolder_(0) {}

void bind_by_pos(int &position,
void *data, details::exchange_type type, bool readOnly) override;
Expand Down Expand Up @@ -225,7 +288,8 @@ struct odbc_vector_use_type_backend : details::vector_use_type_backend,
void *data_;
details::exchange_type type_;
int position_;
char *buf_; // generic buffer
//details::odbc_char_type *buf_; // generic buffer
char* buf_; // generic buffer
std::size_t colSize_; // size of the string column (used for strings)
// used for strings only
std::size_t maxSize_;
Expand Down
82 changes: 63 additions & 19 deletions src/backends/odbc/session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ odbc_session_backend::odbc_session_backend(
"allocating connection handle");
}

#ifdef SOCI_ODBC_WIDE
SQLWCHAR outConnString[1024];
#else
SQLCHAR outConnString[1024];
#endif // SOCI_ODBC_WIDE
SQLSMALLINT strLength = 0;

// Prompt the user for any missing information (typically UID/PWD) in the
Expand Down Expand Up @@ -83,7 +87,11 @@ odbc_session_backend::odbc_session_backend(
hwnd_for_prompt = ::GetDesktopWindow();
#endif // _WIN32

std::string const & connectString = parameters.get_connect_string();
#ifdef SOCI_ODBC_WIDE
std::wstring const& connectString = toUtf16(parameters.get_connect_string());
#else
std::string const& connectString = parameters.get_connect_string();
#endif // SOCI_ODBC_WIDE

// This "infinite" loop can be executed at most twice.
std::string errContext;
Expand Down Expand Up @@ -136,7 +144,12 @@ odbc_session_backend::odbc_session_backend(
break;
}

#ifdef SOCI_ODBC_WIDE
const std::string outConnStringUtf8 = toUtf8((const wchar_t*)outConnString);
connection_string_.assign(outConnStringUtf8.c_str(), outConnStringUtf8.size());
#else
connection_string_.assign((const char*)outConnString, strLength);
#endif

reset_transaction();

Expand All @@ -151,9 +164,14 @@ void odbc_session_backend::configure_connection()
// ensure that the conversions to/from text round trip correctly, which
// is not the case with the default value of 0. Use the maximal
// supported value, which was 2 until 9.x and is 3 since it.

#ifdef SOCI_ODBC_WIDE
SQLWCHAR product_ver[1024];
#else
char product_ver[1024];
#endif // SOCI_ODBC_WIDE

SQLSMALLINT len = sizeof(product_ver);
// In case UNICODE is defined, SQLGetInfoW is called
SQLRETURN rc = SQLGetInfo(hdbc_, SQL_DBMS_VER, product_ver, len, &len);
if (is_odbc_error(rc))
{
Expand All @@ -165,16 +183,25 @@ void odbc_session_backend::configure_connection()
// need to parse it fully, we just need the major version which,
// conveniently, comes first.
unsigned major_ver = 0;
if (std::sscanf(product_ver, "%u", &major_ver) != 1)
#ifdef SOCI_ODBC_WIDE
const std::string product_ver_utf8(toUtf8(product_ver));
#else
const std::string product_ver_utf8(product_ver);
#endif // SOCI_ODBC_WIDE
if (std::sscanf(product_ver_utf8.c_str(), "%u", &major_ver) != 1)
{
throw soci_error("DBMS version \"" + std::string(product_ver) +
throw soci_error("DBMS version \"" + std::string(product_ver_utf8) +
"\" in unrecognizable format.");
}

details::auto_statement<odbc_statement_backend> st(*this);

#ifdef SOCI_ODBC_WIDE
std::wstring const q(major_ver >= 9 ? L"SET extra_float_digits = 3"
: L"SET extra_float_digits = 2");
#else
std::string const q(major_ver >= 9 ? "SET extra_float_digits = 3"
: "SET extra_float_digits = 2");
#endif // SOCI_ODBC_WIDE
rc = SQLExecDirect(st.hstmt_, sqlchar_cast(q), static_cast<SQLINTEGER>(q.size()));

if (is_odbc_error(rc))
Expand All @@ -192,6 +219,7 @@ void odbc_session_backend::configure_connection()
// Also configure the driver to handle unknown types, such as "xml",
// that we use for x_xmltype, as long varchar instead of limiting them
// to 256 characters (by default).
// In case UNICODE is defined, SQLSetConnectAttrW is called
rc = SQLSetConnectAttr(hdbc_, SQL_ATTR_PGOPT_UNKNOWNSASLONGVARCHAR, (SQLPOINTER)1, 0);

// Ignore the error from this one, failure to set it is not fatal and
Expand All @@ -213,10 +241,16 @@ bool odbc_session_backend::is_connected()

// The name of the table we check for is irrelevant, as long as we have a
// working connection, it should still find (or, hopefully, not) something.

#ifdef SOCI_ODBC_WIDE
SQLWCHAR* dummyText = L"bloordyblop";
#else
SQLCHAR* dummyText = sqlchar_cast("bloordyblop");
#endif // SOCI_ODBC_WIDE
return !is_odbc_error(SQLTables(st.hstmt_,
NULL, SQL_NTS,
NULL, SQL_NTS,
sqlchar_cast("bloordyblop"), SQL_NTS,
dummyText, SQL_NTS,
NULL, SQL_NTS));
}

Expand Down Expand Up @@ -451,7 +485,11 @@ odbc_session_backend::get_database_product() const
if (product_ != prod_uninitialized)
return product_;

#ifdef SOCI_ODBC_WIDE
SQLWCHAR product_name[1024];
#else
char product_name[1024];
#endif // SOCI_ODBC_WIDE
SQLSMALLINT len = sizeof(product_name);
SQLRETURN rc = SQLGetInfo(hdbc_, SQL_DBMS_NAME, product_name, len, &len);
if (is_odbc_error(rc))
Expand All @@ -460,19 +498,25 @@ odbc_session_backend::get_database_product() const
"getting ODBC driver name");
}

if (strcmp(product_name, "Firebird") == 0)
product_ = prod_firebird;
else if (strcmp(product_name, "Microsoft SQL Server") == 0)
product_ = prod_mssql;
else if (strcmp(product_name, "MySQL") == 0)
product_ = prod_mysql;
else if (strcmp(product_name, "Oracle") == 0)
product_ = prod_oracle;
else if (strcmp(product_name, "PostgreSQL") == 0)
product_ = prod_postgresql;
else if (strcmp(product_name, "SQLite") == 0)
product_ = prod_sqlite;
else if (strstr(product_name, "DB2") == product_name) // "DB2/LINUXX8664"
#ifdef SOCI_ODBC_WIDE
const std::string product_name_str(toUtf8(product_name));
#else
const std::string product_name_str(product_name);
#endif

if (product_name_str == "Firebird")
product_ = prod_firebird;
else if (product_name_str == "Microsoft SQL Server")
product_ = prod_mssql;
else if (product_name_str == "MySQL")
product_ = prod_mysql;
else if (product_name_str == "Oracle")
product_ = prod_oracle;
else if (product_name_str == "PostgreSQL")
product_ = prod_postgresql;
else if (product_name_str == "SQLite")
product_ = prod_sqlite;
else if(product_name_str.find("DB2") == 0) // "DB2/LINUXX8664"
product_ = prod_db2;
else
product_ = prod_unknown;
Expand Down
Loading

0 comments on commit 832a796

Please sign in to comment.