Skip to content

Commit

Permalink
fix: handle edge cases in prompt processing (#374)
Browse files Browse the repository at this point in the history
  • Loading branch information
shahules786 authored Dec 12, 2023
1 parent b455475 commit 41e9e54
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 11 deletions.
24 changes: 15 additions & 9 deletions src/ragas/metrics/_answer_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,21 @@ def _score_batch(
f1_score = []
for prediction in outputs:
prediction = json_loader.safe_load(prediction[0].text, self.llm)
prediction = [
item.get(key_map[k], np.nan)
for item in prediction
for k in key_map.keys()
]
tp, fp, fn = [
len(item) if isinstance(item, list) else np.nan for item in prediction
]
score = tp / (tp + 0.5 * (fp + fn))
prediction = prediction if isinstance(prediction, list) else []
if prediction:
prediction = [
item.get(key_map[k], np.nan)
for item in prediction
for k in key_map.keys()
]
tp, fp, fn = [
len(item) if isinstance(item, list) else np.nan
for item in prediction
]
score = tp / (tp + 0.5 * (fp + fn))
else:
score = np.nan

f1_score.append(score)

similarity_scores = self.answer_similarity._score_batch(dataset) # type: ignore
Expand Down
2 changes: 1 addition & 1 deletion src/ragas/metrics/_faithfulness.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def _score_batch(
scores = []
for output in outputs:
output = json_loader.safe_load(output[0].text, self.llm)
output = output if output else []
output = output if isinstance(output, list) else []
faithful_statements = sum(
verdict_score_map.get(dict.get("verdict", "").lower(), np.nan)
for dict in output
Expand Down
2 changes: 1 addition & 1 deletion src/ragas/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def _fix_to_json(
callbacks: t.Optional[CallbackManager] = None,
callback_group_name: str = "batch",
):
# TODO (executor)
# TODO (executor)
with trace_as_chain_group(
callback_group_name, callback_manager=callbacks
) as batch_group:
Expand Down

0 comments on commit 41e9e54

Please sign in to comment.