Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[autoparallel] fix bias addition module #1800

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions colossalai/auto_parallel/passes/runtime_apply_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
# substitute the origin node with shape_consistency_node
origin_index_args = new_args.index(node)
new_args[origin_index_args] = shape_consistency_node
user_node.args = new_args
user_node.args = tuple(new_args)
elif str(node) in new_kwargs:
# substitute the origin node with shape_consistency_node
new_kwargs[str(node)] = shape_consistency_node
Expand All @@ -118,10 +118,12 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
comm_actions = node.best_strategy.communication_actions
for op_data, comm_action in comm_actions.items():

if op_data.type == OperationDataType.PARAM:
if comm_action.comm_type == CommType.HOOK:
continue
if comm_action.comm_type == CommType.BEFORE:
if comm_action.key_for_kwarg is not None:
if op_data.type == OperationDataType.OUTPUT:
comm_object = node
elif comm_action.key_for_kwarg is not None:
comm_object = node.kwargs[comm_action.key_for_kwarg]
else:
comm_object = node.args[comm_action.arg_index]
Expand All @@ -140,7 +142,7 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
# substitute the origin node with comm_spec_apply_node
new_args = list(node.args)
new_args[comm_action.arg_index] = comm_spec_apply_node
node.args = new_args
node.args = tuple(new_args)

elif comm_action.comm_type == CommType.AFTER:
with mod_graph.inserting_after(node):
Expand All @@ -163,7 +165,6 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
# substitute the origin node with comm_spec_apply_node
new_kwargs[str(node)] = comm_spec_apply_node
user.kwargs = new_kwargs

return gm


Expand Down
71 changes: 69 additions & 2 deletions colossalai/auto_parallel/passes/runtime_preparation_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
from torch.fx import symbolic_trace
from torch.fx.node import Node

from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction, CommType, OperationDataType
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
OperationDataType,
ShardingStrategy,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.comm_spec import _all_reduce
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
Expand Down Expand Up @@ -42,7 +47,32 @@ def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]):
target_sharding_spec = user_node.best_strategy.get_sharding_spec_by_name(str(node.name))
target_sharding_specs.append(target_sharding_spec)
sharding_spec_convert_dict[index] = target_sharding_specs

# the get_attr node strategy is kind of pending strategy, which means we will change it
# to the same strategy of the user node.
if node.op == 'get_attr':
assert len(target_sharding_specs) == 1, f'sharing weight is not supported in current version.'
new_sharding_spec = target_sharding_specs[0]
user_node = node.strategies_vector.successor_nodes[0]
user_strategy = node.strategies_vector.successor_nodes[0].best_strategy
op_data_in_user = user_strategy.get_op_data_by_name(str(node))
origin_node_sharding_spec_dict[index] = new_sharding_spec
origin_pending_strategy = node.best_strategy
origin_op_data = origin_pending_strategy.get_op_data_by_name(str(node))
new_sharding_specs = origin_pending_strategy.sharding_specs
new_sharding_specs[origin_op_data] = new_sharding_spec
new_communication_actions = {}
if op_data_in_user in user_strategy.communication_actions:
new_communication_action = user_strategy.communication_actions.pop(op_data_in_user)
new_communication_action.arg_index = 0
new_communication_actions[origin_op_data] = new_communication_action
new_strategy = ShardingStrategy(name=str(new_sharding_spec.sharding_sequence),
sharding_specs=new_sharding_specs,
compute_cost=origin_pending_strategy.compute_cost,
communication_cost=origin_pending_strategy.communication_cost,
memory_cost=origin_pending_strategy.memory_cost,
communication_actions=new_communication_actions)
setattr(node, 'best_strategy', new_strategy)
setattr(node, 'sharding_spec', new_sharding_spec)
comm_action_dict = {}
for op_data, comm_action in node.best_strategy.communication_actions.items():
comm_action_dict[op_data.name] = comm_action
Expand Down Expand Up @@ -111,6 +141,43 @@ def hook_fn(grad):
for name, buffer_sharded in sharded_buffer_dict.items():
setattr(target_module, name, buffer_sharded.detach().clone())

if node.op == 'get_attr':
root = node.graph.owning_module
atoms = node.target.split(".")
attr_len = len(atoms)
if attr_len == 1:
target_module = root
target = getattr(root, atoms[0])
else:
target_module = root.get_submodule(atoms[-2])
target = getattr(target_module, atoms[-1])

target_sharding_spec = node.sharding_spec
if target_sharding_spec.dim_partition_dict != {}:
origin_sharding_spec = ShardingSpec(device_mesh, target.shape, {})
setattr(target, 'sharding_spec', origin_sharding_spec)
# TODO: build a ColoParamter class to manager the distributed parameters
target_sharded = torch.nn.Parameter(
shape_consistency_manager.apply_for_autoparallel_runtime(target.data, target.sharding_spec,
target_sharding_spec).detach().clone())
else:
target_sharded = target
setattr(target_module, atoms[-1], target_sharded)

comm_actions = node.best_strategy.communication_actions
for operation_data, comm_action in comm_actions.items():
comm_spec_to_use = comm_action.comm_spec
# register hook to the parameters
if isinstance(node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK:

def wrapper(param, comm_spec):

def hook_fn(grad):
_all_reduce(grad, comm_spec)

param.register_hook(hook_fn)

wrapper(target_sharded, comm_spec_to_use)
return gm


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,15 @@ def get_strategy_generator(self) -> List[StrategyGenerator]:
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process

# check if the input operand is a parameter
if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG

physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
type=data_type,
data=self.node.args[0]._meta_data)
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,19 @@ def collate_strategies(self) -> List[ShardingStrategy]:
arg_index=0)
input_comm_action.comm_spec.gather_dim = total_mesh_dim_list

else:
elif len(total_mesh_dim_list) >= 2:
source_spec = sharding_spec_mapping["input"]
target_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=source_spec.entire_shape,
dim_partition_dict={})
comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)

communication_action_mapping["input"] = input_comm_action
else:
input_comm_action = None

if input_comm_action is not None:
communication_action_mapping["input"] = input_comm_action
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def create_bias_reshape_proxy(self, dimensions):
bias_shape[0] = -1
bias_reshape_node_kind = 'call_method'
bias_reshape_node_target = 'view'
bias_reshape_node_args = (self.bias_proxy, bias_shape)
bias_reshape_node_args = (self.bias_proxy, torch.Size(bias_shape))
bias_reshape_proxy = self.tracer.create_proxy(bias_reshape_node_kind, bias_reshape_node_target,
bias_reshape_node_args, {})
return bias_reshape_proxy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def torch_bmm(input, mat2, *, out=None):


@meta_patched_function.register(torch.nn.functional.linear)
def torch_linear(input, mat2, *, out=None):
def torch_linear(input, mat2, bias=None, *, out=None):
if out is not None:
raise ValueError("Don't support in-place abs for MetaTensor analysis")
output_shape = list(input.shape)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
from functools import partial

import pytest
import torch
import torch.multiprocessing as mp

from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDataType
from colossalai.auto_parallel.tensor_shard.solver import (
CostGraph,
GraphAnalyser,
Solver,
SolverOptions,
StrategiesConstructor,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, assert_close_loose, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port


class LinearModel(torch.nn.Module):

def __init__(self, in_features, out_features):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features)

def forward(self, x):
x = self.linear(x)
x = x * 2

return x


class ConvModel(torch.nn.Module):

def __init__(self, in_channels, out_channels, kernel_size, bias=True):
super().__init__()
self.conv = torch.nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
bias=bias)

def forward(self, x):
x = self.conv(x)
x = x * 2

return x


def check_linear_module(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = LinearModel(4, 8).cuda()
input = torch.rand(4, 4).cuda()
output_compare = model(input)
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
tracer = ColoTracer()
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %linear_weight : [#users=1] = get_attr[target=linear.weight]
# %linear_bias : [#users=1] = get_attr[target=linear.bias]
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%x, %linear_weight), kwargs = {})
# %add : [#users=1] = call_function[target=operator.add](args = (%linear, %linear_bias), kwargs = {})
# %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {})
# return mul
graph = tracer.trace(root=model, meta_args={'x': torch.rand(4, 4).to('meta')})
# def forward(self, x : torch.Tensor):
# linear_weight = self.linear.weight
# linear_bias = self.linear.bias
# linear = torch._C._nn.linear(x, linear_weight); x = linear_weight = None
# add = linear + linear_bias; linear = linear_bias = None
# mul = add * 2; add = None
# return mul
gm = ColoGraphModule(model, graph)
gm.recompile()
node_list = list(graph.nodes)

solver_options = SolverOptions(fast=True)
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
linear_node = node_list[3]
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
graph_analyser = GraphAnalyser(gm)
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
ret = solver.call_solver_serialized_args()
solution = list(ret[0])
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, solution, device_mesh)

gm = runtime_apply_pass(gm)
gm.recompile()
output = gm(input, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
assert_close(output, output_compare)


def check_conv_module(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = ConvModel(3, 6, 2).cuda()
input = torch.rand(4, 3, 64, 64).cuda()
output_compare = model(input)
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
tracer = ColoTracer()
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %conv_weight : [#users=1] = get_attr[target=conv.weight]
# %conv_bias : [#users=1] = get_attr[target=conv.bias]
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%x, %conv_weight), kwargs = {})
# %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {})
# %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {})
# %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {})
# return mul
graph = tracer.trace(root=model, meta_args={'x': torch.rand(4, 3, 64, 64).to('meta')})
# def forward(self, x : torch.Tensor):
# conv_weight = self.conv.weight
# conv_bias = self.conv.bias
# conv2d = torch.conv2d(x, conv_weight); x = conv_weight = None
# view = conv_bias.view([1, -1, 1, 1]); conv_bias = None
# add = conv2d + view; conv2d = view = None
# mul = add * 2; add = None
# return mul
gm = ColoGraphModule(model, graph)

gm.recompile()

node_list = list(graph.nodes)
conv_node = node_list[3]
solver_options = SolverOptions(fast=True)
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()

cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
graph_analyser = GraphAnalyser(gm)
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
ret = solver.call_solver_serialized_args()
solution = list(ret[0])

gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, solution, device_mesh)

gm = runtime_apply_pass(gm)
gm.recompile()
output = gm(input, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
assert_close(output, output_compare)


@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_bias_addition_module():
world_size = 4
run_func_linear = partial(check_linear_module, world_size=world_size, port=free_port())
mp.spawn(run_func_linear, nprocs=world_size)
run_func_conv = partial(check_conv_module, world_size=world_size, port=free_port())
mp.spawn(run_func_conv, nprocs=world_size)


if __name__ == '__main__':
test_bias_addition_module()
Loading