Skip to content

Commit 678b296

Browse files
author
Stephen Hausler
committed
Adding 2D training code
1 parent d44506c commit 678b296

24 files changed

+1467
-14
lines changed

README.md

+10-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# WildScenes: A Benchmark for 2D and 3D Semantic Segmentation in Natural Environments
22

3-
This is the official repo for the WildScenes dataset, which provides benchmarks for semantic segmentation in 2D and 3D.
3+
This is the official repo for the WildScenes dataset, which provides benchmarks for semantic segmentation in 2D and 3D. Training is performed using the [mmsegmentation](https://github.com/open-mmlab/mmsegmentation) and [mmdetection3d](https://github.com/open-mmlab/mmdetection3d) toolboxs. We thank and acknowledge the contributions of these toolboxes.
44

55
### Abstract
66

@@ -48,6 +48,15 @@ Will be released soon.
4848

4949
### Training code
5050

51+
All training and eval scripts are located in the directory scripts/benchmark.
52+
53+
#### 2D Training
54+
55+
Using mmsegmentation, 2D models can be trained using train2d.py.
56+
So far we have released Mask2Former training on WildScenes.
57+
58+
#### 3D Training
59+
5160
Will be released soon.
5261

5362
### Evaluation code

installation.md

+6-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Mamba is a replacement for conda and behaves the same except is faster especiall
1010
These installation instructions are written for CUDA version 12.1.
1111

1212
```shell
13-
mamba create --name wildscenes3 python=3.10
13+
mamba create --name wildscenes python=3.10
1414
mamba activate wildscenes
1515
```
1616

@@ -27,7 +27,7 @@ Step 1: Install Pytorch
2727
Using CUDA 12.1:
2828

2929
```shell
30-
mamba install pytorch torchvision pytorch-cuda -c pytorch -c nvidia -c
30+
mamba install pytorch torchvision pytorch-cuda -c pytorch -c nvidia
3131
```
3232

3333
On CPU only platforms:
@@ -69,6 +69,10 @@ pip install pillow
6969
pip install tensorboard
7070
pip install matplotlib
7171
pip install open3d==0.1.8
72+
pip install pyntcloud
73+
pip install wand
74+
pip install ftfy
75+
pip install regex
7276
```
7377

7478
Step 6: Install other required mamba packages (or install using pip)

scripts/benchmark/eval2d.py

+136
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
# File modified from original file: https://github.com/open-mmlab/mmsegmentation/blob/main/tools/test.py
3+
# Modified to run 2D segmentation on the WildScenes dataset.
4+
import argparse
5+
import os
6+
import os.path as osp
7+
import sys
8+
from pathlib import Path
9+
10+
from mmengine.config import Config, DictAction
11+
from mmengine.runner import Runner
12+
13+
root_dir = Path(__file__).parent.parent.parent
14+
sys.path.append(str(root_dir))
15+
16+
from wildscenes.mmseg_wildscenes.registry import RUNNERS
17+
18+
19+
def parse_args():
20+
parser = argparse.ArgumentParser(
21+
description='MMSeg test (and eval) a model')
22+
parser.add_argument('config', help='train config file path')
23+
parser.add_argument('checkpoint', help='checkpoint file')
24+
parser.add_argument(
25+
'--work-dir',
26+
help=('if specified, the evaluation metric results will be dumped'
27+
'into the directory as json'))
28+
parser.add_argument(
29+
'--out',
30+
type=str,
31+
help='The directory to save output prediction for offline evaluation')
32+
parser.add_argument(
33+
'--show', action='store_true', help='show prediction results')
34+
parser.add_argument(
35+
'--show-dir',
36+
help='directory where painted images will be saved. '
37+
'If specified, it will be automatically saved '
38+
'to the work_dir/timestamp/show_dir')
39+
parser.add_argument(
40+
'--wait-time', type=float, default=2, help='the interval of show (s)')
41+
parser.add_argument(
42+
'--cfg-options',
43+
nargs='+',
44+
action=DictAction,
45+
help='override some settings in the used config, the key-value pair '
46+
'in xxx=yyy format will be merged into config file. If the value to '
47+
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
48+
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
49+
'Note that the quotation marks are necessary and that no white space '
50+
'is allowed.')
51+
parser.add_argument(
52+
'--launcher',
53+
choices=['none', 'pytorch', 'slurm', 'mpi'],
54+
default='none',
55+
help='job launcher')
56+
parser.add_argument(
57+
'--tta', action='store_true', help='Test time augmentation')
58+
# When using PyTorch version >= 2.0.0, the `torch.distributed.launch`
59+
# will pass the `--local-rank` parameter to `tools/train.py` instead
60+
# of `--local_rank`.
61+
parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
62+
args = parser.parse_args()
63+
if 'LOCAL_RANK' not in os.environ:
64+
os.environ['LOCAL_RANK'] = str(args.local_rank)
65+
66+
return args
67+
68+
69+
def trigger_visualization_hook(cfg, args):
70+
default_hooks = cfg.default_hooks
71+
if 'visualization' in default_hooks:
72+
visualization_hook = default_hooks['visualization']
73+
# Turn on visualization
74+
visualization_hook['draw'] = True
75+
if args.show:
76+
visualization_hook['show'] = True
77+
visualization_hook['wait_time'] = args.wait_time
78+
if args.show_dir:
79+
visualizer = cfg.visualizer
80+
visualizer['save_dir'] = args.show_dir
81+
else:
82+
raise RuntimeError(
83+
'VisualizationHook must be included in default_hooks.'
84+
'refer to usage '
85+
'"visualization=dict(type=\'VisualizationHook\')"')
86+
87+
return cfg
88+
89+
90+
def main():
91+
args = parse_args()
92+
93+
# load config
94+
cfg = Config.fromfile(args.config)
95+
cfg.launcher = args.launcher
96+
if args.cfg_options is not None:
97+
cfg.merge_from_dict(args.cfg_options)
98+
99+
# work_dir is determined in this priority: CLI > segment in file > filename
100+
if args.work_dir is not None:
101+
# update configs according to CLI args if args.work_dir is not None
102+
cfg.work_dir = args.work_dir
103+
elif cfg.get('work_dir', None) is None:
104+
# use config filename as default work_dir if cfg.work_dir is None
105+
cfg.work_dir = osp.join('./work_dirs',
106+
osp.splitext(osp.basename(args.config))[0])
107+
108+
cfg.load_from = args.checkpoint
109+
110+
if args.show or args.show_dir:
111+
cfg = trigger_visualization_hook(cfg, args)
112+
113+
if args.tta:
114+
cfg.test_dataloader.dataset.pipeline = cfg.tta_pipeline
115+
cfg.tta_model.module = cfg.model
116+
cfg.model = cfg.tta_model
117+
118+
# add output_dir in metric
119+
if args.out is not None:
120+
cfg.test_evaluator['output_dir'] = args.out
121+
cfg.test_evaluator['keep_results'] = True
122+
123+
# build the runner from config
124+
if 'runner_type' not in cfg:
125+
# build the default runner
126+
runner = Runner.from_cfg(cfg)
127+
else:
128+
# build customized runner from the registry if 'runner_type' is set in the cfg
129+
runner = RUNNERS.build(cfg)
130+
131+
# start testing
132+
runner.test()
133+
134+
135+
if __name__ == '__main__':
136+
main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
export CUDA_VISIBLE_DEVICES=5
2+
python -m torch.distributed.launch \
3+
--nnodes=1 \
4+
--node_rank=0 \
5+
--master_addr="127.0.0.1" \
6+
--nproc_per_node=1 \
7+
--master_port=29500 \
8+
scripts/benchmark/eval2d.py \
9+
"wildscenes/configs/mask2former/mask2former_swin-l-in22k-384x384-pre_2xb20-80k_wildscenes_standard-512x512.py" \
10+
"/raid/work/hau047/trained_models_wildscenes/2dmodels_oldvalset/mask2former_swin-l-in22k-384x384-pre_2xb20-80k_wildscenes_standard-512x512_dgx/best_mIoU_iter_24000_primarysplit.pth" \
11+
--show-dir \
12+
"/raid/work/hau047/wildscenes/Dev_IJRR_rebuttal/visualizations/mask2former_swin-l-in22k-384x384-pre_2xb20-80k_wildscenes_standard-512x512_dgx/images" \
13+
--launcher pytorch

scripts/benchmark/train2d.py

+109
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
# File modified from original file: https://github.com/open-mmlab/mmsegmentation/blob/main/tools/train.py
3+
# Modified to run 2D segmentation on the WildScenes dataset.
4+
import argparse
5+
import logging
6+
import os
7+
import os.path as osp
8+
import sys
9+
from pathlib import Path
10+
11+
from mmengine.config import Config, DictAction
12+
from mmengine.logging import print_log
13+
from mmengine.runner import Runner
14+
15+
root_dir = Path(__file__).parent.parent.parent
16+
sys.path.append(str(root_dir))
17+
18+
from wildscenes.mmseg_wildscenes.registry import RUNNERS
19+
20+
21+
def parse_args():
22+
parser = argparse.ArgumentParser(description='Train a segmentor on wildscenes')
23+
parser.add_argument('config', help='train config file path')
24+
parser.add_argument('--work-dir', help='the dir to save logs and models')
25+
parser.add_argument(
26+
'--resume',
27+
action='store_true',
28+
default=False,
29+
help='resume from the latest checkpoint in the work_dir automatically')
30+
parser.add_argument(
31+
'--amp',
32+
action='store_true',
33+
default=False,
34+
help='enable automatic-mixed-precision training')
35+
parser.add_argument(
36+
'--cfg-options',
37+
nargs='+',
38+
action=DictAction,
39+
help='override some settings in the used config, the key-value pair '
40+
'in xxx=yyy format will be merged into config file. If the value to '
41+
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
42+
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
43+
'Note that the quotation marks are necessary and that no white space '
44+
'is allowed.')
45+
parser.add_argument(
46+
'--launcher',
47+
choices=['none', 'pytorch', 'slurm', 'mpi'],
48+
default='none',
49+
help='job launcher')
50+
# When using PyTorch version >= 2.0.0, the `torch.distributed.launch`
51+
# will pass the `--local-rank` parameter to `tools/train.py` instead
52+
# of `--local_rank`.
53+
parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
54+
args = parser.parse_args()
55+
if 'LOCAL_RANK' not in os.environ:
56+
os.environ['LOCAL_RANK'] = str(args.local_rank)
57+
58+
return args
59+
60+
61+
def main():
62+
args = parse_args()
63+
64+
# load config
65+
cfg = Config.fromfile(args.config)
66+
cfg.launcher = args.launcher
67+
if args.cfg_options is not None:
68+
cfg.merge_from_dict(args.cfg_options)
69+
70+
# work_dir is determined in this priority: CLI > segment in file > filename
71+
if args.work_dir is not None:
72+
# update configs according to CLI args if args.work_dir is not None
73+
cfg.work_dir = args.work_dir
74+
elif cfg.get('work_dir', None) is None:
75+
# use config filename as default work_dir if cfg.work_dir is None
76+
cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0])
77+
78+
# enable automatic-mixed-precision training
79+
if args.amp is True:
80+
optim_wrapper = cfg.optim_wrapper.type
81+
if optim_wrapper == 'AmpOptimWrapper':
82+
print_log(
83+
'AMP training is already enabled in your config.',
84+
logger='current',
85+
level=logging.WARNING)
86+
else:
87+
assert optim_wrapper == 'OptimWrapper', (
88+
'`--amp` is only supported when the optimizer wrapper type is '
89+
f'`OptimWrapper` but got {optim_wrapper}.')
90+
cfg.optim_wrapper.type = 'AmpOptimWrapper'
91+
cfg.optim_wrapper.loss_scale = 'dynamic'
92+
93+
# resume training
94+
cfg.resume = args.resume
95+
96+
# build the runner from config
97+
if 'runner_type' not in cfg:
98+
# build the default runner
99+
runner = Runner.from_cfg(cfg)
100+
else:
101+
# build customized runner from the registry if 'runner_type' is set in the cfg
102+
runner = RUNNERS.build(cfg)
103+
104+
# start training
105+
runner.train()
106+
107+
108+
if __name__ == '__main__':
109+
main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
2+
python -m torch.distributed.launch \
3+
--nnodes=1 \
4+
--node_rank=0 \
5+
--master_addr="127.0.0.1" \
6+
--nproc_per_node=1 \
7+
--master_port=29500 \
8+
scripts/benchmark/train2d.py \
9+
"wildscenes/configs/mask2former/mask2former_r50_2xb20-80k_wildscenes_standard-512x512_dgx.py" \
10+
--launcher pytorch

scripts/data/setup_data.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@
33
from pathlib import Path
44
import argparse
55

6-
from wildscenes.tools import wildscenes_converter
7-
8-
96
root_dir = Path(__file__).parent.parent.parent
107
sys.path.append(str(root_dir))
118

9+
from wildscenes.tools import wildscenes_converter
10+
1211

1312
'''
1413
Need to run setup_data each time start using our dataset. This file converts the raw split data into full path info
@@ -38,7 +37,8 @@ def main(args):
3837

3938
if __name__ == '__main__':
4039
parser = argparse.ArgumentParser()
41-
parser.add_argument('--dataset_rootdir', type=str, required=True, default=None)
40+
parser.add_argument('--dataset_rootdir', type=str, required=True, default=None,
41+
help='This is the full path to the root directory of WildScenes')
4242
parser.add_argument('--splitdir', type=Path, default=root_dir / "data" / "splits")
4343
parser.add_argument('--savedir', type=Path, default=root_dir / "data" / "processed")
4444
parser.add_argument('--overwrite', default=False, action='store_true',

0 commit comments

Comments
 (0)