Add DiT attention fusion for F5-TTS and diffusion transformer models#27999
Add DiT attention fusion for F5-TTS and diffusion transformer models#27999Rishi-Dave wants to merge 4 commits intomicrosoft:mainfrom
Conversation
… models DiT models like F5-TTS use an attention pattern where Q, K, V are pre-computed (e.g. after RoPE) in BNSH format, K is pre-transposed to BNHS, and a custom scalar scale (e.g. 100.0) is applied before Softmax. Optional Cast nodes (FP16<->FP32) may wrap Softmax for mixed-precision. The existing MMDit fusion (for SD3/Flux) expects a very specific Mul->Sqrt->Div->Sqrt->Cast->Slice->Shape scaling path and does not match the simpler Mul(scalar) pattern used in DiT models, so Flash Attention is never dispatched. This commit adds FusionMultiHeadAttentionDiT which recognizes: MatMul(Q, K^T) -> [Cast] -> Mul(scale) -> Softmax -> [Cast] -> MatMul(attn, V) and fuses it into a single MultiHeadAttention op with the custom scale attribute, enabling Flash Attention dispatch. Fixes microsoft#27983
There was a problem hiding this comment.
Pull request overview
This PR extends the Python transformers graph fuser to recognize and fuse DiT-style attention patterns (e.g., F5-TTS) into a com.microsoft::MultiHeadAttention node so Flash Attention can be dispatched for these diffusion transformer models.
Changes:
- Add a new DiT attention fusion pass (
FusionMultiHeadAttentionDiT) that detectsMatMul → (Cast) → Mul(scale) → Softmax → (Cast) → MatMulpatterns and replaces them withMultiHeadAttention(includingscale). - Register the new fusion as a second attention fusion pass in
MmditOnnxModel.fuse_multi_head_attention(). - Add a synthetic DiT ONNX model generator and new unit tests for FP32, optional Cast-wrapped Softmax, and custom scale variants.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| onnxruntime/python/tools/transformers/fusion_mha_dit.py | Implements DiT-specific pattern matching and replacement with MultiHeadAttention + scale. |
| onnxruntime/python/tools/transformers/onnx_model_mmdit.py | Runs the new DiT attention fusion pass after the existing MMDiT fusion pass. |
| onnxruntime/test/python/transformers/dit_model_generator.py | Adds synthetic DiT attention subgraph generators used by fusion tests. |
| onnxruntime/test/python/transformers/test_attention_fusion.py | Adds three unit tests validating DiT attention fusion behavior and attributes. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Cast V to FP16 in the FP16-cast test model so the attention MatMul has type-consistent inputs. Add Softmax-count-is-zero assertions to the FP16-cast and custom-scale tests to match the base test coverage.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 4 out of 4 changed files in this pull request and generated 7 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def fuse(self, node, input_name_to_nodes, output_name_to_node): | ||
| assert node.op_type == "Softmax" | ||
| softmax = node | ||
|
|
There was a problem hiding this comment.
The fusion assumes Softmax is applied along the last axis of the attention scores, but it never validates Softmax's axis attribute. If axis differs, fusing into MultiHeadAttention would change semantics. Please add an explicit check that axis is -1 (or equivalently the last dimension) before applying the fusion.
| # MultiHeadAttention normalizes attention scores along the last dimension. | |
| # Only fuse when Softmax explicitly uses the last axis. | |
| axis = None | |
| for attr in softmax.attribute: | |
| if attr.name == "axis": | |
| axis = helper.get_attribute_value(attr) | |
| break | |
| if axis is None or axis not in (-1, 3): | |
| return |
There was a problem hiding this comment.
Added — the fusion now bails out if Softmax.axis is set to anything other than -1 or 3.
There was a problem hiding this comment.
Thanks for adding the explicit axis guard. There is one remaining default-axis case that still worries me: when axis is omitted, this code treats it as safe, but that only means last-axis for opset >= 13. For opset < 13, ONNX Softmax defaults axis to 1, so an older model with no axis attribute can still be fused and silently change semantics. I verified this by clearing the axis in the synthetic model, setting opset 11, and seeing the pass still produce one MultiHeadAttention and remove Softmax. Please gate axis is None on self.model.get_opset_version() >= 13 or require an explicit last-axis value before fusing.
There was a problem hiding this comment.
Good catch — fixed in d74277c. The fusion now explicitly gates axis is None on the opset version:
if axis is None and self.model.get_opset_version() < 13:
return
For opset < 13, the ONNX Softmax default is axis=1 (not last-axis), so fusing would silently change semantics. For opset >= 13, the default is last-axis which matches MHA's behavior.
This was in my working tree but not pushed when you reviewed — sorry for the timing confusion.
…, shrink test models - Fix get_data_input_of_mul to handle Python int/float scalars (not just np.ndarray) - Validate Softmax axis=-1 before fusing to avoid semantic changes - Add detect_num_heads_from_input_shape fallback for graph inputs - Switch to numpy_helper.from_array for initializer creation - Reduce default test tensor sizes (num_heads=4, head_dim=8) - Fix FP16 type consistency: cast attn output back to FP32 before o_matmul - Add test_dit_attention_fusion_no_k_transpose for the inserted-Transpose path
There was a problem hiding this comment.
The fusion is well scoped and follows the existing MMDiT pattern, but I think there are still a couple of correctness issues to address before merge.
Blocking items:
- The cast-wrapped path can create a
MultiHeadAttentionwith mixed element types for Q/K vs V. I reproduced this with the new FP16-cast synthetic model: after saving with the MS opset, ORT rejects the optimized model becauseMultiHeadAttentioninput type parameterTis bound to bothtensor(float)andtensor(float16). - The Softmax axis guard handles explicit non-last axes, but still treats a missing axis as safe. That is only true for opset >= 13. For opset < 13, omitted Softmax axis defaults to 1, so this fusion can change semantics. I added that detail as a reply on the existing Softmax-axis thread.
Also worth addressing: add a single-consumer guard for the matched intermediate tensors, and make the tests load or run parity on the optimized model so node-count assertions do not miss invalid fused graphs.
| # ======================================================================== | ||
| q_bnsh = matmul_qk.input[0] | ||
| k_transposed_input = matmul_qk.input[1] | ||
| v_bnsh = matmul_sv.input[1] |
There was a problem hiding this comment.
This takes the value input from the second MatMul after any post-Softmax cast. In the cast-wrapped topology, that can be v_bnsh_fp16 while Q/K are still float from the QK path. The resulting fused MultiHeadAttention(q=float, k=float, v=float16) is type-invalid: saving the optimized model with the MS opset and loading it in ORT fails with Type parameter (T) ... bound to different types (tensor(float) and tensor(float16)). Please either require matching Q/K/V element types before fusing or insert casts so the fused MHA obeys the op type contract.
There was a problem hiding this comment.
Fixed in d74277c. Three-layer defense now:
- V trace-back: When
cast_after_softmaxexists, trace V back through any Cast to recover the pre-cast tensor (same type as Q/K). - Explicit dtype check:
get_dtypeon Q, K, V — bail on any confirmed mismatch. - Conservative bail-out: When casts are present and V was not traced through a Cast (e.g., V is natively a different type), bail if we can't verify Q/V types via
get_dtype.
The test validation now runs onnx.shape_inference.infer_shapes() before checking types, so get_dtype resolves intermediate tensor types that would otherwise return None. This catches the exact scenario you described — a fused model that passes structural assertions but produces a type-invalid MHA.
There was a problem hiding this comment.
I agree the new V trace-back fixes the original failure mode, but I still think the casted K path is under-validated here.
Right now the conservative bail-out only keys off q_dtype and v_dtype when casts are present. If K reaches matmul_qk through a casted BNHS tensor, or through the just-added Transpose_BNHS_to_BNSH, get_dtype(k_bnsh) is often None at fuse time even though the pre-cast K source may not actually match Q/V. In that case this still emits MultiHeadAttention without ever proving K's type is compatible.
I think this wants the same treatment as V: validate K from its pre-transpose / pre-cast source, and in casted paths bail unless all three inputs are actually verified compatible.
| self.nodes_to_add.append(mha_node) | ||
| self.node_name_to_graph_name[mha_node.name] = self.this_graph_name | ||
|
|
||
| # Remove fused nodes |
There was a problem hiding this comment.
Before removing/replacing this chain, please guard that the matched intermediate outputs are single-consumer. This pattern is fairly generic (MatMul -> Mul -> Softmax -> MatMul), so if matmul_qk, mul_scale, softmax, or the optional casts feed another node, fusing the path can stop producing tensors still needed elsewhere. The existing MMDiT fusion has a narrower matcher, but this one should probably bail out when any matched intermediate has more than one child.
There was a problem hiding this comment.
Added in d74277c. The fusion now collects all intermediate outputs (matmul_qk, mul_scale, softmax, and both optional Casts) and bails out if any has more than one consumer in input_name_to_nodes.
There was a problem hiding this comment.
The new single-consumer check is a good improvement, but it still only protects the logits-side intermediates.
This pass still unconditionally removes matmul_sv, transpose_out, and reshape_out below. If either matmul_sv.output[0] or transpose_out.output[0] feeds another live consumer, remove_nodes() deletes that branch outright and prune_graph() cannot recover it. I think this should also gate the downstream chain with self.model.is_safe_to_fuse_nodes(nodes_to_remove, [reshape_out.output[0]], ...) before removing those nodes.
…check - Trace V back through Cast in FP16 path to recover pre-cast tensor, ensuring Q/K/V share the same element type for the fused MHA - Add explicit Q/K/V dtype consistency check; bail when casts are present and types are unverifiable (V not traced through Cast) - Guard against Softmax axis=None on opset < 13 where default is axis=1, not last-axis - Add single-consumer guard for all matched intermediate tensors to prevent removing nodes that feed other consumers - Run onnx.shape_inference in test validation so get_dtype resolves intermediate tensor types, catching mixed-dtype fusions that pass structural assertions but fail at ORT load
There was a problem hiding this comment.
I still think there are two correctness issues left on the current head before this is ready to merge.
The blocking items are both in the DiT fusion itself:
- the cast-wrapped type-safety check still does not fully verify the K path, so the fused
MultiHeadAttentioncan still be emitted even when K's type was never actually proven compatible with Q/V; - the new single-consumer guard only protects the logits-side intermediates, but the pass still removes the downstream
matmul_sv -> transpose_out -> reshape_outchain without checking whethermatmul_svortranspose_outfeed any other consumers.
I also left one non-blocking inline suggestion on the FP16 synthetic test model. The new shape-inference-based validation is good, but the current graph still under-exercises the real mixed-precision path because the projections remain FP32.
|
|
||
| if use_fp16_casts: | ||
| # Cast QK scores FP16 -> FP32 (simulating FP16 model needing FP32 Softmax) | ||
| nodes.append(helper.make_node("Cast", ["qk_scores"], ["qk_scores_fp32"], "cast_to_fp32", to=1)) |
There was a problem hiding this comment.
Non-blocking suggestion from the consolidated review: this still under-exercises the real mixed-precision path.
In use_fp16_casts=True, the topology now validates, but Q/K/V projections are still produced by FP32 matmuls, so cast_to_fp32 is only a placeholder no-op and the new dtype-safety checks never see a truly mixed-precision QK path. If we want this test to justify the new fusion safety logic more strongly, it would be worth adding one synthetic case where the projections themselves are FP16 (or are explicitly cast to FP16 before the QK matmul) so the pre-Softmax cast actually models the real graph class.
Summary
FusionMultiHeadAttentionDiTto recognize DiT-style attention patterns (F5-TTS, etc.) and fuse them intoMultiHeadAttention, enabling Flash Attention dispatch.MmditOnnxModel.fuse_multi_head_attention(), alongside the existing MMDit fusion for SD3/Flux.Motivation
Fixes #27983
DiT models like F5-TTS use an attention pattern where Q, K, V are pre-computed (e.g., after RoPE) in BNSH format, K is pre-transposed to BNHS, and a custom scalar scale (e.g., 100.0) is applied via
MulbeforeSoftmax. OptionalCastnodes (FP16↔FP32) may wrapSoftmaxfor mixed-precision inference.The existing MMDit fusion (for SD3/Flux) expects a specific
Mul→Sqrt→Div→Sqrt→Cast→Slice→Shapescaling path and does not match the simplerMul(scalar_constant)pattern, so the attention is never fused and Flash Attention is never dispatched. This causes ~44 extra Cast ops per inference and ~200ms overhead per forward pass.Changes
New files:
onnxruntime/python/tools/transformers/fusion_mha_dit.py— Core fusion class that matches the pattern:and replaces it with a single
MultiHeadAttentionop (withscaleattribute).onnxruntime/test/python/transformers/dit_model_generator.py— Synthetic ONNX graph generators for testing.Modified files:
onnxruntime/python/tools/transformers/onnx_model_mmdit.py— RegisterFusionMultiHeadAttentionDiTas a second fusion pass after the existing MMDit fusion.onnxruntime/test/python/transformers/test_attention_fusion.py— Three new test cases:test_dit_attention_fusion— FP32 with K pre-transpose, scale=100.0test_dit_attention_fusion_with_fp16_casts— FP16 Cast nodes around Softmaxtest_dit_attention_fusion_custom_scale— Standard 1/√d_k scaleTest Plan
MultiHeadAttentionnode is producednum_headsattribute is correctly detected from upstream Reshape shapesscaleattribute matches the original scalar constantSoftmaxnodes remain after fusionruff check,ruff format, andlintrunner -apass clean