Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fx] refactor tracer to trace complete graph #1342

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
df7e650
[CLI] add CLI launcher
YuliangLiu0306 Apr 13, 2022
73753aa
Merge branch 'feature/cli' into main
YuliangLiu0306 Apr 13, 2022
80da77a
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Apr 15, 2022
551359c
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Apr 18, 2022
a25697a
Revert "[CLI] add CLI launcher"
YuliangLiu0306 Apr 19, 2022
77b5704
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Apr 19, 2022
e23d33e
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Apr 20, 2022
997c625
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Apr 23, 2022
961d950
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Apr 24, 2022
2deaa40
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Apr 24, 2022
9ff217f
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Apr 28, 2022
501dc1a
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 May 12, 2022
21e43fd
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 May 21, 2022
cbd4579
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 May 23, 2022
1443291
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 May 30, 2022
e627cf5
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jun 10, 2022
289316e
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jun 15, 2022
689e047
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jun 15, 2022
0a83919
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jun 17, 2022
98c1ef9
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jun 20, 2022
9a3af67
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jun 21, 2022
7700793
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jun 28, 2022
3c77d1f
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jun 30, 2022
7c10323
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jul 4, 2022
11711d1
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jul 6, 2022
cee6276
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jul 8, 2022
8d00be0
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jul 12, 2022
af2a8f9
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jul 12, 2022
7b3899b
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jul 12, 2022
3eb8757
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jul 13, 2022
201b54c
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jul 14, 2022
c5a284d
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jul 18, 2022
8ee6650
Merge branch 'hpcaitech:main' into main
YuliangLiu0306 Jul 19, 2022
64a0bc6
[fx] refactor tracer to trace complete graph
YuliangLiu0306 Jul 19, 2022
6042054
add comments and solve conflicts.
YuliangLiu0306 Jul 20, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 48 additions & 7 deletions colossalai/fx/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
from torch.fx.proxy import Proxy, Attribute
from typing import List, Union, Any
from colossalai.fx.tracer.meta_patch import meta_patched_function

__all__ = ['ColoProxy']

Expand Down Expand Up @@ -45,6 +46,14 @@ def __len__(self):
self._assert_has_meta_data()
return len(self.meta_data)

def __int__(self):
self._assert_has_meta_data()
return int(self.meta_data)

def __float__(self):
self._assert_has_meta_data()
return float(self.meta_data)

def __bool__(self):
self._assert_has_meta_data()
return self.meta_data
Expand All @@ -53,9 +62,6 @@ def __getattr__(self, k):

return ColoAttribute(self, k)

def __setitem__(self, indices, values):
return self.tracer.create_proxy("call_function", operator.setitem, (self, indices, values), {})

def __contains__(self, key):
if self.node.op == "placeholder":
# this is used to handle like
Expand All @@ -65,11 +71,26 @@ def __contains__(self, key):
return super().__contains__(key)


def extract_meta(*args, **kwargs):
"""
This function is copied from _tracer_utils.py to avoid circular import issue.
"""

def _convert(val):
if isinstance(val, ColoProxy):
return val.meta_data
elif isinstance(val, (list, tuple)):
return type(val)([_convert(ele) for ele in val])
return val

new_args = [_convert(val) for val in args]
new_kwargs = {k: _convert(v) for k, v in kwargs.items()}
return new_args, new_kwargs


class ColoAttribute(ColoProxy):

def __init__(self, root, attr: str):
# this class is copied from torch.fx.Attribute
# but inherits ColoProxy
self.root = root
self.attr = attr
self.tracer = root.tracer
Expand All @@ -78,8 +99,28 @@ def __init__(self, root, attr: str):
@property
def node(self):
if self._node is None:
self._node = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {}).node
proxy = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {})
if not isinstance(proxy, ColoProxy):
meta_args, meta_kwargs = extract_meta(*(self.root, self.attr))
meta_out = getattr(*meta_args, **meta_kwargs)
proxy = ColoProxy(proxy.node)
proxy.meta_data = meta_out
self._node = proxy.node

return self._node

def __call__(self, *args, **kwargs):
return self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs)
proxy = self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs)
if not isinstance(proxy, ColoProxy):
meta_args, meta_kwargs = extract_meta(*((self.root,) + args), **kwargs)
method = getattr(meta_args[0].__class__, self.attr)
if meta_patched_function.has(method):
meta_target = meta_patched_function.get(method)
elif meta_patched_function.has(target.__name__):
meta_target = meta_patched_function.get(target.__name__)
else:
meta_target = method
meta_out = meta_target(*meta_args, **meta_kwargs)
proxy = ColoProxy(proxy.node)
proxy.meta_data = meta_out
return proxy
19 changes: 19 additions & 0 deletions colossalai/fx/tracer/_tracer_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import List, Union, Any
from ..proxy import ColoProxy, ColoAttribute
import torch
from .meta_patch import meta_patched_function, meta_patched_module

__all__ = ['is_element_in_list', 'extract_meta']

Expand Down Expand Up @@ -29,3 +31,20 @@ def _convert(val):
new_args = [_convert(val) for val in args]
new_kwargs = {k: _convert(v) for k, v in kwargs.items()}
return new_args, new_kwargs


def compute_meta_data_for_functions_proxy(target, args, kwargs):
args_metas, kwargs_metas = extract_meta(*args, **kwargs)

# fetch patched function
if meta_patched_function.has(target):
meta_target = meta_patched_function.get(target)
elif meta_patched_function.has(target.__name__):
meta_target = meta_patched_function.get(target.__name__)
else:
meta_target = target
meta_out = meta_target(*args_metas, **kwargs_metas)
if isinstance(meta_out, torch.Tensor):
meta_out = meta_out.to(device="meta")

return meta_out
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ def torch_arange(*args, **kwargs):
return torch.empty((end - start) // step, dtype=dtype, device="meta")


@meta_patched_function.register(torch.finfo)
def torch_finfo(*args):
return torch.finfo(*args)


@meta_patched_function.register(torch.where)
def torch_where(condition, x, y):
# torch.where returns the broadcasted tensor of condition, x, and y,
Expand Down
62 changes: 59 additions & 3 deletions colossalai/fx/tracer/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import enum
import inspect
import functools
import operator
from colossalai.fx.tracer.meta_patch import meta_patched_module
import torch
import torch.nn as nn
Expand All @@ -16,8 +17,9 @@
from torch.fx.proxy import Proxy, ParameterProxy
from ..proxy import ColoProxy
from typing import Optional, Dict, Any
from ._tracer_utils import is_element_in_list, extract_meta
from ._tracer_utils import is_element_in_list, extract_meta, compute_meta_data_for_functions_proxy
from .meta_patch import meta_patched_function, meta_patched_module
from torch.fx.graph import magic_methods, reflectable_magic_methods

__all__ = ['ColoTracer']

Expand Down Expand Up @@ -61,7 +63,7 @@ def __init__(self, *args, **kwargs):
# Feature flag for proxying accesses to buffer values
proxy_buffer_attributes: bool = True

_TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty", "tensor"]
_TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty", "tensor", "finfo"]

def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None) -> ColoProxy:
"""
Expand Down Expand Up @@ -344,11 +346,15 @@ def look_for_proxy(*args, **kwargs):
for arg in args:
if isinstance(arg, Proxy):
return arg
if isinstance(arg, (tuple, list)):
return look_for_proxy(*arg)

# find in keyword vars
for k, v in kwargs.items():
if isinstance(v, Proxy):
return v
if isinstance(v, (tuple, list)):
return look_for_proxy(*v)
return None

@functools.wraps(target)
Expand All @@ -358,10 +364,60 @@ def wrapper(*args, **kwargs):
if proxy is not None:
# if the arg is a proxy, then need to record this function called on this proxy
# e.g. torch.ones(size) where size is an input proxy
return proxy.tracer.create_proxy("call_function", target, args, kwargs)
colo_proxy = proxy.tracer.create_proxy("call_function", target, args, kwargs)
if not isinstance(colo_proxy, ColoProxy):
meta_out = compute_meta_data_for_functions_proxy(target, args, kwargs)
colo_proxy = ColoProxy(fx_proxy.node)
colo_proxy.meta_data = meta_out
return colo_proxy
else:
# this is called directly when the inputs do not contain proxy
# e.g. torch.ones(4) where the input is static
return target(*args, **kwargs)

return wrapper, target


# Patched magic methods for ColoProxy, then tracer could record the magic_method like __sub__,
# and add meta_data attribute to the created proxy.
for method in magic_methods:

def _scope(method):

def impl(*args, **kwargs):

tracer = args[0].tracer
target = getattr(operator, method)
proxy = tracer.create_proxy('call_function', target, args, kwargs)
if not isinstance(proxy, ColoProxy):
meta_out = compute_meta_data_for_functions_proxy(target, args, kwargs)
proxy = ColoProxy(proxy.node)
proxy.meta_data = meta_out
return proxy

impl.__name__ = method
as_magic = f'__{method.strip("_")}__'
setattr(ColoProxy, as_magic, impl)

_scope(method)


def _define_reflectable(orig_method_name):
method_name = f'__r{orig_method_name.strip("_")}__'

def impl(self, rhs):
target = getattr(operator, orig_method_name)
proxy = self.tracer.create_proxy('call_function', target, (rhs, self), {})
if not isinstance(proxy, ColoProxy):
meta_out = compute_meta_data_for_functions_proxy(target, *(rhs, self), {})
proxy = ColoProxy(proxy.node)
proxy.meta_data = meta_out
return proxy

impl.__name__ = method_name
impl.__qualname__ = method_name
setattr(ColoProxy, method_name, impl)


for orig_method_name in reflectable_magic_methods:
_define_reflectable(orig_method_name)
35 changes: 29 additions & 6 deletions tests/test_fx/test_coloproxy.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,40 @@
import torch
import torch.nn as nn
from colossalai.fx.proxy import ColoProxy
from colossalai.fx.tracer.tracer import ColoTracer
from torch.fx import GraphModule
import pytest


@pytest.mark.skip('skip due to tracer')
class Conv1D(nn.Module):

def __init__(self, nf, nx):
super().__init__()
self.nf = nf
w = torch.empty(nx, nf)
nn.init.normal_(w, std=0.02)
self.weight = nn.Parameter(w)
self.bias = nn.Parameter(torch.zeros(nf))

def forward(self, x):
size_out = x.shape[:-1] + (self.nf,)
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
x = x.view(size_out)
return x


def test_coloproxy():
# create a dummy node only for testing purpose
model = torch.nn.Linear(10, 10)
gm = torch.fx.symbolic_trace(model)

tracer = ColoTracer()
model = Conv1D(3, 3)
input_sample = {'x': torch.rand(3, 3).to('meta')}

graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
node = list(gm.graph.nodes)[0]

# create proxy
proxy = ColoProxy(node=node)
proxy = ColoProxy(node=node, tracer=tracer)
proxy.meta_data = torch.empty(4, 2, device='meta')

assert len(proxy) == 4
Expand Down
1 change: 0 additions & 1 deletion tests/test_fx/test_pipeline/test_hf_model/test_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
SEQ_LENGHT = 16


@pytest.mark.skip('skip due to tracer')
def test_opt():
MODEL_LIST = [
transformers.OPTModel,
Expand Down
1 change: 0 additions & 1 deletion tests/test_fx/test_pipeline/test_timm_model/test_timm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def test_timm_models_without_control_flow():
split_model_and_compare_output(model, data)


@pytest.mark.skip('skip due to tracer')
def test_timm_models_with_control_flow():
torch.backends.cudnn.deterministic = True

Expand Down
1 change: 0 additions & 1 deletion tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
SEQ_LENGHT = 16


@pytest.mark.skip('skip due to tracer')
def test_opt():
MODEL_LIST = [
transformers.OPTModel,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def test_timm_models_without_control_flow():
trace_and_compare(model_cls, tracer, data)


@pytest.mark.skip('skip due to tracer')
def test_timm_models_with_control_flow():
torch.backends.cudnn.deterministic = True

Expand Down
59 changes: 0 additions & 59 deletions tests/test_fx/test_transform_mlp_pass.py

This file was deleted.