Skip to content

Commit

Permalink
fixed: Unbound Error and Key Error.
Browse files Browse the repository at this point in the history
  • Loading branch information
chakravarthik27 committed Sep 3, 2024
1 parent 16fee46 commit 258a0f7
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 23 deletions.
2 changes: 1 addition & 1 deletion langtest/datahandler/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,7 +957,7 @@ def _import_data(self, file_name, **kwargs) -> List[Sample]:
import ast

i["transformations"] = ast.literal_eval(temp)
else:
elif "transformations" in i:
i.pop("transformations")
sample = self.task.get_sample_class(**i)
samples.append(sample)
Expand Down
38 changes: 16 additions & 22 deletions langtest/modelhandler/jsl_modelhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,31 +450,25 @@ def predict(
SequenceClassificationOutput: Classification output from SparkNLP LightPipeline.
"""
prediction_metadata = self.model.fullAnnotate(text)[0][self.output_col]
prediction = []

if len(prediction_metadata) > 0:
prediction_metadata = prediction_metadata[0].metadata
prediction = [
{"label": x, "score": y} for x, y in prediction_metadata.items()
]

if self.multi_label_classifier:
multi_label = True
if len(prediction_metadata) > 0:
prediction_metadata = prediction_metadata[0].metadata

prediction = [
{"label": x, "score": y} for x, y in prediction_metadata.items()
]
# filter based on the threshold value with score greater than threshold
prediction = [x for x in prediction if float(x["score"]) > self.threshold]

return SequenceClassificationOutput(
text=text,
predictions=prediction,
multi_label=multi_label,
)
else:
return SequenceClassificationOutput(
text=text, predictions=[], multi_label=multi_label
)
prediction = [x for x in prediction if float(x["score"]) > self.threshold]

else:
if not return_all_scores:
prediction = [max(prediction, key=lambda x: x["score"])]
return SequenceClassificationOutput(
text=text,
predictions=prediction,
multi_label=self.multi_label_classifier,
)

if not return_all_scores:
prediction = [max(prediction, key=lambda x: x["score"])]

return SequenceClassificationOutput(text=text, predictions=prediction)

Expand Down

0 comments on commit 258a0f7

Please sign in to comment.