Skip to content

AttentionModuleMixin.set_attention_backend does not download hub kernels #13284

@Marius-Graml

Description

@Marius-Graml

Describe the bug

Hellooo :),

I believe there is a bug with per-module attention backend setting.

Currently, set_attention_backend() works correctly when called on a top-level model (e.g. pipe.transformer.set_attention_backend("sage_hub")), but fails silently when called on individual attention submodules. This means it is not possible to apply a hub-based attention backend (like sage_hub) to only specific transformer blocks of a model.

If this is intended behavior, please feel free to close this issue.

Root Cause

There are two set_attention_backend() methods in diffusers:

  1. ModelMixin.set_attention_backend (modeling_utils.py:586): Validates the backend, calls _check_attention_backend_requirements() and _maybe_download_kernel_for_backend(), sets processor._attention_backend on all child attention modules, and updates the global active backend.

  2. AttentionModuleMixin.set_attention_backend (attention.py:161): Only validates the backend name and sets self.processor._attention_backend. It does not call _maybe_download_kernel_for_backend().

Because the per-module method skips the kernel download, hub-based backends (e.g. sage_hub) are never actually loaded. The backend name is set on the processor, but kernel_fn remains None. No error is raised at this point, the failure only occurs later at inference time when dispatch_attention_fn tries to invoke the missing kernel.

Reproduction

Just copy-paste into an empty notebook:

import torch
from diffusers import WanPipeline
from diffusers.models.attention import AttentionModuleMixin

pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", torch_dtype=torch.bfloat16)
pipe.to("cuda")

# This works: ModelMixin.set_attention_backend downloads the hub kernel internally (uncomment line below for running the functional code)
# pipe.transformer.set_attention_backend("sage_hub")

# This does NOT work: AttentionModuleMixin.set_attention_backend skips the kernel download
for name, module in pipe.transformer.named_modules():
    if isinstance(module, AttentionModuleMixin):
        module.set_attention_backend("sage_hub")

pipe("a cat walking in the snow", num_inference_steps=2) # Here, the error is thrown since the kernel is None as not downloaded

# Root cause: compare the two set_attention_backend implementations
from diffusers.models.attention_dispatch import _HUB_KERNELS_REGISTRY, AttentionBackendName

kernel_fn = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn
print(f"sage_hub kernel_fn after submodule-level set_attention_backend: {kernel_fn}")
# None, because AttentionModuleMixin.set_attention_backend (attention.py:161)
# only does `self.processor._attention_backend = backend`
# without calling `_maybe_download_kernel_for_backend(backend)`

Logs

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[2], line 9
      6     if isinstance(module, AttentionModuleMixin):
      7         module.set_attention_backend("sage_hub")
----> 9 pipe("a cat walking in the snow", num_inference_steps=2)

File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/torch/utils/_contextlib.py:124, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    120 @functools.wraps(func)
    121 def decorate_context(*args, **kwargs):
    122     # pyrefly: ignore [bad-context-manager]
    123     with ctx_factory():
--> 124         return func(*args, **kwargs)

File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/diffusers/pipelines/wan/pipeline_wan.py:608, in WanPipeline.__call__(self, prompt, negative_prompt, height, width, num_frames, num_inference_steps, guidance_scale, guidance_scale_2, num_videos_per_prompt, generator, latents, prompt_embeds, negative_prompt_embeds, output_type, return_dict, attention_kwargs, callback_on_step_end, callback_on_step_end_tensor_inputs, max_sequence_length)
    605     timestep = t.expand(latents.shape[0])
    607 with current_model.cache_context("cond"):
--> 608     noise_pred = current_model(
    609         hidden_states=latent_model_input,
    610         timestep=timestep,
    611         encoder_hidden_states=prompt_embeds,
    612         attention_kwargs=attention_kwargs,
    613         return_dict=False,
    614     )[0]
    616 if self.do_classifier_free_guidance:
    617     with current_model.cache_context("uncond"):

File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/torch/nn/modules/module.py:1776, in Module._wrapped_call_impl(self, *args, **kwargs)
   1774     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1775 else:
-> 1776     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/torch/nn/modules/module.py:1787, in Module._call_impl(self, *args, **kwargs)
   1782 # If we don't have any hooks, we want to skip the rest of the logic in
   1783 # this function, and just call forward.
   1784 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1785         or _global_backward_pre_hooks or _global_backward_hooks
   1786         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1787     return forward_call(*args, **kwargs)
   1789 result = None
   1790 called_always_called_hooks = set()

File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/diffusers/utils/peft_utils.py:315, in apply_lora_scale.<locals>.decorator.<locals>.wrapper(self, *args, **kwargs)
    311     scale_lora_layers(self, lora_scale)
    313 try:
    314     # Execute the forward pass
--> 315     result = forward_fn(self, *args, **kwargs)
    316     return result
    317 finally:
    318     # Always unscale, even if forward pass raises an exception

File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/diffusers/models/transformers/transformer_wan.py:677, in WanTransformer3DModel.forward(self, hidden_states, timestep, encoder_hidden_states, encoder_hidden_states_image, return_dict, attention_kwargs)
    675 else:
    676     for block in self.blocks:
--> 677         hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
    679 # 5. Output norm, projection & unpatchify
    680 if temb.ndim == 3:
    681     # batch_size, seq_len, inner_dim (wan 2.2 ti2v)

File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/torch/nn/modules/module.py:1776, in Module._wrapped_call_impl(self, *args, **kwargs)
   1774     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1775 else:
-> 1776     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/torch/nn/modules/module.py:1787, in Module._call_impl(self, *args, **kwargs)
   1782 # If we don't have any hooks, we want to skip the rest of the logic in
   1783 # this function, and just call forward.
   1784 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1785         or _global_backward_pre_hooks or _global_backward_hooks
   1786         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1787     return forward_call(*args, **kwargs)
   1789 result = None
   1790 called_always_called_hooks = set()

File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/diffusers/models/transformers/transformer_wan.py:489, in WanTransformerBlock.forward(self, hidden_states, encoder_hidden_states, temb, rotary_emb)
    487 # 1. Self-attention
    488 norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
--> 489 attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb)
    490 hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
    492 # 2. Cross-attention

File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/torch/nn/modules/module.py:1776, in Module._wrapped_call_impl(self, *args, **kwargs)
   1774     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1775 else:
-> 1776     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/torch/nn/modules/module.py:1787, in Module._call_impl(self, *args, **kwargs)
   1782 # If we don't have any hooks, we want to skip the rest of the logic in
   1783 # this function, and just call forward.
   1784 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1785         or _global_backward_pre_hooks or _global_backward_hooks
   1786         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1787     return forward_call(*args, **kwargs)
   1789 result = None
   1790 called_always_called_hooks = set()

File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/diffusers/models/transformers/transformer_wan.py:281, in WanAttention.forward(self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, **kwargs)
    273 def forward(
    274     self,
    275     hidden_states: torch.Tensor,
   (...)    279     **kwargs,
    280 ) -> torch.Tensor:
--> 281     return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, **kwargs)

File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/diffusers/models/transformers/transformer_wan.py:143, in WanAttnProcessor.__call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, rotary_emb)
    140     hidden_states_img = hidden_states_img.flatten(2, 3)
    141     hidden_states_img = hidden_states_img.type_as(query)
--> 143 hidden_states = dispatch_attention_fn(
    144     query,
    145     key,
    146     value,
    147     attn_mask=attention_mask,
    148     dropout_p=0.0,
    149     is_causal=False,
    150     backend=self._attention_backend,
    151     # Reference: https://github.com/huggingface/diffusers/pull/12909
    152     parallel_config=(self._parallel_config if encoder_hidden_states is None else None),
    153 )
    154 hidden_states = hidden_states.flatten(2, 3)
    155 hidden_states = hidden_states.type_as(query)

File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/diffusers/models/attention_dispatch.py:432, in dispatch_attention_fn(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, attention_kwargs, backend, parallel_config)
    428         check(**kwargs)
    430 kwargs = {k: v for k, v in kwargs.items() if k in _AttentionBackendRegistry._supported_arg_names[backend_name]}
--> 432 return backend_fn(**kwargs)

File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/diffusers/models/attention_dispatch.py:3277, in _sage_attention_hub(query, key, value, attn_mask, is_causal, scale, return_lse, _parallel_config)
   3275 func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn
   3276 if _parallel_config is None:
-> 3277     out = func(
   3278         q=query,
   3279         k=key,
   3280         v=value,
   3281         tensor_layout="NHD",
   3282         is_causal=is_causal,
   3283         sm_scale=scale,
   3284         return_lse=return_lse,
   3285     )
   3286     if return_lse:
   3287         out, lse, *_ = out

TypeError: 'NoneType' object is not callable

System Info

  • 🤗 Diffusers version: 0.37.0
  • Platform: Linux-6.8.0-106-generic-x86_64-with-glibc2.39
  • Running on Google Colab?: No
  • Python version: 3.11.15
  • PyTorch version (GPU?): 2.10.0+cu128 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.36.2
  • Transformers version: 4.57.6
  • Accelerate version: 1.13.0
  • PEFT version: 0.18.1
  • Bitsandbytes version: 0.49.2
  • Safetensors version: 0.7.0
  • xFormers version: not installed
  • Accelerator: NVIDIA H100 PCIe, 81559 MiB
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: No

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions