Spaces:
Running
on
Zero
Running
on
Zero
Update inferencer.py
Browse files- inferencer.py +7 -0
inferencer.py
CHANGED
@@ -186,6 +186,13 @@ class UniPicV2Inferencer:
|
|
186 |
attention_mask = torch.cat([attention_mask, pad_mask], dim=1)
|
187 |
|
188 |
# Get input embeddings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
inputs_embeds = self.lmm.get_input_embeddings()(input_ids)
|
190 |
|
191 |
# Ensure meta queries are on correct device
|
|
|
186 |
attention_mask = torch.cat([attention_mask, pad_mask], dim=1)
|
187 |
|
188 |
# Get input embeddings
|
189 |
+
|
190 |
+
# 获取 embedding 权重所在设备
|
191 |
+
embed_device = self.lmm.get_input_embeddings().weight.device
|
192 |
+
|
193 |
+
# 确保 input_ids 在同一设备
|
194 |
+
input_ids = input_ids.to(embed_device)
|
195 |
+
|
196 |
inputs_embeds = self.lmm.get_input_embeddings()(input_ids)
|
197 |
|
198 |
# Ensure meta queries are on correct device
|