File size: 18,490 Bytes
9f57ecf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast, Qwen2VLForConditionalGeneration
from gui_actor.constants import IGNORE_INDEX
from typing import List, Tuple, Union, Optional
from gui_actor.trainer import rank0_print
class QwenVLwithVisionHeadOutputWithPast(Qwen2VLCausalLMOutputWithPast):
"""
Output class for Qwen2VL with pointer head, extending the base output class.
Args:
lm_loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
Language modeling loss.
pointer_loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
Vision pointer network loss.
pointer_scores (`List[torch.FloatTensor]`, *optional*):
Attention scores from the pointer network, one tensor per batch item.
loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
Combined loss (weighted sum of lm_loss and pointer_loss).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores from the language modeling head.
past_key_values, hidden_states, attentions, rope_deltas:
Same as parent class.
"""
def __init__(self, lm_loss=None, pointer_loss=None, pointer_scores=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lm_loss = lm_loss
self.pointer_loss = pointer_loss
self.pointer_scores = pointer_scores
class VisionHead_MultiPatch(nn.Module):
def __init__(self, d_model, projection_dim, num_attention_heads=8, dropout_rate=0.1):
super().__init__()
self.d_model = d_model
# Note: We omit additional normalization here because Qwen2VL
# already normalizes hidden states using RMSNorm.
self.projection_enc = nn.Sequential(
nn.Linear(d_model, projection_dim),
nn.GELU(),
nn.Linear(projection_dim, d_model)
)
self.projection_dec = nn.Sequential(
nn.Linear(d_model, projection_dim),
nn.GELU(),
nn.Linear(projection_dim, d_model)
)
# Add self-attention layer for visual features
self.self_attention = nn.MultiheadAttention(
embed_dim=d_model,
num_heads=num_attention_heads,
dropout=dropout_rate,
batch_first=True
)
# Layer normalization and residual connection
self.layer_norm = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout_rate)
def forward(self,
hidden_state_enc, # shape: [n_enc, d_model] where n_enc can vary with image size
hidden_state_dec, # shape: [n_dec, d_model] there can be multiple query in one sample
labels: Optional[torch.Tensor] = None, # shape: [n_dec, n_enc], binary mask of patches in bbox
do_single_patch: bool = False,
):
enc_input = hidden_state_enc.unsqueeze(0)
attn_output, _ = self.self_attention(
query=enc_input,
key=enc_input,
value=enc_input,
# attn_mask=attention_mask,
need_weights=False
)
# Residual connection and layer normalization
hidden_state_enc_ctx = self.layer_norm(enc_input + self.dropout(attn_output))
# Remove batch dimension
hidden_state_enc_ctx = hidden_state_enc_ctx.squeeze(0) # [n_enc, d_model]
# Apply the projection networks.
proj_enc = self.projection_enc(hidden_state_enc_ctx) # [n_enc, d_model]
proj_dec = self.projection_dec(hidden_state_dec) # [n_dec, d_model]
# Compute scaled dot-product attention scores.
# Scaling by sqrt(d_model) is critical regardless of variable n_enc.
scaling = self.d_model ** 0.5
patch_logits = torch.matmul(proj_dec, proj_enc.transpose(0, 1)) / scaling # [n_dec, n_enc]
# Softmax normalization is applied along the encoder dimension.
attn_weights = F.softmax(patch_logits, dim=-1)
loss = None
if (labels is not None) and (not do_single_patch):
epsilon = 1e-8
labels_float = labels.float()
# Normalize each row to get target probability distribution
target_dist = labels_float / (labels_float.sum(dim=-1, keepdim=True) + epsilon)
# Apply log_softmax to logits
pred_log_probs = F.log_softmax(patch_logits, dim=-1)
# Use KL divergence as loss
loss = F.kl_div(pred_log_probs, target_dist, reduction='batchmean')
if do_single_patch and (labels is not None):
loss = F.cross_entropy(attn_scores, labels)
return attn_weights, loss
class Qwen2VLForConditionalGenerationWithPointer(Qwen2VLForConditionalGeneration):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.multi_patch_pointer_head = VisionHead_MultiPatch(self.config.hidden_size, self.config.hidden_size)
self.pointer_loss_weight = kwargs.get("pointer_loss_weight", 1.0)
self.lm_loss_weight = kwargs.get("lm_loss_weight", 1.0)
self.post_init()
def reset_loss_weights(self, pointer_loss_weight, lm_loss_weight):
self.pointer_loss_weight = pointer_loss_weight
self.lm_loss_weight = lm_loss_weight
def forward(self,
input_ids: torch.LongTensor = None, # (batch_size, seq_len)
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
# Grounding
visual_token_indices_of_coordinates: Optional[torch.Tensor] = None, # shape: (batch_size, n_target); each element is the ground-truth index of the visual token that should be attended to for the corresponding target token
multi_patch_labels: Optional[torch.Tensor] = None, # shape: list [(n_target, n_visual), ...]; binary mask of patches in bbox
if_multi_patch: bool = True,
coordinates: Optional[List[Tuple[float, float]]] = None,
verbose: bool = False) -> Union[Tuple, QwenVLwithVisionHeadOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if verbose:
rank0_print(f"input_ids: {input_ids.shape}, {input_ids[0][:5]}...")
rank0_print(f"labels: {labels.shape}, {labels[0][:5]}...")
rank0_print(f"pixel_values: {pixel_values.shape}")
rank0_print(f"image_grid_thw: {image_grid_thw.shape}, {image_grid_thw}")
rank0_print(f"coordinates: {coordinates}")
rank0_print(f"visual_token_indices_of_coordinates: {visual_token_indices_of_coordinates}")
rank0_print(f"return_dict: {return_dict}")
if inputs_embeds is None:
inputs_embeds = self.model.embed_tokens(input_ids) # shape: (batch_size, seq_len, d_model)
if pixel_values is not None:
pixel_values = pixel_values.type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
n_image_features = image_embeds.shape[0]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_mask = (
(input_ids == self.config.image_token_id)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
if pixel_values_videos is not None:
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
n_video_features = video_embeds.shape[0]
if n_video_tokens != n_video_features:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
video_mask = (
(input_ids == self.config.video_token_id)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
# calculate RoPE index once per generation in the pre-fill stage only
if (
(cache_position is not None and cache_position[0] == 0)
or self.rope_deltas is None
or (past_key_values is None or past_key_values.get_seq_length() == 0)
):
position_ids, rope_deltas = self.get_rope_index(
input_ids, image_grid_thw, video_grid_thw, attention_mask
)
self.rope_deltas = rope_deltas
# then use the prev pre-calculated rope-deltas to get the correct position ids
else:
batch_size, seq_length, _ = inputs_embeds.shape
delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
if cache_position is not None: # otherwise `deltas` is an int `0`
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
delta = delta.to(position_ids.device)
position_ids = position_ids.add(delta)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
outputs = self.model(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs[0] # shape: (batch_size, seq_len, d_model)
logits = self.lm_head(hidden_states)
lm_loss = None
if labels is not None and self.lm_loss_weight > 0:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
lm_loss = loss_fct(shift_logits, shift_labels)
# If vision supervision is requested, process the action head.
pointer_loss = None
pointer_scores = []
if visual_token_indices_of_coordinates is not None:
batch_size = input_ids.shape[0]
pointer_losses = []
# Process each sample individually because the number of visual and target tokens may vary.
for i in range(batch_size):
dummy_target = False
# Get the token ids and corresponding hidden states for sample i.
token_ids = input_ids[i] # shape: (seq_length,)
hs = hidden_states[i] # shape: (seq_length, d_model)
# Identify visual tokens indices.
visual_mask = (token_ids == self.config.image_token_id)
visual_indices = torch.nonzero(visual_mask, as_tuple=False).squeeze(-1) # shape: (n_visual,)
# Identify target tokens (the ones that should attend to visual features).
target_mask = (token_ids == self.config.pointer_pad_token_id)
target_indices = torch.nonzero(target_mask, as_tuple=False).squeeze(-1)
# If either visual or target tokens are missing, skip this sample.
if visual_indices.numel() == 0:
raise ValueError(f"No visual or target tokens found for sample {i}.")
if target_indices.numel() == 0:
target_indices = torch.tensor([hs.shape[0] - 1]) # take the last token as the dummy target token
gt = torch.tensor([0]).to(hs.device) # take the first visual token as the dummy ground truth
if if_multi_patch: # task the first 4 visual tokens as the ground truth
sample_labels = torch.zeros_like(visual_indices).unsqueeze(0)
sample_labels[0][:4] = 1
dummy_target = True
else:
# For supervision, we assume that visual_token_indices_of_coordinates[i] is a tensor of shape (n_target,)
# where each element is an integer in the range [0, n_visual-1] indicating the ground-truth visual token.
gt = visual_token_indices_of_coordinates[i].to(hs.device) # shape: (n_target,)
if if_multi_patch:
sample_labels = multi_patch_labels[i]
# Gather the corresponding hidden state representations.
# visual_hidden = hs[visual_indices] # shape: (n_visual, d_model)
visual_embeds = inputs_embeds[i][visual_indices]
target_hidden = hs[target_indices] # shape: (n_target, d_model)
# Calculate loss for multi-patch mode
if if_multi_patch:
# Ensure the number of targets matches between sample and labels
if sample_labels.shape[0] != target_indices.shape[0]:
raise ValueError(f"Sample {i} has mismatched target counts: {sample_labels.shape[0]} labels but found {target_indices.shape[0]} target tokens")
# Process using VisionHead_MultiPatch
attn_scores, loss_v = self.multi_patch_pointer_head(
visual_embeds,
target_hidden,
labels=sample_labels
)
else:
# Deprecated branch - single patch mode is no longer used
# Run the action head to compute the attention (from target tokens to visual tokens) and its loss.
attn_scores, loss_v = self.pointer_head(visual_embeds, target_hidden, labels=gt)
pointer_scores.append(attn_scores.detach().cpu())
pointer_losses.append(loss_v * 0.0 if dummy_target else loss_v)
pointer_loss = torch.stack(pointer_losses).mean()
# Combine the LM loss and vision loss using the provided loss weights.
if lm_loss is None:
total_loss = pointer_loss
elif pointer_loss is None:
total_loss = lm_loss
else:
total_loss = self.lm_loss_weight * lm_loss + self.pointer_loss_weight * pointer_loss
if return_dict:
return QwenVLwithVisionHeadOutputWithPast(
lm_loss=lm_loss,
pointer_loss=pointer_loss,
pointer_scores=pointer_scores,
loss=total_loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=self.rope_deltas,
)
else:
# When labels are provided, parent's forward returns a tuple with loss as the first element.
if labels is not None:
# Replace the LM loss with the combined loss.
output = (lm_loss, pointer_loss, logits, pointer_scores,) + outputs[1:]
print(f"returning: total_loss, logits, pointer_scores, ...")
return (total_loss,) + output if total_loss is not None else output
else:
return outputs |