ByteMorph-Demo / app.py
Boese0601's picture
Update app.py
ead67c4 verified
raw
history blame
5.87 kB
import gradio as gr
import torch
import spaces
import os
import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from omegaconf import OmegaConf
from image_datasets.dataset import image_resize
def tensor_to_pil_image(in_image):
tensor = in_image.squeeze(0)
tensor = (tensor + 1) / 2
tensor = tensor * 255
numpy_array = tensor.permute(1, 2, 0).byte().numpy()
pil_image = Image.fromarray(numpy_array)
return pil_image
# from src.flux.xflux_pipeline import XFluxSampler
args = OmegaConf.load("inference_configs/inference.yaml")
# is_schnell = args.model_name == "flux-schnell"
# sampler = None
device = torch.device("cuda")
# dtype = torch.bfloat16
# dit = load_flow_model2(args.model_name, device="cpu").to(device, dtype=dtype)
# vae = load_ae(args.model_name, device="cpu").to(device, dtype=dtype)
# t5 = load_t5(device="cpu", max_length=256 if is_schnell else 512).to(device, dtype=dtype)
# clip = load_clip("cpu").to(device, dtype=dtype)
#test push
@spaces.GPU
def generate(image: Image.Image, edit_prompt: str):
from src.flux.xflux_pipeline import XFluxSampler
# vae.requires_grad_(False)
# t5.requires_grad_(False)
# clip.requires_grad_(False)
# model_path = hf_hub_download(
# repo_id="Boese0601/ByteMorpher",
# filename="dit.safetensors",
# use_auth_token=os.getenv("HF_TOKEN")
# )
# state_dict = load_file(model_path)
# dit.load_state_dict(state_dict)
# dit.eval()
# dit.to(device, dtype=dtype)
sampler = XFluxSampler(
ip_loaded=False,
spatial_condition=False,
clip_image_processor=None,
image_encoder=None,
improj=None
)
# global sampler
# device = torch.device("cuda")
# dtype = torch.bfloat16
# if sampler is None:
# dit = load_flow_model2(args.model_name, device="cpu").to(device, dtype=dtype)
# vae = load_ae(args.model_name, device="cpu").to(device, dtype=dtype)
# t5 = load_t5(device="cpu", max_length=256 if is_schnell else 512).to(device, dtype=dtype)
# clip = load_clip("cpu").to(device, dtype=dtype)
# vae.requires_grad_(False)
# t5.requires_grad_(False)
# clip.requires_grad_(False)
# model_path = hf_hub_download(
# repo_id="Boese0601/ByteMorpher",
# filename="dit.safetensors",
# use_auth_token=os.getenv("HF_TOKEN")
# )
# state_dict = load_file(model_path)
# dit.load_state_dict(state_dict)
# dit.eval()
# sampler = XFluxSampler(
# clip=clip,
# t5=t5,
# ae=vae,
# model=dit,
# device=device,
# ip_loaded=False,
# spatial_condition=False,
# clip_image_processor=None,
# image_encoder=None,
# improj=None
# )
img = image_resize(image, 512)
w, h = img.size
img = img.resize(((w // 32) * 32, (h // 32) * 32))
img = torch.from_numpy((np.array(img) / 127.5) - 1)
img = img.permute(2, 0, 1).unsqueeze(0).to(device, dtype=dtype)
result = sampler(
device='cuda',
prompt=edit_prompt,
width=args.sample_width,
height=args.sample_height,
num_steps=args.sample_steps,
image_prompt=None,
true_gs=args.cfg_scale,
seed=args.seed,
ip_scale=args.ip_scale if args.use_ip else 1.0,
source_image=img if args.use_spatial_condition else None,
)
return tensor_to_pil_image(result)
def get_samples():
sample_list = [
{
"image": "assets/0_camera_zoom/20486354.png",
"edit_prompt": "Zoom in on the coral and add a small blue fish in the background.",
},
]
return [
[
Image.open(sample["image"]).resize((512, 512)),
sample["edit_prompt"],
]
for sample in sample_list
]
header = """
# ByteMorph
<div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
<a href=""><img src="https://img.shields.io/badge/ariXv-Paper-A42C25.svg" alt="arXiv"></a>
<a href="https://huggingface.co/datasets/Boese0601/ByteMorph-Bench"><img src="https://img.shields.io/badge/🤗-Model-ffbd45.svg" alt="HuggingFace"></a>
<a href="https://github.com/Boese0601/ByteMorph"><img src="https://img.shields.io/badge/GitHub-Code-blue.svg?logo=github&" alt="GitHub"></a>
</div>
"""
def create_app():
with gr.Blocks() as app:
gr.Markdown(header, elem_id="header")
with gr.Row(equal_height=False):
with gr.Column(variant="panel", elem_classes="inputPanel"):
original_image = gr.Image(
type="pil", label="Condition Image", width=300, elem_id="input"
)
edit_prompt = gr.Textbox(lines=2, label="Edit Prompt", elem_id="edit_prompt")
submit_btn = gr.Button("Run", elem_id="submit_btn")
with gr.Column(variant="panel", elem_classes="outputPanel"):
output_image = gr.Image(type="pil", elem_id="output")
with gr.Row():
examples = gr.Examples(
examples=get_samples(),
inputs=[original_image, edit_prompt],
label="Examples",
)
submit_btn.click(
fn=generate,
inputs=[original_image, edit_prompt],
outputs=output_image,
)
gr.HTML(
"""
<div style="text-align: center;">
* This demo's template was modified from <a href="https://arxiv.org/abs/2411.15098" target="_blank">OminiControl</a>.
</div>
"""
)
return app
if __name__ == "__main__":
create_app().launch(debug=False, share=False, ssr_mode=False)