danhtran2mind's picture
Update app.py
d193794 verified
raw
history blame
7.55 kB
import gradio as gr
import torch
from PIL import Image
import numpy as np
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
from tqdm import tqdm
import os
import json
import glob
# Set device and dtype
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Model path
model_name = "danhtran2mind/ghibli-fine-tuned-sd-2.1"
# Load models with consistent dtype
vae = AutoencoderKL.from_pretrained(model_name, subfolder="vae", torch_dtype=dtype).to(device)
tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder="text_encoder", torch_dtype=dtype).to(device)
unet = UNet2DConditionModel.from_pretrained(model_name, subfolder="unet", torch_dtype=dtype).to(device)
scheduler = PNDMScheduler.from_pretrained(model_name, subfolder="scheduler")
def load_examples_from_directory(sample_output_dir="sample_output"):
"""
Load example data from the sample_output directory.
Assumes each image has a corresponding .json file with metadata.
"""
examples = []
# Look for .json files in the directory
json_files = glob.glob(os.path.join(sample_output_dir, "*.json"))
for json_file in json_files:
try:
with open(json_file, 'r') as f:
metadata = json.load(f)
# Ensure required fields are present
required_keys = ["prompt", "height", "width", "num_inference_steps", "guidance_scale", "seed"]
if all(key in metadata for key in required_keys):
examples.append([
metadata["prompt"],
metadata["height"],
metadata["width"],
metadata["num_inference_steps"],
metadata["guidance_scale"],
metadata["seed"]
])
except Exception as e:
print(f"Error loading {json_file}: {e}")
# If no valid examples are found, return a default example
if not examples:
examples = [
["a serene landscape in Ghibli style", 64, 64, 50, 3.5, 42]
]
return examples
def generate_image(prompt, height, width, num_inference_steps, guidance_scale, seed, random_seed):
# Validate inputs
# if not prompt:
# return None, "Prompt cannot be empty."
# if height % 8 != 0 or width % 8 != 0:
# return None, "Height and width must be divisible by 8 (e.g., 256, 512, 1024)."
# if height < 256 or width < 256 or height > 1024 or width > 1024:
# return None, "Height and width must be between 256 and 1024 pixels."
# if num_inference_steps < 10 or num_inference_steps > 100:
# return None, "Number of inference steps must be between 10 and 100."
# if guidance_scale < 1.0 or guidance_scale > 20.0:
# return None, "Guidance scale must be between 1.0 and 20.0."
# if seed < 0 or seed > 4294967295:
# return None, "Seed must be between 0 and 4294967295."
# Set batch size
batch_size = 1
# Handle random seed
if random_seed:
seed = torch.randint(0, 4294967295, (1,)).item()
generator = torch.Generator(device=device).manual_seed(int(seed))
# Tokenize and encode prompt
text_input = tokenizer(
[prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt"
)
with torch.no_grad():
text_embeddings = text_encoder(text_input.input_ids.to(device))[0].to(dtype=dtype)
max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
)
with torch.no_grad():
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0].to(dtype=dtype)
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
# Initialize latents
latents = torch.randn(
(batch_size, unet.config.in_channels, height, width),
generator=generator,
dtype=dtype,
device=device
)
# Set scheduler timesteps
scheduler.set_timesteps(num_inference_steps)
latents = latents * scheduler.init_noise_sigma
# Inference loop
for t in tqdm(scheduler.timesteps, desc="Generating image"):
latent_model_input = torch.cat([latents] * 2)
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
with torch.no_grad():
if device.type == "cuda":
with torch.autocast(device_type="cuda", dtype=torch.float16):
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
else:
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = scheduler.step(noise_pred, t, latents).prev_sample
# Decode latents to image
with torch.no_grad():
latents = latents / vae.config.scaling_factor
image = vae.decode(latents).sample
# Convert to PIL
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
image = (image * 255).round().astype("uint8")
pil_image = Image.fromarray(image[0])
return pil_image, f"Image generated successfully! Seed used: {seed}"
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Ghibli-Style Image Generator")
gr.Markdown("Generate images in Ghibli style using a fine-tuned Stable Diffusion model. Enter ABOVE a prompt and adjust parameters to create your image.")
gr.Markdown("**Note:** For CPU inference, execution time is long (e.g., for 64x64 resolution with 50 inference steps, time is approximately 1800 seconds).")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", placeholder="e.g., 'a serene landscape in Ghibli style'")
height = gr.Slider(label="Height", minimum=10, maximum=512, step=1, value=64)
width = gr.Slider(label="Width", minimum=10, maximum=512, step=1, value=64)
num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=100, step=1, value=50)
guidance_scale = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=20.0, step=0.5, value=3.5)
seed = gr.Slider(label="Seed", minimum=0, maximum=4294967295, step=1, value=42)
random_seed = gr.Checkbox(label="Use Random Seed", value=False)
generate_btn = gr.Button("Generate Image")
with gr.Column():
output_image = gr.Image(label="Generated Image")
output_text = gr.Textbox(label="Status")
gr.Markdown("### Example Prompts")
# Load examples from sample_output directory
examples_data = load_examples_from_directory("sample_output")
examples = gr.Dataframe(
value=examples_data,
headers=["Prompt", "Height", "Width", "Inference Steps", "Guidance Scale", "Seed"],
label="Examples"
)
generate_btn.click(
fn=generate_image,
inputs=[prompt, height, width, num_inference_steps, guidance_scale, seed, random_seed],
outputs=[output_image, output_text]
)
# Launch with limited users
demo.launch(max_threads=3)