10
10
11
11
from explainaboard import feature
12
12
from explainaboard .info import BucketPerformance , Performance , SysOutputInfo
13
- from explainaboard .metric import MetricStats
13
+ import explainaboard .metric
14
+ from explainaboard .metric import Metric
14
15
from explainaboard .processors .processor import Processor
15
16
from explainaboard .processors .processor_registry import register_processor
16
17
from explainaboard .tasks import TaskType
17
- from explainaboard .utils import bucketing , eval_basic_ner
18
+ from explainaboard .utils import bucketing , span_utils
18
19
from explainaboard .utils .analysis import cap_feature
19
- from explainaboard .utils .eval_bucket import f1_seqeval_bucket
20
20
from explainaboard .utils .py_utils import sort_dict
21
21
from explainaboard .utils .typing_utils import unwrap
22
22
@@ -206,7 +206,13 @@ def default_features(cls) -> feature.Features:
206
206
207
207
@classmethod
208
208
def default_metrics (cls ) -> list [str ]:
209
- return ["f1_seqeval" , "recall_seqeval" , "precision_seqeval" ]
209
+ return ["F1Score" ]
210
+
211
+ def _get_true_label (self , data_point : dict ):
212
+ return data_point ["true_tags" ]
213
+
214
+ def _get_predicted_label (self , data_point : dict ):
215
+ return data_point ["pred_tags" ]
210
216
211
217
def _get_statistics_resources (
212
218
self , dataset_split : Dataset
@@ -312,13 +318,11 @@ def _get_fre_rank(self, tokens, statistics):
312
318
# --- End feature functions
313
319
314
320
# These return none because NER is not yet in the main metric interface
315
- def _get_metrics (self , sys_info : SysOutputInfo ):
316
- return None
317
-
318
- def _gen_metric_stats (
319
- self , sys_info : SysOutputInfo , sys_output : list [dict ]
320
- ) -> Optional [list [MetricStats ]]:
321
- return None
321
+ def _get_metrics (self , sys_info : SysOutputInfo ) -> list [Metric ]:
322
+ return [
323
+ getattr (explainaboard .metric , f'BIO{ name } ' )()
324
+ for name in unwrap (sys_info .metric_names )
325
+ ]
322
326
323
327
def _complete_span_features (self , sentence , tags , statistics = None ):
324
328
@@ -328,7 +332,7 @@ def _complete_span_features(self, sentence, tags, statistics=None):
328
332
efre_dic = statistics ["efre_dic" ] if has_stats else None
329
333
330
334
span_dics = []
331
- chunks = eval_basic_ner . get_chunks (tags )
335
+ chunks = span_utils . get_spans_from_bio (tags )
332
336
for tag , sid , eid in chunks :
333
337
span_text = ' ' .join (sentence [sid :eid ])
334
338
# Basic features
@@ -389,35 +393,8 @@ def _complete_features(
389
393
dict_sysout ["pred_entity_info" ] = self ._complete_span_features (
390
394
tokens , dict_sysout ["pred_tags" ], statistics = external_stats
391
395
)
392
- return None
393
-
394
- def get_overall_performance (
395
- self ,
396
- sys_info : SysOutputInfo ,
397
- sys_output : list [dict ],
398
- metric_stats : Any = None ,
399
- ) -> dict [str , Performance ]:
400
- """
401
- Get the overall performance according to metrics
402
- :param sys_info: Information about the system output
403
- :param sys_output: The system output itself
404
- :return: a dictionary of metrics to overall performance numbers
405
- """
406
-
407
- true_tags_list = [x ['true_tags' ] for x in sys_output ]
408
- pred_tags_list = [x ['pred_tags' ] for x in sys_output ]
409
-
410
- overall : dict [str , Performance ] = {}
411
- for metric_name in unwrap (sys_info .metric_names ):
412
- if not metric_name .endswith ('_seqeval' ):
413
- raise NotImplementedError (f'Unsupported metric { metric_name } ' )
414
- # This gets the appropriate metric from the eval_basic_ner package
415
- score_func = getattr (eval_basic_ner , metric_name )
416
- overall [metric_name ] = Performance (
417
- metric_name = metric_name ,
418
- value = score_func (true_tags_list , pred_tags_list ),
419
- )
420
- return overall
396
+ # This is not used elsewhere, so just keep it as-is
397
+ return list ()
421
398
422
399
def _get_span_ids (
423
400
self ,
@@ -554,24 +531,24 @@ def get_bucket_cases_ner(
554
531
samples_over_bucket_true [bucket_interval ], 'true' , sample_dict
555
532
)
556
533
557
- error_case_list = []
534
+ case_list = []
558
535
for pos , tags in sample_dict .items ():
559
536
true_label = tags .get ('true' , 'O' )
560
537
pred_label = tags .get ('pred' , 'O' )
561
- if true_label != pred_label :
562
- split_pos = pos .split ("|||" )
563
- sent_id = int (split_pos [0 ])
564
- span = split_pos [- 1 ]
565
- system_output_id = sys_output [int (sent_id )]["id" ]
566
- error_case = {
567
- "span" : span ,
568
- "text" : str (system_output_id ),
569
- "true_label" : true_label ,
570
- "predicted_label" : pred_label ,
571
- }
572
- error_case_list .append (error_case )
573
-
574
- return error_case_list
538
+
539
+ split_pos = pos .split ("|||" )
540
+ sent_id = int (split_pos [0 ])
541
+ span = split_pos [- 1 ]
542
+ system_output_id = sys_output [int (sent_id )]["id" ]
543
+ error_case = {
544
+ "span" : span ,
545
+ "text" : str (system_output_id ),
546
+ "true_label" : true_label ,
547
+ "predicted_label" : pred_label ,
548
+ }
549
+ case_list .append (error_case )
550
+
551
+ return case_list
575
552
576
553
def get_bucket_performance_ner (
577
554
self ,
@@ -593,6 +570,12 @@ def get_bucket_performance_ner(
593
570
bucket performance
594
571
"""
595
572
573
+ metric_names = unwrap (sys_info .metric_names )
574
+ bucket_metrics = [
575
+ getattr (explainaboard .metric , name )(ignore_classes = ['O' ])
576
+ for name in metric_names
577
+ ]
578
+
596
579
bucket_name_to_performance = {}
597
580
for bucket_interval , spans_true in samples_over_bucket_true .items ():
598
581
@@ -611,29 +594,29 @@ def get_bucket_performance_ner(
611
594
samples_over_bucket_pred ,
612
595
)
613
596
597
+ true_labels = [x ['true_label' ] for x in bucket_samples ]
598
+ pred_labels = [x ['predicted_label' ] for x in bucket_samples ]
599
+
614
600
bucket_performance = BucketPerformance (
615
601
bucket_name = bucket_interval ,
616
602
n_samples = len (spans_pred ),
617
603
bucket_samples = bucket_samples ,
618
604
)
619
- for metric_name in unwrap (sys_info .metric_names ):
620
- """
621
- # Note that: for NER task, the bucket-wise evaluation function is a
622
- # little different from overall evaluation function
623
- # for overall: f1_seqeval
624
- # for bucket: f1_seqeval_bucket
625
- """
626
- f1 , p , r = f1_seqeval_bucket (spans_pred , spans_true )
627
- if metric_name == 'f1_seqeval' :
628
- my_score = f1
629
- elif metric_name == 'precision_seqeval' :
630
- my_score = p
631
- elif metric_name == 'recall_seqeval' :
632
- my_score = r
633
- else :
634
- raise NotImplementedError (f'Unsupported metric { metric_name } ' )
635
- # TODO(gneubig): It'd be better to have significance tests here
636
- performance = Performance (metric_name = metric_name , value = my_score )
605
+ for metric in bucket_metrics :
606
+
607
+ metric_val = metric .evaluate (
608
+ true_labels , pred_labels , conf_value = sys_info .conf_value
609
+ )
610
+ conf_low , conf_high = (
611
+ metric_val .conf_interval if metric_val .conf_interval else None ,
612
+ None ,
613
+ )
614
+ performance = Performance (
615
+ metric_name = metric .name ,
616
+ value = metric_val .value ,
617
+ confidence_score_low = conf_low ,
618
+ confidence_score_high = conf_high ,
619
+ )
637
620
bucket_performance .performances .append (performance )
638
621
639
622
bucket_name_to_performance [bucket_interval ] = bucket_performance
@@ -647,7 +630,7 @@ def get_econ_dic(train_word_sequences, tag_sequences_train, tags):
647
630
Note: when matching, the text span and tag have been lowercased.
648
631
"""
649
632
econ_dic = dict ()
650
- chunks_train = set (eval_basic_ner . get_chunks (tag_sequences_train ))
633
+ chunks_train = set (span_utils . get_spans_from_bio (tag_sequences_train ))
651
634
652
635
# print('tags: ', tags)
653
636
count_idx = 0
@@ -722,7 +705,7 @@ def get_econ_dic(train_word_sequences, tag_sequences_train, tags):
722
705
# Global functions for training set dependent features
723
706
def get_efre_dic (train_word_sequences , tag_sequences_train ):
724
707
efre_dic = dict ()
725
- chunks_train = set (eval_basic_ner . get_chunks (tag_sequences_train ))
708
+ chunks_train = set (span_utils . get_spans_from_bio (tag_sequences_train ))
726
709
count_idx = 0
727
710
word_sequences_train_str = ' ' .join (train_word_sequences ).lower ()
728
711
for true_chunk in tqdm (chunks_train ):
0 commit comments