Skip to content

Commit 8e66373

Browse files
committed
feat: use a generatic for the label's type
1 parent 6bd4669 commit 8e66373

File tree

2 files changed

+32
-32
lines changed

2 files changed

+32
-32
lines changed

src/__tests__/test.test.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,6 @@ describe('Confusion Matrix', () => {
5252

5353
it('should throw if trying to get the count for unexisting label', () => {
5454
const CM = new ConfusionMatrix(full.matrix, full.labels);
55-
expect(() => CM.getCount('A', 'B')).toThrow(/label does not exist/);
55+
expect(() => CM.getCount(4, 5)).toThrow(/label does not exist/);
5656
});
5757
});

src/index.ts

+31-31
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
* @param matrix - The confusion matrix, a 2D Array. Rows represent the actual label and columns the predicted label.
77
* @param labels - Labels of the confusion matrix, a 1D Array
88
*/
9-
export class ConfusionMatrix {
10-
private labels: Label[];
9+
export class ConfusionMatrix<T extends Label> {
10+
private labels: T[];
1111
private matrix: number[][];
12-
constructor(matrix: number[][], labels: Label[]) {
12+
constructor(matrix: number[][], labels: T[]) {
1313
if (matrix.length !== matrix[0].length) {
1414
throw new Error('Confusion matrix must be square');
1515
}
@@ -34,15 +34,15 @@ export class ConfusionMatrix {
3434
* @param [options.sort]
3535
* @return Confusion matrix
3636
*/
37-
static fromLabels(
38-
actual: Label[],
39-
predicted: Label[],
40-
options: FromLabelsOptions = {},
37+
static fromLabels<T extends Label>(
38+
actual: T[],
39+
predicted: T[],
40+
options: FromLabelsOptions<T> = {},
4141
) {
4242
if (predicted.length !== actual.length) {
4343
throw new Error('predicted and actual must have the same length');
4444
}
45-
let distinctLabels: Set<Label>;
45+
let distinctLabels: Set<T>;
4646
if (options.labels) {
4747
distinctLabels = new Set(options.labels);
4848
} else {
@@ -117,7 +117,7 @@ export class ConfusionMatrix {
117117
* Get the number of true positive predictions.
118118
* @param label - The label that should be considered "positive"
119119
*/
120-
getTruePositiveCount(label: Label): number {
120+
getTruePositiveCount(label: T): number {
121121
const index = this.getIndex(label);
122122
return this.matrix[index][index];
123123
}
@@ -126,7 +126,7 @@ export class ConfusionMatrix {
126126
* Get the number of true negative predictions.
127127
* @param label - The label that should be considered "positive"
128128
*/
129-
getTrueNegativeCount(label: Label) {
129+
getTrueNegativeCount(label: T) {
130130
const index = this.getIndex(label);
131131
let count = 0;
132132
for (let i = 0; i < this.matrix.length; i++) {
@@ -143,7 +143,7 @@ export class ConfusionMatrix {
143143
* Get the number of false positive predictions.
144144
* @param label - The label that should be considered "positive"
145145
*/
146-
getFalsePositiveCount(label: Label) {
146+
getFalsePositiveCount(label: T) {
147147
const index = this.getIndex(label);
148148
let count = 0;
149149
for (let i = 0; i < this.matrix.length; i++) {
@@ -158,7 +158,7 @@ export class ConfusionMatrix {
158158
* Get the number of false negative predictions.
159159
* @param label - The label that should be considered "positive"
160160
*/
161-
getFalseNegativeCount(label: Label): number {
161+
getFalseNegativeCount(label: T): number {
162162
const index = this.getIndex(label);
163163
let count = 0;
164164
for (let i = 0; i < this.matrix.length; i++) {
@@ -173,15 +173,15 @@ export class ConfusionMatrix {
173173
* Get the number of real positive samples.
174174
* @param label - The label that should be considered "positive"
175175
*/
176-
getPositiveCount(label: Label) {
176+
getPositiveCount(label: T) {
177177
return this.getTruePositiveCount(label) + this.getFalseNegativeCount(label);
178178
}
179179

180180
/**
181181
* Get the number of real negative samples.
182182
* @param label - The label that should be considered "positive"
183183
*/
184-
getNegativeCount(label: Label) {
184+
getNegativeCount(label: T) {
185185
return this.getTrueNegativeCount(label) + this.getFalsePositiveCount(label);
186186
}
187187

@@ -190,7 +190,7 @@ export class ConfusionMatrix {
190190
* @param label - The label to search for
191191
* @throws if the label is not found
192192
*/
193-
getIndex(label: Label): number {
193+
getIndex(label: T): number {
194194
const index = this.labels.indexOf(label);
195195
if (index === -1) throw new Error('The label does not exist');
196196
return index;
@@ -202,7 +202,7 @@ export class ConfusionMatrix {
202202
* @param label - The label that should be considered "positive"
203203
* @return The true positive rate [0-1]
204204
*/
205-
getTruePositiveRate(label: Label) {
205+
getTruePositiveRate(label: T) {
206206
return this.getTruePositiveCount(label) / this.getPositiveCount(label);
207207
}
208208

@@ -212,7 +212,7 @@ export class ConfusionMatrix {
212212
* @param label - The label that should be considered "positive"
213213
* @return The true negative rate a.k.a. specificity.
214214
*/
215-
getTrueNegativeRate(label: Label) {
215+
getTrueNegativeRate(label: T) {
216216
return this.getTrueNegativeCount(label) / this.getNegativeCount(label);
217217
}
218218

@@ -222,7 +222,7 @@ export class ConfusionMatrix {
222222
* @param label - The label that should be considered "positive"
223223
* @return the positive predictive value a.k.a. precision.
224224
*/
225-
getPositivePredictiveValue(label: Label) {
225+
getPositivePredictiveValue(label: T) {
226226
const TP = this.getTruePositiveCount(label);
227227
return TP / (TP + this.getFalsePositiveCount(label));
228228
}
@@ -232,7 +232,7 @@ export class ConfusionMatrix {
232232
* {@link https://en.wikipedia.org/wiki/Positive_and_negative_predictive_values}
233233
* @param label - The label that should be considered "positive"
234234
*/
235-
getNegativePredictiveValue(label: Label) {
235+
getNegativePredictiveValue(label: T) {
236236
const TN = this.getTrueNegativeCount(label);
237237
return TN / (TN + this.getFalseNegativeCount(label));
238238
}
@@ -242,7 +242,7 @@ export class ConfusionMatrix {
242242
* {@link https://en.wikipedia.org/wiki/Type_I_and_type_II_errors#False_positive_and_false_negative_rates}
243243
* @param label - The label that should be considered "positive"
244244
*/
245-
getFalseNegativeRate(label: Label) {
245+
getFalseNegativeRate(label: T) {
246246
return 1 - this.getTruePositiveRate(label);
247247
}
248248

@@ -251,7 +251,7 @@ export class ConfusionMatrix {
251251
* {@link https://en.wikipedia.org/wiki/Type_I_and_type_II_errors#False_positive_and_false_negative_rates}
252252
* @param label - The label that should be considered "positive"
253253
*/
254-
getFalsePositiveRate(label: Label) {
254+
getFalsePositiveRate(label: T) {
255255
return 1 - this.getTrueNegativeRate(label);
256256
}
257257

@@ -260,7 +260,7 @@ export class ConfusionMatrix {
260260
* {@link https://en.wikipedia.org/wiki/False_discovery_rate}
261261
* @param label - The label that should be considered "positive"
262262
*/
263-
getFalseDiscoveryRate(label: Label) {
263+
getFalseDiscoveryRate(label: T) {
264264
const FP = this.getFalsePositiveCount(label);
265265
return FP / (FP + this.getTruePositiveCount(label));
266266
}
@@ -269,7 +269,7 @@ export class ConfusionMatrix {
269269
* False omission rate (FOR)
270270
* @param label - The label that should be considered "positive"
271271
*/
272-
getFalseOmissionRate(label: Label) {
272+
getFalseOmissionRate(label: T) {
273273
const FN = this.getFalseNegativeCount(label);
274274
return FN / (FN + this.getTruePositiveCount(label));
275275
}
@@ -279,7 +279,7 @@ export class ConfusionMatrix {
279279
* {@link https://en.wikipedia.org/wiki/F1_score}
280280
* @param label - The label that should be considered "positive"
281281
*/
282-
getF1Score(label: Label) {
282+
getF1Score(label: T) {
283283
const TP = this.getTruePositiveCount(label);
284284
return (
285285
(2 * TP) /
@@ -294,7 +294,7 @@ export class ConfusionMatrix {
294294
* {@link https://en.wikipedia.org/wiki/Matthews_correlation_coefficient}
295295
* @param label - The label that should be considered "positive"
296296
*/
297-
getMatthewsCorrelationCoefficient(label: Label) {
297+
getMatthewsCorrelationCoefficient(label: T) {
298298
const TP = this.getTruePositiveCount(label);
299299
const TN = this.getTrueNegativeCount(label);
300300
const FP = this.getFalsePositiveCount(label);
@@ -310,7 +310,7 @@ export class ConfusionMatrix {
310310
* {@link https://en.wikipedia.org/wiki/Youden%27s_J_statistic}
311311
* @param label - The label that should be considered "positive"
312312
*/
313-
getInformedness(label: Label) {
313+
getInformedness(label: T) {
314314
return (
315315
this.getTruePositiveRate(label) + this.getTrueNegativeRate(label) - 1
316316
);
@@ -320,7 +320,7 @@ export class ConfusionMatrix {
320320
* Markedness
321321
* @param label - The label that should be considered "positive"
322322
*/
323-
getMarkedness(label: Label) {
323+
getMarkedness(label: T) {
324324
return (
325325
this.getPositivePredictiveValue(label) +
326326
this.getNegativePredictiveValue(label) -
@@ -333,7 +333,7 @@ export class ConfusionMatrix {
333333
* @param label - The label that should be considered "positive"
334334
* @return The 2x2 confusion table. [[TP, FN], [FP, TN]]
335335
*/
336-
getConfusionTable(label: Label) {
336+
getConfusionTable(label: T) {
337337
return [
338338
[this.getTruePositiveCount(label), this.getFalseNegativeCount(label)],
339339
[this.getFalsePositiveCount(label), this.getTrueNegativeCount(label)],
@@ -362,7 +362,7 @@ export class ConfusionMatrix {
362362
* @param predicted - The predicted label
363363
* @return The element in the confusion matrix
364364
*/
365-
getCount(actual: Label, predicted: Label) {
365+
getCount(actual: T, predicted: T) {
366366
const actualIndex = this.getIndex(actual);
367367
const predictedIndex = this.getIndex(predicted);
368368
return this.matrix[actualIndex][predictedIndex];
@@ -388,7 +388,7 @@ export class ConfusionMatrix {
388388

389389
type Label = boolean | number | string;
390390

391-
interface FromLabelsOptions {
392-
labels?: Label[];
391+
interface FromLabelsOptions<T extends Label> {
392+
labels?: T[];
393393
sort?: (...args: Label[]) => number;
394394
}

0 commit comments

Comments
 (0)