Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fx] fixed compatiblity issue with torch 1.10 #1331

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions colossalai/fx/passes/split_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from torch.fx.graph_module import GraphModule
from typing import Callable, List, Dict, Any, Optional
from torch.fx._compatibility import compatibility
from packaging import version
import inspect


Expand Down Expand Up @@ -233,10 +234,13 @@ def record_cross_partition_use(def_node: torch.fx.node.Node,
base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {}
for node in m.graph.nodes:
if node.op == 'placeholder':
default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty
base_mod_env[node.name] = base_mod_graph.placeholder(node.name,
type_expr=node.type,
default_value=default_value)
if version.parse(torch.__version__) < version.parse('1.11.0'):
base_mod_env[node.name] = base_mod_graph.placeholder(node.name, type_expr=node.type)
else:
default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty
base_mod_env[node.name] = base_mod_graph.placeholder(node.name,
type_expr=node.type,
default_value=default_value)
base_mod_env[node.name].meta = node.meta.copy()

# Do some things iterating over the partitions in topological order again:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@


@meta_patched_function.register(torch.matmul)
@meta_patched_function.register('matmul') # for built-in op @
def torch_matmul(input, other, *, out=None):
# copied from huggingface.utils.fx
d1 = input.dim()
Expand Down
3 changes: 3 additions & 0 deletions colossalai/fx/tracer/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, pr
# fetch patched function
if meta_patched_function.has(target):
meta_target = meta_patched_function.get(target)
elif meta_patched_function.has(target.__name__):
# use name for some builtin op like @ (matmul)
meta_target = meta_patched_function.get(target.__name__)
else:
meta_target = target

Expand Down
6 changes: 1 addition & 5 deletions tests/test_fx/test_pipeline/test_timm_model/test_timm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
import torch
import pytest
try:
import timm.models as tm
except:
pass
import timm.models as tm
from timm_utils import split_model_and_compare_output


Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
import torch
try:
import torchvision.models as tm
except:
pass
import torchvision
import torchvision.models as tm
from colossalai.fx import ColoTracer
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
from torch.fx import GraphModule

from packaging import version
import random
import numpy as np
import inspect
import pytest

MANUAL_SEED = 0
random.seed(MANUAL_SEED)
Expand All @@ -22,9 +19,12 @@
def test_torchvision_models():
MODEL_LIST = [
tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2,
tm.regnet_x_16gf, tm.vit_b_16, tm.convnext_small, tm.efficientnet_b0, tm.mnasnet0_5
tm.regnet_x_16gf, tm.efficientnet_b0, tm.mnasnet0_5
]

if version.parse(torchvision.__version__) >= version.parse('0.12.0'):
MODEL_LIST.extend([tm.vit_b_16, tm.convnext_small])

tracer = ColoTracer()
data = torch.rand(2, 3, 224, 224)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
import torch
import pytest
try:
import timm.models as tm
except:
pass
import timm.models as tm
from colossalai.fx import ColoTracer
from torch.fx import GraphModule

Expand Down
Original file line number Diff line number Diff line change
@@ -1,26 +1,30 @@
import torch
import pytest
try:
import torchvision.models as tm
except:
pass
import torchvision
import torchvision.models as tm
from packaging import version
from colossalai.fx import ColoTracer
from torch.fx import GraphModule


def test_torchvision_models():
MODEL_LIST = [
tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2,
tm.regnet_x_16gf, tm.vit_b_16, tm.convnext_small, tm.mnasnet0_5, tm.efficientnet_b0
tm.regnet_x_16gf, tm.mnasnet0_5, tm.efficientnet_b0
]

RANDOMIZED_MODELS = [tm.efficientnet_b0]

if version.parse(torchvision.__version__) >= version.parse('0.12.0'):
MODEL_LIST.extend([tm.vit_b_16, tm.convnext_small])
RANDOMIZED_MODELS.append(tm.convnext_small)

torch.backends.cudnn.deterministic = True

tracer = ColoTracer()
data = torch.rand(2, 3, 224, 224)

for model_cls in MODEL_LIST:
if model_cls in [tm.convnext_small, tm.efficientnet_b0]:
if model_cls in RANDOMIZED_MODELS:
# remove the impact of randomicity
model = model_cls(stochastic_depth_prob=0)
else:
Expand Down