ByteMorph-Demo / app.py
Boese0601's picture
Update app.py
129ce66 verified
raw
history blame
4.26 kB
import gradio as gr
import torch
import spaces
import os
import numpy as np
from PIL import Image
from omegaconf import OmegaConf
from image_datasets.dataset import image_resize
args = OmegaConf.load("inference_configs/inference.yaml")
device = torch.device("cuda")
dtype = torch.bfloat16
@spaces.GPU
def generate(image: Image.Image, edit_prompt: str):
from src.flux.xflux_pipeline import XFluxSampler
sampler = XFluxSampler(
device = device,
ip_loaded=False,
spatial_condition=True,
clip_image_processor=None,
image_encoder=None,
improj=None,
share_position_embedding = True,
)
img = image_resize(image, 512)
w, h = img.size
img = img.resize(((w // 32) * 32, (h // 32) * 32))
img = torch.from_numpy((np.array(img) / 127.5) - 1)
img = img.permute(2, 0, 1).unsqueeze(0).to(device, dtype=dtype)
result = sampler(
prompt=edit_prompt,
width=args.sample_width,
height=args.sample_height,
num_steps=args.sample_steps,
image_prompt=None,
true_gs=args.cfg_scale,
seed=args.seed,
ip_scale=args.ip_scale if args.use_ip else 1.0,
source_image=img if args.use_spatial_condition else None,
)
return result
def get_samples():
sample_list = [
{
"image": "assets/0_camera_zoom/20486354.png",
"edit_prompt": "Zoom in on the coral and add a small blue fish in the background.",
},
]
return [
[
Image.open(sample["image"]).resize((512, 512)),
sample["edit_prompt"],
]
for sample in sample_list
]
def create_app():
with gr.Blocks() as app:
gr.HTML(
"""
<div style="text-align: center;">
<h2>ByteMorpher</h2>
<a href="https://boese0601.github.io/bytemorph/" target="_blank"><img src="https://img.shields.io/badge/Project-Website-blue" style="display:inline-block;"></a>
<a href="https://github.com/Boese0601/ByteMorph" target="_blank"><img src="https://img.shields.io/github/stars/Boese0601/ByteMorph?label=GitHub%20%E2%98%85&logo=github&color=green" style="display:inline-block;"></a>
<a href="https://huggingface.co/datasets/Boese0601/ByteMorph-6M-Demo" target="_blank"><img src="https://img.shields.io/badge/πŸ€—%20Hugging%20Face-Dataset_Demo-yellow" style="display:inline-block;"></a>
<a href="https://huggingface.co/datasets/Boese0601/ByteMorph-Bench" target="_blank"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace%20-Benchmark-yellow" style="display:inline-block;"></a>
<a href="https://huggingface.co/Boese0601/ByteMorpher" target="_blank"><img src="https://img.shields.io/badge/πŸ€—%20Hugging%20Face%20-Model-yellow" style="display:inline-block;"></a>
</div>
"""
)
# gr.Markdown(header, elem_id="header")
with gr.Row(equal_height=False):
with gr.Column(variant="panel", elem_classes="inputPanel"):
original_image = gr.Image(
type="pil", label="Condition Image", width=300, elem_id="input"
)
edit_prompt = gr.Textbox(lines=2, label="Edit Prompt", elem_id="edit_prompt")
submit_btn = gr.Button("Run", elem_id="submit_btn")
with gr.Column(variant="panel", elem_classes="outputPanel"):
output_image = gr.Image(type="pil", elem_id="output")
with gr.Row():
examples = gr.Examples(
examples=get_samples(),
inputs=[original_image, edit_prompt],
label="Examples",
)
submit_btn.click(
fn=generate,
inputs=[original_image, edit_prompt],
outputs=output_image,
)
gr.HTML(
"""
<div style="text-align: center;">
* This demo's template was modified from <a href="https://arxiv.org/abs/2411.15098" target="_blank">OminiControl</a>.
</div>
"""
)
return app
if __name__ == "__main__":
create_app().launch(debug=False, share=False, ssr_mode=False)