44444 / app.py
1
Update app.py
65dcc19
raw
history blame
9.9 kB
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import json
from pathlib import Path
import gradio as gr
import torch
import spaces
from uno.flux.pipeline import UNOPipeline
def get_examples(examples_dir: str = "assets/examples") -> list:
examples = Path(examples_dir)
ans = []
for example in examples.iterdir():
if not example.is_dir():
continue
with open(example / "config.json") as f:
example_dict = json.load(f)
example_list = []
example_list.append(example_dict["useage"]) # case for
example_list.append(example_dict["prompt"]) # prompt
for key in ["image_ref1", "image_ref2", "image_ref3", "image_ref4"]:
if key in example_dict:
example_list.append(str(example / example_dict[key]))
else:
example_list.append(None)
example_list.append(example_dict["seed"])
ans.append(example_list)
return ans
def create_demo(
model_type: str,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
offload: bool = False,
):
pipeline = UNOPipeline(model_type, device, offload, only_lora=True, lora_rank=512)
pipeline.gradio_generate = spaces.GPU(duratioin=120)(pipeline.gradio_generate)
# 自定义CSS样式
css = """
.gradio-container {
font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif;
}
.main-header {
text-align: center;
margin-bottom: 2rem;
background: linear-gradient(to right, #4776E6, #8E54E9);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
font-weight: 700;
padding: 1rem 0;
}
.container {
border-radius: 12px;
box-shadow: 0 4px 20px rgba(0, 0, 0, 0.1);
padding: 20px;
background: white;
margin-bottom: 1.5rem;
}
.input-container {
background: rgba(245, 247, 250, 0.7);
border-radius: 10px;
padding: 1rem;
margin-bottom: 1rem;
}
.image-grid {
display: grid;
grid-template-columns: repeat(auto-fill, minmax(200px, 1fr));
gap: 10px;
}
.generate-btn {
background: linear-gradient(90deg, #4776E6, #8E54E9);
border: none;
color: white;
padding: 10px 20px;
border-radius: 50px;
font-weight: 600;
box-shadow: 0 4px 10px rgba(0,0,0,0.1);
transition: all 0.3s ease;
}
.generate-btn:hover {
transform: translateY(-2px);
box-shadow: 0 6px 15px rgba(0,0,0,0.15);
}
.badge-container {
display: flex;
justify-content: center;
align-items: center;
gap: 8px;
flex-wrap: wrap;
margin-bottom: 1rem;
}
.badge {
display: inline-block;
padding: 0.25rem 0.75rem;
font-size: 0.875rem;
font-weight: 500;
line-height: 1.5;
text-align: center;
white-space: nowrap;
vertical-align: middle;
border-radius: 30px;
color: white;
background: #6c5ce7;
text-decoration: none;
}
.output-container {
background: rgba(243, 244, 246, 0.7);
border-radius: 10px;
padding: 1.5rem;
}
.slider-container label {
font-weight: 600;
margin-bottom: 0.5rem;
color: #4a5568;
}
"""
badges_text = r"""
<div class="badge-container">
<a href="https://github.com/bytedance/UNO" class="badge" style="background: #24292e;"><img alt="GitHub Stars" src="https://img.shields.io/github/stars/bytedance/UNO" style="vertical-align: middle;"></a>
<a href="https://bytedance.github.io/UNO/" class="badge" style="background: #f1c40f; color: #333;"><img alt="Project Page" src="https://img.shields.io/badge/Project%20Page-UNO-yellow" style="vertical-align: middle;"></a>
<a href="https://arxiv.org/abs/2504.02160" class="badge" style="background: #b31b1b;"><img alt="arXiv" src="https://img.shields.io/badge/arXiv%20paper-UNO-b31b1b.svg" style="vertical-align: middle;"></a>
<a href="https://huggingface.co/bytedance-research/UNO" class="badge" style="background: #FF9D00;"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%97%20Hugging%20Face&message=Model&color=orange" style="vertical-align: middle;"></a>
<a href="https://huggingface.co/spaces/bytedance-research/UNO-FLUX" class="badge" style="background: #FF9D00;"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%97%20Hugging%20Face&message=demo&color=orange" style="vertical-align: middle;"></a>
</div>
""".strip()
with gr.Blocks(css=css) as demo:
gr.Markdown("# <div class='main-header'>UNO-FLUX Image Generator</div>")
gr.Markdown(badges_text)
with gr.Row():
with gr.Column(scale=3):
with gr.Group(elem_classes="container"):
prompt = gr.Textbox(
label="Prompt",
placeholder="Describe the image you want to generate...",
value="handsome woman in the city",
elem_classes="input-container"
)
gr.Markdown("### Reference Images")
with gr.Row(elem_classes="image-grid"):
image_prompt1 = gr.Image(label="Ref Img 1", visible=True, interactive=True, type="pil")
image_prompt2 = gr.Image(label="Ref Img 2", visible=True, interactive=True, type="pil")
image_prompt3 = gr.Image(label="Ref Img 3", visible=True, interactive=True, type="pil")
image_prompt4 = gr.Image(label="Ref Img 4", visible=True, interactive=True, type="pil")
with gr.Row():
with gr.Column(scale=2):
with gr.Group(elem_classes="slider-container"):
width = gr.Slider(512, 2048, 512, step=16, label="Generation Width")
height = gr.Slider(512, 2048, 512, step=16, label="Generation Height")
with gr.Column(scale=1):
gr.Markdown("<div style='background: #f8f9fa; padding: 10px; border-radius: 8px; border-left: 4px solid #4776E6;'>📌 The model was trained on 512x512 resolution.<br>Sizes closer to 512 are more stable, higher sizes give better visual effects but are less stable.</div>")
with gr.Accordion("Advanced Options", open=False):
with gr.Row():
with gr.Column():
num_steps = gr.Slider(1, 50, 25, step=1, label="Number of steps")
with gr.Column():
guidance = gr.Slider(1.0, 5.0, 4.0, step=0.1, label="Guidance", interactive=True)
with gr.Column():
seed = gr.Number(-1, label="Seed (-1 for random)")
generate_btn = gr.Button("Generate", elem_classes="generate-btn")
with gr.Column(scale=2):
with gr.Group(elem_classes="output-container"):
gr.Markdown("### Generated Result")
output_image = gr.Image(label="Generated Image")
download_btn = gr.File(label="Download full-resolution", type="filepath", interactive=False)
inputs = [
prompt, width, height, guidance, num_steps,
seed, image_prompt1, image_prompt2, image_prompt3, image_prompt4
]
generate_btn.click(
fn=pipeline.gradio_generate,
inputs=inputs,
outputs=[output_image, download_btn],
)
example_text = gr.Text("", visible=False, label="Case For:")
examples = get_examples("./assets/examples")
with gr.Group(elem_classes="container"):
gr.Markdown("### <div style='text-align: center; margin-bottom: 1rem;'>Examples</div>")
gr.Examples(
examples=examples,
inputs=[
example_text, prompt,
image_prompt1, image_prompt2, image_prompt3, image_prompt4,
seed, output_image
],
)
return demo
if __name__ == "__main__":
from typing import Literal
from transformers import HfArgumentParser
@dataclasses.dataclass
class AppArgs:
name: Literal["flux-dev", "flux-dev-fp8", "flux-schnell"] = "flux-dev"
device: Literal["cuda", "cpu"] = "cuda" if torch.cuda.is_available() else "cpu"
offload: bool = dataclasses.field(
default=False,
metadata={"help": "If True, sequantial offload the models(ae, dit, text encoder) to CPU if not used."}
)
port: int = 7860
parser = HfArgumentParser([AppArgs])
args_tuple = parser.parse_args_into_dataclasses() # type: tuple[AppArgs]
args = args_tuple[0]
demo = create_demo(args.name, args.device, args.offload)
demo.launch(server_port=args.port, ssr_mode=False)