-
Notifications
You must be signed in to change notification settings - Fork 378
/
Copy pathsampler_seed_hook.py
36 lines (29 loc) · 1.43 KB
/
sampler_seed_hook.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.registry import HOOKS
from .hook import Hook
@HOOKS.register_module()
class DistSamplerSeedHook(Hook):
"""Data-loading sampler for distributed training.
When distributed training, it is only useful in conjunction with
:obj:`EpochBasedRunner`, while :obj:`IterBasedRunner` achieves the same
purpose with :obj:`IterLoader`.
"""
priority = 'NORMAL'
def before_train_epoch(self, runner) -> None:
"""Set the seed for sampler and batch_sampler.
Args:
runner (Runner): The runner of the training process.
"""
if hasattr(runner.train_loop.dataloader, 'sampler') and hasattr(
runner.train_loop.dataloader.sampler, 'set_epoch'):
# In case the` _SingleProcessDataLoaderIter` has no sampler,
# or data loader uses `SequentialSampler` in Pytorch.
runner.train_loop.dataloader.sampler.set_epoch(runner.epoch)
elif hasattr(runner.train_loop.dataloader,
'batch_sampler') and hasattr(
runner.train_loop.dataloader.batch_sampler.sampler,
'set_epoch'):
# In case the` _SingleProcessDataLoaderIter` has no batch sampler.
# batch sampler in pytorch warps the sampler as its attributes.
runner.train_loop.dataloader.batch_sampler.sampler.set_epoch(
runner.epoch)