Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fx] tested the complete workflow for auto-parallel #1336

Merged
merged 4 commits into from
Jul 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 103 additions & 42 deletions colossalai/fx/passes/shard_1d_pass.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
import torch
import torch.nn as nn
import operator
import colossalai
from colossalai.tensor import ProcessGroup
from colossalai.tensor.distspec import shard
from colossalai.tensor.compute_spec import ComputePattern, ComputeSpec

ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
ELEMENTWISE_FUNC_OP = [
torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv,
operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout
]

ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.AvgPool1d, torch.nn.AvgPool2d]
ELEMENTWISE_FUNC_OP = [torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv, operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout, torch.nn.functional.conv1d, torch.nn.functional.conv2d, torch.nn.functional.conv3d, torch.nn.functional.avg_pool1d, torch.nn.functional.avg_pool2d, torch.nn.functional.avg_pool3d, torch.nn.functional.max_pool1d, torch.nn.functional.max_pool2d, torch.nn.functional.max_pool3d]

def weight_split(weight: torch.nn.parameter.Parameter, dim: int, col_normal: bool) -> torch.nn.parameter.Parameter:
"""weight_split
Expand All @@ -21,6 +28,8 @@ def weight_split(weight: torch.nn.parameter.Parameter, dim: int, col_normal: boo
else:
setattr(weight, "fx_attr", (dim, "SHARD", "TP", "col_needs_many_outputs"))
return weight


def column_shard_linear_pass(gm: torch.fx.GraphModule):
# Split all the linear module with column shard. Currently for testing only.
mod_graph = gm.graph
Expand Down Expand Up @@ -48,43 +57,95 @@ def row_shard_linear_pass(gm: torch.fx.GraphModule):
gm.recompile()
return gm

def transform_mlp_pass(gm: torch.fx.GraphModule):

def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: ProcessGroup):
"""
This IR pass checks for transformer MLP like structure and annotate column and row sharding to the linear layers.
"""
#TODO: Needs to handle special cases, like x = linear(x) + linear(x)
mod_graph = gm.graph
col_shard = True
element_op = []
all_linear_name = []
linear_name = []
# Get the name of element wise module(torch.nn.ReLU)
# Get the name of all the linear modules and repeated linear modules
for name, func in gm.named_children():
if not isinstance(func, torch.nn.Linear):
for i in ELEMENTWISE_MODULE_OP:
if isinstance(func, i):
element_op.append(name)
break
else:
if name in all_linear_name:
if name in linear_name:
linear_name.remove(name)
else:
all_linear_name.append(name)
linear_name.append(name)
# If the linear modules is called multiple times, set the dist spec as col shard
# If the module is element wise or the function/method is element wise, remains col_shard
for node in mod_graph.nodes:
if node.target in linear_name:
target_module = node.graph.owning_module.get_submodule(node.target)
dim = 0 if col_shard else -1
target_module.weight = weight_split(target_module.weight, dim=dim, col_normal=False)
col_shard = not col_shard
elif node.target in all_linear_name:
target_module = node.graph.owning_module.get_submodule(node.target)
dim = 0 if col_shard else -1
target_module.weight = weight_split(target_module.weight, dim=dim, col_normal=True)
col_shard = not col_shard
else:
if node.target not in element_op and all(node.target != i for i in ELEMENTWISE_FUNC_OP):
col_shard = True
gm.recompile()
return gm
graph = graph_module.graph
world_size = process_group.world_size()

def _traverse_and_annotate(node, start_tracking, annotation_record, world_size):
# traverse the graph to look for consecutive linear layers
is_linear_module = False

if node.op == 'call_module':
# look for the linear layer
module = node.graph.owning_module.get_submodule(node.target)
if isinstance(module, nn.Linear):
is_linear_module = True
if start_tracking:
# when start_tracking = True
# it means the first linear has been found and the current module
# is the second linear
# set the current linear module to be row-sharded
annotation_record['row'] = module

for shard_type, module in annotation_record.items():
# add row sharding spec
if shard_type == 'row':
dist_spec = shard(dims=[-1], num_partitions=[world_size])
comp_spec = ComputeSpec(ComputePattern.TP1D)
setattr(module.weight, 'pg', process_group)
setattr(module.weight, 'dist_spec', dist_spec)
setattr(module.weight, 'comp_spec', comp_spec)
elif shard_type == 'col':
weight_dist_spec = shard(dims=[0], num_partitions=[world_size])
weight_comp_spec = ComputeSpec(ComputePattern.TP1D)
weight_comp_spec.output_replicate = False
setattr(module.weight, 'pg', process_group)
setattr(module.weight, 'dist_spec', weight_dist_spec)
setattr(module.weight, 'comp_spec', weight_comp_spec)

if module.bias is not None:
bias_dist_spec = shard(dims=[0], num_partitions=[world_size])
bias_comp_spec = ComputeSpec(ComputePattern.TP1D)
bias_comp_spec.output_replicate = False
setattr(module.bias, 'pg', process_group)
setattr(module.bias, 'dist_spec', bias_dist_spec)
setattr(module.bias, 'comp_spec', bias_comp_spec)
start_tracking = False
annotation_record.clear()
else:
# when start tracking = False
# it means the current layer is the first linear
# set the linear layer to be col-sharded
start_tracking = True
annotation_record['col'] = module

if start_tracking and not is_linear_module:
# check against the white list
# if non-element wise op is found, we reset the tracking
if node.op == 'call_module':
module = node.graph.owning_module.get_submodule(node.target)
if module.__class__ not in ELEMENTWISE_MODULE_OP:
start_tracking = False
elif node.op == 'call_function' or node.op == 'call_method':
if node.target not in ELEMENTWISE_FUNC_OP:
start_tracking = False
elif len(node.users.keys()) > 1:
start_tracking = False

if not start_tracking:
annotation_record.clear()

# stop tracking for consecutive linear when branch is found
# e.g.
# out1 = self.linear1(x)
# out2 = self.linear2(x)
# return out1+out2
next_nodes = list(node.users.keys())
if len(next_nodes) > 1:
start_tracking = False
annotation_record.clear()

# traverse
for node in next_nodes:
_traverse_and_annotate(node, start_tracking, annotation_record, world_size)

placeholder_node = list(graph.nodes)[0]
annotate_record = {}
_traverse_and_annotate(placeholder_node, False, annotate_record, world_size)

return graph_module
12 changes: 7 additions & 5 deletions colossalai/utils/model/lazy_init_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def __exit__(self, *args, **kwargs):
self._unpatch_nn_init_funcs()
self._unpatch_torch_tensor_funcs()

def lazy_init_parameters(self, model: torch.nn.Module, device='cpu', call_back: Callable = None):
def lazy_init_parameters(self, model: torch.nn.Module, device='cpu'):
"""
Initialize the weights of the meta-tensor model.

Expand Down Expand Up @@ -205,6 +205,7 @@ def _init_and_shard(module, name, tensor):
# get sharding spec
dist_spec = getattr(tensor, 'dist_spec', None)
pg = getattr(tensor, 'pg', None)
comp_spec = getattr(tensor, 'comp_spec', None)

# convert the tensor from meta to materialized one
if tensor.is_meta:
Expand All @@ -224,14 +225,15 @@ def _init_and_shard(module, name, tensor):
else:
tensor = ColoTensor.from_torch_tensor(tensor)

# apply sharding
if dist_spec:
tensor = tensor.redistribute(dist_spec=dist_spec, pg=pg)

# override the original tensor
with torch.no_grad():
setattr(module, name, tensor)

# apply sharding
if dist_spec:
tensor.process_group = pg
tensor.set_tensor_spec(dist_spec, comp_spec)

_init_recursively(model)

return model
77 changes: 77 additions & 0 deletions tests/test_fx/test_complete_workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import colossalai
import torch
import torch.nn as nn
import pytest
import torch.multiprocessing as mp
import torch.distributed as dist
from colossalai.testing import rerun_if_address_is_in_use
from functools import partial
from colossalai.fx import ColoTracer
from colossalai.utils.model.lazy_init_context import LazyInitContext
from colossalai.fx.passes.shard_1d_pass import transformer_mlp_pass
from colossalai.utils import free_port
from colossalai.tensor import ProcessGroup


class MLP(torch.nn.Module):

def __init__(self, dim: int):
super().__init__()
self.linear1 = torch.nn.Linear(dim, dim)
self.linear2 = torch.nn.Linear(dim, dim)
self.dropout = torch.nn.Dropout(0)
self.relu = torch.nn.ReLU()

def forward(self, x):
x = self.linear1(x)
x = self.dropout(x)
x = self.relu(x)
x = self.linear2(x)
return x


def run_workflow(world_size):
# initailization
with LazyInitContext() as ctx:
model = MLP(16)

# tracing
tracer = ColoTracer()
graph = tracer.trace(model)
gm = torch.fx.GraphModule(model, graph, model.__class__.__name__)

# annotate
annotated_gm = transformer_mlp_pass(gm, process_group=ProcessGroup())
annotated_gm.recompile()

# materialization and sharding
ctx.lazy_init_parameters(annotated_gm)

# # check sharding
assert list(model.linear1.weight.shape) == [16 // world_size, 16]
assert list(model.linear1.bias.shape) == [16 // world_size]
assert list(model.linear2.weight.shape) == [16, 16 // world_size]

# test forward to make sure that IR transform will produce the same results
# like how ColoTensor would do it normally
data = torch.rand(4, 16)
non_fx_out = model(data)
fx_out = annotated_gm(data)
assert torch.equal(non_fx_out, fx_out)


def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_workflow(world_size)


@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2])
@rerun_if_address_is_in_use()
def test_complete_workflow(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
test_complete_workflow(2)
59 changes: 0 additions & 59 deletions tests/test_fx/test_transform_mlp_pass.py

This file was deleted.