Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[autoparallel] update CommSpec to CommActions #1768

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -202,16 +202,17 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:

mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}

if self.node.args[2] is not None:
if 'bias' in self.node.kwargs and self.node.kwargs['bias'] is not None:
# check if the other operand is a parameter
if isinstance(self.node.args[2]._meta_data, torch.nn.parameter.Parameter):
if isinstance(self.node.kwargs["bias"]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
physical_bias_operand = OperationData(name=str(self.node.args[2]),
physical_bias_operand = OperationData(name=str(self.node.kwargs["bias"]),
type=data_type,
data=self.node.args[2]._meta_data)
data=self.node.kwargs["bias"]._meta_data)
mapping['bias'] = physical_bias_operand

return mapping

def post_process(self, strategy: ShardingStrategy):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
from functools import reduce
from typing import List

from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.tensor.shape_consistency import CollectiveCommPattern

from .strategy_generator import StrategyGenerator
Expand Down Expand Up @@ -204,12 +209,13 @@ def split_input_batch(self, mesh_dim_0):
# For SyncBN case, we don't need to do communication for weight and bias.
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
# to SyncBN operation instead of inserting a communication node.
output_comm_spec = self.get_communication_spec(
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0)
logical_process_axis=mesh_dim_0,
comm_type=CommType.AFTER)

communication_action_mapping = {"output": output_comm_spec}
communication_action_mapping = {"output": output_comm_action}

return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
Expand Down Expand Up @@ -238,12 +244,13 @@ def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1):
# For SyncBN case, we don't need to do communication for gradients of weight and bias.
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
# to SyncBN operation instead of inserting a communication node.
output_comm_spec = self.get_communication_spec(
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.AFTER)

communication_action_mapping = {"output": output_comm_spec}
communication_action_mapping = {"output": output_comm_action}

return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
Expand Down Expand Up @@ -282,12 +289,13 @@ def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
# For SyncBN case, we don't need to do communication for gradients of weight and bias.
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
# to SyncBN operation instead of inserting a communication node.
output_comm_spec = self.get_communication_spec(
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0])
logical_process_axis=[mesh_dim_0],
comm_type=CommType.AFTER)

communication_action_mapping = {"output": output_comm_spec}
communication_action_mapping = {"output": output_comm_action}

return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import copy
from typing import List

from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.tensor.shape_consistency import CollectiveCommPattern

from .strategy_generator import FollowingStrategyGenerator
Expand Down Expand Up @@ -83,11 +88,13 @@ def collate_strategies(self) -> List[ShardingStrategy]:
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
if gather_input:
input_communication_spec = self.get_communication_spec(
input_communication_action = self.get_communication_action(
sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
logical_process_axis=logical_process_axis)
communication_action_mapping["input"] = input_communication_spec
logical_process_axis=logical_process_axis,
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping["input"] = input_communication_action

name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}'

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,16 @@
from functools import reduce
from typing import List

from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.utils import (
enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
)
from colossalai.tensor.shape_consistency import CollectiveCommPattern

from .strategy_generator import StrategyGenerator
Expand Down Expand Up @@ -107,18 +114,20 @@ def _generate_strategy_with_dim_partition(self, dim_partition):
total_mesh_dim_list = total_mesh_dim_list[0]
communication_action_mapping = {}

other_comm_spec = self.get_communication_spec(
other_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=total_mesh_dim_list)
communication_action_mapping["other"] = other_comm_spec
logical_process_axis=total_mesh_dim_list,
comm_type=CommType.HOOK)
communication_action_mapping["other"] = other_comm_action

if self.has_bias:
bias_comm_spec = self.get_communication_spec(
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=total_mesh_dim_list)
communication_action_mapping["bias"] = bias_comm_spec
logical_process_axis=total_mesh_dim_list,
comm_type=CommType.HOOK)
communication_action_mapping["bias"] = bias_comm_action

strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
Expand Down
Loading