1
1
"""Data structure for all datasets that come with the framework."""
2
2
from __future__ import annotations
3
3
from abc import ABC , abstractmethod
4
- from dataclasses import dataclass
5
- from enum import Enum , auto
4
+ from dataclasses import dataclass , field
5
+ from enum import auto
6
6
from pathlib import Path
7
7
import typing
8
- from typing import ClassVar , Literal , NamedTuple , Sequence , TypedDict , final
8
+ from typing import ClassVar , Sequence , TypedDict , final
9
9
10
10
import pandas as pd
11
- from ranzen import implements
11
+ from ranzen import StrEnum , implements
12
12
13
13
from ethicml .common import ROOT_PATH
14
14
from ethicml .utility import DataTuple , undo_one_hot
15
15
16
16
from .util import (
17
- DiscFeatureGroup ,
17
+ DiscFeatureGroups ,
18
18
LabelSpec ,
19
19
flatten_dict ,
20
20
from_dummies ,
31
31
"LabelSpecsPair" ,
32
32
"LegacyDataset" ,
33
33
"StaticCSVDataset" ,
34
+ "one_hot_encode_and_combine" ,
34
35
]
35
36
36
37
37
- class LabelSpecsPair (NamedTuple ):
38
- """A pair of label specs."""
38
+ @dataclass
39
+ class LabelSpecsPair :
40
+ """A pair of label specs.
41
+
42
+ :param s: Spec for building the ``s`` label.
43
+ :param y: Spec for building the ``y`` label.
44
+ :param to_remove: List of feature groups that need to be removed because they are label building
45
+ blocks. (Default: ``[]``)
46
+ """
39
47
40
48
s : LabelSpec
41
- """Spec for building the s label."""
42
49
y : LabelSpec
43
- """Spec for building the y label."""
44
- to_remove : list [str ]
45
- """List of feature groups that need to be removed because they are label building blocks."""
50
+ to_remove : list [str ] = field (default_factory = list )
46
51
47
52
48
53
class FeatureSplit (TypedDict ):
@@ -53,11 +58,13 @@ class FeatureSplit(TypedDict):
53
58
y : list [str ]
54
59
55
60
56
- class FeatureOrder (Enum ):
61
+ class FeatureOrder (StrEnum ):
57
62
"""Order of features in the loaded datatuple."""
58
63
59
64
cont_first = auto ()
65
+ """Continuous features first."""
60
66
disc_first = auto ()
67
+ """Discrete features first."""
61
68
62
69
63
70
class Dataset (ABC ):
@@ -74,8 +81,10 @@ def load(
74
81
) -> DataTuple :
75
82
"""Load dataset from its CSV file.
76
83
77
- :param labels_as_features: if True, the s and y labels are included in the x features
78
- :returns: DataTuple with dataframes of features, labels and sensitive attributes
84
+ :param labels_as_features: If ``True``, the s and y labels are included in the x features.
85
+ :param order: Order of the columns in the dataframes. Can be ``disc_first`` or
86
+ ``cont_first``. See :class:`FeatureOrder`.
87
+ :returns: ``DataTuple`` with dataframes of features, labels and sensitive attributes.
79
88
"""
80
89
81
90
@property
@@ -99,7 +108,7 @@ def feature_split(self, order: FeatureOrder = FeatureOrder.disc_first) -> Featur
99
108
100
109
@property
101
110
@abstractmethod
102
- def disc_feature_groups (self ) -> DiscFeatureGroup :
111
+ def disc_feature_groups (self ) -> DiscFeatureGroups :
103
112
"""Return Dictionary of feature groups."""
104
113
105
114
@abstractmethod
@@ -139,7 +148,7 @@ def invert_sens_attr(self) -> bool:
139
148
140
149
@property
141
150
@abstractmethod
142
- def unfiltered_disc_feat_groups (self ) -> DiscFeatureGroup :
151
+ def unfiltered_disc_feat_groups (self ) -> DiscFeatureGroups :
143
152
"""Discrete feature groups, including features for the labels."""
144
153
145
154
@property
@@ -185,7 +194,7 @@ def discrete_features(self) -> list[str]:
185
194
return flatten_dict (self .disc_feature_groups )
186
195
187
196
@property
188
- def disc_feature_groups (self ) -> DiscFeatureGroup :
197
+ def disc_feature_groups (self ) -> DiscFeatureGroups :
189
198
"""Return Dictionary of feature groups, without s and y labels."""
190
199
dfgs = self .unfiltered_disc_feat_groups
191
200
# select those feature groups that are not for the x and y labels
@@ -227,12 +236,12 @@ def load(
227
236
label_specs = self .get_label_specs ()
228
237
229
238
# the following operations remove rows if a label group is not properly one-hot encoded
230
- s_data , s_mask = self . _one_hot_encode_and_combine (s_df , label_spec = label_specs .s )
239
+ s_data , s_mask = one_hot_encode_and_combine (s_df , label_specs .s , self . discard_non_one_hot )
231
240
if s_mask is not None :
232
241
x_data = x_data .loc [s_mask ].reset_index (drop = True )
233
242
s_data = s_data .loc [s_mask ].reset_index (drop = True )
234
243
y_df = y_df .loc [s_mask ].reset_index (drop = True )
235
- y_data , y_mask = self . _one_hot_encode_and_combine (y_df , label_spec = label_specs .y )
244
+ y_data , y_mask = one_hot_encode_and_combine (y_df , label_specs .y , self . discard_non_one_hot )
236
245
if y_mask is not None :
237
246
x_data = x_data .loc [y_mask ].reset_index (drop = True )
238
247
s_data = s_data .loc [y_mask ].reset_index (drop = True )
@@ -268,61 +277,6 @@ def _generate_missing_columns(self, dataframe: pd.DataFrame) -> pd.DataFrame:
268
277
)
269
278
return dataframe
270
279
271
- def _one_hot_encode_and_combine (
272
- self , attributes : pd .DataFrame , label_spec : LabelSpec
273
- ) -> tuple [pd .Series , pd .Series | None ]:
274
- """Construct a new label according to the LabelSpecs.
275
-
276
- :param attributes: DataFrame containing the attributes.
277
- :param label_spec: A label spec.
278
- """
279
- mask = None # the mask is needed when we have to discard samples
280
-
281
- # create a Series of zeroes with the same length as the dataframe
282
- combination : pd .Series = pd .Series (
283
- 0 , index = range (len (attributes )), name = "," .join (label_spec )
284
- )
285
-
286
- for name , spec in label_spec .items ():
287
- if len (spec .columns ) > 1 : # data is one-hot encoded
288
- raw_values = attributes [spec .columns ]
289
- if self .discard_non_one_hot :
290
- # only use those samples where exactly one of the specified attributes is true
291
- mask = raw_values .sum (axis = "columns" ) == 1
292
- else :
293
- assert (raw_values .sum (axis = "columns" ) == 1 ).all (), f"{ name } is not one-hot"
294
- values = undo_one_hot (raw_values )
295
- else :
296
- values = attributes [spec .columns [0 ]]
297
- combination += spec .multiplier * values
298
- return combination , mask
299
-
300
- def expand_labels (self , label : pd .Series , label_type : Literal ["s" , "y" ]) -> pd .DataFrame :
301
- """Expand a label in the form of an index into all the subfeatures.
302
-
303
- :param label: DataFrame containing the labels.
304
- :param label_type: Type of label to expand.
305
- """
306
- label_specs = self .get_label_specs ()
307
- label_mapping = label_specs .s if label_type == "s" else label_specs .y
308
-
309
- # first order the multipliers; this is important for disentangling the values
310
- names_ordered = sorted (label_mapping , key = lambda name : label_mapping [name ].multiplier )
311
-
312
- final_df = {}
313
- for i , name in enumerate (names_ordered ):
314
- spec = label_mapping [name ]
315
- value = label
316
- if i + 1 < len (names_ordered ):
317
- next_group = label_mapping [names_ordered [i + 1 ]]
318
- value = label % next_group .multiplier
319
- value = value // spec .multiplier
320
- value .replace (list (range (len (spec .columns ))), spec .columns , inplace = True )
321
- restored = pd .get_dummies (value )
322
- final_df [name ] = restored # for the multi-level column index
323
-
324
- return pd .concat (final_df , axis = 1 )
325
-
326
280
@typing .no_type_check
327
281
def load_aif (self ): # Returns aif.360 Standard Dataset
328
282
"""Load the dataset as an AIF360 dataset.
@@ -351,6 +305,40 @@ def load_aif(self): # Returns aif.360 Standard Dataset
351
305
)
352
306
353
307
308
+ def one_hot_encode_and_combine (
309
+ attributes : pd .DataFrame , label_spec : LabelSpec , discard_non_one_hot : bool
310
+ ) -> tuple [pd .Series , pd .Series | None ]:
311
+ """Construct a new label according to the given :class:`~ethicml.data.util.LabelSpec`.
312
+
313
+ This function is at the heart of the label spec API in EthicML.
314
+
315
+ :param attributes: DataFrame containing the attributes.
316
+ :param label_spec: A label spec.
317
+ :param discard_non_one_hot: If ``True``, a mask is returned which masks out all rows which are
318
+ not properly one-hot (i.e., either all classes are 0 or more than one is 1).
319
+ :returns: A tuple of a Series with the new labels and -- if ``discard_non_one_hot`` is ``True``
320
+ -- a mask for filtering out the rows that were not properly one-hot.
321
+ """
322
+ mask = None # the mask is needed when we have to discard samples
323
+
324
+ # create a Series of zeroes with the same length as the dataframe
325
+ combination : pd .Series = pd .Series (0 , index = range (len (attributes )), name = "," .join (label_spec ))
326
+
327
+ for name , spec in label_spec .items ():
328
+ if len (spec .columns ) > 1 : # data is one-hot encoded
329
+ raw_values = attributes [spec .columns ]
330
+ if discard_non_one_hot :
331
+ # only use those samples where exactly one of the specified attributes is true
332
+ mask = raw_values .sum (axis = "columns" ) == 1
333
+ else :
334
+ assert (raw_values .sum (axis = "columns" ) == 1 ).all (), f"{ name } is not one-hot"
335
+ values = undo_one_hot (raw_values )
336
+ else :
337
+ values = attributes [spec .columns [0 ]]
338
+ combination += spec .multiplier * values
339
+ return combination , mask
340
+
341
+
354
342
@dataclass
355
343
class CSVDatasetDC (CSVDataset , ABC ):
356
344
"""Dataset that uses the default load function."""
@@ -371,7 +359,37 @@ def invert_sens_attr(self) -> bool:
371
359
372
360
@dataclass
373
361
class StaticCSVDataset (CSVDatasetDC , ABC ):
374
- """Dataset whose size and file location does not depend on constructor arguments."""
362
+ """Dataset whose size and file location does not depend on constructor arguments.
363
+
364
+ :example:
365
+ How to subclass this:
366
+
367
+ .. code:: python
368
+
369
+ @dataclass
370
+ class Toy(StaticCSVDataset):
371
+ '''Dataset with toy data for testing.'''
372
+
373
+ num_samples: ClassVar[int] = 400
374
+ csv_file: ClassVar[str] = "toy.csv"
375
+
376
+ @property
377
+ def name(self) -> str:
378
+ return "Toy"
379
+
380
+ def get_label_specs(self) -> LabelSpecsPair:
381
+ return LabelSpecsPair(
382
+ s=single_col_spec("sens"), y=single_col_spec("class")
383
+ )
384
+
385
+ @property
386
+ def unfiltered_disc_feat_groups(self) -> DiscFeatureGroups:
387
+ return {"disc_1": ["a_1", "a_2", "a_3"], "disc_2": ["b_1", "b_2"]}
388
+
389
+ @property
390
+ def continuous_features(self) -> list[str]:
391
+ return ["c1", "c2"]
392
+ """
375
393
376
394
num_samples : ClassVar [int ] = 0
377
395
csv_file : ClassVar [str ] = "<overwrite me>"
@@ -387,7 +405,11 @@ def get_filename_or_path(self) -> str | Path:
387
405
388
406
@dataclass (init = False )
389
407
class LegacyDataset (CSVDataset ):
390
- """Dataset that uses the default load function."""
408
+ """Dataset base class.
409
+
410
+ This base class is considered legacy now. Please use :class:`CSVDatasetDC` or
411
+ :class:`StaticCSVDataset` instead.
412
+ """
391
413
392
414
discrete_only : bool = False
393
415
invert_s : bool = False
@@ -423,7 +445,7 @@ def __init__(
423
445
424
446
@property
425
447
@implements (CSVDataset )
426
- def unfiltered_disc_feat_groups (self ) -> DiscFeatureGroup :
448
+ def unfiltered_disc_feat_groups (self ) -> DiscFeatureGroups :
427
449
return self ._unfiltered_disc_feat_groups
428
450
429
451
@property
0 commit comments