|
import random |
|
import gradio as gr |
|
import numpy as np |
|
import torch |
|
import spaces |
|
from diffusers import FluxPipeline |
|
from PIL import Image |
|
from diffusers.utils import export_to_gif |
|
from transformers import pipeline |
|
|
|
|
|
|
|
|
|
FRAMES = 4 |
|
DEFAULT_HEIGHT = 256 |
|
DEFAULT_FPS = 8 |
|
MAX_SEED = np.iinfo(np.int32).max |
|
|
|
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
pipe = ( |
|
FluxPipeline.from_pretrained( |
|
"black-forest-labs/FLUX.1-dev", |
|
torch_dtype=torch.float16, |
|
) |
|
.to(device) |
|
) |
|
|
|
|
|
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en") |
|
|
|
|
|
|
|
|
|
|
|
def split_image(input_image: Image.Image, frame_size: int) -> list[Image.Image]: |
|
"""Cut a wide strip into equal square frames.""" |
|
return [ |
|
input_image.crop((i * frame_size, 0, (i + 1) * frame_size, frame_size)) |
|
for i in range(FRAMES) |
|
] |
|
|
|
|
|
def translate_to_english(text: str) -> str: |
|
"""Translate Korean text to English if necessary.""" |
|
return translator(text)[0]["translation_text"] |
|
|
|
|
|
@spaces.GPU() |
|
def predict( |
|
prompt: str, |
|
seed: int = 42, |
|
randomize_seed: bool = False, |
|
guidance_scale: float = 7.0, |
|
num_inference_steps: int = 40, |
|
height: int = DEFAULT_HEIGHT, |
|
fps: int = DEFAULT_FPS, |
|
progress: gr.Progress = gr.Progress(track_tqdm=True), |
|
): |
|
|
|
if any("\u3131" <= ch <= "\u318E" or "\uAC00" <= ch <= "\uD7A3" for ch in prompt): |
|
prompt = translate_to_english(prompt) |
|
|
|
|
|
prompt_template = ( |
|
f"A side-by-side {FRAMES} frame image showing consecutive stills from a looped gif moving left to right. " |
|
f"The gif is of {prompt}." |
|
) |
|
|
|
|
|
if randomize_seed: |
|
seed = random.randint(0, MAX_SEED) |
|
|
|
width = FRAMES * height |
|
|
|
|
|
image = pipe( |
|
prompt=prompt_template, |
|
guidance_scale=guidance_scale, |
|
num_inference_steps=num_inference_steps, |
|
num_images_per_prompt=1, |
|
generator=torch.Generator(device).manual_seed(seed), |
|
height=height, |
|
width=width, |
|
).images[0] |
|
|
|
|
|
gif_path = export_to_gif(split_image(image, height), "flux.gif", fps=fps) |
|
return gif_path, image, seed |
|
|
|
|
|
|
|
|
|
|
|
|
|
css = """ |
|
#col-container {max-width: 820px; margin: 0 auto;} |
|
footer {visibility: hidden;} |
|
""" |
|
|
|
examples = [ |
|
"cat lazily swinging its paws in mid-air", |
|
"panda shaking its hips", |
|
"flower blooming in timelapse", |
|
] |
|
|
|
with gr.Blocks(theme="soft", css=css) as demo: |
|
gr.Markdown("<h1 style='text-align:center'>FLUX GIF Generator</h1>") |
|
|
|
with gr.Column(elem_id="col-container"): |
|
|
|
with gr.Row(): |
|
prompt = gr.Text( |
|
label="", show_label=False, max_lines=1, placeholder="Enter your prompt here…" |
|
) |
|
submit = gr.Button("Generate", scale=0) |
|
|
|
|
|
output_gif = gr.Image(label="", show_label=False) |
|
output_stills = gr.Image(label="", show_label=False, elem_id="stills") |
|
|
|
|
|
with gr.Accordion("Advanced settings", open=False): |
|
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) |
|
randomize_seed = gr.Checkbox(label="Randomize seed", value=True) |
|
with gr.Row(): |
|
guidance_scale = gr.Slider( |
|
label="Guidance scale", minimum=1, maximum=15, step=0.1, value=7.0 |
|
) |
|
num_inference_steps = gr.Slider( |
|
label="Inference steps", minimum=10, maximum=60, step=1, value=40 |
|
) |
|
with gr.Row(): |
|
height = gr.Slider( |
|
label="Frame size (px)", minimum=256, maximum=512, step=64, value=DEFAULT_HEIGHT |
|
) |
|
fps = gr.Slider( |
|
label="GIF FPS", minimum=4, maximum=20, step=1, value=DEFAULT_FPS |
|
) |
|
|
|
|
|
gr.Examples( |
|
examples=examples, |
|
fn=predict, |
|
inputs=[prompt], |
|
outputs=[output_gif, output_stills, seed], |
|
cache_examples="lazy", |
|
) |
|
|
|
|
|
gr.on( |
|
triggers=[submit.click, prompt.submit], |
|
fn=predict, |
|
inputs=[ |
|
prompt, |
|
seed, |
|
randomize_seed, |
|
guidance_scale, |
|
num_inference_steps, |
|
height, |
|
fps, |
|
], |
|
outputs=[output_gif, output_stills, seed], |
|
) |
|
|
|
demo.launch() |
|
|