Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: replace assert with ValueError for llm validation #884

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/ragas/metrics/_answer_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ def _compute_statement_presence(
return score

async def _ascore(self, row: t.Dict, callbacks: Callbacks, is_async: bool) -> float:
assert self.llm is not None, "LLM must be set"
if self.llm is None:
raise ValueError("LLM must be set.")

q, a, g = row["question"], row["answer"], row["ground_truth"]
p_value = self.correctness_prompt.format(question=q, ground_truth=g, answer=a)
Expand Down Expand Up @@ -174,7 +175,8 @@ async def _ascore(self, row: t.Dict, callbacks: Callbacks, is_async: bool) -> fl
return float(score)

def adapt(self, language: str, cache_dir: t.Optional[str] = None) -> None:
assert self.llm is not None, "llm must be set to compute score"
if self.llm is None:
raise ValueError("llm must be set to compute score.")

logger.info(f"Adapting AnswerCorrectness metric to {language}")
self.correctness_prompt = self.correctness_prompt.adapt(
Expand Down
6 changes: 4 additions & 2 deletions src/ragas/metrics/_answer_relevance.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ def _create_question_gen_prompt(self, row: t.Dict) -> PromptValue:
return self.question_generation.format(answer=ans, context="\n".join(ctx))

async def _ascore(self, row: t.Dict, callbacks: Callbacks, is_async: bool) -> float:
assert self.llm is not None, "LLM is not set"
if self.llm is None:
raise ValueError("LLM is not set")

prompt = self._create_question_gen_prompt(row)
result = await self.llm.generate(
Expand All @@ -167,7 +168,8 @@ async def _ascore(self, row: t.Dict, callbacks: Callbacks, is_async: bool) -> fl
return self._calculate_score(answers, row)

def adapt(self, language: str, cache_dir: str | None = None) -> None:
assert self.llm is not None, "LLM is not set"
if self.llm is None:
raise ValueError("LLM is not set")

logger.info(f"Adapting AnswerRelevancy metric to {language}")
self.question_generation = self.question_generation.adapt(
Expand Down
3 changes: 2 additions & 1 deletion src/ragas/metrics/_context_entities_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ async def get_entities(
callbacks: Callbacks,
is_async: bool,
) -> t.Optional[ContextEntitiesResponse]:
assert self.llm is not None, "LLM is not initialized"
if self.llm is None:
raise ValueError("LLM is not initialized")
p_value = self.context_entity_recall_prompt.format(
text=text,
)
Expand Down
6 changes: 4 additions & 2 deletions src/ragas/metrics/_context_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ async def _ascore(
callbacks: Callbacks,
is_async: bool,
) -> float:
assert self.llm is not None, "LLM is not set"
if self.llm is None:
raise ValueError("LLM is not set")

human_prompts = self._context_precision_prompt(row)
responses = []
Expand All @@ -162,7 +163,8 @@ async def _ascore(
return score

def adapt(self, language: str, cache_dir: str | None = None) -> None:
assert self.llm is not None, "LLM is not set"
if self.llm is None:
raise ValueError("LLM is not set")

logging.info(f"Adapting Context Precision to {language}")
self.context_precision_prompt = self.context_precision_prompt.adapt(
Expand Down
6 changes: 4 additions & 2 deletions src/ragas/metrics/_context_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ def _compute_score(self, response: t.Any) -> float:
return score

async def _ascore(self, row: t.Dict, callbacks: Callbacks, is_async: bool) -> float:
assert self.llm is not None, "set LLM before use"
if self.llm is None:
raise ValueError("set LLM before use")
p_value = self._create_context_recall_prompt(row)
result = await self.llm.generate(
p_value,
Expand All @@ -160,7 +161,8 @@ async def _ascore(self, row: t.Dict, callbacks: Callbacks, is_async: bool) -> fl
return self._compute_score(answers)

def adapt(self, language: str, cache_dir: str | None = None) -> None:
assert self.llm is not None, "set LLM before use"
if self.llm is None:
raise ValueError("set LLM before use")

logger.info(f"Adapting Context Recall to {language}")
self.context_recall_prompt = self.context_recall_prompt.adapt(
Expand Down
6 changes: 4 additions & 2 deletions src/ragas/metrics/_context_relevancy.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ def _compute_score(self, response: str, row: t.Dict) -> float:
return min(len(indices) / len(context_sents), 1)

async def _ascore(self, row: t.Dict, callbacks: Callbacks, is_async: bool) -> float:
assert self.llm is not None, "LLM is not initialized"
if self.llm is None:
raise ValueError("LLM is not initialized")

if self.show_deprecation_warning:
logger.warning(
Expand All @@ -84,7 +85,8 @@ async def _ascore(self, row: t.Dict, callbacks: Callbacks, is_async: bool) -> fl
return self._compute_score(result.generations[0][0].text, row)

def adapt(self, language: str, cache_dir: str | None = None) -> None:
assert self.llm is not None, "set LLM before use"
if self.llm is None:
raise ValueError("set LLM before use")

logger.info(f"Adapting Context Relevancy to {language}")
self.context_relevancy_prompt = self.context_relevancy_prompt.adapt(
Expand Down
9 changes: 6 additions & 3 deletions src/ragas/metrics/_faithfulness.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ def _create_answer_prompt(self, row: t.Dict) -> PromptValue:
return prompt_value

def _create_nli_prompt(self, row: t.Dict, statements: t.List[str]) -> PromptValue:
assert self.llm is not None, "llm must be set to compute score"
if self.llm is None:
raise ValueError("llm must be set to compute score")

contexts = row["contexts"]
# check if the statements are support in the contexts
Expand Down Expand Up @@ -199,7 +200,8 @@ async def _ascore(
"""
returns the NLI score for each (q, c, a) pair
"""
assert self.llm is not None, "LLM is not set"
if self.llm is None:
raise ValueError("LLM is not set")
p_value = self._create_answer_prompt(row)
answer_result = await self.llm.generate(
p_value, callbacks=callbacks, is_async=is_async
Expand All @@ -226,7 +228,8 @@ async def _ascore(
return self._compute_score(faithfulness)

def adapt(self, language: str, cache_dir: t.Optional[str] = None) -> None:
assert self.llm is not None, "LLM is not set"
if self.llm is None:
raise ValueError("LLM is not set")

logger.info(f"Adapting Faithfulness metric to {language}")
self.long_form_answer_prompt = self.long_form_answer_prompt.adapt(
Expand Down
6 changes: 4 additions & 2 deletions src/ragas/metrics/critique.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ def _compute_score(self, safe_loaded_responses: t.List[CriticClassification]):
async def _ascore(
self: t.Self, row: t.Dict, callbacks: Callbacks, is_async: bool
) -> float:
assert self.llm is not None, "set LLM before use"
if self.llm is None:
raise ValueError("set LLM before use")

q, c, a = row["question"], row["contexts"], row["answer"]

Expand All @@ -143,7 +144,8 @@ async def _ascore(
return self._compute_score(safe_loaded_responses)

def adapt(self, language: str, cache_dir: str | None = None) -> None:
assert self.llm is not None, "set LLM before use"
if self.llm is None:
raise ValueError("set LLM before use")

logger.info(f"Adapting Critic to {language}")
self.critic_prompt.adapt(language, self.llm, cache_dir)
Expand Down
Loading