File size: 5,922 Bytes
c242674
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""
Paired with a good language model. Thanks!
"""

import torch
from typing import Optional, Tuple
from diffusers.models.transformers.transformer_qwenimage import apply_rotary_emb_qwen

try:
    from kernels import get_kernel
    _k = get_kernel("kernels-community/vllm-flash-attn3")
    _flash_attn_func = _k.flash_attn_func
except Exception as e:
    _flash_attn_func = None
    _kernels_err = e


def _ensure_fa3_available():
    if _flash_attn_func is None:
        raise ImportError(
            "FlashAttention-3 via Hugging Face `kernels` is required. "
            "Tried `get_kernel('kernels-community/vllm-flash-attn3')` and failed with:\n"
            f"{_kernels_err}"
        )

@torch.library.custom_op("flash::flash_attn_func", mutates_args=())
def flash_attn_func(
    q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool = False
) -> torch.Tensor:
    outputs, lse = _flash_attn_func(q, k, v, causal=causal)
    return outputs

@flash_attn_func.register_fake
def _(q, k, v, **kwargs):
    # two outputs:
    # 1. output: (batch, seq_len, num_heads, head_dim)
    # 2. softmax_lse: (batch, num_heads, seq_len) with dtype=torch.float32
    meta_q = torch.empty_like(q).contiguous()
    return meta_q #, q.new_empty((q.size(0), q.size(2), q.size(1)), dtype=torch.float32)


class QwenDoubleStreamAttnProcessorFA3:
    """
    FA3-based attention processor for Qwen double-stream architecture.
    Computes joint attention over concatenated [text, image] streams using vLLM FlashAttention-3
    accessed via Hugging Face `kernels`.

    Notes / limitations:
    - General attention masks are not supported here (FA3 path). `is_causal=False` and no arbitrary mask.
    - Optional windowed attention / sink tokens / softcap can be plumbed through if you use those features.
    - Expects an available `apply_rotary_emb_qwen` in scope (same as your non-FA3 processor).
    """

    _attention_backend = "fa3"  # for parity with your other processors, not used internally

    def __init__(self):
        _ensure_fa3_available()

    @torch.no_grad()
    def __call__(
        self,
        attn,  # Attention module with to_q/to_k/to_v/add_*_proj, norms, to_out, to_add_out, and .heads
        hidden_states: torch.FloatTensor,                 # (B, S_img, D_model)  image stream
        encoder_hidden_states: torch.FloatTensor = None,  # (B, S_txt, D_model)  text stream
        encoder_hidden_states_mask: torch.FloatTensor = None,  # unused in FA3 path
        attention_mask: Optional[torch.FloatTensor] = None,    # unused in FA3 path
        image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # (img_freqs, txt_freqs)
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
        if encoder_hidden_states is None:
            raise ValueError("QwenDoubleStreamAttnProcessorFA3 requires encoder_hidden_states (text stream).")
        if attention_mask is not None:
            # FA3 kernel path here does not consume arbitrary masks; fail fast to avoid silent correctness issues.
            raise NotImplementedError("attention_mask is not supported in this FA3 implementation.")

        _ensure_fa3_available()

        B, S_img, _ = hidden_states.shape
        S_txt = encoder_hidden_states.shape[1]

        # ---- QKV projections (image/sample stream) ----
        img_q = attn.to_q(hidden_states)   # (B, S_img, D)
        img_k = attn.to_k(hidden_states)
        img_v = attn.to_v(hidden_states)

        # ---- QKV projections (text/context stream) ----
        txt_q = attn.add_q_proj(encoder_hidden_states)  # (B, S_txt, D)
        txt_k = attn.add_k_proj(encoder_hidden_states)
        txt_v = attn.add_v_proj(encoder_hidden_states)

        # ---- Reshape to (B, S, H, D_h) ----
        H = attn.heads
        img_q = img_q.unflatten(-1, (H, -1))
        img_k = img_k.unflatten(-1, (H, -1))
        img_v = img_v.unflatten(-1, (H, -1))

        txt_q = txt_q.unflatten(-1, (H, -1))
        txt_k = txt_k.unflatten(-1, (H, -1))
        txt_v = txt_v.unflatten(-1, (H, -1))

        # ---- Q/K normalization (per your module contract) ----
        if getattr(attn, "norm_q", None) is not None:
            img_q = attn.norm_q(img_q)
        if getattr(attn, "norm_k", None) is not None:
            img_k = attn.norm_k(img_k)
        if getattr(attn, "norm_added_q", None) is not None:
            txt_q = attn.norm_added_q(txt_q)
        if getattr(attn, "norm_added_k", None) is not None:
            txt_k = attn.norm_added_k(txt_k)

        # ---- RoPE (Qwen variant) ----
        if image_rotary_emb is not None:
            img_freqs, txt_freqs = image_rotary_emb
            # expects tensors shaped (B, S, H, D_h)
            img_q = apply_rotary_emb_qwen(img_q, img_freqs, use_real=False)
            img_k = apply_rotary_emb_qwen(img_k, img_freqs, use_real=False)
            txt_q = apply_rotary_emb_qwen(txt_q, txt_freqs, use_real=False)
            txt_k = apply_rotary_emb_qwen(txt_k, txt_freqs, use_real=False)

        # ---- Joint attention over [text, image] along sequence axis ----
        # Shapes: (B, S_total, H, D_h)
        q = torch.cat([txt_q, img_q], dim=1)
        k = torch.cat([txt_k, img_k], dim=1)
        v = torch.cat([txt_v, img_v], dim=1)

        # FlashAttention-3 path expects (B, S, H, D_h) and returns (out, softmax_lse)
        out = flash_attn_func(q, k, v, causal=False)  # out: (B, S_total, H, D_h)

        # ---- Back to (B, S, D_model) ----
        out = out.flatten(2, 3).to(q.dtype)

        # Split back to text / image segments
        txt_attn_out = out[:, :S_txt, :]
        img_attn_out = out[:, S_txt:, :]

        # ---- Output projections ----
        img_attn_out = attn.to_out[0](img_attn_out)
        if len(attn.to_out) > 1:
            img_attn_out = attn.to_out[1](img_attn_out)  # dropout if present

        txt_attn_out = attn.to_add_out(txt_attn_out)

        return img_attn_out, txt_attn_out