jiuhai commited on
Commit
ef6b35b
·
verified ·
1 Parent(s): 7acdf62

Update blip3o/model/language_model/blip3o_qwen.py

Browse files
blip3o/model/language_model/blip3o_qwen.py CHANGED
@@ -53,167 +53,167 @@ class blip3oQwenForCausalLM(Qwen2_5_VLForConditionalGeneration, blip3oMetaForCau
53
  return self.model
54
 
55
 
56
- def forward(
57
- self,
58
- input_ids: torch.LongTensor = None,
59
- attention_mask: Optional[torch.Tensor] = None,
60
- position_ids: Optional[torch.LongTensor] = None,
61
- past_key_values: Optional[List[torch.FloatTensor]] = None,
62
- inputs_embeds: Optional[torch.FloatTensor] = None,
63
- labels: Optional[torch.LongTensor] = None,
64
- ids: Optional[list] = None,
65
- i_s_pos: Optional[list] = None,
66
- use_cache: Optional[bool] = None,
67
- output_attentions: Optional[bool] = None,
68
- output_hidden_states: Optional[bool] = None,
69
- gen_image: Optional[torch.FloatTensor] = None,
70
- und_image: Optional[torch.FloatTensor] = None,
71
- grid_thw: Optional[torch.FloatTensor] = None,
72
- image_sizes: Optional[List[List[int]]] = None,
73
- return_dict: Optional[bool] = None,
74
- cache_position: Optional[torch.LongTensor] = None
75
- ) -> Union[Tuple, CausalLMOutputWithPast]:
76
-
77
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
78
- output_hidden_states = (
79
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
80
- )
81
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
82
 
83
- if inputs_embeds is None:
84
- (
85
- input_ids,
86
- position_ids,
87
- attention_mask,
88
- past_key_values,
89
- inputs_embeds,
90
- labels,
91
- latents
92
- ) = self.prepare_inputs_labels_for_multimodal(
93
- input_ids,
94
- position_ids,
95
- attention_mask,
96
- past_key_values,
97
- labels,
98
- gen_image,
99
- und_image,
100
- grid_thw,
101
- i_s_pos,
102
- image_sizes
103
- )
104
-
105
- outputs = self.model(
106
- input_ids=input_ids,
107
- attention_mask=attention_mask,
108
- position_ids=position_ids,
109
- past_key_values=past_key_values,
110
- inputs_embeds=inputs_embeds,
111
- use_cache=use_cache,
112
- output_attentions=output_attentions,
113
- output_hidden_states=output_hidden_states,
114
- return_dict=return_dict,
115
- )
116
 
117
- hidden_states = outputs[0]
118
- logits = self.lm_head(hidden_states)
119
- logits = logits.float()
120
 
121
- total_loss = None
122
- if labels is not None:
123
- # Shift so that tokens < n predict n
124
- shift_logits = logits[..., :-1, :].contiguous()
125
- shift_labels = labels[..., 1:].contiguous()
126
- # Flatten the tokens
127
- loss_fct = torch.nn.CrossEntropyLoss()
128
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
129
- shift_labels = shift_labels.view(-1)
130
- # Enable model parallelism
131
- shift_labels = shift_labels.to(shift_logits.device)
132
- loss = loss_fct(shift_logits, shift_labels)
133
-
134
-
135
- # compute image loss
136
- # target_img_embeds = torch.clone(inputs_embeds.detach())[:,1:,:] # get target image emb
137
- img_loss_funct = torch.nn.MSELoss()
138
- # img_hidden_states = self.get_model().down_projector(hidden_states[:,-self.get_n_query():,:])
139
- img_hidden_states = []
140
 
141
- for b in range(hidden_states.shape[0]):
142
- img_hidden_states.append(hidden_states[b,i_s_pos[b]:i_s_pos[b]+64,:])
143
- img_hidden_states = torch.stack(img_hidden_states,dim=0)
144
- img_hidden_states = self.get_model().down_projector(img_hidden_states)
145
- # img_loss = 0.0
146
- if latents is None:
147
- img_loss = img_loss_funct(img_hidden_states, torch.clone(img_hidden_states.detach()))
148
- else:
149
- bsz = latents.shape[0]
150
- # device = latents.device
151
- dtype = latents.dtype
152
- noise = torch.randn_like(latents, device=latents.device)
153
- u = torch.rand(size=(bsz,), device="cpu")
154
- indices = (u * self.get_model().noise_scheduler.config.num_train_timesteps).long()
155
- timesteps = self.get_model().noise_scheduler.timesteps[indices].to(device=latents.device)
156
- sigmas = self.get_sigmas(timesteps, latents.device, n_dim=latents.ndim, dtype=dtype)
157
- noisy_latents = (1.0 - sigmas) * latents + sigmas * noise
158
- noise_pred = self.get_model().dit(
159
- x=noisy_latents,
160
- timestep=timesteps,
161
- z_latents=self.mask_drop(img_hidden_states),
162
- )
163
- target = noise - latents
164
- img_loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
165
- print(f"img loss {img_loss}")
166
- total_loss = img_loss
167
-
168
- return CausalLMOutputWithPast(
169
- loss=total_loss,
170
- logits=logits,
171
- past_key_values=outputs.past_key_values,
172
- hidden_states=outputs.hidden_states,
173
- attentions=outputs.attentions,
174
- )
175
 
176
 
177
- @torch.no_grad()
178
- def generate(
179
- self,
180
- inputs: Optional[torch.Tensor] = None,
181
- images: Optional[torch.Tensor] = None,
182
- image_sizes: Optional[torch.Tensor] = None,
183
- **kwargs,
184
- ) -> Union[GenerateOutput, torch.LongTensor]:
185
- position_ids = kwargs.pop("position_ids", None)
186
- attention_mask = kwargs.pop("attention_mask", None)
187
- if "inputs_embeds" in kwargs:
188
- raise NotImplementedError("`inputs_embeds` is not supported")
189
-
190
- if images is not None:
191
- (
192
- inputs,
193
- position_ids,
194
- attention_mask,
195
- _,
196
- inputs_embeds,
197
- img_indicator,
198
- _
199
- ) = self.prepare_inputs_labels_for_understanding(
200
- inputs,
201
- position_ids,
202
- attention_mask,
203
- None,
204
- None,
205
- images,
206
- image_sizes=image_sizes
207
- )
208
- else:
209
- inputs_embeds = self.get_model().embed_tokens(inputs)
210
-
211
- return super().generate(
212
- position_ids=position_ids,
213
- attention_mask=attention_mask,
214
- inputs_embeds=inputs_embeds,
215
- **kwargs
216
- )
217
 
218
  @torch.no_grad()
219
  def generate_image(
 
53
  return self.model
54
 
55
 
56
+ # def forward(
57
+ # self,
58
+ # input_ids: torch.LongTensor = None,
59
+ # attention_mask: Optional[torch.Tensor] = None,
60
+ # position_ids: Optional[torch.LongTensor] = None,
61
+ # past_key_values: Optional[List[torch.FloatTensor]] = None,
62
+ # inputs_embeds: Optional[torch.FloatTensor] = None,
63
+ # labels: Optional[torch.LongTensor] = None,
64
+ # ids: Optional[list] = None,
65
+ # i_s_pos: Optional[list] = None,
66
+ # use_cache: Optional[bool] = None,
67
+ # output_attentions: Optional[bool] = None,
68
+ # output_hidden_states: Optional[bool] = None,
69
+ # gen_image: Optional[torch.FloatTensor] = None,
70
+ # und_image: Optional[torch.FloatTensor] = None,
71
+ # grid_thw: Optional[torch.FloatTensor] = None,
72
+ # image_sizes: Optional[List[List[int]]] = None,
73
+ # return_dict: Optional[bool] = None,
74
+ # cache_position: Optional[torch.LongTensor] = None
75
+ # ) -> Union[Tuple, CausalLMOutputWithPast]:
76
+
77
+ # output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
78
+ # output_hidden_states = (
79
+ # output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
80
+ # )
81
+ # return_dict = return_dict if return_dict is not None else self.config.use_return_dict
82
 
83
+ # if inputs_embeds is None:
84
+ # (
85
+ # input_ids,
86
+ # position_ids,
87
+ # attention_mask,
88
+ # past_key_values,
89
+ # inputs_embeds,
90
+ # labels,
91
+ # latents
92
+ # ) = self.prepare_inputs_labels_for_multimodal(
93
+ # input_ids,
94
+ # position_ids,
95
+ # attention_mask,
96
+ # past_key_values,
97
+ # labels,
98
+ # gen_image,
99
+ # und_image,
100
+ # grid_thw,
101
+ # i_s_pos,
102
+ # image_sizes
103
+ # )
104
+
105
+ # outputs = self.model(
106
+ # input_ids=input_ids,
107
+ # attention_mask=attention_mask,
108
+ # position_ids=position_ids,
109
+ # past_key_values=past_key_values,
110
+ # inputs_embeds=inputs_embeds,
111
+ # use_cache=use_cache,
112
+ # output_attentions=output_attentions,
113
+ # output_hidden_states=output_hidden_states,
114
+ # return_dict=return_dict,
115
+ # )
116
 
117
+ # hidden_states = outputs[0]
118
+ # logits = self.lm_head(hidden_states)
119
+ # logits = logits.float()
120
 
121
+ # total_loss = None
122
+ # if labels is not None:
123
+ # # Shift so that tokens < n predict n
124
+ # shift_logits = logits[..., :-1, :].contiguous()
125
+ # shift_labels = labels[..., 1:].contiguous()
126
+ # # Flatten the tokens
127
+ # loss_fct = torch.nn.CrossEntropyLoss()
128
+ # shift_logits = shift_logits.view(-1, self.config.vocab_size)
129
+ # shift_labels = shift_labels.view(-1)
130
+ # # Enable model parallelism
131
+ # shift_labels = shift_labels.to(shift_logits.device)
132
+ # loss = loss_fct(shift_logits, shift_labels)
133
+
134
+
135
+ # # compute image loss
136
+ # # target_img_embeds = torch.clone(inputs_embeds.detach())[:,1:,:] # get target image emb
137
+ # img_loss_funct = torch.nn.MSELoss()
138
+ # # img_hidden_states = self.get_model().down_projector(hidden_states[:,-self.get_n_query():,:])
139
+ # img_hidden_states = []
140
 
141
+ # for b in range(hidden_states.shape[0]):
142
+ # img_hidden_states.append(hidden_states[b,i_s_pos[b]:i_s_pos[b]+64,:])
143
+ # img_hidden_states = torch.stack(img_hidden_states,dim=0)
144
+ # img_hidden_states = self.get_model().down_projector(img_hidden_states)
145
+ # # img_loss = 0.0
146
+ # if latents is None:
147
+ # img_loss = img_loss_funct(img_hidden_states, torch.clone(img_hidden_states.detach()))
148
+ # else:
149
+ # bsz = latents.shape[0]
150
+ # # device = latents.device
151
+ # dtype = latents.dtype
152
+ # noise = torch.randn_like(latents, device=latents.device)
153
+ # u = torch.rand(size=(bsz,), device="cpu")
154
+ # indices = (u * self.get_model().noise_scheduler.config.num_train_timesteps).long()
155
+ # timesteps = self.get_model().noise_scheduler.timesteps[indices].to(device=latents.device)
156
+ # sigmas = self.get_sigmas(timesteps, latents.device, n_dim=latents.ndim, dtype=dtype)
157
+ # noisy_latents = (1.0 - sigmas) * latents + sigmas * noise
158
+ # noise_pred = self.get_model().dit(
159
+ # x=noisy_latents,
160
+ # timestep=timesteps,
161
+ # z_latents=self.mask_drop(img_hidden_states),
162
+ # )
163
+ # target = noise - latents
164
+ # img_loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
165
+ # print(f"img loss {img_loss}")
166
+ # total_loss = img_loss
167
+
168
+ # return CausalLMOutputWithPast(
169
+ # loss=total_loss,
170
+ # logits=logits,
171
+ # past_key_values=outputs.past_key_values,
172
+ # hidden_states=outputs.hidden_states,
173
+ # attentions=outputs.attentions,
174
+ # )
175
 
176
 
177
+ # @torch.no_grad()
178
+ # def generate(
179
+ # self,
180
+ # inputs: Optional[torch.Tensor] = None,
181
+ # images: Optional[torch.Tensor] = None,
182
+ # image_sizes: Optional[torch.Tensor] = None,
183
+ # **kwargs,
184
+ # ) -> Union[GenerateOutput, torch.LongTensor]:
185
+ # position_ids = kwargs.pop("position_ids", None)
186
+ # attention_mask = kwargs.pop("attention_mask", None)
187
+ # if "inputs_embeds" in kwargs:
188
+ # raise NotImplementedError("`inputs_embeds` is not supported")
189
+
190
+ # if images is not None:
191
+ # (
192
+ # inputs,
193
+ # position_ids,
194
+ # attention_mask,
195
+ # _,
196
+ # inputs_embeds,
197
+ # img_indicator,
198
+ # _
199
+ # ) = self.prepare_inputs_labels_for_understanding(
200
+ # inputs,
201
+ # position_ids,
202
+ # attention_mask,
203
+ # None,
204
+ # None,
205
+ # images,
206
+ # image_sizes=image_sizes
207
+ # )
208
+ # else:
209
+ # inputs_embeds = self.get_model().embed_tokens(inputs)
210
+
211
+ # return super().generate(
212
+ # position_ids=position_ids,
213
+ # attention_mask=attention_mask,
214
+ # inputs_embeds=inputs_embeds,
215
+ # **kwargs
216
+ # )
217
 
218
  @torch.no_grad()
219
  def generate_image(