Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fx]add gpt2 passes for pipeline performance test #1366

Merged
merged 36 commits into from
Jul 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
df7e650
[CLI] add CLI launcher
YuliangLiu0306 Apr 13, 2022
73753aa
Merge branch 'feature/cli' into main
YuliangLiu0306 Apr 13, 2022
80da77a
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Apr 15, 2022
551359c
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Apr 18, 2022
a25697a
Revert "[CLI] add CLI launcher"
YuliangLiu0306 Apr 19, 2022
77b5704
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Apr 19, 2022
e23d33e
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Apr 20, 2022
997c625
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Apr 23, 2022
961d950
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Apr 24, 2022
2deaa40
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Apr 24, 2022
9ff217f
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Apr 28, 2022
501dc1a
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 May 12, 2022
21e43fd
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 May 21, 2022
cbd4579
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 May 23, 2022
1443291
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 May 30, 2022
e627cf5
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jun 10, 2022
289316e
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jun 15, 2022
689e047
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jun 15, 2022
0a83919
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jun 17, 2022
98c1ef9
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jun 20, 2022
9a3af67
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jun 21, 2022
7700793
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jun 28, 2022
3c77d1f
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jun 30, 2022
7c10323
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jul 4, 2022
11711d1
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jul 6, 2022
cee6276
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jul 8, 2022
8d00be0
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jul 12, 2022
af2a8f9
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jul 12, 2022
7b3899b
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jul 12, 2022
3eb8757
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jul 13, 2022
201b54c
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jul 14, 2022
c5a284d
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jul 18, 2022
8ee6650
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jul 19, 2022
d57e4fc
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jul 20, 2022
8563e64
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jul 25, 2022
e731e33
[fx]add gpt2 passes for pipeline performance test
YuliangLiu0306 Jul 26, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 25 additions & 11 deletions colossalai/fx/passes/meta_info_prop.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Tuple, NamedTuple, Optional, Dict
from functools import reduce
from torch.fx._compatibility import compatibility
from torch.fx.immutable_collections import immutable_dict, immutable_list


@compatibility(is_backward_compatible=True)
Expand All @@ -16,6 +17,7 @@ class TensorMetadata(NamedTuple):
requires_grad: bool
stride: Tuple[int]
numel: int
is_tensor: bool
# TODO: we can add a list of sharding spec here, and record the sharding
# behaviour by appending sharding spec into list.

Expand All @@ -29,8 +31,9 @@ def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata:
requires_grad = result.requires_grad
stride = result.stride()
numel = result.numel()
is_tensor = True

return TensorMetadata(shape, dtype, requires_grad, stride, numel)
return TensorMetadata(shape, dtype, requires_grad, stride, numel, is_tensor)


def _compute_node_numel(node_metadata: any) -> int:
Expand All @@ -51,6 +54,24 @@ def _compute_node_numel(node_metadata: any) -> int:
return node_numel


def _map_aggregate(arg, fn):
"""
Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys.
"""
if isinstance(arg, torch.Size):
return fn(arg)
if isinstance(arg, tuple):
return tuple(map_aggregate(elem, fn) for elem in arg)
elif isinstance(arg, list):
return immutable_list(map_aggregate(elem, fn) for elem in arg)
elif isinstance(arg, dict):
return immutable_dict((k, map_aggregate(v, fn)) for k, v in arg.items())
elif isinstance(arg, slice):
return slice(map_aggregate(arg.start, fn), map_aggregate(arg.stop, fn), map_aggregate(arg.step, fn))
else:
return fn(arg)


@compatibility(is_backward_compatible=True)
class MetaInfoProp(torch.fx.Interpreter):
"""
Expand Down Expand Up @@ -85,23 +106,16 @@ class MetaInfoProp(torch.fx.Interpreter):

def run_node(self, n: Node) -> Any:
result = super().run_node(n)
found_tensor = False

def extract_tensor_meta(obj):
if isinstance(obj, torch.Tensor):
nonlocal found_tensor
found_tensor = True
return _extract_tensor_metadata(obj)
else:
return obj
return TensorMetadata(None, None, False, None, 0, False)

meta = map_aggregate(result, extract_tensor_meta)
meta = _map_aggregate(result, extract_tensor_meta)

if found_tensor:
n.meta['tensor_meta'] = meta
else:
n.meta['tensor_meta'] = TensorMetadata(None, None, False, None, 0)
# counting the total size of node outputs
n.meta['tensor_meta'] = meta
total_node_size = _compute_node_numel(n.meta['tensor_meta'])
# counting the total size of parameters
total_param_size = 0
Expand Down
288 changes: 288 additions & 0 deletions colossalai/fx/passes/passes_for_gpt2_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
import torch
from torch.fx.graph_module import GraphModule
from typing import Callable, List, Dict, Any, Optional
from torch.fx._compatibility import compatibility
from packaging import version
from colossalai.fx.passes.meta_info_prop import TensorMetadata
import inspect
from colossalai.fx.passes.split_module import Partition
from colossalai.fx.passes.adding_split_node_pass import pipe_split
from torch.fx.node import Node


def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule):
'''
This pass will be used in gpt2 test, only a part of changes may be added into
split_with_split_nodes_pass, and it will be deprecated in future.
'''
part_idx = 0

def eliminate_unused_placeholders(gm):
for node in gm.graph.nodes:
if node.op == 'placeholder':
if not len(node.users):
gm.graph.erase_node(node)
gm.recompile()
return gm

def eliminate_unused_outputs(gm, next_partition_placeholders):
'''
This method is used to eliminate the outputs in previous partition which is unused in next partition.
'''
for node in gm.graph.nodes:
if node.op == 'output':
output_type = node.args[0].__class__
output_args = list(node.args[0])
for n in node.args[0]:
if n.name not in next_partition_placeholders:
output_args.remove(n)
gm.graph.erase_node(node)
gm.graph.output(output_type(output_args))
gm.recompile()
return gm

def split_callback(n: torch.fx.Node):
nonlocal part_idx
if (n.op, n.target) == ('call_function', pipe_split):
part_idx += 1
return part_idx

split_mod = split_module_for_gpt2_test(annotated_gm, None, split_callback)
split_submodules = []
for name, submodule in split_mod.named_modules():
if isinstance(submodule, torch.fx.GraphModule):
for node in submodule.graph.nodes:
if (node.op, node.target) == ('call_function', pipe_split):
submodule.graph.erase_node(node)
submodule.recompile()
split_submodules.append(submodule)

submodules = list(split_mod.children())
placeholder_dict = {}
for submodule in submodules:
submodule = eliminate_unused_placeholders(submodule)
placeholder_dict[submodule] = []
for node in submodule.graph.nodes:
if node.op == 'placeholder':
placeholder_dict[submodule].append(node.name)

for index, submodule in enumerate(submodules):
if index >= len(submodules) - 1:
break
submodule = eliminate_unused_outputs(submodule, placeholder_dict[submodules[index + 1]])
submodule.recompile()
split_mod.recompile()

return split_mod, split_submodules


@compatibility(is_backward_compatible=True)
def split_module_for_gpt2_test(
m: GraphModule,
root_m: torch.nn.Module,
split_callback: Callable[[torch.fx.node.Node], int],
):
"""
This pass will be used in gpt2 pp performance test, only a part of changes may be added into
split_module, and it will be deprecated in future.
"""
partitions: Dict[str, Partition] = {}
orig_nodes: Dict[str, torch.fx.node.Node] = {}

def _node_with_all_tensor_element(node_metadata: Any) -> int:
"""
return whether node contains non-tensor element.
"""
all_tensor_node = True

if isinstance(node_metadata, TensorMetadata):
all_tensor_node = node_metadata.is_tensor and all_tensor_node
elif isinstance(node_metadata, dict):
value_list = [v for _, v in node_metadata.items()]
all_tensor_node += _node_with_all_tensor_element(value_list)
else:
for element in node_metadata:
all_tensor_node += _node_with_all_tensor_element(element)

return all_tensor_node

def _move_all_ancestors_into_partition(node, partition_name):
all_ancestors = set()

def _gen_all_ancestors_set(node):
all_ancestors.add(node)
for n in node.all_input_nodes:
if n in all_ancestors:
continue
_gen_all_ancestors_set(n)

_gen_all_ancestors_set(node)
for n in list(all_ancestors):
if n.op != 'placeholder':
n._fx_partition = partition_name

def record_cross_partition_use(def_node: torch.fx.node.Node,
use_node: Optional[torch.fx.node.Node]): # noqa: B950
def_partition_name = getattr(def_node, '_fx_partition', None)
use_partition_name = getattr(use_node, '_fx_partition', None)
if def_partition_name != use_partition_name:
if 'tensor_meta' in def_node.meta:
if not _node_with_all_tensor_element(def_node.meta['tensor_meta']):
_move_all_ancestors_into_partition(use_node, def_partition_name)
node_process_list.extend(use_node.all_input_nodes)
node_process_list.extend(list(use_node.users))
node_process_list.append(use_node)

return

if def_partition_name is not None:
def_partition = partitions[def_partition_name]
def_partition.outputs.setdefault(def_node.name)
if use_partition_name is not None:
def_partition.partition_dependents.setdefault(use_partition_name)

if use_partition_name is not None:
use_partition = partitions[use_partition_name]
use_partition.inputs.setdefault(def_node.name)
if def_partition_name is not None:
use_partition.partitions_dependent_on.setdefault(def_partition_name)

node_process_list = list(m.graph.nodes)
# split nodes into parititons
while node_process_list:
node = node_process_list.pop(0)
orig_nodes[node.name] = node

if node.op in ["placeholder"]:
continue
if node.op == 'output':
# partition_name = str(split_callback(node))
# def _set_output_args_partition(n, partition_name):
# n._fx_partition = partition_name
# torch.fx.graph.map_arg(node.args[0], lambda n: _set_output_args_partition(n, partition_name))
torch.fx.graph.map_arg(node.args[0], lambda n: record_cross_partition_use(n, None))
continue
partition_name = str(split_callback(node))

# add node to partitions
partition = partitions.get(partition_name)
if partition is None:
partitions[partition_name] = partition = Partition(partition_name)

partition.node_names.append(node.name)
origin_partition_name = getattr(node, '_fx_partition', None)
if origin_partition_name is None:
node._fx_partition = partition_name

torch.fx.graph.map_arg(node.args, lambda def_node: record_cross_partition_use(def_node, node))
torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950

# find partitions with no dependencies
root_partitions: List[str] = []
for partition_name, partition in partitions.items():
if not len(partition.partitions_dependent_on):
root_partitions.append(partition_name)

# check partitions for circular dependencies and create topological partition ordering
sorted_partitions: List[str] = []
while root_partitions:
root_partition = root_partitions.pop()
sorted_partitions.append(root_partition)
for dependent in partitions[root_partition].partition_dependents:
partitions[dependent].partitions_dependent_on.pop(root_partition)
if not partitions[dependent].partitions_dependent_on:
root_partitions.append(dependent)
if len(sorted_partitions) != len(partitions):
raise RuntimeError("cycle exists between partitions!")

# add placeholders to parititons
for partition_name in sorted_partitions:
partition = partitions[partition_name]
for input in partition.inputs:
placeholder = partition.graph.placeholder(input)
placeholder.meta = orig_nodes[input].meta.copy()
partition.environment[orig_nodes[input]] = placeholder

# Transform nodes and collect targets for partition's submodule
for node in m.graph.nodes:
if hasattr(node, '_fx_partition'):
partition = partitions[node._fx_partition]

# swap out old graph nodes in kw/args with references to new nodes in this submodule
environment = partition.environment
gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n])
gathered_kwargs = torch.fx.graph.map_arg(node.kwargs, lambda n: environment[n])

if node.op not in ['call_module', 'get_attr']:
target = node.target
else:
target_atoms = node.target.split('.')
target_attr = m
for atom in target_atoms:
if not hasattr(target_attr, atom):
raise RuntimeError(f'Operator target {node.target} not found!')
target_attr = getattr(target_attr, atom)
# target = target_atoms[-1]
target = '_'.join(target_atoms)
partition.targets[target] = target_attr

assert isinstance(gathered_args, tuple)
assert isinstance(gathered_kwargs, dict)
new_node = partition.graph.create_node(op=node.op,
target=target,
args=gathered_args,
kwargs=gathered_kwargs)
new_node.meta = node.meta.copy()
partition.environment[node] = new_node

# Set up values to construct base module
base_mod_env: Dict[str, torch.fx.node.Node] = {}
base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {}
for node in m.graph.nodes:
if node.op == 'placeholder':
if version.parse(torch.__version__) < version.parse('1.11.0'):
base_mod_env[node.name] = base_mod_graph.placeholder(node.name, type_expr=node.type)
else:
default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty
base_mod_env[node.name] = base_mod_graph.placeholder(node.name,
type_expr=node.type,
default_value=default_value)
base_mod_env[node.name].meta = node.meta.copy()

# Do some things iterating over the partitions in topological order again:
# 1) Finish off submodule Graphs by setting corresponding outputs
# 2) Construct GraphModules for each submodule
# 3) Construct the base graph by emitting calls to those submodules in
# topological order

for partition_name in sorted_partitions:
partition = partitions[partition_name]

# Set correct output values
output_vals = tuple(partition.environment[orig_nodes[name]] for name in partition.outputs)
output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment]
partition.graph.output(output_vals)

# Construct GraphModule for this partition
submod_name = f'submod_{partition_name}'
base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(partition.targets,
partition.graph) # noqa: B950

# Emit call in base graph to this submodule
output_val = base_mod_graph.call_module(submod_name, tuple(base_mod_env[name] for name in partition.inputs))
if len(partition.outputs) > 1:
# Unpack multiple return values from submodule
output_val_proxy = torch.fx.proxy.Proxy(output_val)
for i, output_name in enumerate(partition.outputs):
base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index]
else:
if not partition.outputs:
continue
base_mod_env[list(partition.outputs)[0]] = output_val

for node in m.graph.nodes:
if node.op == 'output':
base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950

return torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)