Skip to content

Commit

Permalink
updated the degradation analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
chakravarthik27 committed Oct 11, 2024
1 parent dc408e8 commit b7fef1f
Showing 1 changed file with 84 additions and 11 deletions.
95 changes: 84 additions & 11 deletions langtest/transform/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1146,14 +1146,14 @@ class DegradationAnalysis(BaseAccuracy):

supported_tasks = ["ner", "text-classification"]

result_data = defaultdict(dict)

@classmethod
def transform(cls, test: str, y_true: List[Any], params: Dict):
sample = MinScoreSample(
category="accuracy",
test_type="degradation_analysis",
)
# reset the result data
DegradationAnalysis.result_data.clear()

return [sample]
return []

@staticmethod
async def run(
Expand Down Expand Up @@ -1183,8 +1183,6 @@ async def run(

progress = kwargs.get("progress_bar", False)

output = defaultdict(dict)

for category, data in test_cases.items():
if category not in ["robustness", "bias"]:
continue
Expand All @@ -1205,24 +1203,30 @@ async def run(

degradation = accuracy_score2 - accuracy_score1

output[category][test_type] = degradation
DegradationAnalysis.result_data[category][test_type] = {
"before": accuracy_score1,
"after": accuracy_score2,
"difference": degradation,
}
if progress:
progress.update(1)

print(output)
return []

@staticmethod
def preprocess(y_true: List, y_pred: List):
def preprocess(y_true: Union[list, pd.Series], y_pred: Union[list, pd.Series]):
"""
Preprocesses the input data for the degradation analysis.
Args:
y_true (List): The true labels.
y_pred (List): The predicted labels.
Returns:
Tuple[pd.Series, pd.Series]: The preprocessed true and predicted labels.
y_true, y_pred (Tuple[pd.Series, pd.Series]):
The preprocessed true and predicted labels.
"""

if isinstance(y_true, list):
Expand All @@ -1246,3 +1250,72 @@ def preprocess(y_true: List, y_pred: List):
y_true = y_true.apply(lambda x: x.split("-")[-1])

return y_true, y_pred

@staticmethod
def show_results():
import pandas as pd
import matplotlib.pyplot as plt

data = DegradationAnalysis.result_data
if not data:
raise ValueError("No data found for degradation analysis.")

for category, tests in data.items():
df = pd.DataFrame(tests).T

fig, ax = plt.subplots(figsize=(12, 6))

y_labels = df.index
y_pos = range(len(y_labels))

for i, label in enumerate(y_labels):
# Before robustness bar
ax.broken_barh(
[(0, df["after"][i])],
(i - 0.2, 0.4),
color="#1f77b4",
label="After" if i == 0 else "",
)
# After robustness bar with adjusted width send this bar to back
ax.broken_barh(
[(0, df["before"][i])],
(i - 0.4, 0.8),
color="#d3d3d3",
zorder=0,
label="Before" if i == 0 else "",
)

# Adjust label positions if too close
if abs(df["before"][i] - df["after"][i]) < 0.05:
offset = 0.03
else:
offset = 0.01

ax.text(
df["after"][i] + 0.01,
i,
f"{df['after'][i]:.2f}",
va="center",
ha="left",
color="#1f77b4",
)
ax.text(
df["before"][i] + offset,
i,
f"{df['before'][i]:.2f}",
va="center",
ha="left",
color="black",
)

ax.set_xlim(0, 1)
ax.set_yticks(y_pos)
ax.set_yticklabels(y_labels)
ax.set_xlabel("Accuracy Score Over Robustness and Bias Tests")
ax.set_title(
"Comparison of Accuracy Before and After Robustness and Bias Tests"
)
ax.legend()

plt.tight_layout()
plt.show()

0 comments on commit b7fef1f

Please sign in to comment.