|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast, Qwen2VLForConditionalGeneration |
|
from gui_actor.constants import IGNORE_INDEX |
|
from typing import List, Tuple, Union, Optional |
|
from gui_actor.trainer import rank0_print |
|
|
|
class QwenVLwithVisionHeadOutputWithPast(Qwen2VLCausalLMOutputWithPast): |
|
""" |
|
Output class for Qwen2VL with pointer head, extending the base output class. |
|
|
|
Args: |
|
lm_loss (`torch.FloatTensor` of shape `(1,)`, *optional*): |
|
Language modeling loss. |
|
pointer_loss (`torch.FloatTensor` of shape `(1,)`, *optional*): |
|
Vision pointer network loss. |
|
pointer_scores (`List[torch.FloatTensor]`, *optional*): |
|
Attention scores from the pointer network, one tensor per batch item. |
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*): |
|
Combined loss (weighted sum of lm_loss and pointer_loss). |
|
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): |
|
Prediction scores from the language modeling head. |
|
past_key_values, hidden_states, attentions, rope_deltas: |
|
Same as parent class. |
|
""" |
|
def __init__(self, lm_loss=None, pointer_loss=None, pointer_scores=None, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.lm_loss = lm_loss |
|
self.pointer_loss = pointer_loss |
|
self.pointer_scores = pointer_scores |
|
|
|
|
|
class VisionHead_MultiPatch(nn.Module): |
|
def __init__(self, d_model, projection_dim, num_attention_heads=8, dropout_rate=0.1): |
|
super().__init__() |
|
self.d_model = d_model |
|
|
|
|
|
|
|
self.projection_enc = nn.Sequential( |
|
nn.Linear(d_model, projection_dim), |
|
nn.GELU(), |
|
nn.Linear(projection_dim, d_model) |
|
) |
|
self.projection_dec = nn.Sequential( |
|
nn.Linear(d_model, projection_dim), |
|
nn.GELU(), |
|
nn.Linear(projection_dim, d_model) |
|
) |
|
|
|
|
|
self.self_attention = nn.MultiheadAttention( |
|
embed_dim=d_model, |
|
num_heads=num_attention_heads, |
|
dropout=dropout_rate, |
|
batch_first=True |
|
) |
|
|
|
|
|
self.layer_norm = nn.LayerNorm(d_model) |
|
self.dropout = nn.Dropout(dropout_rate) |
|
|
|
def forward(self, |
|
hidden_state_enc, |
|
hidden_state_dec, |
|
labels: Optional[torch.Tensor] = None, |
|
do_single_patch: bool = False, |
|
): |
|
|
|
enc_input = hidden_state_enc.unsqueeze(0) |
|
attn_output, _ = self.self_attention( |
|
query=enc_input, |
|
key=enc_input, |
|
value=enc_input, |
|
|
|
need_weights=False |
|
) |
|
|
|
hidden_state_enc_ctx = self.layer_norm(enc_input + self.dropout(attn_output)) |
|
|
|
hidden_state_enc_ctx = hidden_state_enc_ctx.squeeze(0) |
|
|
|
|
|
proj_enc = self.projection_enc(hidden_state_enc_ctx) |
|
proj_dec = self.projection_dec(hidden_state_dec) |
|
|
|
|
|
|
|
scaling = self.d_model ** 0.5 |
|
patch_logits = torch.matmul(proj_dec, proj_enc.transpose(0, 1)) / scaling |
|
|
|
|
|
attn_weights = F.softmax(patch_logits, dim=-1) |
|
|
|
loss = None |
|
if (labels is not None) and (not do_single_patch): |
|
epsilon = 1e-8 |
|
labels_float = labels.float() |
|
|
|
target_dist = labels_float / (labels_float.sum(dim=-1, keepdim=True) + epsilon) |
|
|
|
|
|
pred_log_probs = F.log_softmax(patch_logits, dim=-1) |
|
|
|
loss = F.kl_div(pred_log_probs, target_dist, reduction='batchmean') |
|
|
|
if do_single_patch and (labels is not None): |
|
loss = F.cross_entropy(attn_scores, labels) |
|
|
|
return attn_weights, loss |
|
|
|
|
|
class Qwen2VLForConditionalGenerationWithPointer(Qwen2VLForConditionalGeneration): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.multi_patch_pointer_head = VisionHead_MultiPatch(self.config.hidden_size, self.config.hidden_size) |
|
self.pointer_loss_weight = kwargs.get("pointer_loss_weight", 1.0) |
|
self.lm_loss_weight = kwargs.get("lm_loss_weight", 1.0) |
|
self.post_init() |
|
|
|
def reset_loss_weights(self, pointer_loss_weight, lm_loss_weight): |
|
self.pointer_loss_weight = pointer_loss_weight |
|
self.lm_loss_weight = lm_loss_weight |
|
|
|
def forward(self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
pixel_values: Optional[torch.Tensor] = None, |
|
pixel_values_videos: Optional[torch.FloatTensor] = None, |
|
image_grid_thw: Optional[torch.LongTensor] = None, |
|
video_grid_thw: Optional[torch.LongTensor] = None, |
|
rope_deltas: Optional[torch.LongTensor] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
|
visual_token_indices_of_coordinates: Optional[torch.Tensor] = None, |
|
multi_patch_labels: Optional[torch.Tensor] = None, |
|
if_multi_patch: bool = True, |
|
coordinates: Optional[List[Tuple[float, float]]] = None, |
|
verbose: bool = False) -> Union[Tuple, QwenVLwithVisionHeadOutputWithPast]: |
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
if verbose: |
|
rank0_print(f"input_ids: {input_ids.shape}, {input_ids[0][:5]}...") |
|
rank0_print(f"labels: {labels.shape}, {labels[0][:5]}...") |
|
rank0_print(f"pixel_values: {pixel_values.shape}") |
|
rank0_print(f"image_grid_thw: {image_grid_thw.shape}, {image_grid_thw}") |
|
rank0_print(f"coordinates: {coordinates}") |
|
rank0_print(f"visual_token_indices_of_coordinates: {visual_token_indices_of_coordinates}") |
|
rank0_print(f"return_dict: {return_dict}") |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.model.embed_tokens(input_ids) |
|
if pixel_values is not None: |
|
pixel_values = pixel_values.type(self.visual.dtype) |
|
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) |
|
n_image_tokens = (input_ids == self.config.image_token_id).sum().item() |
|
n_image_features = image_embeds.shape[0] |
|
if n_image_tokens != n_image_features: |
|
raise ValueError( |
|
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" |
|
) |
|
image_mask = ( |
|
(input_ids == self.config.image_token_id) |
|
.unsqueeze(-1) |
|
.expand_as(inputs_embeds) |
|
.to(inputs_embeds.device) |
|
) |
|
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) |
|
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) |
|
|
|
if pixel_values_videos is not None: |
|
pixel_values_videos = pixel_values_videos.type(self.visual.dtype) |
|
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) |
|
n_video_tokens = (input_ids == self.config.video_token_id).sum().item() |
|
n_video_features = video_embeds.shape[0] |
|
if n_video_tokens != n_video_features: |
|
raise ValueError( |
|
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" |
|
) |
|
video_mask = ( |
|
(input_ids == self.config.video_token_id) |
|
.unsqueeze(-1) |
|
.expand_as(inputs_embeds) |
|
.to(inputs_embeds.device) |
|
) |
|
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) |
|
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) |
|
|
|
if attention_mask is not None: |
|
attention_mask = attention_mask.to(inputs_embeds.device) |
|
|
|
|
|
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): |
|
|
|
if ( |
|
(cache_position is not None and cache_position[0] == 0) |
|
or self.rope_deltas is None |
|
or (past_key_values is None or past_key_values.get_seq_length() == 0) |
|
): |
|
position_ids, rope_deltas = self.get_rope_index( |
|
input_ids, image_grid_thw, video_grid_thw, attention_mask |
|
) |
|
self.rope_deltas = rope_deltas |
|
|
|
else: |
|
batch_size, seq_length, _ = inputs_embeds.shape |
|
delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0 |
|
position_ids = torch.arange(seq_length, device=inputs_embeds.device) |
|
position_ids = position_ids.view(1, -1).expand(batch_size, -1) |
|
if cache_position is not None: |
|
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) |
|
delta = delta.to(position_ids.device) |
|
position_ids = position_ids.add(delta) |
|
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) |
|
|
|
outputs = self.model( |
|
input_ids=None, |
|
position_ids=position_ids, |
|
attention_mask=attention_mask, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
cache_position=cache_position, |
|
) |
|
|
|
hidden_states = outputs[0] |
|
logits = self.lm_head(hidden_states) |
|
|
|
lm_loss = None |
|
if labels is not None and self.lm_loss_weight > 0: |
|
|
|
logits = logits.float() |
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
loss_fct = nn.CrossEntropyLoss() |
|
shift_logits = shift_logits.view(-1, self.config.vocab_size) |
|
shift_labels = shift_labels.view(-1) |
|
|
|
shift_labels = shift_labels.to(shift_logits.device) |
|
lm_loss = loss_fct(shift_logits, shift_labels) |
|
|
|
|
|
|
|
pointer_loss = None |
|
pointer_scores = [] |
|
if visual_token_indices_of_coordinates is not None: |
|
batch_size = input_ids.shape[0] |
|
pointer_losses = [] |
|
|
|
|
|
for i in range(batch_size): |
|
dummy_target = False |
|
|
|
|
|
token_ids = input_ids[i] |
|
hs = hidden_states[i] |
|
|
|
|
|
visual_mask = (token_ids == self.config.image_token_id) |
|
visual_indices = torch.nonzero(visual_mask, as_tuple=False).squeeze(-1) |
|
|
|
|
|
target_mask = (token_ids == self.config.pointer_pad_token_id) |
|
target_indices = torch.nonzero(target_mask, as_tuple=False).squeeze(-1) |
|
|
|
|
|
if visual_indices.numel() == 0: |
|
raise ValueError(f"No visual or target tokens found for sample {i}.") |
|
if target_indices.numel() == 0: |
|
target_indices = torch.tensor([hs.shape[0] - 1]) |
|
gt = torch.tensor([0]).to(hs.device) |
|
if if_multi_patch: |
|
sample_labels = torch.zeros_like(visual_indices).unsqueeze(0) |
|
sample_labels[0][:4] = 1 |
|
dummy_target = True |
|
else: |
|
|
|
|
|
gt = visual_token_indices_of_coordinates[i].to(hs.device) |
|
if if_multi_patch: |
|
sample_labels = multi_patch_labels[i] |
|
|
|
|
|
|
|
visual_embeds = inputs_embeds[i][visual_indices] |
|
target_hidden = hs[target_indices] |
|
|
|
|
|
if if_multi_patch: |
|
|
|
if sample_labels.shape[0] != target_indices.shape[0]: |
|
raise ValueError(f"Sample {i} has mismatched target counts: {sample_labels.shape[0]} labels but found {target_indices.shape[0]} target tokens") |
|
|
|
|
|
attn_scores, loss_v = self.multi_patch_pointer_head( |
|
visual_embeds, |
|
target_hidden, |
|
labels=sample_labels |
|
) |
|
|
|
else: |
|
|
|
|
|
attn_scores, loss_v = self.pointer_head(visual_embeds, target_hidden, labels=gt) |
|
|
|
pointer_scores.append(attn_scores.detach().cpu()) |
|
|
|
pointer_losses.append(loss_v * 0.0 if dummy_target else loss_v) |
|
|
|
pointer_loss = torch.stack(pointer_losses).mean() |
|
|
|
|
|
|
|
if lm_loss is None: |
|
total_loss = pointer_loss |
|
elif pointer_loss is None: |
|
total_loss = lm_loss |
|
else: |
|
total_loss = self.lm_loss_weight * lm_loss + self.pointer_loss_weight * pointer_loss |
|
|
|
if return_dict: |
|
return QwenVLwithVisionHeadOutputWithPast( |
|
lm_loss=lm_loss, |
|
pointer_loss=pointer_loss, |
|
pointer_scores=pointer_scores, |
|
loss=total_loss, |
|
logits=logits, |
|
past_key_values=outputs.past_key_values, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
rope_deltas=self.rope_deltas, |
|
) |
|
else: |
|
|
|
if labels is not None: |
|
|
|
output = (lm_loss, pointer_loss, logits, pointer_scores,) + outputs[1:] |
|
print(f"returning: total_loss, logits, pointer_scores, ...") |
|
return (total_loss,) + output if total_loss is not None else output |
|
else: |
|
return outputs |