Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CheckpointIO] a uniform checkpoint I/O module #1689

Merged
merged 43 commits into from
Nov 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
237c4f0
add meta info and utils
ver217 Sep 30, 2022
f149d3b
fix meta args
ver217 Oct 8, 2022
3e190ed
add test unmerge param for zero
ver217 Oct 8, 2022
89f938e
rename unit test
ver217 Oct 8, 2022
ce23dd2
polish unit test
ver217 Oct 8, 2022
9f1c6fc
add merge param
ver217 Oct 8, 2022
b393b81
polish merge param unit test
ver217 Oct 9, 2022
66870bf
add unmerge param unit test
ver217 Oct 10, 2022
8fa6cb8
add writer
ver217 Oct 11, 2022
8bbbe78
fix writer
ver217 Oct 11, 2022
3c01199
refactor writer
ver217 Oct 12, 2022
30580fb
refactor io
ver217 Oct 12, 2022
5478be1
add reader
ver217 Oct 13, 2022
a6802df
fix utils
ver217 Oct 13, 2022
34926ed
fix reader
ver217 Oct 13, 2022
67780ae
fix utils
ver217 Oct 13, 2022
4c67e23
add test build checkpoints
ver217 Oct 13, 2022
4fba79a
fix writer and save
ver217 Oct 14, 2022
21cd756
add test save
ver217 Oct 14, 2022
1795768
refactor writer
ver217 Oct 14, 2022
473bbd8
refactor sharder
ver217 Oct 14, 2022
279528d
add backend
ver217 Oct 14, 2022
9694bde
refactor codes
ver217 Oct 19, 2022
89d8dd8
add merge
ver217 Oct 19, 2022
e919521
add test merge
ver217 Oct 19, 2022
983efd0
polish is_duplicated_list
ver217 Oct 19, 2022
b9b241b
fix merge
ver217 Oct 19, 2022
ce81f6d
refactor unmerge param
ver217 Oct 20, 2022
465900d
fix unit test
ver217 Oct 20, 2022
85554e0
refactor convertor
ver217 Oct 20, 2022
2658b48
add optimizer checkpoint redistor
ver217 Oct 20, 2022
f88117b
add redist
ver217 Oct 20, 2022
66c537a
refactor reader
ver217 Oct 20, 2022
a48e3b3
fix convertor
ver217 Oct 20, 2022
d458062
add test redist
ver217 Oct 20, 2022
df5bac0
refactor convertor
ver217 Oct 20, 2022
6d192ed
add load
ver217 Oct 21, 2022
0d036b8
do not eliminate dp replica when saving
ver217 Oct 21, 2022
24a1450
fix optimizer load state dict
ver217 Oct 21, 2022
9f94fe8
fix load
ver217 Oct 21, 2022
2f5e4b2
add test load
ver217 Oct 21, 2022
878ea11
add __init__
ver217 Oct 21, 2022
60b7d27
fix bugs
ver217 Oct 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions colossalai/utils/checkpoint_io/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .io import load, merge, redist, save
from .meta import (ParamDistMeta, ParamRedistMeta, PipelineRedistMeta, RankRedistMeta, RedistMeta)
74 changes: 74 additions & 0 deletions colossalai/utils/checkpoint_io/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import shutil
import tempfile
from abc import ABC, abstractmethod
from typing import Dict, List, Type

from .reader import CheckpointReader, DiskCheckpointReader
from .writer import CheckpointWriter, DiskCheckpointWriter

_backends: Dict[str, Type['CheckpointIOBackend']] = {}


def register(name: str):
assert name not in _backends, f'"{name}" is registered'

def wrapper(cls):
_backends[name] = cls
return cls

return wrapper


def get_backend(name: str) -> 'CheckpointIOBackend':
assert name in _backends, f'Unsupported backend "{name}"'
return _backends[name]()


class CheckpointIOBackend(ABC):

def __init__(self) -> None:
super().__init__()
self.temps: List[str] = []

@abstractmethod
def get_writer(self,
base_name: str,
overwrite: bool = False,
rank: int = 0,
world_size: int = 1) -> CheckpointWriter:
pass

@abstractmethod
def get_reader(self, base_name: str) -> CheckpointReader:
pass

@abstractmethod
def get_temp(self, base_name: str) -> str:
pass

@abstractmethod
def clean_temp(self) -> None:
pass


@register('disk')
class CheckpointDiskIO(CheckpointIOBackend):

def get_writer(self,
base_name: str,
overwrite: bool = False,
rank: int = 0,
world_size: int = 1) -> CheckpointWriter:
return DiskCheckpointWriter(base_name, overwrite, rank=rank, world_size=world_size)

def get_reader(self, base_name: str) -> CheckpointReader:
return DiskCheckpointReader(base_name)

def get_temp(self, base_name: str) -> str:
temp_dir_name = tempfile.mkdtemp(dir=base_name)
self.temps.append(temp_dir_name)
return temp_dir_name

def clean_temp(self) -> None:
for temp_dir_name in self.temps:
shutil.rmtree(temp_dir_name)
9 changes: 9 additions & 0 deletions colossalai/utils/checkpoint_io/constant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import re

GLOBAL_META_FILE_NAME = 'global_meta.bin'
MODEL_CKPT_FILE_NAME = 'model.bin'
OPTIM_CKPT_FILE_NAME = 'optim.bin'
META_CKPT_FILE_NAME = 'meta.bin'
OTHER_CKPT_FILE_NAME = 'other.bin'

CKPT_PAT = re.compile(r'global_meta|model|optim|meta|other')
227 changes: 227 additions & 0 deletions colossalai/utils/checkpoint_io/convertor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Any, Callable, Dict, List, Optional

from torch import Tensor

from .distributed import merge_param, unmerge_param
from .meta import ParamDistMeta, RedistMeta
from .utils import (ModelCheckpointSharder, OptimizerCheckpointSharder, run_if_not_none)


class CheckpointConvertor(ABC):

@abstractmethod
def append(self, shard_dict: Dict[int, dict], dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None:
pass

@abstractmethod
def complete(self) -> None:
pass


class ModelCheckpointConvertor(CheckpointConvertor):

def __init__(self, param_count: Dict[str, int]) -> None:
super().__init__()
self.param_count = param_count
self.buffer: Dict[str, Dict[int, Tensor]] = defaultdict(dict)

@abstractmethod
def convert_tensors(self, key: str, tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> None:
pass

def append(self, shard_dict: Dict[int, dict], dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None:
for rank, state_dict in shard_dict.items():
for k, tensor in state_dict.items():
self.buffer[k][rank] = tensor
converted_keys = set()
for k, rank_dict in self.buffer.items():
if len(rank_dict) == self.param_count[k]:
tensors = []
dist_metas = []
for rank, tensor in rank_dict.items():
tensors.append(tensor)
if dist_meta_list[rank] is not None:
dist_metas.append(dist_meta_list[rank][k])
self.convert_tensors(k, tensors, dist_metas)
converted_keys.add(k)
for k in converted_keys:
del self.buffer[k]

def complete(self) -> None:
assert len(self.buffer) == 0


class ModelCheckpointMerger(ModelCheckpointConvertor):

def __init__(self, max_shard_size: int, save_fn: Callable[[dict], Any], param_count: Dict[str, int]) -> None:
super().__init__(param_count)
self.sharder = ModelCheckpointSharder(max_shard_size)
self.save_fn = save_fn

def convert_tensors(self, key: str, tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> None:
assert len(dist_metas) == len(tensors)
tensor = merge_param(tensors, dist_metas)
shard = self.sharder.append(key, tensor)
run_if_not_none(self.save_fn, shard)

def complete(self) -> None:
super().complete()
run_if_not_none(self.save_fn, self.sharder.complete())


class ModelCheckpointRedistor(ModelCheckpointConvertor):

def __init__(self, max_shard_size: int, save_fns: List[Callable[[dict], Any]], param_count: Dict[str, int],
redist_meta: RedistMeta) -> None:
super().__init__(param_count)
self.save_fns = save_fns
self.redist_meta = redist_meta
nprocs = len(save_fns)
self.sharders = [ModelCheckpointSharder(max_shard_size) for _ in range(nprocs)]
self.rank_map = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
for k, rank_meta in redist_meta.rank_meta.items():
for rank, rank_info in rank_meta.items():
self.rank_map[k][rank_info.tp_rank][rank_info.dp_rank].append(rank)

def convert_tensors(self, key: str, tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> None:
if len(dist_metas) == 0:
# already global
tensor = tensors[0]
else:
assert len(dist_metas) == len(tensors)
tensor = merge_param(tensors, dist_metas)
for tp_rank, tensor_list in enumerate(unmerge_param(tensor, self.redist_meta.param_meta[key])):
for dp_rank, t in enumerate(tensor_list):
for rank in self.rank_map[key][tp_rank][dp_rank]:
shard = self.sharders[rank].append(key, t)
run_if_not_none(self.save_fns[rank], shard)

def complete(self) -> None:
super().complete()
for rank, save_fn in enumerate(self.save_fns):
run_if_not_none(save_fn, self.sharders[rank].complete())


class OptimizerCheckpointConvertor(CheckpointConvertor):

def __init__(self, param_count: Dict[str, int], param_to_os: Optional[Dict[str, int]],
paired_os: Optional[Dict[int, dict]]) -> None:
super().__init__()
self.param_count = param_count
self.param_to_os = param_to_os
self.paired_os = paired_os
self.buffer: Dict[int, Dict[int, dict]] = defaultdict(dict)
self.os_to_param = {v: k for k, v in param_to_os.items()}

@abstractmethod
def setup(self, param_groups: dict) -> None:
pass

@abstractmethod
def convert_states(self, idx: int, states: List[dict], dist_metas: List[ParamDistMeta]) -> None:
pass

def append(self, shard_dict: Dict[int, dict], dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None:
for rank, state_dict in shard_dict.items():
self.setup(state_dict['param_groups'])
for idx, state in state_dict['state'].items():
self.buffer[idx][rank] = state
converted_indices = set()
for idx, rank_dict in self.buffer.items():
if len(rank_dict) == self.param_count[self.os_to_param[idx]]:
states = []
dist_metas = []
for rank, state in rank_dict.items():
states.append(state)
if dist_meta_list[rank] is not None:
dist_metas.append(dist_meta_list[rank][self.os_to_param[idx]])
self.convert_states(idx, states, dist_metas)
converted_indices.add(idx)
for idx in converted_indices:
del self.buffer[idx]

def complete(self) -> None:
assert len(self.buffer) == 0


class OptimizerCheckpointMerger(OptimizerCheckpointConvertor):

def __init__(self, max_shard_size: int, save_fn: Callable[[dict], Any], param_count: Dict[str, int],
param_to_os: Optional[Dict[str, int]], paired_os: Optional[Dict[int, dict]]) -> None:
super().__init__(param_count, param_to_os, paired_os)
self.max_shard_size = max_shard_size
self.save_fn = save_fn
self.sharder = None

def setup(self, param_groups: dict) -> None:
if self.sharder is None:
self.sharder = OptimizerCheckpointSharder(self.max_shard_size, param_groups)

def convert_states(self, idx: int, states: List[dict], dist_metas: List[ParamDistMeta]) -> None:
assert len(dist_metas) == len(states)
new_state = {}
for state_key, state_tensor in states[0].items():
if self.paired_os[idx][state_key]:
new_state[state_key] = merge_param([state[state_key] for state in states], dist_metas)
else:
new_state[state_key] = state_tensor
shard = self.sharder.append(idx, new_state)
run_if_not_none(self.save_fn, shard)

def complete(self) -> None:
super().complete()
run_if_not_none(self.save_fn, self.sharder.complete())


class OptimizerCheckpointRedistor(OptimizerCheckpointConvertor):

def __init__(self, max_shard_size: int, save_fns: List[Callable[[dict], Any]], param_count: Dict[str, int],
param_to_os: Optional[Dict[str, int]], paired_os: Optional[Dict[int, dict]],
redist_meta: RedistMeta) -> None:
super().__init__(param_count, param_to_os, paired_os)
self.max_shard_size = max_shard_size
self.save_fns = save_fns
self.redist_meta = redist_meta
self.sharders: List[OptimizerCheckpointSharder] = []
self.rank_map = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
for k, rank_meta in redist_meta.rank_meta.items():
for rank, rank_info in rank_meta.items():
self.rank_map[k][rank_info.tp_rank][rank_info.dp_rank].append(rank)

def setup(self, param_groups: dict) -> None:
if len(self.sharders) == 0:
nprocs = len(self.save_fns)
for _ in range(nprocs):
self.sharders.append(OptimizerCheckpointSharder(self.max_shard_size, param_groups))

def convert_states(self, idx: int, states: List[dict], dist_metas: List[ParamDistMeta]) -> None:
need_merge: bool = True
if len(dist_metas) == 0:
need_merge = False
else:
assert len(dist_metas) == len(states)
new_states = [{} for _ in range(len(self.save_fns))]
for state_key, state_tensor in states[0].items():
if self.paired_os[idx][state_key]:
if need_merge:
tensor = merge_param([state[state_key] for state in states], dist_metas)
else:
tensor = state_tensor
for tp_rank, tensor_list in enumerate(
unmerge_param(tensor, self.redist_meta.param_meta[self.os_to_param[idx]])):
for dp_rank, t in enumerate(tensor_list):
for rank in self.rank_map[self.os_to_param[idx]][tp_rank][dp_rank]:
new_states[rank][state_key] = t
else:
for new_state in new_states:
new_state[state_key] = state_tensor
for rank, new_state in enumerate(new_states):
shard = self.sharders[rank].append(idx, new_state)
run_if_not_none(self.save_fns[rank], shard)

def complete(self) -> None:
super().complete()
for rank, save_fn in enumerate(self.save_fns):
run_if_not_none(save_fn, self.sharders[rank].complete())
Loading