Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from diffusers import FluxPipeline | |
| import huggingface_hub | |
| from huggingface_hub import InferenceClient | |
| import os | |
| huggingface_hub.login(token=os.getenv("HUGGINGFACE_API_TOKEN")) | |
| # Initialize the Flux pipeline | |
| def initialize_flux_pipeline(): | |
| pipe = FluxPipeline.from_pretrained( | |
| "black-forest-labs/FLUX.1-dev", | |
| torch_dtype=torch.bfloat16 | |
| ) | |
| pipe.enable_model_cpu_offload() | |
| pipe.load_lora_weights("EvanZhouDev/open-genmoji") | |
| return pipe | |
| flux_pipeline = initialize_flux_pipeline() | |
| # Initialize the language model client | |
| llm_client = InferenceClient("Qwen/Qwen2.5-72B-Instruct", token=os.getenv("HUGGINGFACE_API_TOKEN")) | |
| # Function to refine the prompt | |
| def refine_prompt(original_prompt): | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": ( | |
| "You are helping create a prompt for a Emoji generation image model. An emoji must be easily " | |
| "interpreted when small so details must be exaggerated to be clear. Your goal is to use descriptions " | |
| "to achieve this.\n\nYou will receive a user description, and you must rephrase it to consist of " | |
| "short phrases separated by periods, adding detail to everything the user provides.\n\nAdd describe " | |
| "the color of all parts or components of the emoji. Unless otherwise specified by the user, do not " | |
| "describe people. Do not describe the background of the image. Your output should be in the format:\n\n" | |
| "```emoji of {description}. {addon phrases}. 3D lighting. no cast shadows.```\n\nThe description " | |
| "should be a 1 sentence of your interpretation of the emoji. Then, you may choose to add addon phrases." | |
| " You must use the following in the given scenarios:\n\n- \"cute.\": If generating anything that's not " | |
| "an object, and also not a human\n- \"enlarged head in cartoon style.\": ONLY animals\n- \"head is " | |
| "turned towards viewer.\": ONLY humans or animals\n- \"detailed texture.\": ONLY objects\n\nFurther " | |
| "addon phrases may be added to ensure the clarity of the emoji." | |
| ), | |
| }, | |
| {"role": "user", "content": original_prompt}, | |
| ] | |
| completion = llm_client.chat_completion(messages, max_tokens=100) | |
| refined = completion["choices"][0]["message"]["content"].strip() | |
| return refined | |
| # Define the process function | |
| def process(prompt, guidance_scale, num_inference_steps, height, width, seed): | |
| print(f"Original Prompt: {prompt}") | |
| # Refine the prompt | |
| try: | |
| refined_prompt = refine_prompt(prompt) | |
| print(f"Refined Prompt: {refined_prompt}") | |
| except Exception as e: | |
| return f"Error refining prompt: {str(e)}" | |
| # Set the random generator seed | |
| generator = torch.Generator(device="cuda").manual_seed(seed) | |
| try: | |
| # Generate the image | |
| output = flux_pipeline( | |
| prompt=refined_prompt, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| height=height, | |
| width=width, | |
| generator=generator, | |
| ) | |
| image = output.images[0] | |
| return image | |
| except Exception as e: | |
| return f"Error generating image: {str(e)}" | |
| # Create the Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Flux Text-to-Image Generator with Prompt Refinement") | |
| # User inputs | |
| with gr.Row(): | |
| prompt_input = gr.Textbox(label="Enter a Prompt", placeholder="Describe your image") | |
| guidance_scale_input = gr.Slider( | |
| label="Guidance Scale", minimum=1.0, maximum=10.0, value=3.5, step=0.1 | |
| ) | |
| with gr.Row(): | |
| num_inference_steps_input = gr.Slider( | |
| label="Inference Steps", minimum=1, maximum=100, value=50, step=1 | |
| ) | |
| seed_input = gr.Number(label="Seed", value=42, precision=0) | |
| with gr.Row(): | |
| height_input = gr.Slider(label="Height", minimum=256, maximum=2048, value=768, step=64) | |
| width_input = gr.Slider(label="Width", minimum=256, maximum=2048, value=1360, step=64) | |
| # Output components | |
| refined_prompt_output = gr.Textbox(label="Refined Prompt", interactive=False) | |
| image_output = gr.Image(label="Generated Image") | |
| # Button to generate the image | |
| generate_button = gr.Button("Generate Image") | |
| # Define button click behavior | |
| generate_button.click( | |
| fn=lambda prompt, *args: (refine_prompt(prompt), process(prompt, *args)), | |
| inputs=[ | |
| prompt_input, | |
| guidance_scale_input, | |
| num_inference_steps_input, | |
| height_input, | |
| width_input, | |
| seed_input, | |
| ], | |
| outputs=[refined_prompt_output, image_output], | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch(show_error=True) | |