Spaces:
dreroc
/
Running on Zero

yichenchenchen commited on
Commit
8f7a765
·
verified ·
1 Parent(s): 7db5491

Update inferencer.py

Browse files
Files changed (1) hide show
  1. inferencer.py +8 -8
inferencer.py CHANGED
@@ -244,8 +244,8 @@ class Inferencer:
244
 
245
  # 2) Encode image and extract features
246
  with torch.no_grad():
247
- x_enc = self.model.encode(img_tensor)
248
- x_con, z_enc = self.model.extract_visual_feature(x_enc)
249
 
250
  # 3) Prepare text prompts
251
  m = n = self.image_size // 16
@@ -267,18 +267,18 @@ class Inferencer:
267
  cfg_prompt_str = cfg_prompt_str.replace('<image>', '<image>' * image_length)
268
 
269
  # 4) Tokenize and prepare inputs
270
- input_ids = self.model.tokenizer.encode(
271
  prompt_str, add_special_tokens=True, return_tensors='pt')[0].cuda()
272
 
273
  if cfg != 1.0:
274
- null_input_ids = self.model.tokenizer.encode(
275
  cfg_prompt_str, add_special_tokens=True, return_tensors='pt')[0].cuda()
276
  attention_mask = pad_sequence(
277
  [torch.ones_like(input_ids), torch.ones_like(null_input_ids)],
278
  batch_first=True, padding_value=0).to(torch.bool)
279
  input_ids = pad_sequence(
280
  [input_ids, null_input_ids],
281
- batch_first=True, padding_value=self.model.tokenizer.eos_token_id)
282
  else:
283
  input_ids = input_ids[None]
284
  attention_mask = torch.ones_like(input_ids).to(torch.bool)
@@ -288,10 +288,10 @@ class Inferencer:
288
  z_enc = torch.cat([z_enc, z_enc], dim=0)
289
  x_con = torch.cat([x_con, x_con], dim=0)
290
 
291
- inputs_embeds = z_enc.new_zeros(*input_ids.shape, self.model.llm.config.hidden_size)
292
  #debug:目前这里报错
293
  inputs_embeds[input_ids == image_token_idx] = z_enc.flatten(0, 1)
294
- inputs_embeds[input_ids != image_token_idx] = self.model.llm.get_input_embeddings()(
295
  input_ids[input_ids != image_token_idx]
296
  )
297
 
@@ -312,7 +312,7 @@ class Inferencer:
312
  attention_mask = attention_mask.expand(bsz, -1)
313
 
314
  # 7) Sampling
315
- samples = self.model.sample(
316
  inputs_embeds=inputs_embeds,
317
  attention_mask=attention_mask,
318
  num_iter=num_iter,
 
244
 
245
  # 2) Encode image and extract features
246
  with torch.no_grad():
247
+ x_enc = model.encode(img_tensor)
248
+ x_con, z_enc = model.extract_visual_feature(x_enc)
249
 
250
  # 3) Prepare text prompts
251
  m = n = self.image_size // 16
 
267
  cfg_prompt_str = cfg_prompt_str.replace('<image>', '<image>' * image_length)
268
 
269
  # 4) Tokenize and prepare inputs
270
+ input_ids = model.tokenizer.encode(
271
  prompt_str, add_special_tokens=True, return_tensors='pt')[0].cuda()
272
 
273
  if cfg != 1.0:
274
+ null_input_ids = model.tokenizer.encode(
275
  cfg_prompt_str, add_special_tokens=True, return_tensors='pt')[0].cuda()
276
  attention_mask = pad_sequence(
277
  [torch.ones_like(input_ids), torch.ones_like(null_input_ids)],
278
  batch_first=True, padding_value=0).to(torch.bool)
279
  input_ids = pad_sequence(
280
  [input_ids, null_input_ids],
281
+ batch_first=True, padding_value=model.tokenizer.eos_token_id)
282
  else:
283
  input_ids = input_ids[None]
284
  attention_mask = torch.ones_like(input_ids).to(torch.bool)
 
288
  z_enc = torch.cat([z_enc, z_enc], dim=0)
289
  x_con = torch.cat([x_con, x_con], dim=0)
290
 
291
+ inputs_embeds = z_enc.new_zeros(*input_ids.shape, model.llm.config.hidden_size)
292
  #debug:目前这里报错
293
  inputs_embeds[input_ids == image_token_idx] = z_enc.flatten(0, 1)
294
+ inputs_embeds[input_ids != image_token_idx] = model.llm.get_input_embeddings()(
295
  input_ids[input_ids != image_token_idx]
296
  )
297
 
 
312
  attention_mask = attention_mask.expand(bsz, -1)
313
 
314
  # 7) Sampling
315
+ samples = model.sample(
316
  inputs_embeds=inputs_embeds,
317
  attention_mask=attention_mask,
318
  num_iter=num_iter,