Skip to content

Commit

Permalink
[colotensor] add Tensor.view op and its unit test
Browse files Browse the repository at this point in the history
[colotensor] add megatron initialization for gpt2
  • Loading branch information
1SAA committed Jul 20, 2022
1 parent 92b0b13 commit 2c4c227
Show file tree
Hide file tree
Showing 13 changed files with 291 additions and 46 deletions.
1 change: 1 addition & 0 deletions colossalai/nn/_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from .embedding import colo_embedding
from .addmm import colo_addmm
from .embedding_bag import colo_embedding_bag
from .view import colo_view
4 changes: 3 additions & 1 deletion colossalai/nn/_ops/addmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ def colo_addmm(input_tensor: GeneralTensor,
if not mat2.has_compute_spec(): # No Model Parallel Applied
assert mat2.is_replicate(), 'Invalid mat2 spec for native addmm op'
assert input_tensor.is_replicate(), 'Invalid input spec for native addmm op'
ret_tensor = ColoTensor.from_torch_tensor(torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha))
ret_tensor = ColoTensor.from_torch_tensor(
tensor=torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha),
spec=ColoTensorSpec(mat2.get_process_group()))
elif mat2.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if mat2.is_shard_1drow() and input_tensor.is_replicate():
mode = 'row'
Expand Down
22 changes: 12 additions & 10 deletions colossalai/nn/_ops/embedding.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import torch.nn.functional as F
from typing import Optional
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ComputePattern, ColoTensorSpec, ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ReplicaSpec
from colossalai.tensor import ComputePattern, ColoTensorSpec, ComputePattern, ComputeSpec, ColoTensor, ShardSpec, \
ReplicaSpec
from ._utils import GeneralTensor, convert_to_colo_tensor, reduce_input


Expand Down Expand Up @@ -110,17 +111,18 @@ def colo_embedding(input_tensor: GeneralTensor,
assert isinstance(weight, ColoTensor)
input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group())

if not weight.has_compute_spec(): # No Model Parallel Applied
if not weight.has_compute_spec(): # No Model Parallel Applied
assert weight.is_replicate(), 'Invalid weight spec for native embedding op'
return ColoTensor.from_torch_tensor(
F.embedding(input_tensor,
weight,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse))
elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
tensor=F.embedding(input_tensor,
weight,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse),
spec=ColoTensorSpec(weight.get_process_group()))
elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if weight.is_shard_1drow():
mode = 'row'
elif weight.is_shard_1dcol():
Expand Down
30 changes: 16 additions & 14 deletions colossalai/nn/_ops/embedding_bag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from typing import Optional
from torch import Tensor
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor, distspec, ColoTensorSpec, ShardSpec, ReplicaSpec
from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor, distspec, ColoTensorSpec, \
ShardSpec, ReplicaSpec
from ._utils import GeneralTensor, convert_to_colo_tensor


Expand Down Expand Up @@ -89,21 +90,22 @@ def colo_embedding_bag(input_tensor: GeneralTensor,

# Handle differen parallel actions.

if not weight.has_compute_spec(): # No Model Parallel Applied
if not weight.has_compute_spec(): # No Model Parallel Applied
assert weight.is_replicate(), 'Invalid weight spec for native embedding op'
return ColoTensor.from_torch_tensor(
F.embedding_bag(input_tensor,
weight,
offsets=offsets,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
mode=mode,
sparse=sparse,
per_sample_weights=per_sample_weights,
include_last_offset=include_last_offset,
padding_idx=padding_idx))
elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
tensor=F.embedding_bag(input_tensor,
weight,
offsets=offsets,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
mode=mode,
sparse=sparse,
per_sample_weights=per_sample_weights,
include_last_offset=include_last_offset,
padding_idx=padding_idx),
spec=ColoTensorSpec(weight.get_process_group()))
elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if weight.is_shard_1dcol():
tp_mode = 'col'
else:
Expand Down
6 changes: 5 additions & 1 deletion colossalai/nn/_ops/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,9 @@ def colo_layernorm(
input_tensor = input_tensor.redistribute(ReplicaSpec())

output = F.layer_norm(input_tensor, normalized_shape, weight=weight, bias=bias, eps=eps)
output = ColoTensor.from_torch_tensor(output, ColoTensorSpec(input_tensor.get_process_group()))
output = ColoTensor.from_torch_tensor(
tensor=output,
spec=ColoTensorSpec(
pg=input_tensor.get_process_group(),
dist_attr=input_tensor.dist_spec))
return output
94 changes: 94 additions & 0 deletions colossalai/nn/_ops/view.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import math
import torch
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ColoTensor, ColoTensorSpec
from typing import Optional, Union


def _all_int(my_iter):
return all(isinstance(i, int) for i in my_iter)


def _get_valid_shape(shape):
if isinstance(shape, list):
if _all_int(shape):
return tuple(shape)
else:
raise RuntimeError("expects type(int) but finds an other type")
elif isinstance(shape, tuple):
if _all_int(shape):
return shape
else:
return _get_valid_shape(shape[0])
else:
raise RuntimeError("expects an iterable array but finds '{}'".format(type(shape)))


def _shape_infer(org_sp, tgt_sp):
cnt = 0
pos = 0
for idx, dim in enumerate(tgt_sp):
if dim < -1:
raise RuntimeError("invalid shape dimension {}".format(dim))
elif dim == -1:
cnt += 1
pos = idx

if cnt > 1:
raise RuntimeError("only one dimension can be inferred")

org_prod = math.prod(org_sp)
tgt_prod = math.prod(tgt_sp)

if cnt == 0:
if org_prod != tgt_prod:
raise RuntimeError("shape '{}' is invalid for input of size {}".format(tgt_sp, org_prod))
else:
return tgt_sp
elif org_prod % tgt_prod != 0:
raise RuntimeError("shape '{}' is invalid for input of size {}".format(tgt_sp, org_prod))

infer_dim = -(org_prod // tgt_prod)
return tgt_sp[: pos] + (infer_dim,) + tgt_sp[pos + 1:]


@colo_op_impl(torch.Tensor.view)
def colo_view(self: ColoTensor, *shape) -> 'ColoTensor':
"""Handles ``__torch_function__`` dispatch for ``torch.Tensor.view``.
Changes the shape of the current tensor.
"""
assert isinstance(self, ColoTensor)
# apply original `view` function for replicated colo tensors
if self.is_replicate():
return self.view(*shape)

cur_sp = self.size()
org_sp = self.size_global()
# parse the passed arguments
tgt_sp = _get_valid_shape(shape)
# get the correct shape from inference
inf_sp = _shape_infer(org_sp, tgt_sp)

if self.is_shard_1drow() and org_sp[0] == inf_sp[0]:
new_shape = (cur_sp[0],) + tgt_sp[1:]
res = self.view(*new_shape)
elif self.is_shard_1dcol() and org_sp[-1] == inf_sp[-1]:
new_shape = tgt_sp[:-1] + (cur_sp[-1],)
res = self.view(*new_shape)
else:
return self.view_global(*inf_sp)

return ColoTensor.from_torch_tensor(
tensor=res,
spec=ColoTensorSpec(
pg=self.get_process_group(),
dist_attr=self.dist_spec))


@colo_op_impl(torch.Tensor.size)
def colo_size(self: ColoTensor, dim: Optional[int] = None) -> Union[torch.Size, int]:
size = self.size_global()
if dim is None:
return size
else:
return size[dim]
26 changes: 14 additions & 12 deletions colossalai/tensor/colo_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,28 +22,30 @@ def _get_my_nowrap_functions() -> Set[Callable]:
}


def _convert_output(output, pg: ProcessGroup):
def _convert_output(output, colo_spec: ColoTensorSpec):
if type(output) == torch.Tensor:
return ColoTensor.from_torch_tensor(output, ColoTensorSpec(pg))
return ColoTensor.from_torch_tensor(output, colo_spec)
elif isinstance(output, (list, tuple)):
return type(output)(_convert_output(o, pg) for o in output)
return type(output)(_convert_output(o, colo_spec) for o in output)
else:
return output


def _scan_for_pg_from_args(args, kwargs) -> ProcessGroup:
def _get_spec_from_args(args, kwargs) -> ColoTensorSpec:
for elem in args:
if isinstance(elem, ColoTensor):
pg = elem.get_process_group()
return pg
dp = elem.dist_spec
return ColoTensorSpec(pg, dp)
elif isinstance(elem, (list, tuple)):
pg = _scan_for_pg_from_args(elem, {})
if pg is not None:
return pg
spec = _get_spec_from_args(elem, {})
if spec is not None:
return spec
for k, v in kwargs.items():
if isinstance(v, ColoTensor):
pg = v.get_process_group()
return pg
dp = v.dist_spec
return ColoTensorSpec(pg, dp)
return None


Expand Down Expand Up @@ -170,11 +172,11 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
if func in _get_my_nowrap_functions():
return ret
else:
pg = _scan_for_pg_from_args(args, kwargs)
return _convert_output(ret, pg)
colo_spec = _get_spec_from_args(args, kwargs)
return _convert_output(ret, colo_spec)

def __repr__(self):
return f'ColoTensor:\n{super().__repr__()}\n{self.dist_spec}\n{self.process_group}'
return f'ColoTensor:\n{super().__repr__()}\n{self.dist_spec}\n{self.process_group}\n{self.compute_spec}'

def _redistribute(self, dist_spec: _DistSpec) -> None:
"""_redistribute
Expand Down
5 changes: 4 additions & 1 deletion colossalai/tensor/compute_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,7 @@ def __init__(self, compute_pattern: ComputePattern) -> None:
self.output_replicate = True

def __repr__(self):
return f'compute pattern: {self.compute_pattern}'
return f'Compute pattern: {self.compute_pattern}'

def set_output_replicate(self, flag: bool = True):
self.output_replicate = flag
4 changes: 2 additions & 2 deletions colossalai/utils/checkpoint/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torch.distributed as dist
from colossalai.tensor import ColoTensor, ColoTensorSpec
from colossalai.tensor.distspec import _DistSpec
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern


def gather_tensor(colo_tensor: ColoTensor) -> None:
Expand All @@ -26,7 +26,7 @@ def gather_tensor(colo_tensor: ColoTensor) -> None:
def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None:
"""Reversal operation of `gather_tensor`.
"""
if dist_spec.placement == 'r':
if dist_spec.placement == DistPlacementPattern.REPLICATE:
dist.broadcast(colo_tensor.data, 0)
else:
global_size = colo_tensor.size_global()
Expand Down
6 changes: 6 additions & 0 deletions tests/test_tensor/common_utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,9 @@ def split_param_row_tp1d(param, pg):

def split_param_col_tp1d(param, pg):
split_param_single_dim_tp1d(-1, param, pg)


def debug_print(ranks, *args):
if dist.get_rank() in ranks:
print(*args)
dist.barrier()
35 changes: 32 additions & 3 deletions tests/test_tensor/model/test_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from colossalai.tensor import ShardSpec, ComputePattern, ComputeSpec, ProcessGroup, ColoTensor, ColoTensorSpec
from colossalai.nn.parallel.data_parallel import ColoDDP
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import split_param_col_tp1d, split_param_row_tp1d, debug_print


def init_1d_row_spec(model, pg: ProcessGroup):
Expand All @@ -34,6 +35,32 @@ def init_1d_col_spec(model, pg: ProcessGroup):
p.set_tensor_spec(*spec)


def init_megatron_spec(model, pg: ProcessGroup):
for mn, module in model.named_modules():
# debug_print([0], mn)
for pn, param in module.named_parameters(recurse=False):
# debug_print([0], '\t', pn, param.compute_spec, param.shape)
param.set_process_group(pg)

if 'mlp.c_fc' in mn:
if 'weight' in pn or 'bias' in pn:
split_param_col_tp1d(param, pg)
param.compute_spec.set_output_replicate(False)
else:
raise RuntimeError
elif 'mlp.c_proj' in mn:
if 'weight' in pn:
split_param_row_tp1d(param, pg)
else:
assert 'bias' in pn
elif 'wte' in mn or 'wpe' in mn:
assert 'weight' in pn
split_param_col_tp1d(param, pg)
elif 'c_fc' in mn or 'c_proj' in mn:
split_param_col_tp1d(param, pg)
# debug_print([0], '\t', param.compute_spec, param.shape)


def check_param_equal(model, torch_model, pg: ProcessGroup):
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
assert pg.tp_local_rank() is not None, f"{pg.rank()} {pg.tp_world_size()} {pg._tp_degree} {pg.tp_local_rank()}1"
Expand Down Expand Up @@ -102,8 +129,10 @@ def run_dist(rank, world_size, port, use_ddp):
if use_ddp and world_size == 1:
return
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_gpt(init_1d_row_spec, use_ddp)
run_gpt(init_1d_col_spec, use_ddp)
# Comments below tests for speed concern
# run_gpt(init_1d_row_spec, use_ddp)
# run_gpt(init_1d_col_spec, use_ddp)
run_gpt(init_megatron_spec, use_ddp)


@pytest.mark.dist
Expand All @@ -116,4 +145,4 @@ def test_gpt(world_size, use_ddp):


if __name__ == '__main__':
test_gpt(4, use_ddp=True)
test_gpt(4, use_ddp=False)
Loading

0 comments on commit 2c4c227

Please sign in to comment.