Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Generalize MHA pattern #2092

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions onnxscript/rewriter/_ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,17 @@ def has_rank(value: ir.Value | None, rank: int) -> bool:
return False
shape = value.shape
return (shape is not None) and (shape.rank() == rank)


def get_dim(value: ir.Value | None, dim: int) -> ir.SymbolicDim | int | None:
"""Returns the value of the given dimension, or None if it is not statically known."""
if value is None:
return None
shape = value.shape
if shape is None:
return None
if dim < 0:
dim += shape.rank()
Comment on lines +136 to +137
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be absorbed into return shape[dim]?

if dim < 0 or dim >= shape.rank():
return None
return shape[dim]
1 change: 1 addition & 0 deletions onnxscript/rewriter/llama_rule_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,5 +304,6 @@ def llama_p0_rule_set() -> orp.RewriteRuleSet:
transpose_identity_rule,
transpose_transpose_rule,
unsqueeze_unsqueeze_rule,
squeeze_reshape_1d_rule,
]
)
4 changes: 4 additions & 0 deletions onnxscript/rewriter/ort_fusions/_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
import onnx
import onnxruntime
import packaging.version

import onnxscript.ir as ir
import onnxscript.ir._io as io
Expand All @@ -21,6 +22,9 @@ def _save(model, modelpath):
io.save(model, modelpath)


ort_version = packaging.version.Version(onnxruntime.__version__)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this used? If so

Suggested change
ort_version = packaging.version.Version(onnxruntime.__version__)
_ORT_VERSION = packaging.version.Version(onnxruntime.__version__)



def ort_run(model_name: str, model, inputs):
providers = ["CPUExecutionProvider"]
with tempfile.TemporaryDirectory() as temp_dir:
Expand Down
243 changes: 143 additions & 100 deletions onnxscript/rewriter/ort_fusions/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,36 @@
# Licensed under the MIT License.
from __future__ import annotations

from typing import Sequence
from typing import Sequence, Union

import onnxscript.ir as ir
from onnxscript.rewriter import pattern
from onnxscript.rewriter import _ir_utils, pattern

"""
The MultiHeadAttention pattern:
The MultiHeadAttention pattern: generate an instance
MHA (query, key, value, None, None, mask, past_key, past_value)
where query has shape (B, S, D), key has shape (B, Skv, D), and value has shape (B, Skv, Dv).
The next two inputs bias and key_padding_mask are None in this pattern. The mask (attention_bias)
must be of shape (1 or B, 1 or H, S, St). past_key and past_value are of shape (B, H, Spast, Dh).

We use the following abbreviations for the dimensions:
B: Batch size
S: Sequence length
D: input embedding dimension
Dv: value hidden size (usually, Dv = D)
H: number of heads
d_h: head size (usually, D = H * d_h)
Dh: head size or embedding dimension per head (usually, D = H * Dh)
Skv: key/value sequence length
St: total sequence length

thus, weights are usually of shape (D, D) and (D, D) and (D, D)

for each of Q, K, and V, we have the following pattern:
MatMul (Input, W), producing output of shape (B, S, D)
Reshape to produce a matrix of shape (B, S, H, d_h)
Transpose middle two axes to produce a matrix of shape (B, H, S, d_h)

This is followed by a RotaryEmbedding pattern for Q and K

The last two axes of the key-embedding are then swapped (using a Reshape/Transpose/Reshape sequence)

The dot-product attention is then computed using SDPA.
Finally, the output is transposed and reshaped back to (B, S, D) shape
In the sequel, the suffix "_BHSDh" indicates that the tensor has the shape (B, H, S, Dh).
The suffix "BH_Skv_Dh" indicates that the tensor has the shape (B*H, Skv, Dh).
"""

Dim = Union[int, ir.SymbolicDim]

def _check_shape(bindings: dict[str, int], val: ir.Value, shape: Sequence[str]) -> bool:

def _check_shape(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str]) -> bool:
if val.shape is None:
return False
if val.shape.rank() != len(shape):
Expand All @@ -46,131 +45,171 @@


class MultiHeadAttention(pattern.RewriteRuleClassBase):
def __init__(self, name: str, *, use_2d_matmul: bool):
super().__init__(name)
self._use_2d_matmul = use_2d_matmul

def _compute_QKV(self, op, input, weight, reshape_var: str):
"""Applied to generate each of Q, K, and V from input."""
if self._use_2d_matmul:
# Convert batched input of shape (B, S, D) to 2D input (B*S, D)
input = op.Reshape(input, _allow_other_inputs=True)
projected = op.MatMul(input, weight)
if self._use_2d_matmul:
# Convert 2D output back to batched output of shape (B, S, D)
projected = op.Reshape(projected, _allow_other_inputs=True)
# Reshape from (B, S, D) to (B, S, H, D/H)
reshaped = op.Reshape(
projected,
_allow_other_inputs=True,
_allow_other_attributes=True,
_outputs=[reshape_var],
)
# Transpose from (B, S, H, D/H) to (B, H, S, D/H)
transposed = op.Transpose(reshaped, perm=[0, 2, 1, 3])
return transposed
def __init__(self):
super().__init__("MHA")

def pattern(
self,
op,
input,
query_weight,
key_weight,
value_weight,
qkv_weight,
query_BSD,
key_BSD,
value_BSD,
mask,
cos,
sin,
past_key,
past_value,
position_ids,
cos,
sin,
):
query = self._compute_QKV(op, input, query_weight, "query_mm_reshaped")
key = self._compute_QKV(op, input, key_weight, "key_mm_reshaped")
value = self._compute_QKV(op, input, value_weight, "value_mm_reshaped")
# First, query, key, and value are reshaped+transposed from (B, S, D) to (B, H, S, D/H)

# Reshape from (B, S, D) to (B, S, H, D/H)
query_BSHDh = op.Reshape(
query_BSD,
_allow_other_inputs=True,
_allow_other_attributes=True,
_outputs=["query_BSHDh"],
)
# Transpose from (B, S, H, D/H) to (B, H, S, D/H)
query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3])

# Reshape from (B, S, D) to (B, S, H, D/H)
key_BSHDh = op.Reshape(
key_BSD,
_allow_other_inputs=True,
_allow_other_attributes=True,
_outputs=["key_BSHDh"],
)
# Transpose from (B, S, H, D/H) to (B, H, S, D/H)
key_BHSDh = op.Transpose(key_BSHDh, perm=[0, 2, 1, 3])

# Reshape from (B, S, D) to (B, S, H, D/H)
value_BSHDh = op.Reshape(
value_BSD,
_allow_other_inputs=True,
_allow_other_attributes=True,
_outputs=["value_BSHDh"],
)
# Transpose from (B, S, H, D/H) to (B, H, S, D/H)
value_BHSDh = op.Transpose(value_BSHDh, perm=[0, 2, 1, 3])

query_BHSDh_rope = op.RotaryEmbedding(
query_BHSDh, position_ids, cos, sin, _domain="com.microsoft"
)
key_BHSDh_rope = op.RotaryEmbedding(
key_BHSDh, position_ids, cos, sin, _domain="com.microsoft"
)

query_rope = op.RotaryEmbedding(query, position_ids, cos, sin, _domain="com.microsoft")
# Concatenate past_key cache and current key, and transpose to enable
# dot-product attention computation.

key_rope = op.RotaryEmbedding(key, position_ids, cos, sin, _domain="com.microsoft")
key_rope = op.Concat(past_key, key_rope, axis=-2)
# Transpose last two axes of key_rope to compute dot-product via matmul.
key_reshaped = op.Reshape(
key_rope, _allow_other_inputs=True, _outputs=["key_reshaped"]
key_seq = op.Concat(past_key, key_BHSDh_rope, axis=-2)
# Transpose last two axes of key_seq to compute dot-product via matmul.
key_seq_BH_Skv_Dh = op.Reshape(
key_seq, _allow_other_inputs=True, _outputs=["key_seq_BH_Skv_Dh"]
)
key_reshaped_transposed = op.Transpose(key_reshaped, perm=[0, 2, 1])
key_transposed = op.Reshape(
key_reshaped_transposed, _allow_other_inputs=True, _outputs=["key_transposed"]
key_seq_BH_Dh_Skv = op.Transpose(key_seq_BH_Skv_Dh, perm=[0, 2, 1])
key_seq_B_H_Dh_Skv = op.Reshape(
key_seq_BH_Dh_Skv, _allow_other_inputs=True, _outputs=["key_seq_B_H_Dh_Skv"]
)

value = op.Concat(past_value, value, axis=-2)
# Concatenate past_value cache and current value
value_seq = op.Concat(past_value, value_BHSDh, axis=-2)

attention = op.SDPA(
query_rope, key_transposed, value, mask, _domain="ai.onnxruntime.fusion"
query_BHSDh_rope,
key_seq_B_H_Dh_Skv,
value_seq,
mask,
_domain="ai.onnxruntime.fusion",
)
# Transpose back to (B, S, H, D/H)

# Transpose attention back to (B, S, H, D/H)
attention_transposed = op.Transpose(attention, perm=[0, 2, 1, 3])
# Reshape back to (B, S, D)
attention_reshaped = op.Reshape(
attention_transposed, _allow_other_inputs=True, _outputs=["attention_reshaped"]
)
return attention_reshaped, key_rope, value
return attention_reshaped, key_seq, value_seq

def check(
self,
op,
query_mm_reshaped,
key_mm_reshaped,
value_mm_reshaped,
key_reshaped,
key_transposed,
attention_reshaped,
query_BSD,
key_BSD,
value_BSD,
mask,
past_key,
past_value,
query_BSHDh,
key_BSHDh,
value_BSHDh,
**_,
):
bindings: dict[str, int] = {}
status = (
_check_shape(bindings, query_mm_reshaped, ["B", "S", "H", "d_h"])
and _check_shape(bindings, key_mm_reshaped, ["B", "S", "H", "d_h"])
and _check_shape(bindings, value_mm_reshaped, ["B", "S", "H", "d_h"])
and _check_shape(bindings, key_reshaped, ["B*H", "KVS", "d_h"])
and _check_shape(bindings, key_transposed, ["B", "H", "d_h", "KVS"])
and _check_shape(bindings, attention_reshaped, ["B", "S", "H*d_h"])
)
if not status:
bindings: dict[str, Dim] = {}

def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
return not _check_shape(bindings, val, dims)

if no_match(query_BSD, ["B", "S", "D"]):
return False
if no_match(key_BSD, ["B", "Skv", "D"]):
return False
if no_match(value_BSD, ["B", "Skv", "D"]):
return False
# TODO: broadcast check
# if no_match(mask, ["B", "H", "S", "St"]):
# return False
Comment on lines +161 to +162

Check notice

Code scanning / CodeQL

Commented-out code Note

This comment appears to contain commented-out code.

Copilot Autofix AI 3 days ago

To fix the problem, we need to either remove the commented-out code or reinstate it if it is necessary for the functionality. Given the presence of TODO comments, it is likely that the commented-out code was intended to be revisited and possibly reinstated. Therefore, the best approach is to reinstate the commented-out code and ensure that it is functional.

  • Reinstate the commented-out code on lines 161-162 and 173-178.
  • Ensure that the reinstated code is properly integrated and does not cause any issues.
Suggested changeset 1
onnxscript/rewriter/ort_fusions/mha.py

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/onnxscript/rewriter/ort_fusions/mha.py b/onnxscript/rewriter/ort_fusions/mha.py
--- a/onnxscript/rewriter/ort_fusions/mha.py
+++ b/onnxscript/rewriter/ort_fusions/mha.py
@@ -160,4 +160,4 @@
         # TODO: broadcast check
-        # if no_match(mask, ["B", "H", "S", "St"]):
-        #     return False
+        if no_match(mask, ["B", "H", "S", "St"]):
+            return False
         if no_match(past_key, ["B", "H", "Spast", "Dh"]):
@@ -172,8 +172,8 @@
             return False
-        # if not status:
-        #     return False
-        # if bindings["B"] * bindings["H"] != bindings["B*H"]:
-        #     return False
-        # if bindings["H"] * bindings["Dh"] != bindings["H*Dh"]:
-        #     return False
+        if not status:
+            return False
+        if bindings["B"] * bindings["H"] != bindings["B*H"]:
+            return False
+        if bindings["H"] * bindings["Dh"] != bindings["H*Dh"]:
+            return False
         return True
EOF
@@ -160,4 +160,4 @@
# TODO: broadcast check
# if no_match(mask, ["B", "H", "S", "St"]):
# return False
if no_match(mask, ["B", "H", "S", "St"]):
return False
if no_match(past_key, ["B", "H", "Spast", "Dh"]):
@@ -172,8 +172,8 @@
return False
# if not status:
# return False
# if bindings["B"] * bindings["H"] != bindings["B*H"]:
# return False
# if bindings["H"] * bindings["Dh"] != bindings["H*Dh"]:
# return False
if not status:
return False
if bindings["B"] * bindings["H"] != bindings["B*H"]:
return False
if bindings["H"] * bindings["Dh"] != bindings["H*Dh"]:
return False
return True
Copilot is powered by AI and may make mistakes. Always verify output.
Positive Feedback
Negative Feedback

Provide additional feedback

Please help us improve GitHub Copilot by sharing more details about this comment.

Please select one or more of the options
if no_match(past_key, ["B", "H", "Spast", "Dh"]):
return False
if no_match(past_value, ["B", "H", "Spast", "Dv"]):
return False
if no_match(query_BSHDh, ["B", "S", "H", "Dh"]):
return False
if no_match(key_BSHDh, ["B", "S", "H", "Dh"]):
return False
if no_match(value_BSHDh, ["B", "S", "H", "Dh"]):
return False
# if not status:
# return False
# if bindings["B"] * bindings["H"] != bindings["B*H"]:
# return False
# if bindings["H"] * bindings["d_h"] != bindings["H*d_h"]:
# if bindings["H"] * bindings["Dh"] != bindings["H*Dh"]:
# return False

Check notice

Code scanning / CodeQL

Commented-out code Note

This comment appears to contain commented-out code.
return True

def rewrite(
self,
op,
input,
query_weight,
key_weight,
value_weight,
query_BSD,
key_BSD,
value_BSD,
mask,
cos,
sin,
past_key,
past_value,
key_BSHDh,
position_ids,
query_mm_reshaped,
cos,
sin,
**_,
):
num_heads = query_mm_reshaped.shape[2]
query = op.MatMul(input, query_weight)
key = op.MatMul(input, key_weight)
value = op.MatMul(input, value_weight)

query_rope = op.RotaryEmbedding(query, position_ids, cos, sin, _domain="com.microsoft")
key_rope = op.RotaryEmbedding(key, position_ids, cos, sin, _domain="com.microsoft")
num_heads = _ir_utils.get_dim(key_BSHDh, 2)
if not isinstance(num_heads, int):
return None

# Switch to 3D RotaryEmbedding
# TODO: forward other attributes
query_BSD_rope = op.RotaryEmbedding(
query_BSD, position_ids, cos, sin, _domain="com.microsoft"
)
key_BSD_rope = op.RotaryEmbedding(
key_BSD, position_ids, cos, sin, _domain="com.microsoft"
)

return op.MultiHeadAttention(
query_rope,
key_rope,
value,
query_BSD_rope,
key_BSD_rope,
value_BSD,
None, # bias
None, # key padding mask
mask, # attention mask/bias
Expand All @@ -182,11 +221,15 @@
)


_rule1 = MultiHeadAttention.rule("MHA_2dmm", use_2d_matmul=False)
_rule1 = MultiHeadAttention.rule()

mha_rules = pattern.RewriteRuleSet([_rule1])


def fuse_mha(model: ir.Model) -> int:
def fuse_mha(model: ir.Model, *, debug: bool = False) -> int:
count = mha_rules.apply_to_model(model)
if debug and count == 0:
tracer = pattern.MatchingTracer()
mha_rules.apply_to_model(model, tracer=tracer)
tracer.report()
return count
Loading
Loading