Spaces:
Paused
Paused
| # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import dataclasses | |
| import json | |
| from pathlib import Path | |
| import gradio as gr | |
| import torch | |
| import openai | |
| import os | |
| from uno.flux.pipeline import UNOPipeline | |
| from uno.utils.prompt_enhancer import enhance_prompt_with_chatgpt | |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" | |
| openai.api_key = os.getenv("OPENAI_API_KEY") | |
| from huggingface_hub import login | |
| login(token=os.getenv("HUGGINGFACE_TOKEN")) | |
| def get_examples(examples_dir: str = "assets/examples") -> list: | |
| examples = Path(examples_dir) | |
| ans = [] | |
| for example in examples.iterdir(): | |
| if not example.is_dir(): | |
| continue | |
| with open(example / "config.json") as f: | |
| example_dict = json.load(f) | |
| example_list = [example_dict["useage"], example_dict["prompt"]] | |
| for key in ["image_ref1", "image_ref2", "image_ref3", "image_ref4"]: | |
| example_list.append(str(example / example_dict[key]) if key in example_dict else None) | |
| example_list.append(example_dict["seed"]) | |
| ans.append(example_list) | |
| return ans | |
| def create_demo(model_type: str, device: str = "cuda" if torch.cuda.is_available() else "cpu", offload: bool = False): | |
| pipeline = UNOPipeline(model_type, device, offload, only_lora=True, lora_rank=512) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# UNO by UNO team") | |
| gr.Markdown( | |
| """ | |
| <div style="text-align: center; display: flex; justify-content: left; gap: 5px;"> | |
| <a href="https://github.com/bytedance/UNO"><img alt="Build" src="https://img.shields.io/github/stars/bytedance/UNO"></a> | |
| <a href="https://bytedance.github.io/UNO/"><img alt="Build" src="https://img.shields.io/badge/Project%20Page-UNO-yellow"></a> | |
| <a href="https://arxiv.org/abs/2504.02160"><img alt="Build" src="https://img.shields.io/badge/arXiv%20paper-UNO-b31b1b.svg"></a> | |
| <a href="https://huggingface.co/bytedance-research/UNO"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%97%20Hugging%20Face&message=Model&color=orange"></a> | |
| <a href="https://huggingface.co/spaces/bytedance-research/UNO-FLUX"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%97%20Hugging%20Face&message=demo&color=orange"></a> | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt = gr.Textbox(label="Prompt", value="handsome woman in the city") | |
| with gr.Row(): | |
| image_prompt1 = gr.Image(label="Ref Img1", type="pil") | |
| image_prompt2 = gr.Image(label="Ref Img2", type="pil") | |
| image_prompt3 = gr.Image(label="Ref Img3", type="pil") | |
| image_prompt4 = gr.Image(label="Ref Img4", type="pil") | |
| with gr.Row(): | |
| with gr.Column(): | |
| width = gr.Slider(512, 2048, 512, step=16, label="Generation Width") | |
| height = gr.Slider(512, 2048, 512, step=16, label="Generation Height") | |
| with gr.Column(): | |
| gr.Markdown("π Trained on 512x512. Larger size = better quality, but less stable.") | |
| with gr.Accordion("Advanced Options", open=False): | |
| with gr.Row(): | |
| num_steps = gr.Slider(1, 50, 25, step=1, label="Number of steps") | |
| guidance = gr.Slider(1.0, 5.0, 4.0, step=0.1, label="Guidance") | |
| seed = gr.Number(-1, label="Seed (-1 for random)") | |
| num_outputs = gr.Slider(1, 5, 5, step=1, label="Number of Enhanced Prompts / Images") | |
| generate_btn = gr.Button("Generate Enhanced Images") | |
| with gr.Column(): | |
| outputs = [] | |
| for i in range(5): | |
| outputs.append(gr.Image(label=f"Image {i+1}")) | |
| outputs.append(gr.Textbox(label=f"Enhanced Prompt {i+1}")) | |
| def run_generation(prompt, width, height, guidance, num_steps, seed, | |
| img1, img2, img3, img4, num_outputs): | |
| uploaded_images = [img for img in [img1, img2, img3, img4] if img is not None] | |
| print(f"\nπ₯ [DEBUG] User prompt: {prompt}") | |
| prompts = enhance_prompt_with_chatgpt( | |
| user_prompt=prompt, | |
| num_prompts=num_outputs, | |
| reference_images=uploaded_images | |
| ) | |
| print(f"\nπ§ [DEBUG] Final Prompt List (len={len(prompts)}):") | |
| for idx, p in enumerate(prompts): | |
| print(f" [{idx+1}] {p}") | |
| while len(prompts) < num_outputs: | |
| prompts.append(prompt) | |
| results = [] | |
| for i in range(num_outputs): | |
| try: | |
| seed_val = int(seed) if seed != -1 else torch.randint(0, 10**8, (1,)).item() | |
| print(f"π§ͺ [DEBUG] Using seed: {seed_val} for image {i+1}") | |
| gen_image, _ = pipeline.gradio_generate( | |
| prompt=prompts[i], | |
| width=width, | |
| height=height, | |
| guidance=guidance, | |
| num_steps=num_steps, | |
| seed=seed_val, | |
| image_prompt1=img1, | |
| image_prompt2=img2, | |
| image_prompt3=img3, | |
| image_prompt4=img4, | |
| ) | |
| print(f"β [DEBUG] Image {i+1} generated using prompt: {prompts[i]}") | |
| results.append(gen_image) | |
| results.append(prompts[i]) | |
| except Exception as e: | |
| print(f"β [ERROR] Failed to generate image {i+1}: {e}") | |
| results.append(None) | |
| results.append(f"β οΈ Failed to generate: {e}") | |
| # Pad to 10 outputs: 5 image + prompt pairs | |
| while len(results) < 10: | |
| results.append(None if len(results) % 2 == 0 else "") | |
| return results | |
| generate_btn.click( | |
| fn=run_generation, | |
| inputs=[ | |
| prompt, width, height, guidance, num_steps, | |
| seed, image_prompt1, image_prompt2, image_prompt3, image_prompt4, num_outputs | |
| ], | |
| outputs=outputs | |
| ) | |
| example_text = gr.Text("", visible=False, label="Case For:") | |
| examples = get_examples("./assets/examples") | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[ | |
| example_text, prompt, | |
| image_prompt1, image_prompt2, image_prompt3, image_prompt4, | |
| seed, outputs[0] | |
| ], | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| from typing import Literal | |
| from transformers import HfArgumentParser | |
| class AppArgs: | |
| name: Literal["flux-dev", "flux-dev-fp8", "flux-schnell"] = "flux-dev" | |
| device: Literal["cuda", "cpu"] = "cuda" if torch.cuda.is_available() else "cpu" | |
| offload: bool = dataclasses.field( | |
| default=False, | |
| metadata={"help": "If True, sequentially offload unused models to CPU"} | |
| ) | |
| port: int = 7860 | |
| parser = HfArgumentParser([AppArgs]) | |
| args = parser.parse_args_into_dataclasses()[0] | |
| demo = create_demo(args.name, args.device, args.offload) | |
| demo.launch(server_port=args.port) | |