sayakpaul HF Staff commited on
Commit
f4d1b25
·
verified ·
1 Parent(s): ba2c444

Update qwenimage/qwen_fa3_processor.py

Browse files
Files changed (1) hide show
  1. qwenimage/qwen_fa3_processor.py +18 -6
qwenimage/qwen_fa3_processor.py CHANGED
@@ -2,6 +2,23 @@ import torch
2
  from typing import Optional, Tuple
3
  from diffusers.models.transformers.transformer_qwenimage import apply_rotary_emb_qwen
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  @torch.library.custom_op("flash::flash_attn_func", mutates_args=())
6
  def flash_attn_func(
7
  q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool = False
@@ -35,12 +52,7 @@ class QwenDoubleStreamAttnProcessorFA3:
35
  _attention_backend = "fa3" # for parity with your other processors, not used internally
36
 
37
  def __init__(self):
38
- try:
39
- from flash_attn.flash_attn_interface import flash_attn_interface_func
40
- except ImportError:
41
- raise ImportError(
42
- "flash_attention v3 package is required to be installed"
43
- )
44
 
45
  @torch.no_grad()
46
  def __call__(
 
2
  from typing import Optional, Tuple
3
  from diffusers.models.transformers.transformer_qwenimage import apply_rotary_emb_qwen
4
 
5
+ try:
6
+ from kernels import get_kernel
7
+ _k = get_kernel("kernels-community/vllm-flash-attn3")
8
+ _flash_attn_func = _k.flash_attn_func
9
+ except Exception as e:
10
+ _flash_attn_func = None
11
+ _kernels_err = e
12
+
13
+
14
+ def _ensure_fa3_available():
15
+ if _flash_attn_func is None:
16
+ raise ImportError(
17
+ "FlashAttention-3 via Hugging Face `kernels` is required. "
18
+ "Tried `get_kernel('kernels-community/vllm-flash-attn3')` and failed with:\n"
19
+ f"{_kernels_err}"
20
+
21
+
22
  @torch.library.custom_op("flash::flash_attn_func", mutates_args=())
23
  def flash_attn_func(
24
  q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool = False
 
52
  _attention_backend = "fa3" # for parity with your other processors, not used internally
53
 
54
  def __init__(self):
55
+ _ensure_fa3_available()
 
 
 
 
 
56
 
57
  @torch.no_grad()
58
  def __call__(