Skip to content

Commit 8804370

Browse files
authored
Merge pull request #202 from neulab/ner_metrics
Change NER to use metrics class Former-commit-id: 767353b
2 parents 4da06d8 + df745f0 commit 8804370

12 files changed

+211
-252
lines changed

explainaboard/metric.py

+73-8
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import numpy as np
1313

1414
from explainaboard.utils.async_eaas import AsyncEaaSRequest
15+
from explainaboard.utils.span_utils import get_spans_from_bio
1516
from explainaboard.utils.typing_utils import unwrap
1617

1718

@@ -215,16 +216,23 @@ class F1Score(Metric):
215216
def default_name(cls) -> str:
216217
return 'F1'
217218

218-
def __init__(self, average: str = 'micro', separate_match: bool = False):
219+
def __init__(
220+
self,
221+
average: str = 'micro',
222+
separate_match: bool = False,
223+
ignore_classes: Optional[list] = None,
224+
):
219225
"""Constructor for f-measure
220226
:param average: What variety of average to measure
221227
:param separate_match: Whether to count matches separately for true and pred.
222228
This is useful in, for example bucketing, when ref and pred are not aligned
229+
:param ignore_classes: Classes to ignore
223230
"""
224231
self.average: str = average
225232
self.separate_match: bool = separate_match
226233
self._stat_mult: int = 4 if separate_match else 3
227234
self._pred_match_offfset: int = 3 if separate_match else 2
235+
self.ignore_classes: Optional[list] = ignore_classes
228236
supported_averages = {'micro', 'macro'}
229237
if average not in supported_averages:
230238
raise ValueError(f'only {supported_averages} supported for now')
@@ -244,6 +252,9 @@ def calc_stats_from_data(self, true_data: list, pred_data: list) -> MetricStats:
244252
(when self.separate_match=True only)
245253
"""
246254
id_map: dict[str, int] = {}
255+
if self.ignore_classes is not None:
256+
for ignore_class in self.ignore_classes:
257+
id_map[ignore_class] = -1
247258
for word in itertools.chain(true_data, pred_data):
248259
if word not in id_map:
249260
id_map[word] = len(id_map)
@@ -253,12 +264,14 @@ def calc_stats_from_data(self, true_data: list, pred_data: list) -> MetricStats:
253264
stats = np.zeros((n_data, n_classes * self._stat_mult))
254265
for i, (t, p) in enumerate(zip(true_data, pred_data)):
255266
tid, pid = id_map[t], id_map[p]
256-
stats[i, tid * self._stat_mult + 0] += 1
257-
stats[i, pid * self._stat_mult + 1] += 1
258-
if tid == pid:
259-
stats[i, tid * self._stat_mult + 2] += 1
260-
if self.separate_match:
261-
stats[i, tid * self._stat_mult + 3] += 1
267+
if tid != -1:
268+
stats[i, tid * self._stat_mult + 0] += 1
269+
if pid != -1:
270+
stats[i, pid * self._stat_mult + 1] += 1
271+
if tid == pid:
272+
stats[i, tid * self._stat_mult + 2] += 1
273+
if self.separate_match:
274+
stats[i, tid * self._stat_mult + 3] += 1
262275
return MetricStats(stats)
263276

264277
def calc_metric_from_aggregate(self, agg_stats: np.ndarray) -> float:
@@ -295,6 +308,58 @@ def get_metadata(self) -> dict:
295308
return meta
296309

297310

311+
class BIOF1Score(F1Score):
312+
"""
313+
Calculate F1 score over BIO-tagged spans.
314+
"""
315+
316+
def __init__(self, average: str = 'micro'):
317+
"""Constructor for BIO f-measure
318+
:param average: What variety of average to measure
319+
"""
320+
super().__init__(average=average)
321+
322+
def calc_stats_from_data(
323+
self, true_data: list[list[str]], pred_data: list[list[str]]
324+
) -> MetricStats:
325+
"""
326+
Return sufficient statistics necessary to compute f-score.
327+
:param true_data: True outputs
328+
:param pred_data: Predicted outputs
329+
:return: Returns stats for each class (integer id c) in the following columns of
330+
MetricStats
331+
* c*self._stat_mult + 0: occurrences in the true output
332+
* c*self._stat_mult + 1: occurrences in the predicted output
333+
* c*self._stat_mult + 2: number of matches with the true output
334+
"""
335+
336+
# Identify the tag types
337+
true_chain, pred_chain = (
338+
itertools.chain.from_iterable(x) for x in (true_data, pred_data)
339+
)
340+
all_tags = set(itertools.chain(true_chain, pred_chain))
341+
tag_ids = {
342+
k: v for v, k in enumerate([x[2:] for x in all_tags if x.startswith('B-')])
343+
}
344+
345+
# Create the sufficient statistics
346+
n_data, n_classes = len(true_data), len(tag_ids)
347+
# This is a bit memory inefficient if there's a large number of classes
348+
stats = np.zeros((n_data, n_classes * self._stat_mult))
349+
350+
for i, (true_sent, pred_sent) in enumerate(zip(true_data, pred_data)):
351+
true_spans, pred_spans = (
352+
get_spans_from_bio(x) for x in (true_sent, pred_sent)
353+
)
354+
match_spans = [x for x in true_spans if x in pred_spans]
355+
for offset, spans in enumerate((true_spans, pred_spans, match_spans)):
356+
for chunk in spans:
357+
c = tag_ids[chunk[0]]
358+
stats[i, c * 3 + offset] += 1
359+
360+
return MetricStats(stats)
361+
362+
298363
class Hits(Metric):
299364
"""
300365
Calculates the hits metric, telling whether the predicted output is in a set of true
@@ -375,7 +440,7 @@ def filter(self, indices: Union[list[int], np.ndarray]) -> MetricStats:
375440
"""
376441
Return a view of these stats filtered down to the indicated indices
377442
"""
378-
sdata: np.ndarray = unwrap(self._data)
443+
sdata: np.ndarray = self.get_data()
379444
if not isinstance(indices, np.ndarray):
380445
indices = np.array(indices)
381446
return MetricStats(sdata[indices])

explainaboard/processors/named_entity_recognition.py

+59-76
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@
1010

1111
from explainaboard import feature
1212
from explainaboard.info import BucketPerformance, Performance, SysOutputInfo
13-
from explainaboard.metric import MetricStats
13+
import explainaboard.metric
14+
from explainaboard.metric import Metric
1415
from explainaboard.processors.processor import Processor
1516
from explainaboard.processors.processor_registry import register_processor
1617
from explainaboard.tasks import TaskType
17-
from explainaboard.utils import bucketing, eval_basic_ner
18+
from explainaboard.utils import bucketing, span_utils
1819
from explainaboard.utils.analysis import cap_feature
19-
from explainaboard.utils.eval_bucket import f1_seqeval_bucket
2020
from explainaboard.utils.py_utils import sort_dict
2121
from explainaboard.utils.typing_utils import unwrap
2222

@@ -206,7 +206,13 @@ def default_features(cls) -> feature.Features:
206206

207207
@classmethod
208208
def default_metrics(cls) -> list[str]:
209-
return ["f1_seqeval", "recall_seqeval", "precision_seqeval"]
209+
return ["F1Score"]
210+
211+
def _get_true_label(self, data_point: dict):
212+
return data_point["true_tags"]
213+
214+
def _get_predicted_label(self, data_point: dict):
215+
return data_point["pred_tags"]
210216

211217
def _get_statistics_resources(
212218
self, dataset_split: Dataset
@@ -312,13 +318,11 @@ def _get_fre_rank(self, tokens, statistics):
312318
# --- End feature functions
313319

314320
# These return none because NER is not yet in the main metric interface
315-
def _get_metrics(self, sys_info: SysOutputInfo):
316-
return None
317-
318-
def _gen_metric_stats(
319-
self, sys_info: SysOutputInfo, sys_output: list[dict]
320-
) -> Optional[list[MetricStats]]:
321-
return None
321+
def _get_metrics(self, sys_info: SysOutputInfo) -> list[Metric]:
322+
return [
323+
getattr(explainaboard.metric, f'BIO{name}')()
324+
for name in unwrap(sys_info.metric_names)
325+
]
322326

323327
def _complete_span_features(self, sentence, tags, statistics=None):
324328

@@ -328,7 +332,7 @@ def _complete_span_features(self, sentence, tags, statistics=None):
328332
efre_dic = statistics["efre_dic"] if has_stats else None
329333

330334
span_dics = []
331-
chunks = eval_basic_ner.get_chunks(tags)
335+
chunks = span_utils.get_spans_from_bio(tags)
332336
for tag, sid, eid in chunks:
333337
span_text = ' '.join(sentence[sid:eid])
334338
# Basic features
@@ -389,35 +393,8 @@ def _complete_features(
389393
dict_sysout["pred_entity_info"] = self._complete_span_features(
390394
tokens, dict_sysout["pred_tags"], statistics=external_stats
391395
)
392-
return None
393-
394-
def get_overall_performance(
395-
self,
396-
sys_info: SysOutputInfo,
397-
sys_output: list[dict],
398-
metric_stats: Any = None,
399-
) -> dict[str, Performance]:
400-
"""
401-
Get the overall performance according to metrics
402-
:param sys_info: Information about the system output
403-
:param sys_output: The system output itself
404-
:return: a dictionary of metrics to overall performance numbers
405-
"""
406-
407-
true_tags_list = [x['true_tags'] for x in sys_output]
408-
pred_tags_list = [x['pred_tags'] for x in sys_output]
409-
410-
overall: dict[str, Performance] = {}
411-
for metric_name in unwrap(sys_info.metric_names):
412-
if not metric_name.endswith('_seqeval'):
413-
raise NotImplementedError(f'Unsupported metric {metric_name}')
414-
# This gets the appropriate metric from the eval_basic_ner package
415-
score_func = getattr(eval_basic_ner, metric_name)
416-
overall[metric_name] = Performance(
417-
metric_name=metric_name,
418-
value=score_func(true_tags_list, pred_tags_list),
419-
)
420-
return overall
396+
# This is not used elsewhere, so just keep it as-is
397+
return list()
421398

422399
def _get_span_ids(
423400
self,
@@ -554,24 +531,24 @@ def get_bucket_cases_ner(
554531
samples_over_bucket_true[bucket_interval], 'true', sample_dict
555532
)
556533

557-
error_case_list = []
534+
case_list = []
558535
for pos, tags in sample_dict.items():
559536
true_label = tags.get('true', 'O')
560537
pred_label = tags.get('pred', 'O')
561-
if true_label != pred_label:
562-
split_pos = pos.split("|||")
563-
sent_id = int(split_pos[0])
564-
span = split_pos[-1]
565-
system_output_id = sys_output[int(sent_id)]["id"]
566-
error_case = {
567-
"span": span,
568-
"text": str(system_output_id),
569-
"true_label": true_label,
570-
"predicted_label": pred_label,
571-
}
572-
error_case_list.append(error_case)
573-
574-
return error_case_list
538+
539+
split_pos = pos.split("|||")
540+
sent_id = int(split_pos[0])
541+
span = split_pos[-1]
542+
system_output_id = sys_output[int(sent_id)]["id"]
543+
error_case = {
544+
"span": span,
545+
"text": str(system_output_id),
546+
"true_label": true_label,
547+
"predicted_label": pred_label,
548+
}
549+
case_list.append(error_case)
550+
551+
return case_list
575552

576553
def get_bucket_performance_ner(
577554
self,
@@ -593,6 +570,12 @@ def get_bucket_performance_ner(
593570
bucket performance
594571
"""
595572

573+
metric_names = unwrap(sys_info.metric_names)
574+
bucket_metrics = [
575+
getattr(explainaboard.metric, name)(ignore_classes=['O'])
576+
for name in metric_names
577+
]
578+
596579
bucket_name_to_performance = {}
597580
for bucket_interval, spans_true in samples_over_bucket_true.items():
598581

@@ -611,29 +594,29 @@ def get_bucket_performance_ner(
611594
samples_over_bucket_pred,
612595
)
613596

597+
true_labels = [x['true_label'] for x in bucket_samples]
598+
pred_labels = [x['predicted_label'] for x in bucket_samples]
599+
614600
bucket_performance = BucketPerformance(
615601
bucket_name=bucket_interval,
616602
n_samples=len(spans_pred),
617603
bucket_samples=bucket_samples,
618604
)
619-
for metric_name in unwrap(sys_info.metric_names):
620-
"""
621-
# Note that: for NER task, the bucket-wise evaluation function is a
622-
# little different from overall evaluation function
623-
# for overall: f1_seqeval
624-
# for bucket: f1_seqeval_bucket
625-
"""
626-
f1, p, r = f1_seqeval_bucket(spans_pred, spans_true)
627-
if metric_name == 'f1_seqeval':
628-
my_score = f1
629-
elif metric_name == 'precision_seqeval':
630-
my_score = p
631-
elif metric_name == 'recall_seqeval':
632-
my_score = r
633-
else:
634-
raise NotImplementedError(f'Unsupported metric {metric_name}')
635-
# TODO(gneubig): It'd be better to have significance tests here
636-
performance = Performance(metric_name=metric_name, value=my_score)
605+
for metric in bucket_metrics:
606+
607+
metric_val = metric.evaluate(
608+
true_labels, pred_labels, conf_value=sys_info.conf_value
609+
)
610+
conf_low, conf_high = (
611+
metric_val.conf_interval if metric_val.conf_interval else None,
612+
None,
613+
)
614+
performance = Performance(
615+
metric_name=metric.name,
616+
value=metric_val.value,
617+
confidence_score_low=conf_low,
618+
confidence_score_high=conf_high,
619+
)
637620
bucket_performance.performances.append(performance)
638621

639622
bucket_name_to_performance[bucket_interval] = bucket_performance
@@ -647,7 +630,7 @@ def get_econ_dic(train_word_sequences, tag_sequences_train, tags):
647630
Note: when matching, the text span and tag have been lowercased.
648631
"""
649632
econ_dic = dict()
650-
chunks_train = set(eval_basic_ner.get_chunks(tag_sequences_train))
633+
chunks_train = set(span_utils.get_spans_from_bio(tag_sequences_train))
651634

652635
# print('tags: ', tags)
653636
count_idx = 0
@@ -722,7 +705,7 @@ def get_econ_dic(train_word_sequences, tag_sequences_train, tags):
722705
# Global functions for training set dependent features
723706
def get_efre_dic(train_word_sequences, tag_sequences_train):
724707
efre_dic = dict()
725-
chunks_train = set(eval_basic_ner.get_chunks(tag_sequences_train))
708+
chunks_train = set(span_utils.get_spans_from_bio(tag_sequences_train))
726709
count_idx = 0
727710
word_sequences_train_str = ' '.join(train_word_sequences).lower()
728711
for true_chunk in tqdm(chunks_train):

explainaboard/processors/processor.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -164,10 +164,7 @@ def _get_feature_func(self, func_name: str):
164164
def _get_eaas_client(self):
165165
if not self._eaas_client:
166166
self._eaas_config = Config()
167-
self._eaas_client = AsyncEaaSClient()
168-
self._eaas_client.load_config(
169-
self._eaas_config
170-
) # The config you have created above
167+
self._eaas_client = AsyncEaaSClient(self._eaas_config)
171168
return self._eaas_client
172169

173170
def _get_true_label(self, data_point: dict):

explainaboard/tests/test_metric.py

+20-2
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,25 @@ def test_mrr(self):
5555
result = metric.evaluate(true, pred, conf_value=0.05)
5656
self.assertAlmostEqual(result.value, 2.5 / 6.0)
5757

58+
def test_ner_f1(self):
59+
60+
true = [
61+
['O', 'O', 'B-MISC', 'I-MISC', 'B-MISC', 'O', 'O'],
62+
['B-PER', 'I-PER', 'O'],
63+
]
64+
pred = [
65+
['O', 'O', 'B-MISC', 'I-MISC', 'B-MISC', 'I-MISC', 'O'],
66+
['B-PER', 'I-PER', 'O'],
67+
]
68+
69+
metric = explainaboard.metric.BIOF1Score(average='micro')
70+
result = metric.evaluate(true, pred, conf_value=0.05)
71+
self.assertAlmostEqual(result.value, 2.0 / 3.0)
72+
73+
metric = explainaboard.metric.BIOF1Score(average='macro')
74+
result = metric.evaluate(true, pred, conf_value=0.05)
75+
self.assertAlmostEqual(result.value, 3.0 / 4.0)
76+
5877
def _get_eaas_request(
5978
self,
6079
sys_output: list[dict],
@@ -92,8 +111,7 @@ def test_eaas_decomposabiltiy(self):
92111
sys_output = list(loader.load())
93112

94113
# Initialize client and decide which metrics to test
95-
eaas_client = AsyncEaaSClient()
96-
eaas_client.load_config(Config())
114+
eaas_client = AsyncEaaSClient(Config())
97115
metric_names = ['rouge1', 'bleu', 'chrf']
98116
# Uncomment the following line to test all metrics,
99117
# but beware that it will be very slow

0 commit comments

Comments
 (0)