-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[colotensor] add Tensor.view op and its unit test
[colotensor] add megatron initialization for gpt2
- Loading branch information
Showing
16 changed files
with
309 additions
and
79 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import math | ||
import torch | ||
from colossalai.tensor.op_wrapper import colo_op_impl | ||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec | ||
from typing import Optional, Union | ||
|
||
|
||
def _all_int(my_iter): | ||
return all(isinstance(i, int) for i in my_iter) | ||
|
||
|
||
def _get_valid_shape(shape): | ||
if isinstance(shape, list): | ||
if _all_int(shape): | ||
return tuple(shape) | ||
else: | ||
raise RuntimeError("expects type(int) but finds an other type") | ||
elif isinstance(shape, tuple): | ||
if _all_int(shape): | ||
return shape | ||
else: | ||
return _get_valid_shape(shape[0]) | ||
else: | ||
raise RuntimeError("expects an iterable array but finds '{}'".format(type(shape))) | ||
|
||
|
||
def _shape_infer(org_sp, tgt_sp): | ||
cnt = 0 | ||
pos = 0 | ||
for idx, dim in enumerate(tgt_sp): | ||
if dim < -1: | ||
raise RuntimeError("invalid shape dimension {}".format(dim)) | ||
elif dim == -1: | ||
cnt += 1 | ||
pos = idx | ||
|
||
if cnt > 1: | ||
raise RuntimeError("only one dimension can be inferred") | ||
|
||
org_prod = math.prod(org_sp) | ||
tgt_prod = math.prod(tgt_sp) | ||
|
||
if cnt == 0: | ||
if org_prod != tgt_prod: | ||
raise RuntimeError("shape '{}' is invalid for input of size {}".format(tgt_sp, org_prod)) | ||
else: | ||
return tgt_sp | ||
elif org_prod % tgt_prod != 0: | ||
raise RuntimeError("shape '{}' is invalid for input of size {}".format(tgt_sp, org_prod)) | ||
|
||
infer_dim = -(org_prod // tgt_prod) | ||
return tgt_sp[: pos] + (infer_dim,) + tgt_sp[pos + 1:] | ||
|
||
|
||
@colo_op_impl(torch.Tensor.view) | ||
def colo_view(self: ColoTensor, *shape) -> 'ColoTensor': | ||
"""Handles ``__torch_function__`` dispatch for ``torch.Tensor.view``. | ||
Changes the shape of the current tensor. | ||
""" | ||
assert isinstance(self, ColoTensor) | ||
# apply original `view` function for replicated colo tensors | ||
if self.is_replicate(): | ||
return self.view(*shape) | ||
|
||
cur_sp = self.size() | ||
org_sp = self.size_global() | ||
# parse the passed arguments | ||
tgt_sp = _get_valid_shape(shape) | ||
# get the correct shape from inference | ||
inf_sp = _shape_infer(org_sp, tgt_sp) | ||
|
||
if self.is_shard_1drow() and org_sp[0] == inf_sp[0]: | ||
new_shape = (cur_sp[0],) + tgt_sp[1:] | ||
res = self.view(*new_shape) | ||
elif self.is_shard_1dcol() and org_sp[-1] == inf_sp[-1]: | ||
new_shape = tgt_sp[:-1] + (cur_sp[-1],) | ||
res = self.view(*new_shape) | ||
else: | ||
replicated_t = self.redistribute(dist_spec=ReplicaSpec()) | ||
return ColoTensor.from_torch_tensor( | ||
tensor=replicated_t.view(*shape), | ||
spec=ColoTensorSpec(self.get_process_group())) | ||
|
||
return ColoTensor.from_torch_tensor( | ||
tensor=res, | ||
spec=ColoTensorSpec( | ||
pg=self.get_process_group(), | ||
dist_attr=self.dist_spec)) | ||
|
||
|
||
@colo_op_impl(torch.Tensor.size) | ||
def colo_size(self: ColoTensor, dim: Optional[int] = None) -> Union[torch.Size, int]: | ||
size = self.size_global() | ||
if dim is None: | ||
return size | ||
else: | ||
return size[dim] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.