Skip to content

Commit

Permalink
Fix some typos in package and struct names (#119)
Browse files Browse the repository at this point in the history
  • Loading branch information
aadi-pcln authored Apr 16, 2024
1 parent 2aac755 commit 01ec67c
Show file tree
Hide file tree
Showing 11 changed files with 61 additions and 61 deletions.
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
// To run this example execute: cargo run --example conversational_retriver_chain --features postgres
// To run this example execute: cargo run --example conversational_retriever_chain --features postgres

#[cfg(feature = "postgres")]
use futures_util::StreamExt;
#[cfg(feature = "postgres")]
use langchain_rust::{
add_documents,
chain::{Chain, ConversationalRetriverChainBuilder},
chain::{Chain, ConversationalRetrieverChainBuilder},
embedding::openai::openai_embedder::OpenAiEmbedder,
llm::{OpenAI, OpenAIModel},
memory::SimpleMemory,
Expand Down Expand Up @@ -51,11 +51,11 @@ async fn main() {

let llm = OpenAI::default().with_model(OpenAIModel::Gpt35.to_string());

let chain = ConversationalRetriverChainBuilder::new()
let chain = ConversationalRetrieverChainBuilder::new()
.llm(llm)
.rephrase_question(true)
.memory(SimpleMemory::new().into())
.retriver(Retriever::new(store, 5))
.retriever(Retriever::new(store, 5))
.build()
.expect("Error building ConversationalChain");

Expand Down Expand Up @@ -88,5 +88,5 @@ async fn main() {
fn main() {
println!("This example requires the 'postgres' feature to be enabled.");
println!("Please run the command as follows:");
println!("cargo run --example conversational_retriver_chain --features postgres");
println!("cargo run --example conversational_retriever_chain --features postgres");
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@ use std::error::Error;
use async_trait::async_trait;
use futures_util::StreamExt;
use langchain_rust::{
chain::{Chain, ConversationalRetriverChainBuilder},
chain::{Chain, ConversationalRetrieverChainBuilder},
llm::{OpenAI, OpenAIModel},
memory::SimpleMemory,
prompt_args,
schemas::{Document, Retriever},
};

struct RetriverMock {}
struct RetrieverMock {}
#[async_trait]
impl Retriever for RetriverMock {
impl Retriever for RetrieverMock {
async fn get_relevant_documents(
&self,
_question: &str,
Expand Down Expand Up @@ -40,10 +40,10 @@ impl Retriever for RetriverMock {
#[tokio::main]
async fn main() {
let llm = OpenAI::default().with_model(OpenAIModel::Gpt35.to_string());
let chain = ConversationalRetriverChainBuilder::new()
let chain = ConversationalRetrieverChainBuilder::new()
.llm(llm)
.rephrase_question(true)
.retriver(RetriverMock {})
.retriever(RetrieverMock {})
.memory(SimpleMemory::new().into())
.build()
.expect("Error building ConversationalChain");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,24 @@ use std::sync::Arc;
use tokio::sync::Mutex;

use crate::{
chain::{Chain, ChainError, CondenseQuetionGeneratorChain, StuffDocument, DEFAULT_OUTPUT_KEY},
chain::{Chain, ChainError, CondenseQuestionGeneratorChain, StuffDocument, DEFAULT_OUTPUT_KEY},
language_models::llm::LLM,
memory::SimpleMemory,
schemas::{BaseMemory, Retriever},
};

use super::ConversationalRetriverChain;
use super::ConversationalRetrieverChain;

const CONVERSATIONAL_RETRIEVAL_QA_DEFAULT_INPUT_KEY: &str = "question";

///Conversation Retriver Chain Builder
///Conversation Retriever Chain Builder
/// # Usage
/// ## Convensional way
/// ```rust,ignore
/// let chain = ConversationalRetriverChainBuilder::new()
/// let chain = ConversationalRetrieverChainBuilder::new()
/// .llm(llm)
/// .rephrase_question(true)
/// .retriver(RetriverMock {})
/// .retriever(RetrieverMock {})
/// .memory(SimpleMemory::new().into())
/// .build()
/// .expect("Error building ConversationalChain");
Expand All @@ -30,20 +30,20 @@ const CONVERSATIONAL_RETRIEVAL_QA_DEFAULT_INPUT_KEY: &str = "question";
///
/// let llm = Box::new(OpenAI::default().with_model(OpenAIModel::Gpt35.to_string()));
/// let combine_documents_chain = StuffDocument::load_stuff_qa(llm.clone_box());
// let condense_question_chian = CondenseQuetionGeneratorChain::new(llm.clone_box());
/// let chain = ConversationalRetriverChainBuilder::new()
// let condense_question_chian = CondenseQuestionGeneratorChain::new(llm.clone_box());
/// let chain = ConversationalRetrieverChainBuilder::new()
/// .rephrase_question(true)
/// .combine_documents_chain(Box::new(combine_documents_chain))
/// .condense_question_chian(Box::new(condense_question_chian))
/// .retriver(RetriverMock {})
/// .retriever(RetrieverMock {})
/// .memory(SimpleMemory::new().into())
/// .build()
/// .expect("Error building ConversationalChain");
/// ```
///
pub struct ConversationalRetriverChainBuilder {
pub struct ConversationalRetrieverChainBuilder {
llm: Option<Box<dyn LLM>>,
retriver: Option<Box<dyn Retriever>>,
retriever: Option<Box<dyn Retriever>>,
memory: Option<Arc<Mutex<dyn BaseMemory>>>,
combine_documents_chain: Option<Box<dyn Chain>>,
condense_question_chian: Option<Box<dyn Chain>>,
Expand All @@ -52,11 +52,11 @@ pub struct ConversationalRetriverChainBuilder {
input_key: String,
output_key: String,
}
impl ConversationalRetriverChainBuilder {
impl ConversationalRetrieverChainBuilder {
pub fn new() -> Self {
ConversationalRetriverChainBuilder {
ConversationalRetrieverChainBuilder {
llm: None,
retriver: None,
retriever: None,
memory: None,
combine_documents_chain: None,
condense_question_chian: None,
Expand All @@ -67,8 +67,8 @@ impl ConversationalRetriverChainBuilder {
}
}

pub fn retriver<R: Into<Box<dyn Retriever>>>(mut self, retriver: R) -> Self {
self.retriver = Some(retriver.into());
pub fn retriever<R: Into<Box<dyn Retriever>>>(mut self, retriever: R) -> Self {
self.retriever = Some(retriever.into());
self
}

Expand Down Expand Up @@ -115,17 +115,17 @@ impl ConversationalRetriverChainBuilder {
self
}

pub fn build(mut self) -> Result<ConversationalRetriverChain, ChainError> {
pub fn build(mut self) -> Result<ConversationalRetrieverChain, ChainError> {
if let Some(llm) = self.llm {
let combine_documents_chain = StuffDocument::load_stuff_qa(llm.clone_box());
let condense_question_chian = CondenseQuetionGeneratorChain::new(llm.clone_box());
let condense_question_chian = CondenseQuestionGeneratorChain::new(llm.clone_box());
self.combine_documents_chain = Some(Box::new(combine_documents_chain));
self.condense_question_chian = Some(Box::new(condense_question_chian));
}

let retriver = self
.retriver
.ok_or_else(|| ChainError::MissingObject("Retriver must be set".into()))?;
let retriever = self
.retriever
.ok_or_else(|| ChainError::MissingObject("Retriever must be set".into()))?;

let memory = self
.memory
Expand All @@ -141,8 +141,8 @@ impl ConversationalRetriverChainBuilder {
"Condense question chain must be set or llm must be set".into(),
)
})?;
Ok(ConversationalRetriverChain {
retriver,
Ok(ConversationalRetrieverChain {
retriever,
memory,
combine_documents_chain,
condense_question_chian,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ use crate::{
const CONVERSATIONAL_RETRIEVAL_QA_DEFAULT_SOURCE_DOCUMENT_KEY: &str = "source_documents";
const CONVERSATIONAL_RETRIEVAL_QA_DEFAULT_GENERATED_QUESTION_KEY: &str = "generated_question";

pub struct ConversationalRetriverChain {
pub(crate) retriver: Box<dyn Retriever>,
pub struct ConversationalRetrieverChain {
pub(crate) retriever: Box<dyn Retriever>,
pub memory: Arc<Mutex<dyn BaseMemory>>,
pub(crate) combine_documents_chain: Box<dyn Chain>,
pub(crate) condense_question_chian: Box<dyn Chain>,
Expand All @@ -33,7 +33,7 @@ pub struct ConversationalRetriverChain {
pub(crate) input_key: String, //Default is `question`
pub(crate) output_key: String, //default is output
}
impl ConversationalRetriverChain {
impl ConversationalRetrieverChain {
async fn get_question(
&self,
history: &[Message],
Expand Down Expand Up @@ -67,7 +67,7 @@ impl ConversationalRetriverChain {
}

#[async_trait]
impl Chain for ConversationalRetriverChain {
impl Chain for ConversationalRetrieverChain {
async fn call(&self, input_variables: PromptArgs) -> Result<GenerateResult, ChainError> {
let output = self.execute(input_variables).await?;
let result: GenerateResult = serde_json::from_value(output[DEFAULT_RESULT_KEY].clone())?;
Expand Down Expand Up @@ -95,10 +95,10 @@ impl Chain for ConversationalRetriverChain {
}

let documents = self
.retriver
.retriever
.get_relevant_documents(&question)
.await
.map_err(|e| ChainError::RetreiverError(e.to_string()))?;
.map_err(|e| ChainError::RetrieverError(e.to_string()))?;

let mut output = self
.combine_documents_chain
Expand Down Expand Up @@ -166,10 +166,10 @@ impl Chain for ConversationalRetriverChain {
let (question, _) = self.get_question(&history, &human_message.content).await?;

let documents = self
.retriver
.retriever
.get_relevant_documents(&question)
.await
.map_err(|e| ChainError::RetreiverError(e.to_string()))?;
.map_err(|e| ChainError::RetrieverError(e.to_string()))?;

let stream = self
.combine_documents_chain
Expand Down Expand Up @@ -235,7 +235,7 @@ mod tests {
use std::error::Error;

use crate::{
chain::ConversationalRetriverChainBuilder,
chain::ConversationalRetrieverChainBuilder,
llm::openai::{OpenAI, OpenAIModel},
memory::SimpleMemory,
prompt_args,
Expand All @@ -244,9 +244,9 @@ mod tests {

use super::*;

struct RetriverTest {}
struct RetrieverTest {}
#[async_trait]
impl Retriever for RetriverTest {
impl Retriever for RetrieverTest {
async fn get_relevant_documents(
&self,
_question: &str,
Expand Down Expand Up @@ -276,9 +276,9 @@ mod tests {
#[ignore]
async fn test_invoke_retriever_conversational() {
let llm = OpenAI::default().with_model(OpenAIModel::Gpt35.to_string());
let chain = ConversationalRetriverChainBuilder::new()
let chain = ConversationalRetrieverChainBuilder::new()
.llm(llm)
.retriver(RetriverTest {})
.retriever(RetrieverTest {})
.memory(SimpleMemory::new().into())
.build()
.expect("Error building ConversationalChain");
Expand Down
5 changes: 5 additions & 0 deletions src/chain/conversational_retrieval_qa/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
mod builder;
pub use builder::*;

mod conversational_retrieval_qa;
pub use conversational_retrieval_qa::*;
5 changes: 0 additions & 5 deletions src/chain/conversational_retrival_qa/mod.rs

This file was deleted.

4 changes: 2 additions & 2 deletions src/chain/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ pub enum ChainError {
#[error("LLM error: {0}")]
LLMError(#[from] LLMError),

#[error("Retreiver error: {0}")]
RetreiverError(String),
#[error("Retriever error: {0}")]
RetrieverError(String),

#[error("OutputParser error: {0}")]
OutputParser(#[from] OutputParserError),
Expand Down
4 changes: 2 additions & 2 deletions src/chain/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ pub use stuff_documents::*;
mod question_answering;
pub use question_answering::*;

mod conversational_retrival_qa;
pub use conversational_retrival_qa::*;
mod conversational_retrieval_qa;
pub use conversational_retrieval_qa::*;

mod error;
pub use error::*;
Expand Down
6 changes: 3 additions & 3 deletions src/chain/question_answering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ impl CondenseQuestionPromptBuilder {
}
}

pub struct CondenseQuetionGeneratorChain {
pub struct CondenseQuestionGeneratorChain {
chain: LLMChain,
}

impl CondenseQuetionGeneratorChain {
impl CondenseQuestionGeneratorChain {
pub fn new<L: Into<Box<dyn LLM>>>(llm: L) -> Self {
let condense_question_prompt_template =
template_jinja2!(DEFAULTCONDENSEQUESTIONTEMPLATE, "chat_history", "question");
Expand All @@ -76,7 +76,7 @@ impl CondenseQuetionGeneratorChain {
}

#[async_trait]
impl Chain for CondenseQuetionGeneratorChain {
impl Chain for CondenseQuestionGeneratorChain {
async fn call(&self, input_variables: PromptArgs) -> Result<GenerateResult, ChainError> {
self.chain.call(input_variables).await
}
Expand Down
4 changes: 2 additions & 2 deletions src/schemas/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ pub use prompt::*;
pub mod document;
pub use document::*;

mod retrivers;
pub use retrivers::*;
mod retrievers;
pub use retrievers::*;

mod tools_openai_like;
pub use tools_openai_like::*;
Expand Down
4 changes: 2 additions & 2 deletions src/schemas/retrivers.rs → src/schemas/retrievers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ impl<R> From<R> for Box<dyn Retriever>
where
R: Retriever + 'static,
{
fn from(retriver: R) -> Self {
Box::new(retriver)
fn from(retriever: R) -> Self {
Box::new(retriever)
}
}

0 comments on commit 01ec67c

Please sign in to comment.