Skip to content

Conversation

@dxqb
Copy link
Contributor

@dxqb dxqb commented Dec 26, 2025

What does this PR do?

Using an attention backend (https://huggingface.co/docs/diffusers/main/optimization/attention_backends) with a model that passes attention masks yields incorrect results.

This is already checked in parallel backends...

raise ValueError("`attn_mask` is not yet supported for flash-attn 2.")

...but not yet in the regular ones.

This PR changes that.
Fixes #12605

Who can review?

@yiyixuxu and @asomoza
CC @zzlol63 @tolgacangoz

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Jan 6, 2026

@bot /style

@github-actions
Copy link
Contributor

github-actions bot commented Jan 6, 2026

Style bot fixed some files and pushed the changes.

@yiyixuxu yiyixuxu requested a review from sayakpaul January 6, 2026 01:44
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

But we don't even support passing attn_mask in the respective methods. So, I am not sure how this is applicable.

@dxqb
Copy link
Contributor Author

dxqb commented Jan 6, 2026

Thanks!

But we don't even support passing attn_mask in the respective methods. So, I am not sure how this is applicable.

True, but that is the point: Those backends don't have an 'attn_mask' parameter currently. Therefore, when the caller passes an 'attn_mask' argument, the attention mask is silently ignored here:

if _AttentionBackendRegistry._checks_enabled:

if checks are disabled, which they are by default.

Ignoring extra arguments might be valid behaviour for other paremters, but not for attention masks. Ignoring an attention mask yields incorrect results. It should fail instead of returning incorrect results - that's what this PR does.

@yiyixuxu your style bot changed a line I didn't even mean to insert. Fixed.

@sayakpaul sayakpaul merged commit 41a6e86 into huggingface:main Jan 6, 2026
10 of 11 checks passed
@dxqb dxqb deleted the check_attn_mask branch January 7, 2026 18:26
@dxqb dxqb restored the check_attn_mask branch January 7, 2026 18:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

dispatch_attention_fn silently ignores attn_mask for certain backends

4 participants