Skip to content

Commit

Permalink
Add connection mode functionality to KnowledgeStore to allow multiple
Browse files Browse the repository at this point in the history
access

Implement connection modes to allow persistent or on-demand database connections when using the KnowledgeStore.

Introduce the `ConnectionMode` enum to define the two modes and modify the `StoreBuilder` structure to allow setting a connection mode during store creation. The `KnowledgeStoreImpl` is updated to handle both connection modes, establishing a database connection only when necessary for on-demand mode, while maintaining a persistent connection if specified.
  • Loading branch information
Christian Stolz committed Oct 15, 2024
1 parent 619eaa9 commit 6642bfc
Showing 5 changed files with 73 additions and 29 deletions.
7 changes: 5 additions & 2 deletions src/cli/knowledge/add.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::cli::knowledge::knowledge_args::AddArgs;
use crate::knowledge::{DataSource, EmbeddingData, KnowledgeStore, StoreBuilder};
use crate::knowledge::{ConnectionMode, DataSource, EmbeddingData, KnowledgeStore, StoreBuilder};
use log::{info, warn};
use std::borrow::Cow;
use std::error::Error;
@@ -10,7 +10,10 @@ use tokio::task::JoinHandle;
use walkdir::WalkDir;

pub async fn add(add: AddArgs) -> Result<(), Box<dyn Error>> {
let store = StoreBuilder::new().build().await?;
let store = StoreBuilder::new()
.connection_mode(ConnectionMode::Persistent)
.build()
.await?;

if let Some(dir) = add.dir {
add_directory_to_knowledge(&dir, &store).await?;
7 changes: 5 additions & 2 deletions src/cli/knowledge/init.rs
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@ use crate::config;
use crate::config::Config;
use crate::context::{load_files_into_context, ContextConsumer};
use crate::knowledge::DataSource::Context;
use crate::knowledge::{EmbeddingData, KnowledgeStore, StoreBuilder};
use crate::knowledge::{ConnectionMode, EmbeddingData, KnowledgeStore, StoreBuilder};
use crate::persona::resolve_persona;
use async_channel::{Receiver, Sender};
use std::borrow::Cow;
@@ -28,7 +28,10 @@ fn get_config() -> Config {
/// Entry point for initializing the knowledge system.
pub async fn init(args: InitArgs) -> Result<(), Box<dyn Error>> {
let config = get_config();
let db = StoreBuilder::new().build().await?;
let db = StoreBuilder::new()
.connection_mode(ConnectionMode::Persistent)
.build()
.await?;
let (sender, receiver) = async_channel::bounded(CHANNEL_SIZE);
let persona = resolve_persona(&args.persona, config.default_persona.as_str())?;

7 changes: 7 additions & 0 deletions src/knowledge/interface.rs
Original file line number Diff line number Diff line change
@@ -32,6 +32,13 @@ impl Clone for EmbeddingServiceHandle {
}
}

#[derive(Debug, Clone, Copy, Default)]
pub enum ConnectionMode {
#[default]
OnDemand, // Open/close connection as needed
Persistent, // Keep connection open
}

/// KnowledgeStore trait to abstract knowledge retrieval.
/// It should generate embeddings from the user input and then perform the database query
/// to retrieve relevant documents.
17 changes: 13 additions & 4 deletions src/knowledge/store_builder.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use crate::knowledge::store_impl::KnowledgeStoreImpl;
use crate::knowledge::KnowledgeStore;
use crate::knowledge::{ConnectionMode, KnowledgeStore};
use std::error::Error;
use std::sync::Arc;

#[derive(Default)]
pub struct StoreBuilder {}
pub struct StoreBuilder {
connection_mode: ConnectionMode,
}

impl StoreBuilder {
pub(crate) fn new() -> StoreBuilder {
@@ -14,7 +16,14 @@ impl StoreBuilder {

impl StoreBuilder {
// Build method to construct the ChatServiceFactory
pub async fn build(self) -> Result<Arc<dyn KnowledgeStore>, Box<dyn Error>> {
Ok(Arc::new(KnowledgeStoreImpl::new().await?))
pub async fn build(&self) -> Result<Arc<dyn KnowledgeStore>, Box<dyn Error>> {
Ok(Arc::new(
KnowledgeStoreImpl::new(self.connection_mode).await?,
))
}

pub fn connection_mode(&mut self, mode: ConnectionMode) -> &mut Self {
self.connection_mode = mode;
self
}
}
64 changes: 43 additions & 21 deletions src/knowledge/store_impl.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use crate::config::{get_knowledge_dir, CONFIG};
use crate::knowledge::{
EmbeddingData, EmbeddingServiceBuilder, EmbeddingServiceHandle, KnowledgeResult,
KnowledgeStore, Record,
ConnectionMode, EmbeddingData, EmbeddingServiceBuilder, EmbeddingServiceHandle,
KnowledgeResult, KnowledgeStore, Record,
};
use async_trait::async_trait;
use log::{info, warn};
use std::borrow::Cow;
use std::error::Error;
use std::sync::Arc;
use surrealdb::engine::local::{Db, RocksDb};
use surrealdb::Surreal;

@@ -15,39 +16,47 @@ use surrealdb::Surreal;
/// using the EmbeddingServiceHandle.
pub struct KnowledgeStoreImpl {
embedding_service: EmbeddingServiceHandle,
db: Surreal<Db>,
db: Option<Arc<Surreal<Db>>>,
}

impl KnowledgeStoreImpl {
/// Creates a new instance of `KnowledgeStoreImpl` by connecting to the knowledge database
/// and initializing an `EmbeddingServiceHandle` based on the current configuration.
pub async fn new() -> Result<Self, Box<dyn Error>> {
// Get the embedding model from the configuration
pub async fn new(mode: ConnectionMode) -> Result<Self, Box<dyn Error>> {
let embedding_model = {
let config = CONFIG.lock().unwrap();
config.ai.embedding_model.clone()
};
// Create an embedding service based on the selected model
let embedding_service = EmbeddingServiceBuilder::new()
.model_name(embedding_model.into())
.build()?;

// Connect to the SurrealDB local database
let db = Surreal::new::<RocksDb>(get_knowledge_dir()?.to_str().unwrap()).await?;
db.use_ns("knowledge").use_db("knowledge_db").await?;

// Ensure the knowledge database has an index for HNSW-based embeddings similarity search
db.query(format!(
"DEFINE INDEX idx_mtree_cosine ON context_embeddings FIELDS embedding MTREE DIMENSION {} DIST COSINE TYPE F32;",
embedding_service.inner.embedding_len()
))
.await?;
let db = match mode {
ConnectionMode::Persistent => {
Some(connect_to_db(embedding_service.inner.embedding_len()).await?)
}
ConnectionMode::OnDemand => None,
};

Ok(KnowledgeStoreImpl {
embedding_service,
db,
})
}

async fn connect(&self, idx_len: usize) -> Result<Arc<Surreal<Db>>, Box<dyn Error>> {
connect_to_db(idx_len).await
}
}

async fn connect_to_db(idx_len: usize) -> Result<Arc<Surreal<Db>>, Box<dyn Error>> {
info!("Connecting to db");
let db = Surreal::new::<RocksDb>(get_knowledge_dir()?.to_str().unwrap()).await?;
db.use_ns("knowledge").use_db("knowledge_db").await?;
db.query(format!(
"DEFINE INDEX idx_mtree_cosine ON context_embeddings FIELDS embedding MTREE DIMENSION {} DIST COSINE TYPE F32;",
idx_len
))
.await?;
Ok(Arc::new(db))
}

#[async_trait]
@@ -65,8 +74,15 @@ impl KnowledgeStore for KnowledgeStoreImpl {
.get_embedding(user_input)
.await?;
info!("Searching for knowledge for embedding");
let db_handle = if let Some(db) = &self.db {
db
} else {
&self
.connect(self.embedding_service.inner.embedding_len())
.await?
};
// Query the knowledge base for the closest embeddings (most relevant documents)
let mut results = match self.db
let mut results = match db_handle
.query("SELECT data_source, content, metadata, vector::similarity::cosine(embedding, $embedding) AS distance FROM context_embeddings WHERE embedding <|10|> $embedding ORDER BY distance;")
.bind(("embedding", embedding))
.await {
@@ -88,9 +104,15 @@ impl KnowledgeStore for KnowledgeStoreImpl {
async fn store_knowledge(&self, knowledge: EmbeddingData) -> Result<(), Box<dyn Error>> {
let data_source = knowledge.data_source.to_string();
info!("Storing knowledge for: {}", data_source);
let db_handle = if let Some(db) = &self.db {
db
} else {
&self
.connect(self.embedding_service.inner.embedding_len())
.await?
};

match self
.db
match db_handle
.upsert::<Option<Record>>(("context_embeddings", knowledge.data_source.to_string()))
.content(knowledge)
.await

0 comments on commit 6642bfc

Please sign in to comment.