I believe there is a bug with per-module attention backend setting.
If this is intended behavior, please feel free to close this issue.
Because the per-module method skips the kernel download, hub-based backends (e.g. sage_hub) are never actually loaded. The backend name is set on the processor, but kernel_fn remains None. No error is raised at this point, the failure only occurs later at inference time when dispatch_attention_fn tries to invoke the missing kernel.
import torch
from diffusers import WanPipeline
from diffusers.models.attention import AttentionModuleMixin
pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", torch_dtype=torch.bfloat16)
pipe.to("cuda")
# This works: ModelMixin.set_attention_backend downloads the hub kernel internally (uncomment line below for running the functional code)
# pipe.transformer.set_attention_backend("sage_hub")
# This does NOT work: AttentionModuleMixin.set_attention_backend skips the kernel download
for name, module in pipe.transformer.named_modules():
if isinstance(module, AttentionModuleMixin):
module.set_attention_backend("sage_hub")
pipe("a cat walking in the snow", num_inference_steps=2) # Here, the error is thrown since the kernel is None as not downloaded
# Root cause: compare the two set_attention_backend implementations
from diffusers.models.attention_dispatch import _HUB_KERNELS_REGISTRY, AttentionBackendName
kernel_fn = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn
print(f"sage_hub kernel_fn after submodule-level set_attention_backend: {kernel_fn}")
# None, because AttentionModuleMixin.set_attention_backend (attention.py:161)
# only does `self.processor._attention_backend = backend`
# without calling `_maybe_download_kernel_for_backend(backend)`
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[2], line 9
6 if isinstance(module, AttentionModuleMixin):
7 module.set_attention_backend("sage_hub")
----> 9 pipe("a cat walking in the snow", num_inference_steps=2)
File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/torch/utils/_contextlib.py:124, in context_decorator.<locals>.decorate_context(*args, **kwargs)
120 @functools.wraps(func)
121 def decorate_context(*args, **kwargs):
122 # pyrefly: ignore [bad-context-manager]
123 with ctx_factory():
--> 124 return func(*args, **kwargs)
File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/diffusers/pipelines/wan/pipeline_wan.py:608, in WanPipeline.__call__(self, prompt, negative_prompt, height, width, num_frames, num_inference_steps, guidance_scale, guidance_scale_2, num_videos_per_prompt, generator, latents, prompt_embeds, negative_prompt_embeds, output_type, return_dict, attention_kwargs, callback_on_step_end, callback_on_step_end_tensor_inputs, max_sequence_length)
605 timestep = t.expand(latents.shape[0])
607 with current_model.cache_context("cond"):
--> 608 noise_pred = current_model(
609 hidden_states=latent_model_input,
610 timestep=timestep,
611 encoder_hidden_states=prompt_embeds,
612 attention_kwargs=attention_kwargs,
613 return_dict=False,
614 )[0]
616 if self.do_classifier_free_guidance:
617 with current_model.cache_context("uncond"):
File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/torch/nn/modules/module.py:1776, in Module._wrapped_call_impl(self, *args, **kwargs)
1774 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1775 else:
-> 1776 return self._call_impl(*args, **kwargs)
File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/torch/nn/modules/module.py:1787, in Module._call_impl(self, *args, **kwargs)
1782 # If we don't have any hooks, we want to skip the rest of the logic in
1783 # this function, and just call forward.
1784 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1785 or _global_backward_pre_hooks or _global_backward_hooks
1786 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1787 return forward_call(*args, **kwargs)
1789 result = None
1790 called_always_called_hooks = set()
File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/diffusers/utils/peft_utils.py:315, in apply_lora_scale.<locals>.decorator.<locals>.wrapper(self, *args, **kwargs)
311 scale_lora_layers(self, lora_scale)
313 try:
314 # Execute the forward pass
--> 315 result = forward_fn(self, *args, **kwargs)
316 return result
317 finally:
318 # Always unscale, even if forward pass raises an exception
File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/diffusers/models/transformers/transformer_wan.py:677, in WanTransformer3DModel.forward(self, hidden_states, timestep, encoder_hidden_states, encoder_hidden_states_image, return_dict, attention_kwargs)
675 else:
676 for block in self.blocks:
--> 677 hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
679 # 5. Output norm, projection & unpatchify
680 if temb.ndim == 3:
681 # batch_size, seq_len, inner_dim (wan 2.2 ti2v)
File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/torch/nn/modules/module.py:1776, in Module._wrapped_call_impl(self, *args, **kwargs)
1774 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1775 else:
-> 1776 return self._call_impl(*args, **kwargs)
File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/torch/nn/modules/module.py:1787, in Module._call_impl(self, *args, **kwargs)
1782 # If we don't have any hooks, we want to skip the rest of the logic in
1783 # this function, and just call forward.
1784 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1785 or _global_backward_pre_hooks or _global_backward_hooks
1786 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1787 return forward_call(*args, **kwargs)
1789 result = None
1790 called_always_called_hooks = set()
File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/diffusers/models/transformers/transformer_wan.py:489, in WanTransformerBlock.forward(self, hidden_states, encoder_hidden_states, temb, rotary_emb)
487 # 1. Self-attention
488 norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
--> 489 attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb)
490 hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
492 # 2. Cross-attention
File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/torch/nn/modules/module.py:1776, in Module._wrapped_call_impl(self, *args, **kwargs)
1774 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1775 else:
-> 1776 return self._call_impl(*args, **kwargs)
File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/torch/nn/modules/module.py:1787, in Module._call_impl(self, *args, **kwargs)
1782 # If we don't have any hooks, we want to skip the rest of the logic in
1783 # this function, and just call forward.
1784 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1785 or _global_backward_pre_hooks or _global_backward_hooks
1786 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1787 return forward_call(*args, **kwargs)
1789 result = None
1790 called_always_called_hooks = set()
File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/diffusers/models/transformers/transformer_wan.py:281, in WanAttention.forward(self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, **kwargs)
273 def forward(
274 self,
275 hidden_states: torch.Tensor,
(...) 279 **kwargs,
280 ) -> torch.Tensor:
--> 281 return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, **kwargs)
File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/diffusers/models/transformers/transformer_wan.py:143, in WanAttnProcessor.__call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, rotary_emb)
140 hidden_states_img = hidden_states_img.flatten(2, 3)
141 hidden_states_img = hidden_states_img.type_as(query)
--> 143 hidden_states = dispatch_attention_fn(
144 query,
145 key,
146 value,
147 attn_mask=attention_mask,
148 dropout_p=0.0,
149 is_causal=False,
150 backend=self._attention_backend,
151 # Reference: https://github.com/huggingface/diffusers/pull/12909
152 parallel_config=(self._parallel_config if encoder_hidden_states is None else None),
153 )
154 hidden_states = hidden_states.flatten(2, 3)
155 hidden_states = hidden_states.type_as(query)
File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/diffusers/models/attention_dispatch.py:432, in dispatch_attention_fn(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, attention_kwargs, backend, parallel_config)
428 check(**kwargs)
430 kwargs = {k: v for k, v in kwargs.items() if k in _AttentionBackendRegistry._supported_arg_names[backend_name]}
--> 432 return backend_fn(**kwargs)
File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/diffusers/models/attention_dispatch.py:3277, in _sage_attention_hub(query, key, value, attn_mask, is_causal, scale, return_lse, _parallel_config)
3275 func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn
3276 if _parallel_config is None:
-> 3277 out = func(
3278 q=query,
3279 k=key,
3280 v=value,
3281 tensor_layout="NHD",
3282 is_causal=is_causal,
3283 sm_scale=scale,
3284 return_lse=return_lse,
3285 )
3286 if return_lse:
3287 out, lse, *_ = out
TypeError: 'NoneType' object is not callable
Describe the bug
Hellooo :),
I believe there is a bug with per-module attention backend setting.
Currently,
set_attention_backend()works correctly when called on a top-level model (e.g.pipe.transformer.set_attention_backend("sage_hub")), but fails silently when called on individual attention submodules. This means it is not possible to apply a hub-based attention backend (like sage_hub) to only specific transformer blocks of a model.If this is intended behavior, please feel free to close this issue.
Root Cause
There are two set_attention_backend() methods in diffusers:
ModelMixin.set_attention_backend(modeling_utils.py:586): Validates the backend, calls_check_attention_backend_requirements()and_maybe_download_kernel_for_backend(), setsprocessor._attention_backendon all child attention modules, and updates the global active backend.AttentionModuleMixin.set_attention_backend(attention.py:161): Only validates the backend name and setsself.processor._attention_backend. It does not call_maybe_download_kernel_for_backend().Because the per-module method skips the kernel download, hub-based backends (e.g. sage_hub) are never actually loaded. The backend name is set on the processor, but kernel_fn remains None. No error is raised at this point, the failure only occurs later at inference time when dispatch_attention_fn tries to invoke the missing kernel.
Reproduction
Just copy-paste into an empty notebook:
Logs
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) Cell In[2], line 9 6 if isinstance(module, AttentionModuleMixin): 7 module.set_attention_backend("sage_hub") ----> 9 pipe("a cat walking in the snow", num_inference_steps=2) File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/torch/utils/_contextlib.py:124, in context_decorator.<locals>.decorate_context(*args, **kwargs) 120 @functools.wraps(func) 121 def decorate_context(*args, **kwargs): 122 # pyrefly: ignore [bad-context-manager] 123 with ctx_factory(): --> 124 return func(*args, **kwargs) File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/diffusers/pipelines/wan/pipeline_wan.py:608, in WanPipeline.__call__(self, prompt, negative_prompt, height, width, num_frames, num_inference_steps, guidance_scale, guidance_scale_2, num_videos_per_prompt, generator, latents, prompt_embeds, negative_prompt_embeds, output_type, return_dict, attention_kwargs, callback_on_step_end, callback_on_step_end_tensor_inputs, max_sequence_length) 605 timestep = t.expand(latents.shape[0]) 607 with current_model.cache_context("cond"): --> 608 noise_pred = current_model( 609 hidden_states=latent_model_input, 610 timestep=timestep, 611 encoder_hidden_states=prompt_embeds, 612 attention_kwargs=attention_kwargs, 613 return_dict=False, 614 )[0] 616 if self.do_classifier_free_guidance: 617 with current_model.cache_context("uncond"): File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/torch/nn/modules/module.py:1776, in Module._wrapped_call_impl(self, *args, **kwargs) 1774 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1775 else: -> 1776 return self._call_impl(*args, **kwargs) File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/torch/nn/modules/module.py:1787, in Module._call_impl(self, *args, **kwargs) 1782 # If we don't have any hooks, we want to skip the rest of the logic in 1783 # this function, and just call forward. 1784 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1785 or _global_backward_pre_hooks or _global_backward_hooks 1786 or _global_forward_hooks or _global_forward_pre_hooks): -> 1787 return forward_call(*args, **kwargs) 1789 result = None 1790 called_always_called_hooks = set() File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/diffusers/utils/peft_utils.py:315, in apply_lora_scale.<locals>.decorator.<locals>.wrapper(self, *args, **kwargs) 311 scale_lora_layers(self, lora_scale) 313 try: 314 # Execute the forward pass --> 315 result = forward_fn(self, *args, **kwargs) 316 return result 317 finally: 318 # Always unscale, even if forward pass raises an exception File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/diffusers/models/transformers/transformer_wan.py:677, in WanTransformer3DModel.forward(self, hidden_states, timestep, encoder_hidden_states, encoder_hidden_states_image, return_dict, attention_kwargs) 675 else: 676 for block in self.blocks: --> 677 hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) 679 # 5. Output norm, projection & unpatchify 680 if temb.ndim == 3: 681 # batch_size, seq_len, inner_dim (wan 2.2 ti2v) File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/torch/nn/modules/module.py:1776, in Module._wrapped_call_impl(self, *args, **kwargs) 1774 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1775 else: -> 1776 return self._call_impl(*args, **kwargs) File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/torch/nn/modules/module.py:1787, in Module._call_impl(self, *args, **kwargs) 1782 # If we don't have any hooks, we want to skip the rest of the logic in 1783 # this function, and just call forward. 1784 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1785 or _global_backward_pre_hooks or _global_backward_hooks 1786 or _global_forward_hooks or _global_forward_pre_hooks): -> 1787 return forward_call(*args, **kwargs) 1789 result = None 1790 called_always_called_hooks = set() File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/diffusers/models/transformers/transformer_wan.py:489, in WanTransformerBlock.forward(self, hidden_states, encoder_hidden_states, temb, rotary_emb) 487 # 1. Self-attention 488 norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) --> 489 attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb) 490 hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states) 492 # 2. Cross-attention File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/torch/nn/modules/module.py:1776, in Module._wrapped_call_impl(self, *args, **kwargs) 1774 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1775 else: -> 1776 return self._call_impl(*args, **kwargs) File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/torch/nn/modules/module.py:1787, in Module._call_impl(self, *args, **kwargs) 1782 # If we don't have any hooks, we want to skip the rest of the logic in 1783 # this function, and just call forward. 1784 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1785 or _global_backward_pre_hooks or _global_backward_hooks 1786 or _global_forward_hooks or _global_forward_pre_hooks): -> 1787 return forward_call(*args, **kwargs) 1789 result = None 1790 called_always_called_hooks = set() File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/diffusers/models/transformers/transformer_wan.py:281, in WanAttention.forward(self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, **kwargs) 273 def forward( 274 self, 275 hidden_states: torch.Tensor, (...) 279 **kwargs, 280 ) -> torch.Tensor: --> 281 return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, **kwargs) File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/diffusers/models/transformers/transformer_wan.py:143, in WanAttnProcessor.__call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, rotary_emb) 140 hidden_states_img = hidden_states_img.flatten(2, 3) 141 hidden_states_img = hidden_states_img.type_as(query) --> 143 hidden_states = dispatch_attention_fn( 144 query, 145 key, 146 value, 147 attn_mask=attention_mask, 148 dropout_p=0.0, 149 is_causal=False, 150 backend=self._attention_backend, 151 # Reference: https://github.com/huggingface/diffusers/pull/12909 152 parallel_config=(self._parallel_config if encoder_hidden_states is None else None), 153 ) 154 hidden_states = hidden_states.flatten(2, 3) 155 hidden_states = hidden_states.type_as(query) File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/diffusers/models/attention_dispatch.py:432, in dispatch_attention_fn(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, attention_kwargs, backend, parallel_config) 428 check(**kwargs) 430 kwargs = {k: v for k, v in kwargs.items() if k in _AttentionBackendRegistry._supported_arg_names[backend_name]} --> 432 return backend_fn(**kwargs) File ~/miniconda3/envs/pruna-dev/lib/python3.11/site-packages/diffusers/models/attention_dispatch.py:3277, in _sage_attention_hub(query, key, value, attn_mask, is_causal, scale, return_lse, _parallel_config) 3275 func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn 3276 if _parallel_config is None: -> 3277 out = func( 3278 q=query, 3279 k=key, 3280 v=value, 3281 tensor_layout="NHD", 3282 is_causal=is_causal, 3283 sm_scale=scale, 3284 return_lse=return_lse, 3285 ) 3286 if return_lse: 3287 out, lse, *_ = out TypeError: 'NoneType' object is not callableSystem Info
Who can help?
No response