Skip to content

Commit

Permalink
better support for other langchain llm clients
Browse files Browse the repository at this point in the history
  • Loading branch information
ciekawy committed May 8, 2024
1 parent be87de6 commit d9fe689
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 33 deletions.
3 changes: 0 additions & 3 deletions src/ragas/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,9 @@
from ragas.run_config import RunConfig, add_async_retry, add_retry
import logging

# logging.basicConfig(level=logging.DEBUG)

DEFAULT_MODEL_NAME = "BAAI/bge-small-en-v1.5"



class BaseRagasEmbeddings(Embeddings, ABC):
run_config: RunConfig

Expand Down
5 changes: 3 additions & 2 deletions src/ragas/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ragas.exceptions import ExceptionInRunner
from ragas.executor import Executor
from ragas.llms import llm_factory
from ragas.llms.base import BaseRagasLLM, LangchainLLMWrapper
from ragas.llms.base import BaseRagasLLM, LangchainLLMWrapper, LLMConfig
from ragas.metrics._answer_correctness import AnswerCorrectness
from ragas.metrics.base import Metric, MetricWithEmbeddings, MetricWithLLM
from ragas.metrics.critique import AspectCritique
Expand All @@ -41,6 +41,7 @@ def evaluate(
dataset: Dataset,
metrics: list[Metric] | None = None,
llm: t.Optional[BaseRagasLLM | LangchainLLM] = None,
llm_config: t.Optional[LLMConfig] = None,
embeddings: t.Optional[BaseRagasEmbeddings | LangchainEmbeddings] = None,
callbacks: Callbacks = None,
is_async: bool = True,
Expand Down Expand Up @@ -148,7 +149,7 @@ def evaluate(

# set the llm and embeddings
if isinstance(llm, LangchainLLM):
llm = LangchainLLMWrapper(llm, run_config=run_config)
llm = LangchainLLMWrapper(llm, llm_config=llm_config, run_config=run_config)
if isinstance(embeddings, LangchainEmbeddings):
embeddings = LangchainEmbeddingsWrapper(embeddings)

Expand Down
3 changes: 2 additions & 1 deletion src/ragas/llms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from ragas.llms.base import BaseRagasLLM, LangchainLLMWrapper, llm_factory
from ragas.llms.base import BaseRagasLLM, LangchainLLMWrapper, LLMConfig, llm_factory

__all__ = [
"BaseRagasLLM",
"LangchainLLMWrapper",
"LLMConfig",
"llm_factory",
]
90 changes: 71 additions & 19 deletions src/ragas/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,20 @@
from langchain_community.chat_models.vertexai import ChatVertexAI
from langchain_community.llms import VertexAI
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import HumanMessage
from langchain_core.outputs import LLMResult
from langchain_core.prompt_values import ChatPromptValue, StringPromptValue
from langchain_core.prompts import HumanMessagePromptTemplate, ChatPromptTemplate
from langchain_openai.chat_models import AzureChatOpenAI, ChatOpenAI
from langchain_openai.llms import AzureOpenAI, OpenAI
from langchain_openai.llms.base import BaseOpenAI

from ragas.run_config import RunConfig, add_async_retry, add_retry
import re
import hashlib
import traceback


if t.TYPE_CHECKING:
from langchain_core.callbacks import Callbacks

Expand Down Expand Up @@ -110,6 +115,17 @@ async def generate(
)
return await loop.run_in_executor(None, generate_text)

@dataclass
class LLMConfig:
stop: t.Optional[t.List[str]] = None
params: t.Optional[t.Dict[str, t.Any]] = None
prompt_callback: t.Optional[t.Callable[[PromptValue], t.Tuple[t.List[PromptValue], t.Dict[str, t.Any]]]] = None
result_callback: t.Optional[t.Callable[[LLMResult], t.Tuple[t.List[LLMResult]]]] = None

def __init__(self, stop: t.Optional[t.List[str]] = None, prompt_callback: t.Optional[t.Callable[[PromptValue], t.Tuple[t.List[PromptValue], t.Dict[str, t.Any]]]] = None, **kwargs):
self.stop = stop
self.params = kwargs
self.prompt_callback = prompt_callback

class LangchainLLMWrapper(BaseRagasLLM):
"""
Expand All @@ -120,12 +136,18 @@ class LangchainLLMWrapper(BaseRagasLLM):
"""

def __init__(
self, langchain_llm: BaseLanguageModel, run_config: t.Optional[RunConfig] = None
self,
langchain_llm: BaseLanguageModel,
run_config: t.Optional[RunConfig] = None,
llm_config: LLMConfig = None,
):
self.langchain_llm = langchain_llm
if run_config is None:
run_config = RunConfig()
self.set_run_config(run_config)
if llm_config is None:
llm_config = LLMConfig()
self.llm_config = llm_config

def generate_text(
self,
Expand All @@ -136,21 +158,38 @@ def generate_text(
callbacks: Callbacks = None,
) -> LLMResult:
temperature = self.get_temperature(n=n)
stop = stop or self.llm_config.stop

if self.llm_config.prompt_callback:
prompts, extra_params = self.llm_config.prompt_callback(prompt)
else:
prompts = [prompt]
extra_params = {}

if is_multiple_completion_supported(self.langchain_llm):
return self.langchain_llm.generate_prompt(
prompts=[prompt],
result = self.langchain_llm.generate_prompt(
prompts=prompts,
n=n,
temperature=temperature,
stop=stop,
callbacks=callbacks,
stop=stop,
**self.llm_config.params,
**extra_params,
)
if self.llm_config.result_callback:
return self.llm_config.result_callback(result)
return result
else:
result = self.langchain_llm.generate_prompt(
prompts=[prompt] * n,
temperature=temperature,
stop=stop,
callbacks=callbacks,
**self.llm_config.params,
**extra_params,
)
if self.llm_config.result_callback:
result = self.llm_config.result_callback(result)
# make LLMResult.generation appear as if it was n_completions
# note that LLMResult.runs is still a list that represents each run
generations = [[g[0] for g in result.generations]]
Expand All @@ -162,43 +201,56 @@ async def agenerate_text(
prompt: PromptValue,
n: int = 1,
temperature: float = 1e-8,
stop: t.Optional[t.List[str]] = None, #["<|eot_id|>"], #None,
stop: t.Optional[t.List[str]] = None,
callbacks: Callbacks = None,
) -> LLMResult:
# traceback.print_stack()
logger.debug(f"Generating text with prompt: {str(prompt).encode('utf-8').decode('unicode_escape')}...")
stop = ["<|eot_id|>"]
# ["</s>", "[/INST]"] #
prompt.prompt_str =f"<human>: {prompt.prompt_str}\n<bot>:"
# to trace request/response for multi-threaded execution
gen_id = hashlib.md5(str(prompt).encode('utf-8')).hexdigest()[:4]
stop = stop or self.llm_config.stop
prompt_str = prompt.prompt_str
logger.debug(f"Generating text for [{gen_id}] with prompt: {prompt_str}")
temperature = self.get_temperature(n=n)
if self.llm_config.prompt_callback:
prompts, extra_params = self.llm_config.prompt_callback(prompt)
else:
prompts = [prompt] * n
extra_params = {}
if is_multiple_completion_supported(self.langchain_llm):
response = await self.langchain_llm.agenerate_prompt(
prompts=[prompt],
result = await self.langchain_llm.agenerate_prompt(
prompts=prompts,
n=n,
temperature=temperature,
stop=stop,
callbacks=callbacks,
**self.llm_config.params,
**extra_params,
)
logger.debug(f"got result (m): {response.generations[0][0].text}")
return response
if self.llm_config.result_callback:
result = self.llm_config.result_callback(result)
logger.debug(f"got result (m): {result.generations[0][0].text}")
return result
else:
result = await self.langchain_llm.agenerate_prompt(
prompts=[prompt] * n,
prompts=prompts,
temperature=temperature,
stop=stop,
callbacks=callbacks,
**self.llm_config.params,
**extra_params,
)
if self.llm_config.result_callback:
result = self.llm_config.result_callback(result)
# make LLMResult.generation appear as if it was n_completions
# note that LLMResult.runs is still a list that represents each run
generations = [[g[0] for g in result.generations]]
result.generations = generations

# this part should go to LLMConfig.result_callback
if len(result.generations[0][0].text) > 0:
# while the <human>/<bot> tags improves answer quality, I observed sometimes the </bit> to leak into the response
result.generations[0][0].text = re.sub(r"</?bot>", '', result.generations[0][0].text)
logger.debug(f"got result: {result.generations[0][0].text}")
logger.debug(f"got result [{gen_id}]: {result.generations[0][0].text}")
# todo configure on question?
if len(result.generations[0][0].text) < 24:
logger.warn(f"truncated response?: {result.generations}")
logger.warning(f"truncated response?: {result.generations}")
return result

def set_run_config(self, run_config: RunConfig):
Expand Down
14 changes: 9 additions & 5 deletions src/ragas/testset/evolutions.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,13 @@ async def generate_datarow(
):
assert self.generator_llm is not None, "generator_llm cannot be None"

node_content = [
f"{i+1}\t{n.page_content}" for i, n in enumerate(current_nodes.nodes)
]
# clear index distinction helps in getting it more clear for LLM - especially for long, complex contexts
node_content = {
str(i + 1): n.page_content for i, n in enumerate(current_nodes.nodes)
}
# node_content = [
# f"{i+1}\t{n.page_content}" for i, n in enumerate(current_nodes.nodes)
# ]
results = await self.generator_llm.generate(
prompt=self.find_relevant_context_prompt.format(
question=question, contexts=node_content
Expand All @@ -208,9 +212,9 @@ async def generate_datarow(
)
else:
selected_nodes = [
current_nodes.nodes[i - 1]
current_nodes.nodes[int(i) - 1]
for i in relevant_context_indices
if i - 1 < len(current_nodes.nodes)
if int(i) - 1 < len(current_nodes.nodes)
]
relevant_context = (
CurrentNodes(root_node=selected_nodes[0], nodes=selected_nodes)
Expand Down
9 changes: 6 additions & 3 deletions src/ragas/testset/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ragas.embeddings.base import BaseRagasEmbeddings, LangchainEmbeddingsWrapper
from ragas.exceptions import ExceptionInRunner
from ragas.executor import Executor
from ragas.llms import BaseRagasLLM, LangchainLLMWrapper
from ragas.llms import BaseRagasLLM, LangchainLLMWrapper, LLMConfig
from ragas.run_config import RunConfig
from ragas.testset.docstore import Document, DocumentStore, InMemoryDocumentStore
from ragas.testset.evolutions import (
Expand All @@ -32,6 +32,7 @@
from ragas.testset.filters import EvolutionFilter, NodeFilter, QuestionFilter
from ragas.utils import check_if_sum_is_close, deprecated, get_feature_language, is_nan


if t.TYPE_CHECKING:
from langchain_core.documents import Document as LCDocument
from llama_index.core.schema import Document as LlamaindexDocument
Expand Down Expand Up @@ -81,9 +82,11 @@ def from_langchain(
docstore: t.Optional[DocumentStore] = None,
run_config: t.Optional[RunConfig] = None,
chunk_size: int = 1024,
generator_llm_config: t.Optional[LLMConfig] = None,
critic_llm_config: t.Optional[LLMConfig] = None,
) -> "TestsetGenerator":
generator_llm_model = LangchainLLMWrapper(generator_llm)
critic_llm_model = LangchainLLMWrapper(critic_llm)
generator_llm_model = LangchainLLMWrapper(generator_llm, llm_config=generator_llm_config)
critic_llm_model = LangchainLLMWrapper(critic_llm, llm_config=critic_llm_config)
embeddings_model = LangchainEmbeddingsWrapper(embeddings)

keyphrase_extractor = KeyphraseExtractor(llm=generator_llm_model)
Expand Down

0 comments on commit d9fe689

Please sign in to comment.