Skip to content

Commit ddb8048

Browse files
author
Myles Bartlett
committed
Initial commit.
0 parents  commit ddb8048

File tree

214 files changed

+13817
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

214 files changed

+13817
-0
lines changed

LICENSE

+674
Large diffs are not rendered by default.

README.md

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Okapi: Generalising Better By Making Statistical Matches Match
2+
3+
Official code for the NeurIPS 2022 paper _Okapi: Generalising Better By Making
4+
Statistical Matches Match_
5+
6+
> We propose Okapi, a simple, efficient, and general method for robust
7+
semi-supervised learning based on online statistical matching. Our method uses
8+
a nearest-neighbours-based matching procedure to generate cross-domain views
9+
for a consistency loss, while eliminating statistical outliers. In order to
10+
perform the online matching in a runtime- and memory-efficient way, we
11+
draw upon the self-supervised literature and combine a memory bank with
12+
a slow-moving momentum encoder. The consistency loss is applied within
13+
the feature space, rather than on the predictive distribution, making
14+
the method agnostic to both the modality and the task in question. We
15+
experiment on the WILDS 2.0 datasets Sagawa et al., which significantly
16+
expands the range of modalities, applications, and shifts available for
17+
studying and benchmarking real-world unsupervised adaptation. Contrary
18+
to Sagawa et al., we show that it is in fact possible to leverage
19+
additional unlabelled data to improve upon empirical risk minimisation
20+
(ERM) results with the right method. Our method outperforms the
21+
baseline methods in terms of out-of-distribution (OOD) generalisation
22+
on the iWildCam (a multi-class classification task) and PovertyMap (a
23+
regression task) image datasets as well as the CivilComments (a binary
24+
classification task) text dataset. Furthermore, from a qualitative
25+
perspective, we show the matches obtained from the learned encoder are
26+
strongly semantically related.
27+
28+
## Requirements
29+
- python >=3.9
30+
- [poetry](https://python-poetry.org/)
31+
- CUDA >=11.3 (if installing with ``install.sh``)
32+
33+
## Installation
34+
We use [poetry](https://python-poetry.org/) for dependency management,
35+
installation of which is a prerequisite for installation of the python
36+
dependencies. With poetry installed, the dependencies can then be installed by
37+
running ``install.sh``, contingent on CUDA >=11.3 being installed if installing
38+
to a CUDA-equipped machine. This constraint can be bypassed by manually
39+
excuting the commands:
40+
- ``poetry install``
41+
- install the appropriate version of Pytorch and ``torch-scatter`` (required
42+
for evaluation with [WILDS](https://github.com/p-lambda/wilds)) for the
43+
version of CUDA installed on your machine.
44+
45+
## Running the code
46+
We use [hydra](https://github.com/facebookresearch/hydra) for managing the
47+
configuration of our experiments. Experiment configurations are grouped by
48+
dataset in ``external_confs/experiments`` and can be imported via the
49+
commandline with the command ``python main.py +experiment={dataset}/{method}``;
50+
one can then override any desired configs/arguments with the syntax
51+
``{config}={name_of_config_file}`` or ``{config}.{attribute}={value}``
52+
(e.g.``seed=42`` (defined in the main config class), ``backbone=iw/rn50``,
53+
``alglr.=1.e-5``).
54+
55+
56+
## Citation
57+
```
58+
@article{bartlett2022okapi,
59+
title={Okapi: Generalising Better by Making Statistical Matches Match},
60+
author={Bartlett, Myles and Romiti, Sara and Sharmanska, Viktoriia and Quadrianto, Novi},
61+
journal={Advances in neural information processing systems},
62+
volume={35},
63+
year={2022}
64+
}

external_confs/alg/iwildcam/clip.yaml

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
---
2+
defaults:
3+
- /schema/alg: erm
4+
- defaults
5+
- _self_
6+
model:
7+
evaluator:
8+
lr: 5e-05
9+
optimizer_cls: 'torch.optim.AdamW'
10+
optimizer_kwargs: null
11+
use_sam: false
12+
sam_rho: 0.05
13+
scheduler_cls: null
14+
scheduler_kwargs: null
15+
lr_sched_interval: step
16+
lr_sched_freq: 1
17+
loss_fn: null
18+
batch_transforms:
19+
- _target_: ranzen.torch.transforms.RandomCutMix
20+
alpha: 1.0
21+
num_classes: 182
22+
- _target_: ranzen.torch.transforms.RandomMixUp.with_beta_dist
23+
alpha: 0.2
24+
num_classes: 182
25+
inplace: true
+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
---
2+
defaults:
3+
- /schema/alg: fixmatch
4+
- defaults
5+
- _self_
6+
model:
7+
evaluator:
8+
lr: 3e-05
9+
optimizer_cls: 'torch.optim.AdamW'
10+
optimizer_kwargs: null
11+
use_sam: false
12+
sam_rho: 0.05
13+
scheduler_cls: null
14+
scheduler_kwargs: null
15+
lr_sched_interval: step
16+
lr_sched_freq: 1
17+
batch_transforms: null
18+
confidence_threshold: 0.70
19+
loss_u_weight: 1.0
20+
temperature: 1.0
+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
---
2+
defaults:
3+
- /schema/backbone: convnext
4+
- _self_
5+
version: TINY
6+
in_channels: 3
7+
pretrained: true

external_confs/backbone/iw/rn50.yaml

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
---
2+
defaults:
3+
- /schema/backbone: resnet
4+
- _self_
5+
version: RN50
6+
in_channels: 3
7+
pretrained: true
+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
---
2+
defaults:
3+
- /schema/backbone: convnext
4+
- _self_
5+
in_channels: 8
6+
pretrained: true
7+
version: TINY
8+
checkpoint_path: ''
+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
---
2+
defaults:
3+
- /schema/backbone: resnet
4+
- _self_
5+
version: RN18
6+
in_channels: 8
7+
pretrained: true

external_confs/checkpointer/cc.yaml

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
---
2+
defaults:
3+
- /schema/checkpointer: base
4+
- _self_
5+
monitor: "validate/OOD/acc_wg"
6+
mode: 'max'
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
---
2+
defaults:
3+
- /schema/checkpointer: base
4+
- _self_
5+
monitor: "validate/OOD/F1-macro_all"
6+
mode: 'max'

external_confs/checkpointer/pm.yaml

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
---
2+
defaults:
3+
- /schema/checkpointer: base
4+
- _self_
5+
monitor: "validate/OOD/r_wg"
6+
mode: 'max'

external_confs/dm/erm_no_aug.yaml

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
defaults:
2+
- iwildcam
3+
- _self_
4+
5+
groupby_fields: ['location']
6+
train_batch_size_l: 24
7+
training_mode: step
8+
9+
train_transforms_l:
10+
_target_: torchvision.transforms.Compose
11+
transforms:
12+
- _target_: torchvision.transforms.Resize
13+
size: ${ dm.target_resolution }
14+
- _target_: torchvision.transforms.CenterCrop
15+
size: ${ dm.target_resolution }
16+
- _target_: torchvision.transforms.ToTensor
17+
- _target_: torchvision.transforms.Normalize
18+
mean: [ 0.485, 0.456, 0.406 ]
19+
std: [ 0.229, 0.224, 0.225 ]
20+
21+
test_transforms:
22+
_target_: torchvision.transforms.Compose
23+
transforms:
24+
- _target_: torchvision.transforms.Resize
25+
size: ${ dm.target_resolution }
26+
- _target_: torchvision.transforms.CenterCrop
27+
size: ${ dm.target_resolution }
28+
- _target_: torchvision.transforms.ToTensor
29+
- _target_: torchvision.transforms.Normalize
30+
mean: [ 0.485, 0.456, 0.406 ]
31+
std: [ 0.229, 0.224, 0.225 ]

external_confs/dm/iwildcam/clip.yaml

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
---
2+
defaults:
3+
- iwildcam
4+
- _self_
5+
6+
train_batch_size_l: 24
7+
training_mode: step
8+
target_resolution: 224
9+
10+
train_transforms_l:
11+
_target_: torchvision.transforms.Compose
12+
transforms:
13+
- _target_: torchvision.transforms.Resize
14+
size: ${ target_resolution }
15+
- _target_: torchvision.transforms.CenterCrop
16+
size: ${ target_resolution }
17+
- _target_: torchvision.transforms.RandomHorizontalFlip
18+
- _target_: torchvision.transforms.RandAugment
19+
num_ops: 2
20+
- _target_: torchvision.transforms.ToTensor
21+
- _target_: torchvision.transforms.Normalize
22+
mean: [0.48145466, 0.4578275, 0.40821073]
23+
std: [0.26862954, 0.26130258, 0.27577711]
24+
25+
26+
test_transforms:
27+
_target_: torchvision.transforms.Compose
28+
transforms:
29+
- _target_: torchvision.transforms.Resize
30+
size: ${ target_resolution }
31+
- _target_: torchvision.transforms.CenterCrop
32+
size: ${ target_resolution }
33+
- _target_: torchvision.transforms.ToTensor
34+
- _target_: torchvision.transforms.Normalize
35+
mean: [0.48145466, 0.4578275, 0.40821073]
36+
std: [0.26862954, 0.26130258, 0.27577711]
+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
---
2+
defaults:
3+
- iwildcam
4+
- _self_
5+
6+
groupby_fields: ['location']
7+
train_batch_size_l: 24
8+
training_mode: step
9+
10+
train_transforms_l:
11+
_target_: torchvision.transforms.Compose
12+
transforms:
13+
- _target_: torchvision.transforms.Resize
14+
size: ${ dm.target_resolution }
15+
- _target_: torchvision.transforms.CenterCrop
16+
size: ${ dm.target_resolution }
17+
- _target_: torchvision.transforms.RandomHorizontalFlip
18+
- _target_: torchvision.transforms.RandAugment
19+
num_ops: 2
20+
- _target_: torchvision.transforms.ToTensor
21+
- _target_: torchvision.transforms.Normalize
22+
mean: [ 0.485, 0.456, 0.406 ]
23+
std: [ 0.229, 0.224, 0.225 ]
24+
25+
test_transforms:
26+
_target_: torchvision.transforms.Compose
27+
transforms:
28+
- _target_: torchvision.transforms.Resize
29+
size: ${ dm.target_resolution }
30+
- _target_: torchvision.transforms.CenterCrop
31+
size: ${ dm.target_resolution }
32+
- _target_: torchvision.transforms.ToTensor
33+
- _target_: torchvision.transforms.Normalize
34+
mean: [ 0.485, 0.456, 0.406 ]
35+
std: [ 0.229, 0.224, 0.225 ]
+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
---
2+
defaults:
3+
- iwildcam
4+
- _self_
5+
6+
training_mode: step
7+
use_unlabeled: true
8+
target_resolution: 448
9+
train_batch_size_l: 16
10+
train_batch_size_u: 16
11+
12+
train_transforms_l:
13+
_target_: torchvision.transforms.Compose
14+
transforms:
15+
- _target_: torchvision.transforms.Resize
16+
size: ${ dm.target_resolution }
17+
- _target_: torchvision.transforms.CenterCrop
18+
size: ${ dm.target_resolution }
19+
- _target_: torchvision.transforms.RandomHorizontalFlip
20+
- _target_: torchvision.transforms.RandAugment
21+
num_ops: 2
22+
- _target_: torchvision.transforms.ToTensor
23+
- _target_: torchvision.transforms.Normalize
24+
mean: [ 0.485, 0.456, 0.406 ]
25+
std: [ 0.229, 0.224, 0.225 ]
26+
27+
train_transforms_u:
28+
_target_: src.transforms.FixMatchTransform
29+
shared_transform_start:
30+
_target_: torchvision.transforms.Compose
31+
transforms:
32+
- _target_: torchvision.transforms.Resize
33+
size: ${ dm.target_resolution }
34+
strong_transform:
35+
_target_: torchvision.transforms.Compose
36+
transforms:
37+
- _target_: torchvision.transforms.RandomHorizontalFlip
38+
- _target_: torchvision.transforms.RandomCrop
39+
size: ${ dm.target_resolution }
40+
- _target_: src.transforms.FixMatchRandAugment
41+
num_ops: 2
42+
weak_transform:
43+
_target_: torchvision.transforms.Compose
44+
transforms:
45+
- _target_: torchvision.transforms.RandomHorizontalFlip
46+
- _target_: torchvision.transforms.RandomCrop
47+
size: ${ dm.target_resolution }
48+
shared_transform_end:
49+
_target_: torchvision.transforms.Compose
50+
transforms:
51+
- _target_: torchvision.transforms.ToTensor
52+
- _target_: torchvision.transforms.Normalize
53+
mean: [ 0.485, 0.456, 0.406 ]
54+
std: [ 0.229, 0.224, 0.225 ]
55+
56+
test_transforms:
57+
_target_: torchvision.transforms.Compose
58+
transforms:
59+
- _target_: torchvision.transforms.Resize
60+
size: ${ dm.target_resolution }
61+
- _target_: torchvision.transforms.CenterCrop
62+
size: ${ dm.target_resolution }
63+
- _target_: torchvision.transforms.ToTensor
64+
- _target_: torchvision.transforms.Normalize
65+
mean: [ 0.485, 0.456, 0.406 ]
66+
std: [ 0.229, 0.224, 0.225 ]
67+

external_confs/dm/iwildcam/okapi.yaml

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
---
2+
defaults:
3+
- iwildcam
4+
- _self_
5+
6+
training_mode: step
7+
groupby_fields: [location]
8+
train_batch_size_l: 16
9+
train_batch_size_u: 16
10+
use_unlabeled: true
11+
12+
train_transforms_l:
13+
_target_: torchvision.transforms.Compose
14+
transforms:
15+
- _target_: torchvision.transforms.Resize
16+
size: ${ dm.target_resolution }
17+
- _target_: torchvision.transforms.CenterCrop
18+
size: ${ dm.target_resolution }
19+
- _target_: torchvision.transforms.RandomHorizontalFlip
20+
- _target_: torchvision.transforms.RandAugment
21+
num_ops: 2
22+
- _target_: torchvision.transforms.ToTensor
23+
- _target_: torchvision.transforms.Normalize
24+
mean: [ 0.485, 0.456, 0.406 ]
25+
std: [ 0.229, 0.224, 0.225 ]
26+
27+
test_transforms:
28+
_target_: torchvision.transforms.Compose
29+
transforms:
30+
- _target_: torchvision.transforms.Resize
31+
size: ${ dm.target_resolution }
32+
- _target_: torchvision.transforms.CenterCrop
33+
size: ${ dm.target_resolution }
34+
- _target_: torchvision.transforms.ToTensor
35+
- _target_: torchvision.transforms.Normalize
36+
mean: [ 0.485, 0.456, 0.406 ]
37+
std: [ 0.229, 0.224, 0.225 ]

external_confs/dm/pm/erm.yaml

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
---
2+
defaults:
3+
- /schema/dm: poverty_map
4+
- _self_
5+
6+
fold: A
7+
train_batch_size_l: 128
8+
training_mode: step
9+
use_unlabeled: false
10+
groupby_fields: ['country']
11+
train_transforms_l:
12+
_target_: src.transforms.Identity

0 commit comments

Comments
 (0)