-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add first poc for new knowledge management feature
Implement a knowledge management system to allow users to store and retrieve knowledge entries. This includes commands for initializing knowledge, searching, and processing embedding data. Introduce several new structs and methods to handle embedding services through different providers. The changes introduce a new module for handling knowledge operations, which differs from the previous implementation by providing structured commands and functionality for knowledge initialization and search through the use of embeddings. It also introduces a new embedding service builder for integrating with backends like OpenAI and Ollama, improving upon the previous query mechanisms by implementing organized and reusable code for managing knowledge inputs efficiently.
- Loading branch information
Christian Stolz
committed
Oct 10, 2024
1 parent
dbcf7c9
commit 355b8d3
Showing
19 changed files
with
2,892 additions
and
123 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
use crate::cli::knowledge::knowledge_args::InitArgs; | ||
use crate::config; | ||
use crate::config::{get_knowledge_dir, Config}; | ||
use crate::context::{load_files_into_context, ContextConsumer}; | ||
use crate::knowledge::EmbeddingServiceBuilder; | ||
use crate::persona::resolve_persona; | ||
use async_channel::Receiver; | ||
use async_channel::Sender; | ||
use serde::{Deserialize, Serialize}; | ||
use std::borrow::Cow; | ||
use std::env; | ||
use std::error::Error; | ||
use surrealdb::engine::local::RocksDb; | ||
use surrealdb::sql::Thing; | ||
use surrealdb::Surreal; | ||
|
||
/// Bounded channel size | ||
const CHANNEL_SIZE: usize = 10; | ||
|
||
type Job = (String, String); | ||
|
||
#[derive(Debug, Serialize, Deserialize)] | ||
pub struct EmbeddingData { | ||
pub file_name: String, | ||
pub embedding: Vec<f32>, // Adjust the size based on the model | ||
pub metadata: Option<String>, | ||
} | ||
|
||
// Add this to your SurrealDB record structure | ||
#[derive(Debug, Deserialize)] | ||
pub struct Record { | ||
#[allow(dead_code)] | ||
id: Thing, | ||
} | ||
|
||
fn get_config() -> Config { | ||
let config = config::CONFIG.lock().unwrap(); | ||
config.clone() | ||
} | ||
|
||
pub async fn init(args: InitArgs) -> Result<(), Box<dyn Error>> { | ||
let config = get_config(); | ||
let db = Surreal::new::<RocksDb>(get_knowledge_dir()?.to_str().unwrap()).await?; | ||
db.use_ns("knowledge").use_db("knowledge_db").await?; | ||
let (sender, receiver) = async_channel::bounded(CHANNEL_SIZE); | ||
let persona = resolve_persona(&args.persona, config.default_persona.as_str())?; | ||
let model_name = if let Some(model) = args.model { | ||
model // Use user-specified model | ||
} else { | ||
config.ai.embedding_model.clone() // Default model from configuration | ||
}; | ||
let client = EmbeddingServiceBuilder::new() | ||
.model_name(model_name.into()) | ||
.build()?; | ||
|
||
let max_threads = 10; // Define the maximum number of threads | ||
let mut handles = vec![]; | ||
|
||
// Spawn a fixed number of threads | ||
for _ in 0..max_threads { | ||
let client_clone = client.clone(); | ||
let db_clone = db.clone(); | ||
let receiver_clone: Receiver<Job> = receiver.clone(); | ||
|
||
let handle = tokio::spawn(async move { | ||
loop { | ||
let job = receiver_clone.recv().await; | ||
let (filename, content) = match job { | ||
Ok(job) => job, | ||
Err(_) => break, | ||
}; | ||
eprintln!("Processing File: {}", filename); | ||
|
||
let embedding_response = client_clone.inner.get_embedding(content).await.unwrap(); | ||
|
||
let embedding_data = EmbeddingData { | ||
file_name: filename.clone(), | ||
embedding: embedding_response.to_vec(), | ||
metadata: Some("Add additional details here if needed".to_string()), | ||
}; | ||
|
||
let _: Option<Record> = db_clone | ||
.create("embedding_table") | ||
.content(embedding_data) | ||
.await | ||
.unwrap(); | ||
} | ||
}); | ||
handles.push(handle); | ||
} | ||
|
||
// Initialize the OpenAI client | ||
let mut consumer = Consumer { sender }; | ||
load_files_into_context( | ||
&mut consumer, | ||
&env::current_dir().unwrap(), | ||
&persona.file_types, | ||
) | ||
.unwrap(); | ||
|
||
// Close the channel after loading files | ||
drop(consumer); | ||
|
||
// Wait for all threads to finish processing | ||
for handle in handles { | ||
let _ = handle.await; | ||
} | ||
|
||
Ok(()) | ||
} | ||
|
||
struct Consumer { | ||
sender: Sender<Job>, | ||
} | ||
|
||
impl ContextConsumer for Consumer { | ||
fn consume(&mut self, filename: Cow<str>, content: Cow<str>) -> Result<(), Box<dyn Error>> { | ||
let job = (filename.into(), content.into()); | ||
// Use block_in_place to execute blocking operations | ||
tokio::task::block_in_place(|| self.sender.send_blocking(job))?; | ||
Ok(()) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
//! This module defines the `KnowledgeArgs` struct, which holds the command-line arguments | ||
//! related to managing knowledge entries in the Rusty Buddy application. | ||
//! | ||
//! It utilizes the `clap` library for parsing, providing a clear interface for users specifying | ||
//! how to manage knowledge entries. | ||
use clap::{Args, Subcommand}; | ||
|
||
/// This struct holds the command-line arguments for managing knowledge entries. | ||
#[derive(Subcommand)] | ||
pub enum KnowledgeArgs { | ||
Init(InitArgs), | ||
Search(SearchArgs), | ||
} | ||
|
||
#[derive(Args)] | ||
pub struct SearchArgs { | ||
pub search: String, | ||
} | ||
|
||
#[derive(Args)] | ||
pub struct InitArgs { | ||
/// Specify a persona for the knowledge initalisation | ||
#[arg(short, long)] | ||
pub persona: Option<String>, | ||
|
||
/// Sets the AI model to use in this knowledge initalisation | ||
#[arg(short = 'm', long = "model")] | ||
pub model: Option<String>, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
//! This module provides functionality for managing knowledge entries for various commands. | ||
//! It allows users to store, retrieve, and delete knowledge entries that enhance the interactivity | ||
//! of whether it's for wishes or chats. | ||
mod init; | ||
mod knowledge_args; | ||
mod run; | ||
mod search; | ||
|
||
pub use knowledge_args::KnowledgeArgs; | ||
pub use run::run_knowledge; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
//! This module provides functionality for managing knowledge entries based | ||
//! on user-defined commands within the Rusty Buddy application. | ||
use crate::cli::knowledge::{init, search, KnowledgeArgs}; | ||
use std::error::Error; | ||
|
||
/// Runs the knowledge command, executing the specified action based on user input. | ||
/// | ||
/// # Arguments | ||
/// | ||
/// * `args` - KnowledgeArgs holding the parameters for the command. | ||
/// | ||
/// # Returns | ||
/// | ||
/// Returns a Result indicating success or an error if the process fails. | ||
pub async fn run_knowledge(args: KnowledgeArgs) -> Result<(), Box<dyn Error>> { | ||
match args { | ||
KnowledgeArgs::Init(init) => { | ||
init::init(init).await?; | ||
} | ||
KnowledgeArgs::Search(search) => { | ||
search::search(search).await?; | ||
} | ||
} | ||
|
||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
use crate::cli::knowledge::knowledge_args::SearchArgs; | ||
use crate::config; | ||
use crate::config::{get_knowledge_dir, Config}; | ||
use crate::knowledge::EmbeddingServiceBuilder; | ||
use serde::{Deserialize, Serialize}; | ||
use std::error::Error; | ||
use surrealdb::engine::local::RocksDb; | ||
use surrealdb::Surreal; | ||
|
||
#[derive(Debug, Serialize, Deserialize)] | ||
struct Search { | ||
distance: f32, | ||
file_name: String, | ||
} | ||
|
||
fn get_config() -> Config { | ||
let config = config::CONFIG.lock().unwrap(); | ||
config.clone() | ||
} | ||
|
||
pub async fn search(args: SearchArgs) -> Result<(), Box<dyn Error>> { | ||
let config = get_config(); | ||
let db = Surreal::new::<RocksDb>(get_knowledge_dir()?.to_str().unwrap()).await?; | ||
db.use_ns("knowledge").use_db("knowledge_db").await?; | ||
let model_name = config.ai.embedding_model.clone(); | ||
let client = EmbeddingServiceBuilder::new() | ||
.model_name(model_name.into()) | ||
.build()?; | ||
let embedding = client.inner.get_embedding(args.search).await?; | ||
db.query("DEFINE INDEX hnsw_pts ON embedding_table FIELDS embedding HNSW DIMENSION 3072;") | ||
.await?; | ||
// Assuming response.data holds embedding data | ||
let mut groups = db | ||
.query("SELECT file_name, vector::similarity::cosine(embedding, $embedding) AS distance FROM embedding_table WHERE embedding <|10,40|> $embedding ORDER BY distance;") | ||
.bind(("embedding", embedding)) | ||
.await?; | ||
let files: Vec<Search> = groups.take(0)?; | ||
for file in files { | ||
println!("{} {}", file.file_name, file.distance); | ||
} | ||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.