diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 090bc3ee..c99cb14f 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -8,18 +8,16 @@ on: pull_request: workflow_dispatch: +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + jobs: build: runs-on: ubuntu-latest env: RUST_LOG: info steps: - - uses: styfle/cancel-workflow-action@0.12.1 - name: Cancel Outdated Builds - with: - all_but_latest: true - access_token: ${{ github.token }} - - uses: actions/checkout@v4 name: Checkout Repository @@ -36,21 +34,12 @@ jobs: # Run Clippy on all targets. The lint workflow doesn't run Clippy on tests, because the tests # don't compile with all combinations of features. - - name: Clippy + - name: Clippy(all-features) run: cargo clippy --workspace --all-features --all-targets -- -D warnings - # Install nextest - - name: Install Nextest - run: cargo install cargo-nextest - - - name: Test - run: | - cargo nextest run --workspace --release --all-features - timeout-minutes: 60 + - name: Clippy(no-storage) + run: cargo clippy --workspace --features no-storage --all-targets -- -D warnings - - name: Doc Test - run: cargo test --release --all-features --doc - - name: Generate Documentation run: | cargo doc --no-deps --lib --release --all-features @@ -63,3 +52,42 @@ jobs: github_token: ${{ secrets.GITHUB_TOKEN }} publish_dir: ./target/doc cname: tide-disco.docs.espressosys.com + test-sqlite: + runs-on: ubuntu-latest + env: + RUST_LOG: info + steps: + - uses: actions/checkout@v4 + name: Checkout Repository + + - uses: Swatinem/rust-cache@v2 + name: Enable Rust Caching + + # Install nextest + - name: Install Nextest + run: cargo install cargo-nextest + + - name: Test + run: | + cargo nextest run --workspace --release --all-features + timeout-minutes: 60 + + test-postgres: + runs-on: ubuntu-latest + env: + RUST_LOG: info + steps: + - uses: actions/checkout@v4 + name: Checkout Repository + + - uses: Swatinem/rust-cache@v2 + name: Enable Rust Caching + + # Install nextest + - name: Install Nextest + run: cargo install cargo-nextest + + - name: Test + run: | + cargo nextest run --workspace --release --features "no-storage, testing" + timeout-minutes: 60 diff --git a/.gitignore b/.gitignore index 35f4835c..b00a88f4 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,7 @@ lcov.info /vsc -/.vscode \ No newline at end of file +/.vscode + +# for sqlite databases created during the tests +/tmp \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index e6a162b7..e279aa46 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,11 @@ license = "GPL-3.0-or-later" [features] default = ["file-system-data-source", "metrics-data-source", "sql-data-source"] +# Enables support for an embedded SQLite database instead of PostgreSQL. +# Ideal for lightweight nodes that benefit from pruning and merklized state storage, +# offering advantages over file system storage. +embedded-db = [] + # Enable the availability data source backed by the local file system. file-system-data-source = ["atomic_store"] diff --git a/examples/simple-server.rs b/examples/simple-server.rs index 503b61bf..42fdba6f 100644 --- a/examples/simple-server.rs +++ b/examples/simple-server.rs @@ -84,15 +84,20 @@ async fn init_db() -> Db { } #[cfg(not(target_os = "windows"))] -async fn init_data_source(db: &Db) -> DataSource { - data_source::sql::Config::default() - .user("postgres") - .password("password") - .host(db.host()) - .port(db.port()) - .connect(Default::default()) - .await - .unwrap() +async fn init_data_source(#[allow(unused_variables)] db: &Db) -> DataSource { + let mut cfg = data_source::sql::Config::default(); + + #[cfg(not(feature = "embedded-db"))] + { + cfg = cfg.host(db.host()).port(db.port()); + } + + #[cfg(feature = "embedded-db")] + { + cfg = cfg.db_path(db.path()); + } + + cfg.connect(Default::default()).await.unwrap() } #[cfg(target_os = "windows")] diff --git a/migrations/V100__drop_leaf_payload.sql b/migrations/postgres/V100__drop_leaf_payload.sql similarity index 100% rename from migrations/V100__drop_leaf_payload.sql rename to migrations/postgres/V100__drop_leaf_payload.sql diff --git a/migrations/V10__init_schema.sql b/migrations/postgres/V10__init_schema.sql similarity index 100% rename from migrations/V10__init_schema.sql rename to migrations/postgres/V10__init_schema.sql diff --git a/migrations/V200__create_aggregates_table.sql b/migrations/postgres/V200__create_aggregates_table.sql similarity index 100% rename from migrations/V200__create_aggregates_table.sql rename to migrations/postgres/V200__create_aggregates_table.sql diff --git a/migrations/V20__payload_hash_index.sql b/migrations/postgres/V20__payload_hash_index.sql similarity index 100% rename from migrations/V20__payload_hash_index.sql rename to migrations/postgres/V20__payload_hash_index.sql diff --git a/migrations/V300__transactions_count.sql b/migrations/postgres/V300__transactions_count.sql similarity index 100% rename from migrations/V300__transactions_count.sql rename to migrations/postgres/V300__transactions_count.sql diff --git a/migrations/V30__drop_leaf_block_hash_fkey_constraint.sql b/migrations/postgres/V30__drop_leaf_block_hash_fkey_constraint.sql similarity index 100% rename from migrations/V30__drop_leaf_block_hash_fkey_constraint.sql rename to migrations/postgres/V30__drop_leaf_block_hash_fkey_constraint.sql diff --git a/migrations/postgres/V400__rename_transaction_table.sql b/migrations/postgres/V400__rename_transaction_table.sql new file mode 100644 index 00000000..06fac653 --- /dev/null +++ b/migrations/postgres/V400__rename_transaction_table.sql @@ -0,0 +1,5 @@ +ALTER TABLE transaction + RENAME TO transactions; + +ALTER TABLE transactions + RENAME COLUMN index TO idx; \ No newline at end of file diff --git a/migrations/sqlite/V100__init_schema.sql b/migrations/sqlite/V100__init_schema.sql new file mode 100644 index 00000000..876fd7a5 --- /dev/null +++ b/migrations/sqlite/V100__init_schema.sql @@ -0,0 +1,76 @@ +CREATE TABLE header +( + height BIGINT PRIMARY KEY, + hash TEXT NOT NULL UNIQUE, + payload_hash TEXT NOT NULL, + timestamp BIGINT NOT NULL, + + -- For convenience, we store the entire application-specific header type as JSON. Just like + -- `leaf.leaf` and `leaf.qc`, this allows us to easily reconstruct the entire header using + -- `serde_json`, and to run queries and create indexes on application-specific header fields + -- without having a specific column for those fields. In many cases, this will enable new + -- application-specific API endpoints to be implemented without altering the schema (beyond + -- possibly adding an index for performance reasons). + data JSONB NOT NULL +); + +CREATE INDEX header_timestamp_idx ON header (timestamp); + +CREATE TABLE payload +( + height BIGINT PRIMARY KEY REFERENCES header (height) ON DELETE CASCADE, + size INTEGER, + data BLOB, + num_transactions INTEGER +); + +CREATE TABLE vid +( + height BIGINT PRIMARY KEY REFERENCES header (height) ON DELETE CASCADE, + common BLOB NOT NULL, + share BLOB +); + +CREATE TABLE leaf +( + height BIGINT PRIMARY KEY REFERENCES header (height) ON DELETE CASCADE, + hash TEXT NOT NULL UNIQUE, + block_hash TEXT NOT NULL, + + -- For convenience, we store the entire leaf and justifying QC as JSON blobs. There is a bit of + -- redundancy here with the indexed fields above, but it makes it easy to reconstruct the entire + -- leaf without depending on the specific fields of the application-specific leaf type. We + -- choose JSON over a binary format, even though it has a larger storage footprint, because + -- Postgres actually has decent JSON support: we don't have to worry about escaping non-ASCII + -- characters in inputs, and we can even do queries on the JSON and add indices over sub-objects + -- of the JSON blobs. + leaf JSONB NOT NULL, + qc JSONB NOT NULL +); + +CREATE TABLE transactions +( + hash TEXT NOT NULL, + -- Block containing this transaction. + block_height BIGINT NOT NULL REFERENCES header(height) ON DELETE CASCADE, + -- Position within the block. Transaction indices are an application-specific type, so we store + -- it as a serialized blob. We use JSON instead of a binary format so that the application can + -- make use of the transaction index in its own SQL queries. + idx JSONB NOT NULL, + PRIMARY KEY (block_height, idx) +); +-- This index is not unique, because nothing stops HotShot from sequencing duplicate transactions. +CREATE INDEX transaction_hash ON transactions (hash); + +CREATE TABLE pruned_height ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + -- The height of the last pruned block. + last_height BIGINT NOT NULL +); + +CREATE TABLE last_merklized_state_height ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + height BIGINT NOT NULL +); + +CREATE INDEX header_payload_hash_idx ON header (payload_hash); \ No newline at end of file diff --git a/migrations/sqlite/V200__create_aggregates_table.sql b/migrations/sqlite/V200__create_aggregates_table.sql new file mode 100644 index 00000000..b81e5a8e --- /dev/null +++ b/migrations/sqlite/V200__create_aggregates_table.sql @@ -0,0 +1,5 @@ +CREATE TABLE aggregate ( + height BIGINT PRIMARY KEY REFERENCES header (height) ON DELETE CASCADE, + num_transactions BIGINT NOT NULL, + payload_size BIGINT NOT NULL +); \ No newline at end of file diff --git a/src/data_source/sql.rs b/src/data_source/sql.rs index a6100661..1c2e80ea 100644 --- a/src/data_source/sql.rs +++ b/src/data_source/sql.rs @@ -26,10 +26,12 @@ pub use anyhow::Error; use hotshot_types::traits::node_implementation::NodeType; pub use refinery::Migration; -pub use sql::{Config, Transaction}; +pub use sql::Transaction; pub type Builder = fetching::Builder; +pub type Config = sql::Config; + impl Config { /// Connect to the database with this config. pub async fn connect>( @@ -78,9 +80,11 @@ impl Config { /// /// ## Initialization /// -/// When creating a [`SqlDataSource`], the caller can use [`Config`] to specify the host, user, and -/// database to connect to. As such, [`SqlDataSource`] is not very opinionated about how the -/// Postgres instance is set up. The administrator must simply ensure that there is a database +/// When creating a PostgreSQL [`SqlDataSource`], the caller can use [`Config`] to specify the host, user, and +/// database for the connection. If the `embedded-db` feature is enabled, the caller can instead specify the +/// file path for an SQLite database. +/// As such, [`SqlDataSource`] is not very opinionated about how the +/// database instance is set up. The administrator must simply ensure that there is a database /// dedicated to the [`SqlDataSource`] and a user with appropriate permissions (all on `SCHEMA` and /// all on `DATABASE`) over that database. /// @@ -96,10 +100,13 @@ impl Config { /// GRANT ALL ON DATABASE hotshot_query_service TO hotshot_user WITH GRANT OPTION; /// ``` /// -/// One could then connect to this database with the following [`Config`]: +/// For SQLite, simply provide the file path, and the file will be created if it does not already exist. +/// +/// One could then connect to this database with the following [`Config`] for postgres: /// /// ``` /// # use hotshot_query_service::data_source::sql::Config; +/// #[cfg(not(feature= "embedded-db"))] /// Config::default() /// .host("postgres.database.hostname") /// .database("hotshot_query_service") @@ -107,7 +114,15 @@ impl Config { /// .password("password") /// # ; /// ``` +/// Or, if the `embedded-db` feature is enabled, configure it as follows for SQLite: /// +/// ``` +/// # use hotshot_query_service::data_source::sql::Config; +/// #[cfg(feature= "embedded-db")] +/// Config::default() +/// .db_path("temp.db".into()) +/// # ; +/// ``` /// ## Resetting /// /// In general, resetting the database when necessary is left up to the administrator. However, for diff --git a/src/data_source/storage/sql.rs b/src/data_source/storage/sql.rs index 07941fc0..c38275db 100644 --- a/src/data_source/storage/sql.rs +++ b/src/data_source/storage/sql.rs @@ -24,17 +24,22 @@ use crate::{ }; use async_trait::async_trait; use chrono::Utc; -use futures::future::FutureExt; + use hotshot_types::traits::metrics::Metrics; use itertools::Itertools; use log::LevelFilter; + +#[cfg(not(feature = "embedded-db"))] +use futures::future::FutureExt; +#[cfg(not(feature = "embedded-db"))] +use sqlx::postgres::{PgConnectOptions, PgSslMode}; +#[cfg(feature = "embedded-db")] +use sqlx::sqlite::SqliteConnectOptions; use sqlx::{ pool::{Pool, PoolOptions}, - postgres::{PgConnectOptions, PgSslMode}, ConnectOptions, Row, }; use std::{cmp::min, fmt::Debug, str::FromStr, time::Duration}; - pub extern crate sqlx; pub use sqlx::{Database, Sqlite}; @@ -48,7 +53,7 @@ pub use anyhow::Error; // in the expansion of `include_migrations`, even when `include_migrations` is invoked from another // crate which doesn't have `include_dir` as a dependency. pub use crate::include_migrations; -pub use db::Db; +pub use db::*; pub use include_dir::include_dir; pub use queries::QueryBuilder; pub use refinery::Migration; @@ -56,7 +61,7 @@ pub use transaction::*; use self::{migrate::Migrator, transaction::PoolMetrics}; -/// Embed migrations from the given directory into the current binary. +/// Embed migrations from the given directory into the current binary for PostgreSQL or SQLite. /// /// The macro invocation `include_migrations!(path)` evaluates to an expression of type `impl /// Iterator`. Each migration must be a text file which is an immediate child of @@ -71,15 +76,23 @@ use self::{migrate::Migrator, transaction::PoolMetrics}; /// /// As an example, this is the invocation used to load the default migrations from the /// `hotshot-query-service` crate. The migrations are located in a directory called `migrations` at -/// the root of the crate. +/// - PostgreSQL migrations are in `/migrations/postgres`. +/// - SQLite migrations are in `/migrations/sqlite`. /// /// ``` /// # use hotshot_query_service::data_source::sql::{include_migrations, Migration}; +/// // For PostgreSQL +/// #[cfg(not(feature = "embedded-db"))] +/// let mut migrations: Vec = +/// include_migrations!("$CARGO_MANIFEST_DIR/migrations/postgres").collect(); +/// // For SQLite +/// #[cfg(feature = "embedded-db")] /// let mut migrations: Vec = -/// include_migrations!("$CARGO_MANIFEST_DIR/migrations").collect(); -/// migrations.sort(); -/// assert_eq!(migrations[0].version(), 10); -/// assert_eq!(migrations[0].name(), "init_schema"); +/// include_migrations!("$CARGO_MANIFEST_DIR/migrations/sqlite").collect(); +/// +/// migrations.sort(); +/// assert_eq!(migrations[0].version(), 10); +/// assert_eq!(migrations[0].name(), "init_schema"); /// ``` /// /// Note that a similar macro is available from Refinery: @@ -117,7 +130,13 @@ macro_rules! include_migrations { /// The migrations requied to build the default schema for this version of [`SqlStorage`]. pub fn default_migrations() -> Vec { - let mut migrations = include_migrations!("$CARGO_MANIFEST_DIR/migrations").collect::>(); + #[cfg(not(feature = "embedded-db"))] + let mut migrations = + include_migrations!("$CARGO_MANIFEST_DIR/migrations/postgres").collect::>(); + + #[cfg(feature = "embedded-db")] + let mut migrations = + include_migrations!("$CARGO_MANIFEST_DIR/migrations/sqlite").collect::>(); // Check version uniqueness and sort by version. validate_migrations(&mut migrations).expect("default migrations are invalid"); @@ -184,28 +203,64 @@ fn add_custom_migrations( .map(|pair| pair.reduce(|_, custom| custom)) } -/// Postgres client config. -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct Config { + #[cfg(feature = "embedded-db")] + db_opt: SqliteConnectOptions, + #[cfg(not(feature = "embedded-db"))] db_opt: PgConnectOptions, pool_opt: PoolOptions, + #[cfg(not(feature = "embedded-db"))] schema: String, reset: bool, migrations: Vec, no_migrations: bool, pruner_cfg: Option, archive: bool, + pool: Option>, } +#[cfg(not(feature = "embedded-db"))] impl Default for Config { fn default() -> Self { PgConnectOptions::default() + .username("postgres") + .password("password") .host("localhost") .port(5432) .into() } } +#[cfg(feature = "embedded-db")] +impl Default for Config { + fn default() -> Self { + SqliteConnectOptions::default() + .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal) + .busy_timeout(Duration::from_secs(30)) + .auto_vacuum(sqlx::sqlite::SqliteAutoVacuum::Incremental) + .create_if_missing(true) + .into() + } +} + +#[cfg(feature = "embedded-db")] +impl From for Config { + fn from(db_opt: SqliteConnectOptions) -> Self { + Self { + db_opt, + pool_opt: PoolOptions::default(), + reset: false, + migrations: vec![], + no_migrations: false, + pruner_cfg: None, + archive: false, + pool: None, + } + } +} + +#[cfg(not(feature = "embedded-db"))] impl From for Config { fn from(db_opt: PgConnectOptions) -> Self { Self { @@ -217,10 +272,12 @@ impl From for Config { no_migrations: false, pruner_cfg: None, archive: false, + pool: None, } } } +#[cfg(not(feature = "embedded-db"))] impl FromStr for Config { type Err = ::Err; @@ -229,6 +286,29 @@ impl FromStr for Config { } } +#[cfg(feature = "embedded-db")] +impl FromStr for Config { + type Err = ::Err; + + fn from_str(s: &str) -> Result { + Ok(SqliteConnectOptions::from_str(s)?.into()) + } +} + +#[cfg(feature = "embedded-db")] +impl Config { + pub fn busy_timeout(mut self, timeout: Duration) -> Self { + self.db_opt = self.db_opt.busy_timeout(timeout); + self + } + + pub fn db_path(mut self, path: std::path::PathBuf) -> Self { + self.db_opt = self.db_opt.filename(path); + self + } +} + +#[cfg(not(feature = "embedded-db"))] impl Config { /// Set the hostname of the database server. /// @@ -281,6 +361,15 @@ impl Config { self.schema = schema.into(); self } +} + +impl Config { + /// Sets the database connection pool + /// This allows reusing an existing connection pool when building a new `SqlStorage` instance. + pub fn pool(mut self, pool: Pool) -> Self { + self.pool = Some(pool); + self + } /// Reset the schema on connection. /// @@ -384,7 +473,7 @@ impl Config { } /// Storage for the APIs provided in this crate, backed by a remote PostgreSQL database. -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct SqlStorage { pool: Pool, metrics: PrometheusMetrics, @@ -400,32 +489,59 @@ pub struct Pruner { } impl SqlStorage { + pub fn pool(&self) -> Pool { + self.pool.clone() + } /// Connect to a remote database. pub async fn connect(mut config: Config) -> Result { + let metrics = PrometheusMetrics::default(); + let pool_metrics = PoolMetrics::new(&*metrics.subgroup("sql".into())); + let pool = config.pool_opt.clone(); + let pruner_cfg = config.pruner_cfg; + + // re-use the same pool if present and return early + if let Some(pool) = config.pool { + return Ok(Self { + metrics, + pool_metrics, + pool, + pruner_cfg, + }); + } + + #[cfg(not(feature = "embedded-db"))] let schema = config.schema.clone(); - let pool = config - .pool_opt - .after_connect(move |conn, _| { - let schema = schema.clone(); - async move { - query(&format!("SET search_path TO {schema}")) - .execute(conn) - .await?; - Ok(()) - } - .boxed() - }) - .connect_with(config.db_opt) - .await?; + #[cfg(not(feature = "embedded-db"))] + let pool = pool.after_connect(move |conn, _| { + let schema = config.schema.clone(); + async move { + query(&format!("SET search_path TO {schema}")) + .execute(conn) + .await?; + Ok(()) + } + .boxed() + }); + + #[cfg(feature = "embedded-db")] + if config.reset { + std::fs::remove_file(config.db_opt.get_filename())?; + } + + let pool = pool.connect_with(config.db_opt).await?; // Create or connect to the schema for this query service. let mut conn = pool.acquire().await?; + + #[cfg(not(feature = "embedded-db"))] if config.reset { - query(&format!("DROP SCHEMA IF EXISTS {} CASCADE", config.schema)) + query(&format!("DROP SCHEMA IF EXISTS {} CASCADE", schema)) .execute(conn.as_mut()) .await?; } - query(&format!("CREATE SCHEMA IF NOT EXISTS {}", config.schema)) + + #[cfg(not(feature = "embedded-db"))] + query(&format!("CREATE SCHEMA IF NOT EXISTS {}", schema)) .execute(conn.as_mut()) .await?; @@ -471,12 +587,13 @@ impl SqlStorage { .await?; } - let metrics = PrometheusMetrics::default(); + conn.close().await?; + Ok(Self { pool, - pool_metrics: PoolMetrics::new(&*metrics.subgroup("sql".into())), + pool_metrics, metrics, - pruner_cfg: config.pruner_cfg, + pruner_cfg, }) } } @@ -545,10 +662,17 @@ impl PruneStorage for SqlStorage { async fn get_disk_usage(&self) -> anyhow::Result { let mut tx = self.read().await?; - let row = tx - .fetch_one("SELECT pg_database_size(current_database())") - .await?; + + #[cfg(not(feature = "embedded-db"))] + let query = "SELECT pg_database_size(current_database())"; + + #[cfg(feature = "embedded-db")] + let query = " + SELECT( (SELECT page_count FROM pragma_page_count) * (SELECT * FROM pragma_page_size)) AS total_bytes"; + + let row = tx.fetch_one(query).await?; let size: i64 = row.get(0); + Ok(size as u64) } @@ -672,38 +796,69 @@ impl VersionedDataSource for SqlStorage { // These tests run the `postgres` Docker image, which doesn't work on Windows. #[cfg(all(any(test, feature = "testing"), not(target_os = "windows")))] pub mod testing { + #![allow(unused_imports)] + use refinery::Migration; use std::{ env, process::{Command, Stdio}, - str, + str::{self, FromStr}, time::Duration, }; use tokio::net::TcpStream; use tokio::time::timeout; use portpicker::pick_unused_port; - use refinery::Migration; use super::Config; use crate::testing::sleep; #[derive(Debug)] pub struct TmpDb { + #[cfg(not(feature = "embedded-db"))] host: String, + #[cfg(not(feature = "embedded-db"))] port: u16, + #[cfg(not(feature = "embedded-db"))] container_id: String, + #[cfg(feature = "embedded-db")] + db_path: std::path::PathBuf, + #[allow(dead_code)] persistent: bool, } impl TmpDb { + #[cfg(feature = "embedded-db")] + fn init_sqlite_db(persistent: bool) -> Self { + let file = tempfile::Builder::new() + .prefix("sqlite-") + .suffix(".db") + .tempfile() + .unwrap(); + + let (_, db_path) = file.keep().unwrap(); + + Self { + db_path, + persistent, + } + } pub async fn init() -> Self { - Self::init_inner(false).await + #[cfg(feature = "embedded-db")] + return Self::init_sqlite_db(false); + + #[cfg(not(feature = "embedded-db"))] + Self::init_postgres(false).await } pub async fn persistent() -> Self { - Self::init_inner(true).await + #[cfg(feature = "embedded-db")] + return Self::init_sqlite_db(true); + + #[cfg(not(feature = "embedded-db"))] + Self::init_postgres(true).await } - async fn init_inner(persistent: bool) -> Self { + #[cfg(not(feature = "embedded-db"))] + async fn init_postgres(persistent: bool) -> Self { let docker_hostname = env::var("DOCKER_HOSTNAME"); // This picks an unused port on the current system. If docker is // configured to run on a different host then this may not find a @@ -718,9 +873,11 @@ pub mod testing { .arg("-d") .args(["-p", &format!("{port}:5432")]) .args(["-e", "POSTGRES_PASSWORD=password"]); + if !persistent { cmd.arg("--rm"); } + let output = cmd.arg("postgres").output().unwrap(); let stdout = str::from_utf8(&output.stdout).unwrap(); let stderr = str::from_utf8(&output.stderr).unwrap(); @@ -743,28 +900,50 @@ pub mod testing { db } + #[cfg(not(feature = "embedded-db"))] pub fn host(&self) -> String { self.host.clone() } + #[cfg(not(feature = "embedded-db"))] pub fn port(&self) -> u16 { self.port } + #[cfg(feature = "embedded-db")] + pub fn path(&self) -> std::path::PathBuf { + self.db_path.clone() + } + pub fn config(&self) -> Config { - Config::default() + #[cfg(feature = "embedded-db")] + let mut cfg: Config = { + let db_path = self.db_path.to_string_lossy(); + let path = format!("sqlite:{db_path}"); + sqlx::sqlite::SqliteConnectOptions::from_str(&path) + .expect("invalid db path") + .create_if_missing(true) + .into() + }; + + #[cfg(not(feature = "embedded-db"))] + let mut cfg = Config::default() .user("postgres") .password("password") .host(self.host()) - .port(self.port()) - .migrations(vec![Migration::unapplied( - "V11__create_test_merkle_tree_table.sql", - &TestMerkleTreeMigration::create("test_tree"), - ) - .unwrap()]) + .port(self.port()); + + cfg = cfg.migrations(vec![Migration::unapplied( + "V101__create_test_merkle_tree_table.sql", + &TestMerkleTreeMigration::create("test_tree"), + ) + .unwrap()]); + + cfg } - pub fn stop(&mut self) { + #[cfg(not(feature = "embedded-db"))] + pub fn stop_postgres(&mut self) { tracing::info!(container = self.container_id, "stopping postgres"); let output = Command::new("docker") .args(["stop", self.container_id.as_str()]) @@ -778,7 +957,8 @@ pub mod testing { ); } - pub async fn start(&mut self) { + #[cfg(not(feature = "embedded-db"))] + pub async fn start_postgres(&mut self) { tracing::info!(container = self.container_id, "resuming postgres"); let output = Command::new("docker") .args(["start", self.container_id.as_str()]) @@ -794,6 +974,7 @@ pub mod testing { self.wait_for_ready().await; } + #[cfg(not(feature = "embedded-db"))] async fn wait_for_ready(&self) { let timeout_duration = Duration::from_secs( env::var("SQL_TMP_DB_CONNECT_TIMEOUT") @@ -858,20 +1039,18 @@ pub mod testing { } } + #[cfg(not(feature = "embedded-db"))] impl Drop for TmpDb { fn drop(&mut self) { - self.stop(); - if self.persistent { - let output = Command::new("docker") - .args(["container", "rm", self.container_id.as_str()]) - .output() - .unwrap(); - assert!( - output.status.success(), - "error removing postgres docker {}: {}", - self.container_id, - str::from_utf8(&output.stderr).unwrap() - ); + self.stop_postgres(); + } + } + + #[cfg(feature = "embedded-db")] + impl Drop for TmpDb { + fn drop(&mut self) { + if !self.persistent { + std::fs::remove_file(self.db_path.clone()).unwrap(); } } } @@ -880,29 +1059,44 @@ pub mod testing { impl TestMerkleTreeMigration { fn create(name: &str) -> String { + let (bit_vec, binary, hash_pk, root_stored_column) = if cfg!(feature = "embedded-db") { + ( + "TEXT", + "BLOB", + "INTEGER PRIMARY KEY AUTOINCREMENT", + " (json_extract(data, '$.test_merkle_tree_root'))", + ) + } else { + ( + "BIT(8)", + "BYTEA", + "SERIAL PRIMARY KEY", + "(data->>'test_merkle_tree_root')", + ) + }; + format!( "CREATE TABLE IF NOT EXISTS hash ( - id SERIAL PRIMARY KEY, - value BYTEA NOT NULL UNIQUE + id {hash_pk}, + value {binary} NOT NULL UNIQUE ); - - + ALTER TABLE header ADD column test_merkle_tree_root text - GENERATED ALWAYS as (data->>'test_merkle_tree_root') STORED; + GENERATED ALWAYS as {root_stored_column} STORED; CREATE TABLE {name} ( - path integer[] NOT NULL, + path JSONB NOT NULL, created BIGINT NOT NULL, - hash_id INT NOT NULL REFERENCES hash (id), - children INT[], - children_bitvec BIT(8), - index JSONB, - entry JSONB + hash_id INT NOT NULL, + children JSONB, + children_bitvec {bit_vec}, + idx JSONB, + entry JSONB, + PRIMARY KEY (path, created) ); - ALTER TABLE {name} ADD CONSTRAINT {name}_pk PRIMARY KEY (path, created); CREATE INDEX {name}_created ON {name} (created);" ) } @@ -931,21 +1125,18 @@ mod test { setup_test(); let db = TmpDb::init().await; - let port = db.port(); - let host = &db.host(); - - let connect = |migrations: bool, custom_migrations| async move { - let mut cfg = Config::default() - .user("postgres") - .password("password") - .host(host) - .port(port) - .migrations(custom_migrations); - if !migrations { - cfg = cfg.no_migrations(); + let cfg = db.config(); + + let connect = |migrations: bool, custom_migrations| { + let cfg = cfg.clone(); + async move { + let mut cfg = cfg.migrations(custom_migrations); + if !migrations { + cfg = cfg.no_migrations(); + } + let client = SqlStorage::connect(cfg).await?; + Ok::<_, Error>(client) } - let client = SqlStorage::connect(cfg).await?; - Ok::<_, Error>(client) }; // Connecting with migrations disabled should fail if the database is not already up to date @@ -967,7 +1158,11 @@ mod test { "ALTER TABLE test ADD COLUMN data INTEGER;", ) .unwrap(), - Migration::unapplied("V998__create_test_table.sql", "CREATE TABLE test ();").unwrap(), + Migration::unapplied( + "V998__create_test_table.sql", + "CREATE TABLE test (x bigint);", + ) + .unwrap(), ]; connect(true, migrations.clone()).await.unwrap(); @@ -981,6 +1176,7 @@ mod test { } #[test] + #[cfg(not(feature = "embedded-db"))] fn test_config_from_str() { let cfg = Config::from_str("postgresql://user:password@host:8080").unwrap(); assert_eq!(cfg.db_opt.get_username(), "user"); @@ -988,6 +1184,13 @@ mod test { assert_eq!(cfg.db_opt.get_port(), 8080); } + #[test] + #[cfg(feature = "embedded-db")] + fn test_config_from_str() { + let cfg = Config::from_str("sqlite://data.db").unwrap(); + assert_eq!(cfg.db_opt.get_filename().to_string_lossy(), "data.db"); + } + async fn vacuum(storage: &SqlStorage) { storage .pool @@ -1004,14 +1207,7 @@ mod test { setup_test(); let db = TmpDb::init().await; - let port = db.port(); - let host = &db.host(); - - let cfg = Config::default() - .user("postgres") - .password("password") - .host(host) - .port(port); + let cfg = db.config(); let mut storage = SqlStorage::connect(cfg).await.unwrap(); let mut leaf = LeafQueryData::::genesis::( @@ -1179,14 +1375,7 @@ mod test { setup_test(); let db = TmpDb::init().await; - let port = db.port(); - let host = &db.host(); - - let cfg = Config::default() - .user("postgres") - .password("password") - .host(host) - .port(port); + let cfg = db.config(); let storage = SqlStorage::connect(cfg).await.unwrap(); assert!(storage diff --git a/src/data_source/storage/sql/db.rs b/src/data_source/storage/sql/db.rs index 94a868d6..5b3c86aa 100644 --- a/src/data_source/storage/sql/db.rs +++ b/src/data_source/storage/sql/db.rs @@ -10,24 +10,25 @@ // You should have received a copy of the GNU General Public License along with this program. If not, // see . -/// The concrete database backing a SQL data source. +/// The underlying database type for a SQL data source. /// -/// Currently only Postgres is supported. In the future we can support SQLite as well by making this -/// an enum with variants for each (we'll then need to create enums and trait implementations for -/// all the associated types as well; it will be messy). +/// Currently, only PostgreSQL and SQLite are supported, with selection based on the "embedded-db" feature flag. +/// - When the "embedded-db" feature is enabled, SQLite is used. +/// - When it’s disabled, PostgreSQL is used. /// -/// The reason for taking this approach over sqlx's `Any` database is that we can support SQL types +/// ### Design Choice +/// The reason for taking this approach over sqlx's Any database is that we can support SQL types /// which are implemented for the two backends we care about (Postgres and SQLite) but not for _any_ /// SQL database, such as MySQL. Crucially, JSON types fall in this category. /// /// The reason for taking this approach rather than writing all of our code to be generic over the -/// `Database` implementation is that `sqlx` does not have the necessary trait bounds on all of the -/// associated types (e.g. `Database::Connection` does not implement `Executor` for all possible -/// databases, the `Executor` impl lives on each concrete connection type) and Rust does not provide +/// Database implementation is that sqlx does not have the necessary trait bounds on all of the +/// associated types (e.g. Database::Connection does not implement Executor for all possible +/// databases, the Executor impl lives on each concrete connection type) and Rust does not provide /// a good way of encapsulating a collection of trait bounds on associated types. Thus, our function /// signatures become untenably messy with bounds like /// -/// ``` +/// ```rust /// # use sqlx::{Database, Encode, Executor, IntoArguments, Type}; /// fn foo() /// where @@ -36,5 +37,20 @@ /// for<'a> i64: Type + Encode<'a, DB>, /// {} /// ``` -/// etc. + +#[cfg(feature = "embedded-db")] +pub type Db = sqlx::Sqlite; +#[cfg(not(feature = "embedded-db"))] pub type Db = sqlx::Postgres; + +#[cfg(feature = "embedded-db")] +pub mod syntax_helpers { + pub const MAX_FN: &str = "MAX"; + pub const BINARY_TYPE: &str = "BLOB"; +} + +#[cfg(not(feature = "embedded-db"))] +pub mod syntax_helpers { + pub const MAX_FN: &str = "GREATEST"; + pub const BINARY_TYPE: &str = "BYTEA"; +} diff --git a/src/data_source/storage/sql/queries.rs b/src/data_source/storage/sql/queries.rs index 77eef922..d2a94063 100644 --- a/src/data_source/storage/sql/queries.rs +++ b/src/data_source/storage/sql/queries.rs @@ -95,6 +95,7 @@ impl<'q> QueryBuilder<'q> { self.arguments.add(arg).map_err(|err| QueryError::Error { message: format!("{err:#}"), })?; + Ok(format!("${}", self.arguments.len())) } @@ -352,7 +353,7 @@ impl Transaction { "SELECT {HEADER_COLUMNS} FROM header AS h WHERE {where_clause} - ORDER BY h.height ASC + ORDER BY h.height LIMIT 1" ); let row = query.query(&sql).fetch_one(self.as_mut()).await?; diff --git a/src/data_source/storage/sql/queries/availability.rs b/src/data_source/storage/sql/queries/availability.rs index bf7caf99..ad54818d 100644 --- a/src/data_source/storage/sql/queries/availability.rs +++ b/src/data_source/storage/sql/queries/availability.rs @@ -67,7 +67,7 @@ where FROM header AS h JOIN payload AS p ON h.height = p.height WHERE {where_clause} - ORDER BY h.height ASC + ORDER BY h.height LIMIT 1" ); let row = query.query(&sql).fetch_one(self.as_mut()).await?; @@ -89,7 +89,7 @@ where FROM header AS h JOIN payload AS p ON h.height = p.height WHERE {where_clause} - ORDER BY h.height ASC + ORDER BY h.height LIMIT 1" ); let row = query.query(&sql).fetch_one(self.as_mut()).await?; @@ -135,7 +135,7 @@ where FROM header AS h JOIN vid AS v ON h.height = v.height WHERE {where_clause} - ORDER BY h.height ASC + ORDER BY h.height LIMIT 1" ); let row = query.query(&sql).fetch_one(self.as_mut()).await?; @@ -173,7 +173,7 @@ where { let mut query = QueryBuilder::default(); let where_clause = query.bounds_to_where_clause(range, "height")?; - let sql = format!("SELECT {LEAF_COLUMNS} FROM leaf {where_clause} ORDER BY height ASC"); + let sql = format!("SELECT {LEAF_COLUMNS} FROM leaf {where_clause} ORDER BY height"); Ok(query .query(&sql) .fetch(self.as_mut()) @@ -197,7 +197,7 @@ where FROM header AS h JOIN payload AS p ON h.height = p.height {where_clause} - ORDER BY h.height ASC" + ORDER BY h.height" ); Ok(query .query(&sql) @@ -222,7 +222,7 @@ where FROM header AS h JOIN payload AS p ON h.height = p.height {where_clause} - ORDER BY h.height ASC" + ORDER BY h.height" ); Ok(query .query(&sql) @@ -272,7 +272,7 @@ where FROM header AS h JOIN vid AS v ON h.height = v.height {where_clause} - ORDER BY h.height ASC" + ORDER BY h.height" ); Ok(query .query(&sql) @@ -318,12 +318,12 @@ where // ORDER BY ASC ensures that if there are duplicate transactions, we return the first // one. let sql = format!( - "SELECT {BLOCK_COLUMNS}, t.index AS tx_index + "SELECT {BLOCK_COLUMNS}, t.idx AS tx_index FROM header AS h JOIN payload AS p ON h.height = p.height - JOIN transaction AS t ON t.block_height = h.height + JOIN transactions AS t ON t.block_height = h.height WHERE t.hash = {hash_param} - ORDER BY t.block_height, t.index + ORDER BY t.block_height, t.idx LIMIT 1" ); let row = query.query(&sql).fetch_one(self.as_mut()).await?; diff --git a/src/data_source/storage/sql/queries/explorer.rs b/src/data_source/storage/sql/queries/explorer.rs index 7ce2c774..4b3b22ad 100644 --- a/src/data_source/storage/sql/queries/explorer.rs +++ b/src/data_source/storage/sql/queries/explorer.rs @@ -217,14 +217,14 @@ where // returned results based on. let transaction_target_query = match target { TransactionIdentifier::Latest => query( - "SELECT t.block_height AS height, t.index AS index FROM transaction AS t ORDER BY t.block_height DESC, t.index DESC LIMIT 1", + "SELECT t.block_height AS height, t.idx AS \"index\" FROM transactions AS t ORDER BY t.block_height DESC, t.idx DESC LIMIT 1", ), TransactionIdentifier::HeightAndOffset(height, _) => query( - "SELECT t.block_height AS height, t.index AS index FROM transaction AS t WHERE t.block_height = $1 ORDER BY t.block_height DESC, t.index DESC LIMIT 1", + "SELECT t.block_height AS height, t.idx AS \"index\" FROM transactions AS t WHERE t.block_height = $1 ORDER BY t.block_height DESC, t.idx DESC LIMIT 1", ) .bind(*height as i64), TransactionIdentifier::Hash(hash) => query( - "SELECT t.block_height AS height, t.index AS index FROM transaction AS t WHERE t.hash = $1 ORDER BY t.block_height DESC, t.index DESC LIMIT 1", + "SELECT t.block_height AS height, t.idx AS \"index\" FROM transactions AS t WHERE t.hash = $1 ORDER BY t.block_height DESC, t.idx DESC LIMIT 1", ) .bind(hash.to_string()), }; @@ -238,7 +238,8 @@ where }; let block_height = transaction_target.get::("height") as usize; - let transaction_index = transaction_target.get::>, _>("index"); + let transaction_index = + transaction_target.get_unchecked::>, _>("index"); let offset = if let TransactionIdentifier::HeightAndOffset(_, offset) = target { *offset } else { @@ -262,9 +263,9 @@ where JOIN payload AS p ON h.height = p.height WHERE h.height IN ( SELECT t.block_height - FROM transaction AS t - WHERE (t.block_height, t.index) <= ({}, {}) - ORDER BY t.block_height DESC, t.index DESC + FROM transactions AS t + WHERE (t.block_height, t.idx) <= ({}, {}) + ORDER BY t.block_height DESC, t.idx DESC LIMIT {} ) ORDER BY h.height DESC", @@ -338,7 +339,7 @@ where JOIN payload AS p ON h.height = p.height WHERE h.height = ( SELECT MAX(t1.block_height) - FROM transaction AS t1 + FROM transactions AS t1 ) ORDER BY h.height DESC" ), @@ -348,9 +349,9 @@ where JOIN payload AS p ON h.height = p.height WHERE h.height = ( SELECT t1.block_height - FROM transaction AS t1 + FROM transactions AS t1 WHERE t1.block_height = {} - ORDER BY t1.block_height, t1.index + ORDER BY t1.block_height, t1.idx OFFSET {} LIMIT 1 ) @@ -364,9 +365,9 @@ where JOIN payload AS p ON h.height = p.height WHERE h.height = ( SELECT t1.block_height - FROM transaction AS t1 + FROM transactions AS t1 WHERE t1.hash = {} - ORDER BY t1.block_height DESC, t1.index DESC + ORDER BY t1.block_height DESC, t1.idx DESC LIMIT 1 ) ORDER BY h.height DESC", @@ -420,7 +421,7 @@ where p.height = h.height WHERE h.height IN (SELECT height FROM header ORDER BY height DESC LIMIT 50) - ORDER BY h.height ASC + ORDER BY h.height ", ) .fetch(self.as_mut()); @@ -472,12 +473,14 @@ where let latest_block: BlockDetail = self.get_block_detail(BlockIdentifier::Latest).await?; + let latest_blocks: Vec> = self .get_block_summaries(GetBlockSummariesRequest(BlockRange { target: BlockIdentifier::Latest, num_blocks: NonZeroUsize::new(10).unwrap(), })) .await?; + let latest_transactions: Vec> = self .get_transaction_summaries(GetTransactionSummariesRequest { range: TransactionRange { @@ -535,7 +538,7 @@ where "SELECT {BLOCK_COLUMNS} FROM header AS h JOIN payload AS p ON h.height = p.height - JOIN transaction AS t ON h.height = t.block_height + JOIN transactions AS t ON h.height = t.block_height WHERE t.hash = $1 ORDER BY h.height DESC LIMIT 5" diff --git a/src/data_source/storage/sql/queries/node.rs b/src/data_source/storage/sql/queries/node.rs index dc708b37..70147e84 100644 --- a/src/data_source/storage/sql/queries/node.rs +++ b/src/data_source/storage/sql/queries/node.rs @@ -118,7 +118,7 @@ where "SELECT v.share AS share FROM vid AS v JOIN header AS h ON v.height = h.height WHERE {where_clause} - ORDER BY h.height ASC + ORDER BY h.height LIMIT 1" ); let (share_data,) = query @@ -159,7 +159,7 @@ where (SELECT count(*) AS null_payloads FROM payload WHERE data IS NULL) AS p, (SELECT count(*) AS total_vid FROM vid) AS v, (SELECT count(*) AS null_vid FROM vid WHERE share IS NULL) AS vn, - coalesce((SELECT last_height FROM pruned_height ORDER BY id DESC LIMIT 1)) as pruned_height + (SELECT(SELECT last_height FROM pruned_height ORDER BY id DESC LIMIT 1) as pruned_height) "; let row = query(sql) .fetch_optional(self.as_mut()) diff --git a/src/data_source/storage/sql/queries/state.rs b/src/data_source/storage/sql/queries/state.rs index aad05c4e..2febb182 100644 --- a/src/data_source/storage/sql/queries/state.rs +++ b/src/data_source/storage/sql/queries/state.rs @@ -16,6 +16,8 @@ use super::{ super::transaction::{query_as, Transaction, TransactionMode, Write}, DecodeError, QueryBuilder, }; +use crate::data_source::storage::sql::build_where_in; +use crate::data_source::storage::sql::sqlx::Row; use crate::{ data_source::storage::{MerklizedStateHeightStorage, MerklizedStateStorage}, merklized_state::{MerklizedState, Snapshot}, @@ -29,7 +31,8 @@ use jf_merkle_tree::{ prelude::{MerkleNode, MerkleProof}, DigestAlgorithm, MerkleCommitment, ToTraversalPath, }; -use sqlx::{types::BitVec, FromRow}; +use sqlx::types::BitVec; +use sqlx::types::JsonValue; use std::collections::{HashMap, HashSet, VecDeque}; use std::sync::Arc; @@ -56,38 +59,44 @@ where // Get all the nodes in the path to the index. // Order by pos DESC is to return nodes from the leaf to the root - let (query, sql) = build_get_path_query(state_type, traversal_path.clone(), created)?; - let nodes = query - .query_as::(&sql) - .fetch_all(self.as_mut()) - .await?; + let rows = query.query(&sql).fetch_all(self.as_mut()).await?; + + let nodes: Vec = rows.into_iter().map(|r| r.into()).collect(); // insert all the hash ids to a hashset which is used to query later // HashSet is used to avoid duplicates let mut hash_ids = HashSet::new(); - nodes.iter().for_each(|n| { - hash_ids.insert(n.hash_id); - if let Some(children) = &n.children { + for node in nodes.iter() { + hash_ids.insert(node.hash_id); + if let Some(children) = &node.children { + let children: Vec = + serde_json::from_value(children.clone()).map_err(|e| QueryError::Error { + message: format!("Error deserializing 'children' into Vec: {e}"), + })?; hash_ids.extend(children); } - }); + } // Find all the hash values and create a hashmap // Hashmap will be used to get the hash value of the nodes children and the node itself. - let hashes: HashMap> = - query_as("SELECT id, value FROM hash WHERE id = ANY( $1)") - .bind(hash_ids.into_iter().collect::>()) + let hashes = if !hash_ids.is_empty() { + let (query, sql) = build_where_in("SELECT id, value FROM hash", "id", hash_ids)?; + query + .query_as(&sql) .fetch(self.as_mut()) - .try_collect() - .await?; + .try_collect::>>() + .await? + } else { + HashMap::new() + }; let mut proof_path = VecDeque::with_capacity(State::tree_height()); for Node { hash_id, children, children_bitvec, - index, + idx, entry, .. } in nodes.iter() @@ -96,9 +105,18 @@ where let value = hashes.get(hash_id).ok_or(QueryError::Error { message: format!("node's value references non-existent hash {hash_id}"), })?; - match (children, children_bitvec, index, entry) { + + match (children, children_bitvec, idx, entry) { // If the row has children then its a branch (Some(children), Some(children_bitvec), None, None) => { + let children: Vec = + serde_json::from_value(children.clone()).map_err(|e| { + QueryError::Error { + message: format!( + "Error deserializing 'children' into Vec: {e}" + ), + } + })?; let mut children = children.iter(); // Reconstruct the Children MerkleNodes from storage. @@ -293,7 +311,9 @@ impl Transaction { }; // Make sure the requested snapshot is up to date. + let height = self.get_last_state_height().await?; + if height < (created as usize) { return Err(QueryError::NotFound); } @@ -319,15 +339,49 @@ pub(crate) fn build_hash_batch_insert( } // Represents a row in a state table -#[derive(Debug, Default, Clone, FromRow)] +#[derive(Debug, Default, Clone)] pub(crate) struct Node { - pub(crate) path: Vec, + pub(crate) path: JsonValue, pub(crate) created: i64, pub(crate) hash_id: i32, - pub(crate) children: Option>, + pub(crate) children: Option, pub(crate) children_bitvec: Option, - pub(crate) index: Option, - pub(crate) entry: Option, + pub(crate) idx: Option, + pub(crate) entry: Option, +} + +#[cfg(feature = "embedded-db")] +impl From for Node { + fn from(row: sqlx::sqlite::SqliteRow) -> Self { + let bit_string: Option = row.get_unchecked("children_bitvec"); + let children_bitvec: Option = + bit_string.map(|b| b.chars().map(|c| c == '1').collect()); + + Self { + path: row.get_unchecked("path"), + created: row.get_unchecked("created"), + hash_id: row.get_unchecked("hash_id"), + children: row.get_unchecked("children"), + children_bitvec, + idx: row.get_unchecked("idx"), + entry: row.get_unchecked("entry"), + } + } +} + +#[cfg(not(feature = "embedded-db"))] +impl From for Node { + fn from(row: sqlx::postgres::PgRow) -> Self { + Self { + path: row.get_unchecked("path"), + created: row.get_unchecked("created"), + hash_id: row.get_unchecked("hash_id"), + children: row.get_unchecked("children"), + children_bitvec: row.get_unchecked("children_bitvec"), + idx: row.get_unchecked("idx"), + entry: row.get_unchecked("entry"), + } + } } impl Node { @@ -344,18 +398,27 @@ impl Node { "hash_id", "children", "children_bitvec", - "index", + "idx", "entry", ], ["path", "created"], nodes.into_iter().map(|n| { + #[cfg(feature = "embedded-db")] + let children_bitvec: Option = n + .children_bitvec + .clone() + .map(|b| b.iter().map(|bit| if bit { '1' } else { '0' }).collect()); + + #[cfg(not(feature = "embedded-db"))] + let children_bitvec = n.children_bitvec.clone(); + ( n.path.clone(), n.created, n.hash_id, n.children.clone(), - n.children_bitvec.clone(), - n.index.clone(), + children_bitvec, + n.idx.clone(), n.entry.clone(), ) }), @@ -371,15 +434,20 @@ fn build_get_path_query<'q>( ) -> QueryResult<(QueryBuilder<'q>, String)> { let mut query = QueryBuilder::default(); let mut traversal_path = traversal_path.into_iter().map(|x| x as i32); - let created = query.bind(created)?; // We iterate through the path vector skipping the first element after each iteration let len = traversal_path.len(); let mut sub_queries = Vec::new(); + + query.bind(created)?; + for _ in 0..=len { - let node_path = query.bind(traversal_path.clone().rev().collect::>())?; + let path = traversal_path.clone().rev().collect::>(); + let path: serde_json::Value = path.into(); + let node_path = query.bind(path)?; + let sub_query = format!( - "(SELECT * FROM {table} WHERE path = {node_path} AND created <= {created} ORDER BY created DESC LIMIT 1)", + "SELECT * FROM (SELECT * FROM {table} WHERE path = {node_path} AND created <= $1 ORDER BY created DESC LIMIT 1)", ); sub_queries.push(sub_query); @@ -387,7 +455,17 @@ fn build_get_path_query<'q>( } let mut sql: String = sub_queries.join(" UNION "); - sql.push_str("ORDER BY path DESC"); + + sql = format!("SELECT * FROM ({sql}) as t "); + + // PostgreSQL already orders JSON arrays by length, so no additional function is needed + // For SQLite, `length()` is used to sort by length. + if cfg!(feature = "embedded-db") { + sql.push_str("ORDER BY length(t.path) DESC"); + } else { + sql.push_str("ORDER BY t.path DESC"); + } + Ok((query, sql)) } @@ -442,7 +520,7 @@ mod test { [( block_height as i64, format!("randomHash{i}"), - "t", + "t".to_string(), 0, test_data, )], @@ -505,7 +583,13 @@ mod test { "header", ["height", "hash", "payload_hash", "timestamp", "data"], ["height"], - [(2i64, "randomstring", "t", 0, test_data)], + [( + 2i64, + "randomstring".to_string(), + "t".to_string(), + 0, + test_data, + )], ) .await .unwrap(); @@ -538,14 +622,10 @@ mod test { // Find all the nodes of Index 0 in table let mut tx = storage.read().await.unwrap(); let rows = query("SELECT * from test_tree where path = $1 ORDER BY created") - .bind(node_path) + .bind(serde_json::to_value(node_path).unwrap()) .fetch(tx.as_mut()); - let nodes: Vec<_> = rows - .map(|res| Node::from_row(&res.unwrap())) - .try_collect() - .await - .unwrap(); + let nodes: Vec = rows.map(|res| res.unwrap().into()).collect().await; // There should be only 2 versions of this node assert!(nodes.len() == 2, "incorrect number of nodes"); assert_eq!(nodes[0].created, 1, "wrong block height"); @@ -600,7 +680,13 @@ mod test { "header", ["height", "hash", "payload_hash", "timestamp", "data"], ["height"], - [(block_height as i64, "randomString", "t", 0, test_data)], + [( + block_height as i64, + "randomString".to_string(), + "t".to_string(), + 0, + test_data, + )], ) .await .unwrap(); @@ -665,8 +751,8 @@ mod test { ["height"], [( 2i64, - "randomString2", - "t", + "randomString2".to_string(), + "t".to_string(), 0, serde_json::json!({ MockMerkleTree::header_state_commitment_field() : serde_json::to_value(test_tree.commitment()).unwrap()}), )], @@ -730,7 +816,7 @@ mod test { [( i as i64, format!("hash{i}"), - "t", + "t".to_string(), 0, serde_json::json!({ MockMerkleTree::header_state_commitment_field() : serde_json::to_value(test_tree.commitment()).unwrap()}) )], @@ -796,7 +882,13 @@ mod test { "header", ["height", "hash", "payload_hash", "timestamp", "data"], ["height"], - [(block_height as i64, "randomString", "t", 0, test_data)], + [( + block_height as i64, + "randomString".to_string(), + "t".to_string(), + 0, + test_data, + )], ) .await .unwrap(); @@ -862,7 +954,7 @@ mod test { [( block_height as i64, format!("rarndomString{i}"), - "t", + "t".to_string(), 0, test_data, )], @@ -924,7 +1016,13 @@ mod test { "header", ["height", "hash", "payload_hash", "timestamp", "data"], ["height"], - [(block_height as i64, "randomStringgg", "t", 0, test_data)], + [( + block_height as i64, + "randomStringgg".to_string(), + "t".to_string(), + 0, + test_data, + )], ) .await .unwrap(); @@ -955,7 +1053,13 @@ mod test { "header", ["height", "hash", "payload_hash", "timestamp", "data"], ["height"], - [(2i64, "randomHashString", "t", 0, test_data)], + [( + 2i64, + "randomHashString".to_string(), + "t".to_string(), + 0, + test_data, + )], ) .await .unwrap(); @@ -975,12 +1079,12 @@ mod test { .rev() .map(|n| *n as i32) .collect::>(); - tx.execute_one( + tx.execute( query(&format!( "DELETE FROM {} WHERE created = 2 and path = $1", MockMerkleTree::state_type() )) - .bind(node_path), + .bind(serde_json::to_value(node_path).unwrap()), ) .await .expect("failed to delete internal node"); @@ -1076,7 +1180,7 @@ mod test { [( block_height as i64, format!("hash{block_height}"), - "hash", + "hash".to_string(), 0i64, serde_json::json!({ MockMerkleTree::header_state_commitment_field() : serde_json::to_value(tree.commitment()).unwrap()}), )], @@ -1199,8 +1303,8 @@ mod test { ["height"], [( 0i64, - "hash", - "hash", + "hash".to_string(), + "hash".to_string(), 0, serde_json::json!({ MockMerkleTree::header_state_commitment_field() : serde_json::to_value(test_tree.commitment()).unwrap()}), )], @@ -1247,12 +1351,13 @@ mod test { // Now delete the leaf node for the last entry we inserted, corrupting the database. let index = serde_json::to_value(tree_size - 1).unwrap(); let mut tx = storage.write().await.unwrap(); - tx.execute_one_with_retries( - &format!( - "DELETE FROM {} WHERE index = $1", + + tx.execute( + query(&format!( + "DELETE FROM {} WHERE idx = $1", MockMerkleTree::state_type() - ), - (index,), + )) + .bind(serde_json::to_value(index).unwrap()), ) .await .unwrap(); diff --git a/src/data_source/storage/sql/transaction.rs b/src/data_source/storage/sql/transaction.rs index 18531109..4c1bbe1f 100644 --- a/src/data_source/storage/sql/transaction.rs +++ b/src/data_source/storage/sql/transaction.rs @@ -20,6 +20,7 @@ use super::{ queries::{ + self, state::{build_hash_batch_insert, Node}, DecodeError, }, @@ -35,9 +36,9 @@ use crate::{ }, merklized_state::{MerklizedState, UpdateStateData}, types::HeightIndexed, - Header, Payload, QueryError, VidShare, + Header, Payload, QueryError, QueryResult, VidShare, }; -use anyhow::{bail, ensure, Context}; +use anyhow::{bail, Context}; use ark_serialize::CanonicalSerialize; use async_trait::async_trait; use committable::Committable; @@ -51,15 +52,14 @@ use hotshot_types::traits::{ }; use itertools::Itertools; use jf_merkle_tree::prelude::{MerkleNode, MerkleProof}; -use sqlx::{pool::Pool, types::BitVec, Encode, Execute, FromRow, Type}; +use sqlx::types::BitVec; +pub use sqlx::Executor; +use sqlx::{pool::Pool, query_builder::Separated, Encode, FromRow, QueryBuilder, Type}; use std::{ collections::{HashMap, HashSet}, marker::PhantomData, - time::{Duration, Instant}, + time::Instant, }; -use tokio::time::sleep; - -pub use sqlx::Executor; pub type Query<'q> = sqlx::query::Query<'q, Db, ::Arguments<'q>>; pub type QueryAs<'q, T> = sqlx::query::QueryAs<'q, Db, T, ::Arguments<'q>>; @@ -93,9 +93,12 @@ pub trait TransactionMode: Send + Sync { } impl TransactionMode for Write { + #[allow(unused_variables)] async fn begin(conn: &mut ::Connection) -> anyhow::Result<()> { + #[cfg(not(feature = "embedded-db"))] conn.execute("SET TRANSACTION ISOLATION LEVEL SERIALIZABLE") .await?; + Ok(()) } @@ -105,9 +108,12 @@ impl TransactionMode for Write { } impl TransactionMode for Read { + #[allow(unused_variables)] async fn begin(conn: &mut ::Connection) -> anyhow::Result<()> { + #[cfg(not(feature = "embedded-db"))] conn.execute("SET TRANSACTION ISOLATION LEVEL SERIALIZABLE, READ ONLY, DEFERRABLE") .await?; + Ok(()) } @@ -227,9 +233,12 @@ impl update::Transaction for Transaction { /// as long as the parameters outlive the duration of the query (the `'p: 'q`) bound on the /// [`bind`](Self::bind) function. pub trait Params<'p> { - fn bind<'q>(self, q: Query<'q>) -> Query<'q> + fn bind<'q, 'r>( + self, + q: &'q mut Separated<'r, 'p, Db, &'static str>, + ) -> &'q mut Separated<'r, 'p, Db, &'static str> where - 'p: 'q; + 'p: 'r; } /// A collection of parameters with a statically known length. @@ -241,19 +250,19 @@ pub trait FixedLengthParams<'p, const N: usize>: Params<'p> {} macro_rules! impl_tuple_params { ($n:literal, ($($t:ident,)+)) => { - impl<'p, $($t),+> Params<'p> for ($($t,)+) + impl<'p, $($t),+> Params<'p> for ($($t,)+) where $( - $t: 'p + for<'q> Encode<'q, Db> + Type - ),+ { - fn bind<'q>(self, q: Query<'q>) -> Query<'q> - where - 'p: 'q + $t: 'p + Encode<'p, Db> + Type + ),+{ + fn bind<'q, 'r>(self, q: &'q mut Separated<'r, 'p, Db, &'static str>) -> &'q mut Separated<'r, 'p, Db, &'static str> + where 'p: 'r, { #[allow(non_snake_case)] let ($($t,)+) = self; q $( - .bind($t) + .push_bind($t) )+ + } } @@ -274,117 +283,37 @@ impl_tuple_params!(6, (T1, T2, T3, T4, T5, T6,)); impl_tuple_params!(7, (T1, T2, T3, T4, T5, T6, T7,)); impl_tuple_params!(8, (T1, T2, T3, T4, T5, T6, T7, T8,)); -impl<'p, T> Params<'p> for Vec +pub fn build_where_in<'a, I>( + query: &'a str, + column: &'a str, + values: I, +) -> QueryResult<(queries::QueryBuilder<'a>, String)> where - T: Params<'p>, + I: IntoIterator, + I::Item: 'a + Encode<'a, Db> + Type, { - fn bind<'q>(self, mut q: Query<'q>) -> Query<'q> - where - 'p: 'q, - { - for params in self { - q = params.bind(q); - } - q + let mut builder = queries::QueryBuilder::default(); + let params = values + .into_iter() + .map(|v| Ok(format!("{} ", builder.bind(v)?))) + .collect::>>()?; + + if params.is_empty() { + return Err(QueryError::Error { + message: "failed to build WHERE IN query. No parameter found ".to_string(), + }); } + + let sql = format!( + "{query} where {column} IN ({}) ", + params.into_iter().join(",") + ); + + Ok((builder, sql)) } /// Low-level, general database queries and mutation. impl Transaction { - /// Execute a statement that is expected to modify exactly one row. - /// - /// Returns an error if the database is not modified. - pub async fn execute_one<'q, E>(&mut self, statement: E) -> anyhow::Result<()> - where - E: 'q + Execute<'q, Db>, - { - let nrows = self.execute_many(statement).await?; - if nrows > 1 { - // If more than one row is affected, we don't return an error, because clearly - // _something_ happened and modified the database. So we don't necessarily want the - // caller to retry. But we do log an error, because it seems the query did something - // different than the caller intended. - tracing::error!("statement modified more rows ({nrows}) than expected (1)"); - } - Ok(()) - } - - /// Execute a statement that is expected to modify exactly one row. - /// - /// Returns an error if the database is not modified. Retries several times before failing. - pub async fn execute_one_with_retries<'q>( - &mut self, - statement: &'q str, - params: impl Params<'q> + Clone, - ) -> anyhow::Result<()> { - let interval = Duration::from_secs(1); - let mut retries = 5; - - while let Err(err) = self - .execute_one(params.clone().bind(query(statement))) - .await - { - tracing::error!( - %statement, - "error in statement execution ({retries} tries remaining): {err}" - ); - if retries == 0 { - return Err(err); - } - retries -= 1; - sleep(interval).await; - } - - Ok(()) - } - - /// Execute a statement that is expected to modify at least one row. - /// - /// Returns an error if the database is not modified. - pub async fn execute_many<'q, E>(&mut self, statement: E) -> anyhow::Result - where - E: 'q + Execute<'q, Db>, - { - let nrows = self.execute(statement).await?.rows_affected(); - ensure!(nrows > 0, "statement failed: 0 rows affected"); - Ok(nrows) - } - - /// Execute a statement that is expected to modify at least one row. - /// - /// Returns an error if the database is not modified. Retries several times before failing. - pub async fn execute_many_with_retries<'q, 'p>( - &mut self, - statement: &'q str, - params: impl Params<'p> + Clone, - ) -> anyhow::Result - where - 'p: 'q, - { - let interval = Duration::from_secs(1); - let mut retries = 5; - - loop { - match self - .execute_many(params.clone().bind(query(statement))) - .await - { - Ok(nrows) => return Ok(nrows), - Err(err) => { - tracing::error!( - %statement, - "error in statement execution ({retries} tries remaining): {err}" - ); - if retries == 0 { - return Err(err); - } - retries -= 1; - sleep(interval).await; - } - } - } - } - pub async fn upsert<'p, const N: usize, R>( &mut self, table: &str, @@ -400,35 +329,31 @@ impl Transaction { .iter() .map(|col| format!("{col} = excluded.{col}")) .join(","); - let columns = columns.into_iter().join(","); + let columns_str = columns + .into_iter() + .map(|col| format!("\"{col}\"")) + .join(","); let pk = pk.into_iter().join(","); - let mut values = vec![]; - let mut params = vec![]; + let mut query_builder = + QueryBuilder::new(format!("INSERT INTO \"{table}\" ({columns_str}) ")); let mut num_rows = 0; - for (row, entries) in rows.into_iter().enumerate() { - let start = row * N; - let end = (row + 1) * N; - let row_params = (start..end).map(|i| format!("${}", i + 1)).join(","); - values.push(format!("({row_params})")); - params.push(entries); + query_builder.push_values(rows, |mut b, row| { num_rows += 1; - } + row.bind(&mut b); + }); if num_rows == 0 { tracing::warn!("trying to upsert 0 rows, this has no effect"); return Ok(()); } - tracing::debug!("upserting {num_rows} rows"); - - let values = values.into_iter().join(","); - let stmt = format!( - "INSERT INTO {table} ({columns}) - VALUES {values} - ON CONFLICT ({pk}) DO UPDATE SET {set_columns}" - ); - let rows_modified = self.execute_many_with_retries(&stmt, params).await?; + query_builder.push(format!(" ON CONFLICT ({pk}) DO UPDATE SET {set_columns}")); + + let res = self.execute(query_builder.build()).await?; + let stmt = query_builder.sql(); + let rows_modified = res.rows_affected() as usize; + if rows_modified != num_rows { tracing::error!( stmt, @@ -519,13 +444,14 @@ where // The header and payload tables should already have been initialized when we inserted the // corresponding leaf. All we have to do is add the payload itself and its size. let payload = block.payload.encode(); + self.upsert( "payload", ["height", "data", "size", "num_transactions"], ["height"], [( block.height() as i64, - payload.as_ref(), + payload.as_ref().to_vec(), block.size() as i32, block.num_transactions() as i32, )], @@ -541,9 +467,9 @@ where } if !rows.is_empty() { self.upsert( - "transaction", - ["hash", "block_height", "index"], - ["block_height", "index"], + "transactions", + ["hash", "block_height", "idx"], + ["block_height", "idx"], rows, ) .await?; @@ -630,7 +556,7 @@ impl, const ARITY: usize> nodes.push(( Node { path: node_path, - index: Some(index), + idx: Some(index), ..Default::default() }, None, @@ -658,7 +584,7 @@ impl, const ARITY: usize> nodes.push(( Node { path, - index: Some(index), + idx: Some(index), entry: Some(entry), ..Default::default() }, @@ -750,10 +676,12 @@ impl, const ARITY: usize> message: "Missing child hash".to_string(), })?; - node.children = Some(children_hashes); + node.children = Some(children_hashes.into()); } } + Node::upsert(name, nodes.into_iter().map(|(n, _, _)| n), self).await?; + Ok(()) } } diff --git a/src/fetching/provider/query_service.rs b/src/fetching/provider/query_service.rs index e2135cbc..3b844542 100644 --- a/src/fetching/provider/query_service.rs +++ b/src/fetching/provider/query_service.rs @@ -1135,6 +1135,8 @@ mod test { tracing::info!("retrieve from storage"); let fetch = data_source.get_leaf(1).await; assert_eq!(leaves[0], fetch.try_resolve().ok().unwrap()); + + drop(db); } #[tokio::test(flavor = "multi_thread")]