From 6257c22e01b79abd824391f1943ea8e513206ba0 Mon Sep 17 00:00:00 2001 From: Kirigaya Kazuto <59416203+LSTM-Kirigaya@users.noreply.github.com> Date: Fri, 5 Aug 2022 17:58:33 +0800 Subject: [PATCH 1/5] Delete _pipeline_schedule_v2.py --- .../engine/schedule/_pipeline_schedule_v2.py | 807 ------------------ 1 file changed, 807 deletions(-) delete mode 100644 colossalai/engine/schedule/_pipeline_schedule_v2.py diff --git a/colossalai/engine/schedule/_pipeline_schedule_v2.py b/colossalai/engine/schedule/_pipeline_schedule_v2.py deleted file mode 100644 index ebf216167381..000000000000 --- a/colossalai/engine/schedule/_pipeline_schedule_v2.py +++ /dev/null @@ -1,807 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import inspect -from typing import Callable, List, Tuple, Union - -import colossalai.communication.p2p_v2 as comm -import torch.cuda -from colossalai.amp.naive_amp import NaiveAMPModel -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.logging import get_dist_logger -from colossalai.utils import switch_virtual_pipeline_parallel_rank -from colossalai.utils.cuda import get_current_device -# TODO remove it when release -from colorama import Back, Style - -from ._base_schedule import BaseSchedule - - -def get_tensor_shape(): - if hasattr(gpc.config, 'TENSOR_SHAPE'): - return gpc.config.TENSOR_SHAPE - - if not gpc.is_initialized(ParallelMode.PIPELINE): - return None - - if hasattr(gpc.config, 'SEQ_LENGTH') and hasattr(gpc.config, 'GLOBAL_BATCH_SIZE') and hasattr( - gpc.config, 'GLOBAL_BATCH_SIZE') and hasattr(gpc.config, 'HIDDEN_SIZE'): - if gpc.is_initialized(ParallelMode.DATA): - dp_size = gpc.get_world_size(ParallelMode.DATA) - else: - dp_size = 1 - if gpc.is_initialized(ParallelMode.SEQUENCE): - seq_size = gpc.get_world_size(ParallelMode.SEQUENCE) - else: - seq_size = 1 - - tensor_shape = (gpc.config.SEQ_LENGTH // seq_size, - gpc.config.GLOBAL_BATCH_SIZE // dp_size // gpc.config.NUM_MICRO_BATCHES, gpc.config.HIDDEN_SIZE) - return tensor_shape - else: - return None - - -def pack_return_tensors(return_tensors): - output, label = tuple(zip(*return_tensors)) - if isinstance(output[0], torch.Tensor): - output = torch.cat(output, dim=0) - elif isinstance(output[0], (list, tuple)): - output = tuple(torch.cat(tensors, dim=0) for tensors in zip(*output)) - else: - raise TypeError(f'Output of model must be tensor or list/tuple of tensors') - if isinstance(label[0], torch.Tensor): - label = torch.cat(label, dim=0) - else: - merged_label = {k: [] for k in label[0].keys()} - for d in label: - for k, v in d.items(): - merged_label[k].append(v) - label = {k: torch.cat(v, dim=0) for k, v in merged_label.items()} - return output, label - - -class PipelineSchedule(BaseSchedule): - """A helper schedule class for pipeline parallelism running environment. - It uses non-interleaved 1F1B strategy. Other properties are similar as - :class:`NonPipelineSchedule`. - - Args: - num_microbatches (int): The number of microbatches. - data_process_func (Callable, optional): - The preprocessing function which receives a batch of data, and it will be executed in `load_batch`. - tensor_shape (torch.Size, optional): Specified shape in pipeline communication. - scatter_gather_tensors (bool, optional): - If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization. - - Example: - - # this shows an example of customized data_process_func - def data_process_func(stage_output, dataloader_output): - output1, output2 = stage_output - item1, item2, item3 = dataloader_output - - # assume item2 is not needed - data = (output1, output2, item1) - label = item3 - return data, label - - """ - - def __init__(self, - num_microbatches, - data_process_func: Callable = None): - - # we need to make sure that the signature of the data_process_func is valid - if data_process_func: - sig = inspect.signature(data_process_func) - assert len(sig.parameters) == 2, \ - 'The data_process_func only takes in two parameters for NonPipelineSchedule, ' \ - 'which is the tensors passed by the previous pipeline stage and the dataloader output from this stage, ' \ - 'i.e. data_process_func(stage_output, dataloader_output).' - - super().__init__(data_process_func=data_process_func) - - assert num_microbatches > 0, f'expected num_microbatches to be larger then 1, but got {num_microbatches}' - - self.num_microbatches = num_microbatches - self.dtype = torch.float - - self._logger = get_dist_logger() - - # cache for the batch data - self.batch_data = None - - - def load_batch(self, data_iter): - # Pipeline schedule just puts data in memory - batch_data = super().load_batch(data_iter, to_gpu=False) - self.microbatch_offset = 0 - assert self.batch_size % self.num_microbatches == 0, \ - "Batch size should divided by the number of microbatches" - self.microbatch_size = self.batch_size // self.num_microbatches - self.batch_data = batch_data - - def _get_data_slice(self, data, offset): - if isinstance(data, torch.Tensor): - return data[offset:offset + self.microbatch_size] - elif isinstance(data, (list, tuple)): - data_dict = {} - for element in data: - if isinstance(element, dict): - data_dict.update({k: v[offset:offset + self.microbatch_size] for k, v in element.items()}) - elif data_dict: - data_dict['label'] = element[offset:offset + self.microbatch_size] - if data_dict: - return data_dict - return [val[offset:offset + self.microbatch_size] for val in data] - elif isinstance(data, dict): - return {k: v[offset:offset + self.microbatch_size] for k, v in data.items()} - else: - raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}") - - def load_micro_batch(self): - mciro_batch_data = self._get_data_slice(self.batch_data, self.microbatch_offset) - self.microbatch_offset += self.microbatch_size - return self._move_to_device(mciro_batch_data) - - def pre_processing(self, engine): - from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 - # TODO: remove this after testing new zero with pipeline parallelism - model = engine.model - if isinstance(model, NaiveAMPModel): - self.dtype = torch.half - model = model.model - if isinstance(model, ShardedModelV2): - self.dtype = torch.half - model = model.module - # sig = inspect.signature(model.forward) - # for p in sig.parameters.values(): - # assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported' - - @staticmethod - def _call_engine(model, data): - if data is not None: - if isinstance(data, torch.Tensor): - return model(data) - elif isinstance(data, (list, tuple)): - return model(*data) - elif isinstance(data, dict): - stage_output = None - if 'stage_output' in data: - stage_output = data.pop('stage_output') - if stage_output is None: - return model(**data) - elif isinstance(stage_output, torch.Tensor): - return model(stage_output, **data) - elif isinstance(stage_output, (tuple, list)): - return model(*stage_output, **data) - else: - raise TypeError( - f"Expected stage_output to be of type torch.Tensor, list, or tuple, but got {type(stage_output)}" - ) - else: - raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}") - - def _get_actual_forward_func(self, module): - if isinstance(module, NaiveAMPModel): - sig = inspect.signature(module.model.forward) - elif hasattr(module, 'colo_attr'): - sig = inspect.signature(module.module.forward) - else: - sig = inspect.signature(module.forward) - return sig - - def _get_data_label_for_current_step(self, stage_output, micro_batch_data, criterion, model): - if self.data_process_func: - # use customized function to get data and label - data, label = self.data_process_func(stage_output, micro_batch_data) - else: - if isinstance(micro_batch_data, (tuple, list)): - if gpc.is_first_rank(ParallelMode.PIPELINE): - # for the first stage, we use the data from the - # dataloader output by default - data, label = micro_batch_data - else: - # for non-first stage, we use the output passed - # by the previous as the model input - data = stage_output - _, label = micro_batch_data - elif isinstance(micro_batch_data, dict): - data = {} - data['stage_output'] = stage_output - if 'label' in micro_batch_data: - label = micro_batch_data.pop('label') - else: - label = None - load_data = micro_batch_data - data.update(load_data) - return data, label - - def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None): - """Forward step for passed-in model. If it is the first stage, the input tensor - is obtained from data_iterator, otherwise the passed-in input_obj is used. - Returns output tensor. This is a helper function and can be ignored by users. - - Args: - engine (colossalai.engine.Engine): Colossalai engine for training and inference. - input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Input tensor for this pipeline stage. - return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return. - return_output_label (bool, optional): Whether returns output labels. - accum_loss (optional): Where accumulated loss stores. - Returns: - Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current pipeline stage. - """ - micro_batch_data = self.load_micro_batch() - - data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data, engine.criterion, engine.model) - - output_obj = self._call_engine(engine.model, data) - - if gpc.is_last_rank(ParallelMode.PIPELINE): - if return_output_label: - return_tensors.append((output_obj, label)) - if accum_loss is not None: - loss_reduced = self._call_engine_criterion(engine, output_obj, label) / self.num_microbatches - accum_loss.add_(loss_reduced.detach()) - return loss_reduced - else: - # forward only, it's useless since backward is not needed - return output_obj - else: - if isinstance(output_obj, torch.Tensor): - self._logger.debug( - f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_obj.shape}, dtype {output_obj.dtype}' - ) - return output_obj - - def _backward_step(self, engine, input_obj, output_obj, output_obj_grad): - """Backward step through the passed-in output tensor. If it is the last stage, the - output_obj_grad is None, otherwise it is the gradients with respect to stage's output tensor. - Returns the gradients with respect to the input tensor (None if first stage). - This is a helper function and can be ignored by users. - - Args: - engine (colossalai.engine.Engine): Colossalai engine for training and inference. - input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): input tensor for this pipeline stage. - output_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): output tensor for this pipeline stage. - output_obj_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): gradient of output tensor for this pipeline stage. - - Returns: - Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: gradient of input tensor. - """ - - # Retain the grad on the input_obj. - if input_obj is not None: - if isinstance(input_obj, torch.Tensor): - input_obj.retain_grad() - else: - for in_tensor in input_obj: - if in_tensor is not None: - in_tensor.retain_grad() - # Backward pass. - if output_obj_grad is None: - engine.backward(output_obj) - else: - engine.backward_by_grad(output_obj, output_obj_grad) - - # Collect the grad of the input_obj. - input_obj_grad = None - if input_obj is not None: - if isinstance(input_obj, torch.Tensor): - input_obj_grad = input_obj.grad - else: - input_obj_grad = [] - for in_tensor in input_obj: - input_obj_grad.append(in_tensor.grad) - - return input_obj_grad - - def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True): - """Runs non-interleaved 1F1B schedule, with communication between pipeline stages. - Returns a tuple with losses if the last stage, an empty tuple otherwise. - - Args: - engine (colossalai.engine.Engine): Colossalai engine for training and inference. - data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader). - forward_only (bool, optional): - Whether run forward step only. Default is false. If true, no backward will be run. - return_loss (bool, optional): Whether returns the loss value. Default is true. - return_output_label (bool, optional): If False, the output and label won't be returned. - - Returns: - Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None. - """ - - assert forward_only or return_loss, \ - 'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.' - self.load_batch(data_iter) - - # num_warmup_microbatches is the step when not all the processers are working - num_warmup_microbatches = \ - (gpc.get_world_size(ParallelMode.PIPELINE) - - gpc.get_local_rank(ParallelMode.PIPELINE) - 1) - num_warmup_microbatches = min(num_warmup_microbatches, self.num_microbatches) - num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches - - # Input, output tensors only need to be saved when doing backward passes - input_objs = None - output_objs = None - local_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - - - if not forward_only: - input_objs = [] - output_objs = [] - return_tensors = [] - if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True): - accum_loss = torch.zeros(1, device=get_current_device()) - else: - accum_loss = None - - # Run warmup forward passes. - for i in range(num_warmup_microbatches): - # print(Back.BLUE, "rank {}".format(local_rank), Style.RESET_ALL, "ready to recv_forward") - input_obj = comm.recv_forward() - # print(Back.BLUE, "rank {}".format(local_rank), Style.RESET_ALL, "finish recv_forward") - - output_obj = self._forward_step(engine, - input_obj, - return_tensors, - return_output_label=return_output_label, - accum_loss=accum_loss) - - comm.send_forward(output_obj) - # print(Back.BLUE, "rank {}".format(local_rank), Style.RESET_ALL, "finish send_forward") - - if not forward_only: - input_objs.append(input_obj) - output_objs.append(output_obj) - - # print(Back.GREEN, "rank {}".format(local_rank), Style.RESET_ALL, "warmup finish") - - # Before running 1F1B, need to receive first forward tensor. - # If all microbatches are run in warmup / cooldown phase, then no need to - # receive this tensor here. - if num_microbatches_remaining > 0: - input_obj = comm.recv_forward() - - - # Run 1F1B in steady state. - for i in range(num_microbatches_remaining): - last_iteration = (i == (num_microbatches_remaining - 1)) - - output_obj = self._forward_step(engine, - input_obj, - return_tensors, - return_output_label=return_output_label, - accum_loss=accum_loss) - if forward_only: - comm.send_forward(output_obj) - - if not last_iteration: - input_obj = comm.recv_forward() - - else: - # TODO adjust here - comm.send_forward(output_obj) - output_obj_grad = comm.recv_backward() - - # Add input_obj and output_obj to end of list. - input_objs.append(input_obj) - output_objs.append(output_obj) - - # Pop output_obj and output_obj from the start of the list for - # the backward pass. - input_obj = input_objs.pop(0) - output_obj = output_objs.pop(0) - - input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad) - - if last_iteration: - input_obj = None - comm.send_backward(input_obj_grad) - else: - input_obj = comm.recv_forward() - comm.send_backward(input_obj_grad) - - - # Run cooldown backward passes. - if not forward_only: - for i in range(num_warmup_microbatches): - input_obj = input_objs.pop(0) - output_obj = output_objs.pop(0) - - output_obj_grad = comm.recv_backward() - input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad) - comm.send_backward(input_obj_grad) - - if len(return_tensors) > 0: - output, label = pack_return_tensors(return_tensors) - return output, label, accum_loss - else: - return None, None, accum_loss - - -class InterleavedPipelineSchedule(PipelineSchedule): - - def __init__(self, - num_microbatches: int, - num_model_chunks: int, - data_process_func: Callable = None, - tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None, - scatter_gather_tensors: bool = False): - """A helper schedule class for pipeline parallelism running environment. - It uses interleaved 1F1B strategy. Other properties are similar as - :class:`NonPipelineSchedule`. - - Args: - num_microbatches (int): The number of microbatches. - num_model_chunks (int): The number of model chunks. - data_process_func (Callable, optional): - The preprocessing function which receives a batch of data, and it will be executed in `load_batch`. - tensor_shape (torch.Size, optional): Specified shape in pipeline communication. - scatter_gather_tensors (bool, optional): - If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization. - """ - assert num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0, \ - 'num_microbatches must be an integer multiple of pipeline parallel world size' - assert isinstance(num_model_chunks, int) and num_model_chunks > 0, \ - f'expected num_model_chunks to be an integer and larger than 0, but got {num_model_chunks}' - super().__init__(num_microbatches, - data_process_func=data_process_func, - tensor_shape=tensor_shape, - scatter_gather_tensors=scatter_gather_tensors) - gpc.set_virtual_pipeline_parallel_size(num_model_chunks) - gpc.set_virtual_pipeline_parallel_rank(0) - self.num_model_chunks = num_model_chunks - - def pre_processing(self, engine): - from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 - if isinstance(engine.model, ShardedModelV2): - self.dtype = torch.half - elif isinstance(engine.model[0], NaiveAMPModel): - self.dtype = torch.half - for model in engine.model: - if isinstance(model, NaiveAMPModel): - model = model.model - sig = inspect.signature(model.forward) - for p in sig.parameters.values(): - assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported' - - def load_batch(self, data_iter): - super().load_batch(data_iter) - # overwrite microbatch_offset, since model chunks load the same microbatch, and should tract the offset - self.microbatch_offset = [0 for _ in range(self.num_model_chunks)] - - def load_micro_batch(self, model_chunk_id): - data = self._get_data_slice(self.batch_data, self.microbatch_offset[model_chunk_id]) - self.microbatch_offset[model_chunk_id] += self.microbatch_size - return self._move_to_device(data) - - def _forward_step(self, - engine, - model_chunk_id, - input_obj, - return_tensors, - return_output_label=True, - accum_loss=None): - """Forward step for passed-in model. If it is the first stage, the input tensor - is obtained from data_iterator, otherwise the passed-in input_obj is used. - Returns output tensor. This is a helper function and can be ignored by users. - - Args: - engine (colossalai.engine.Engine): Colossalai engine for training and inference. - model_chunk_id (int): The id of model chunks. - input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Input tensor for this pipeline stage. - return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return. - return_output_label (bool, optional): Whether returns output labels. - accum_loss (optional): Where accumulated loss stores. - Returns: - Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current pipeline stage. - """ - micro_batch_data = self.load_micro_batch(model_chunk_id) - data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data, engine.criterion, - engine.model[model_chunk_id]) - - output_obj = self._call_engine(engine.model[model_chunk_id], data) - - if gpc.is_pipeline_last_stage(): - if return_output_label: - return_tensors.append((output_obj, label)) - if accum_loss is not None: - loss_reduced = self._call_engine_criterion(engine, output_obj, label) / self.num_microbatches - accum_loss.add_(loss_reduced.detach()) - return loss_reduced - else: - # forward only, it's useless since backward is not needed - return output_obj - else: - if isinstance(output_obj, torch.Tensor): - self._logger.debug( - f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_obj.shape}, dtype {output_obj.dtype}' - ) - return output_obj - - def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True): - """Run interleaved 1F1B schedule (model split into model chunks), with - communication between pipeline stages as needed. - - Args: - engine (colossalai.engine.Engine): Colossalai engine for training and inference. - data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader). - forward_only (bool, optional): - Whether run forward step only. Default is false. If true, no backward will be run. - return_loss (bool, optional): Whether returns the loss value. Default is true. - return_output_label (bool, optional): If False, the output and label won't be returned. - - Returns: - Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None. - The loss would be returned only in the last stage. - """ - assert forward_only or return_loss, \ - 'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.' - self.load_batch(data_iter) - model = engine.model - input_objs = [[] for _ in range(len(model))] - output_objs = [[] for _ in range(len(model))] - return_tensors = [] - if not forward_only: - output_obj_grads = [[] for _ in range(len(model))] - if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True): - accum_loss = torch.zeros(1, device=get_current_device()) - else: - accum_loss = None - - # Used for obj meta information communication - input_obj_shapes = [self.tensor_shape for _ in range(len(model))] - output_obj_shapes = [None for _ in range(len(model))] - send_tensor_shape_flags = [self.tensor_shape is None for _ in range(len(model))] - - pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE) - pipeline_parallel_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - - # Compute number of warmup and remaining microbatches. - num_model_chunks = len(model) - num_microbatches = self.num_microbatches * num_model_chunks - all_warmup_microbatches = False - if forward_only: - num_warmup_microbatches = num_microbatches - else: - # Run all forward passes and then all backward passes if number of - # microbatches is just the number of pipeline stages. - # Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on - # all workers, followed by more microbatches after depending on - # stage ID (more forward passes for earlier stages, later stages can - # immediately start with 1F1B). - if self.num_microbatches == pipeline_parallel_size: - num_warmup_microbatches = num_microbatches - all_warmup_microbatches = True - else: - num_warmup_microbatches = \ - (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2 - num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size - num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) - num_microbatches_remaining = \ - num_microbatches - num_warmup_microbatches - - def get_model_chunk_id(microbatch_id, forward): - """Helper method to get the model chunk ID given the iteration number.""" - microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks) - model_chunk_id = microbatch_id_in_group // pipeline_parallel_size - if not forward: - model_chunk_id = (num_model_chunks - model_chunk_id - 1) - return model_chunk_id - - def _forward_step_helper(microbatch_id): - """Helper method to run forward step with model split into chunks - (run set_virtual_pipeline_model_parallel_rank() before calling - forward_step()).""" - model_chunk_id = get_model_chunk_id(microbatch_id, forward=True) - gpc.set_virtual_pipeline_parallel_rank(model_chunk_id) - - # forward step - if gpc.is_pipeline_first_stage(): - if len(input_objs[model_chunk_id]) == \ - len(output_objs[model_chunk_id]): - input_objs[model_chunk_id].append(None) - input_obj = input_objs[model_chunk_id][-1] - output_obj = self._forward_step(engine, - model_chunk_id, - input_obj, - return_tensors, - return_output_label=return_output_label, - accum_loss=accum_loss) - output_objs[model_chunk_id].append(output_obj) - - # if forward-only, no need to save tensors for a backward pass - if forward_only: - input_objs[model_chunk_id].pop() - output_objs[model_chunk_id].pop() - - return output_obj - - def _backward_step_helper(microbatch_id): - """Helper method to run backward step with model split into chunks - (run set_virtual_pipeline_model_parallel_rank() before calling - backward_step()).""" - model_chunk_id = get_model_chunk_id(microbatch_id, forward=False) - gpc.set_virtual_pipeline_parallel_rank(model_chunk_id) - - if gpc.is_pipeline_last_stage(): - if len(output_obj_grads[model_chunk_id]) == 0: - output_obj_grads[model_chunk_id].append(None) - input_obj = input_objs[model_chunk_id].pop(0) - output_obj = output_objs[model_chunk_id].pop(0) - output_obj_grad = output_obj_grads[model_chunk_id].pop(0) - input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad) - - return input_obj_grad - - # Run warmup forward passes. - gpc.set_virtual_pipeline_parallel_rank(0) - if not gpc.is_pipeline_first_stage(): - input_obj_shapes[0] = comm.recv_obj_meta(input_obj_shapes[0]) - input_objs[0].append( - comm.recv_forward(input_obj_shapes[0], dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors)) - - for k in range(num_warmup_microbatches): - model_chunk_id = get_model_chunk_id(k, forward=True) - output_obj = _forward_step_helper(k) - if not gpc.is_pipeline_last_stage(): - if isinstance(output_obj, torch.Tensor): - output_obj_shapes[model_chunk_id] = output_obj.shape - else: - output_obj_shapes[model_chunk_id] = [] - for out_tensor in output_obj: - output_obj_shapes[model_chunk_id].append(out_tensor.shape) - send_tensor_shape_flags[model_chunk_id] = comm.send_obj_meta(output_obj, - send_tensor_shape_flags[model_chunk_id]) - # Determine if tensor should be received from previous stage. - next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True) - recv_prev = True - if gpc.is_pipeline_first_stage(ignore_virtual=True): - if next_forward_model_chunk_id == 0: - recv_prev = False - if k == (num_microbatches - 1): - recv_prev = False - - # Don't send tensor downstream if on last stage. - if gpc.is_pipeline_last_stage(): - output_obj = None - - with switch_virtual_pipeline_parallel_rank(next_forward_model_chunk_id): - if not gpc.is_pipeline_first_stage(): - input_obj_shapes[next_forward_model_chunk_id] = comm.recv_obj_meta( - input_obj_shapes[next_forward_model_chunk_id]) - # Send and receive tensors as appropriate (send tensors computed - # in this iteration; receive tensors for next iteration). - input_shape = input_obj_shapes[next_forward_model_chunk_id] if recv_prev else None - if k == (num_warmup_microbatches - 1) and not forward_only and \ - not all_warmup_microbatches: - input_obj_grad = None - recv_next = True - if gpc.is_pipeline_last_stage(ignore_virtual=True): - recv_next = False - output_shape = output_obj_shapes[num_model_chunks - 1] if recv_next else None - input_obj, output_obj_grad = \ - comm.send_forward_backward_recv_forward_backward( - output_obj, input_obj_grad, - input_shape, - output_shape, - recv_prev=recv_prev, recv_next=recv_next, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors) - output_obj_grads[num_model_chunks - 1].append(output_obj_grad) - else: - input_obj = \ - comm.send_forward_recv_forward( - output_obj, - input_shape, - recv_prev=recv_prev, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors) - input_objs[next_forward_model_chunk_id].append(input_obj) - - # Run 1F1B in steady state. - for k in range(num_microbatches_remaining): - # Forward pass. - forward_k = k + num_warmup_microbatches - output_obj = _forward_step_helper(forward_k) - - # Backward pass. - backward_k = k - input_obj_grad = _backward_step_helper(backward_k) - - # Send output_obj and input_obj_grad, receive input_obj - # and output_obj_grad. - - # Determine if current stage has anything to send in either direction, - # otherwise set obj to None. - forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True) - gpc.set_virtual_pipeline_parallel_rank(forward_model_chunk_id) - if gpc.is_pipeline_last_stage(): - output_obj = None - - backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False) - gpc.set_virtual_pipeline_parallel_rank(backward_model_chunk_id) - if gpc.is_pipeline_first_stage(): - input_obj_grad = None - - # Determine if peers are sending, and where in data structure to put - # received tensors. - recv_prev = True - if gpc.is_pipeline_first_stage(ignore_virtual=True): - # First stage is ahead of last stage by (pipeline_parallel_size - 1). - next_forward_model_chunk_id = get_model_chunk_id(forward_k - (pipeline_parallel_size - 1), forward=True) - if next_forward_model_chunk_id == (num_model_chunks - 1): - recv_prev = False - next_forward_model_chunk_id += 1 - else: - next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True) - - recv_next = True - if gpc.is_pipeline_last_stage(ignore_virtual=True): - # Last stage is ahead of first stage by (pipeline_parallel_size - 1). - next_backward_model_chunk_id = get_model_chunk_id(backward_k - (pipeline_parallel_size - 1), - forward=False) - if next_backward_model_chunk_id == 0: - recv_next = False - next_backward_model_chunk_id -= 1 - else: - next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False) - - # If last iteration, don't receive; we already received one extra - # before the start of the for loop. - if k == (num_microbatches_remaining - 1): - recv_prev = False - - input_shape = input_obj_shapes[next_forward_model_chunk_id] if recv_prev else None - output_shape = output_obj_shapes[next_backward_model_chunk_id] if recv_next else None - # Communicate objs. - input_obj, output_obj_grad = \ - comm.send_forward_backward_recv_forward_backward( - output_obj, input_obj_grad, - input_shape, - output_shape, - recv_prev=recv_prev, recv_next=recv_next, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors) - - # Put input_obj and output_obj_grad in data structures in the - # right location. - if recv_prev: - input_objs[next_forward_model_chunk_id].append(input_obj) - if recv_next: - output_obj_grads[next_backward_model_chunk_id].append(output_obj_grad) - - # Run cooldown backward passes (flush out pipeline). - if not forward_only: - if all_warmup_microbatches: - output_obj_grads[num_model_chunks - 1].append( - comm.recv_backward(output_obj_shapes[num_model_chunks - 1], - scatter_gather_tensors=self.scatter_gather_tensors)) - for k in range(num_microbatches_remaining, num_microbatches): - input_obj_grad = _backward_step_helper(k) - next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False) - recv_next = True - if gpc.is_pipeline_last_stage(ignore_virtual=True): - if next_backward_model_chunk_id == (num_model_chunks - 1): - recv_next = False - if k == (num_microbatches - 1): - recv_next = False - output_shape = output_obj_shapes[next_backward_model_chunk_id] if recv_next else None - output_obj_grads[next_backward_model_chunk_id].append( - comm.send_backward_recv_backward(input_obj_grad, - output_shape, - recv_next=recv_next, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors)) - - if len(return_tensors) > 0: - output, label = pack_return_tensors(return_tensors) - return output, label, accum_loss - else: - return None, None, accum_loss From 5633e9b1ccb3e9eb44f6cb3d15dd4d12d75b0f5d Mon Sep 17 00:00:00 2001 From: Kirigaya Kazuto <59416203+LSTM-Kirigaya@users.noreply.github.com> Date: Fri, 5 Aug 2022 17:59:09 +0800 Subject: [PATCH 2/5] Delete test_cifar_with_data_pipeline_tensor_v2.py --- ...test_cifar_with_data_pipeline_tensor_v2.py | 117 ------------------ 1 file changed, 117 deletions(-) delete mode 100644 tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor_v2.py diff --git a/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor_v2.py b/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor_v2.py deleted file mode 100644 index 25c56cb4a62d..000000000000 --- a/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor_v2.py +++ /dev/null @@ -1,117 +0,0 @@ -import os - -from functools import partial -from pathlib import Path - -import colossalai -import pytest -import torch -import torch.multiprocessing as mp -from colossalai.amp import AMP_TYPE -from colossalai.trainer import Trainer, hooks -from colossalai.context import ParallelMode -from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus -from colossalai.utils import free_port -from colossalai.core import global_context as gpc -from colossalai.logging import get_dist_logger -from colossalai.nn import CrossEntropyLoss -from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR -from colossalai.utils import get_dataloader -from colossalai.pipeline.pipelinable import PipelinableContext -from colossalai.logging import disable_existing_loggers -from torchvision.datasets import CIFAR10 -from torchvision import transforms - -from colossalai.engine.schedule._pipeline_schedule_v2 import PipelineSchedule as PSV2 - -disable_existing_loggers() -BATCH_SIZE = 4 -NUM_EPOCHS = 10 -WARMUP_EPOCHS = 5 -CONFIG = dict(NUM_MICRO_BATCHES=2, - parallel=dict(pipeline=2, tensor=dict(size=1, mode='1d')), - fp16=dict(mode=AMP_TYPE.NAIVE), - gradient_accumulation=2) - - -def run_trainer(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - - disable_existing_loggers() - # get logger - logger = get_dist_logger() - - pipelinable = PipelinableContext() - try: - from titans.model.vit import vit_tiny_patch4_32 - except ImportError: - logger.warning('skip the test_cifar_with_data_pipeline_tensor test because titan is not installed') - logger.warning('please install titan from https://github.com/hpcaitech/Titans') - return - with pipelinable: - model = vit_tiny_patch4_32() - pipelinable.to_layer_list() - pipelinable.policy = "uniform" - model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE)) - - # craete dataloaders - root = Path(os.environ['DATA']) - transform_train = transforms.Compose([ - transforms.RandomCrop(32, padding=4, pad_if_needed=True), - transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10), - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), - ]) - train_dataset = CIFAR10(root=root, train=True, download=True, transform=transform_train) - train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=BATCH_SIZE, pin_memory=True) - - # create loss function - criterion = CrossEntropyLoss(label_smoothing=0.1) - - # create optimizer - optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0) - - # create lr scheduler - lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, total_steps=NUM_EPOCHS, warmup_steps=WARMUP_EPOCHS) - - # intiailize - engine, train_dataloader, *_ = colossalai.initialize(model=model, - optimizer=optimizer, - criterion=criterion, - train_dataloader=train_dataloader) - - engine._schedule = PSV2(num_microbatches=gpc.config.NUM_MICRO_BATCHES) - # print("enter" * 20) - # # test v2 schedule - # try: - # engine._schedule = PipelineSchedule(gpc.config.NUM_MICRO_BATCHES) - # except Exception as e: - # print(e) - # return - logger = get_dist_logger() - - - trainer = Trainer(engine=engine, logger=logger) - - hook_list = [ - hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False), - ] - - trainer.fit(train_dataloader=train_dataloader, - epochs=NUM_EPOCHS, - hooks=hook_list, - display_progress=True) - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_hybrid_parallel(): - world_size = 2 - run_func = partial(run_trainer, world_size=world_size, port=free_port()) - disable_existing_loggers() - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_hybrid_parallel() From b342701629da81fc261d88f6b8b98273d9d36405 Mon Sep 17 00:00:00 2001 From: LSTM-Kirigaya <1193466151@qq.com> Date: Fri, 5 Aug 2022 18:15:52 +0800 Subject: [PATCH 3/5] [engin/schedule] use p2p_v2 to recontruct pipeline_schedule --- colossalai/communication/p2p_v2.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/colossalai/communication/p2p_v2.py b/colossalai/communication/p2p_v2.py index 71a78e1a8943..c17204de4537 100644 --- a/colossalai/communication/p2p_v2.py +++ b/colossalai/communication/p2p_v2.py @@ -17,6 +17,7 @@ TensorShape = Union[torch.Size, List[int], Tuple[int]] _pg_manager = {} +_unpickler = pickle.Unpickler def init_process_group(): @@ -70,7 +71,7 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) - buf = bytes(buf_array) io_bytes = io.BytesIO(buf) - byte_pickler = pickle.Unpickler(io_bytes) + byte_pickler = _unpickler(io_bytes) unpickle = byte_pickler.load() return unpickle From f97a76f2717408febebb387a6cb73cc1da4e2838 Mon Sep 17 00:00:00 2001 From: LSTM-Kirigaya <1193466151@qq.com> Date: Mon, 8 Aug 2022 15:34:38 +0800 Subject: [PATCH 4/5] [communication] remove print code --- colossalai/communication/p2p_v2.py | 19 ++--------------- .../test_comm/test_boardcast_send_recv_v2.py | 10 --------- tests/test_comm/test_object_list_p2p_v2.py | 21 ------------------- 3 files changed, 2 insertions(+), 48 deletions(-) diff --git a/colossalai/communication/p2p_v2.py b/colossalai/communication/p2p_v2.py index c17204de4537..9b96d622f539 100644 --- a/colossalai/communication/p2p_v2.py +++ b/colossalai/communication/p2p_v2.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from typing import List, Tuple, Union, Any, Dict +from typing import List, Tuple, Union, Any import pickle import io @@ -9,8 +9,6 @@ import torch.distributed as dist from torch.distributed import distributed_c10d as c10d from torch.distributed import ProcessGroupNCCL -# TODO remove it when release -from colorama import Back, Style from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc @@ -93,9 +91,7 @@ def _broadcast_object_list(object_list: List[Any], src: int, dst: int, device=No group = _acquire_pair_group_handle(src, dst) if c10d._rank_not_in_group(group): - # c10d._warn_not_in_group("broadcast_object_list") - print(Back.RED, "ERROR", Style.RESET_ALL, - "{} and {} has abnormal reflection, broadcast failed!".format(src, dst)) + c10d._warn_not_in_group("broadcast_object_list") return local_rank = gpc.get_local_rank(ParallelMode.PIPELINE) @@ -122,7 +118,6 @@ def _broadcast_object_list(object_list: List[Any], src: int, dst: int, device=No # Broadcast object sizes c10d.broadcast(object_sizes_tensor, src=src, group=group, async_op=False) - # print(Back.CYAN, "inner broadcast length", Style.RESET_ALL, "{} finish {} {}".format(local_rank, local_rank, src)) # Concatenate and broadcast serialized object tensors if local_rank == src: @@ -137,7 +132,6 @@ def _broadcast_object_list(object_list: List[Any], src: int, dst: int, device=No object_tensor = object_tensor.to(current_device) c10d.broadcast(object_tensor, src=src, group=group, async_op=False) - # print(Back.CYAN, "inner broadcast content", Style.RESET_ALL, "rank_{} finish {} {}".format(local_rank, local_rank, src)) # Deserialize objects using their stored sizes. offset = 0 @@ -158,9 +152,6 @@ def _broadcast_object_list(object_list: List[Any], src: int, dst: int, device=No unpickle_object = unpickle_object.cuda() object_list[i] = unpickle_object - # print(Back.BLUE, "this is rank_{}".format(gpc.get_local_rank(ParallelMode.PIPELINE)), Style.RESET_ALL, object_list) - - # print(Back.GREEN, "rank_{} finish _broadcast_object_list".format(local_rank), Style.RESET_ALL) def _send_object(object: Any, dst: int) -> None: @@ -182,12 +173,9 @@ def _send_object(object: Any, dst: int) -> None: # broadcast length first # TODO : more elegant ? P.S. reduce a _broadcast_object_list _broadcast_object_list([len(object)], local_rank, dst) - # print(Back.LIGHTMAGENTA_EX, "[_send_object]", Style.RESET_ALL, "rank_{} send length {} to rank_{}".format(local_rank, [len(object)], dst)) # then broadcast safely _broadcast_object_list(object, local_rank, dst) - # print(Back.LIGHTGREEN_EX, "[_send_object]", Style.RESET_ALL, "rank_{} send {} to rank_{}".format(local_rank, type(object), dst)) - def _recv_object(src: int) -> Any: """recv anything from src @@ -202,17 +190,14 @@ def _recv_object(src: int) -> Any: # handler = _acquire_pair_group_handle(local_rank, src) # recv length first length = [0] - # print(Back.LIGHTYELLOW_EX, "[_recv_object]", Style.RESET_ALL, "rank_{} waiting for msg from rank_{}".format(local_rank, src)) _broadcast_object_list(length, src, local_rank) # then create recv buff from length[0] and broadcast object = [None] * length[0] - # print(Back.MAGENTA, "[_recv_object]", Style.RESET_ALL, "rank_{} recv length {} from rank_{}".format(local_rank, length, src)) _broadcast_object_list(object, src, local_rank) if length[0] == 1: object = object[0] - # print(Back.GREEN, "[_recv_object]", Style.RESET_ALL, "rank_{} recv {} from rank_{}".format(local_rank, type(object), src)) return object diff --git a/tests/test_comm/test_boardcast_send_recv_v2.py b/tests/test_comm/test_boardcast_send_recv_v2.py index d709d49d84cf..1520d6054043 100644 --- a/tests/test_comm/test_boardcast_send_recv_v2.py +++ b/tests/test_comm/test_boardcast_send_recv_v2.py @@ -23,15 +23,6 @@ def check_layer(rank, world_size, port): disable_existing_loggers() launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl', verbose=False) rank = gpc.get_local_rank(ParallelMode.PIPELINE) - # obj = [ - # rank, - # f"hello, I am rank {rank}", - # torch.randn((3, )), - # None, - # (1, 2, 3), - # [2, 3, 3], - # {1 : 1, 3 : 4} - # ] if rank == 0: obj = [torch.randn(3,)] @@ -47,7 +38,6 @@ def check_layer(rank, world_size, port): obj = [torch.randn(3,)] _send_object(obj, 2) - print(f"rank {rank} fin") gpc.destroy() torch.cuda.empty_cache() diff --git a/tests/test_comm/test_object_list_p2p_v2.py b/tests/test_comm/test_object_list_p2p_v2.py index c9c196329a27..c639ac9f8ef3 100644 --- a/tests/test_comm/test_object_list_p2p_v2.py +++ b/tests/test_comm/test_object_list_p2p_v2.py @@ -43,20 +43,14 @@ def check_send_recv_forward(): data_list_to_send.append(data_in_list.to(device)) send_forward(data_to_send, scatter_gather_tensors=use_scatter_gather_tensors) - # print("finish send_forward(data_to_send)") send_forward(data_list_to_send, scatter_gather_tensors=use_scatter_gather_tensors) - # print("finish send_forward(data_list_to_send)") - print("rank1 {}".format(data_to_send)) elif local_rank == 1: device = torch.device('cuda:1') data_recv = recv_forward(TENSOR_SIZE, scatter_gather_tensors=use_scatter_gather_tensors) - # print("finish data_recv = recv_forward(TENSOR_SIZE)") data_list_recv = recv_forward(TENSOR_SIZE_LIST, scatter_gather_tensors=use_scatter_gather_tensors) - # print("finish data_recv = recv_forward(TENSOR_SIZE_LIST)") - print("rank1 {}".format(data_recv)) data_to_check = data.to(device) assert data_recv.equal(data_to_check) @@ -66,8 +60,6 @@ def check_send_recv_forward(): data_recv = data_recv.to(device) assert data_recv.equal(data_to_check) - print("[forward] rank 1, recv and check all right") - def check_send_recv_backward(): disable_existing_loggers() @@ -84,7 +76,6 @@ def check_send_recv_backward(): grad_recv = grad_recv.to(device) grad_to_check = grad_send.to(device) assert grad_recv.equal(grad_to_check) - print("[backward] rank 0, recv and check all right") else: device = torch.device('cuda:1') grad_to_send = grad.to(device) @@ -93,7 +84,6 @@ def check_send_recv_backward(): grad_list_to_send.append(grad_in_list.to(device)) send_backward(grad_to_send) send_backward(grad_list_to_send) - print("[backward] rank 1, send") def check_small_pipeline(): @@ -102,30 +92,19 @@ def check_small_pipeline(): assert gpc.world_size == 4, "make sure to set world size to 4 to start the training process" local_rank = gpc.get_local_rank(ParallelMode.PIPELINE) if local_rank == 0: - # print("I am {}, next is {}, prev is {}".format(local_rank, gpc.get_next_global_rank(ParallelMode.PIPELINE), gpc.get_prev_global_rank(ParallelMode.PIPELINE))) obj = [1, torch.randn(2, 2).cuda(), None] send_forward(obj) elif local_rank == 1: - # print("I am {}, next is {}, prev is {}".format(local_rank, gpc.get_next_global_rank(ParallelMode.PIPELINE), gpc.get_prev_global_rank(ParallelMode.PIPELINE))) obj = recv_forward() send_forward(obj) - print("rank_{} {}".format(local_rank, obj)) elif local_rank == 2: - # print("I am {}, next is {}, prev is {}".format(local_rank, gpc.get_next_global_rank(ParallelMode.PIPELINE), gpc.get_prev_global_rank(ParallelMode.PIPELINE))) obj = recv_forward() - print("rank_{} {}".format(local_rank, obj)) send_forward(obj) elif local_rank == 3: - # print("I am {}, next is {}, prev is {}".format(local_rank, gpc.get_next_global_rank(ParallelMode.PIPELINE), gpc.get_prev_global_rank(ParallelMode.PIPELINE))) - # import time - # time.sleep(5) obj = recv_forward() - print("rank_{} {}".format(local_rank, obj)) else: pass - print("rank {} fin".format(local_rank)) - def check_layer(rank, world_size, port): disable_existing_loggers() From 3998bae9b8abba7fe7a4035c2ef0fb78195151ea Mon Sep 17 00:00:00 2001 From: LSTM-Kirigaya <1193466151@qq.com> Date: Mon, 8 Aug 2022 15:52:30 +0800 Subject: [PATCH 5/5] [communication] remove print code --- colossalai/communication/p2p_v2.py | 40 ------------------------------ 1 file changed, 40 deletions(-) diff --git a/colossalai/communication/p2p_v2.py b/colossalai/communication/p2p_v2.py index 9b96d622f539..0b575e7dba77 100644 --- a/colossalai/communication/p2p_v2.py +++ b/colossalai/communication/p2p_v2.py @@ -266,43 +266,3 @@ def send_backward(input_object: Any, prev_rank: int = None) -> None: if prev_rank is None: prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) _send_object(input_object, prev_rank) - - -def send_forward_recv_backward(): - """reserve - """ - pass - - -def send_backward_recv_forward(): - """reserve - """ - pass - - -def send_forward_recv_forward(): - """reserve - """ - pass - - -def send_backward_recv_backward(): - """reserve - """ - pass - - -def send_forward_backward_recv_forward_backward( - output_tensor, - input_tensor_grad, - input_tensor_shape, - output_grad_shape, - recv_prev=True, - recv_next=True, - prev_rank=None, - next_rank=None, - dtype=torch.float, - scatter_gather_tensors=False) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]: - """reserve - """ - pass