Spaces:
dreroc
/
Running on Zero

yichenchenchen commited on
Commit
db62905
·
verified ·
1 Parent(s): f615fc2

Update inferencer.py

Browse files
Files changed (1) hide show
  1. inferencer.py +9 -138
inferencer.py CHANGED
@@ -80,7 +80,7 @@ class Inferencer:
80
  self.image_size = image_size
81
  self.image_shape = (image_size // 16, image_size // 16)
82
  self.cfg_prompt = cfg_prompt
83
- self.model = self.init_model()
84
 
85
  @spaces.GPU
86
  def init_model(self):
@@ -132,6 +132,8 @@ class Inferencer:
132
  cfg_schedule="constant",
133
  temperature=1.0,
134
  ):
 
 
135
  prompt = self.model.prompt_template["INSTRUCTION"].format(
136
  input=f"Generate an image: {raw_prompt.strip()}."
137
  )
@@ -168,6 +170,8 @@ class Inferencer:
168
  @spaces.GPU
169
  def query_image(self, img: Image.Image, prompt=""):
170
  model = self.model
 
 
171
  tokenizer = model.tokenizer
172
  special_tokens_dict = {"additional_special_tokens": ["<image>"]}
173
  tokenizer.add_special_tokens(special_tokens_dict)
@@ -214,143 +218,6 @@ class Inferencer:
214
  # print(tokenizer.decode(output[0]))
215
  return tokenizer.decode(output[0])
216
 
217
- # def edit_image(
218
- # self,
219
- # img: Image.Image,
220
- # prompt: str,
221
- # cfg: float = 2.0,
222
- # cfg_prompt: str = "Repeat this image.",
223
- # cfg_schedule="constant",
224
- # temperature: float = 1.0,
225
- # grid_size: int = 2,
226
- # num_iter: int = 64,
227
- # mode: str = "conditional",
228
- # ) -> list[Image.Image]:
229
-
230
- # model = self.model
231
- # tokenizer = model.tokenizer
232
- # m = n = self.image_size // 16
233
- # image_length = m * n + 64
234
-
235
- # # preprocess image
236
- # image = img.convert("RGB")
237
- # original_size = image.size
238
- # image = image.resize((self.image_size, self.image_size))
239
- # image = torch.from_numpy(np.array(image)).to(
240
- # dtype=model.dtype, device=self.device
241
- # )
242
- # image = rearrange(image, "h w c -> c h w")[None]
243
- # image = 2 * (image / 255) - 1
244
-
245
- # # prepare prompt
246
- # special_tokens_dict = {"additional_special_tokens": ["<image>"]}
247
- # tokenizer.add_special_tokens(special_tokens_dict)
248
- # image_token_idx = tokenizer.encode("<image>", add_special_tokens=False)[-1]
249
-
250
- # full_prompt = model.prompt_template["INSTRUCTION"].format(
251
- # input="<image>\n" + prompt
252
- # )
253
- # full_prompt = full_prompt.replace("<image>", "<image>" * image_length)
254
- # input_ids = tokenizer.encode(
255
- # full_prompt, add_special_tokens=True, return_tensors="pt"
256
- # )[0].to(self.device)
257
-
258
- # if cfg != 1.0:
259
- # null_prompt = model.prompt_template["INSTRUCTION"].format(
260
- # input="<image>\n" + cfg_prompt
261
- # )
262
- # null_prompt = null_prompt.replace("<image>", "<image>" * image_length)
263
- # null_input_ids = tokenizer.encode(
264
- # null_prompt, add_special_tokens=True, return_tensors="pt"
265
- # )[0].to(self.device)
266
- # attention_mask = pad_sequence(
267
- # [torch.ones_like(input_ids), torch.ones_like(null_input_ids)],
268
- # batch_first=True,
269
- # padding_value=0,
270
- # ).to(torch.bool)
271
- # input_ids = pad_sequence(
272
- # [input_ids, null_input_ids],
273
- # batch_first=True,
274
- # padding_value=tokenizer.eos_token_id,
275
- # )
276
- # else:
277
- # input_ids = input_ids[None]
278
- # attention_mask = torch.ones_like(input_ids).to(torch.bool)
279
-
280
- # with torch.no_grad():
281
- # x_enc = model.encode(image).to(model.dtype)
282
- # x_con, z_enc = model.extract_visual_feature(x_enc)
283
-
284
- # if cfg != 1.0:
285
- # z_enc = torch.cat([z_enc, z_enc], dim=0)
286
- # x_con = torch.cat([x_con, x_con], dim=0)
287
-
288
- # inputs_embeds = z_enc.new_zeros(*input_ids.shape, model.llm.config.hidden_size)
289
- # inputs_embeds[input_ids == image_token_idx] = z_enc.flatten(0, 1)
290
- # inputs_embeds[input_ids != image_token_idx] = model.llm.get_input_embeddings()(
291
- # input_ids[input_ids != image_token_idx]
292
- # )
293
-
294
- # # repeat
295
- # bsz = grid_size**2
296
- # x_con = torch.cat([x_con] * bsz)
297
- # if cfg != 1.0:
298
- # inputs_embeds = torch.cat(
299
- # [
300
- # inputs_embeds[:1].expand(bsz, -1, -1),
301
- # inputs_embeds[1:].expand(bsz, -1, -1),
302
- # ]
303
- # )
304
- # attention_mask = torch.cat(
305
- # [
306
- # attention_mask[:1].expand(bsz, -1),
307
- # attention_mask[1:].expand(bsz, -1),
308
- # ]
309
- # )
310
- # else:
311
- # inputs_embeds = inputs_embeds.expand(bsz, -1, -1)
312
- # attention_mask = attention_mask.expand(bsz, -1)
313
-
314
- # # sample
315
- # with torch.no_grad():
316
- # if mode == "conditional":
317
- # samples = model.sample(
318
- # inputs_embeds=inputs_embeds,
319
- # attention_mask=attention_mask,
320
- # num_iter=num_iter,
321
- # cfg=cfg,
322
- # cfg_schedule=cfg_schedule,
323
- # temperature=temperature,
324
- # progress=False,
325
- # image_shape=(m, n),
326
- # x_con=x_con,
327
- # )
328
- # else:
329
- # samples = model.sample(
330
- # inputs_embeds=inputs_embeds,
331
- # attention_mask=attention_mask,
332
- # num_iter=num_iter,
333
- # cfg=cfg,
334
- # cfg_schedule=cfg_schedule,
335
- # temperature=temperature,
336
- # progress=False,
337
- # image_shape=(m, n),
338
- # )
339
-
340
- # samples = rearrange(
341
- # samples, "(m n) c h w -> (m h) (n w) c", m=grid_size, n=grid_size
342
- # )
343
- # samples = (
344
- # torch.clamp(127.5 * samples + 128.0, 0, 255)
345
- # .to("cpu", dtype=torch.uint8)
346
- # .numpy()
347
- # )
348
-
349
- # output_image = Image.fromarray(samples).resize(
350
- # (original_size[0] * grid_size, original_size[1] * grid_size)
351
- # )
352
- # return [output_image]
353
-
354
  @spaces.GPU
355
  def edit_image(
356
  self,
@@ -366,6 +233,8 @@ class Inferencer:
366
  """Edit single image based on prompt."""
367
 
368
  model = self.model
 
 
369
  tokenizer = model.tokenizer
370
  special_tokens_dict = {"additional_special_tokens": ["<image>"]}
371
  tokenizer.add_special_tokens(special_tokens_dict)
@@ -466,6 +335,8 @@ class Inferencer:
466
  @spaces.GPU
467
  def query_text(self, prompt=""):
468
  model = self.model
 
 
469
  tokenizer = model.tokenizer
470
 
471
  # 构造文本 prompt
 
80
  self.image_size = image_size
81
  self.image_shape = (image_size // 16, image_size // 16)
82
  self.cfg_prompt = cfg_prompt
83
+ self.model = None
84
 
85
  @spaces.GPU
86
  def init_model(self):
 
132
  cfg_schedule="constant",
133
  temperature=1.0,
134
  ):
135
+ if not model:
136
+ self.model = self.init_model()
137
  prompt = self.model.prompt_template["INSTRUCTION"].format(
138
  input=f"Generate an image: {raw_prompt.strip()}."
139
  )
 
170
  @spaces.GPU
171
  def query_image(self, img: Image.Image, prompt=""):
172
  model = self.model
173
+ if not model:
174
+ self.model = self.init_model()
175
  tokenizer = model.tokenizer
176
  special_tokens_dict = {"additional_special_tokens": ["<image>"]}
177
  tokenizer.add_special_tokens(special_tokens_dict)
 
218
  # print(tokenizer.decode(output[0]))
219
  return tokenizer.decode(output[0])
220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  @spaces.GPU
222
  def edit_image(
223
  self,
 
233
  """Edit single image based on prompt."""
234
 
235
  model = self.model
236
+ if not model:
237
+ self.model = self.init_model()
238
  tokenizer = model.tokenizer
239
  special_tokens_dict = {"additional_special_tokens": ["<image>"]}
240
  tokenizer.add_special_tokens(special_tokens_dict)
 
335
  @spaces.GPU
336
  def query_text(self, prompt=""):
337
  model = self.model
338
+ if not model:
339
+ self.model = self.init_model()
340
  tokenizer = model.tokenizer
341
 
342
  # 构造文本 prompt