-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[zero] add chunk size searching algorithm for parameters in different…
… groups
- Loading branch information
Showing
3 changed files
with
164 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from .chunkv2 import ChunkV2 | ||
from .search_utils import clasify_params, search_chunk_configuration |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
from typing import Dict, List | ||
import numpy as np | ||
import torch.nn as nn | ||
from colossalai.tensor import ColoParameter | ||
|
||
|
||
def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None: | ||
"""Filter those parameters whose size is too large from others. | ||
""" | ||
params_size = [p.numel() for p in model.parameters()] | ||
params_size_arr = np.array(params_size) | ||
|
||
std = np.std(params_size_arr) | ||
mean = np.mean(params_size_arr) | ||
upper_limit = mean + 3 * std | ||
|
||
for key in size_dict: | ||
org_list = size_dict[key] | ||
size_dict[key] = list(filter(lambda x: x <= upper_limit, org_list)) | ||
|
||
|
||
def _get_unused_byte(size_list: List[int], chunk_size: int) -> int: | ||
"""Get unused byte for a certain chunk size. | ||
""" | ||
acc = 0 | ||
left = 0 | ||
for s in size_list: | ||
if s > left: | ||
acc += left | ||
left = chunk_size | ||
left -= s | ||
return left + acc | ||
|
||
|
||
def clasify_params(model: nn.Module) -> Dict[int, List[ColoParameter]]: | ||
params_dict: Dict[int, List[ColoParameter]] = dict() | ||
for param in model.parameters(): | ||
assert isinstance(param, ColoParameter), "please init model in the ColoInitContext" | ||
param_key = param.process_group.dp_world_size() | ||
|
||
if param_key not in params_dict: | ||
params_dict[param_key] = [] | ||
params_dict[param_key].append(param) | ||
|
||
return params_dict | ||
|
||
|
||
def search_chunk_configuration( | ||
model: nn.Module, | ||
search_range_mb: int, | ||
search_interval_byte: int, # hidden size is the best value for the interval | ||
min_chunk_size_mb: int = 32, | ||
filter_exlarge_params: bool = True | ||
): | ||
search_range_byte = search_range_mb * 1024 ** 2 | ||
min_chunk_size_byte = min_chunk_size_mb * 1024 ** 2 | ||
assert search_range_byte % search_interval_byte == 0 | ||
|
||
params_dict = clasify_params(model) | ||
config_dict: Dict[int, Dict] = dict() | ||
|
||
size_dict: Dict[int, List[int]] = dict() | ||
for key in params_dict: | ||
params_list = params_dict[key] | ||
size_list = [p.numel() for p in params_list] | ||
# let small parameters keep gathered in CUDA all the time | ||
total_size = sum(size_list) | ||
if total_size < min_chunk_size_byte: | ||
config_dict[key] = dict(chunk_size=total_size, keep_gathered=True) | ||
else: | ||
size_dict[key] = size_list | ||
|
||
if filter_exlarge_params: | ||
_filter_exlarge_params(model, size_dict) | ||
|
||
max_size = min_chunk_size_byte | ||
for key in size_dict: | ||
max_size = max(max_size, max(size_dict[key])) | ||
|
||
min_chunk_waste = float('+inf') | ||
best_chunk_size = max_size | ||
|
||
for chunk_size in range(max_size, max_size + search_range_byte + 1, search_interval_byte): | ||
temp_waste = 0 | ||
for key in size_dict: | ||
temp_waste += _get_unused_byte(size_dict[key], chunk_size) | ||
if temp_waste < min_chunk_waste: | ||
min_chunk_waste = temp_waste | ||
best_chunk_size = chunk_size | ||
|
||
for key in params_dict: | ||
if key in config_dict: | ||
continue | ||
config_dict[key] = dict(chunk_size=best_chunk_size, keep_gathered=False) | ||
|
||
return config_dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import pytest | ||
|
||
from functools import partial | ||
|
||
import torch | ||
import torch.multiprocessing as mp | ||
import torch.distributed as dist | ||
|
||
import colossalai | ||
from colossalai.testing import rerun_if_address_is_in_use | ||
from colossalai.gemini.update import search_chunk_configuration | ||
from colossalai.utils import free_port, get_current_device | ||
from colossalai.utils.model.colo_init_context import ColoInitContext | ||
from colossalai.tensor import ShardSpec, ComputePattern, ComputeSpec, ProcessGroup | ||
from tests.components_to_test.registry import non_distributed_component_funcs | ||
|
||
|
||
def init_1d_row_spec(model, pg: ProcessGroup): | ||
tensor_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) | ||
for n, p in model.named_parameters(): | ||
if 'weight' in n and 'ln' not in n: | ||
p.set_process_group(pg) | ||
p.set_tensor_spec(*tensor_spec) | ||
|
||
|
||
def exam_search_chunk_size(): | ||
|
||
world_size = torch.distributed.get_world_size() | ||
pg_tp = ProcessGroup(tp_degree=world_size) | ||
|
||
get_components_func = non_distributed_component_funcs.get_callable('gpt2') | ||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() | ||
|
||
# make sure torch_model and model has the same parameter values | ||
with ColoInitContext(device=get_current_device()): | ||
model = model_builder() | ||
init_1d_row_spec(model, pg_tp) | ||
config_dict = search_chunk_configuration( | ||
model, | ||
search_range_mb=1, | ||
search_interval_byte=16, | ||
min_chunk_size_mb=0, | ||
filter_exlarge_params=True) | ||
|
||
for key in config_dict: | ||
chunk_size = config_dict[key]['chunk_size'] | ||
if world_size == 1: | ||
assert chunk_size == 31616 | ||
else: | ||
assert chunk_size == 1024 | ||
|
||
|
||
def run_dist(rank, world_size, port): | ||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') | ||
exam_search_chunk_size() | ||
|
||
|
||
@pytest.mark.dist | ||
@pytest.mark.parametrize('world_size', [1, 4]) | ||
@rerun_if_address_is_in_use() | ||
def test_search(world_size): | ||
run_func = partial(run_dist, world_size=world_size, port=free_port()) | ||
mp.spawn(run_func, nprocs=world_size) | ||
|
||
|
||
if __name__ == '__main__': | ||
test_search(4) |