Skip to content

Commit

Permalink
Adding functionality for MySQL and PostgreSQL to use schemas in the m…
Browse files Browse the repository at this point in the history
…etadata functions
  • Loading branch information
cstiborg committed Aug 6, 2024
1 parent e50e135 commit bca17c1
Show file tree
Hide file tree
Showing 7 changed files with 220 additions and 25 deletions.
17 changes: 17 additions & 0 deletions include/soci/mysql/soci-mysql.h
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,23 @@ struct mysql_session_backend : details::session_backend
return "SELECT table_name AS 'TABLE_NAME' FROM information_schema.tables WHERE table_schema = DATABASE()";
}

std::string get_column_descriptions_query() const override
{
return "SELECT column_name as \"COLUMN_NAME\","
" data_type as \"DATA_TYPE\","
" character_maximum_length as \"CHARACTER_MAXIMUM_LENGTH\","
" numeric_precision as \"NUMERIC_PRECISION\","
" numeric_scale as \"NUMERIC_SCALE\","
" is_nullable as \"IS_NULLABLE\""
" from information_schema.columns"
" where"
" case"
" when :s is not NULL THEN table_schema = :s"
" else table_schema = DATABASE()"
" end"
" and table_name = :t";
}

MYSQL *conn_;
};

Expand Down
3 changes: 3 additions & 0 deletions include/soci/postgresql/soci-postgresql.h
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,9 @@ struct postgresql_session_backend : details::session_backend

std::string get_next_statement_name();

std::string get_table_names_query() const override;
std::string get_column_descriptions_query() const override;

int statementCount_;
bool single_row_mode_;
PGconn * conn_;
Expand Down
8 changes: 7 additions & 1 deletion include/soci/session.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ namespace soci
{
class values;
class backend_factory;
struct schema_table_name;

namespace details
{
Expand Down Expand Up @@ -146,7 +147,8 @@ class SOCI_DECL session
// Since this is intended for use with statement objects, where results are obtained one row after another,
// it makes sense to bind either std::string for each output field or soci::column_info for the whole row.
// Note: table_name is a non-const reference to prevent temporary objects,
// this argument is bound as a regular "use" element.
// this argument is bound as a regular "use" element. The table_name can consist of both a schema name and
// a table_name separated by a dot.
details::prepare_temp_type prepare_column_descriptions(std::string & table_name);

// Functions for basic portable DDL statements.
Expand Down Expand Up @@ -215,6 +217,8 @@ class SOCI_DECL session
SOCI_NOT_COPYABLE(session)

void reset_after_move();
struct schema_table_name * alloc_schema_table_name(std::string & tableName);
void clean_schema_table_names();

std::ostringstream query_stream_;
std::unique_ptr<details::query_transformation_function> query_transformation_;
Expand All @@ -232,6 +236,8 @@ class SOCI_DECL session
bool isFromPool_;
std::size_t poolPosition_;
connection_pool * pool_;

struct schema_table_name * schema_table_name_;
};

} // namespace soci
Expand Down
7 changes: 3 additions & 4 deletions include/soci/soci-backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ class blob_backend
class session_backend
{
public:
session_backend() : failoverCallback_(NULL), session_(NULL), schema_name_("public") {}
session_backend() : failoverCallback_(NULL), session_(NULL) {}
virtual ~session_backend() {}

virtual bool is_connected() = 0;
Expand Down Expand Up @@ -366,7 +366,7 @@ class session_backend
{
return "select table_name as \"TABLE_NAME\""
" from information_schema.tables"
" where table_schema = '" + schema_name_ + "'";
" where table_schema = 'public'";
}

// Returns a query with a single parameter (table name) for the list
Expand All @@ -380,7 +380,7 @@ class session_backend
" numeric_scale as \"NUMERIC_SCALE\","
" is_nullable as \"IS_NULLABLE\""
" from information_schema.columns"
" where table_schema = '" + schema_name_ + "' and table_name = :t";
" where table_schema = 'public' and table_name = :t";
}

virtual std::string create_table(const std::string & tableName)
Expand Down Expand Up @@ -565,7 +565,6 @@ class session_backend

failover_callback * failoverCallback_;
session * session_;
std::string schema_name_;

private:
SOCI_NOT_COPYABLE(session_backend)
Expand Down
1 change: 0 additions & 1 deletion src/backends/mysql/session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,6 @@ mysql_session_backend::mysql_session_backend(
clean_up();
throw mysql_soci_error(errMsg, errNum);
}
schema_name_ = db;
}

#if defined(__GNUC__) && ( __GNUC__ > 4 || (__GNUC__ == 4 && (__GNUC_MINOR__ > 6)))
Expand Down
124 changes: 114 additions & 10 deletions src/backends/postgresql/session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,87 @@ void hard_exec(postgresql_session_backend & session_backend,
postgresql_result(session_backend, PQexec(conn, query)).check_for_errors(errMsg);
}

// helper function to quote a string before sinding to PostgreSQL
char * quote(PGconn * conn, const char *s, size_t len)
{
int error_code;
char *retv = new char[2 * len + 3];
retv[0] = '\'';
int len_esc = PQescapeStringConn(conn, retv + 1, s, len, &error_code);
if (error_code > 0)
{
len_esc = 0;
}
retv[len_esc + 1] = '\'';
retv[len_esc + 2] = '\0';

return retv;
}

// helper function to collect schemas from search_path
std::vector<std::string> get_schema_names(PGconn * conn)
{
std::vector<std::string> schema_names;
PGresult* search_path = PQexec(conn, "SHOW search_path");
if (PQresultStatus(search_path) == PGRES_TUPLES_OK)
{
if (PQntuples(search_path) > 0)
{
std::string schema_name = PQgetvalue(search_path, 0, 0);
if (!(schema_name.length() == 2 && schema_name[0] == '"' && schema_name[1] == '"'))
{
// Assure no bad characters
char * escaped_schema = quote(conn, schema_name.c_str(), schema_name.length());
schema_names.push_back(escaped_schema);
delete[] escaped_schema;
}
}
}
PQclear(search_path);
if (schema_names.empty())
{
PGresult* current_user = PQexec(conn, "SELECT current_user");
if (PQresultStatus(current_user) == PGRES_TUPLES_OK)
{
if (PQntuples(current_user) > 0)
{
std::string user = PQgetvalue(current_user, 0, 0);

// Assure no bad characters
char * escaped_user = quote(conn, user.c_str(), user.length());
schema_names.push_back(escaped_user);
delete[] escaped_user;
}
}
schema_names.push_back("public");
}

return schema_names;
}

// helper function to create a comma separated list of strings
std::string create_list_of_strings(const std::vector<std::string>& list)
{
std::ostringstream oss;
for (size_t i = 0; i < list.size(); ++i) {
if (i != 0) {
oss << ", ";
}
oss << list[i];
}
return oss.str();
}

// helper function to create a case list for strings
std::string create_case_list_of_strings(const std::vector<std::string>& list)
{
std::ostringstream oss;
for (size_t i = 0; i < list.size(); ++i) {
oss << " WHEN " << list[i] << " THEN " << i;
}
return oss.str();
}

} // namespace unnamed

postgresql_session_backend::postgresql_session_backend(
Expand Down Expand Up @@ -67,16 +148,6 @@ void postgresql_session_backend::connect(
: "SET extra_float_digits = 2",
"Cannot set extra_float_digits parameter");

PGresult* res = PQexec(conn, "SHOW search_path");
if (PQresultStatus(res) == PGRES_TUPLES_OK)
{
if (PQntuples(res) > 0)
{
schema_name_ = PQgetvalue(res, 0, 0);
}
}
PQclear(res);

conn_ = conn;
connectionParameters_ = parameters;
}
Expand Down Expand Up @@ -162,3 +233,36 @@ postgresql_blob_backend * postgresql_session_backend::make_blob_backend()
{
return new postgresql_blob_backend(*this);
}

std::string postgresql_session_backend::get_table_names_query() const
{
return std::string("SELECT table_schema || '.' || table_name AS \"TABLE_NAME\" FROM information_schema.tables WHERE table_schema in (") + create_list_of_strings(get_schema_names(conn_)) + ")";
}

std::string postgresql_session_backend::get_column_descriptions_query() const
{
std::vector<std::string> schema_list = get_schema_names(conn_);
return std::string("WITH Schema AS ("
" SELECT table_schema"
" FROM information_schema.columns"
" WHERE table_name = :t"
" AND CASE"
" WHEN :s::VARCHAR is not NULL THEN table_schema = :s::VARCHAR"
" ELSE table_schema in (") + create_list_of_strings(schema_list) + ") END"
" ORDER BY"
" CASE table_schema" +
create_case_list_of_strings(schema_list) +
" ELSE " + std::to_string(schema_list.size()) + " END"
" LIMIT 1 )"
" SELECT column_name as \"COLUMN_NAME\","
" data_type as \"DATA_TYPE\","
" character_maximum_length as \"CHARACTER_MAXIMUM_LENGTH\","
" numeric_precision as \"NUMERIC_PRECISION\","
" numeric_scale as \"NUMERIC_SCALE\","
" is_nullable as \"IS_NULLABLE\""
" FROM information_schema.columns"
" WHERE table_name = :t"
" AND table_schema = ("
" SELECT table_schema"
" FROM Schema )";
}
Loading

0 comments on commit bca17c1

Please sign in to comment.