import gradio as gr import torch import spaces import os import numpy as np from PIL import Image from huggingface_hub import hf_hub_download from safetensors.torch import load_file from omegaconf import OmegaConf from image_datasets.dataset import image_resize def tensor_to_pil_image(in_image): tensor = in_image.squeeze(0) tensor = (tensor + 1) / 2 tensor = tensor * 255 numpy_array = tensor.permute(1, 2, 0).byte().numpy() pil_image = Image.fromarray(numpy_array) return pil_image # from src.flux.xflux_pipeline import XFluxSampler args = OmegaConf.load("inference_configs/inference.yaml") # is_schnell = args.model_name == "flux-schnell" # sampler = None device = torch.device("cuda") # dtype = torch.bfloat16 # dit = load_flow_model2(args.model_name, device="cpu").to(device, dtype=dtype) # vae = load_ae(args.model_name, device="cpu").to(device, dtype=dtype) # t5 = load_t5(device="cpu", max_length=256 if is_schnell else 512).to(device, dtype=dtype) # clip = load_clip("cpu").to(device, dtype=dtype) #test push @spaces.GPU def generate(image: Image.Image, edit_prompt: str): from src.flux.xflux_pipeline import XFluxSampler # vae.requires_grad_(False) # t5.requires_grad_(False) # clip.requires_grad_(False) # model_path = hf_hub_download( # repo_id="Boese0601/ByteMorpher", # filename="dit.safetensors", # use_auth_token=os.getenv("HF_TOKEN") # ) # state_dict = load_file(model_path) # dit.load_state_dict(state_dict) # dit.eval() # dit.to(device, dtype=dtype) sampler = XFluxSampler( ip_loaded=False, spatial_condition=False, clip_image_processor=None, image_encoder=None, improj=None ) # global sampler # device = torch.device("cuda") # dtype = torch.bfloat16 # if sampler is None: # dit = load_flow_model2(args.model_name, device="cpu").to(device, dtype=dtype) # vae = load_ae(args.model_name, device="cpu").to(device, dtype=dtype) # t5 = load_t5(device="cpu", max_length=256 if is_schnell else 512).to(device, dtype=dtype) # clip = load_clip("cpu").to(device, dtype=dtype) # vae.requires_grad_(False) # t5.requires_grad_(False) # clip.requires_grad_(False) # model_path = hf_hub_download( # repo_id="Boese0601/ByteMorpher", # filename="dit.safetensors", # use_auth_token=os.getenv("HF_TOKEN") # ) # state_dict = load_file(model_path) # dit.load_state_dict(state_dict) # dit.eval() # sampler = XFluxSampler( # clip=clip, # t5=t5, # ae=vae, # model=dit, # device=device, # ip_loaded=False, # spatial_condition=False, # clip_image_processor=None, # image_encoder=None, # improj=None # ) img = image_resize(image, 512) w, h = img.size img = img.resize(((w // 32) * 32, (h // 32) * 32)) img = torch.from_numpy((np.array(img) / 127.5) - 1) img = img.permute(2, 0, 1).unsqueeze(0).to(device, dtype=dtype) result = sampler( device='cuda', prompt=edit_prompt, width=args.sample_width, height=args.sample_height, num_steps=args.sample_steps, image_prompt=None, true_gs=args.cfg_scale, seed=args.seed, ip_scale=args.ip_scale if args.use_ip else 1.0, source_image=img if args.use_spatial_condition else None, ) return tensor_to_pil_image(result) def get_samples(): sample_list = [ { "image": "assets/0_camera_zoom/20486354.png", "edit_prompt": "Zoom in on the coral and add a small blue fish in the background.", }, ] return [ [ Image.open(sample["image"]).resize((512, 512)), sample["edit_prompt"], ] for sample in sample_list ] header = """ # ByteMorph
""" def create_app(): with gr.Blocks() as app: gr.Markdown(header, elem_id="header") with gr.Row(equal_height=False): with gr.Column(variant="panel", elem_classes="inputPanel"): original_image = gr.Image( type="pil", label="Condition Image", width=300, elem_id="input" ) edit_prompt = gr.Textbox(lines=2, label="Edit Prompt", elem_id="edit_prompt") submit_btn = gr.Button("Run", elem_id="submit_btn") with gr.Column(variant="panel", elem_classes="outputPanel"): output_image = gr.Image(type="pil", elem_id="output") with gr.Row(): examples = gr.Examples( examples=get_samples(), inputs=[original_image, edit_prompt], label="Examples", ) submit_btn.click( fn=generate, inputs=[original_image, edit_prompt], outputs=output_image, ) gr.HTML( """