File size: 4,259 Bytes
7e1fa02
 
99738e0
 
 
 
 
47dbef4
 
99738e0
2a887f4
b5ba35d
99738e0
 
ecd9835
d574ec9
47dbef4
 
 
5b14e9e
47dbef4
b3edf02
47dbef4
 
d7e8671
 
47dbef4
f3bd7d5
99738e0
 
 
 
 
 
47dbef4
 
 
 
 
 
 
 
 
 
 
1c193eb
99738e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e1fa02
99738e0
 
fc778e7
 
 
 
 
 
 
 
129ce66
fc778e7
 
 
 
99738e0
 
 
 
7e1fa02
99738e0
 
7e1fa02
99738e0
 
7e1fa02
99738e0
 
 
 
 
 
7e1fa02
99738e0
 
 
 
 
 
 
 
 
 
 
 
 
7e1fa02
 
99738e0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import gradio as gr
import torch
import spaces
import os
import numpy as np
from PIL import Image
from omegaconf import OmegaConf

from image_datasets.dataset import image_resize
args = OmegaConf.load("inference_configs/inference.yaml")
device = torch.device("cuda")
dtype = torch.bfloat16
@spaces.GPU
def generate(image: Image.Image, edit_prompt: str):
    
    from src.flux.xflux_pipeline import XFluxSampler
    

    sampler = XFluxSampler(
        device = device,
        ip_loaded=False,
        spatial_condition=True,
        clip_image_processor=None,
        image_encoder=None,
        improj=None,
        share_position_embedding = True,
    )
    
    img = image_resize(image, 512)
    w, h = img.size
    img = img.resize(((w // 32) * 32, (h // 32) * 32))
    img = torch.from_numpy((np.array(img) / 127.5) - 1)
    img = img.permute(2, 0, 1).unsqueeze(0).to(device, dtype=dtype)

    result = sampler(
        prompt=edit_prompt,
        width=args.sample_width,
        height=args.sample_height,
        num_steps=args.sample_steps,
        image_prompt=None,
        true_gs=args.cfg_scale,
        seed=args.seed,
        ip_scale=args.ip_scale if args.use_ip else 1.0,
        source_image=img if args.use_spatial_condition else None,
    )
    return result

def get_samples():
    sample_list = [
        {
            "image": "assets/0_camera_zoom/20486354.png",
            "edit_prompt": "Zoom in on the coral and add a small blue fish in the background.",
        },
    ]
    return [
        [
            Image.open(sample["image"]).resize((512, 512)),
            sample["edit_prompt"],
        ]
        for sample in sample_list
    ]


def create_app():
    with gr.Blocks() as app:
        gr.HTML(
            """
            <div style="text-align: center;">
                <h2>ByteMorpher</h2>
                <a href="https://boese0601.github.io/bytemorph/" target="_blank"><img src="https://img.shields.io/badge/Project-Website-blue" style="display:inline-block;"></a>
                <a href="https://github.com/Boese0601/ByteMorph" target="_blank"><img src="https://img.shields.io/github/stars/Boese0601/ByteMorph?label=GitHub%20%E2%98%85&logo=github&color=green" style="display:inline-block;"></a>
                <a href="https://huggingface.co/datasets/Boese0601/ByteMorph-6M-Demo" target="_blank"><img src="https://img.shields.io/badge/πŸ€—%20Hugging%20Face-Dataset_Demo-yellow" style="display:inline-block;"></a>
                <a href="https://huggingface.co/datasets/Boese0601/ByteMorph-Bench" target="_blank"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace%20-Benchmark-yellow" style="display:inline-block;"></a>
                <a href="https://huggingface.co/Boese0601/ByteMorpher" target="_blank"><img src="https://img.shields.io/badge/πŸ€—%20Hugging%20Face%20-Model-yellow" style="display:inline-block;"></a>
            </div>
            """
        )
        # gr.Markdown(header, elem_id="header")
        with gr.Row(equal_height=False):
            with gr.Column(variant="panel", elem_classes="inputPanel"):
                original_image = gr.Image(
                    type="pil", label="Condition Image", width=300, elem_id="input"
                )
                edit_prompt = gr.Textbox(lines=2, label="Edit Prompt", elem_id="edit_prompt")
                submit_btn = gr.Button("Run", elem_id="submit_btn")

            with gr.Column(variant="panel", elem_classes="outputPanel"):
                output_image = gr.Image(type="pil", elem_id="output")

        with gr.Row():
            examples = gr.Examples(
                examples=get_samples(),
                inputs=[original_image, edit_prompt],
                label="Examples",
            )

        submit_btn.click(
            fn=generate,
            inputs=[original_image, edit_prompt],
            outputs=output_image,
        )
        gr.HTML(
            """
            <div style="text-align: center;">
                * This demo's template was modified from <a href="https://arxiv.org/abs/2411.15098" target="_blank">OminiControl</a>.
            </div>
            """
        )
    return app

if __name__ == "__main__":
    create_app().launch(debug=False, share=False, ssr_mode=False)