yichenchenchen commited on
Commit
ad7a819
·
verified ·
1 Parent(s): 00ad5ed

Update inferencer.py

Browse files
Files changed (1) hide show
  1. inferencer.py +7 -0
inferencer.py CHANGED
@@ -186,6 +186,13 @@ class UniPicV2Inferencer:
186
  attention_mask = torch.cat([attention_mask, pad_mask], dim=1)
187
 
188
  # Get input embeddings
 
 
 
 
 
 
 
189
  inputs_embeds = self.lmm.get_input_embeddings()(input_ids)
190
 
191
  # Ensure meta queries are on correct device
 
186
  attention_mask = torch.cat([attention_mask, pad_mask], dim=1)
187
 
188
  # Get input embeddings
189
+
190
+ # 获取 embedding 权重所在设备
191
+ embed_device = self.lmm.get_input_embeddings().weight.device
192
+
193
+ # 确保 input_ids 在同一设备
194
+ input_ids = input_ids.to(embed_device)
195
+
196
  inputs_embeds = self.lmm.get_input_embeddings()(input_ids)
197
 
198
  # Ensure meta queries are on correct device