Skip to content

Commit

Permalink
Add first poc for new knowledge management feature
Browse files Browse the repository at this point in the history
Implement a knowledge management system to allow users to store and retrieve knowledge entries. This includes commands for initializing knowledge, searching, and processing embedding data. Introduce several new structs and methods to handle embedding services through different providers.

The changes introduce a new module for handling knowledge operations, which differs from the previous implementation by providing structured commands and functionality for knowledge initialization and search through the use of embeddings. It also introduces a new embedding service builder for integrating with backends like OpenAI and Ollama, improving upon the previous query mechanisms by implementing organized and reusable code for managing knowledge inputs efficiently.
  • Loading branch information
Christian Stolz committed Oct 10, 2024
1 parent dbcf7c9 commit 355b8d3
Show file tree
Hide file tree
Showing 19 changed files with 2,892 additions and 123 deletions.
2,577 changes: 2,489 additions & 88 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,5 @@ async-trait = "0.1"
ollama-rs = { version = "0.2", features = ["stream"] }
chrono = { version = "0.4", features = ["serde"]}
chrono-tz = "0.10"
surrealdb = { version = "2.0", features = ["kv-rocksdb"] }
async-channel = "2.3"
7 changes: 6 additions & 1 deletion src/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use crate::cli::chat::ChatArgs;
use crate::cli::commitmessage::CommitMessageArgs;
use crate::cli::createbackground::BackgroundArgs;
use crate::cli::createicon::CreateIconArgs;
use crate::cli::knowledge::KnowledgeArgs;
use crate::cli::wish::WishArgs;
use clap::{Parser, Subcommand};
use clap_complete::aot::Shell;
Expand Down Expand Up @@ -65,11 +66,15 @@ pub enum Commands {
CreateIcon(CreateIconArgs),

/// Create a background using DALL·E based on user input.
CreateBackground(BackgroundArgs), // <-- New command
CreateBackground(BackgroundArgs),

/// Collect files from a specified directory and create a context for chat.
Wish(WishArgs),

/// Manage knowledge entries.
#[clap(subcommand)]
Knowledge(KnowledgeArgs),

/// Initialize configuration and environment.
Init,
}
14 changes: 1 addition & 13 deletions src/cli/chat/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ use crate::cli::spinner::{start_spinner, stop_spinner};
use crate::cli::style::configure_mad_skin;
use crate::config;
use crate::config::{get_chat_sessions_dir, Config};
use crate::persona::{get_persona, Persona};
use crate::persona::{resolve_persona, Persona};
use atty::Stream;
use chrono::{DateTime, Local, Utc};
use log::error;
Expand Down Expand Up @@ -103,18 +103,6 @@ fn initialize_command_registry() -> CommandRegistry {
command_registry
}

fn resolve_persona(
persona_name: &Option<String>,
default_persona: &str,
) -> Result<Persona, Box<dyn Error>> {
match persona_name {
Some(name) => {
get_persona(name).ok_or_else(|| "Specified persona not found. Using default.".into())
}
None => Ok(get_persona(default_persona).unwrap()),
}
}

fn handle_session(
chat_service: &mut ChatService,
start_new: bool,
Expand Down
123 changes: 123 additions & 0 deletions src/cli/knowledge/init.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
use crate::cli::knowledge::knowledge_args::InitArgs;
use crate::config;
use crate::config::{get_knowledge_dir, Config};
use crate::context::{load_files_into_context, ContextConsumer};
use crate::knowledge::EmbeddingServiceBuilder;
use crate::persona::resolve_persona;
use async_channel::Receiver;
use async_channel::Sender;
use serde::{Deserialize, Serialize};
use std::borrow::Cow;
use std::env;
use std::error::Error;
use surrealdb::engine::local::RocksDb;
use surrealdb::sql::Thing;
use surrealdb::Surreal;

/// Bounded channel size
const CHANNEL_SIZE: usize = 10;

type Job = (String, String);

#[derive(Debug, Serialize, Deserialize)]
pub struct EmbeddingData {
pub file_name: String,
pub embedding: Vec<f32>, // Adjust the size based on the model
pub metadata: Option<String>,
}

// Add this to your SurrealDB record structure
#[derive(Debug, Deserialize)]
pub struct Record {
#[allow(dead_code)]
id: Thing,
}

fn get_config() -> Config {
let config = config::CONFIG.lock().unwrap();
config.clone()
}

pub async fn init(args: InitArgs) -> Result<(), Box<dyn Error>> {
let config = get_config();
let db = Surreal::new::<RocksDb>(get_knowledge_dir()?.to_str().unwrap()).await?;
db.use_ns("knowledge").use_db("knowledge_db").await?;
let (sender, receiver) = async_channel::bounded(CHANNEL_SIZE);
let persona = resolve_persona(&args.persona, config.default_persona.as_str())?;
let model_name = if let Some(model) = args.model {
model // Use user-specified model
} else {
config.ai.embedding_model.clone() // Default model from configuration
};
let client = EmbeddingServiceBuilder::new()
.model_name(model_name.into())
.build()?;

let max_threads = 10; // Define the maximum number of threads
let mut handles = vec![];

// Spawn a fixed number of threads
for _ in 0..max_threads {
let client_clone = client.clone();
let db_clone = db.clone();
let receiver_clone: Receiver<Job> = receiver.clone();

let handle = tokio::spawn(async move {
loop {
let job = receiver_clone.recv().await;
let (filename, content) = match job {
Ok(job) => job,
Err(_) => break,
};
eprintln!("Processing File: {}", filename);

let embedding_response = client_clone.inner.get_embedding(content).await.unwrap();

let embedding_data = EmbeddingData {
file_name: filename.clone(),
embedding: embedding_response.to_vec(),
metadata: Some("Add additional details here if needed".to_string()),
};

let _: Option<Record> = db_clone
.create("embedding_table")
.content(embedding_data)
.await
.unwrap();
}
});
handles.push(handle);
}

// Initialize the OpenAI client
let mut consumer = Consumer { sender };
load_files_into_context(
&mut consumer,
&env::current_dir().unwrap(),
&persona.file_types,
)
.unwrap();

// Close the channel after loading files
drop(consumer);

// Wait for all threads to finish processing
for handle in handles {
let _ = handle.await;
}

Ok(())
}

struct Consumer {
sender: Sender<Job>,
}

impl ContextConsumer for Consumer {
fn consume(&mut self, filename: Cow<str>, content: Cow<str>) -> Result<(), Box<dyn Error>> {
let job = (filename.into(), content.into());
// Use block_in_place to execute blocking operations
tokio::task::block_in_place(|| self.sender.send_blocking(job))?;
Ok(())
}
}
30 changes: 30 additions & 0 deletions src/cli/knowledge/knowledge_args.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
//! This module defines the `KnowledgeArgs` struct, which holds the command-line arguments
//! related to managing knowledge entries in the Rusty Buddy application.
//!
//! It utilizes the `clap` library for parsing, providing a clear interface for users specifying
//! how to manage knowledge entries.
use clap::{Args, Subcommand};

/// This struct holds the command-line arguments for managing knowledge entries.
#[derive(Subcommand)]
pub enum KnowledgeArgs {
Init(InitArgs),
Search(SearchArgs),
}

#[derive(Args)]
pub struct SearchArgs {
pub search: String,
}

#[derive(Args)]
pub struct InitArgs {
/// Specify a persona for the knowledge initalisation
#[arg(short, long)]
pub persona: Option<String>,

/// Sets the AI model to use in this knowledge initalisation
#[arg(short = 'm', long = "model")]
pub model: Option<String>,
}
11 changes: 11 additions & 0 deletions src/cli/knowledge/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
//! This module provides functionality for managing knowledge entries for various commands.
//! It allows users to store, retrieve, and delete knowledge entries that enhance the interactivity
//! of whether it's for wishes or chats.
mod init;
mod knowledge_args;
mod run;
mod search;

pub use knowledge_args::KnowledgeArgs;
pub use run::run_knowledge;
27 changes: 27 additions & 0 deletions src/cli/knowledge/run.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
//! This module provides functionality for managing knowledge entries based
//! on user-defined commands within the Rusty Buddy application.
use crate::cli::knowledge::{init, search, KnowledgeArgs};
use std::error::Error;

/// Runs the knowledge command, executing the specified action based on user input.
///
/// # Arguments
///
/// * `args` - KnowledgeArgs holding the parameters for the command.
///
/// # Returns
///
/// Returns a Result indicating success or an error if the process fails.
pub async fn run_knowledge(args: KnowledgeArgs) -> Result<(), Box<dyn Error>> {
match args {
KnowledgeArgs::Init(init) => {
init::init(init).await?;
}
KnowledgeArgs::Search(search) => {
search::search(search).await?;
}
}

Ok(())
}
42 changes: 42 additions & 0 deletions src/cli/knowledge/search.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
use crate::cli::knowledge::knowledge_args::SearchArgs;
use crate::config;
use crate::config::{get_knowledge_dir, Config};
use crate::knowledge::EmbeddingServiceBuilder;
use serde::{Deserialize, Serialize};
use std::error::Error;
use surrealdb::engine::local::RocksDb;
use surrealdb::Surreal;

#[derive(Debug, Serialize, Deserialize)]
struct Search {
distance: f32,
file_name: String,
}

fn get_config() -> Config {
let config = config::CONFIG.lock().unwrap();
config.clone()
}

pub async fn search(args: SearchArgs) -> Result<(), Box<dyn Error>> {
let config = get_config();
let db = Surreal::new::<RocksDb>(get_knowledge_dir()?.to_str().unwrap()).await?;
db.use_ns("knowledge").use_db("knowledge_db").await?;
let model_name = config.ai.embedding_model.clone();
let client = EmbeddingServiceBuilder::new()
.model_name(model_name.into())
.build()?;
let embedding = client.inner.get_embedding(args.search).await?;
db.query("DEFINE INDEX hnsw_pts ON embedding_table FIELDS embedding HNSW DIMENSION 3072;")
.await?;
// Assuming response.data holds embedding data
let mut groups = db
.query("SELECT file_name, vector::similarity::cosine(embedding, $embedding) AS distance FROM embedding_table WHERE embedding <|10,40|> $embedding ORDER BY distance;")
.bind(("embedding", embedding))
.await?;
let files: Vec<Search> = groups.take(0)?;
for file in files {
println!("{} {}", file.file_name, file.distance);
}
Ok(())
}
1 change: 1 addition & 0 deletions src/cli/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ pub mod createbackground;
pub mod createicon;
pub mod editor;
pub mod init;
pub mod knowledge;
mod slash_completer;
mod spinner;
mod style;
Expand Down
9 changes: 9 additions & 0 deletions src/config/config_file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ pub struct AI {

#[serde(default = "wish_model")]
pub wish_model: String,

#[serde(default = "embedding_model")]
pub embedding_model: String,

#[serde(default = "default_timeout_secs")]
pub chat_timeout_secs: u64,
}
Expand All @@ -115,6 +119,7 @@ fn default_ai() -> AI {
commit_model: commit_model(),
chat_model: chat_model(),
chat_timeout_secs: default_timeout_secs(),
embedding_model: embedding_model(),
}
}

Expand All @@ -133,6 +138,9 @@ fn chat_model() -> String {
fn commit_model() -> String {
"gpt-4o-mini".to_string()
}
fn embedding_model() -> String {
"text-embedding-3-large".to_string()
}

fn wish_model() -> String {
default_model()
Expand Down Expand Up @@ -181,6 +189,7 @@ impl Default for Config {
commit_model: "".to_string(),
wish_model: "".to_string(),
chat_timeout_secs: default_timeout_secs(),
embedding_model: "".to_string(),
},
personas: vec![],
models: None,
Expand Down
6 changes: 6 additions & 0 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ pub fn get_chat_sessions_dir() -> Result<PathBuf, String> {
Ok(config_dir.join("chat"))
}

pub fn get_knowledge_dir() -> Result<PathBuf, String> {
let config_file = get_config_file()?;
let config_dir = config_file.parent().expect("Expected a parent directory");
Ok(config_dir.join("knowledge"))
}

pub fn get_config_file() -> Result<PathBuf, String> {
get_config_file_from_dir(
env::current_dir().map_err(|e| format!("Failed to get the current directory: {}", e))?,
Expand Down
Loading

0 comments on commit 355b8d3

Please sign in to comment.