Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import argparse | |
| import math | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoModelForCausalLM | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Test Image Editing") | |
| parser.add_argument( | |
| "--model_path", | |
| type=str, | |
| default="AIDC-AI/Ovis-U1-3B", | |
| ) | |
| parser.add_argument( | |
| "--steps", type=int, default=50, | |
| ) | |
| parser.add_argument( | |
| "--img_cfg", type=float, default=1.5, | |
| ) | |
| parser.add_argument( | |
| "--txt_cfg", type=float, default=6, | |
| ) | |
| args = parser.parse_args() | |
| return args | |
| def load_blank_image(width, height): | |
| pil_image = Image.new("RGB", (width, height), (255, 255, 255)).convert('RGB') | |
| return pil_image | |
| def build_inputs(model, text_tokenizer, visual_tokenizer, prompt, pil_image, target_width, target_height): | |
| if pil_image is not None: | |
| target_size = (int(target_width), int(target_height)) | |
| pil_image, vae_pixel_values, cond_img_ids = model.visual_generator.process_image_aspectratio(pil_image, target_size) | |
| cond_img_ids[..., 0] = 1.0 | |
| vae_pixel_values = vae_pixel_values.unsqueeze(0).to(device=model.device) | |
| width = pil_image.width | |
| height = pil_image.height | |
| resized_height, resized_width = visual_tokenizer.smart_resize(height, width, max_pixels=visual_tokenizer.image_processor.min_pixels) | |
| pil_image = pil_image.resize((resized_width, resized_height)) | |
| else: | |
| vae_pixel_values = None | |
| cond_img_ids = None | |
| prompt, input_ids, pixel_values, grid_thws = model.preprocess_inputs( | |
| prompt, | |
| [pil_image], | |
| generation_preface=None, | |
| return_labels=False, | |
| propagate_exception=False, | |
| multimodal_type='single_image', | |
| fix_sample_overall_length_navit=False | |
| ) | |
| attention_mask = torch.ne(input_ids, text_tokenizer.pad_token_id) | |
| input_ids = input_ids.unsqueeze(0).to(device=model.device) | |
| attention_mask = attention_mask.unsqueeze(0).to(device=model.device) | |
| if pixel_values is not None: | |
| pixel_values = torch.cat([ | |
| pixel_values.to(device=visual_tokenizer.device, dtype=torch.bfloat16) if pixel_values is not None else None | |
| ],dim=0) | |
| if grid_thws is not None: | |
| grid_thws = torch.cat([ | |
| grid_thws.to(device=visual_tokenizer.device) if grid_thws is not None else None | |
| ],dim=0) | |
| return input_ids, pixel_values, attention_mask, grid_thws, vae_pixel_values | |
| def pipe_img_edit(model, input_img, prompt, steps, txt_cfg, img_cfg, seed=42): | |
| text_tokenizer = model.get_text_tokenizer() | |
| visual_tokenizer = model.get_visual_tokenizer() | |
| width, height = input_img.size | |
| height, width = visual_tokenizer.smart_resize(height, width, factor=32) | |
| gen_kwargs = dict( | |
| max_new_tokens=1024, | |
| do_sample=False, | |
| top_p=None, | |
| top_k=None, | |
| temperature=None, | |
| repetition_penalty=None, | |
| eos_token_id=text_tokenizer.eos_token_id, | |
| pad_token_id=text_tokenizer.pad_token_id, | |
| use_cache=True, | |
| height=height, | |
| width=width, | |
| num_steps=steps, | |
| seed=seed, | |
| img_cfg=img_cfg, | |
| txt_cfg=txt_cfg, | |
| ) | |
| uncond_image = load_blank_image(width, height) | |
| uncond_prompt = "<image>\nGenerate an image." | |
| input_ids, pixel_values, attention_mask, grid_thws, _ = build_inputs(model, text_tokenizer, visual_tokenizer, uncond_prompt, uncond_image, width, height) | |
| with torch.inference_mode(): | |
| no_both_cond = model.generate_condition(input_ids, pixel_values=pixel_values, attention_mask=attention_mask, grid_thws=grid_thws, **gen_kwargs) | |
| input_img = input_img.resize((width, height)) | |
| prompt = "<image>\n" + prompt.strip() | |
| with torch.inference_mode(): | |
| input_ids, pixel_values, attention_mask, grid_thws, _ = build_inputs(model, text_tokenizer, visual_tokenizer, uncond_prompt, input_img, width, height) | |
| no_txt_cond = model.generate_condition(input_ids, pixel_values=pixel_values, attention_mask=attention_mask, grid_thws=grid_thws, **gen_kwargs) | |
| input_ids, pixel_values, attention_mask, grid_thws, vae_pixel_values = build_inputs(model, text_tokenizer, visual_tokenizer, prompt, input_img, width, height) | |
| with torch.inference_mode(): | |
| cond = model.generate_condition(input_ids, pixel_values=pixel_values, attention_mask=attention_mask, grid_thws=grid_thws, **gen_kwargs) | |
| cond["vae_pixel_values"] = vae_pixel_values | |
| images = model.generate_img(cond=cond, no_both_cond=no_both_cond, no_txt_cond=no_txt_cond, **gen_kwargs) | |
| return images | |
| def main(): | |
| args = parse_args() | |
| model, loading_info = AutoModelForCausalLM.from_pretrained(args.model_path, | |
| torch_dtype=torch.bfloat16, | |
| output_loading_info=True, | |
| trust_remote_code=True | |
| ) | |
| print(f'Loading info of Ovis-U1:\n{loading_info}') | |
| model = model.eval().to("cuda") | |
| model = model.to(torch.bfloat16) | |
| image_path = os.path.join(os.path.dirname(__file__), "docs", "imgs", "cat.png") | |
| pil_img = Image.open(image_path).convert('RGB') | |
| prompt = "add a hat to this cat." | |
| image = pipe_img_edit(model, pil_img, prompt, | |
| args.steps, args.txt_cfg, args.img_cfg)[0] | |
| image.save("test_image_edit.png") | |
| if __name__ == "__main__": | |
| main() |