|
|
import gradio as gr |
|
|
import torch |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
from transformers import CLIPTextModel, CLIPTokenizer |
|
|
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler |
|
|
from tqdm import tqdm |
|
|
import os |
|
|
import json |
|
|
import glob |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
|
|
|
|
|
|
model_name = "danhtran2mind/ghibli-fine-tuned-sd-2.1" |
|
|
|
|
|
|
|
|
vae = AutoencoderKL.from_pretrained(model_name, subfolder="vae", torch_dtype=dtype).to(device) |
|
|
tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer") |
|
|
text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder="text_encoder", torch_dtype=dtype).to(device) |
|
|
unet = UNet2DConditionModel.from_pretrained(model_name, subfolder="unet", torch_dtype=dtype).to(device) |
|
|
scheduler = PNDMScheduler.from_pretrained(model_name, subfolder="scheduler") |
|
|
|
|
|
def load_examples_from_directory(sample_output_dir="sample_output"): |
|
|
""" |
|
|
Load example data from the sample_output directory. |
|
|
Assumes each image has a corresponding .json file with metadata. |
|
|
""" |
|
|
examples = [] |
|
|
|
|
|
json_files = glob.glob(os.path.join(sample_output_dir, "*.json")) |
|
|
|
|
|
for json_file in json_files: |
|
|
try: |
|
|
with open(json_file, 'r') as f: |
|
|
metadata = json.load(f) |
|
|
|
|
|
required_keys = ["prompt", "height", "width", "num_inference_steps", "guidance_scale", "seed"] |
|
|
if all(key in metadata for key in required_keys): |
|
|
examples.append([ |
|
|
metadata["prompt"], |
|
|
metadata["height"], |
|
|
metadata["width"], |
|
|
metadata["num_inference_steps"], |
|
|
metadata["guidance_scale"], |
|
|
metadata["seed"] |
|
|
]) |
|
|
except Exception as e: |
|
|
print(f"Error loading {json_file}: {e}") |
|
|
|
|
|
|
|
|
if not examples: |
|
|
examples = [ |
|
|
["a serene landscape in Ghibli style", 64, 64, 50, 3.5, 42] |
|
|
] |
|
|
return examples |
|
|
|
|
|
def generate_image(prompt, height, width, num_inference_steps, guidance_scale, seed, random_seed): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
batch_size = 1 |
|
|
|
|
|
|
|
|
if random_seed: |
|
|
seed = torch.randint(0, 4294967295, (1,)).item() |
|
|
generator = torch.Generator(device=device).manual_seed(int(seed)) |
|
|
|
|
|
|
|
|
text_input = tokenizer( |
|
|
[prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt" |
|
|
) |
|
|
with torch.no_grad(): |
|
|
text_embeddings = text_encoder(text_input.input_ids.to(device))[0].to(dtype=dtype) |
|
|
|
|
|
max_length = text_input.input_ids.shape[-1] |
|
|
uncond_input = tokenizer( |
|
|
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" |
|
|
) |
|
|
with torch.no_grad(): |
|
|
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0].to(dtype=dtype) |
|
|
|
|
|
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) |
|
|
|
|
|
|
|
|
latents = torch.randn( |
|
|
(batch_size, unet.config.in_channels, height, width), |
|
|
generator=generator, |
|
|
dtype=dtype, |
|
|
device=device |
|
|
) |
|
|
|
|
|
|
|
|
scheduler.set_timesteps(num_inference_steps) |
|
|
latents = latents * scheduler.init_noise_sigma |
|
|
|
|
|
|
|
|
for t in tqdm(scheduler.timesteps, desc="Generating image"): |
|
|
latent_model_input = torch.cat([latents] * 2) |
|
|
latent_model_input = scheduler.scale_model_input(latent_model_input, t) |
|
|
|
|
|
with torch.no_grad(): |
|
|
if device.type == "cuda": |
|
|
with torch.autocast(device_type="cuda", dtype=torch.float16): |
|
|
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample |
|
|
else: |
|
|
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample |
|
|
|
|
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
|
latents = scheduler.step(noise_pred, t, latents).prev_sample |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
latents = latents / vae.config.scaling_factor |
|
|
image = vae.decode(latents).sample |
|
|
|
|
|
|
|
|
image = (image / 2 + 0.5).clamp(0, 1) |
|
|
image = image.detach().cpu().permute(0, 2, 3, 1).numpy() |
|
|
image = (image * 255).round().astype("uint8") |
|
|
pil_image = Image.fromarray(image[0]) |
|
|
|
|
|
return pil_image, f"Image generated successfully! Seed used: {seed}" |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# Ghibli-Style Image Generator") |
|
|
gr.Markdown("Generate images in Ghibli style using a fine-tuned Stable Diffusion model. Enter ABOVE a prompt and adjust parameters to create your image.") |
|
|
gr.Markdown("**Note:** For CPU inference, execution time is long (e.g., for 64x64 resolution with 50 inference steps, time is approximately 1800 seconds).") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
prompt = gr.Textbox(label="Prompt", placeholder="e.g., 'a serene landscape in Ghibli style'") |
|
|
height = gr.Slider(label="Height", minimum=10, maximum=512, step=1, value=64) |
|
|
width = gr.Slider(label="Width", minimum=10, maximum=512, step=1, value=64) |
|
|
num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=100, step=1, value=50) |
|
|
guidance_scale = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=20.0, step=0.5, value=3.5) |
|
|
seed = gr.Slider(label="Seed", minimum=0, maximum=4294967295, step=1, value=42) |
|
|
random_seed = gr.Checkbox(label="Use Random Seed", value=False) |
|
|
generate_btn = gr.Button("Generate Image") |
|
|
with gr.Column(): |
|
|
output_image = gr.Image(label="Generated Image") |
|
|
output_text = gr.Textbox(label="Status") |
|
|
|
|
|
gr.Markdown("### Example Prompts") |
|
|
|
|
|
examples_data = load_examples_from_directory("sample_output") |
|
|
examples = gr.Dataframe( |
|
|
value=examples_data, |
|
|
headers=["Prompt", "Height", "Width", "Inference Steps", "Guidance Scale", "Seed"], |
|
|
label="Examples" |
|
|
) |
|
|
|
|
|
generate_btn.click( |
|
|
fn=generate_image, |
|
|
inputs=[prompt, height, width, num_inference_steps, guidance_scale, seed, random_seed], |
|
|
outputs=[output_image, output_text] |
|
|
) |
|
|
|
|
|
|
|
|
demo.launch(max_threads=3) |