Spaces:
aiqtech
/
Running on Zero

flxgif / app.py
aiqtech's picture
Update app.py
856ab89 verified
import random
import gradio as gr
import numpy as np
import torch
import spaces
from diffusers import FluxPipeline
from PIL import Image
from diffusers.utils import export_to_gif
from transformers import pipeline
# -------------------------
# Configuration constants
# -------------------------
FRAMES = 4 # number of stills laid out horizontally
DEFAULT_HEIGHT = 256 # per‑frame size (px)
DEFAULT_FPS = 8 # smoother playback than the original 4 fps
MAX_SEED = np.iinfo(np.int32).max
# -------------------------
# Model initialisation
# -------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = (
FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.float16, # slightly higher precision than bfloat16 for crisper output
)
.to(device)
)
# English is the primary UI language, but Korean prompts are still accepted & translated.
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
# -------------------------
# Helper functions
# -------------------------
def split_image(input_image: Image.Image, frame_size: int) -> list[Image.Image]:
"""Cut a wide strip into equal square frames."""
return [
input_image.crop((i * frame_size, 0, (i + 1) * frame_size, frame_size))
for i in range(FRAMES)
]
def translate_to_english(text: str) -> str:
"""Translate Korean text to English if necessary."""
return translator(text)[0]["translation_text"]
@spaces.GPU()
def predict(
prompt: str,
seed: int = 42,
randomize_seed: bool = False,
guidance_scale: float = 7.0,
num_inference_steps: int = 40,
height: int = DEFAULT_HEIGHT,
fps: int = DEFAULT_FPS,
progress: gr.Progress = gr.Progress(track_tqdm=True),
):
# 1) Language handling
if any("\u3131" <= ch <= "\u318E" or "\uAC00" <= ch <= "\uD7A3" for ch in prompt):
prompt = translate_to_english(prompt)
# 2) Prompt template
prompt_template = (
f"A side-by-side {FRAMES} frame image showing consecutive stills from a looped gif moving left to right. "
f"The gif is of {prompt}."
)
# 3) Seed control
if randomize_seed:
seed = random.randint(0, MAX_SEED)
width = FRAMES * height # maintain square frames
# 4) Generation
image = pipe(
prompt=prompt_template,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
num_images_per_prompt=1,
generator=torch.Generator(device).manual_seed(seed),
height=height,
width=width,
).images[0]
# 5) Assemble gif
gif_path = export_to_gif(split_image(image, height), "flux.gif", fps=fps)
return gif_path, image, seed
# -------------------------
# Interface
# -------------------------
css = """
#col-container {max-width: 820px; margin: 0 auto;}
footer {visibility: hidden;}
"""
examples = [
"cat lazily swinging its paws in mid-air",
"panda shaking its hips",
"flower blooming in timelapse",
]
with gr.Blocks(theme="soft", css=css) as demo:
gr.Markdown("<h1 style='text-align:center'>FLUX GIF Generator</h1>")
with gr.Column(elem_id="col-container"):
# Prompt row
with gr.Row():
prompt = gr.Text(
label="", show_label=False, max_lines=1, placeholder="Enter your prompt here…"
)
submit = gr.Button("Generate", scale=0)
# Outputs
output_gif = gr.Image(label="", show_label=False)
output_stills = gr.Image(label="", show_label=False, elem_id="stills")
# Advanced controls
with gr.Accordion("Advanced settings", open=False):
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale", minimum=1, maximum=15, step=0.1, value=7.0
)
num_inference_steps = gr.Slider(
label="Inference steps", minimum=10, maximum=60, step=1, value=40
)
with gr.Row():
height = gr.Slider(
label="Frame size (px)", minimum=256, maximum=512, step=64, value=DEFAULT_HEIGHT
)
fps = gr.Slider(
label="GIF FPS", minimum=4, maximum=20, step=1, value=DEFAULT_FPS
)
# Example prompts
gr.Examples(
examples=examples,
fn=predict,
inputs=[prompt],
outputs=[output_gif, output_stills, seed],
cache_examples="lazy",
)
# Event wiring
gr.on(
triggers=[submit.click, prompt.submit],
fn=predict,
inputs=[
prompt,
seed,
randomize_seed,
guidance_scale,
num_inference_steps,
height,
fps,
],
outputs=[output_gif, output_stills, seed],
)
demo.launch()