diff --git a/src/ragas/embeddings/base.py b/src/ragas/embeddings/base.py index fcb9d1dc2d..87e1b0ea8e 100644 --- a/src/ragas/embeddings/base.py +++ b/src/ragas/embeddings/base.py @@ -14,12 +14,9 @@ from ragas.run_config import RunConfig, add_async_retry, add_retry import logging -# logging.basicConfig(level=logging.DEBUG) - DEFAULT_MODEL_NAME = "BAAI/bge-small-en-v1.5" - class BaseRagasEmbeddings(Embeddings, ABC): run_config: RunConfig diff --git a/src/ragas/evaluation.py b/src/ragas/evaluation.py index 2b205c6d06..71ee2423a8 100644 --- a/src/ragas/evaluation.py +++ b/src/ragas/evaluation.py @@ -18,7 +18,7 @@ from ragas.exceptions import ExceptionInRunner from ragas.executor import Executor from ragas.llms import llm_factory -from ragas.llms.base import BaseRagasLLM, LangchainLLMWrapper +from ragas.llms.base import BaseRagasLLM, LangchainLLMWrapper, LLMConfig from ragas.metrics._answer_correctness import AnswerCorrectness from ragas.metrics.base import Metric, MetricWithEmbeddings, MetricWithLLM from ragas.metrics.critique import AspectCritique @@ -41,6 +41,7 @@ def evaluate( dataset: Dataset, metrics: list[Metric] | None = None, llm: t.Optional[BaseRagasLLM | LangchainLLM] = None, + llm_config: t.Optional[LLMConfig] = None, embeddings: t.Optional[BaseRagasEmbeddings | LangchainEmbeddings] = None, callbacks: Callbacks = None, is_async: bool = True, @@ -148,7 +149,7 @@ def evaluate( # set the llm and embeddings if isinstance(llm, LangchainLLM): - llm = LangchainLLMWrapper(llm, run_config=run_config) + llm = LangchainLLMWrapper(llm, llm_config=llm_config, run_config=run_config) if isinstance(embeddings, LangchainEmbeddings): embeddings = LangchainEmbeddingsWrapper(embeddings) diff --git a/src/ragas/llms/__init__.py b/src/ragas/llms/__init__.py index dc1e37fb53..0a24f1c06c 100644 --- a/src/ragas/llms/__init__.py +++ b/src/ragas/llms/__init__.py @@ -1,7 +1,8 @@ -from ragas.llms.base import BaseRagasLLM, LangchainLLMWrapper, llm_factory +from ragas.llms.base import BaseRagasLLM, LangchainLLMWrapper, LLMConfig, llm_factory __all__ = [ "BaseRagasLLM", "LangchainLLMWrapper", + "LLMConfig", "llm_factory", ] diff --git a/src/ragas/llms/base.py b/src/ragas/llms/base.py index 64259041d1..b4bd9ff098 100644 --- a/src/ragas/llms/base.py +++ b/src/ragas/llms/base.py @@ -10,15 +10,20 @@ from langchain_community.chat_models.vertexai import ChatVertexAI from langchain_community.llms import VertexAI from langchain_core.language_models import BaseLanguageModel +from langchain_core.messages import HumanMessage from langchain_core.outputs import LLMResult +from langchain_core.prompt_values import ChatPromptValue, StringPromptValue +from langchain_core.prompts import HumanMessagePromptTemplate, ChatPromptTemplate from langchain_openai.chat_models import AzureChatOpenAI, ChatOpenAI from langchain_openai.llms import AzureOpenAI, OpenAI from langchain_openai.llms.base import BaseOpenAI from ragas.run_config import RunConfig, add_async_retry, add_retry import re +import hashlib import traceback + if t.TYPE_CHECKING: from langchain_core.callbacks import Callbacks @@ -110,6 +115,17 @@ async def generate( ) return await loop.run_in_executor(None, generate_text) +@dataclass +class LLMConfig: + stop: t.Optional[t.List[str]] = None + params: t.Optional[t.Dict[str, t.Any]] = None + prompt_callback: t.Optional[t.Callable[[PromptValue], t.Tuple[t.List[PromptValue], t.Dict[str, t.Any]]]] = None + result_callback: t.Optional[t.Callable[[LLMResult], t.Tuple[t.List[LLMResult]]]] = None + + def __init__(self, stop: t.Optional[t.List[str]] = None, prompt_callback: t.Optional[t.Callable[[PromptValue], t.Tuple[t.List[PromptValue], t.Dict[str, t.Any]]]] = None, **kwargs): + self.stop = stop + self.params = kwargs + self.prompt_callback = prompt_callback class LangchainLLMWrapper(BaseRagasLLM): """ @@ -120,12 +136,18 @@ class LangchainLLMWrapper(BaseRagasLLM): """ def __init__( - self, langchain_llm: BaseLanguageModel, run_config: t.Optional[RunConfig] = None + self, + langchain_llm: BaseLanguageModel, + run_config: t.Optional[RunConfig] = None, + llm_config: LLMConfig = None, ): self.langchain_llm = langchain_llm if run_config is None: run_config = RunConfig() self.set_run_config(run_config) + if llm_config is None: + llm_config = LLMConfig() + self.llm_config = llm_config def generate_text( self, @@ -136,21 +158,38 @@ def generate_text( callbacks: Callbacks = None, ) -> LLMResult: temperature = self.get_temperature(n=n) + stop = stop or self.llm_config.stop + + if self.llm_config.prompt_callback: + prompts, extra_params = self.llm_config.prompt_callback(prompt) + else: + prompts = [prompt] + extra_params = {} + if is_multiple_completion_supported(self.langchain_llm): - return self.langchain_llm.generate_prompt( - prompts=[prompt], + result = self.langchain_llm.generate_prompt( + prompts=prompts, n=n, temperature=temperature, - stop=stop, callbacks=callbacks, + stop=stop, + **self.llm_config.params, + **extra_params, ) + if self.llm_config.result_callback: + return self.llm_config.result_callback(result) + return result else: result = self.langchain_llm.generate_prompt( prompts=[prompt] * n, temperature=temperature, stop=stop, callbacks=callbacks, + **self.llm_config.params, + **extra_params, ) + if self.llm_config.result_callback: + result = self.llm_config.result_callback(result) # make LLMResult.generation appear as if it was n_completions # note that LLMResult.runs is still a list that represents each run generations = [[g[0] for g in result.generations]] @@ -162,43 +201,56 @@ async def agenerate_text( prompt: PromptValue, n: int = 1, temperature: float = 1e-8, - stop: t.Optional[t.List[str]] = None, #["<|eot_id|>"], #None, + stop: t.Optional[t.List[str]] = None, callbacks: Callbacks = None, ) -> LLMResult: - # traceback.print_stack() - logger.debug(f"Generating text with prompt: {str(prompt).encode('utf-8').decode('unicode_escape')}...") - stop = ["<|eot_id|>"] - # ["", "[/INST]"] # - prompt.prompt_str =f": {prompt.prompt_str}\n:" + # to trace request/response for multi-threaded execution + gen_id = hashlib.md5(str(prompt).encode('utf-8')).hexdigest()[:4] + stop = stop or self.llm_config.stop + prompt_str = prompt.prompt_str + logger.debug(f"Generating text for [{gen_id}] with prompt: {prompt_str}") temperature = self.get_temperature(n=n) + if self.llm_config.prompt_callback: + prompts, extra_params = self.llm_config.prompt_callback(prompt) + else: + prompts = [prompt] * n + extra_params = {} if is_multiple_completion_supported(self.langchain_llm): - response = await self.langchain_llm.agenerate_prompt( - prompts=[prompt], + result = await self.langchain_llm.agenerate_prompt( + prompts=prompts, n=n, temperature=temperature, stop=stop, callbacks=callbacks, + **self.llm_config.params, + **extra_params, ) - logger.debug(f"got result (m): {response.generations[0][0].text}") - return response + if self.llm_config.result_callback: + result = self.llm_config.result_callback(result) + logger.debug(f"got result (m): {result.generations[0][0].text}") + return result else: result = await self.langchain_llm.agenerate_prompt( - prompts=[prompt] * n, + prompts=prompts, temperature=temperature, - stop=stop, callbacks=callbacks, + **self.llm_config.params, + **extra_params, ) + if self.llm_config.result_callback: + result = self.llm_config.result_callback(result) # make LLMResult.generation appear as if it was n_completions # note that LLMResult.runs is still a list that represents each run generations = [[g[0] for g in result.generations]] result.generations = generations + + # this part should go to LLMConfig.result_callback if len(result.generations[0][0].text) > 0: - # while the / tags improves answer quality, I observed sometimes the to leak into the response result.generations[0][0].text = re.sub(r"", '', result.generations[0][0].text) - logger.debug(f"got result: {result.generations[0][0].text}") + logger.debug(f"got result [{gen_id}]: {result.generations[0][0].text}") # todo configure on question? if len(result.generations[0][0].text) < 24: - logger.warn(f"truncated response?: {result.generations}") + logger.warning(f"truncated response?: {result.generations}") return result def set_run_config(self, run_config: RunConfig): diff --git a/src/ragas/testset/evolutions.py b/src/ragas/testset/evolutions.py index 1bb5837c55..1bea209082 100644 --- a/src/ragas/testset/evolutions.py +++ b/src/ragas/testset/evolutions.py @@ -186,9 +186,13 @@ async def generate_datarow( ): assert self.generator_llm is not None, "generator_llm cannot be None" - node_content = [ - f"{i+1}\t{n.page_content}" for i, n in enumerate(current_nodes.nodes) - ] + # clear index distinction helps in getting it more clear for LLM - especially for long, complex contexts + node_content = { + str(i + 1): n.page_content for i, n in enumerate(current_nodes.nodes) + } + # node_content = [ + # f"{i+1}\t{n.page_content}" for i, n in enumerate(current_nodes.nodes) + # ] results = await self.generator_llm.generate( prompt=self.find_relevant_context_prompt.format( question=question, contexts=node_content @@ -208,9 +212,9 @@ async def generate_datarow( ) else: selected_nodes = [ - current_nodes.nodes[i - 1] + current_nodes.nodes[int(i) - 1] for i in relevant_context_indices - if i - 1 < len(current_nodes.nodes) + if int(i) - 1 < len(current_nodes.nodes) ] relevant_context = ( CurrentNodes(root_node=selected_nodes[0], nodes=selected_nodes) diff --git a/src/ragas/testset/generator.py b/src/ragas/testset/generator.py index 60c9dca692..3f9e1908a2 100644 --- a/src/ragas/testset/generator.py +++ b/src/ragas/testset/generator.py @@ -16,7 +16,7 @@ from ragas.embeddings.base import BaseRagasEmbeddings, LangchainEmbeddingsWrapper from ragas.exceptions import ExceptionInRunner from ragas.executor import Executor -from ragas.llms import BaseRagasLLM, LangchainLLMWrapper +from ragas.llms import BaseRagasLLM, LangchainLLMWrapper, LLMConfig from ragas.run_config import RunConfig from ragas.testset.docstore import Document, DocumentStore, InMemoryDocumentStore from ragas.testset.evolutions import ( @@ -32,6 +32,7 @@ from ragas.testset.filters import EvolutionFilter, NodeFilter, QuestionFilter from ragas.utils import check_if_sum_is_close, deprecated, get_feature_language, is_nan + if t.TYPE_CHECKING: from langchain_core.documents import Document as LCDocument from llama_index.core.schema import Document as LlamaindexDocument @@ -81,9 +82,11 @@ def from_langchain( docstore: t.Optional[DocumentStore] = None, run_config: t.Optional[RunConfig] = None, chunk_size: int = 1024, + generator_llm_config: t.Optional[LLMConfig] = None, + critic_llm_config: t.Optional[LLMConfig] = None, ) -> "TestsetGenerator": - generator_llm_model = LangchainLLMWrapper(generator_llm) - critic_llm_model = LangchainLLMWrapper(critic_llm) + generator_llm_model = LangchainLLMWrapper(generator_llm, llm_config=generator_llm_config) + critic_llm_model = LangchainLLMWrapper(critic_llm, llm_config=critic_llm_config) embeddings_model = LangchainEmbeddingsWrapper(embeddings) keyphrase_extractor = KeyphraseExtractor(llm=generator_llm_model)