yichenchenchen commited on
Commit
cc7c8f0
·
verified ·
1 Parent(s): ad7a819

Update inferencer.py

Browse files
Files changed (1) hide show
  1. inferencer.py +7 -1
inferencer.py CHANGED
@@ -212,7 +212,13 @@ class UniPicV2Inferencer:
212
  # Forward through LMM
213
  if hasattr(self.lmm.model, "rope_deltas"):
214
  self.lmm.model.rope_deltas = None
215
-
 
 
 
 
 
 
216
  outputs = self.lmm.model(
217
  inputs_embeds=inputs_embeds.to(self.device),
218
  attention_mask=attention_mask.to(self.device),
 
212
  # Forward through LMM
213
  if hasattr(self.lmm.model, "rope_deltas"):
214
  self.lmm.model.rope_deltas = None
215
+
216
+ model_device = self.lmm.model.embed_tokens.weight.device
217
+ # 强制将所有 tensor 输入搬到这个设备
218
+ for k, v in inputs.items():
219
+ if isinstance(v, torch.Tensor):
220
+ inputs[k] = v.to(model_device)
221
+
222
  outputs = self.lmm.model(
223
  inputs_embeds=inputs_embeds.to(self.device),
224
  attention_mask=attention_mask.to(self.device),