From b99a35a48836ce6d6db7330eab1d264b58bfd1fd Mon Sep 17 00:00:00 2001 From: Nima Date: Sat, 20 Apr 2024 11:39:56 +0200 Subject: [PATCH] fix: replace assert with ValueError for llm validation --- src/ragas/metrics/_answer_correctness.py | 6 ++++-- src/ragas/metrics/_answer_relevance.py | 6 ++++-- src/ragas/metrics/_context_entities_recall.py | 3 ++- src/ragas/metrics/_context_precision.py | 6 ++++-- src/ragas/metrics/_context_recall.py | 6 ++++-- src/ragas/metrics/_context_relevancy.py | 6 ++++-- src/ragas/metrics/_faithfulness.py | 9 ++++++--- src/ragas/metrics/critique.py | 6 ++++-- 8 files changed, 32 insertions(+), 16 deletions(-) diff --git a/src/ragas/metrics/_answer_correctness.py b/src/ragas/metrics/_answer_correctness.py index fe206b8459..785c4e5fe6 100644 --- a/src/ragas/metrics/_answer_correctness.py +++ b/src/ragas/metrics/_answer_correctness.py @@ -140,7 +140,8 @@ def _compute_statement_presence( return score async def _ascore(self, row: t.Dict, callbacks: Callbacks, is_async: bool) -> float: - assert self.llm is not None, "LLM must be set" + if self.llm is None: + raise ValueError("LLM must be set.") q, a, g = row["question"], row["answer"], row["ground_truth"] p_value = self.correctness_prompt.format(question=q, ground_truth=g, answer=a) @@ -174,7 +175,8 @@ async def _ascore(self, row: t.Dict, callbacks: Callbacks, is_async: bool) -> fl return float(score) def adapt(self, language: str, cache_dir: t.Optional[str] = None) -> None: - assert self.llm is not None, "llm must be set to compute score" + if self.llm is None: + raise ValueError("llm must be set to compute score.") logger.info(f"Adapting AnswerCorrectness metric to {language}") self.correctness_prompt = self.correctness_prompt.adapt( diff --git a/src/ragas/metrics/_answer_relevance.py b/src/ragas/metrics/_answer_relevance.py index af95aa06da..39dcee0eeb 100644 --- a/src/ragas/metrics/_answer_relevance.py +++ b/src/ragas/metrics/_answer_relevance.py @@ -146,7 +146,8 @@ def _create_question_gen_prompt(self, row: t.Dict) -> PromptValue: return self.question_generation.format(answer=ans, context="\n".join(ctx)) async def _ascore(self, row: t.Dict, callbacks: Callbacks, is_async: bool) -> float: - assert self.llm is not None, "LLM is not set" + if self.llm is None: + raise ValueError("LLM is not set") prompt = self._create_question_gen_prompt(row) result = await self.llm.generate( @@ -167,7 +168,8 @@ async def _ascore(self, row: t.Dict, callbacks: Callbacks, is_async: bool) -> fl return self._calculate_score(answers, row) def adapt(self, language: str, cache_dir: str | None = None) -> None: - assert self.llm is not None, "LLM is not set" + if self.llm is None: + raise ValueError("LLM is not set") logger.info(f"Adapting AnswerRelevancy metric to {language}") self.question_generation = self.question_generation.adapt( diff --git a/src/ragas/metrics/_context_entities_recall.py b/src/ragas/metrics/_context_entities_recall.py index ba249107f7..71b00a65cd 100644 --- a/src/ragas/metrics/_context_entities_recall.py +++ b/src/ragas/metrics/_context_entities_recall.py @@ -151,7 +151,8 @@ async def get_entities( callbacks: Callbacks, is_async: bool, ) -> t.Optional[ContextEntitiesResponse]: - assert self.llm is not None, "LLM is not initialized" + if self.llm is None: + raise ValueError("LLM is not initialized") p_value = self.context_entity_recall_prompt.format( text=text, ) diff --git a/src/ragas/metrics/_context_precision.py b/src/ragas/metrics/_context_precision.py index e6a0ff41e2..8f523b1076 100644 --- a/src/ragas/metrics/_context_precision.py +++ b/src/ragas/metrics/_context_precision.py @@ -136,7 +136,8 @@ async def _ascore( callbacks: Callbacks, is_async: bool, ) -> float: - assert self.llm is not None, "LLM is not set" + if self.llm is None: + raise ValueError("LLM is not set") human_prompts = self._context_precision_prompt(row) responses = [] @@ -162,7 +163,8 @@ async def _ascore( return score def adapt(self, language: str, cache_dir: str | None = None) -> None: - assert self.llm is not None, "LLM is not set" + if self.llm is None: + raise ValueError("LLM is not set") logging.info(f"Adapting Context Precision to {language}") self.context_precision_prompt = self.context_precision_prompt.adapt( diff --git a/src/ragas/metrics/_context_recall.py b/src/ragas/metrics/_context_recall.py index f77c49ce93..c1eb8b984a 100644 --- a/src/ragas/metrics/_context_recall.py +++ b/src/ragas/metrics/_context_recall.py @@ -142,7 +142,8 @@ def _compute_score(self, response: t.Any) -> float: return score async def _ascore(self, row: t.Dict, callbacks: Callbacks, is_async: bool) -> float: - assert self.llm is not None, "set LLM before use" + if self.llm is None: + raise ValueError("set LLM before use") p_value = self._create_context_recall_prompt(row) result = await self.llm.generate( p_value, @@ -160,7 +161,8 @@ async def _ascore(self, row: t.Dict, callbacks: Callbacks, is_async: bool) -> fl return self._compute_score(answers) def adapt(self, language: str, cache_dir: str | None = None) -> None: - assert self.llm is not None, "set LLM before use" + if self.llm is None: + raise ValueError("set LLM before use") logger.info(f"Adapting Context Recall to {language}") self.context_recall_prompt = self.context_recall_prompt.adapt( diff --git a/src/ragas/metrics/_context_relevancy.py b/src/ragas/metrics/_context_relevancy.py index efedc12e8a..81894229f0 100644 --- a/src/ragas/metrics/_context_relevancy.py +++ b/src/ragas/metrics/_context_relevancy.py @@ -67,7 +67,8 @@ def _compute_score(self, response: str, row: t.Dict) -> float: return min(len(indices) / len(context_sents), 1) async def _ascore(self, row: t.Dict, callbacks: Callbacks, is_async: bool) -> float: - assert self.llm is not None, "LLM is not initialized" + if self.llm is None: + raise ValueError("LLM is not initialized") if self.show_deprecation_warning: logger.warning( @@ -84,7 +85,8 @@ async def _ascore(self, row: t.Dict, callbacks: Callbacks, is_async: bool) -> fl return self._compute_score(result.generations[0][0].text, row) def adapt(self, language: str, cache_dir: str | None = None) -> None: - assert self.llm is not None, "set LLM before use" + if self.llm is None: + raise ValueError("set LLM before use") logger.info(f"Adapting Context Relevancy to {language}") self.context_relevancy_prompt = self.context_relevancy_prompt.adapt( diff --git a/src/ragas/metrics/_faithfulness.py b/src/ragas/metrics/_faithfulness.py index ceacc45280..8cc09d7dda 100644 --- a/src/ragas/metrics/_faithfulness.py +++ b/src/ragas/metrics/_faithfulness.py @@ -168,7 +168,8 @@ def _create_answer_prompt(self, row: t.Dict) -> PromptValue: return prompt_value def _create_nli_prompt(self, row: t.Dict, statements: t.List[str]) -> PromptValue: - assert self.llm is not None, "llm must be set to compute score" + if self.llm is None: + raise ValueError("llm must be set to compute score") contexts = row["contexts"] # check if the statements are support in the contexts @@ -199,7 +200,8 @@ async def _ascore( """ returns the NLI score for each (q, c, a) pair """ - assert self.llm is not None, "LLM is not set" + if self.llm is None: + raise ValueError("LLM is not set") p_value = self._create_answer_prompt(row) answer_result = await self.llm.generate( p_value, callbacks=callbacks, is_async=is_async @@ -226,7 +228,8 @@ async def _ascore( return self._compute_score(faithfulness) def adapt(self, language: str, cache_dir: t.Optional[str] = None) -> None: - assert self.llm is not None, "LLM is not set" + if self.llm is None: + raise ValueError("LLM is not set") logger.info(f"Adapting Faithfulness metric to {language}") self.long_form_answer_prompt = self.long_form_answer_prompt.adapt( diff --git a/src/ragas/metrics/critique.py b/src/ragas/metrics/critique.py index 48362c0627..16fac44803 100644 --- a/src/ragas/metrics/critique.py +++ b/src/ragas/metrics/critique.py @@ -120,7 +120,8 @@ def _compute_score(self, safe_loaded_responses: t.List[CriticClassification]): async def _ascore( self: t.Self, row: t.Dict, callbacks: Callbacks, is_async: bool ) -> float: - assert self.llm is not None, "set LLM before use" + if self.llm is None: + raise ValueError("set LLM before use") q, c, a = row["question"], row["contexts"], row["answer"] @@ -143,7 +144,8 @@ async def _ascore( return self._compute_score(safe_loaded_responses) def adapt(self, language: str, cache_dir: str | None = None) -> None: - assert self.llm is not None, "set LLM before use" + if self.llm is None: + raise ValueError("set LLM before use") logger.info(f"Adapting Critic to {language}") self.critic_prompt.adapt(language, self.llm, cache_dir)