Skip to content

Commit 6791efc

Browse files
authored
Merge pull request #774 from wearepal/improve-docs
Fill in a lot of missing documentation
2 parents 0a7659f + 54ba63f commit 6791efc

28 files changed

+236
-164
lines changed

.github/workflows/continuous_integration.yml

+10-10
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ jobs:
3131
runs-on: ubuntu-latest
3232

3333
steps:
34-
- uses: actions/checkout@v2
34+
- uses: actions/checkout@v3
3535
- name: Set up Python
36-
uses: actions/setup-python@v2
36+
uses: actions/setup-python@v4
3737
with:
3838
python-version: ${{ env.PYTHON_VERSION }}
3939
- name: Install pylint
@@ -50,9 +50,9 @@ jobs:
5050
runs-on: ubuntu-latest
5151

5252
steps:
53-
- uses: actions/checkout@v2
53+
- uses: actions/checkout@v3
5454
- name: Set up Python
55-
uses: actions/setup-python@v2
55+
uses: actions/setup-python@v4
5656
with:
5757
python-version: ${{ env.PYTHON_VERSION }}
5858
- name: Install black
@@ -72,9 +72,9 @@ jobs:
7272
runs-on: ubuntu-latest
7373

7474
steps:
75-
- uses: actions/checkout@v2
75+
- uses: actions/checkout@v3
7676
- name: Set up Python 3.10
77-
uses: actions/setup-python@v2
77+
uses: actions/setup-python@v4
7878
with:
7979
python-version: "3.10"
8080
- name: Install pydocstyle
@@ -96,9 +96,9 @@ jobs:
9696

9797
runs-on: ubuntu-latest
9898
steps:
99-
- uses: actions/checkout@v2
99+
- uses: actions/checkout@v3
100100
- name: Set up Python
101-
uses: actions/setup-python@v2
101+
uses: actions/setup-python@v4
102102
with:
103103
python-version: ${{ env.PYTHON_VERSION }}
104104
- name: Install darglint
@@ -113,7 +113,7 @@ jobs:
113113
runs-on: ubuntu-latest
114114
steps:
115115
- name: Check out repository
116-
uses: actions/checkout@v2
116+
uses: actions/checkout@v3
117117
- name: Set up Python
118118
uses: actions/setup-python@v4
119119
with:
@@ -152,7 +152,7 @@ jobs:
152152
# --- check-out repo and set-up python ---
153153
#----------------------------------------------
154154
- name: Check out repository
155-
uses: actions/checkout@v2
155+
uses: actions/checkout@v3
156156
- name: Set up Python
157157
uses: actions/setup-python@v4
158158
with:

docs/ethicml.data.rst

+6
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,9 @@ ethicml.data
66
:members:
77
:imported-members:
88
:autosummary:
9+
10+
Aliases
11+
=======
12+
13+
.. autoclass:: ethicml.data.util.DiscFeatureGroups
14+
.. autoclass:: ethicml.data.util.LabelSpec

ethicml/data/dataset.py

+99-77
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
11
"""Data structure for all datasets that come with the framework."""
22
from __future__ import annotations
33
from abc import ABC, abstractmethod
4-
from dataclasses import dataclass
5-
from enum import Enum, auto
4+
from dataclasses import dataclass, field
5+
from enum import auto
66
from pathlib import Path
77
import typing
8-
from typing import ClassVar, Literal, NamedTuple, Sequence, TypedDict, final
8+
from typing import ClassVar, Sequence, TypedDict, final
99

1010
import pandas as pd
11-
from ranzen import implements
11+
from ranzen import StrEnum, implements
1212

1313
from ethicml.common import ROOT_PATH
1414
from ethicml.utility import DataTuple, undo_one_hot
1515

1616
from .util import (
17-
DiscFeatureGroup,
17+
DiscFeatureGroups,
1818
LabelSpec,
1919
flatten_dict,
2020
from_dummies,
@@ -31,18 +31,23 @@
3131
"LabelSpecsPair",
3232
"LegacyDataset",
3333
"StaticCSVDataset",
34+
"one_hot_encode_and_combine",
3435
]
3536

3637

37-
class LabelSpecsPair(NamedTuple):
38-
"""A pair of label specs."""
38+
@dataclass
39+
class LabelSpecsPair:
40+
"""A pair of label specs.
41+
42+
:param s: Spec for building the ``s`` label.
43+
:param y: Spec for building the ``y`` label.
44+
:param to_remove: List of feature groups that need to be removed because they are label building
45+
blocks. (Default: ``[]``)
46+
"""
3947

4048
s: LabelSpec
41-
"""Spec for building the s label."""
4249
y: LabelSpec
43-
"""Spec for building the y label."""
44-
to_remove: list[str]
45-
"""List of feature groups that need to be removed because they are label building blocks."""
50+
to_remove: list[str] = field(default_factory=list)
4651

4752

4853
class FeatureSplit(TypedDict):
@@ -53,11 +58,13 @@ class FeatureSplit(TypedDict):
5358
y: list[str]
5459

5560

56-
class FeatureOrder(Enum):
61+
class FeatureOrder(StrEnum):
5762
"""Order of features in the loaded datatuple."""
5863

5964
cont_first = auto()
65+
"""Continuous features first."""
6066
disc_first = auto()
67+
"""Discrete features first."""
6168

6269

6370
class Dataset(ABC):
@@ -74,8 +81,10 @@ def load(
7481
) -> DataTuple:
7582
"""Load dataset from its CSV file.
7683
77-
:param labels_as_features: if True, the s and y labels are included in the x features
78-
:returns: DataTuple with dataframes of features, labels and sensitive attributes
84+
:param labels_as_features: If ``True``, the s and y labels are included in the x features.
85+
:param order: Order of the columns in the dataframes. Can be ``disc_first`` or
86+
``cont_first``. See :class:`FeatureOrder`.
87+
:returns: ``DataTuple`` with dataframes of features, labels and sensitive attributes.
7988
"""
8089

8190
@property
@@ -99,7 +108,7 @@ def feature_split(self, order: FeatureOrder = FeatureOrder.disc_first) -> Featur
99108

100109
@property
101110
@abstractmethod
102-
def disc_feature_groups(self) -> DiscFeatureGroup:
111+
def disc_feature_groups(self) -> DiscFeatureGroups:
103112
"""Return Dictionary of feature groups."""
104113

105114
@abstractmethod
@@ -139,7 +148,7 @@ def invert_sens_attr(self) -> bool:
139148

140149
@property
141150
@abstractmethod
142-
def unfiltered_disc_feat_groups(self) -> DiscFeatureGroup:
151+
def unfiltered_disc_feat_groups(self) -> DiscFeatureGroups:
143152
"""Discrete feature groups, including features for the labels."""
144153

145154
@property
@@ -185,7 +194,7 @@ def discrete_features(self) -> list[str]:
185194
return flatten_dict(self.disc_feature_groups)
186195

187196
@property
188-
def disc_feature_groups(self) -> DiscFeatureGroup:
197+
def disc_feature_groups(self) -> DiscFeatureGroups:
189198
"""Return Dictionary of feature groups, without s and y labels."""
190199
dfgs = self.unfiltered_disc_feat_groups
191200
# select those feature groups that are not for the x and y labels
@@ -227,12 +236,12 @@ def load(
227236
label_specs = self.get_label_specs()
228237

229238
# the following operations remove rows if a label group is not properly one-hot encoded
230-
s_data, s_mask = self._one_hot_encode_and_combine(s_df, label_spec=label_specs.s)
239+
s_data, s_mask = one_hot_encode_and_combine(s_df, label_specs.s, self.discard_non_one_hot)
231240
if s_mask is not None:
232241
x_data = x_data.loc[s_mask].reset_index(drop=True)
233242
s_data = s_data.loc[s_mask].reset_index(drop=True)
234243
y_df = y_df.loc[s_mask].reset_index(drop=True)
235-
y_data, y_mask = self._one_hot_encode_and_combine(y_df, label_spec=label_specs.y)
244+
y_data, y_mask = one_hot_encode_and_combine(y_df, label_specs.y, self.discard_non_one_hot)
236245
if y_mask is not None:
237246
x_data = x_data.loc[y_mask].reset_index(drop=True)
238247
s_data = s_data.loc[y_mask].reset_index(drop=True)
@@ -268,61 +277,6 @@ def _generate_missing_columns(self, dataframe: pd.DataFrame) -> pd.DataFrame:
268277
)
269278
return dataframe
270279

271-
def _one_hot_encode_and_combine(
272-
self, attributes: pd.DataFrame, label_spec: LabelSpec
273-
) -> tuple[pd.Series, pd.Series | None]:
274-
"""Construct a new label according to the LabelSpecs.
275-
276-
:param attributes: DataFrame containing the attributes.
277-
:param label_spec: A label spec.
278-
"""
279-
mask = None # the mask is needed when we have to discard samples
280-
281-
# create a Series of zeroes with the same length as the dataframe
282-
combination: pd.Series = pd.Series(
283-
0, index=range(len(attributes)), name=",".join(label_spec)
284-
)
285-
286-
for name, spec in label_spec.items():
287-
if len(spec.columns) > 1: # data is one-hot encoded
288-
raw_values = attributes[spec.columns]
289-
if self.discard_non_one_hot:
290-
# only use those samples where exactly one of the specified attributes is true
291-
mask = raw_values.sum(axis="columns") == 1
292-
else:
293-
assert (raw_values.sum(axis="columns") == 1).all(), f"{name} is not one-hot"
294-
values = undo_one_hot(raw_values)
295-
else:
296-
values = attributes[spec.columns[0]]
297-
combination += spec.multiplier * values
298-
return combination, mask
299-
300-
def expand_labels(self, label: pd.Series, label_type: Literal["s", "y"]) -> pd.DataFrame:
301-
"""Expand a label in the form of an index into all the subfeatures.
302-
303-
:param label: DataFrame containing the labels.
304-
:param label_type: Type of label to expand.
305-
"""
306-
label_specs = self.get_label_specs()
307-
label_mapping = label_specs.s if label_type == "s" else label_specs.y
308-
309-
# first order the multipliers; this is important for disentangling the values
310-
names_ordered = sorted(label_mapping, key=lambda name: label_mapping[name].multiplier)
311-
312-
final_df = {}
313-
for i, name in enumerate(names_ordered):
314-
spec = label_mapping[name]
315-
value = label
316-
if i + 1 < len(names_ordered):
317-
next_group = label_mapping[names_ordered[i + 1]]
318-
value = label % next_group.multiplier
319-
value = value // spec.multiplier
320-
value.replace(list(range(len(spec.columns))), spec.columns, inplace=True)
321-
restored = pd.get_dummies(value)
322-
final_df[name] = restored # for the multi-level column index
323-
324-
return pd.concat(final_df, axis=1)
325-
326280
@typing.no_type_check
327281
def load_aif(self): # Returns aif.360 Standard Dataset
328282
"""Load the dataset as an AIF360 dataset.
@@ -351,6 +305,40 @@ def load_aif(self): # Returns aif.360 Standard Dataset
351305
)
352306

353307

308+
def one_hot_encode_and_combine(
309+
attributes: pd.DataFrame, label_spec: LabelSpec, discard_non_one_hot: bool
310+
) -> tuple[pd.Series, pd.Series | None]:
311+
"""Construct a new label according to the given :class:`~ethicml.data.util.LabelSpec`.
312+
313+
This function is at the heart of the label spec API in EthicML.
314+
315+
:param attributes: DataFrame containing the attributes.
316+
:param label_spec: A label spec.
317+
:param discard_non_one_hot: If ``True``, a mask is returned which masks out all rows which are
318+
not properly one-hot (i.e., either all classes are 0 or more than one is 1).
319+
:returns: A tuple of a Series with the new labels and -- if ``discard_non_one_hot`` is ``True``
320+
-- a mask for filtering out the rows that were not properly one-hot.
321+
"""
322+
mask = None # the mask is needed when we have to discard samples
323+
324+
# create a Series of zeroes with the same length as the dataframe
325+
combination: pd.Series = pd.Series(0, index=range(len(attributes)), name=",".join(label_spec))
326+
327+
for name, spec in label_spec.items():
328+
if len(spec.columns) > 1: # data is one-hot encoded
329+
raw_values = attributes[spec.columns]
330+
if discard_non_one_hot:
331+
# only use those samples where exactly one of the specified attributes is true
332+
mask = raw_values.sum(axis="columns") == 1
333+
else:
334+
assert (raw_values.sum(axis="columns") == 1).all(), f"{name} is not one-hot"
335+
values = undo_one_hot(raw_values)
336+
else:
337+
values = attributes[spec.columns[0]]
338+
combination += spec.multiplier * values
339+
return combination, mask
340+
341+
354342
@dataclass
355343
class CSVDatasetDC(CSVDataset, ABC):
356344
"""Dataset that uses the default load function."""
@@ -371,7 +359,37 @@ def invert_sens_attr(self) -> bool:
371359

372360
@dataclass
373361
class StaticCSVDataset(CSVDatasetDC, ABC):
374-
"""Dataset whose size and file location does not depend on constructor arguments."""
362+
"""Dataset whose size and file location does not depend on constructor arguments.
363+
364+
:example:
365+
How to subclass this:
366+
367+
.. code:: python
368+
369+
@dataclass
370+
class Toy(StaticCSVDataset):
371+
'''Dataset with toy data for testing.'''
372+
373+
num_samples: ClassVar[int] = 400
374+
csv_file: ClassVar[str] = "toy.csv"
375+
376+
@property
377+
def name(self) -> str:
378+
return "Toy"
379+
380+
def get_label_specs(self) -> LabelSpecsPair:
381+
return LabelSpecsPair(
382+
s=single_col_spec("sens"), y=single_col_spec("class")
383+
)
384+
385+
@property
386+
def unfiltered_disc_feat_groups(self) -> DiscFeatureGroups:
387+
return {"disc_1": ["a_1", "a_2", "a_3"], "disc_2": ["b_1", "b_2"]}
388+
389+
@property
390+
def continuous_features(self) -> list[str]:
391+
return ["c1", "c2"]
392+
"""
375393

376394
num_samples: ClassVar[int] = 0
377395
csv_file: ClassVar[str] = "<overwrite me>"
@@ -387,7 +405,11 @@ def get_filename_or_path(self) -> str | Path:
387405

388406
@dataclass(init=False)
389407
class LegacyDataset(CSVDataset):
390-
"""Dataset that uses the default load function."""
408+
"""Dataset base class.
409+
410+
This base class is considered legacy now. Please use :class:`CSVDatasetDC` or
411+
:class:`StaticCSVDataset` instead.
412+
"""
391413

392414
discrete_only: bool = False
393415
invert_s: bool = False
@@ -423,7 +445,7 @@ def __init__(
423445

424446
@property
425447
@implements(CSVDataset)
426-
def unfiltered_disc_feat_groups(self) -> DiscFeatureGroup:
448+
def unfiltered_disc_feat_groups(self) -> DiscFeatureGroups:
427449
return self._unfiltered_disc_feat_groups
428450

429451
@property

ethicml/data/load.py

-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ def load_data(dataset: Dataset) -> DataTuple:
1919
This function only exists for backwards compatibility. Use dataset.load() instead.
2020
2121
:param dataset: dataset object
22-
:param ordered: if True, return features such that discrete come first, then continuous (Default: False)
2322
:returns: DataTuple with dataframes of features, labels and sensitive attributes
2423
"""
2524
return dataset.load()

ethicml/data/lookup.py

+2
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ def get_dataset_obj_by_name(name: str) -> Callable[[], Dataset]:
3434
"""Given a dataset name, get the corresponding dataset object.
3535
3636
:param name: Name of the dataset.
37+
:returns: A callable that can be used to construct the dataset object.
38+
:raises NotImplementedError: If the given name does not correspond to a dataset.
3739
"""
3840
lookup = _lookup_table()
3941
lowercase_name = name.lower()

0 commit comments

Comments
 (0)