Update inferencer.py
Browse files- inferencer.py +8 -8
inferencer.py
CHANGED
@@ -244,8 +244,8 @@ class Inferencer:
|
|
244 |
|
245 |
# 2) Encode image and extract features
|
246 |
with torch.no_grad():
|
247 |
-
x_enc =
|
248 |
-
x_con, z_enc =
|
249 |
|
250 |
# 3) Prepare text prompts
|
251 |
m = n = self.image_size // 16
|
@@ -267,18 +267,18 @@ class Inferencer:
|
|
267 |
cfg_prompt_str = cfg_prompt_str.replace('<image>', '<image>' * image_length)
|
268 |
|
269 |
# 4) Tokenize and prepare inputs
|
270 |
-
input_ids =
|
271 |
prompt_str, add_special_tokens=True, return_tensors='pt')[0].cuda()
|
272 |
|
273 |
if cfg != 1.0:
|
274 |
-
null_input_ids =
|
275 |
cfg_prompt_str, add_special_tokens=True, return_tensors='pt')[0].cuda()
|
276 |
attention_mask = pad_sequence(
|
277 |
[torch.ones_like(input_ids), torch.ones_like(null_input_ids)],
|
278 |
batch_first=True, padding_value=0).to(torch.bool)
|
279 |
input_ids = pad_sequence(
|
280 |
[input_ids, null_input_ids],
|
281 |
-
batch_first=True, padding_value=
|
282 |
else:
|
283 |
input_ids = input_ids[None]
|
284 |
attention_mask = torch.ones_like(input_ids).to(torch.bool)
|
@@ -288,10 +288,10 @@ class Inferencer:
|
|
288 |
z_enc = torch.cat([z_enc, z_enc], dim=0)
|
289 |
x_con = torch.cat([x_con, x_con], dim=0)
|
290 |
|
291 |
-
inputs_embeds = z_enc.new_zeros(*input_ids.shape,
|
292 |
#debug:目前这里报错
|
293 |
inputs_embeds[input_ids == image_token_idx] = z_enc.flatten(0, 1)
|
294 |
-
inputs_embeds[input_ids != image_token_idx] =
|
295 |
input_ids[input_ids != image_token_idx]
|
296 |
)
|
297 |
|
@@ -312,7 +312,7 @@ class Inferencer:
|
|
312 |
attention_mask = attention_mask.expand(bsz, -1)
|
313 |
|
314 |
# 7) Sampling
|
315 |
-
samples =
|
316 |
inputs_embeds=inputs_embeds,
|
317 |
attention_mask=attention_mask,
|
318 |
num_iter=num_iter,
|
|
|
244 |
|
245 |
# 2) Encode image and extract features
|
246 |
with torch.no_grad():
|
247 |
+
x_enc = model.encode(img_tensor)
|
248 |
+
x_con, z_enc = model.extract_visual_feature(x_enc)
|
249 |
|
250 |
# 3) Prepare text prompts
|
251 |
m = n = self.image_size // 16
|
|
|
267 |
cfg_prompt_str = cfg_prompt_str.replace('<image>', '<image>' * image_length)
|
268 |
|
269 |
# 4) Tokenize and prepare inputs
|
270 |
+
input_ids = model.tokenizer.encode(
|
271 |
prompt_str, add_special_tokens=True, return_tensors='pt')[0].cuda()
|
272 |
|
273 |
if cfg != 1.0:
|
274 |
+
null_input_ids = model.tokenizer.encode(
|
275 |
cfg_prompt_str, add_special_tokens=True, return_tensors='pt')[0].cuda()
|
276 |
attention_mask = pad_sequence(
|
277 |
[torch.ones_like(input_ids), torch.ones_like(null_input_ids)],
|
278 |
batch_first=True, padding_value=0).to(torch.bool)
|
279 |
input_ids = pad_sequence(
|
280 |
[input_ids, null_input_ids],
|
281 |
+
batch_first=True, padding_value=model.tokenizer.eos_token_id)
|
282 |
else:
|
283 |
input_ids = input_ids[None]
|
284 |
attention_mask = torch.ones_like(input_ids).to(torch.bool)
|
|
|
288 |
z_enc = torch.cat([z_enc, z_enc], dim=0)
|
289 |
x_con = torch.cat([x_con, x_con], dim=0)
|
290 |
|
291 |
+
inputs_embeds = z_enc.new_zeros(*input_ids.shape, model.llm.config.hidden_size)
|
292 |
#debug:目前这里报错
|
293 |
inputs_embeds[input_ids == image_token_idx] = z_enc.flatten(0, 1)
|
294 |
+
inputs_embeds[input_ids != image_token_idx] = model.llm.get_input_embeddings()(
|
295 |
input_ids[input_ids != image_token_idx]
|
296 |
)
|
297 |
|
|
|
312 |
attention_mask = attention_mask.expand(bsz, -1)
|
313 |
|
314 |
# 7) Sampling
|
315 |
+
samples = model.sample(
|
316 |
inputs_embeds=inputs_embeds,
|
317 |
attention_mask=attention_mask,
|
318 |
num_iter=num_iter,
|