From edceae6655f6477c5b89b6c7a99a7201a204bba3 Mon Sep 17 00:00:00 2001 From: Stanislas Polu Date: Fri, 17 Nov 2023 16:58:10 +0100 Subject: [PATCH] [WIP] core: qdrant_migrator (#2571) * WIP qdrant_migrator * WIP migrate * code * nits * nit --- core/Cargo.lock | 107 +++++++ core/Cargo.toml | 7 +- core/bin/qdrant_migrator.rs | 431 +++++++++++++++++++++++++++ core/src/data_sources/data_source.rs | 78 +++-- core/src/data_sources/qdrant.rs | 47 ++- core/src/stores/postgres.rs | 26 ++ core/src/stores/store.rs | 10 +- 7 files changed, 666 insertions(+), 40 deletions(-) create mode 100644 core/bin/qdrant_migrator.rs diff --git a/core/Cargo.lock b/core/Cargo.lock index 9752a5cb0409..eb9007e1bcca 100644 --- a/core/Cargo.lock +++ b/core/Cargo.lock @@ -74,6 +74,54 @@ dependencies = [ "libc", ] +[[package]] +name = "anstream" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ab91ebe16eb252986481c5b62f6098f3b698a45e34b5b98200cf20dd2484a44" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7079075b41f533b8c61d2a4d073c4676e1f8b249ff94a393b0595db304e0dd87" + +[[package]] +name = "anstyle-parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "317b9a89c1868f5ea6ff1d9539a69f45dffc21ce321ac1fd1160dfa48c8e2140" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ca11d4be1bab0c8bc8734a9aa7bf4ee8316d462a08c6ac5052f888fef5b494b" +dependencies = [ + "windows-sys", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0699d10d2f4d628a98ee7b57b289abbc98ff3bad977cb3152709d4bf2330628" +dependencies = [ + "anstyle", + "windows-sys", +] + [[package]] name = "anyhow" version = "1.0.75" @@ -594,6 +642,46 @@ dependencies = [ "phf_codegen", ] +[[package]] +name = "clap" +version = "4.4.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2275f18819641850fa26c89acc84d465c1bf91ce57bc2748b28c420473352f64" +dependencies = [ + "clap_builder", + "clap_derive", +] + +[[package]] +name = "clap_builder" +version = "4.4.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07cdf1b148b25c1e1f7a42225e30a0d99a615cd4637eae7365548dd4529b95bc" +dependencies = [ + "anstream", + "anstyle", + "clap_lex", + "strsim", +] + +[[package]] +name = "clap_derive" +version = "4.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf9804afaaf59a91e75b022a30fb7229a7901f60c755489cc61c9b423b836442" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn 2.0.38", +] + +[[package]] +name = "clap_lex" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "702fc72eb24e5a1e48ce58027a675bc24edd52096d5397d4aea7c6dd9eca0bd1" + [[package]] name = "cloud-storage" version = "0.11.1" @@ -617,6 +705,12 @@ dependencies = [ "tokio", ] +[[package]] +name = "colorchoice" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" + [[package]] name = "concurrent-queue" version = "2.3.0" @@ -902,6 +996,7 @@ dependencies = [ "bb8-postgres", "blake3", "bstr 0.2.17", + "clap", "cloud-storage", "deno_core", "dns-lookup", @@ -2904,6 +2999,12 @@ dependencies = [ "unicode-normalization", ] +[[package]] +name = "strsim" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" + [[package]] name = "strum" version = "0.25.0" @@ -3509,6 +3610,12 @@ version = "2.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" +[[package]] +name = "utf8parse" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" + [[package]] name = "uuid" version = "1.5.0" diff --git a/core/Cargo.toml b/core/Cargo.toml index 882aaa8aea05..61847036cfa6 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -7,6 +7,10 @@ edition = "2021" name = "dust-api" path = "bin/dust_api.rs" +[[bin]] +name = "qdrant_migrator" +path = "bin/qdrant_migrator.rs" + [dependencies] anyhow = "1.0" serde = { version = "1.0", features = ["rc", "derive"] } @@ -50,4 +54,5 @@ tower-http = {version = "0.4", features = ["full"]} tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } deno_core = "0.200" -rayon = "1.8.0" \ No newline at end of file +rayon = "1.8.0" +clap = { version = "4.4", features = ["derive"] } \ No newline at end of file diff --git a/core/bin/qdrant_migrator.rs b/core/bin/qdrant_migrator.rs new file mode 100644 index 000000000000..24ba65308249 --- /dev/null +++ b/core/bin/qdrant_migrator.rs @@ -0,0 +1,431 @@ +use std::str::FromStr; + +use anyhow::{anyhow, Result}; +use clap::{Parser, Subcommand}; +use dust::{ + data_sources::qdrant::{QdrantClients, QdrantCluster, QdrantDataSourceConfig}, + project, + run::Credentials, + stores::postgres, + stores::store, + utils, +}; +use qdrant_client::{ + prelude::Payload, + qdrant::{self, PointId, ScrollPoints}, +}; + +#[derive(Debug, Subcommand)] +enum Commands { + #[command(arg_required_else_help = true)] + #[command(about = "Show qdrant state for data source", long_about = None)] + Show { + project_id: i64, + data_source_id: String, + }, + #[command(arg_required_else_help = true)] + #[command(about = "Set `shadow_write_cluster` \ + (!!! creates collection on `shadow_write_cluster`)", long_about = None)] + SetShadowWrite { + project_id: i64, + data_source_id: String, + cluster: String, + }, + #[command(arg_required_else_help = true)] + #[command(about = "Clear `shadow_write_cluster` \ + (!!! deletes collection from `shadow_write_cluster`)", long_about = None)] + ClearShadowWrite { + project_id: i64, + data_source_id: String, + }, + #[command(arg_required_else_help = true)] + #[command(about = "Migrate `cluster` collection to `shadow_write_cluster`", long_about = None)] + MigrateShadowWrite { + project_id: i64, + data_source_id: String, + }, + #[command(arg_required_else_help = true)] + #[command(about = "Switch `shadow_write_cluster` and `cluster` \ + (!!! moves read traffic to `shadow_write_cluster`)", long_about = None)] + CommitShadowWrite { + project_id: i64, + data_source_id: String, + }, +} + +#[derive(Debug, Parser)] +#[command(name = "collection_migrator")] +#[command(about = "Tooling to migrate Qdrant collections", long_about = None)] +struct Cli { + #[command(subcommand)] + command: Commands, +} + +fn main() -> Result<()> { + let args = Cli::parse(); + + let rt = tokio::runtime::Builder::new_multi_thread() + .worker_threads(32) + .enable_all() + .build() + .unwrap(); + + let r = rt.block_on(async { + tracing_subscriber::fmt() + .with_target(false) + .compact() + .with_ansi(false) + .init(); + let store: Box = match std::env::var("CORE_DATABASE_URI") { + Ok(db_uri) => { + let store = postgres::PostgresStore::new(&db_uri).await?; + Box::new(store) + } + Err(_) => Err(anyhow!("CORE_DATABASE_URI is required (postgres)"))?, + }; + + let qdrant_clients = QdrantClients::build().await?; + + match args.command { + Commands::Show { + project_id, + data_source_id, + } => { + let project = project::Project::new_from_id(project_id); + let ds = match store.load_data_source(&project, &data_source_id).await? { + Some(ds) => ds, + None => Err(anyhow!("Data source not found"))?, + }; + + utils::info(&format!( + "Data source: collection={} cluster={} shadow_write_cluster={}", + ds.qdrant_collection(), + qdrant_clients + .main_cluster(&ds.config().qdrant_config) + .to_string(), + match qdrant_clients.shadow_write_cluster(&ds.config().qdrant_config) { + Some(cluster) => cluster.to_string(), + None => "none".to_string(), + } + )); + + let qdrant_client = qdrant_clients.main_client(&ds.config().qdrant_config); + match qdrant_client + .collection_info(ds.qdrant_collection()) + .await? + .result + { + Some(info) => { + utils::info(&format!( + "[MAIN] Qdrant collection: cluster={} collection={} status={} \ + points_count={}", + qdrant_clients + .main_cluster(&ds.config().qdrant_config) + .to_string(), + ds.qdrant_collection(), + info.status.to_string(), + info.points_count, + )); + } + None => Err(anyhow!("Qdrant collection not found"))?, + } + + match qdrant_clients.shadow_write_cluster(&ds.config().qdrant_config) { + Some(shadow_write_cluster) => { + let shadow_write_qdrant_client = qdrant_clients + .shadow_write_client(&ds.config().qdrant_config) + .unwrap(); + match shadow_write_qdrant_client + .collection_info(ds.qdrant_collection()) + .await? + .result + { + Some(info) => { + utils::info(&format!( + "[SHADOW] Qdrant collection: cluster={} collection={} status={}\ + points_count={}", + shadow_write_cluster.to_string(), + ds.qdrant_collection(), + info.status.to_string(), + info.points_count, + )); + } + None => Err(anyhow!("Qdrant collection not found"))?, + } + } + None => (), + }; + + Ok::<(), anyhow::Error>(()) + } + Commands::SetShadowWrite { + project_id, + data_source_id, + cluster, + } => { + let project = project::Project::new_from_id(project_id); + let mut ds = match store.load_data_source(&project, &data_source_id).await? { + Some(ds) => ds, + None => Err(anyhow!("Data source not found"))?, + }; + + let mut config = ds.config().clone(); + + config.qdrant_config = match config.qdrant_config { + Some(c) => Some(QdrantDataSourceConfig { + cluster: c.cluster, + shadow_write_cluster: Some(QdrantCluster::from_str(cluster.as_str())?), + }), + None => Some(QdrantDataSourceConfig { + cluster: QdrantCluster::Main0, + shadow_write_cluster: Some(QdrantCluster::from_str(cluster.as_str())?), + }), + }; + + // Create collection on shadow_write_cluster. + let shadow_write_qdrant_client = + match qdrant_clients.shadow_write_client(&config.qdrant_config) { + Some(client) => client, + None => unreachable!(), + }; + + // We send a fake credentials here since this is not really used for OpenAI to get + // the embeedding size (which is what happens here). May need to be revisited in + // future. + let mut credentials = Credentials::new(); + credentials.insert("OPENAI_API_KEY".to_string(), "foo".to_string()); + + ds.create_qdrant_collection(credentials, shadow_write_qdrant_client.clone()) + .await?; + + utils::done(&format!( + "Created qdrant shadow_write_cluster collection: \ + collection={} shadow_write_cluster={}", + ds.qdrant_collection(), + match qdrant_clients.shadow_write_cluster(&config.qdrant_config) { + Some(cluster) => cluster.to_string(), + None => "none".to_string(), + } + )); + + // Add shadow_write_cluster to config. + ds.update_config(store, &config).await?; + + utils::done(&format!( + "Updated data source: collection={} cluster={} shadow_write_cluster={}", + ds.qdrant_collection(), + qdrant_clients + .main_cluster(&ds.config().qdrant_config) + .to_string(), + match qdrant_clients.shadow_write_cluster(&ds.config().qdrant_config) { + Some(cluster) => cluster.to_string(), + None => "none".to_string(), + } + )); + + Ok::<(), anyhow::Error>(()) + } + Commands::ClearShadowWrite { + project_id, + data_source_id, + } => { + // This is the most dangerous command of all as it is the only one to actually + // delete data in an unrecoverable way. + let project = project::Project::new_from_id(project_id); + let mut ds = match store.load_data_source(&project, &data_source_id).await? { + Some(ds) => ds, + None => Err(anyhow!("Data source not found"))?, + }; + + let shadow_write_qdrant_client = + match qdrant_clients.shadow_write_client(&ds.config().qdrant_config) { + Some(client) => client, + None => Err(anyhow!("No shadow write cluster to clear"))?, + }; + + match shadow_write_qdrant_client + .collection_info(ds.qdrant_collection()) + .await? + .result + { + Some(info) => { + // confirm + match utils::confirm(&format!( + "[DANGER] Are you sure you want to delete this qdrant \ + shadow_write_cluster collection? \ + (this is definitive) shadow_write_cluster={} points_count={}", + match qdrant_clients.shadow_write_cluster(&ds.config().qdrant_config) { + Some(cluster) => cluster.to_string(), + None => "none".to_string(), + } + .to_string(), + info.points_count, + ))? { + true => (), + false => Err(anyhow!("Aborted"))?, + } + } + None => Err(anyhow!("Qdrant collection not found"))?, + }; + + // Delete collection on shadow_write_cluster. + shadow_write_qdrant_client + .delete_collection(ds.qdrant_collection()) + .await?; + + utils::done(&format!( + "Deleted qdrant shadow_write_cluster collection: \ + collection={} shadow_write_cluster={}", + ds.qdrant_collection(), + match qdrant_clients.shadow_write_cluster(&ds.config().qdrant_config) { + Some(cluster) => cluster.to_string(), + None => "none".to_string(), + } + )); + + // Remove shadow_write_cluster from config. + let mut config = ds.config().clone(); + + config.qdrant_config = match config.qdrant_config { + Some(c) => Some(QdrantDataSourceConfig { + cluster: c.cluster, + shadow_write_cluster: None, + }), + None => Some(QdrantDataSourceConfig { + cluster: QdrantCluster::Main0, + shadow_write_cluster: None, + }), + }; + + ds.update_config(store, &config).await?; + + utils::done(&format!( + "Updated data source: collection={} cluster={} shadow_write_cluster={}", + ds.qdrant_collection(), + qdrant_clients + .main_cluster(&ds.config().qdrant_config) + .to_string(), + match qdrant_clients.shadow_write_cluster(&ds.config().qdrant_config) { + Some(cluster) => cluster.to_string(), + None => "none".to_string(), + } + )); + + Ok::<(), anyhow::Error>(()) + } + Commands::MigrateShadowWrite { + project_id, + data_source_id, + } => { + let project = project::Project::new_from_id(project_id); + let ds = match store.load_data_source(&project, &data_source_id).await? { + Some(ds) => ds, + None => Err(anyhow!("Data source not found"))?, + }; + + let qdrant_client = qdrant_clients.main_client(&ds.config().qdrant_config); + + // Delete collection on shadow_write_cluster. + let shadow_write_qdrant_client = + match qdrant_clients.shadow_write_client(&ds.config().qdrant_config) { + Some(client) => client, + None => Err(anyhow!("No shadow write cluster to migrate to"))?, + }; + + let mut page_offset: Option = None; + let mut total: usize = 0; + loop { + let scroll_results = qdrant_client + .scroll(&ScrollPoints { + collection_name: ds.qdrant_collection(), + with_vectors: Some(true.into()), + with_payload: Some(true.into()), + limit: Some(256), + offset: page_offset, + ..Default::default() + }) + .await?; + + let count = scroll_results.result.len(); + + let points = scroll_results + .result + .into_iter() + .map(|r| { + qdrant::PointStruct::new( + r.id.unwrap(), + r.vectors.unwrap(), + Payload::new_from_hashmap(r.payload), + ) + }) + .collect::>(); + + shadow_write_qdrant_client + .upsert_points(ds.qdrant_collection(), points, None) + .await?; + + total += count; + utils::info(&format!("Migrated points: count={} total={}", count, total)); + + page_offset = scroll_results.next_page_offset; + if page_offset.is_none() { + break; + } + } + + utils::info(&format!("Done migrating: total={}", total)); + + Ok::<(), anyhow::Error>(()) + } + Commands::CommitShadowWrite { + project_id, + data_source_id, + } => { + let project = project::Project::new_from_id(project_id); + let mut ds = match store.load_data_source(&project, &data_source_id).await? { + Some(ds) => ds, + None => Err(anyhow!("Data source not found"))?, + }; + + let mut config = ds.config().clone(); + + config.qdrant_config = match config.qdrant_config { + Some(c) => match c.shadow_write_cluster { + Some(cluster) => Some(QdrantDataSourceConfig { + cluster: cluster, + shadow_write_cluster: Some(c.cluster), + }), + None => Err(anyhow!("No shadow write cluster to commit"))?, + }, + None => Err(anyhow!("No shadow write cluster to commit"))?, + }; + + ds.update_config(store, &config).await?; + + utils::info(&format!( + "Updated data source: collection={} cluster={} shadow_write_cluster={}", + ds.qdrant_collection(), + qdrant_clients + .main_cluster(&ds.config().qdrant_config) + .to_string(), + match qdrant_clients.shadow_write_cluster(&ds.config().qdrant_config) { + Some(cluster) => cluster.to_string(), + None => "none".to_string(), + } + )); + + Ok::<(), anyhow::Error>(()) + } + } + }); + + match r { + Ok(_) => (), + Err(e) => { + utils::error(&format!("Error: {:?}", e)); + std::process::exit(1); + } + } + + Ok(()) +} diff --git a/core/src/data_sources/data_source.rs b/core/src/data_sources/data_source.rs index a25f749603a8..63a4b3a0252c 100644 --- a/core/src/data_sources/data_source.rs +++ b/core/src/data_sources/data_source.rs @@ -296,43 +296,30 @@ impl DataSource { &self.config } - fn qdrant_collection(&self) -> String { + pub fn qdrant_collection(&self) -> String { format!("ds_{}", self.internal_id) } - pub async fn setup( + pub async fn update_config( + &mut self, + store: Box, + config: &DataSourceConfig, + ) -> Result<()> { + self.config = config.clone(); + store + .update_data_source_config(&self.project, &self.data_source_id, &self.config) + .await?; + Ok(()) + } + + pub async fn create_qdrant_collection( &self, credentials: Credentials, - qdrant_clients: QdrantClients, + qdrant_client: Arc, ) -> Result<()> { - let qdrant_client = qdrant_clients.main_client(&self.config.qdrant_config); - let mut embedder = provider(self.config.provider_id).embedder(self.config.model_id.clone()); embedder.initialize(credentials).await?; - // GCP store created data to test GCP. - let bucket = match std::env::var("DUST_DATA_SOURCES_BUCKET") { - Ok(bucket) => bucket, - Err(_) => Err(anyhow!("DUST_DATA_SOURCES_BUCKET is not set"))?, - }; - - let bucket_path = format!("{}/{}", self.project.project_id(), self.internal_id); - let data_source_created_path = format!("{}/created.txt", bucket_path); - - Object::create( - &bucket, - format!("{}", self.created).as_bytes().to_vec(), - &data_source_created_path, - "application/text", - ) - .await?; - - utils::done(&format!( - "Created GCP bucket for data_source `{}`", - self.data_source_id - )); - - // Qdrant create collection. qdrant_client .create_collection(&qdrant::CreateCollection { collection_name: self.qdrant_collection(), @@ -369,6 +356,41 @@ impl DataSource { ..Default::default() }) .await?; + Ok(()) + } + + pub async fn setup( + &self, + credentials: Credentials, + qdrant_clients: QdrantClients, + ) -> Result<()> { + let qdrant_client = qdrant_clients.main_client(&self.config.qdrant_config); + + // GCP store created data to test GCP. + let bucket = match std::env::var("DUST_DATA_SOURCES_BUCKET") { + Ok(bucket) => bucket, + Err(_) => Err(anyhow!("DUST_DATA_SOURCES_BUCKET is not set"))?, + }; + + let bucket_path = format!("{}/{}", self.project.project_id(), self.internal_id); + let data_source_created_path = format!("{}/created.txt", bucket_path); + + Object::create( + &bucket, + format!("{}", self.created).as_bytes().to_vec(), + &data_source_created_path, + "application/text", + ) + .await?; + + utils::done(&format!( + "Created GCP bucket for data_source `{}`", + self.data_source_id + )); + + // Qdrant create collection. + self.create_qdrant_collection(credentials, qdrant_client.clone()) + .await?; let _ = qdrant_client .create_field_index( diff --git a/core/src/data_sources/qdrant.rs b/core/src/data_sources/qdrant.rs index 844b2a24f7ae..f62004cfe78d 100644 --- a/core/src/data_sources/qdrant.rs +++ b/core/src/data_sources/qdrant.rs @@ -1,5 +1,7 @@ +use crate::utils::ParseError; use anyhow::{anyhow, Result}; use std::collections::HashMap; +use std::str::FromStr; use std::sync::Arc; use parking_lot::Mutex; @@ -10,16 +12,37 @@ use serde::{Deserialize, Serialize}; pub enum QdrantCluster { #[serde(rename = "main-0")] Main0, - //#[serde(rename = "dedicated-0")] - //Dedicated0, + #[serde(rename = "dedicated-0")] + Dedicated0, } -static QDRANT_CLUSTER_VARIANTS: &[QdrantCluster] = &[QdrantCluster::Main0]; +static QDRANT_CLUSTER_VARIANTS: &[QdrantCluster] = + &[QdrantCluster::Main0, QdrantCluster::Dedicated0]; + +impl ToString for QdrantCluster { + fn to_string(&self) -> String { + match self { + QdrantCluster::Main0 => String::from("main-0"), + QdrantCluster::Dedicated0 => String::from("dedicated-0"), + } + } +} + +impl FromStr for QdrantCluster { + type Err = ParseError; + fn from_str(s: &str) -> Result { + match s { + "main-0" => Ok(QdrantCluster::Main0), + "dedicated-0" => Ok(QdrantCluster::Dedicated0), + _ => Err(ParseError::with_message("Unknown QdrantCluster"))?, + } + } +} pub fn env_var_prefix_for_cluster(cluster: QdrantCluster) -> &'static str { match cluster { QdrantCluster::Main0 => "QDRANT_MAIN_0", - // QDrantCluster::Dedicated0 => "QDRANT_DEDICATED_0", + QdrantCluster::Dedicated0 => "QDRANT_DEDICATED_0", } } @@ -30,8 +53,8 @@ pub struct QdrantClients { #[derive(Serialize, Deserialize, PartialEq, Clone, Debug)] pub struct QdrantDataSourceConfig { - cluster: QdrantCluster, - shadow_write_cluster: Option, + pub cluster: QdrantCluster, + pub shadow_write_cluster: Option, } impl QdrantClients { @@ -78,13 +101,17 @@ impl QdrantClients { } } + pub fn main_cluster(&self, config: &Option) -> QdrantCluster { + match config { + Some(config) => config.cluster, + None => QdrantCluster::Main0, + } + } + // Returns the client for the cluster specified in the config or the main-0 cluster if no config // is provided. pub fn main_client(&self, config: &Option) -> Arc { - match config { - Some(config) => self.client(config.cluster), - None => self.client(QdrantCluster::Main0), - } + self.client(self.main_cluster(config)) } pub fn shadow_write_cluster( diff --git a/core/src/stores/postgres.rs b/core/src/stores/postgres.rs index 141a4e76a479..cca2eafc9b15 100644 --- a/core/src/stores/postgres.rs +++ b/core/src/stores/postgres.rs @@ -1033,6 +1033,32 @@ impl Store for PostgresStore { } } + async fn update_data_source_config( + &self, + project: &Project, + data_source_id: &str, + config: &DataSourceConfig, + ) -> Result<()> { + let project_id = project.project_id(); + let data_source_id = data_source_id.to_string(); + let data_source_config = config.clone(); + + let pool = self.pool.clone(); + let c = pool.get().await?; + + let config_data = serde_json::to_string(&data_source_config)?; + let stmt = c + .prepare( + "UPDATE data_sources SET config_json = $1 \ + WHERE project = $2 AND data_source_id = $3", + ) + .await?; + c.execute(&stmt, &[&config_data, &project_id, &data_source_id]) + .await?; + + Ok(()) + } + async fn load_data_source_document( &self, project: &Project, diff --git a/core/src/stores/store.rs b/core/src/stores/store.rs index 56a3a007c691..97f1ec9c66c5 100644 --- a/core/src/stores/store.rs +++ b/core/src/stores/store.rs @@ -1,5 +1,7 @@ use crate::blocks::block::BlockType; -use crate::data_sources::data_source::{DataSource, Document, DocumentVersion, SearchFilter}; +use crate::data_sources::data_source::{ + DataSource, DataSourceConfig, Document, DocumentVersion, SearchFilter, +}; use crate::databases::database::{Database, DatabaseRow, DatabaseTable}; use crate::dataset::Dataset; use crate::http::request::{HttpRequest, HttpResponse}; @@ -88,6 +90,12 @@ pub trait Store { project: &Project, data_source_id: &str, ) -> Result>; + async fn update_data_source_config( + &self, + project: &Project, + data_source_id: &str, + config: &DataSourceConfig, + ) -> Result<()>; async fn load_data_source_document( &self, project: &Project,