Skip to content

Commit

Permalink
[autoparallel] add numerical test for node strategies (#1760)
Browse files Browse the repository at this point in the history
* [autoparallel] add numerical test for node strategies

* polish code

* polish code
  • Loading branch information
YuliangLiu0306 authored Oct 27, 2022
1 parent 25952b6 commit b4cc59b
Show file tree
Hide file tree
Showing 10 changed files with 285 additions and 62 deletions.
52 changes: 39 additions & 13 deletions colossalai/auto_parallel/passes/runtime_apply_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def runtime_apply(node: Node, origin_dict: Dict, input_dict: Dict, node_index: i
"""
origin_sharding_spec = origin_dict[node_index]
target_sharding_spec = input_dict[node_index][user_node_index]

return shape_consistency_manager.apply_for_autoparallel_runtime(node, origin_sharding_spec, target_sharding_spec)


Expand Down Expand Up @@ -81,18 +80,24 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
if not hasattr(node, 'best_strategy') or node.op == 'output':
continue

for user_node in node.strategies_vector.successor_nodes:
user_node_index = user_node.strategies_vector.predecessor_nodes.index(node)
for user_node_index, user_node in enumerate(node.strategies_vector.successor_nodes):
with mod_graph.inserting_before(user_node):
shape_consistency_node = mod_graph.create_node('call_function',
runtime_apply,
args=(node, origin_dict_node, input_dict_node,
node_to_index_dict[node], user_node_index))

origin_index_args = user_node.args.index(node)
new_args = list(user_node.args)
new_args[origin_index_args] = shape_consistency_node
user_node.args = new_args
new_kwargs = dict(user_node.kwargs)
# the origin node may be a positional argument or key word argument of user node
if node in new_args:
# substitute the origin node with shape_consistency_node
origin_index_args = new_args.index(node)
new_args[origin_index_args] = shape_consistency_node
user_node.args = new_args
elif str(node) in new_kwargs:
# substitute the origin node with shape_consistency_node
new_kwargs[str(node)] = shape_consistency_node
user_node.kwargs = new_kwargs

return gm

Expand All @@ -112,18 +117,31 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):

comm_actions = node.best_strategy.communication_actions
for op_data, comm_action in comm_actions.items():
comm_object = node.args[comm_action.arg_index]

if op_data.type == OperationDataType.PARAM:
continue
if comm_action.comm_type == CommType.BEFORE:
if comm_action.key_for_kwarg is not None:
comm_object = node.kwargs[comm_action.key_for_kwarg]
else:
comm_object = node.args[comm_action.arg_index]
with mod_graph.inserting_before(node):
comm_spec_apply_node = mod_graph.create_node('call_function',
runtime_comm_spec_apply,
args=(comm_object, comm_actions_dict_node,
node_to_index_dict[node], op_data.name))
new_args = list(node.args)
new_args[comm_action.arg_index] = comm_spec_apply_node
node.args = new_args
# the origin node may be a positional argument or key word argument of user node
if comm_action.key_for_kwarg is not None:
# substitute the origin node with comm_spec_apply_node
new_kwargs = dict(node.kwargs)
new_kwargs[comm_action.key_for_kwarg] = comm_spec_apply_node
node.kwargs = new_kwargs
else:
# substitute the origin node with comm_spec_apply_node
new_args = list(node.args)
new_args[comm_action.arg_index] = comm_spec_apply_node
node.args = new_args

elif comm_action.comm_type == CommType.AFTER:
with mod_graph.inserting_after(node):
comm_spec_apply_node = mod_graph.create_node('call_function',
Expand All @@ -135,8 +153,16 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
if user == comm_spec_apply_node:
continue
new_args = list(user.args)
new_args[new_args.index(node)] = comm_spec_apply_node
user.args = tuple(new_args)
new_kwargs = dict(user.kwargs)
# the origin node may be a positional argument or key word argument of user node
if node in new_args:
# substitute the origin node with comm_spec_apply_node
new_args[new_args.index(node)] = comm_spec_apply_node
user.args = tuple(new_args)
elif str(node) in new_kwargs:
# substitute the origin node with comm_spec_apply_node
new_kwargs[str(node)] = comm_spec_apply_node
user.kwargs = new_kwargs

return gm

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh):
if target_sharding_spec.dim_partition_dict != {}:
origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {})
setattr(param, 'sharding_spec', origin_sharding_spec)
# TODO: build a ColoParamter class to manager the distributed parameters
param_sharded = torch.nn.Parameter(
shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec,
target_sharding_spec).detach().clone())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,14 @@
from functools import reduce
from typing import List


from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)

from colossalai.auto_parallel.tensor_shard.utils import \
ignore_sharding_exception

from colossalai.auto_parallel.tensor_shard.utils import ignore_sharding_exception
from colossalai.tensor.shape_consistency import CollectiveCommPattern

from .strategy_generator import StrategyGenerator
Expand Down Expand Up @@ -135,7 +131,8 @@ def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE)
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping = {"input": input_comm_action}

if self.is_param("other"):
Expand Down Expand Up @@ -223,8 +220,7 @@ def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1):
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.AFTER,
arg_index=0)
comm_type=CommType.AFTER)

communication_action_mapping = {"output": output_comm_action}

Expand Down Expand Up @@ -277,8 +273,7 @@ def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1):
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.AFTER,
arg_index=0)
comm_type=CommType.AFTER)
input_comm_action = self.get_communication_action(
sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
Expand Down Expand Up @@ -316,8 +311,7 @@ def split_input_in_channel_weight_in_channel(self, mesh_dim_0):
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.AFTER,
arg_index=0)
comm_type=CommType.AFTER)

communication_action_mapping = {"output": output_comm_action}

Expand Down Expand Up @@ -351,7 +345,8 @@ def split_weight_out_channel(self, mesh_dim_0):
sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE)
comm_type=CommType.BEFORE,
arg_index=0)

communication_action_mapping = {"input": input_comm_action}

Expand Down Expand Up @@ -441,8 +436,7 @@ def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1):
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.AFTER,
arg_index=0)
comm_type=CommType.AFTER)

communication_action_mapping = {"output": output_comm_action}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,17 @@ def get_communication_action(self,
communication_pattern: CollectiveCommPattern,
logical_process_axis: Union[int, List[int]],
comm_type: CommType,
arg_index: int = -1) -> CommAction:
arg_index: int = -1,
key_for_kwarg: any = None) -> CommAction:
"""
A factory method to produce a CommAction object.
"""
return CommAction(comm_spec=self.get_communication_spec(sharding_spec=sharding_spec,
communication_pattern=communication_pattern,
logical_process_axis=logical_process_axis),
comm_type=comm_type,
arg_index=arg_index)
arg_index=arg_index,
key_for_kwarg=key_for_kwarg)

def update_communication_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
"""
Expand Down
1 change: 1 addition & 0 deletions colossalai/auto_parallel/tensor_shard/sharding_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class CommAction:
comm_spec: CommSpec = None
comm_type: CommType = None
arg_index: int = -1
key_for_kwarg: any = None


@dataclass
Expand Down
19 changes: 16 additions & 3 deletions colossalai/device/device_mesh.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from functools import reduce
import operator
from functools import reduce

import torch
import torch.distributed as dist

Expand All @@ -11,7 +12,7 @@ class DeviceMesh:
can be viewed as a 1x16 or a 4x4 logical mesh). Each mesh dimension has its
own latency and bandwidth. We use alpha-beta model to model the
communication cost.
Arguments:
physical_mesh_id (torch.Tensor): physical view of the devices in global rank.
mesh_shape (torch.Size): shape of logical view.
Expand Down Expand Up @@ -64,6 +65,18 @@ def num_devices(self):
def logical_mesh_id(self):
return self._logical_mesh_id

def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
if k != 'process_groups_dict':
setattr(result, k, __import__("copy").deepcopy(v, memo))
else:
setattr(result, k, v)

return result

def flatten(self):
"""
Flatten the logical mesh into an effective 1d logical mesh,
Expand All @@ -90,7 +103,7 @@ def _global_rank_to_logical_rank_map(self, tensor, index_list):
def create_process_groups_for_logical_mesh(self):
'''
This method is used to initialize the logical process groups which will be used in communications
among logical device mesh.
among logical device mesh.
Note: if init_process_group set to False, you have to call this method manually. Otherwise,
the communication related function, such as ShapeConsistencyManager.apply will raise errors.
'''
Expand Down
9 changes: 9 additions & 0 deletions colossalai/tensor/shape_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ class ShapeConsistencyOptions:
pass


def to_global(distributed_tensor: torch.Tensor, sharding_spec: ShardingSpec):
shape_consistency_manager = ShapeConsistencyManager()
global_sharding_spec = ShardingSpec(sharding_spec.device_mesh, sharding_spec.entire_shape, {})
with torch.no_grad():
global_tensor = shape_consistency_manager.apply_for_autoparallel_runtime(distributed_tensor, sharding_spec,
global_sharding_spec)
return global_tensor


def set_shape_consistency_options(options: ShapeConsistencyOptions):
"""
Configure the shape consistency manager via function call.
Expand Down
13 changes: 6 additions & 7 deletions colossalai/tensor/sharding_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import torch

from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.utils import (all_gather_simulator, all_to_all_simulator, shard_simulator)

__all__ = ['_DimSpec', 'ShardingException', 'ShardingSpec']

Expand All @@ -23,7 +22,7 @@ class _DimSpec:
This class is used internally in ShardingSpec.
Argument:
shard_list(List[int]): if shard_list is None, the dim spec will be 'R' type.
shard_list(List[int]): if shard_list is None, the dim spec will be 'R' type.
Otherwise, the element in shard_list means the data will be sharded in that dimension.
'''

Expand Down Expand Up @@ -62,7 +61,7 @@ def _convert_str_to_shard_list(self, str_spec):

def build_difference_2d_dict(self):
'''
Build a difference maping for 2D device mesh case. It will be used to
Build a difference maping for 2D device mesh case. It will be used to
compute the difference between DimSpec pairs.
'''

Expand Down Expand Up @@ -159,9 +158,9 @@ class ShardingNotDivisibleError(ShardingSpecException):
class ShardingSpec:
'''
Sharding spec for a tensor, it contains info of the logical device mesh this tensor belong
to, the entire shape of the tensor before sharded, and the sharding sequence looks like
to, the entire shape of the tensor before sharded, and the sharding sequence looks like
[R, R, S0, S1].
Argument:
device_mesh(DeviceMesh): A logical view of a physical mesh.
entire_shape(torch.Size): The entire shape of tensor before sharded.
Expand Down Expand Up @@ -260,10 +259,10 @@ def sharding_sequence_difference(self, other):
# device_mesh_shape: (4, 4)
sharding_spec_to_compare = ShardingSpec(device_mesh, entire_shape, dim_partition_dict_to_compare)
print(sharding_spec.sharding_sequence_difference(sharding_spec_to_compare))
Output:
25
Argument:
other(ShardingSpec): The ShardingSpec to compared with.
Expand Down
Loading

0 comments on commit b4cc59b

Please sign in to comment.