Skip to content

Commit

Permalink
Support Postgres in database.ts (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zarel authored Nov 30, 2023
1 parent 84e373f commit 7bbc471
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 96 deletions.
17 changes: 9 additions & 8 deletions src/actions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {Replays} from './replays';
import {ActionError, QueryHandler, Server} from './server';
import {toID, updateserver, bash, time, escapeHTML} from './utils';
import * as tables from './tables';
import {SQL} from './database';
import * as pathModule from 'path';
import IPTools from './ip-tools';
import * as crypto from 'crypto';
Expand Down Expand Up @@ -662,9 +663,9 @@ export const actions: {[k: string]: QueryHandler} = {
}
let teams = [];
try {
teams = await tables.pgdb.query(
'SELECT teamid, team, format, title as name FROM teams WHERE ownerid = $1', [this.user.id]
) ?? [];
teams = await tables.teams.selectAll<any>(
SQL`teamid, team, format, title as name`
)`WHERE ownerid = ${this.user.id}`;
} catch (e) {
Server.crashlog(e, 'a teams database query', params);
throw new ActionError('The server could not load your teams. Please try again later.');
Expand Down Expand Up @@ -693,13 +694,13 @@ export const actions: {[k: string]: QueryHandler} = {
throw new ActionError("Invalid team ID");
}
try {
const data = await tables.pgdb.query(
`SELECT ownerid, team, private as privacy FROM teams WHERE teamid = $1`, [teamid]
);
if (!data || !data.length || data[0].ownerid !== this.user.id) {
const data = await tables.teams.selectOne<any>(
SQL`ownerid, team, private as privacy`
)`WHERE teamid = ${teamid}`;
if (!data || data.ownerid !== this.user.id) {
return {team: null};
}
return data[0];
return data;
} catch (e) {
Server.crashlog(e, 'a teams database request', params);
throw new ActionError("Failed to fetch team. Please try again later.");
Expand Down
208 changes: 125 additions & 83 deletions src/database.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,11 @@ export class SQLStatement {
} else if (value === undefined) {
this.sql[this.sql.length - 1] += nextString;
} else if (Array.isArray(value)) {
if (this.sql[this.sql.length - 1].endsWith(`\``)) {
if ('"`'.includes(this.sql[this.sql.length - 1].slice(-1))) {
// "`a`, `b`" syntax
const quoteChar = this.sql[this.sql.length - 1].slice(-1);
for (const col of value) {
this.append(col, `\`, \``);
this.append(col, `${quoteChar}, ${quoteChar}`);
}
this.sql[this.sql.length - 1] = this.sql[this.sql.length - 1].slice(0, -4) + nextString;
} else {
Expand All @@ -52,21 +53,21 @@ export class SQLStatement {
}
} else if (this.sql[this.sql.length - 1].endsWith('(')) {
// "(`a`, `b`) VALUES (1, 2)" syntax
this.sql[this.sql.length - 1] += `\``;
this.sql[this.sql.length - 1] += `"`;
for (const col in value) {
this.append(col, `\`, \``);
this.append(col, `", "`);
}
this.sql[this.sql.length - 1] = this.sql[this.sql.length - 1].slice(0, -4) + `\`) VALUES (`;
this.sql[this.sql.length - 1] = this.sql[this.sql.length - 1].slice(0, -4) + `") VALUES (`;
for (const col in value) {
this.append(value[col], `, `);
}
this.sql[this.sql.length - 1] = this.sql[this.sql.length - 1].slice(0, -2) + nextString;
} else if (this.sql[this.sql.length - 1].toUpperCase().endsWith(' SET ')) {
// "`a` = 1, `b` = 2" syntax
this.sql[this.sql.length - 1] += `\``;
this.sql[this.sql.length - 1] += `"`;
for (const col in value) {
this.append(col, `\` = `);
this.append(value[col], `, \``);
this.append(col, `" = `);
this.append(value[col], `, "`);
}
this.sql[this.sql.length - 1] = this.sql[this.sql.length - 1].slice(0, -3) + nextString;
} else {
Expand All @@ -83,27 +84,29 @@ export class SQLStatement {
* Tag function for SQL, with some magic.
*
* * `` SQL`UPDATE table SET a = ${'hello"'}` ``
* * `` 'UPDATE table SET a = "hello"' ``
* * `` `UPDATE table SET a = 'hello'` ``
*
* Values surrounded by `` \` `` become names:
* Values surrounded by `"` or `` ` `` become identifiers:
*
* * ``` SQL`SELECT * FROM \`${'table'}\`` ```
* * `` 'SELECT * FROM `table`' ``
* * ``` SQL`SELECT * FROM "${'table'}"` ```
* * `` `SELECT * FROM "table"` ``
*
* (Make sure to use `"` for Postgres and `` ` `` for MySQL.)
*
* Objects preceded by SET become setters:
*
* * `` SQL`UPDATE table SET ${{a: 1, b: 2}}` ``
* * `` 'UPDATE table SET `a` = 1, `b` = 2' ``
* * `` `UPDATE table SET "a" = 1, "b" = 2` ``
*
* Objects surrounded by `()` become keys and values:
*
* * `` SQL`INSERT INTO table (${{a: 1, b: 2}})` ``
* * `` 'INSERT INTO table (`a`, `b`) VALUES (1, 2)' ``
* * `` `INSERT INTO table ("a", "b") VALUES (1, 2)` ``
*
* Arrays become lists; surrounding by `` \` `` turns them into lists of names:
* Arrays become lists; surrounding by `"` or `` ` `` turns them into lists of names:
*
* * `` SQL`INSERT INTO table (\`${['a', 'b']}\`) VALUES (${[1, 2]})` ``
* * `` 'INSERT INTO table (`a`, `b`) VALUES (1, 2)' ``
* * `` SQL`INSERT INTO table ("${['a', 'b']}") VALUES (${[1, 2]})` ``
* * `` `INSERT INTO table ("a", "b") VALUES (1, 2)` ``
*/
export function SQL(strings: TemplateStringsArray, ...values: SQLValue[]) {
return new SQLStatement(strings, values);
Expand All @@ -113,53 +116,24 @@ export interface ResultRow {[k: string]: BasicSQLValue}

export const connectedDatabases: Database[] = [];

export class Database {
connection: mysql.Pool;
export abstract class Database<Pool extends mysql.Pool | pg.Pool = mysql.Pool | pg.Pool, OkPacket = unknown> {
connection: Pool;
prefix: string;
constructor(config: mysql.PoolOptions & {prefix?: string}) {
this.prefix = config.prefix || "";
if (config.prefix) {
config = {...config};
delete config.prefix;
}
this.connection = mysql.createPool(config);
constructor(connection: Pool, prefix = '') {
this.prefix = prefix;
this.connection = connection;
connectedDatabases.push(this);
}
resolveSQL(query: SQLStatement): [query: string, values: BasicSQLValue[]] {
let sql = query.sql[0];
const values = [];
for (let i = 0; i < query.values.length; i++) {
const value = query.values[i];
if (query.sql[i + 1].startsWith('`')) {
sql = sql.slice(0, -1) + this.connection.escapeId('' + value) + query.sql[i + 1].slice(1);
} else {
sql += '?' + query.sql[i + 1];
values.push(value);
}
}
return [sql, values];
}
abstract _resolveSQL(query: SQLStatement): [query: string, values: BasicSQLValue[]];
abstract _query(sql: string, values: BasicSQLValue[]): Promise<any>;
abstract escapeId(param: string): string;
query<T = ResultRow>(sql: SQLStatement): Promise<T[]>;
query<T = ResultRow>(): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise<T[]>;
query<T = ResultRow>(sql?: SQLStatement) {
if (!sql) return (strings: any, ...rest: any) => this.query<T>(new SQLStatement(strings, rest));

return new Promise<T[]>((resolve, reject) => {
const [query, values] = this.resolveSQL(sql);
this.connection.query(query, values, (e, results: any) => {
if (e) {
return reject(new Error(`${e.message} (${query}) (${values}) [${e.code}]`));
}
if (Array.isArray(results)) {
for (const row of results) {
for (const col in row) {
if (Buffer.isBuffer(row[col])) row[col] = row[col].toString();
}
}
}
return resolve(results);
});
});
const [query, values] = this._resolveSQL(sql);
return this._query(query, values);
}
queryOne<T = ResultRow>(sql: SQLStatement): Promise<T | undefined>;
queryOne<T = ResultRow>(): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise<T | undefined>;
Expand All @@ -168,14 +142,14 @@ export class Database {

return this.query<T>(sql).then(res => Array.isArray(res) ? res[0] : res);
}
queryExec(sql: SQLStatement): Promise<mysql.OkPacket>;
queryExec(): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise<mysql.OkPacket>;
queryExec(sql: SQLStatement): Promise<OkPacket>;
queryExec(): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise<OkPacket>;
queryExec(sql?: SQLStatement) {
if (!sql) return (strings: any, ...rest: any) => this.queryExec(new SQLStatement(strings, rest));
return this.queryOne<mysql.OkPacket>(sql);
return this.queryOne<OkPacket>(sql);
}
close() {
this.connection.end();
void this.connection.end();
}
}

Expand All @@ -198,7 +172,7 @@ export class DatabaseTable<Row> {
this.primaryKeyName = primaryKeyName;
}
escapeId(param: string) {
return this.db.connection.escapeId(param);
return this.db.escapeId(param);
}

// raw
Expand All @@ -224,45 +198,52 @@ export class DatabaseTable<Row> {
selectAll<T = Row>(entries?: (keyof Row & string)[] | SQLStatement):
(strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise<T[]> {
if (!entries) entries = SQL`*`;
if (Array.isArray(entries)) entries = SQL`\`${entries}\``;
if (Array.isArray(entries)) entries = SQL`"${entries}"`;
return (strings, ...rest) =>
this.query<T>()`SELECT ${entries} FROM \`${this.name}\` ${new SQLStatement(strings, rest)}`;
this.query<T>()`SELECT ${entries} FROM "${this.name}" ${new SQLStatement(strings, rest)}`;
}
selectOne<T = Row>(entries?: (keyof Row & string)[] | SQLStatement):
(strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise<T | undefined> {
if (!entries) entries = SQL`*`;
if (Array.isArray(entries)) entries = SQL`\`${entries}\``;
if (Array.isArray(entries)) entries = SQL`"${entries}"`;
return (strings, ...rest) =>
this.queryOne<T>()`SELECT ${entries} FROM \`${this.name}\` ${new SQLStatement(strings, rest)} LIMIT 1`;
this.queryOne<T>()`SELECT ${entries} FROM "${this.name}" ${new SQLStatement(strings, rest)} LIMIT 1`;
}
updateAll(partialRow: PartialOrSQL<Row>):
(strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise<mysql.OkPacket> {
return (strings, ...rest) =>
this.queryExec()`UPDATE \`${this.name}\` SET ${partialRow as any} ${new SQLStatement(strings, rest)}`;
this.queryExec()`UPDATE "${this.name}" SET ${partialRow as any} ${new SQLStatement(strings, rest)}`;
}
updateOne(partialRow: PartialOrSQL<Row>):
(strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise<mysql.OkPacket> {
return (s, ...r) =>
this.queryExec()`UPDATE \`${this.name}\` SET ${partialRow as any} ${new SQLStatement(s, r)} LIMIT 1`;
this.queryExec()`UPDATE "${this.name}" SET ${partialRow as any} ${new SQLStatement(s, r)} LIMIT 1`;
}
deleteAll():
(strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise<mysql.OkPacket> {
return (strings, ...rest) =>
this.queryExec()`DELETE FROM \`${this.name}\` ${new SQLStatement(strings, rest)}`;
this.queryExec()`DELETE FROM "${this.name}" ${new SQLStatement(strings, rest)}`;
}
deleteOne():
(strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise<mysql.OkPacket> {
return (strings, ...rest) =>
this.queryExec()`DELETE FROM \`${this.name}\` ${new SQLStatement(strings, rest)} LIMIT 1`;
this.queryExec()`DELETE FROM "${this.name}" ${new SQLStatement(strings, rest)} LIMIT 1`;
}
eval<T>():
(strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise<T | undefined> {
return (strings, ...rest) =>
this.queryOne<{result: T}>(
)`SELECT ${new SQLStatement(strings, rest)} AS result FROM "${this.name}" LIMIT 1`
.then(row => row?.result);
}

// high-level

insert(partialRow: PartialOrSQL<Row>, where?: SQLStatement) {
return this.queryExec()`INSERT INTO \`${this.name}\` (${partialRow as SQLValue}) ${where}`;
return this.queryExec()`INSERT INTO "${this.name}" (${partialRow as SQLValue}) ${where}`;
}
insertIgnore(partialRow: PartialOrSQL<Row>, where?: SQLStatement) {
return this.queryExec()`INSERT IGNORE INTO \`${this.name}\` (${partialRow as SQLValue}) ${where}`;
return this.queryExec()`INSERT IGNORE INTO "${this.name}" (${partialRow as SQLValue}) ${where}`;
}
async tryInsert(partialRow: PartialOrSQL<Row>, where?: SQLStatement) {
try {
Expand All @@ -279,28 +260,89 @@ export class DatabaseTable<Row> {
return this.replace(partialRow, where);
}
replace(partialRow: PartialOrSQL<Row>, where?: SQLStatement) {
return this.queryExec()`REPLACE INTO \`${this.name}\` (${partialRow as SQLValue}) ${where}`;
return this.queryExec()`REPLACE INTO "${this.name}" (${partialRow as SQLValue}) ${where}`;
}
get(primaryKey: BasicSQLValue, entries?: (keyof Row & string)[] | SQLStatement) {
return this.selectOne(entries)`WHERE \`${this.primaryKeyName}\` = ${primaryKey}`;
return this.selectOne(entries)`WHERE "${this.primaryKeyName}" = ${primaryKey}`;
}
delete(primaryKey: BasicSQLValue) {
return this.deleteAll()`WHERE \`${this.primaryKeyName}\` = ${primaryKey} LIMIT 1`;
return this.deleteAll()`WHERE "${this.primaryKeyName}" = ${primaryKey} LIMIT 1`;
}
update(primaryKey: BasicSQLValue, data: PartialOrSQL<Row>) {
return this.updateAll(data)`WHERE \`${this.primaryKeyName}\` = ${primaryKey} LIMIT 1`;
return this.updateAll(data)`WHERE "${this.primaryKeyName}" = ${primaryKey} LIMIT 1`;
}
}

export class PGDatabase {
database: pg.Pool | null;
constructor(config: pg.PoolConfig | null) {
this.database = config ? new pg.Pool(config) : null;
export class MySQLDatabase extends Database<mysql.Pool, mysql.OkPacket> {
constructor(config: mysql.PoolOptions & {prefix?: string}) {
const prefix = config.prefix || "";
if (config.prefix) {
config = {...config};
delete config.prefix;
}
super(mysql.createPool(config), prefix);
}
async query<O = any>(query: string, values: BasicSQLValue[]) {
if (!this.database) return null;
const result = await this.database.query(query, values);
return result.rows as O[];
override _resolveSQL(query: SQLStatement): [query: string, values: BasicSQLValue[]] {
let sql = query.sql[0];
const values = [];
for (let i = 0; i < query.values.length; i++) {
const value = query.values[i];
if (query.sql[i + 1].startsWith('`') || query.sql[i + 1].startsWith('"')) {
sql = sql.slice(0, -1) + this.escapeId('' + value) + query.sql[i + 1].slice(1);
} else {
sql += '?' + query.sql[i + 1];
values.push(value);
}
}
return [sql, values];
}
override _query(query: string, values: BasicSQLValue[]): Promise<any> {
return new Promise((resolve, reject) => {
this.connection.query(query, values, (e, results: any) => {
if (e) {
return reject(new Error(`${e.message} (${query}) (${values}) [${e.code}]`));
}
if (Array.isArray(results)) {
for (const row of results) {
for (const col in row) {
if (Buffer.isBuffer(row[col])) row[col] = row[col].toString();
}
}
}
return resolve(results);
});
});
}
override escapeId(id: string) {
return this.connection.escapeId(id);
}
}

export class PGDatabase extends Database<pg.Pool, []> {
constructor(config: pg.PoolConfig) {
super(new pg.Pool(config));
}
override _resolveSQL(query: SQLStatement): [query: string, values: BasicSQLValue[]] {
let sql = query.sql[0];
const values = [];
let paramCount = 0;
for (let i = 0; i < query.values.length; i++) {
const value = query.values[i];
if (query.sql[i + 1].startsWith('`') || query.sql[i + 1].startsWith('"')) {
sql = sql.slice(0, -1) + this.escapeId('' + value) + query.sql[i + 1].slice(1);
} else {
paramCount++;
sql += `$${paramCount}` + query.sql[i + 1];
values.push(value);
}
}
return [sql, values];
}
override _query(query: string, values: BasicSQLValue[]) {
return this.connection.query(query, values).then(res => res.rows);
}
override escapeId(id: string) {
// @ts-expect-error @types/pg really needs to be updated
return pg.escapeIdentifier(id);
}
}
Loading

0 comments on commit 7bbc471

Please sign in to comment.