diff --git a/colossalai/nn/_ops/__init__.py b/colossalai/nn/_ops/__init__.py index 784f0abc6691..c91da3ad179a 100644 --- a/colossalai/nn/_ops/__init__.py +++ b/colossalai/nn/_ops/__init__.py @@ -5,3 +5,4 @@ from .embedding import colo_embedding from .addmm import colo_addmm from .embedding_bag import colo_embedding_bag +from .view import colo_view diff --git a/colossalai/nn/_ops/addmm.py b/colossalai/nn/_ops/addmm.py index 235c45376c46..ce7e8bef63e7 100644 --- a/colossalai/nn/_ops/addmm.py +++ b/colossalai/nn/_ops/addmm.py @@ -69,7 +69,9 @@ def colo_addmm(input_tensor: GeneralTensor, if not mat2.has_compute_spec(): # No Model Parallel Applied assert mat2.is_replicate(), 'Invalid mat2 spec for native addmm op' assert input_tensor.is_replicate(), 'Invalid input spec for native addmm op' - ret_tensor = ColoTensor.from_torch_tensor(torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha)) + ret_tensor = ColoTensor.from_torch_tensor( + tensor=torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha), + spec=ColoTensorSpec(mat2.get_process_group())) elif mat2.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied if mat2.is_shard_1drow() and input_tensor.is_replicate(): mode = 'row' diff --git a/colossalai/nn/_ops/embedding.py b/colossalai/nn/_ops/embedding.py index 0ff5b02aec24..2040d83c1fca 100644 --- a/colossalai/nn/_ops/embedding.py +++ b/colossalai/nn/_ops/embedding.py @@ -1,7 +1,8 @@ import torch.nn.functional as F from typing import Optional from colossalai.tensor.op_wrapper import colo_op_impl -from colossalai.tensor import ComputePattern, ColoTensorSpec, ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ReplicaSpec +from colossalai.tensor import ComputePattern, ColoTensorSpec, ComputePattern, ComputeSpec, ColoTensor, ShardSpec, \ + ReplicaSpec from ._utils import GeneralTensor, convert_to_colo_tensor, reduce_input @@ -110,17 +111,18 @@ def colo_embedding(input_tensor: GeneralTensor, assert isinstance(weight, ColoTensor) input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group()) - if not weight.has_compute_spec(): # No Model Parallel Applied + if not weight.has_compute_spec(): # No Model Parallel Applied assert weight.is_replicate(), 'Invalid weight spec for native embedding op' return ColoTensor.from_torch_tensor( - F.embedding(input_tensor, - weight, - padding_idx=padding_idx, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse)) - elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied + tensor=F.embedding(input_tensor, + weight, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse), + spec=ColoTensorSpec(weight.get_process_group())) + elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied if weight.is_shard_1drow(): mode = 'row' elif weight.is_shard_1dcol(): diff --git a/colossalai/nn/_ops/embedding_bag.py b/colossalai/nn/_ops/embedding_bag.py index 1a9af107cc6d..cdab44856acf 100644 --- a/colossalai/nn/_ops/embedding_bag.py +++ b/colossalai/nn/_ops/embedding_bag.py @@ -2,7 +2,8 @@ from typing import Optional from torch import Tensor from colossalai.tensor.op_wrapper import colo_op_impl -from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor, distspec, ColoTensorSpec, ShardSpec, ReplicaSpec +from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor, distspec, ColoTensorSpec, \ + ShardSpec, ReplicaSpec from ._utils import GeneralTensor, convert_to_colo_tensor @@ -89,21 +90,22 @@ def colo_embedding_bag(input_tensor: GeneralTensor, # Handle differen parallel actions. - if not weight.has_compute_spec(): # No Model Parallel Applied + if not weight.has_compute_spec(): # No Model Parallel Applied assert weight.is_replicate(), 'Invalid weight spec for native embedding op' return ColoTensor.from_torch_tensor( - F.embedding_bag(input_tensor, - weight, - offsets=offsets, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - mode=mode, - sparse=sparse, - per_sample_weights=per_sample_weights, - include_last_offset=include_last_offset, - padding_idx=padding_idx)) - elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied + tensor=F.embedding_bag(input_tensor, + weight, + offsets=offsets, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + mode=mode, + sparse=sparse, + per_sample_weights=per_sample_weights, + include_last_offset=include_last_offset, + padding_idx=padding_idx), + spec=ColoTensorSpec(weight.get_process_group())) + elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied if weight.is_shard_1dcol(): tp_mode = 'col' else: diff --git a/colossalai/nn/_ops/layernorm.py b/colossalai/nn/_ops/layernorm.py index 4134408d1037..e3eef9b182a4 100644 --- a/colossalai/nn/_ops/layernorm.py +++ b/colossalai/nn/_ops/layernorm.py @@ -19,5 +19,9 @@ def colo_layernorm( input_tensor = input_tensor.redistribute(ReplicaSpec()) output = F.layer_norm(input_tensor, normalized_shape, weight=weight, bias=bias, eps=eps) - output = ColoTensor.from_torch_tensor(output, ColoTensorSpec(input_tensor.get_process_group())) + output = ColoTensor.from_torch_tensor( + tensor=output, + spec=ColoTensorSpec( + pg=input_tensor.get_process_group(), + dist_attr=input_tensor.dist_spec)) return output diff --git a/colossalai/nn/_ops/view.py b/colossalai/nn/_ops/view.py new file mode 100644 index 000000000000..3197e7568d6f --- /dev/null +++ b/colossalai/nn/_ops/view.py @@ -0,0 +1,97 @@ +import math +import torch +from colossalai.tensor.op_wrapper import colo_op_impl +from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec +from typing import Optional, Union + + +def _all_int(my_iter): + return all(isinstance(i, int) for i in my_iter) + + +def _get_valid_shape(shape): + if isinstance(shape, list): + if _all_int(shape): + return tuple(shape) + else: + raise RuntimeError("expects type(int) but finds an other type") + elif isinstance(shape, tuple): + if _all_int(shape): + return shape + else: + return _get_valid_shape(shape[0]) + else: + raise RuntimeError("expects an iterable array but finds '{}'".format(type(shape))) + + +def _shape_infer(org_sp, tgt_sp): + cnt = 0 + pos = 0 + for idx, dim in enumerate(tgt_sp): + if dim < -1: + raise RuntimeError("invalid shape dimension {}".format(dim)) + elif dim == -1: + cnt += 1 + pos = idx + + if cnt > 1: + raise RuntimeError("only one dimension can be inferred") + + org_prod = math.prod(org_sp) + tgt_prod = math.prod(tgt_sp) + + if cnt == 0: + if org_prod != tgt_prod: + raise RuntimeError("shape '{}' is invalid for input of size {}".format(tgt_sp, org_prod)) + else: + return tgt_sp + elif org_prod % tgt_prod != 0: + raise RuntimeError("shape '{}' is invalid for input of size {}".format(tgt_sp, org_prod)) + + infer_dim = -(org_prod // tgt_prod) + return tgt_sp[: pos] + (infer_dim,) + tgt_sp[pos + 1:] + + +@colo_op_impl(torch.Tensor.view) +def colo_view(self: ColoTensor, *shape) -> 'ColoTensor': + """Handles ``__torch_function__`` dispatch for ``torch.Tensor.view``. + Changes the shape of the current tensor. + """ + assert isinstance(self, ColoTensor) + # apply original `view` function for replicated colo tensors + if self.is_replicate(): + return self.view(*shape) + + cur_sp = self.size() + org_sp = self.size_global() + # parse the passed arguments + tgt_sp = _get_valid_shape(shape) + # get the correct shape from inference + inf_sp = _shape_infer(org_sp, tgt_sp) + + if self.is_shard_1drow() and org_sp[0] == inf_sp[0]: + new_shape = (cur_sp[0],) + tgt_sp[1:] + res = self.view(*new_shape) + elif self.is_shard_1dcol() and org_sp[-1] == inf_sp[-1]: + new_shape = tgt_sp[:-1] + (cur_sp[-1],) + res = self.view(*new_shape) + else: + replicated_t = self.redistribute(dist_spec=ReplicaSpec()) + return ColoTensor.from_torch_tensor( + tensor=replicated_t.view(*shape), + spec=ColoTensorSpec(self.get_process_group())) + + return ColoTensor.from_torch_tensor( + tensor=res, + spec=ColoTensorSpec( + pg=self.get_process_group(), + dist_attr=self.dist_spec)) + + +@colo_op_impl(torch.Tensor.size) +def colo_size(self: ColoTensor, dim: Optional[int] = None) -> Union[torch.Size, int]: + size = self.size_global() + if dim is None: + return size + else: + return size[dim] diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index 51cc4619a5cc..05b34de7f1ff 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -22,28 +22,30 @@ def _get_my_nowrap_functions() -> Set[Callable]: } -def _convert_output(output, pg: ProcessGroup): +def _convert_output(output, colo_spec: ColoTensorSpec): if type(output) == torch.Tensor: - return ColoTensor.from_torch_tensor(output, ColoTensorSpec(pg)) + return ColoTensor.from_torch_tensor(output, colo_spec) elif isinstance(output, (list, tuple)): - return type(output)(_convert_output(o, pg) for o in output) + return type(output)(_convert_output(o, colo_spec) for o in output) else: return output -def _scan_for_pg_from_args(args, kwargs) -> ProcessGroup: +def _get_spec_from_args(args, kwargs) -> ColoTensorSpec: for elem in args: if isinstance(elem, ColoTensor): pg = elem.get_process_group() - return pg + dp = elem.dist_spec + return ColoTensorSpec(pg, dp) elif isinstance(elem, (list, tuple)): - pg = _scan_for_pg_from_args(elem, {}) - if pg is not None: - return pg + spec = _get_spec_from_args(elem, {}) + if spec is not None: + return spec for k, v in kwargs.items(): if isinstance(v, ColoTensor): pg = v.get_process_group() - return pg + dp = v.dist_spec + return ColoTensorSpec(pg, dp) return None @@ -170,11 +172,11 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): if func in _get_my_nowrap_functions(): return ret else: - pg = _scan_for_pg_from_args(args, kwargs) - return _convert_output(ret, pg) + colo_spec = _get_spec_from_args(args, kwargs) + return _convert_output(ret, colo_spec) def __repr__(self): - return f'ColoTensor:\n{super().__repr__()}\n{self.dist_spec}\n{self.process_group}' + return f'ColoTensor:\n{super().__repr__()}\n{self.dist_spec}\n{self.process_group}\n{self.compute_spec}' def _redistribute(self, dist_spec: _DistSpec) -> None: """_redistribute @@ -243,50 +245,32 @@ def __deepcopy__(self, memo): memo[id(self)] = tensor return tensor - ##### override builtin functions which must use tensor in replicate placement #### + # override builtin functions which must use tensor in replicate placement # - def view_local(self, *args) -> 'ColoTensor': - return super().view(*args) - - def size_local(self, *args, **kwargs) -> torch.Size: - return super().size(*args, **kwargs) - - def view_global(self, *args) -> 'ColoTensor': - """override the torch buildin view() - the args passed in must be in a replicate placement. - Returns: - ColoTensor: a tensor after viewed. - """ - if self.is_replicate(): - return super().view(*args) - replicated_t = self.redistribute(dist_spec=ReplicaSpec()) - return replicated_t.view(*args) + def size_local(self, *args) -> torch.Size: + with torch._C.DisableTorchFunction(): + return super().size(*args) - def size_global(self, args: Optional[int] = None) -> torch.Size: + def size_global(self, *args) -> torch.Size: """override the torch buildin size() the shape passed in must be in a replicate placement. Returns: ColoTensor: a tensor after viewed. """ if self.is_replicate(): - if args is not None: - return super().size(args) - else: - return super().size() - + return self.size_local(*args) spec = self.dist_spec dims = spec.dims num_partitions = spec.num_partitions # import inspect # print(*['{:40}| {}:{}\n'.format(x.function, x.filename, x.lineno) for x in inspect.stack()]) - - size_list = list(super().size()) + size_list = list(self.size_local()) for dim, num_partition in zip(dims, num_partitions): size_list[dim] *= num_partition - if args is not None: - return size_list[args] - else: + if args == (): return torch.Size(size_list) + else: + return size_list[args[0]] # Some API for dist spec check diff --git a/colossalai/tensor/compute_spec.py b/colossalai/tensor/compute_spec.py index acaba2a46c26..bc0ea98ccc7a 100644 --- a/colossalai/tensor/compute_spec.py +++ b/colossalai/tensor/compute_spec.py @@ -22,4 +22,7 @@ def __init__(self, compute_pattern: ComputePattern) -> None: self.output_replicate = True def __repr__(self): - return f'compute pattern: {self.compute_pattern}' + return f'Compute pattern: {self.compute_pattern}' + + def set_output_replicate(self, flag: bool = True): + self.output_replicate = flag diff --git a/colossalai/utils/checkpoint/utils.py b/colossalai/utils/checkpoint/utils.py index e018d37119d7..3b8b83c15072 100644 --- a/colossalai/utils/checkpoint/utils.py +++ b/colossalai/utils/checkpoint/utils.py @@ -1,7 +1,7 @@ import torch import torch.distributed as dist from colossalai.tensor import ColoTensor, ColoTensorSpec -from colossalai.tensor.distspec import _DistSpec +from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern def gather_tensor(colo_tensor: ColoTensor) -> None: @@ -26,7 +26,7 @@ def gather_tensor(colo_tensor: ColoTensor) -> None: def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None: """Reversal operation of `gather_tensor`. """ - if dist_spec.placement == 'r': + if dist_spec.placement == DistPlacementPattern.REPLICATE: dist.broadcast(colo_tensor.data, 0) else: global_size = colo_tensor.size_global() diff --git a/tests/test_tensor/common_utils/_utils.py b/tests/test_tensor/common_utils/_utils.py index 1c86871add8c..5c5d06622beb 100644 --- a/tests/test_tensor/common_utils/_utils.py +++ b/tests/test_tensor/common_utils/_utils.py @@ -73,3 +73,9 @@ def split_param_row_tp1d(param, pg): def split_param_col_tp1d(param, pg): split_param_single_dim_tp1d(-1, param, pg) + + +def debug_print(ranks, *args): + if dist.get_rank() in ranks: + print(*args) + dist.barrier() diff --git a/tests/test_tensor/core/test_tensor.py b/tests/test_tensor/core/test_tensor.py index ad9547ef6e8d..b48d9e9a2dfa 100644 --- a/tests/test_tensor/core/test_tensor.py +++ b/tests/test_tensor/core/test_tensor.py @@ -75,7 +75,7 @@ def _run_view(world_size): assert t.size_global(1) == 5 assert t.size_global() == torch.Size([4 * world_size, 5]) - t = t.view_global(4 * 5 * world_size) + t = t.view(4 * 5 * world_size) assert t.shape == torch.Size([4 * 5 * world_size]) @@ -129,9 +129,9 @@ def _run_set_tensor_spec(world_size): spec1 = ColoTensorSpec(pg) t1 = ColoTensor.from_torch_tensor(torch.randn(2, 3, 4), spec1) - dist_spec2 = (ShardSpec([-1], [pg.tp_world_size()]), None) + dist_spec2 = ShardSpec([-1], [pg.tp_world_size()]) assert t1.is_replicate() - t1.set_dist_spec(*dist_spec2) + t1.set_dist_spec(dist_spec2) assert t1.is_shard_1dcol() diff --git a/tests/test_tensor/model/test_gpt2.py b/tests/test_tensor/model/test_gpt2.py index 19c6716b0fc1..b74016451389 100644 --- a/tests/test_tensor/model/test_gpt2.py +++ b/tests/test_tensor/model/test_gpt2.py @@ -15,6 +15,7 @@ from colossalai.tensor import ShardSpec, ComputePattern, ComputeSpec, ProcessGroup, ColoTensor, ColoTensorSpec from colossalai.nn.parallel.data_parallel import ColoDDP from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_tensor.common_utils import split_param_col_tp1d, split_param_row_tp1d, debug_print def init_1d_row_spec(model, pg: ProcessGroup): @@ -34,6 +35,32 @@ def init_1d_col_spec(model, pg: ProcessGroup): p.set_tensor_spec(*spec) +def init_megatron_spec(model, pg: ProcessGroup): + for mn, module in model.named_modules(): + # debug_print([0], mn) + for pn, param in module.named_parameters(recurse=False): + # debug_print([0], '\t', pn, param.compute_spec, param.shape) + param.set_process_group(pg) + + if 'mlp.c_fc' in mn: + if 'weight' in pn or 'bias' in pn: + split_param_col_tp1d(param, pg) + param.compute_spec.set_output_replicate(False) + else: + raise RuntimeError + elif 'mlp.c_proj' in mn: + if 'weight' in pn: + split_param_row_tp1d(param, pg) + else: + assert 'bias' in pn + elif 'wte' in mn or 'wpe' in mn: + assert 'weight' in pn + split_param_col_tp1d(param, pg) + elif 'c_fc' in mn or 'c_proj' in mn: + split_param_col_tp1d(param, pg) + # debug_print([0], '\t', param.compute_spec, param.shape) + + def check_param_equal(model, torch_model, pg: ProcessGroup): for p, torch_p in zip(model.parameters(), torch_model.parameters()): assert pg.tp_local_rank() is not None, f"{pg.rank()} {pg.tp_world_size()} {pg._tp_degree} {pg.tp_local_rank()}1" @@ -102,8 +129,10 @@ def run_dist(rank, world_size, port, use_ddp): if use_ddp and world_size == 1: return colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_gpt(init_1d_row_spec, use_ddp) - run_gpt(init_1d_col_spec, use_ddp) + # Comments below tests for speed concern + # run_gpt(init_1d_row_spec, use_ddp) + # run_gpt(init_1d_col_spec, use_ddp) + run_gpt(init_megatron_spec, use_ddp) @pytest.mark.dist @@ -116,4 +145,4 @@ def test_gpt(world_size, use_ddp): if __name__ == '__main__': - test_gpt(4, use_ddp=True) + test_gpt(4, use_ddp=False) diff --git a/tests/test_tensor/ops/test_view.py b/tests/test_tensor/ops/test_view.py new file mode 100644 index 000000000000..fc6fc2d3c291 --- /dev/null +++ b/tests/test_tensor/ops/test_view.py @@ -0,0 +1,100 @@ +from functools import partial + +import colossalai +import pytest +import torch +import torch.multiprocessing as mp +import torch.distributed as dist +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port, get_current_device +from colossalai.tensor import ColoTensorSpec, ProcessGroup, ColoTensor, ShardSpec +from colossalai.tensor.distspec import DistPlacementPattern +from tests.test_tensor.common_utils import split_param_row_tp1d, split_param_col_tp1d, debug_print + + +def exam_view_core(pg): + # the case of replicated ColoTensors + x = torch.randn(4, 4).cuda() + x_colo = ColoTensor(x, ColoTensorSpec(pg)) + + y = x.view(2, -1, 2) + y_colo = x_colo.view(2, -1, 2) + + assert torch.all(y == y_colo) + assert y_colo.dist_spec.placement == DistPlacementPattern.REPLICATE + # the perfect case of col-sliced ColoTensors + split_param_col_tp1d(x_colo, pg) + + z = x.view(torch.Size((2, 1, 2, -1))) + z_colo = x_colo.view(torch.Size((2, 1, 2, -1))) + if dist.get_rank() == 0: + z = z[:, :, :, 0:2] + else: + z = z[:, :, :, 2:] + assert torch.all(z == z_colo) + assert z_colo.dist_spec == x_colo.dist_spec + # the perfect case of row-sliced ColoTensors + split_param_row_tp1d(x_colo, pg) + + z = x.view(torch.Size((-1, 2, 2))) + z_colo = x_colo.view(torch.Size((-1, 2, 2))) + if dist.get_rank() == 0: + z = z[0:2, :, :] + else: + z = z[2:, :, :] + assert torch.all(z == z_colo) + assert z_colo.dist_spec == x_colo.dist_spec + # the normal case of row-sliced ColoTensors + z = x.view(-1, 2, 2, 2) + z_colo = x_colo.view(-1, 2, 2, 2) + assert torch.all(z == z_colo) + assert y_colo.dist_spec.placement == DistPlacementPattern.REPLICATE + + +def exam_view_autograd(pg): + x = torch.randn(8, 2, device=get_current_device(), requires_grad=True) + y = torch.randn(8, 2, device=get_current_device(), requires_grad=True) + with torch.no_grad(): + y.copy_(x) + y = ColoTensor(y, ColoTensorSpec(pg)) + y_slice = y.redistribute(ShardSpec([-1], [pg.tp_world_size()])) + + xx = x.view(2, 2, -1) + yy_slice = y_slice.view(2, 2, -1) + yy = yy_slice.to_replicate() + grad = torch.randn(2, 2, 4, device=get_current_device()) + + xx.backward(grad) + yy.backward(grad) + assert torch.all(x.grad == y.grad) + + +def exam_view_errors(pg): + x = torch.randn(8, 2, device=get_current_device()) + x = ColoTensor(x, ColoTensorSpec(pg)) + split_param_row_tp1d(x, pg) + + x.view('a', 'b', 'c') + x.view(8, -1) + x.view([-2, -2, -2]) + x.view((-1, -1, -1)) + + +def run_dist(rank, world_size, port): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + pg = ProcessGroup(tp_degree=torch.distributed.get_world_size()) + exam_view_core(pg) + exam_view_autograd(pg) + # exam_view_errors(pg) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [2]) +@rerun_if_address_is_in_use() +def test_view(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_view(2) diff --git a/tests/test_utils/test_colo_checkpoint_tools.py b/tests/test_tensor/test_colo_checkpoint_tools.py similarity index 91% rename from tests/test_utils/test_colo_checkpoint_tools.py rename to tests/test_tensor/test_colo_checkpoint_tools.py index 551886f25b6f..ec80c06e87de 100644 --- a/tests/test_utils/test_colo_checkpoint_tools.py +++ b/tests/test_tensor/test_colo_checkpoint_tools.py @@ -11,7 +11,7 @@ from colossalai.utils import free_port from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ProcessGroup, ColoTensorSpec from colossalai.utils.checkpoint.utils import gather_tensor, scatter_tensor -from tests.test_tensor._utils import tensor_shard_equal +from tests.test_tensor.common_utils import tensor_shard_equal def run_dist(rank, world_size, port, dp_degree, tp_degree): @@ -24,7 +24,7 @@ def run_dist(rank, world_size, port, dp_degree, tp_degree): gather_tensor(param) if dist.get_rank() == 0: - assert torch.allclose(x, param.data, rtol=0, atol=0) + assert torch.all(x == param) else: assert tensor_shard_equal(x, param.data, pg.tp_local_rank(), pg.tp_world_size()) dist.barrier() diff --git a/tests/test_zero/test_sharded_optim_state_dict.py b/tests/test_zero/test_sharded_optim_state_dict.py index c8143901d813..f8c42930b281 100644 --- a/tests/test_zero/test_sharded_optim_state_dict.py +++ b/tests/test_zero/test_sharded_optim_state_dict.py @@ -6,7 +6,7 @@ from colossalai.utils.cuda import get_current_device from colossalai.utils import free_port from functools import partial -from tests.test_tensor._utils import set_seed +from tests.test_tensor.common_utils import set_seed from tests.components_to_test.registry import non_distributed_component_funcs from colossalai.testing import parameterize from colossalai.nn.optimizer import HybridAdam diff --git a/tests/test_zero/test_zero_optim_state_dict.py b/tests/test_zero/test_zero_optim_state_dict.py index 7ecb917951b9..5ab6ee0be6eb 100644 --- a/tests/test_zero/test_zero_optim_state_dict.py +++ b/tests/test_zero/test_zero_optim_state_dict.py @@ -9,7 +9,7 @@ from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.core import global_context as gpc from functools import partial -from tests.test_tensor._utils import set_seed +from tests.test_tensor.common_utils import set_seed from tests.components_to_test.registry import non_distributed_component_funcs from colossalai.nn.parallel.data_parallel import ZeroDDP from colossalai.gemini import ChunkManager, GeminiManager