From c1f9993c1d4711c3da74bdaedb2c2b57e2b737aa Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Tue, 18 Feb 2025 17:34:24 -0800 Subject: [PATCH 1/8] Extract function via rewrite --- onnxscript/rewriter/pattern.py | 68 ++++++++++++++++++++++++++++- onnxscript/rewriter/pattern_test.py | 15 +++++++ 2 files changed, 82 insertions(+), 1 deletion(-) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 868da6244..4ef78cd26 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -1292,6 +1292,7 @@ def __init__( remove_nodes: bool = True, graph_pre_visitor: Callable[[], None] | None = None, graph_post_visitor: Callable[[], None] | None = None, + as_function: bool = False, ) -> None: """Create a rewrite rule. @@ -1313,7 +1314,10 @@ def __init__( graph_post_visitor: A function that will be called after the rewriting is complete for a graph or function. """ - + if as_function and not remove_nodes: + raise ValueError( + "as_function=True is only supported when remove_nodes=True." + ) if not isinstance(target_pattern, GraphPattern): target_pattern = _to_graph_pattern(target_pattern) self._target_pattern = target_pattern @@ -1338,6 +1342,7 @@ def __init__( self.remove_nodes = remove_nodes self.graph_pre_visitor = graph_pre_visitor self.graph_post_visitor = graph_post_visitor + self.as_function = as_function def __str__(self) -> str: return self.name if self.name else "Anonymous Rule" @@ -1528,6 +1533,49 @@ def check(self, op, *args, **kwargs): def rewrite(self, op, *args, **kwargs): raise NotImplementedError("Method 'rewrite' must be implemented by derived class.") +def _copy_for_function(nodes: Sequence[ir.Node], outputs: Sequence[ir.Value]): + value_map : dict[ir.Value, ir.Value] = {} + def copy_value(value: ir.Value) -> ir.Value: + if value is None: + return None + if value in value_map: + return value_map[value] + # Create a formal-parameter value to represent this value: + new_value = ir.Value( + name = value.name, + shape = value.shape, + type = value.type, + doc_string= value.doc_string, + ) + value_map[value] = new_value + return new_value + def copy_attr_value(attr_value): + return attr_value # TODO + def copy_node(node: ir.Node) -> ir.Node: + new_inputs = [copy_value(v) for v in node.inputs] + new_attributes = {k: copy_attr_value(v) for k, v in node.attributes.items()} + new_node = ir.Node( + node.domain, + node.op_type, + new_inputs, + new_attributes, + overload=node.overload, + num_outputs=len(node.outputs), + graph=None, + name=node.name, + doc_string=node.doc_string, # type: ignore + metadata_props=node.metadata_props.copy(), + ) + new_outputs = new_node.outputs + for i, output in enumerate(node.outputs): + value_map[output] = new_outputs[i] + if output.name is not None: + new_outputs[i].name = output.name + return new_node + function_nodes = [copy_node(node) for node in nodes] + function_inputs = list(value_map.values()) + function_outputs = [copy_value(v) for v in outputs] + return (function_inputs, function_nodes, function_outputs) class RewriteRuleSet: def __init__(self, rules: Sequence[RewriteRule], *, commute: bool = False) -> None: @@ -1599,10 +1647,28 @@ def _apply_to_graph_or_function( delta.match.outputs, delta.new_outputs, ) + if rule.as_function: + # Create new function from delta.match.nodes and add it to model.functions. + # Determine: inputs/outputs, domain, name, overload, opset_imports. + # Create a copy of nodes, replacing actuals by formals. + original_nodes = delta.match.nodes + used_domains: set[str] = set(node.domain for node in original_nodes) + parent_opset_imports = graph_or_function.opset_imports + used_opset_imports = { k: v for k, v in parent_opset_imports.items() if k in used_domains } + inputs, nodes, outputs = _copy_for_function(original_nodes, delta.match.outputs) + assert len(delta.new_nodes) == 1 + call_node = delta.new_nodes[0] + domain = call_node.domain + name = call_node.op_type + overload = "" # TODO + graph = ir.Graph(inputs, outputs, nodes=nodes, opset_imports=used_opset_imports) + f = ir.Function(domain, name, overload, graph=graph, attributes={}) + model.functions[f.identifier()] = f count += 1 if rule.graph_post_visitor: rule.graph_post_visitor() + return count def apply_to_model( diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index ca865ecde..ddd38168a 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -577,6 +577,21 @@ def test_model(x: FLOAT[16, 8]) -> FLOAT[16, 4]: self.assertIn(init_name, model.graph.initializers) self.assertIs(last_node.inputs[1], model.graph.initializers[init_name]) + def test_extract_function(self): + def source_pattern(op, x, y, z): + sum = op.Add(x, y) + return op.Mul(sum, z) + def replacement(op, x, y, z): + return op.AddMul(x, y, z, _domain = "some.domain") + rule = pattern.RewriteRule(source_pattern, replacement, as_function=True) + @script() + def test_model(x: FLOAT[1024], y: FLOAT[1024], z: FLOAT[1024]) -> FLOAT[1024]: + return op.Mul(op.Add(x, y), z) + model_proto = test_model.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + rule.apply_to_model(model) + # self.assertEqual(len(model.functions), 1) + model.display() class PatternBuilderTest(unittest.TestCase): def test_pattern_builder_context(self): From 65479668ad67e47c0bef957ce3b270e778828bf8 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 19 Feb 2025 11:18:11 -0800 Subject: [PATCH 2/8] Finalize function extraction --- onnxscript/optimizer/__init__.py | 2 + onnxscript/rewriter/pattern.py | 94 ++++++++++++++++++----------- onnxscript/rewriter/pattern_test.py | 16 ++++- 3 files changed, 73 insertions(+), 39 deletions(-) diff --git a/onnxscript/optimizer/__init__.py b/onnxscript/optimizer/__init__.py index 8ba6229c1..65936a9ff 100644 --- a/onnxscript/optimizer/__init__.py +++ b/onnxscript/optimizer/__init__.py @@ -8,6 +8,7 @@ import onnxscript.optimizer._legacy._optimizer as legacy_optimizer import onnxscript.optimizer._legacy.constant_folding as legacy_constant_folding from onnxscript import ir +from onnxscript.optimizer._inliner import inline from onnxscript.optimizer._optimizer import optimize_ir from onnxscript.optimizer._remove_unused import remove_unused_nodes @@ -36,4 +37,5 @@ def fold_constants(model: ir.Model | onnx.ModelProto, *args, **kwargs): "optimize", "optimize_ir", "basic_constant_propagation", + "inline", ] diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 4ef78cd26..99bd7925b 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -1315,9 +1315,7 @@ def __init__( is complete for a graph or function. """ if as_function and not remove_nodes: - raise ValueError( - "as_function=True is only supported when remove_nodes=True." - ) + raise ValueError("as_function=True is only supported when remove_nodes=True.") if not isinstance(target_pattern, GraphPattern): target_pattern = _to_graph_pattern(target_pattern) self._target_pattern = target_pattern @@ -1533,24 +1531,33 @@ def check(self, op, *args, **kwargs): def rewrite(self, op, *args, **kwargs): raise NotImplementedError("Method 'rewrite' must be implemented by derived class.") -def _copy_for_function(nodes: Sequence[ir.Node], outputs: Sequence[ir.Value]): - value_map : dict[ir.Value, ir.Value] = {} - def copy_value(value: ir.Value) -> ir.Value: - if value is None: - return None - if value in value_map: - return value_map[value] + +def _copy_for_function( + inputs: Sequence[ir.Value], nodes: Sequence[ir.Node], outputs: Sequence[ir.Value] +): + value_map: dict[ir.Value, ir.Value] = {} + function_inputs: list[ir.Value] = [] + for input in inputs: # Create a formal-parameter value to represent this value: new_value = ir.Value( - name = value.name, - shape = value.shape, - type = value.type, - doc_string= value.doc_string, + name=input.name, + shape=input.shape, + type=input.type, + doc_string=input.doc_string, ) - value_map[value] = new_value - return new_value + value_map[input] = new_value + function_inputs.append(new_value) + + def copy_value(value: ir.Value) -> ir.Value: + if value is None: + return None + if value not in value_map: + raise ValueError(f"Value {value} not found in value_map.") + return value_map[value] + def copy_attr_value(attr_value): - return attr_value # TODO + return attr_value # TODO + def copy_node(node: ir.Node) -> ir.Node: new_inputs = [copy_value(v) for v in node.inputs] new_attributes = {k: copy_attr_value(v) for k, v in node.attributes.items()} @@ -1572,11 +1579,12 @@ def copy_node(node: ir.Node) -> ir.Node: if output.name is not None: new_outputs[i].name = output.name return new_node + function_nodes = [copy_node(node) for node in nodes] - function_inputs = list(value_map.values()) function_outputs = [copy_value(v) for v in outputs] return (function_inputs, function_nodes, function_outputs) + class RewriteRuleSet: def __init__(self, rules: Sequence[RewriteRule], *, commute: bool = False) -> None: if commute: @@ -1639,36 +1647,50 @@ def _apply_to_graph_or_function( # is sufficient for patterns with a single output-node "node", which can serve as the # insertion-point. onnxscript.optimizer.basic_constant_propagation(delta.new_nodes) - _convenience.replace_nodes_and_values( - graph_or_function, - node, - delta.match.nodes if rule.remove_nodes else [], - delta.new_nodes, - delta.match.outputs, - delta.new_outputs, - ) if rule.as_function: + if len(delta.new_nodes) != 1: + raise ValueError( + "as_function=True is only supported for patterns with a single replacement node." + ) + call_node = delta.new_nodes[0] + domain = call_node.domain + name = call_node.op_type + overload = "" # TODO + + # Create topologically sorted list of nodes to be replaced. + unsorted_nodes = set(delta.match.nodes) + original_nodes = [n for n in graph_or_function if n in unsorted_nodes] + inputs, nodes, outputs = _copy_for_function( + call_node.inputs, original_nodes, delta.match.outputs + ) # Create new function from delta.match.nodes and add it to model.functions. # Determine: inputs/outputs, domain, name, overload, opset_imports. # Create a copy of nodes, replacing actuals by formals. - original_nodes = delta.match.nodes + used_domains: set[str] = set(node.domain for node in original_nodes) parent_opset_imports = graph_or_function.opset_imports - used_opset_imports = { k: v for k, v in parent_opset_imports.items() if k in used_domains } - inputs, nodes, outputs = _copy_for_function(original_nodes, delta.match.outputs) - assert len(delta.new_nodes) == 1 - call_node = delta.new_nodes[0] - domain = call_node.domain - name = call_node.op_type - overload = "" # TODO - graph = ir.Graph(inputs, outputs, nodes=nodes, opset_imports=used_opset_imports) + used_opset_imports = { + k: v for k, v in parent_opset_imports.items() if k in used_domains + } + + graph = ir.Graph( + inputs, outputs, nodes=nodes, opset_imports=used_opset_imports + ) f = ir.Function(domain, name, overload, graph=graph, attributes={}) model.functions[f.identifier()] = f + _convenience.replace_nodes_and_values( + graph_or_function, + node, + delta.match.nodes if rule.remove_nodes else [], + delta.new_nodes, + delta.match.outputs, + delta.new_outputs, + ) + count += 1 if rule.graph_post_visitor: rule.graph_post_visitor() - return count def apply_to_model( diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index ddd38168a..ad16ab464 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -9,6 +9,7 @@ import onnx.checker import onnx.parser +import onnxscript.optimizer from onnxscript import FLOAT, ir, script from onnxscript import opset17 as op from onnxscript.rewriter import cast_constant_of_shape, pattern @@ -581,17 +582,26 @@ def test_extract_function(self): def source_pattern(op, x, y, z): sum = op.Add(x, y) return op.Mul(sum, z) + def replacement(op, x, y, z): - return op.AddMul(x, y, z, _domain = "some.domain") + return op.AddMul(x, y, z, _domain="some.domain") + rule = pattern.RewriteRule(source_pattern, replacement, as_function=True) + @script() def test_model(x: FLOAT[1024], y: FLOAT[1024], z: FLOAT[1024]) -> FLOAT[1024]: return op.Mul(op.Add(x, y), z) + model_proto = test_model.to_model_proto() model = ir.serde.deserialize_model(model_proto) rule.apply_to_model(model) - # self.assertEqual(len(model.functions), 1) - model.display() + self.assertEqual([x.op_type for x in model.graph], ["AddMul"]) + self.assertEqual([f.name for f in model.functions.values()], ["AddMul"]) + function = model.functions[("some.domain", "AddMul", "")] + self.assertEqual([x.op_type for x in function], ["Add", "Mul"]) + onnxscript.optimizer.inline(model) + self.assertEqual([x.op_type for x in model.graph], ["Add", "Mul"]) + class PatternBuilderTest(unittest.TestCase): def test_pattern_builder_context(self): From 67d275d39a695a81ad10189f8df1745fcbb04fa9 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 19 Feb 2025 12:21:35 -0800 Subject: [PATCH 3/8] Fix attr copy and overload --- onnxscript/rewriter/pattern.py | 45 +++++++++++++++++++++++------ onnxscript/rewriter/pattern_test.py | 11 +++++-- 2 files changed, 44 insertions(+), 12 deletions(-) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 99bd7925b..695cced06 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -1535,10 +1535,11 @@ def rewrite(self, op, *args, **kwargs): def _copy_for_function( inputs: Sequence[ir.Value], nodes: Sequence[ir.Node], outputs: Sequence[ir.Value] ): + """Utility function to extract a subgraph out as a function.""" value_map: dict[ir.Value, ir.Value] = {} function_inputs: list[ir.Value] = [] for input in inputs: - # Create a formal-parameter value to represent this value: + # Create a function input (formal-parameter value) to represent this value: new_value = ir.Value( name=input.name, shape=input.shape, @@ -1548,15 +1549,19 @@ def _copy_for_function( value_map[input] = new_value function_inputs.append(new_value) - def copy_value(value: ir.Value) -> ir.Value: + def copy_value(value: ir.Value | None) -> ir.Value | None: if value is None: return None if value not in value_map: raise ValueError(f"Value {value} not found in value_map.") return value_map[value] - def copy_attr_value(attr_value): - return attr_value # TODO + def copy_attr_value(attr: ir.Attr | ir.RefAttr) -> ir.Attr | ir.RefAttr: + if not isinstance(attr, ir.Attr): + raise ValueError("RefAttr not supported.") + if attr.type in {ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS}: + raise ValueError("Graph attributes not supported.") + return attr.copy() def copy_node(node: ir.Node) -> ir.Node: new_inputs = [copy_value(v) for v in node.inputs] @@ -1585,6 +1590,27 @@ def copy_node(node: ir.Node) -> ir.Node: return (function_inputs, function_nodes, function_outputs) +def _get_new_overload(model: ir.Model, domain: str, name: str) -> str: + """Get a new overload for the given domain and name. + + Args: + model: The model to which the new overload will be added. + domain: The domain of the new overload. + name: The opname of the new overload. + + Returns: + The new overload name. + """ + existing_functions = model.functions + # Just a simple implementation for now + overload = 1 + while True: + overload_name = str(overload) + if (domain, name, overload_name) not in existing_functions: + return overload_name + overload += 1 + + class RewriteRuleSet: def __init__(self, rules: Sequence[RewriteRule], *, commute: bool = False) -> None: if commute: @@ -1648,6 +1674,7 @@ def _apply_to_graph_or_function( # insertion-point. onnxscript.optimizer.basic_constant_propagation(delta.new_nodes) if rule.as_function: + # Create a function out of a copy of the matched nodes if len(delta.new_nodes) != 1: raise ValueError( "as_function=True is only supported for patterns with a single replacement node." @@ -1655,17 +1682,16 @@ def _apply_to_graph_or_function( call_node = delta.new_nodes[0] domain = call_node.domain name = call_node.op_type - overload = "" # TODO + overload = _get_new_overload(model, domain, name) + call_node.overload = overload # Create topologically sorted list of nodes to be replaced. unsorted_nodes = set(delta.match.nodes) original_nodes = [n for n in graph_or_function if n in unsorted_nodes] + # Create new inputs/nodes/outputs for the function inputs, nodes, outputs = _copy_for_function( call_node.inputs, original_nodes, delta.match.outputs ) - # Create new function from delta.match.nodes and add it to model.functions. - # Determine: inputs/outputs, domain, name, overload, opset_imports. - # Create a copy of nodes, replacing actuals by formals. used_domains: set[str] = set(node.domain for node in original_nodes) parent_opset_imports = graph_or_function.opset_imports @@ -1711,10 +1737,11 @@ def apply_to_model( assert isinstance(model, ir.Model) tracer = MatchingTracer() if debug else None onnxscript.optimizer.basic_constant_propagation(model.graph) + original_functions = list(model.functions.values()) count = self._apply_to_graph_or_function( model, model.graph, verbose=verbose, tracer=tracer ) - for function in model.functions.values(): + for function in original_functions: onnxscript.optimizer.basic_constant_propagation(function) count += self._apply_to_graph_or_function( model, function, verbose=verbose, tracer=tracer diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index ad16ab464..f9b25e99c 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -595,9 +595,14 @@ def test_model(x: FLOAT[1024], y: FLOAT[1024], z: FLOAT[1024]) -> FLOAT[1024]: model_proto = test_model.to_model_proto() model = ir.serde.deserialize_model(model_proto) rule.apply_to_model(model) - self.assertEqual([x.op_type for x in model.graph], ["AddMul"]) - self.assertEqual([f.name for f in model.functions.values()], ["AddMul"]) - function = model.functions[("some.domain", "AddMul", "")] + self.assertEqual(len(model.functions), 1) + self.assertEqual(len(model.graph), 1) + call_node = model.graph.node(0) + self.assertEqual(call_node.domain, "some.domain") + self.assertEqual(call_node.op_type, "AddMul") + function_id = call_node.op_identifier() + self.assertIn(function_id, model.functions) + function = model.functions[function_id] self.assertEqual([x.op_type for x in function], ["Add", "Mul"]) onnxscript.optimizer.inline(model) self.assertEqual([x.op_type for x in model.graph], ["Add", "Mul"]) From 990ea74d3587d0b650aaab60a5c6580e4c148b0b Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 19 Feb 2025 14:01:19 -0800 Subject: [PATCH 4/8] Add tests for attributes and overload --- onnxscript/rewriter/pattern.py | 4 +- onnxscript/rewriter/pattern_test.py | 60 +++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 2 deletions(-) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 695cced06..1d1d7aac8 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -1561,11 +1561,11 @@ def copy_attr_value(attr: ir.Attr | ir.RefAttr) -> ir.Attr | ir.RefAttr: raise ValueError("RefAttr not supported.") if attr.type in {ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS}: raise ValueError("Graph attributes not supported.") - return attr.copy() + return ir.Attr(attr.name, attr.type, attr.value, doc_string=attr.doc_string) def copy_node(node: ir.Node) -> ir.Node: new_inputs = [copy_value(v) for v in node.inputs] - new_attributes = {k: copy_attr_value(v) for k, v in node.attributes.items()} + new_attributes = [copy_attr_value(v) for v in node.attributes.values()] new_node = ir.Node( node.domain, node.op_type, diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index f9b25e99c..1906b28d0 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -607,6 +607,66 @@ def test_model(x: FLOAT[1024], y: FLOAT[1024], z: FLOAT[1024]) -> FLOAT[1024]: onnxscript.optimizer.inline(model) self.assertEqual([x.op_type for x in model.graph], ["Add", "Mul"]) + def test_extract_function_with_attr(self): + def source_pattern(op, x, y): + sum = op.Add(x, y) + return op.Transpose(sum, perm=[1, 0]) + + def replacement(op, x, y): + return op.AddTranspose(x, y, _domain="some.domain") + + rule = pattern.RewriteRule(source_pattern, replacement, as_function=True) + + @script() + def test_model(x: FLOAT[1024, 512], y: FLOAT[1024, 512]) -> FLOAT[512, 1024]: + return op.Transpose(op.Add(x, y), perm=[1, 0]) + + model_proto = test_model.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + rule.apply_to_model(model) + self.assertEqual(len(model.functions), 1) + self.assertEqual(len(model.graph), 1) + call_node = model.graph.node(0) + self.assertEqual(call_node.domain, "some.domain") + self.assertEqual(call_node.op_type, "AddTranspose") + function_id = call_node.op_identifier() + self.assertIn(function_id, model.functions) + function = model.functions[function_id] + self.assertEqual([x.op_type for x in function], ["Add", "Transpose"]) + transpose_node = function[1] + self.assertEqual(transpose_node.attributes["perm"].value, [1, 0]) + onnxscript.optimizer.inline(model) + self.assertEqual([x.op_type for x in model.graph], ["Add", "Transpose"]) + + def test_extract_repeated_function(self): + def source_pattern(op, x, y, z): + sum = op.Add(x, y) + return op.Mul(sum, z) + + def replacement(op, x, y, z): + return op.AddMul(x, y, z, _domain="some.domain") + + rule = pattern.RewriteRule(source_pattern, replacement, as_function=True) + + @script() + def test_model(x: FLOAT[1024], y: FLOAT[1024], z: FLOAT[1024]) -> FLOAT[1024]: + t1 = op.Mul(op.Add(x, y), z) + t2 = op.Mul(op.Add(t1, y), z) + return t2 + + model_proto = test_model.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + rule.apply_to_model(model) + self.assertEqual(len(model.functions), 2) + self.assertEqual(len(model.graph), 2) + for call_node in model.graph: + self.assertEqual(call_node.domain, "some.domain") + self.assertEqual(call_node.op_type, "AddMul") + function_id = call_node.op_identifier() + self.assertIn(function_id, model.functions) + onnxscript.optimizer.inline(model) + self.assertEqual([x.op_type for x in model.graph], ["Add", "Mul", "Add", "Mul"]) + class PatternBuilderTest(unittest.TestCase): def test_pattern_builder_context(self): From 4fdade21c7b2e8f9ca49d8ec548ebb39395f22d6 Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Thu, 20 Feb 2025 09:05:39 -0800 Subject: [PATCH 5/8] Update onnxscript/rewriter/pattern.py Co-authored-by: Justin Chu --- onnxscript/rewriter/pattern.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 1d1d7aac8..da8baa050 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -1702,7 +1702,7 @@ def _apply_to_graph_or_function( graph = ir.Graph( inputs, outputs, nodes=nodes, opset_imports=used_opset_imports ) - f = ir.Function(domain, name, overload, graph=graph, attributes={}) + f = ir.Function(domain, name, overload, graph=graph, attributes=()) model.functions[f.identifier()] = f _convenience.replace_nodes_and_values( graph_or_function, From bbb09288c45136e7ae9058e69a46091e7ebacce4 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 20 Feb 2025 11:14:57 -0800 Subject: [PATCH 6/8] Fix lint errors --- onnxscript/rewriter/pattern.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index da8baa050..18cda4ab4 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -1533,13 +1533,15 @@ def rewrite(self, op, *args, **kwargs): def _copy_for_function( - inputs: Sequence[ir.Value], nodes: Sequence[ir.Node], outputs: Sequence[ir.Value] + inputs: Sequence[ir.Value | None], nodes: Sequence[ir.Node], outputs: Sequence[ir.Value] ): """Utility function to extract a subgraph out as a function.""" value_map: dict[ir.Value, ir.Value] = {} function_inputs: list[ir.Value] = [] for input in inputs: # Create a function input (formal-parameter value) to represent this value: + if input is None: + raise NotImplementedError("None inputs not supported.") new_value = ir.Value( name=input.name, shape=input.shape, @@ -1558,9 +1560,13 @@ def copy_value(value: ir.Value | None) -> ir.Value | None: def copy_attr_value(attr: ir.Attr | ir.RefAttr) -> ir.Attr | ir.RefAttr: if not isinstance(attr, ir.Attr): - raise ValueError("RefAttr not supported.") + # No need to support this currently, as rewriting inside a function is + # not used, as it has several challenges. + raise NotImplementedError("RefAttr not supported.") if attr.type in {ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS}: - raise ValueError("Graph attributes not supported.") + # No need to support this currently, as rewriting control-flow constructs + # is not used and has several challenges. + raise NotImplementedError("Graph attributes not supported.") return ir.Attr(attr.name, attr.type, attr.value, doc_string=attr.doc_string) def copy_node(node: ir.Node) -> ir.Node: @@ -1693,7 +1699,7 @@ def _apply_to_graph_or_function( call_node.inputs, original_nodes, delta.match.outputs ) - used_domains: set[str] = set(node.domain for node in original_nodes) + used_domains: set[str] = {node.domain for node in original_nodes} parent_opset_imports = graph_or_function.opset_imports used_opset_imports = { k: v for k, v in parent_opset_imports.items() if k in used_domains From 0cec9083f08b0f27f7bfc7e7057a4d7db3e94558 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 20 Feb 2025 11:27:49 -0800 Subject: [PATCH 7/8] Add documentation --- onnxscript/rewriter/pattern.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 18cda4ab4..74457a1fe 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -1313,6 +1313,10 @@ def __init__( rewriting to the top-level graph or a function. graph_post_visitor: A function that will be called after the rewriting is complete for a graph or function. + as_function: If True, the matched nodes will be extracted into a model + local function. This is only supported when remove_nodes=True and + when the replacement subgraph has a single node, representing the + function call. """ if as_function and not remove_nodes: raise ValueError("as_function=True is only supported when remove_nodes=True.") From 0e4bad7367e35a0be818d8a2c07bcc88017cf787 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 20 Feb 2025 12:07:20 -0800 Subject: [PATCH 8/8] Address PR feedback --- onnxscript/rewriter/pattern.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 74457a1fe..4dc95b29b 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -1571,7 +1571,8 @@ def copy_attr_value(attr: ir.Attr | ir.RefAttr) -> ir.Attr | ir.RefAttr: # No need to support this currently, as rewriting control-flow constructs # is not used and has several challenges. raise NotImplementedError("Graph attributes not supported.") - return ir.Attr(attr.name, attr.type, attr.value, doc_string=attr.doc_string) + # Primitive attributes are immutable by design and can be shared. + return attr def copy_node(node: ir.Node) -> ir.Node: new_inputs = [copy_value(v) for v in node.inputs] @@ -1747,6 +1748,8 @@ def apply_to_model( assert isinstance(model, ir.Model) tracer = MatchingTracer() if debug else None onnxscript.optimizer.basic_constant_propagation(model.graph) + # Rewriting may introduce new functions. In the following loop, + # we restrict rewriting to original functions, not newly introduced ones. original_functions = list(model.functions.values()) count = self._apply_to_graph_or_function( model, model.graph, verbose=verbose, tracer=tracer