6
6
* @param matrix - The confusion matrix, a 2D Array. Rows represent the actual label and columns the predicted label.
7
7
* @param labels - Labels of the confusion matrix, a 1D Array
8
8
*/
9
- export class ConfusionMatrix {
10
- private labels : Label [ ] ;
9
+ export class ConfusionMatrix < T extends Label > {
10
+ private labels : T [ ] ;
11
11
private matrix : number [ ] [ ] ;
12
- constructor ( matrix : number [ ] [ ] , labels : Label [ ] ) {
12
+ constructor ( matrix : number [ ] [ ] , labels : T [ ] ) {
13
13
if ( matrix . length !== matrix [ 0 ] . length ) {
14
14
throw new Error ( 'Confusion matrix must be square' ) ;
15
15
}
@@ -34,15 +34,15 @@ export class ConfusionMatrix {
34
34
* @param [options.sort]
35
35
* @return Confusion matrix
36
36
*/
37
- static fromLabels (
38
- actual : Label [ ] ,
39
- predicted : Label [ ] ,
40
- options : FromLabelsOptions = { } ,
37
+ static fromLabels < T extends Label > (
38
+ actual : T [ ] ,
39
+ predicted : T [ ] ,
40
+ options : FromLabelsOptions < T > = { } ,
41
41
) {
42
42
if ( predicted . length !== actual . length ) {
43
43
throw new Error ( 'predicted and actual must have the same length' ) ;
44
44
}
45
- let distinctLabels : Set < Label > ;
45
+ let distinctLabels : Set < T > ;
46
46
if ( options . labels ) {
47
47
distinctLabels = new Set ( options . labels ) ;
48
48
} else {
@@ -117,7 +117,7 @@ export class ConfusionMatrix {
117
117
* Get the number of true positive predictions.
118
118
* @param label - The label that should be considered "positive"
119
119
*/
120
- getTruePositiveCount ( label : Label ) : number {
120
+ getTruePositiveCount ( label : T ) : number {
121
121
const index = this . getIndex ( label ) ;
122
122
return this . matrix [ index ] [ index ] ;
123
123
}
@@ -126,7 +126,7 @@ export class ConfusionMatrix {
126
126
* Get the number of true negative predictions.
127
127
* @param label - The label that should be considered "positive"
128
128
*/
129
- getTrueNegativeCount ( label : Label ) {
129
+ getTrueNegativeCount ( label : T ) {
130
130
const index = this . getIndex ( label ) ;
131
131
let count = 0 ;
132
132
for ( let i = 0 ; i < this . matrix . length ; i ++ ) {
@@ -143,7 +143,7 @@ export class ConfusionMatrix {
143
143
* Get the number of false positive predictions.
144
144
* @param label - The label that should be considered "positive"
145
145
*/
146
- getFalsePositiveCount ( label : Label ) {
146
+ getFalsePositiveCount ( label : T ) {
147
147
const index = this . getIndex ( label ) ;
148
148
let count = 0 ;
149
149
for ( let i = 0 ; i < this . matrix . length ; i ++ ) {
@@ -158,7 +158,7 @@ export class ConfusionMatrix {
158
158
* Get the number of false negative predictions.
159
159
* @param label - The label that should be considered "positive"
160
160
*/
161
- getFalseNegativeCount ( label : Label ) : number {
161
+ getFalseNegativeCount ( label : T ) : number {
162
162
const index = this . getIndex ( label ) ;
163
163
let count = 0 ;
164
164
for ( let i = 0 ; i < this . matrix . length ; i ++ ) {
@@ -173,15 +173,15 @@ export class ConfusionMatrix {
173
173
* Get the number of real positive samples.
174
174
* @param label - The label that should be considered "positive"
175
175
*/
176
- getPositiveCount ( label : Label ) {
176
+ getPositiveCount ( label : T ) {
177
177
return this . getTruePositiveCount ( label ) + this . getFalseNegativeCount ( label ) ;
178
178
}
179
179
180
180
/**
181
181
* Get the number of real negative samples.
182
182
* @param label - The label that should be considered "positive"
183
183
*/
184
- getNegativeCount ( label : Label ) {
184
+ getNegativeCount ( label : T ) {
185
185
return this . getTrueNegativeCount ( label ) + this . getFalsePositiveCount ( label ) ;
186
186
}
187
187
@@ -190,7 +190,7 @@ export class ConfusionMatrix {
190
190
* @param label - The label to search for
191
191
* @throws if the label is not found
192
192
*/
193
- getIndex ( label : Label ) : number {
193
+ getIndex ( label : T ) : number {
194
194
const index = this . labels . indexOf ( label ) ;
195
195
if ( index === - 1 ) throw new Error ( 'The label does not exist' ) ;
196
196
return index ;
@@ -202,7 +202,7 @@ export class ConfusionMatrix {
202
202
* @param label - The label that should be considered "positive"
203
203
* @return The true positive rate [0-1]
204
204
*/
205
- getTruePositiveRate ( label : Label ) {
205
+ getTruePositiveRate ( label : T ) {
206
206
return this . getTruePositiveCount ( label ) / this . getPositiveCount ( label ) ;
207
207
}
208
208
@@ -212,7 +212,7 @@ export class ConfusionMatrix {
212
212
* @param label - The label that should be considered "positive"
213
213
* @return The true negative rate a.k.a. specificity.
214
214
*/
215
- getTrueNegativeRate ( label : Label ) {
215
+ getTrueNegativeRate ( label : T ) {
216
216
return this . getTrueNegativeCount ( label ) / this . getNegativeCount ( label ) ;
217
217
}
218
218
@@ -222,7 +222,7 @@ export class ConfusionMatrix {
222
222
* @param label - The label that should be considered "positive"
223
223
* @return the positive predictive value a.k.a. precision.
224
224
*/
225
- getPositivePredictiveValue ( label : Label ) {
225
+ getPositivePredictiveValue ( label : T ) {
226
226
const TP = this . getTruePositiveCount ( label ) ;
227
227
return TP / ( TP + this . getFalsePositiveCount ( label ) ) ;
228
228
}
@@ -232,7 +232,7 @@ export class ConfusionMatrix {
232
232
* {@link https://en.wikipedia.org/wiki/Positive_and_negative_predictive_values}
233
233
* @param label - The label that should be considered "positive"
234
234
*/
235
- getNegativePredictiveValue ( label : Label ) {
235
+ getNegativePredictiveValue ( label : T ) {
236
236
const TN = this . getTrueNegativeCount ( label ) ;
237
237
return TN / ( TN + this . getFalseNegativeCount ( label ) ) ;
238
238
}
@@ -242,7 +242,7 @@ export class ConfusionMatrix {
242
242
* {@link https://en.wikipedia.org/wiki/Type_I_and_type_II_errors#False_positive_and_false_negative_rates}
243
243
* @param label - The label that should be considered "positive"
244
244
*/
245
- getFalseNegativeRate ( label : Label ) {
245
+ getFalseNegativeRate ( label : T ) {
246
246
return 1 - this . getTruePositiveRate ( label ) ;
247
247
}
248
248
@@ -251,7 +251,7 @@ export class ConfusionMatrix {
251
251
* {@link https://en.wikipedia.org/wiki/Type_I_and_type_II_errors#False_positive_and_false_negative_rates}
252
252
* @param label - The label that should be considered "positive"
253
253
*/
254
- getFalsePositiveRate ( label : Label ) {
254
+ getFalsePositiveRate ( label : T ) {
255
255
return 1 - this . getTrueNegativeRate ( label ) ;
256
256
}
257
257
@@ -260,7 +260,7 @@ export class ConfusionMatrix {
260
260
* {@link https://en.wikipedia.org/wiki/False_discovery_rate}
261
261
* @param label - The label that should be considered "positive"
262
262
*/
263
- getFalseDiscoveryRate ( label : Label ) {
263
+ getFalseDiscoveryRate ( label : T ) {
264
264
const FP = this . getFalsePositiveCount ( label ) ;
265
265
return FP / ( FP + this . getTruePositiveCount ( label ) ) ;
266
266
}
@@ -269,7 +269,7 @@ export class ConfusionMatrix {
269
269
* False omission rate (FOR)
270
270
* @param label - The label that should be considered "positive"
271
271
*/
272
- getFalseOmissionRate ( label : Label ) {
272
+ getFalseOmissionRate ( label : T ) {
273
273
const FN = this . getFalseNegativeCount ( label ) ;
274
274
return FN / ( FN + this . getTruePositiveCount ( label ) ) ;
275
275
}
@@ -279,7 +279,7 @@ export class ConfusionMatrix {
279
279
* {@link https://en.wikipedia.org/wiki/F1_score}
280
280
* @param label - The label that should be considered "positive"
281
281
*/
282
- getF1Score ( label : Label ) {
282
+ getF1Score ( label : T ) {
283
283
const TP = this . getTruePositiveCount ( label ) ;
284
284
return (
285
285
( 2 * TP ) /
@@ -294,7 +294,7 @@ export class ConfusionMatrix {
294
294
* {@link https://en.wikipedia.org/wiki/Matthews_correlation_coefficient}
295
295
* @param label - The label that should be considered "positive"
296
296
*/
297
- getMatthewsCorrelationCoefficient ( label : Label ) {
297
+ getMatthewsCorrelationCoefficient ( label : T ) {
298
298
const TP = this . getTruePositiveCount ( label ) ;
299
299
const TN = this . getTrueNegativeCount ( label ) ;
300
300
const FP = this . getFalsePositiveCount ( label ) ;
@@ -310,7 +310,7 @@ export class ConfusionMatrix {
310
310
* {@link https://en.wikipedia.org/wiki/Youden%27s_J_statistic}
311
311
* @param label - The label that should be considered "positive"
312
312
*/
313
- getInformedness ( label : Label ) {
313
+ getInformedness ( label : T ) {
314
314
return (
315
315
this . getTruePositiveRate ( label ) + this . getTrueNegativeRate ( label ) - 1
316
316
) ;
@@ -320,7 +320,7 @@ export class ConfusionMatrix {
320
320
* Markedness
321
321
* @param label - The label that should be considered "positive"
322
322
*/
323
- getMarkedness ( label : Label ) {
323
+ getMarkedness ( label : T ) {
324
324
return (
325
325
this . getPositivePredictiveValue ( label ) +
326
326
this . getNegativePredictiveValue ( label ) -
@@ -333,7 +333,7 @@ export class ConfusionMatrix {
333
333
* @param label - The label that should be considered "positive"
334
334
* @return The 2x2 confusion table. [[TP, FN], [FP, TN]]
335
335
*/
336
- getConfusionTable ( label : Label ) {
336
+ getConfusionTable ( label : T ) {
337
337
return [
338
338
[ this . getTruePositiveCount ( label ) , this . getFalseNegativeCount ( label ) ] ,
339
339
[ this . getFalsePositiveCount ( label ) , this . getTrueNegativeCount ( label ) ] ,
@@ -362,7 +362,7 @@ export class ConfusionMatrix {
362
362
* @param predicted - The predicted label
363
363
* @return The element in the confusion matrix
364
364
*/
365
- getCount ( actual : Label , predicted : Label ) {
365
+ getCount ( actual : T , predicted : T ) {
366
366
const actualIndex = this . getIndex ( actual ) ;
367
367
const predictedIndex = this . getIndex ( predicted ) ;
368
368
return this . matrix [ actualIndex ] [ predictedIndex ] ;
@@ -388,7 +388,7 @@ export class ConfusionMatrix {
388
388
389
389
type Label = boolean | number | string ;
390
390
391
- interface FromLabelsOptions {
392
- labels ?: Label [ ] ;
391
+ interface FromLabelsOptions < T extends Label > {
392
+ labels ?: T [ ] ;
393
393
sort ?: ( ...args : Label [ ] ) => number ;
394
394
}
0 commit comments