Skip to content

Commit

Permalink
refactor: linting
Browse files Browse the repository at this point in the history
  • Loading branch information
Ki-Seki committed Aug 20, 2024
1 parent 6b7d43f commit 23203e8
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 6 deletions.
2 changes: 1 addition & 1 deletion eval/benchs/halueval/eval_halueval_dialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def scoring(self, data_point: dict) -> dict:
)
response = self.model.safe_request(query)

answer = response.strip().split()
answer = response.strip().split()
# Extract the first word, such as "Yes", "No", "#Yes", "No."
# Note: "".strip() returns [] instead of [""]
answer = answer[0] if answer else ""
Expand Down
3 changes: 1 addition & 2 deletions eval/benchs/halueval/eval_halueval_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,13 @@ def scoring(self, data_point: dict) -> dict:
)
response = self.model.safe_request(query)

answer = response.strip().split()
answer = response.strip().split()
# Extract the first word, such as "Yes", "No", "#Yes", "No."
# Note: "".strip() returns [] instead of [""]
answer = answer[0] if answer else ""
# Remove the leading "#", ".", ","
answer = answer.strip("#").strip(".").strip(",")


return {
"metrics": {
"correct": ground_truth.lower() == answer.lower(),
Expand Down
2 changes: 1 addition & 1 deletion eval/benchs/halueval/eval_halueval_summa.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def scoring(self, data_point: dict) -> dict:
)
response = self.model.safe_request(query)

answer = response.strip().split()
answer = response.strip().split()
# Extract the first word, such as "Yes", "No", "#Yes", "No."
# Note: "".strip() returns [] instead of [""]
answer = answer[0] if answer else ""
Expand Down
4 changes: 2 additions & 2 deletions eval/benchs/uhgeval/eval_sele.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def which_is_true(self, contn1: str, contn2: str, obj: dict) -> tuple[int, str]:
"""Given two continuations, determine which one is more accurate.
Returns:
tuple: (answer, response). `answer` is 1 if the first continuation is more
accurate, 2 if the second one is more accurate, and -1 if error. `response`
tuple: (answer, response). `answer` is 1 if the first continuation is more
accurate, 2 if the second one is more accurate, and -1 if error. `response`
is the model's response.
"""
query = PROMPT_TEMPLATE.format(
Expand Down

0 comments on commit 23203e8

Please sign in to comment.