diff --git a/colossalai/communication/p2p_v2.py b/colossalai/communication/p2p_v2.py new file mode 100644 index 000000000000..0b575e7dba77 --- /dev/null +++ b/colossalai/communication/p2p_v2.py @@ -0,0 +1,268 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from typing import List, Tuple, Union, Any +import pickle +import io + +import torch +import torch.distributed as dist +from torch.distributed import distributed_c10d as c10d +from torch.distributed import ProcessGroupNCCL + +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc + +TensorShape = Union[torch.Size, List[int], Tuple[int]] +_pg_manager = {} +_unpickler = pickle.Unpickler + + +def init_process_group(): + """intialise process group by dist.new_group in the adjacent stages + + Args: + None + + Returns: + None + """ + world_size = gpc.get_world_size(ParallelMode.PIPELINE) + for i in range(world_size - 1): + _pg_manager[(i, i + 1)] = dist.new_group([i, i + 1]) + + +def _acquire_pair_group_handle(first_rank: int, second_rank: int) -> ProcessGroupNCCL: + """get the group handle of two given ranks + + Args: + first_rank (int): first rank in the pair + second_rank (int): second rank in the pair + + Returns: + :class:`ProcessGroupNCCL`: the handle of the group consisting of the given two ranks + """ + if len(_pg_manager) == 0: + init_process_group() + if first_rank > second_rank: + first_rank, second_rank = second_rank, first_rank + pair_key = (first_rank, second_rank) + return _pg_manager[pair_key] + + +def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -> object: + """transform tensor to object with unpickle. + Info of the device in bytes stream will be modified into current device before unpickling + + Args: + tensor (:class:`torch.tensor`): tensor to be unpickled + tensor_size (:class:`torch.Size`): Size of the real info in bytes + + Returns: + Any: object after unpickled + """ + buf = tensor.numpy().tobytes()[:tensor_size] + if b'cuda' in buf: + buf_array = bytearray(buf) + device_index = torch.cuda.current_device() + buf_array[buf_array.find(b'cuda') + 5] = 48 + device_index + buf = bytes(buf_array) + + io_bytes = io.BytesIO(buf) + byte_pickler = _unpickler(io_bytes) + unpickle = byte_pickler.load() + + return unpickle + + +def _broadcast_object_list(object_list: List[Any], src: int, dst: int, device=None): + """This is a modified version of the broadcast_object_list in torch.distribution + The only difference is that object will be move to correct device after unpickled. + If local_rank = src, then object list will be sent to rank src. Otherwise, object list will + be updated with data sent from rank src. + + Args: + object_list (List[Any]): list of object to broadcast + src (int): source rank to broadcast + dst (int): dst rank to broadcast + device (:class:`torch.device`): device to do broadcast. current device in default + + """ + group = _acquire_pair_group_handle(src, dst) + + if c10d._rank_not_in_group(group): + c10d._warn_not_in_group("broadcast_object_list") + return + + local_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + # Serialize object_list elements to tensors on src rank. + if local_rank == src: + tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list]) + object_sizes_tensor = torch.cat(size_list) + else: + object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long) + + is_nccl_backend = c10d._check_for_nccl_backend(group) + current_device = None + + if device is not None: + if is_nccl_backend and device.type != "cuda": + raise ValueError("device type must be cuda for nccl backend") + current_device = device + else: + current_device = torch.device("cpu") + if is_nccl_backend: + current_device = torch.device("cuda", torch.cuda.current_device()) + if is_nccl_backend: + object_sizes_tensor = object_sizes_tensor.to(current_device) + + # Broadcast object sizes + c10d.broadcast(object_sizes_tensor, src=src, group=group, async_op=False) + + # Concatenate and broadcast serialized object tensors + if local_rank == src: + object_tensor = torch.cat(tensor_list) + else: + object_tensor = torch.empty( # type: ignore[call-overload] + torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type] + dtype=torch.uint8, + ) + + if is_nccl_backend: + object_tensor = object_tensor.to(current_device) + + c10d.broadcast(object_tensor, src=src, group=group, async_op=False) + + # Deserialize objects using their stored sizes. + offset = 0 + + if local_rank != src: + for i, obj_size in enumerate(object_sizes_tensor): + obj_view = object_tensor[offset:offset + obj_size] + obj_view = obj_view.type(torch.uint8) + if obj_view.device != torch.device("cpu"): + obj_view = obj_view.cpu() + offset += obj_size + # unpickle + unpickle_object = _cuda_safe_tensor_to_object(obj_view, obj_size) + + # unconsistence in device + if isinstance(unpickle_object, + torch.Tensor) and unpickle_object.device.index != torch.cuda.current_device(): + unpickle_object = unpickle_object.cuda() + + object_list[i] = unpickle_object + + +def _send_object(object: Any, dst: int) -> None: + """send anything to dst rank + Args: + object (Any): object needed to be sent + dst (int): rank of the destination + + Returns: + None + """ + local_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + # handler = _acquire_pair_group_handle(local_rank, dst) + + # transform to list if not + if isinstance(object, torch.Tensor): + object = [object] + + # broadcast length first + # TODO : more elegant ? P.S. reduce a _broadcast_object_list + _broadcast_object_list([len(object)], local_rank, dst) + # then broadcast safely + _broadcast_object_list(object, local_rank, dst) + + +def _recv_object(src: int) -> Any: + """recv anything from src + + Args: + src (int): source rank of data. local rank will receive data from src rank. + + Returns: + Any: Object received from src. + """ + local_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + # handler = _acquire_pair_group_handle(local_rank, src) + # recv length first + length = [0] + _broadcast_object_list(length, src, local_rank) + + # then create recv buff from length[0] and broadcast + object = [None] * length[0] + _broadcast_object_list(object, src, local_rank) + + if length[0] == 1: + object = object[0] + + return object + + +def recv_forward(prev_rank: int = None) -> Any: + """Copy the forward output from the previous stage in pipeline as the input tensor of this stage. + + Args: + input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received. + prev_rank (int, optional): The rank of the source of the tensor. + + Returns: + Any: The input tensor or input tensor list. + """ + if gpc.is_pipeline_first_stage(): + input_tensor = None + else: + if prev_rank is None: + prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) + input_tensor = _recv_object(prev_rank) + + return input_tensor + + +def recv_backward(next_rank: int = None) -> Any: + """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. + + Args: + output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received. + next_rank (int, optional): The rank of the source of the tensor. + + Returns: + Any: The input gradient tensor or gradident tensor list. + """ + if gpc.is_pipeline_last_stage(): + output_tensor_grad = None + else: + if next_rank is None: + next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) + output_tensor_grad = _recv_object(next_rank) + + return output_tensor_grad + + +def send_forward(output_object: Any, next_rank: int = None) -> None: + """Sends the input tensor to the next stage in pipeline. + + Args: + output_object Any: Object to be sent. + next_rank (int, optional): The rank of the recipient of the tensor. + """ + if not gpc.is_pipeline_last_stage(): + if next_rank is None: + next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) + _send_object(output_object, next_rank) + + +def send_backward(input_object: Any, prev_rank: int = None) -> None: + """Sends the gradient tensor to the previous stage in pipeline. + + Args: + input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent + prev_rank (int, optional): The rank of the recipient of the tensor + """ + if not gpc.is_pipeline_first_stage(): + if prev_rank is None: + prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) + _send_object(input_object, prev_rank) diff --git a/tests/test_comm/test_boardcast_send_recv_v2.py b/tests/test_comm/test_boardcast_send_recv_v2.py new file mode 100644 index 000000000000..1520d6054043 --- /dev/null +++ b/tests/test_comm/test_boardcast_send_recv_v2.py @@ -0,0 +1,54 @@ +from functools import partial +from typing import List + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from colossalai.communication.p2p_v2 import _send_object, _recv_object, init_process_group +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.utils import free_port, get_current_device +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.logging import disable_existing_loggers + +disable_existing_loggers() +world_size = 4 +CONFIG = dict(parallel=dict(pipeline=world_size)) +torch.manual_seed(123) + + +def check_layer(rank, world_size, port): + disable_existing_loggers() + launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl', verbose=False) + rank = gpc.get_local_rank(ParallelMode.PIPELINE) + + if rank == 0: + obj = [torch.randn(3,)] + _send_object(obj, 1) + + if rank == 1: + _recv_object(0) + + if rank == 2: + _recv_object(3) + + if rank == 3: + obj = [torch.randn(3,)] + _send_object(obj, 2) + + gpc.destroy() + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_object_list_p2p(): + disable_existing_loggers() + run_func = partial(check_layer, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_object_list_p2p() diff --git a/tests/test_comm/test_object_list_p2p_v2.py b/tests/test_comm/test_object_list_p2p_v2.py new file mode 100644 index 000000000000..c639ac9f8ef3 --- /dev/null +++ b/tests/test_comm/test_object_list_p2p_v2.py @@ -0,0 +1,132 @@ +from functools import partial + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from colossalai.communication.p2p_v2 import send_forward, recv_forward, send_backward, recv_backward, init_process_group +from colossalai.context import ParallelMode, Initializer_Pipeline +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.utils import free_port, get_current_device +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.logging import disable_existing_loggers + +disable_existing_loggers() + +# config +world_size = 4 +CONFIG = dict(parallel=dict(pipeline=4)) +torch.manual_seed(123) +use_scatter_gather_tensors = False + +# data +torch.manual_seed(123) +LIST_LENGTH = 3 +TENSOR_SIZE = torch.Size((3, 3)) +TENSOR_SIZE_LIST = [TENSOR_SIZE for i in range(LIST_LENGTH)] +data = torch.rand(3, 3) +data_list = [torch.rand(3, 3) for i in range(LIST_LENGTH)] +grad = torch.rand(3, 3) +grad_list = [torch.rand(3, 3) for i in range(LIST_LENGTH)] + + +def check_send_recv_forward(): + disable_existing_loggers() + local_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + + if local_rank == 0: + device = torch.device('cuda:0') + data_to_send = data.to(device) + data_list_to_send = [] + for data_in_list in data_list: + data_list_to_send.append(data_in_list.to(device)) + + send_forward(data_to_send, scatter_gather_tensors=use_scatter_gather_tensors) + send_forward(data_list_to_send, scatter_gather_tensors=use_scatter_gather_tensors) + + elif local_rank == 1: + device = torch.device('cuda:1') + + data_recv = recv_forward(TENSOR_SIZE, scatter_gather_tensors=use_scatter_gather_tensors) + data_list_recv = recv_forward(TENSOR_SIZE_LIST, scatter_gather_tensors=use_scatter_gather_tensors) + + data_to_check = data.to(device) + + assert data_recv.equal(data_to_check) + + for data_recv, data_send in zip(data_list_recv, data_list): + data_to_check = data_send.to(device) + data_recv = data_recv.to(device) + assert data_recv.equal(data_to_check) + + +def check_send_recv_backward(): + disable_existing_loggers() + if gpc.get_local_rank(ParallelMode.PIPELINE) == 0: + device = torch.device('cuda:0') + grad_recv = recv_backward(TENSOR_SIZE) + grad_list_recv = recv_backward(TENSOR_SIZE_LIST) + + grad_to_check = grad.to(device) + grad_recv = grad_recv[0].to(device) + + assert grad_recv.equal(grad_to_check) + for grad_recv, grad_send in zip(grad_list_recv, grad_list): + grad_recv = grad_recv.to(device) + grad_to_check = grad_send.to(device) + assert grad_recv.equal(grad_to_check) + else: + device = torch.device('cuda:1') + grad_to_send = grad.to(device) + grad_list_to_send = [] + for grad_in_list in grad_list: + grad_list_to_send.append(grad_in_list.to(device)) + send_backward(grad_to_send) + send_backward(grad_list_to_send) + + +def check_small_pipeline(): + disable_existing_loggers() + # make sure the rank is 4 + assert gpc.world_size == 4, "make sure to set world size to 4 to start the training process" + local_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + if local_rank == 0: + obj = [1, torch.randn(2, 2).cuda(), None] + send_forward(obj) + elif local_rank == 1: + obj = recv_forward() + send_forward(obj) + elif local_rank == 2: + obj = recv_forward() + send_forward(obj) + elif local_rank == 3: + obj = recv_forward() + else: + pass + + +def check_layer(rank, world_size, port): + disable_existing_loggers() + launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + disable_existing_loggers() + # check_send_recv_forward() + check_small_pipeline() + + gpc.destroy() + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_object_list_p2p(): + disable_existing_loggers() + run_func = partial(check_layer, world_size=world_size, port=free_port()) + disable_existing_loggers() + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + disable_existing_loggers() + test_object_list_p2p()