From d8106257de142c351c13fb4d711822cd8a97c31a Mon Sep 17 00:00:00 2001 From: Francisco Castro Date: Thu, 31 May 2018 02:24:47 -0300 Subject: [PATCH 1/8] Basic support for prepared statements on sqlite3. With this modification we can enjoy the security benefits of having prepared-statement-alike additional parameters. To do this, the additional parameters should be passed after the statement in the execute method. This means that a new prepared statement will be created on each execute call, so don't expect big a performance increase. Maybe in a distant future a LRU cache of prepared statements could be added. --- doc/us/manual.html | 5 +++-- src/ls_sqlite3.c | 47 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/doc/us/manual.html b/doc/us/manual.html index edbc80f..96498af 100644 --- a/doc/us/manual.html +++ b/doc/us/manual.html @@ -219,8 +219,9 @@

Methods

the operation could not be performed or when it is not implemented. -
conn:execute(statement)
-
Executes the given SQL statement.
+
conn:execute(statement[,...])
+
Executes the given SQL statement. As in traditional prepared statements, + additional parameters can be used to avoid SQL injections, though not all drivers support this.
Returns: a cursor object if there are results, or the number of rows affected by the command otherwise.
diff --git a/src/ls_sqlite3.c b/src/ls_sqlite3.c index 33672ca..6d8d52c 100644 --- a/src/ls_sqlite3.c +++ b/src/ls_sqlite3.c @@ -379,6 +379,7 @@ static int conn_escape(lua_State *L) */ static int conn_execute(lua_State *L) { + int i; conn_data *conn = getconnection(L); const char *statement = luaL_checkstring(L, 2); int res; @@ -398,6 +399,52 @@ static int conn_execute(lua_State *L) return luasql_faildirect(L, errmsg); } + /* bind any additional arguments to the statement */ + numcols = lua_gettop(L); + for (i = 3; i <= numcols; i++) + { + const char * buffer; + size_t size; + switch (lua_type(L, i)) { + case LUA_TNIL: + res = sqlite3_bind_null(vm, i - 2); + break; + + case LUA_TBOOLEAN: + case LUA_TNUMBER: +#ifdef LUA_INT_TYPE + if (lua_isnumber(L, i) && !lua_isinteger(L, i)) + { + res = sqlite3_bind_double(vm, i - 2, lua_tonumber(L, i)); + } + else + { + res = sqlite3_bind_int64(vm, i - 2, lua_tointeger(L, i)); + } +#else + res = sqlite3_bind_double(vm, i - 2, lua_tonumber(L, i)); +#endif + break; + + case LUA_TSTRING: + buffer = lua_tolstring(L, i, &size); + res = sqlite3_bind_blob(vm, i - 2, buffer, size, SQLITE_TRANSIENT); + break; + + default: + sqlite3_finalize(vm); + return luaL_error(L, LUASQL_PREFIX"Invalid type for execute parameter %d", i - 2); + } + + /* handle errors */ + if (res != SQLITE_OK) + { + errmsg = sqlite3_errmsg(conn->sql_conn); + sqlite3_finalize(vm); + return luaL_error(L, LUASQL_PREFIX"Error binding parameter %d: %s", i - 2, errmsg); + } + } + /* process first result to retrive query information and type */ res = sqlite3_step(vm); numcols = sqlite3_column_count(vm); From 65b6faa1eb75ae7406bf6c15417d3ee667c5a2f2 Mon Sep 17 00:00:00 2001 From: Francisco Castro Date: Thu, 31 May 2018 03:41:14 -0300 Subject: [PATCH 2/8] Basic input parameters support for PostgreSQL As in SQLite3, we now support prepared-statements-alike passing of optional parameters by using the PQexecParams function, and since adding support for binary types is not an easy task, any argument is converted to a string before being sent and converted back to the expected type by PostgreSQL... in any case this is better than nothing. You may want to use a cast ``::type'' if it's not inferred. Example: > db = require'luasql.postgres'.postgres():connect('') > assert(db:execute('create table t(a int)')) > assert(db:execute('insert into t values($1)', 17)) > res = assert(db:execute('select $1+$2::int, a from t where a>$1', 3, 4)) > =res:fetch() 7 17 --- src/ls_postgres.c | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/ls_postgres.c b/src/ls_postgres.c index dd97ea7..85162dc 100644 --- a/src/ls_postgres.c +++ b/src/ls_postgres.c @@ -414,7 +414,20 @@ static int conn_escape (lua_State *L) { static int conn_execute (lua_State *L) { conn_data *conn = getconnection (L); const char *statement = luaL_checkstring (L, 2); - PGresult *res = PQexec(conn->pg_conn, statement); + int nparams = lua_gettop(L); + PGresult *res; + if (nparams > 2) { + int i; + const char ** values = malloc(sizeof (char *) * (nparams - 2)); + for (i = 3; i <= nparams; i++) + values[i - 3] = lua_tostring(L, i); + res = PQexecParams(conn->pg_conn, statement, nparams - 2, NULL, values, NULL, NULL, 0); + free(values); + } + else { + /* for multiple statements support */ + res = PQexec(conn->pg_conn, statement); + } if (res && PQresultStatus(res)==PGRES_COMMAND_OK) { /* no tuples returned */ lua_pushnumber(L, atof(PQcmdTuples(res))); From 792896896fc3db92b246c7c18385fe9d8c92f8b8 Mon Sep 17 00:00:00 2001 From: Francisco Castro Date: Thu, 22 Nov 2018 20:27:10 -0300 Subject: [PATCH 3/8] Bind string values as text. Since sqlite3_bind_text is binary safe, binding as text can be done without worries. --- src/ls_sqlite3.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ls_sqlite3.c b/src/ls_sqlite3.c index 6d8d52c..a9b9f2b 100644 --- a/src/ls_sqlite3.c +++ b/src/ls_sqlite3.c @@ -428,7 +428,7 @@ static int conn_execute(lua_State *L) case LUA_TSTRING: buffer = lua_tolstring(L, i, &size); - res = sqlite3_bind_blob(vm, i - 2, buffer, size, SQLITE_TRANSIENT); + res = sqlite3_bind_text(vm, i - 2, buffer, size, SQLITE_TRANSIENT); break; default: From e970d8f1ecb661407a3a12b07ce88a62532b7a64 Mon Sep 17 00:00:00 2001 From: Francisco Castro Date: Fri, 23 Nov 2018 04:42:26 -0300 Subject: [PATCH 4/8] Refactor for mysql_stmt API. This eliminates the need for escaping parameters, since now they can be specified as additional arguments to execute. --- src/ls_mysql.c | 213 +++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 169 insertions(+), 44 deletions(-) diff --git a/src/ls_mysql.c b/src/ls_mysql.c index 1e01c3d..48d959d 100644 --- a/src/ls_mysql.c +++ b/src/ls_mysql.c @@ -10,6 +10,7 @@ #include #include #include +#include #ifdef WIN32 #include @@ -71,13 +72,25 @@ typedef struct { } conn_data; typedef struct { - short closed; - int conn; /* reference to connection */ - int numcols; /* number of columns */ - int colnames, coltypes; /* reference to column information tables */ - MYSQL_RES *my_res; + short closed; + int conn; /* reference to connection */ + int numcols; /* number of columns */ + int colnames, coltypes; /* reference to column information tables */ + MYSQL_RES *my_res; + MYSQL_STMT *stmt; + MYSQL_BIND *params; /* bound to result columns */ + unsigned long *real_lengths; /* params[i].length will point to these real_lengths */ + bool *nulls; /* buffer for is_null */ + bool *errors; /* buffer for error */ } cur_data; +typedef union { + double number; + size_t size; + long long int longlong; + char c; +} column_data; + LUASQL_API int luaopen_luasql_mysql (lua_State *L); @@ -117,11 +130,22 @@ static cur_data *getcursor (lua_State *L) { /* ** Push the value of #i field of #tuple row. */ -static void pushvalue (lua_State *L, void *row, long int len) { - if (row == NULL) - lua_pushnil (L); - else - lua_pushlstring (L, row, len); +static void pushvalue (lua_State *L, cur_data *cur, int i) { + if (cur->nulls[i]) { + lua_pushnil(L); + cur->nulls[i] = 0; + } else { + /* error flags are set whenever lengths differ, but we resize only when real lengths are bigger */ + if (cur->errors[i]) { + if (cur->real_lengths[i] > cur->params[i].buffer_length) { + cur->params[i].buffer = realloc(cur->params[i].buffer, cur->real_lengths[i]); + cur->params[i].buffer_length = cur->real_lengths[i]; + } + mysql_stmt_fetch_column(cur->stmt, &cur->params[i], i, 0); + cur->errors[i] = 0; + } + lua_pushlstring(L, cur->params[i].buffer, cur->real_lengths[i]); + } } @@ -185,38 +209,40 @@ static void create_colinfo (lua_State *L, cur_data *cur) { ** Closes the cursos and nullify all structure fields. */ static void cur_nullify (lua_State *L, cur_data *cur) { + int i; /* Nullify structure fields. */ cur->closed = 1; mysql_free_result(cur->my_res); + mysql_stmt_close(cur->stmt); + for (i = 0; i < cur->numcols; i++) { + free(cur->params[i].buffer); + cur->params[i].buffer = NULL; + } luaL_unref (L, LUA_REGISTRYINDEX, cur->conn); luaL_unref (L, LUA_REGISTRYINDEX, cur->colnames); luaL_unref (L, LUA_REGISTRYINDEX, cur->coltypes); } - + /* ** Get another row of the given cursor. */ static int cur_fetch (lua_State *L) { cur_data *cur = getcursor (L); - MYSQL_RES *res = cur->my_res; - unsigned long *lengths; - MYSQL_ROW row = mysql_fetch_row(res); - if (row == NULL) { - cur_nullify (L, cur); + int r = mysql_stmt_fetch(cur->stmt); + if (r && r != MYSQL_DATA_TRUNCATED) { + cur_nullify(L, cur); lua_pushnil(L); /* no more results */ return 1; } - lengths = mysql_fetch_lengths(res); - if (lua_istable (L, 2)) { const char *opts = luaL_optstring (L, 3, "n"); if (strchr (opts, 'n') != NULL) { /* Copy values to numerical indices */ int i; for (i = 0; i < cur->numcols; i++) { - pushvalue (L, row[i], lengths[i]); - lua_rawseti (L, 2, i+1); + pushvalue(L, cur, i); + lua_rawseti(L, 2, i+1); } } if (strchr (opts, 'a') != NULL) { @@ -231,7 +257,7 @@ static int cur_fetch (lua_State *L) { lua_rawgeti(L, -1, i+1); /* push the field name */ /* Actually push the value */ - pushvalue (L, row[i], lengths[i]); + pushvalue (L, cur, i); lua_rawset (L, 2); } /* lua_pop(L, 1); Pops colnames table. Not needed */ @@ -243,7 +269,7 @@ static int cur_fetch (lua_State *L) { int i; luaL_checkstack (L, cur->numcols, LUASQL_PREFIX"too many columns"); for (i = 0; i < cur->numcols; i++) - pushvalue (L, row[i], lengths[i]); + pushvalue (L, cur, i); return cur->numcols; /* return #numcols values */ } } @@ -317,7 +343,7 @@ static int cur_getcoltypes (lua_State *L) { ** Push the number of rows. */ static int cur_numrows (lua_State *L) { - lua_pushinteger (L, (lua_Number)mysql_num_rows (getcursor(L)->my_res)); + lua_pushinteger (L, (lua_Number)mysql_stmt_num_rows (getcursor(L)->stmt)); return 1; } @@ -325,9 +351,12 @@ static int cur_numrows (lua_State *L) { /* ** Create a new Cursor object and push it on top of the stack. */ -static int create_cursor (lua_State *L, int conn, MYSQL_RES *result, int cols) { - cur_data *cur = (cur_data *)lua_newuserdata(L, sizeof(cur_data)); +static int create_cursor (lua_State *L, int conn, MYSQL_STMT *stmt, MYSQL_RES *result, int cols) { + int i; + size_t memsize = sizeof (cur_data) + cols * (sizeof (MYSQL_BIND) + sizeof (unsigned long) + 2 * sizeof (bool)); + cur_data *cur = (cur_data *)lua_newuserdata(L, memsize); luasql_setmeta (L, LUASQL_CURSOR_MYSQL); + memset(cur, 0, memsize); /* fill in structure */ cur->closed = 0; @@ -336,8 +365,28 @@ static int create_cursor (lua_State *L, int conn, MYSQL_RES *result, int cols) { cur->colnames = LUA_NOREF; cur->coltypes = LUA_NOREF; cur->my_res = result; + cur->stmt = stmt; lua_pushvalue (L, conn); cur->conn = luaL_ref (L, LUA_REGISTRYINDEX); + cur->params = (MYSQL_BIND *)(sizeof (cur_data) + (char *)cur); /* after cur */ + cur->real_lengths = (unsigned long *)(cols * sizeof (MYSQL_BIND) + (char *)cur->params); + cur->nulls = (bool*)(cols * sizeof (unsigned long) + (char *)cur->real_lengths); + cur->errors = (bool*)(cols * sizeof (bool) + (char *)cur->nulls); + for (i = 0; i < cols; i++) { + cur->params[i].buffer_type = MYSQL_TYPE_STRING; + cur->params[i].buffer = malloc(0); + cur->params[i].buffer_length = 0; + cur->params[i].length = &cur->real_lengths[i]; + /* old versions use my_bool, newer use bool. There's no simple way to detect it */ + cur->params[i].is_null = (void*)&cur->nulls[i]; + cur->params[i].error = (void*)&cur->errors[i]; + } + + if (mysql_stmt_bind_result(cur->stmt, cur->params)) { + int n = luasql_failmsg(L, "error executing query (stmt_prepare). MySQL: ", mysql_stmt_error(stmt)); + cur_nullify(L, cur); + return n; + } return 1; } @@ -417,27 +466,103 @@ static int conn_execute (lua_State *L) { conn_data *conn = getconnection (L); size_t st_len; const char *statement = luaL_checklstring (L, 2, &st_len); - if (mysql_real_query(conn->my_conn, statement, st_len)) - /* error executing query */ - return luasql_failmsg(L, "error executing query. MySQL: ", mysql_error(conn->my_conn)); - else - { - MYSQL_RES *res = mysql_store_result(conn->my_conn); - unsigned int num_cols = mysql_field_count(conn->my_conn); - - if (res) { /* tuples returned */ - return create_cursor (L, 1, res, num_cols); - } - else { /* mysql_use_result() returned nothing; should it have? */ - if(num_cols == 0) { /* no tuples returned */ - /* query does not return data (it was not a SELECT) */ - lua_pushinteger(L, mysql_affected_rows(conn->my_conn)); - return 1; - } - else /* mysql_use_result() should have returned data */ - return luasql_failmsg(L, "error retrieving result. MySQL: ", mysql_error(conn->my_conn)); + int i, nparams = lua_gettop(L); + MYSQL_STMT * stmt; + MYSQL_BIND * params; + column_data * params_data; + MYSQL_RES * res; + unsigned int num_cols; + + stmt = mysql_stmt_init(conn->my_conn); + if (stmt == NULL) + return luasql_failmsg(L, "error executing query (stmt_init). MySQL: ", mysql_error(conn->my_conn)); + if (mysql_stmt_prepare(stmt, statement, st_len)) { + int n = luasql_failmsg(L, "error executing query (stmt_prepare). MySQL: ", mysql_stmt_error(stmt)); + mysql_stmt_close(stmt); + return n; + } + if (nparams - 2 != mysql_stmt_param_count(stmt)) { + mysql_stmt_close(stmt); + return luasql_faildirect(L, "error executing query. Invalid parameter count"); + } + params = calloc(sizeof (MYSQL_BIND), nparams - 2); + params_data = calloc(sizeof (column_data), nparams - 2); + for (i = 3; i <= nparams; i++) { + switch (lua_type(L, i)) { + case LUA_TNIL: + params[i-3].buffer_type = MYSQL_TYPE_NULL; + break; + case LUA_TBOOLEAN: + params_data[i-3].c = lua_toboolean(L, i); + params[i-3].buffer_type = MYSQL_TYPE_TINY; + params[i-3].buffer = ¶ms_data[i-3].c; + params[i-3].buffer_length = sizeof (char); + break; + case LUA_TNUMBER: +#ifdef LUA_INT_TYPE + if (lua_isinteger(L, i)) { + params_data[i-3].longlong = lua_tointeger(L, i); + params[i-3].buffer_type = MYSQL_TYPE_LONGLONG; + params[i-3].buffer = ¶ms_data[i-3].longlong; + params[i-3].buffer_length = sizeof (long long int); + break; + } +#endif + params_data[i-3].number = lua_tonumber(L, i); + params[i-3].buffer_type = MYSQL_TYPE_DOUBLE; + params[i-3].buffer = ¶ms_data[i-3].number; + params[i-3].buffer_length = sizeof (double); + break; + case LUA_TSTRING: + params[i-3].buffer_type = MYSQL_TYPE_STRING; + params[i-3].buffer = (char*)lua_tolstring(L, i, ¶ms_data[i-3].size); + params[i-3].buffer_length = params_data[i-3].size; + params[i-3].length = ¶ms_data[i-3].size; + break; + default: + free(params); + free(params_data); + mysql_stmt_close(stmt); + return luasql_faildirect(L, "error executing query. Invalid parameter type"); } } + if (mysql_stmt_bind_param(stmt, params)) { + int n = luasql_failmsg(L, "error executing query (stmt_bind_param). MySQL: ", mysql_stmt_error(stmt)); + free(params); + free(params_data); + mysql_stmt_close(stmt); + return n; + } + if (mysql_stmt_execute(stmt)) { + int n = luasql_failmsg(L, "error executing query (stmt_execute). MySQL: ", mysql_stmt_error(stmt)); + free(params); + free(params_data); + mysql_stmt_close(stmt); + return n; + } + free(params); + free(params_data); + if (mysql_stmt_store_result(stmt)) { + int n = luasql_failmsg(L, "error executing query (stmt_store_result). MySQL: ", mysql_stmt_error(stmt)); + mysql_stmt_close(stmt); + return n; + } + + res = mysql_stmt_result_metadata(stmt); + num_cols = mysql_stmt_field_count(stmt); + if (res) { /* tuples returned */ + return create_cursor (L, 1, stmt, res, num_cols); + } + + if(num_cols == 0) { /* no tuples returned */ + /* query does not return data (it was not a SELECT) */ + lua_pushinteger(L, mysql_stmt_affected_rows(stmt)); + mysql_stmt_close(stmt); + return 1; + } else { /* mysql_use_result() should have returned data */ + mysql_stmt_close(stmt); + return luasql_failmsg(L, "error retrieving result. MySQL: ", mysql_stmt_error(stmt)); + } } From 2baecd2dbc429f75f00a5fd567f4b814b3135e14 Mon Sep 17 00:00:00 2001 From: Francisco Castro Date: Fri, 23 Nov 2018 04:48:20 -0300 Subject: [PATCH 5/8] update documentation regarding execute arguments. --- doc/us/manual.html | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/doc/us/manual.html b/doc/us/manual.html index 9871fad..7754314 100644 --- a/doc/us/manual.html +++ b/doc/us/manual.html @@ -221,7 +221,8 @@

Methods

conn:execute(statement[,...])
Executes the given SQL statement. As in traditional prepared statements, - additional parameters can be used to avoid SQL injections, though not all drivers support this.
+ additional parameters can be used to avoid SQL injections. Although this is only + supported by sqlite3, postgres and mysql drivers.
Returns: a cursor object if there are results, or the number of rows affected by the command otherwise.
From 6fea651677c6609f67878973c096e0d2361d9f5a Mon Sep 17 00:00:00 2001 From: Francisco Castro Date: Fri, 23 Nov 2018 20:57:40 -0300 Subject: [PATCH 6/8] Add a generic conn:execute(sql, ...) function. This function can be used by all the drivers simplifying the logic, requiring only to implement conn:prepare(sql) and stmt:execute(...). --- src/luasql.c | 32 ++++++++++++++++++++++++++++++++ src/luasql.h | 1 + 2 files changed, 33 insertions(+) diff --git a/src/luasql.c b/src/luasql.c index ed1cae7..ecf463a 100644 --- a/src/luasql.c +++ b/src/luasql.c @@ -131,3 +131,35 @@ LUASQL_API void luasql_set_info (lua_State *L) { lua_pushliteral (L, "LuaSQL 2.3.5 (for "LUA_VERSION")"); lua_settable (L, -3); } + +/* +** Execute an SQL statement from a string. +** Return a Cursor object if the statement is a query, otherwise +** return the number of tuples affected by the statement. +** It's nothing more than a C implementation of: +** function conn:execute(sql, ...) +** local stmt, msg = conn:prepare(sql) +** if stmt == nil then return nil, msg end +** return stmt:execute(...) +** end +*/ +LUASQL_API int luasql_conn_execute (lua_State *L) { + // stack: conn sql ... + lua_getfield(L, 1, "prepare"); + lua_pushvalue(L, 1); + lua_pushvalue(L, 2); + // stack: conn sql ... conn.prepare conn sql + lua_call(L, 2, 2); + // stack: conn sql ... stmt msg + if (lua_isnil(L, -2)) + return 2; + lua_pop(L, 1); + lua_replace(L, 2); + // stack: conn stmt ... + lua_getfield(L, 2, "execute"); + lua_replace(L, 1); + // stack: stmt.execute stmt ... + lua_call(L, lua_gettop(L)-1, LUA_MULTRET); + // stack: cur msg? + return lua_gettop(L); +} diff --git a/src/luasql.h b/src/luasql.h index 345bf57..d7fe655 100644 --- a/src/luasql.h +++ b/src/luasql.h @@ -29,6 +29,7 @@ LUASQL_API int luasql_failmsg (lua_State *L, const char *err, const char *m); LUASQL_API int luasql_createmeta (lua_State *L, const char *name, const luaL_Reg *methods); LUASQL_API void luasql_setmeta (lua_State *L, const char *name); LUASQL_API void luasql_set_info (lua_State *L); +LUASQL_API int luasql_conn_execute (lua_State *L); #if !defined LUA_VERSION_NUM || LUA_VERSION_NUM==501 void luaL_setfuncs (lua_State *L, const luaL_Reg *l, int nup); From 04e30371628125bc7adda5ff62dbfe887b25e5d0 Mon Sep 17 00:00:00 2001 From: Francisco Castro Date: Sat, 24 Nov 2018 01:56:27 -0300 Subject: [PATCH 7/8] Support for reusable Prepared Statements in MySQL --- src/ls_mysql.c | 329 +++++++++++++++++++++++++++++-------------------- 1 file changed, 193 insertions(+), 136 deletions(-) diff --git a/src/ls_mysql.c b/src/ls_mysql.c index 48d959d..91e83bb 100644 --- a/src/ls_mysql.c +++ b/src/ls_mysql.c @@ -27,6 +27,7 @@ #define LUASQL_ENVIRONMENT_MYSQL "MySQL environment" #define LUASQL_CONNECTION_MYSQL "MySQL connection" +#define LUASQL_STATEMENT_MYSQL "MySQL statement" #define LUASQL_CURSOR_MYSQL "MySQL cursor" /* For compat with old version 4.0 */ @@ -61,6 +62,15 @@ #endif +struct cur_data; + +typedef union { + double number; + size_t size; + long long int longlong; + char c; +} column_data; + typedef struct { short closed; } env_data; @@ -72,25 +82,28 @@ typedef struct { } conn_data; typedef struct { - short closed; - int conn; /* reference to connection */ - int numcols; /* number of columns */ - int colnames, coltypes; /* reference to column information tables */ - MYSQL_RES *my_res; - MYSQL_STMT *stmt; - MYSQL_BIND *params; /* bound to result columns */ - unsigned long *real_lengths; /* params[i].length will point to these real_lengths */ - bool *nulls; /* buffer for is_null */ - bool *errors; /* buffer for error */ + short closed; + int conn_ref; /* reference to connection */ + MYSQL_STMT *my_stmt; + struct cur_data *cur; /* for closing already open cursors */ + int nparams; + MYSQL_BIND *params; + column_data *params_data; +} statement_data; + +typedef struct cur_data { + short closed; + int stmt_ref; /* reference to statement */ + statement_data *st; + int numcols; /* number of columns */ + int colnames, coltypes; /* references to column information tables */ + MYSQL_RES *my_res; + MYSQL_BIND *params; /* bound to result columns */ + unsigned long *real_lengths; /* params[i].length will point to these real_lengths */ + bool *nulls; /* buffer for is_null */ + bool *errors; /* buffer for error */ } cur_data; -typedef union { - double number; - size_t size; - long long int longlong; - char c; -} column_data; - LUASQL_API int luaopen_luasql_mysql (lua_State *L); @@ -116,6 +129,17 @@ static conn_data *getconnection (lua_State *L) { } +/* +** Check for valid statement. +*/ +static statement_data *getstatement (lua_State *L) { + statement_data *stmt = (statement_data *)luaL_checkudata (L, 1, LUASQL_STATEMENT_MYSQL); + luaL_argcheck (L, stmt != NULL, 1, "statement expected"); + luaL_argcheck (L, !stmt->closed, 1, "statement is closed"); + return stmt; +} + + /* ** Check for valid cursor. */ @@ -141,7 +165,7 @@ static void pushvalue (lua_State *L, cur_data *cur, int i) { cur->params[i].buffer = realloc(cur->params[i].buffer, cur->real_lengths[i]); cur->params[i].buffer_length = cur->real_lengths[i]; } - mysql_stmt_fetch_column(cur->stmt, &cur->params[i], i, 0); + mysql_stmt_fetch_column(cur->st->my_stmt, &cur->params[i], i, 0); cur->errors[i] = 0; } lua_pushlstring(L, cur->params[i].buffer, cur->real_lengths[i]); @@ -212,15 +236,16 @@ static void cur_nullify (lua_State *L, cur_data *cur) { int i; /* Nullify structure fields. */ cur->closed = 1; + if (cur->st->cur == cur) + cur->st->cur = NULL; mysql_free_result(cur->my_res); - mysql_stmt_close(cur->stmt); for (i = 0; i < cur->numcols; i++) { free(cur->params[i].buffer); cur->params[i].buffer = NULL; } - luaL_unref (L, LUA_REGISTRYINDEX, cur->conn); - luaL_unref (L, LUA_REGISTRYINDEX, cur->colnames); - luaL_unref (L, LUA_REGISTRYINDEX, cur->coltypes); + luaL_unref(L, LUA_REGISTRYINDEX, cur->stmt_ref); + luaL_unref(L, LUA_REGISTRYINDEX, cur->colnames); + luaL_unref(L, LUA_REGISTRYINDEX, cur->coltypes); } @@ -229,7 +254,8 @@ static void cur_nullify (lua_State *L, cur_data *cur) { */ static int cur_fetch (lua_State *L) { cur_data *cur = getcursor (L); - int r = mysql_stmt_fetch(cur->stmt); + statement_data * st = cur->st; + int r = mysql_stmt_fetch(st->my_stmt); if (r && r != MYSQL_DATA_TRUNCATED) { cur_nullify(L, cur); lua_pushnil(L); /* no more results */ @@ -249,7 +275,7 @@ static int cur_fetch (lua_State *L) { int i; /* Check if colnames exists */ if (cur->colnames == LUA_NOREF) - create_colinfo(L, cur); + create_colinfo(L, cur); lua_rawgeti (L, LUA_REGISTRYINDEX, cur->colnames);/* Push colnames*/ /* Copy values to alphanumerical indices */ @@ -343,7 +369,7 @@ static int cur_getcoltypes (lua_State *L) { ** Push the number of rows. */ static int cur_numrows (lua_State *L) { - lua_pushinteger (L, (lua_Number)mysql_stmt_num_rows (getcursor(L)->stmt)); + lua_pushinteger (L, (lua_Number)mysql_stmt_num_rows (getcursor(L)->st->my_stmt)); return 1; } @@ -351,7 +377,7 @@ static int cur_numrows (lua_State *L) { /* ** Create a new Cursor object and push it on top of the stack. */ -static int create_cursor (lua_State *L, int conn, MYSQL_STMT *stmt, MYSQL_RES *result, int cols) { +static int create_cursor (lua_State *L, int st_index, statement_data * st, MYSQL_RES *result, int cols) { int i; size_t memsize = sizeof (cur_data) + cols * (sizeof (MYSQL_BIND) + sizeof (unsigned long) + 2 * sizeof (bool)); cur_data *cur = (cur_data *)lua_newuserdata(L, memsize); @@ -360,14 +386,13 @@ static int create_cursor (lua_State *L, int conn, MYSQL_STMT *stmt, MYSQL_RES *r /* fill in structure */ cur->closed = 0; - cur->conn = LUA_NOREF; cur->numcols = cols; cur->colnames = LUA_NOREF; cur->coltypes = LUA_NOREF; cur->my_res = result; - cur->stmt = stmt; - lua_pushvalue (L, conn); - cur->conn = luaL_ref (L, LUA_REGISTRYINDEX); + cur->st = st; + lua_pushvalue(L, st_index); + cur->stmt_ref = luaL_ref(L, LUA_REGISTRYINDEX); cur->params = (MYSQL_BIND *)(sizeof (cur_data) + (char *)cur); /* after cur */ cur->real_lengths = (unsigned long *)(cols * sizeof (MYSQL_BIND) + (char *)cur->params); cur->nulls = (bool*)(cols * sizeof (unsigned long) + (char *)cur->real_lengths); @@ -382,8 +407,8 @@ static int create_cursor (lua_State *L, int conn, MYSQL_STMT *stmt, MYSQL_RES *r cur->params[i].error = (void*)&cur->errors[i]; } - if (mysql_stmt_bind_result(cur->stmt, cur->params)) { - int n = luasql_failmsg(L, "error executing query (stmt_prepare). MySQL: ", mysql_stmt_error(stmt)); + if (mysql_stmt_bind_result(st->my_stmt, cur->params)) { + int n = luasql_failmsg(L, "error executing query (stmt_prepare). MySQL: ", mysql_stmt_error(st->my_stmt)); cur_nullify(L, cur); return n; } @@ -458,114 +483,138 @@ static int escape_string (lua_State *L) { } /* -** Execute an SQL statement. -** Return a Cursor object if the statement is a query, otherwise -** return the number of tuples affected by the statement. +** Create a prepared statement. +** Return a statement_data object that can be used as a first the parameter of conn:execute. +** This allows improved performance when the same query is called several times. */ -static int conn_execute (lua_State *L) { +static int conn_prepare (lua_State *L) { conn_data *conn = getconnection (L); + statement_data * st; size_t st_len; - const char *statement = luaL_checklstring (L, 2, &st_len); - int i, nparams = lua_gettop(L); - MYSQL_STMT * stmt; - MYSQL_BIND * params; - column_data * params_data; - MYSQL_RES * res; - unsigned int num_cols; + const char * sql = luaL_checklstring (L, 2, &st_len); - stmt = mysql_stmt_init(conn->my_conn); - if (stmt == NULL) + st = (statement_data *)lua_newuserdata(L, sizeof (statement_data)); + memset(st, 0, sizeof (statement_data)); + st->conn_ref = LUA_NOREF; + lua_pushvalue(L, 1); + st->conn_ref = luaL_ref (L, LUA_REGISTRYINDEX); + luasql_setmeta(L, LUASQL_STATEMENT_MYSQL); + + st->my_stmt = mysql_stmt_init(conn->my_conn); + if (st->my_stmt == NULL) return luasql_failmsg(L, "error executing query (stmt_init). MySQL: ", mysql_error(conn->my_conn)); - if (mysql_stmt_prepare(stmt, statement, st_len)) { - int n = luasql_failmsg(L, "error executing query (stmt_prepare). MySQL: ", mysql_stmt_error(stmt)); - mysql_stmt_close(stmt); - return n; + if (mysql_stmt_prepare(st->my_stmt, sql, st_len)) { + return luasql_failmsg(L, "error executing query (stmt_prepare). MySQL: ", mysql_stmt_error(st->my_stmt)); } - if (nparams - 2 != mysql_stmt_param_count(stmt)) { - mysql_stmt_close(stmt); - return luasql_faildirect(L, "error executing query. Invalid parameter count"); + st->nparams = mysql_stmt_param_count(st->my_stmt); + st->params = calloc(sizeof (MYSQL_BIND), st->nparams); + st->params_data = calloc(sizeof (column_data), st->nparams); + return 1; +} + +static void statement_nullify(lua_State * L, statement_data * st) { + st->closed = 1; + if (st->cur) + cur_nullify(L, st->cur); + if (st->my_stmt) + mysql_stmt_close(st->my_stmt); + if (st->params) + free(st->params); + if (st->params_data) + free(st->params_data); + luaL_unref(L, LUA_REGISTRYINDEX, st->conn_ref); +} + +static int statement_gc (lua_State *L) { + statement_data * st = (statement_data *)luaL_checkudata (L, 1, LUASQL_STATEMENT_MYSQL); + if (st != NULL && !(st->closed)) + statement_nullify (L, st); + return 0; +} + +static int statement_close (lua_State *L) { + statement_data * st = (statement_data *)luaL_checkudata (L, 1, LUASQL_STATEMENT_MYSQL); + luaL_argcheck (L, st != NULL, 1, LUASQL_PREFIX"statement expected"); + if (st->closed) { + lua_pushboolean (L, 0); + return 1; } - params = calloc(sizeof (MYSQL_BIND), nparams - 2); - params_data = calloc(sizeof (column_data), nparams - 2); - for (i = 3; i <= nparams; i++) { - switch (lua_type(L, i)) { + statement_nullify (L, st); + lua_pushboolean (L, 1); + return 1; +} + +/* +** Execute an SQL statement from a prepared statement. +** Return a Cursor object if the statement is a query, otherwise +** return the number of tuples affected by the statement. +*/ +static int statement_execute (lua_State *L) { + statement_data * st = getstatement (L); + int i; + MYSQL_RES * res; + unsigned int num_cols; + + if (lua_gettop(L) - 1 != st->nparams) + return luasql_faildirect(L, "error executing query. Invalid parameter count"); + memset(st->params, 0, sizeof (MYSQL_BIND) * st->nparams); + memset(st->params_data, 0, sizeof (column_data) * st->nparams); + for (i = 0; i < st->nparams; i++) { + switch (lua_type(L, i + 2)) { case LUA_TNIL: - params[i-3].buffer_type = MYSQL_TYPE_NULL; + st->params[i].buffer_type = MYSQL_TYPE_NULL; break; case LUA_TBOOLEAN: - params_data[i-3].c = lua_toboolean(L, i); - params[i-3].buffer_type = MYSQL_TYPE_TINY; - params[i-3].buffer = ¶ms_data[i-3].c; - params[i-3].buffer_length = sizeof (char); + st->params_data[i].c = lua_toboolean(L, i + 2); + st->params[i].buffer_type = MYSQL_TYPE_TINY; + st->params[i].buffer = &st->params_data[i].c; + st->params[i].buffer_length = sizeof (char); break; case LUA_TNUMBER: #ifdef LUA_INT_TYPE if (lua_isinteger(L, i)) { - params_data[i-3].longlong = lua_tointeger(L, i); - params[i-3].buffer_type = MYSQL_TYPE_LONGLONG; - params[i-3].buffer = ¶ms_data[i-3].longlong; - params[i-3].buffer_length = sizeof (long long int); + st->params_data[i].longlong = lua_tointeger(L, i + 2); + st->params[i].buffer_type = MYSQL_TYPE_LONGLONG; + st->params[i].buffer = &st->params_data[i].longlong; + st->params[i].buffer_length = sizeof (long long int); break; } #endif - params_data[i-3].number = lua_tonumber(L, i); - params[i-3].buffer_type = MYSQL_TYPE_DOUBLE; - params[i-3].buffer = ¶ms_data[i-3].number; - params[i-3].buffer_length = sizeof (double); + st->params_data[i].number = lua_tonumber(L, i + 2); + st->params[i].buffer_type = MYSQL_TYPE_DOUBLE; + st->params[i].buffer = &st->params_data[i].number; + st->params[i].buffer_length = sizeof (double); break; case LUA_TSTRING: - params[i-3].buffer_type = MYSQL_TYPE_STRING; - params[i-3].buffer = (char*)lua_tolstring(L, i, ¶ms_data[i-3].size); - params[i-3].buffer_length = params_data[i-3].size; - params[i-3].length = ¶ms_data[i-3].size; + st->params[i].buffer_type = MYSQL_TYPE_STRING; + st->params[i].buffer = (char*)lua_tolstring(L, i + 2, &st->params_data[i].size); + st->params[i].buffer_length = st->params_data[i].size; + st->params[i].length = &st->params_data[i].size; break; default: - free(params); - free(params_data); - mysql_stmt_close(stmt); return luasql_faildirect(L, "error executing query. Invalid parameter type"); } } - if (mysql_stmt_bind_param(stmt, params)) { - int n = luasql_failmsg(L, "error executing query (stmt_bind_param). MySQL: ", mysql_stmt_error(stmt)); - free(params); - free(params_data); - mysql_stmt_close(stmt); - return n; - } - if (mysql_stmt_execute(stmt)) { - int n = luasql_failmsg(L, "error executing query (stmt_execute). MySQL: ", mysql_stmt_error(stmt)); - free(params); - free(params_data); - mysql_stmt_close(stmt); - return n; - } - free(params); - free(params_data); - if (mysql_stmt_store_result(stmt)) { - int n = luasql_failmsg(L, "error executing query (stmt_store_result). MySQL: ", mysql_stmt_error(stmt)); - mysql_stmt_close(stmt); - return n; - } - - res = mysql_stmt_result_metadata(stmt); - num_cols = mysql_stmt_field_count(stmt); - if (res) { /* tuples returned */ - return create_cursor (L, 1, stmt, res, num_cols); - } - - if(num_cols == 0) { /* no tuples returned */ - /* query does not return data (it was not a SELECT) */ - lua_pushinteger(L, mysql_stmt_affected_rows(stmt)); - mysql_stmt_close(stmt); - return 1; - } else { /* mysql_use_result() should have returned data */ - mysql_stmt_close(stmt); - return luasql_failmsg(L, "error retrieving result. MySQL: ", mysql_stmt_error(stmt)); - } + if (mysql_stmt_bind_param(st->my_stmt, st->params)) + return luasql_failmsg(L, "error executing query (stmt_bind_param). MySQL: ", mysql_stmt_error(st->my_stmt)); + if (mysql_stmt_execute(st->my_stmt)) + return luasql_failmsg(L, "error executing query (stmt_execute). MySQL: ", mysql_stmt_error(st->my_stmt)); + if (mysql_stmt_store_result(st->my_stmt)) + return luasql_failmsg(L, "error executing query (stmt_store_result). MySQL: ", mysql_stmt_error(st->my_stmt)); + + res = mysql_stmt_result_metadata(st->my_stmt); + num_cols = mysql_stmt_field_count(st->my_stmt); + + if (res) /* tuples returned */ + return create_cursor (L, 1, st, res, num_cols); + if (num_cols > 0) /* mysql_use_result() should have returned data */ + return luasql_failmsg(L, "error retrieving result. MySQL: ", mysql_stmt_error(st->my_stmt)); + + /* no tuples returned: query does not return data (it was not a SELECT) */ + lua_pushinteger(L, mysql_stmt_affected_rows(st->my_stmt)); + return 1; } - /* ** Commit the current transaction. */ @@ -692,35 +741,43 @@ static int env_close (lua_State *L) { ** Create metatables for each class of object. */ static void create_metatables (lua_State *L) { - struct luaL_Reg environment_methods[] = { - {"__gc", env_gc}, - {"close", env_close}, - {"connect", env_connect}, + struct luaL_Reg environment_methods[] = { + {"__gc", env_gc}, + {"close", env_close}, + {"connect", env_connect}, {NULL, NULL}, }; - struct luaL_Reg connection_methods[] = { - {"__gc", conn_gc}, - {"close", conn_close}, - {"ping", conn_ping}, - {"escape", escape_string}, - {"execute", conn_execute}, - {"commit", conn_commit}, - {"rollback", conn_rollback}, - {"setautocommit", conn_setautocommit}, + struct luaL_Reg connection_methods[] = { + {"__gc", conn_gc}, + {"close", conn_close}, + {"ping", conn_ping}, + {"escape", escape_string}, + {"execute", luasql_conn_execute}, + {"prepare", conn_prepare}, + {"commit", conn_commit}, + {"rollback", conn_rollback}, + {"setautocommit", conn_setautocommit}, {"getlastautoid", conn_getlastautoid}, {NULL, NULL}, - }; - struct luaL_Reg cursor_methods[] = { - {"__gc", cur_gc}, - {"close", cur_close}, - {"getcolnames", cur_getcolnames}, - {"getcoltypes", cur_getcoltypes}, - {"fetch", cur_fetch}, - {"numrows", cur_numrows}, + }; + struct luaL_Reg statement_methods[] = { + {"__gc", statement_gc}, + {"close", statement_close}, + {"execute", statement_execute}, + {NULL, NULL}, + }; + struct luaL_Reg cursor_methods[] = { + {"__gc", cur_gc}, + {"close", cur_close}, + {"getcolnames", cur_getcolnames}, + {"getcoltypes", cur_getcoltypes}, + {"fetch", cur_fetch}, + {"numrows", cur_numrows}, {NULL, NULL}, - }; + }; luasql_createmeta (L, LUASQL_ENVIRONMENT_MYSQL, environment_methods); luasql_createmeta (L, LUASQL_CONNECTION_MYSQL, connection_methods); + luasql_createmeta (L, LUASQL_STATEMENT_MYSQL, statement_methods); luasql_createmeta (L, LUASQL_CURSOR_MYSQL, cursor_methods); lua_pop (L, 3); } From 9cfa0dd91d18e2a2ab6514ce319784c1001754c0 Mon Sep 17 00:00:00 2001 From: Francisco Castro Date: Tue, 27 Nov 2018 20:52:24 -0300 Subject: [PATCH 8/8] add test cases for prepare statements --- tests/mysql.lua | 2 ++ tests/postgres.lua | 2 ++ tests/sqlite3.lua | 1 + tests/test.lua | 40 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 45 insertions(+) diff --git a/tests/mysql.lua b/tests/mysql.lua index aa1f313..3f6bfdc 100644 --- a/tests/mysql.lua +++ b/tests/mysql.lua @@ -9,6 +9,8 @@ table.insert (CUR_METHODS, "numrows") table.insert (EXTENSIONS, numrows) table.insert (CONN_METHODS, "escape") table.insert (EXTENSIONS, escape) +table.insert (EXTENSIONS, prepare) +table.insert (EXTENSIONS, execparams) --------------------------------------------------------------------- -- Build SQL command to create the test table. diff --git a/tests/postgres.lua b/tests/postgres.lua index 8cec53d..19cff97 100644 --- a/tests/postgres.lua +++ b/tests/postgres.lua @@ -4,8 +4,10 @@ --------------------------------------------------------------------- DEFAULT_USERNAME = "postgres" +PREPARED_STATEMENT_ARGUMENT = function(n) return "$"..n end table.insert (CUR_METHODS, "numrows") table.insert (EXTENSIONS, numrows) table.insert (CONN_METHODS, "escape") table.insert (EXTENSIONS, escape) +table.insert (EXTENSIONS, execparams) diff --git a/tests/sqlite3.lua b/tests/sqlite3.lua index 1773cf1..102c7ef 100644 --- a/tests/sqlite3.lua +++ b/tests/sqlite3.lua @@ -20,3 +20,4 @@ end table.insert (CONN_METHODS, "escape") table.insert (EXTENSIONS, escape) +table.insert (EXTENSIONS, execparams) diff --git a/tests/test.lua b/tests/test.lua index 00bcb1f..9811ad1 100644 --- a/tests/test.lua +++ b/tests/test.lua @@ -15,6 +15,10 @@ MSG_CURSOR_NOT_CLOSED = "cursor was not automatically closed by fetch" CHECK_GETCOL_INFO_TABLES = true +PREPARED_STATEMENT_ARGUMENT = function(n) + return "?" -- postgres uses $1, $2, ... +end + --------------------------------------------------------------------- if not string.find(_VERSION, " 5.0") then table.getn = assert((loadstring or load)[[return function (t) return #t end]])() @@ -111,6 +115,10 @@ CUR_METHODS = { "close", "fetch", "getcolnames", "getcoltypes", } CUR_OK = function (obj) return test_object (obj, CUR_METHODS) end +STMT_METHODS = { "close", "execute", } +STMT_OK = function (obj) + return test_object (obj, STMT_METHODS) +end function checkUnknownDatabase(ENV) assert2 (nil, ENV:connect ("/unknown-data-base"), "this should be an error") @@ -511,6 +519,38 @@ function escape () io.write (" escape") end +function execparams () + local arg = PREPARED_STATEMENT_ARGUMENT + + assert2 (1, CONN:execute ("insert into t (f1, f2) values ("..arg(1)..", "..arg(2)..")", "x", "'")) + local cur = CUR_OK (CONN:execute ("select f1 from t where f2 = "..arg(1), "'")) + assert2 ((cur:fetch()), 'x') + cur:close() + assert2 (1, CONN:execute ("delete from t where f2 in ("..arg(1)..")", "'")) + + io.write (" execparams") +end + +function prepare () + local arg = PREPARED_STATEMENT_ARGUMENT + + local sql = "insert into t (f1, f2) values ("..arg(1)..", "..arg(2)..")" + local stmt = STMT_OK (CONN:prepare(sql)) + + for i = 1, 10 do + assert2 (1, stmt:execute ("foo", i)) + end + stmt:close() + local stmt = STMT_OK (CONN:prepare("select count(*) from t where f1 = "..arg(1))) + local cur = CUR_OK (stmt:execute "foo") + cur:close() + stmt:close() + + assert2 (10, CONN:execute ("delete from t where f1 = 'foo'")) + + io.write (" execparams") +end + --------------------------------------------------------------------- --------------------------------------------------------------------- function check_close()