Skip to content

Add DiT attention fusion for F5-TTS and diffusion transformer models#27999

Open
Rishi-Dave wants to merge 4 commits intomicrosoft:mainfrom
Rishi-Dave:rishidave/feat/dit-attention-fusion
Open

Add DiT attention fusion for F5-TTS and diffusion transformer models#27999
Rishi-Dave wants to merge 4 commits intomicrosoft:mainfrom
Rishi-Dave:rishidave/feat/dit-attention-fusion

Conversation

@Rishi-Dave
Copy link
Copy Markdown
Contributor

Summary

  • Add FusionMultiHeadAttentionDiT to recognize DiT-style attention patterns (F5-TTS, etc.) and fuse them into MultiHeadAttention, enabling Flash Attention dispatch.
  • Register the new fusion as a second pass in MmditOnnxModel.fuse_multi_head_attention(), alongside the existing MMDit fusion for SD3/Flux.
  • Add test model generator and three test cases covering FP32, FP16 cast, and custom-scale variants.

Motivation

Fixes #27983

DiT models like F5-TTS use an attention pattern where Q, K, V are pre-computed (e.g., after RoPE) in BNSH format, K is pre-transposed to BNHS, and a custom scalar scale (e.g., 100.0) is applied via Mul before Softmax. Optional Cast nodes (FP16↔FP32) may wrap Softmax for mixed-precision inference.

The existing MMDit fusion (for SD3/Flux) expects a specific Mul→Sqrt→Div→Sqrt→Cast→Slice→Shape scaling path and does not match the simpler Mul(scalar_constant) pattern, so the attention is never fused and Flash Attention is never dispatched. This causes ~44 extra Cast ops per inference and ~200ms overhead per forward pass.

Changes

New files:

  • onnxruntime/python/tools/transformers/fusion_mha_dit.py — Core fusion class that matches the pattern:

    MatMul(Q, K^T) → [Cast FP16→FP32] → Mul(scale) → Softmax → [Cast FP32→FP16] → MatMul(attn, V)
        → Transpose(0,2,1,3) → Reshape → output
    

    and replaces it with a single MultiHeadAttention op (with scale attribute).

  • onnxruntime/test/python/transformers/dit_model_generator.py — Synthetic ONNX graph generators for testing.

Modified files:

  • onnxruntime/python/tools/transformers/onnx_model_mmdit.py — Register FusionMultiHeadAttentionDiT as a second fusion pass after the existing MMDit fusion.
  • onnxruntime/test/python/transformers/test_attention_fusion.py — Three new test cases:
    • test_dit_attention_fusion — FP32 with K pre-transpose, scale=100.0
    • test_dit_attention_fusion_with_fp16_casts — FP16 Cast nodes around Softmax
    • test_dit_attention_fusion_custom_scale — Standard 1/√d_k scale

Test Plan

  • All three new DiT fusion tests pass, verifying:
    • Exactly 1 MultiHeadAttention node is produced
    • num_heads attribute is correctly detected from upstream Reshape shapes
    • scale attribute matches the original scalar constant
    • No Softmax nodes remain after fusion
  • Existing attention fusion tests remain unaffected
  • ruff check, ruff format, and lintrunner -a pass clean

… models

DiT models like F5-TTS use an attention pattern where Q, K, V are
pre-computed (e.g. after RoPE) in BNSH format, K is pre-transposed to
BNHS, and a custom scalar scale (e.g. 100.0) is applied before Softmax.
Optional Cast nodes (FP16<->FP32) may wrap Softmax for mixed-precision.

The existing MMDit fusion (for SD3/Flux) expects a very specific
Mul->Sqrt->Div->Sqrt->Cast->Slice->Shape scaling path and does not
match the simpler Mul(scalar) pattern used in DiT models, so Flash
Attention is never dispatched.

This commit adds FusionMultiHeadAttentionDiT which recognizes:
  MatMul(Q, K^T) -> [Cast] -> Mul(scale) -> Softmax -> [Cast] -> MatMul(attn, V)

and fuses it into a single MultiHeadAttention op with the custom scale
attribute, enabling Flash Attention dispatch.

Fixes microsoft#27983
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends the Python transformers graph fuser to recognize and fuse DiT-style attention patterns (e.g., F5-TTS) into a com.microsoft::MultiHeadAttention node so Flash Attention can be dispatched for these diffusion transformer models.

Changes:

  • Add a new DiT attention fusion pass (FusionMultiHeadAttentionDiT) that detects MatMul → (Cast) → Mul(scale) → Softmax → (Cast) → MatMul patterns and replaces them with MultiHeadAttention (including scale).
  • Register the new fusion as a second attention fusion pass in MmditOnnxModel.fuse_multi_head_attention().
  • Add a synthetic DiT ONNX model generator and new unit tests for FP32, optional Cast-wrapped Softmax, and custom scale variants.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.

File Description
onnxruntime/python/tools/transformers/fusion_mha_dit.py Implements DiT-specific pattern matching and replacement with MultiHeadAttention + scale.
onnxruntime/python/tools/transformers/onnx_model_mmdit.py Runs the new DiT attention fusion pass after the existing MMDiT fusion pass.
onnxruntime/test/python/transformers/dit_model_generator.py Adds synthetic DiT attention subgraph generators used by fusion tests.
onnxruntime/test/python/transformers/test_attention_fusion.py Adds three unit tests validating DiT attention fusion behavior and attributes.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Cast V to FP16 in the FP16-cast test model so the attention MatMul
has type-consistent inputs. Add Softmax-count-is-zero assertions to
the FP16-cast and custom-scale tests to match the base test coverage.
@tianleiwu tianleiwu requested a review from Copilot April 7, 2026 19:12
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 4 out of 4 changed files in this pull request and generated 7 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

def fuse(self, node, input_name_to_nodes, output_name_to_node):
assert node.op_type == "Softmax"
softmax = node

Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fusion assumes Softmax is applied along the last axis of the attention scores, but it never validates Softmax's axis attribute. If axis differs, fusing into MultiHeadAttention would change semantics. Please add an explicit check that axis is -1 (or equivalently the last dimension) before applying the fusion.

Suggested change
# MultiHeadAttention normalizes attention scores along the last dimension.
# Only fuse when Softmax explicitly uses the last axis.
axis = None
for attr in softmax.attribute:
if attr.name == "axis":
axis = helper.get_attribute_value(attr)
break
if axis is None or axis not in (-1, 3):
return

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added — the fusion now bails out if Softmax.axis is set to anything other than -1 or 3.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the explicit axis guard. There is one remaining default-axis case that still worries me: when axis is omitted, this code treats it as safe, but that only means last-axis for opset >= 13. For opset < 13, ONNX Softmax defaults axis to 1, so an older model with no axis attribute can still be fused and silently change semantics. I verified this by clearing the axis in the synthetic model, setting opset 11, and seeing the pass still produce one MultiHeadAttention and remove Softmax. Please gate axis is None on self.model.get_opset_version() >= 13 or require an explicit last-axis value before fusing.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch — fixed in d74277c. The fusion now explicitly gates axis is None on the opset version:

if axis is None and self.model.get_opset_version() < 13:
return

For opset < 13, the ONNX Softmax default is axis=1 (not last-axis), so fusing would silently change semantics. For opset >= 13, the default is last-axis which matches MHA's behavior.

This was in my working tree but not pushed when you reviewed — sorry for the timing confusion.

…, shrink test models

- Fix get_data_input_of_mul to handle Python int/float scalars (not just np.ndarray)
- Validate Softmax axis=-1 before fusing to avoid semantic changes
- Add detect_num_heads_from_input_shape fallback for graph inputs
- Switch to numpy_helper.from_array for initializer creation
- Reduce default test tensor sizes (num_heads=4, head_dim=8)
- Fix FP16 type consistency: cast attn output back to FP32 before o_matmul
- Add test_dit_attention_fusion_no_k_transpose for the inserted-Transpose path
Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fusion is well scoped and follows the existing MMDiT pattern, but I think there are still a couple of correctness issues to address before merge.

Blocking items:

  • The cast-wrapped path can create a MultiHeadAttention with mixed element types for Q/K vs V. I reproduced this with the new FP16-cast synthetic model: after saving with the MS opset, ORT rejects the optimized model because MultiHeadAttention input type parameter T is bound to both tensor(float) and tensor(float16).
  • The Softmax axis guard handles explicit non-last axes, but still treats a missing axis as safe. That is only true for opset >= 13. For opset < 13, omitted Softmax axis defaults to 1, so this fusion can change semantics. I added that detail as a reply on the existing Softmax-axis thread.

Also worth addressing: add a single-consumer guard for the matched intermediate tensors, and make the tests load or run parity on the optimized model so node-count assertions do not miss invalid fused graphs.

# ========================================================================
q_bnsh = matmul_qk.input[0]
k_transposed_input = matmul_qk.input[1]
v_bnsh = matmul_sv.input[1]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This takes the value input from the second MatMul after any post-Softmax cast. In the cast-wrapped topology, that can be v_bnsh_fp16 while Q/K are still float from the QK path. The resulting fused MultiHeadAttention(q=float, k=float, v=float16) is type-invalid: saving the optimized model with the MS opset and loading it in ORT fails with Type parameter (T) ... bound to different types (tensor(float) and tensor(float16)). Please either require matching Q/K/V element types before fusing or insert casts so the fused MHA obeys the op type contract.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in d74277c. Three-layer defense now:

  1. V trace-back: When cast_after_softmax exists, trace V back through any Cast to recover the pre-cast tensor (same type as Q/K).
  2. Explicit dtype check: get_dtype on Q, K, V — bail on any confirmed mismatch.
  3. Conservative bail-out: When casts are present and V was not traced through a Cast (e.g., V is natively a different type), bail if we can't verify Q/V types via get_dtype.

The test validation now runs onnx.shape_inference.infer_shapes() before checking types, so get_dtype resolves intermediate tensor types that would otherwise return None. This catches the exact scenario you described — a fused model that passes structural assertions but produces a type-invalid MHA.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree the new V trace-back fixes the original failure mode, but I still think the casted K path is under-validated here.

Right now the conservative bail-out only keys off q_dtype and v_dtype when casts are present. If K reaches matmul_qk through a casted BNHS tensor, or through the just-added Transpose_BNHS_to_BNSH, get_dtype(k_bnsh) is often None at fuse time even though the pre-cast K source may not actually match Q/V. In that case this still emits MultiHeadAttention without ever proving K's type is compatible.

I think this wants the same treatment as V: validate K from its pre-transpose / pre-cast source, and in casted paths bail unless all three inputs are actually verified compatible.

self.nodes_to_add.append(mha_node)
self.node_name_to_graph_name[mha_node.name] = self.this_graph_name

# Remove fused nodes
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before removing/replacing this chain, please guard that the matched intermediate outputs are single-consumer. This pattern is fairly generic (MatMul -> Mul -> Softmax -> MatMul), so if matmul_qk, mul_scale, softmax, or the optional casts feed another node, fusing the path can stop producing tensors still needed elsewhere. The existing MMDiT fusion has a narrower matcher, but this one should probably bail out when any matched intermediate has more than one child.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in d74277c. The fusion now collects all intermediate outputs (matmul_qk, mul_scale, softmax, and both optional Casts) and bails out if any has more than one consumer in input_name_to_nodes.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new single-consumer check is a good improvement, but it still only protects the logits-side intermediates.

This pass still unconditionally removes matmul_sv, transpose_out, and reshape_out below. If either matmul_sv.output[0] or transpose_out.output[0] feeds another live consumer, remove_nodes() deletes that branch outright and prune_graph() cannot recover it. I think this should also gate the downstream chain with self.model.is_safe_to_fuse_nodes(nodes_to_remove, [reshape_out.output[0]], ...) before removing those nodes.

…check

- Trace V back through Cast in FP16 path to recover pre-cast tensor,
  ensuring Q/K/V share the same element type for the fused MHA
- Add explicit Q/K/V dtype consistency check; bail when casts are
  present and types are unverifiable (V not traced through Cast)
- Guard against Softmax axis=None on opset < 13 where default is
  axis=1, not last-axis
- Add single-consumer guard for all matched intermediate tensors to
  prevent removing nodes that feed other consumers
- Run onnx.shape_inference in test validation so get_dtype resolves
  intermediate tensor types, catching mixed-dtype fusions that pass
  structural assertions but fail at ORT load
Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think there are two correctness issues left on the current head before this is ready to merge.

The blocking items are both in the DiT fusion itself:

  • the cast-wrapped type-safety check still does not fully verify the K path, so the fused MultiHeadAttention can still be emitted even when K's type was never actually proven compatible with Q/V;
  • the new single-consumer guard only protects the logits-side intermediates, but the pass still removes the downstream matmul_sv -> transpose_out -> reshape_out chain without checking whether matmul_sv or transpose_out feed any other consumers.

I also left one non-blocking inline suggestion on the FP16 synthetic test model. The new shape-inference-based validation is good, but the current graph still under-exercises the real mixed-precision path because the projections remain FP32.


if use_fp16_casts:
# Cast QK scores FP16 -> FP32 (simulating FP16 model needing FP32 Softmax)
nodes.append(helper.make_node("Cast", ["qk_scores"], ["qk_scores_fp32"], "cast_to_fp32", to=1))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-blocking suggestion from the consolidated review: this still under-exercises the real mixed-precision path.

In use_fp16_casts=True, the topology now validates, but Q/K/V projections are still produced by FP32 matmuls, so cast_to_fp32 is only a placeholder no-op and the new dtype-safety checks never see a truly mixed-precision QK path. If we want this test to justify the new fusion safety logic more strongly, it would be worth adding one synthetic case where the projections themselves are FP16 (or are explicitly cast to FP16 before the QK matmul) so the pre-Softmax cast actually models the real graph class.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Flash Attention not dispatched for DiT-style attention pattern (diffusion transformers)

3 participants