Skip to content

Commit 4da06d8

Browse files
authored
Merge pull request #205 from neulab/fix_qa_bug
Fixed bug in QA metrics Former-commit-id: c462f38
2 parents bc49336 + f53f456 commit 4da06d8

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

explainaboard/metric.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -418,14 +418,17 @@ class QAMetric(Metric):
418418

419419
def normalize_answer(self, s: str) -> str:
420420
"""Lower text and remove punctuation, articles and extra whitespace."""
421-
s = re.sub(r'\b(a|an|the)\b', ' ', s)
422-
s = ' '.join(s.split())
421+
s = s.lower()
423422
exclude_punc = set(string.punctuation)
424423
s = ''.join(ch for ch in s if ch not in exclude_punc)
425-
s = s.lower()
424+
s = re.sub(r'\b(a|an|the)\b', ' ', s)
425+
s = ' '.join(s.split())
426426
return s
427427

428-
def calc_stats_from_data(self, true_data: list, pred_data: list) -> MetricStats:
428+
def calc_stats_from_data(
429+
self, true_data: list[Union[str, list[str]]], pred_data: list[str]
430+
) -> MetricStats:
431+
true_data = [[x] if isinstance(x, str) else x for x in true_data]
429432
return MetricStats(
430433
np.array(
431434
[

0 commit comments

Comments
 (0)