Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[autoparallel] add numerical test for node strategies #1760

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 32 additions & 13 deletions colossalai/auto_parallel/passes/runtime_apply_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def runtime_apply(node: Node, origin_dict: Dict, input_dict: Dict, node_index: i
"""
origin_sharding_spec = origin_dict[node_index]
target_sharding_spec = input_dict[node_index][user_node_index]

return shape_consistency_manager.apply_for_autoparallel_runtime(node, origin_sharding_spec, target_sharding_spec)


Expand Down Expand Up @@ -81,18 +80,22 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
if not hasattr(node, 'best_strategy') or node.op == 'output':
continue

for user_node in node.strategies_vector.successor_nodes:
user_node_index = user_node.strategies_vector.predecessor_nodes.index(node)
for user_node_index, user_node in enumerate(node.strategies_vector.successor_nodes):
# user_node_index = user_node.strategies_vector.predecessor_nodes.index(node)
with mod_graph.inserting_before(user_node):
shape_consistency_node = mod_graph.create_node('call_function',
runtime_apply,
args=(node, origin_dict_node, input_dict_node,
node_to_index_dict[node], user_node_index))

origin_index_args = user_node.args.index(node)
new_args = list(user_node.args)
new_args[origin_index_args] = shape_consistency_node
user_node.args = new_args
new_kwargs = dict(user_node.kwargs)
if node in new_args:
origin_index_args = new_args.index(node)
new_args[origin_index_args] = shape_consistency_node
user_node.args = new_args
elif str(node) in new_kwargs:
new_kwargs[str(node)] = shape_consistency_node
user_node.kwargs = new_kwargs

return gm

Expand All @@ -112,18 +115,29 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):

comm_actions = node.best_strategy.communication_actions
for op_data, comm_action in comm_actions.items():
comm_object = node.args[comm_action.arg_index]

if op_data.type == OperationDataType.PARAM:
continue
if comm_action.comm_type == CommType.BEFORE:
if comm_action.key_for_kwarg is not None:
comm_object = node.kwargs[comm_action.key_for_kwarg]
else:
comm_object = node.args[comm_action.arg_index]
with mod_graph.inserting_before(node):
comm_spec_apply_node = mod_graph.create_node('call_function',
runtime_comm_spec_apply,
args=(comm_object, comm_actions_dict_node,
node_to_index_dict[node], op_data.name))
new_args = list(node.args)
new_args[comm_action.arg_index] = comm_spec_apply_node
node.args = new_args

if comm_action.key_for_kwarg is not None:
new_kwargs = dict(node.kwargs)
new_kwargs[comm_action.key_for_kwarg] = comm_spec_apply_node
node.kwargs = new_kwargs
else:
new_args = list(node.args)
new_args[comm_action.arg_index] = comm_spec_apply_node
node.args = new_args

elif comm_action.comm_type == CommType.AFTER:
with mod_graph.inserting_after(node):
comm_spec_apply_node = mod_graph.create_node('call_function',
Expand All @@ -135,8 +149,13 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
if user == comm_spec_apply_node:
continue
new_args = list(user.args)
new_args[new_args.index(node)] = comm_spec_apply_node
user.args = tuple(new_args)
new_kwargs = dict(user.kwargs)
if node in new_args:
new_args[new_args.index(node)] = comm_spec_apply_node
user.args = tuple(new_args)
elif str(node) in new_kwargs:
new_kwargs[node] = comm_spec_apply_node
user.kwargs = new_kwargs

return gm

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh):
if target_sharding_spec.dim_partition_dict != {}:
origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {})
setattr(param, 'sharding_spec', origin_sharding_spec)
# TODO: build a ColoParamter class to manager the distributed parameters
param_sharded = torch.nn.Parameter(
shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec,
target_sharding_spec).detach().clone())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,14 @@
from functools import reduce
from typing import List


from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)

from colossalai.auto_parallel.tensor_shard.utils import \
ignore_sharding_exception

from colossalai.auto_parallel.tensor_shard.utils import ignore_sharding_exception
from colossalai.tensor.shape_consistency import CollectiveCommPattern

from .strategy_generator import StrategyGenerator
Expand Down Expand Up @@ -135,7 +131,8 @@ def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE)
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping = {"input": input_comm_action}

if self.is_param("other"):
Expand Down Expand Up @@ -223,8 +220,7 @@ def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1):
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.AFTER,
arg_index=0)
comm_type=CommType.AFTER)

communication_action_mapping = {"output": output_comm_action}

Expand Down Expand Up @@ -277,8 +273,7 @@ def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1):
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.AFTER,
arg_index=0)
comm_type=CommType.AFTER)
input_comm_action = self.get_communication_action(
sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
Expand Down Expand Up @@ -316,8 +311,7 @@ def split_input_in_channel_weight_in_channel(self, mesh_dim_0):
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.AFTER,
arg_index=0)
comm_type=CommType.AFTER)

communication_action_mapping = {"output": output_comm_action}

Expand Down Expand Up @@ -351,7 +345,8 @@ def split_weight_out_channel(self, mesh_dim_0):
sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE)
comm_type=CommType.BEFORE,
arg_index=0)

communication_action_mapping = {"input": input_comm_action}

Expand Down Expand Up @@ -441,8 +436,7 @@ def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1):
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.AFTER,
arg_index=0)
comm_type=CommType.AFTER)

communication_action_mapping = {"output": output_comm_action}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,17 @@ def get_communication_action(self,
communication_pattern: CollectiveCommPattern,
logical_process_axis: Union[int, List[int]],
comm_type: CommType,
arg_index: int = -1) -> CommAction:
arg_index: int = -1,
key_for_kwarg: any = None) -> CommAction:
"""
A factory method to produce a CommAction object.
"""
return CommAction(comm_spec=self.get_communication_spec(sharding_spec=sharding_spec,
communication_pattern=communication_pattern,
logical_process_axis=logical_process_axis),
comm_type=comm_type,
arg_index=arg_index)
arg_index=arg_index,
key_for_kwarg=key_for_kwarg)

def update_communication_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class CommAction:
comm_spec: CommSpec = None
comm_type: CommType = None
arg_index: int = -1
key_for_kwarg: any = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this used for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a node is kwarg of other node, not a positional argument, we could use this to locate it in runtime pass.



@dataclass
Expand Down
19 changes: 16 additions & 3 deletions colossalai/device/device_mesh.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from functools import reduce
import operator
from functools import reduce

import torch
import torch.distributed as dist

Expand All @@ -11,7 +12,7 @@ class DeviceMesh:
can be viewed as a 1x16 or a 4x4 logical mesh). Each mesh dimension has its
own latency and bandwidth. We use alpha-beta model to model the
communication cost.

Arguments:
physical_mesh_id (torch.Tensor): physical view of the devices in global rank.
mesh_shape (torch.Size): shape of logical view.
Expand Down Expand Up @@ -64,6 +65,18 @@ def num_devices(self):
def logical_mesh_id(self):
return self._logical_mesh_id

def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
if k != 'process_groups_dict':
setattr(result, k, __import__("copy").deepcopy(v, memo))
else:
setattr(result, k, v)

return result

def flatten(self):
"""
Flatten the logical mesh into an effective 1d logical mesh,
Expand All @@ -90,7 +103,7 @@ def _global_rank_to_logical_rank_map(self, tensor, index_list):
def create_process_groups_for_logical_mesh(self):
'''
This method is used to initialize the logical process groups which will be used in communications
among logical device mesh.
among logical device mesh.
Note: if init_process_group set to False, you have to call this method manually. Otherwise,
the communication related function, such as ShapeConsistencyManager.apply will raise errors.
'''
Expand Down
Loading