b2bomber's picture
Update app.py
9390207 verified
import gradio as gr
import torch
from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderTiny
# For Image-to-Image, you would also import:
# from diffusers import StableDiffusionImg2ImgPipeline
from PIL import Image
import os # For better logging/debugging
from typing import Literal # For type hinting the gender choices
# --- Configuration ---
# 1. Force CPU usage for compatibility on machines without a GPU
device = "cpu"
# 2. Choose a smaller/distilled Stable Diffusion model for CPU speed
# 'nota-ai/bk-sdm-small' offers a good balance of size, speed, and reasonable quality for CPU.
# If higher quality is essential and you can tolerate much longer generation times on CPU,
# you might consider 'runwayml/stable-diffusion-v1-5', but be prepared for significant slowdowns
# and potentially higher memory consumption that might require `enable_sequential_cpu_offload()`.
model_id = "nota-ai/bk-sdm-small"
# 3. Tiny VAE for drastically faster encoding/decoding on CPU. This is a crucial optimization.
tiny_vae_id = "sayakpaul/taesd-diffusers"
# --- Model Loading ---
# Load the pipeline globally when the application starts to avoid reloading on each request.
print(f"[{os.getpid()}] Loading model: {model_id} on {device}...")
try:
# Use StableDiffusionPipeline for Text-to-Image generation (generate a new person in a style)
# If you want to transform an uploaded image (Image-to-Image), uncomment the line below
# and replace `StableDiffusionPipeline` with `StableDiffusionImg2ImgPipeline`.
pipe_class = StableDiffusionPipeline
# pipe_class = StableDiffusionImg2ImgPipeline # Uncomment this for Image-to-Image functionality
pipe = pipe_class.from_pretrained(
model_id,
torch_dtype=torch.float32, # CPU usually performs best with float32
low_cpu_mem_usage=True, # Helps reduce peak memory usage on CPU
safety_checker=None # Disable safety checker to save CPU cycles and memory for faster generation
)
print(f"[{os.getpid()}] Main pipeline loaded.")
# Load and assign the Tiny VAE for significant speed optimization in the VAE step
print(f"[{os.getpid()}] Loading Tiny VAE from {tiny_vae_id}...")
try:
pipe.vae = AutoencoderTiny.from_pretrained(tiny_vae_id, torch_dtype=torch.float32)
print(f"[{os.getpid()}] Tiny VAE loaded successfully.")
except Exception as vae_e:
print(f"[{os.getpid()}] Warning: Could not load Tiny VAE '{tiny_vae_id}': {vae_e}. Using default VAE (this will be slower).")
# Ensure the default VAE is explicitly moved to CPU if Tiny VAE fails to load
pipe.vae.to(device)
# Move entire pipeline components to CPU explicitly
pipe.to(device)
# Set up the scheduler. DDIMScheduler is a good general-purpose choice.
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
# Optional: Enable CPU offload if you encounter Out-Of-Memory errors on CPU,
# especially with larger models. Be aware that this will make generation significantly slower.
# pipe.enable_sequential_cpu_offload()
print(f"[{os.getpid()}] Model fully loaded and configured on {device}.")
except Exception as e:
print(f"[{os.getpid()}] FATAL ERROR: Failed to load models: {e}")
# Raise an exception to prevent the application from starting if model loading fails
raise RuntimeError(f"Failed to load Stable Diffusion model: {e}")
# --- Preset Styles ---
styles = {
"Pixar": "pixar style portrait of",
"Anime": "anime style portrait of",
"Cyberpunk": "cyberpunk futuristic avatar of",
"Disney": "disney movie character of",
"Sketch": "pencil sketch portrait of",
"Astronaut": "realistic astronaut with helmet, portrait of"
}
# --- Generation Function ---
def generate_avatar(image_input: Image.Image, style: str, gender: Literal["male", "female", "unspecified"]):
"""
Generates an avatar based on a chosen style and gender.
- If using StableDiffusionPipeline (Text-to-Image): The uploaded `image_input`
is used only to trigger the generation and is NOT directly used to
influence the avatar's appearance. A new person is generated based on the text.
- If using StableDiffusionImg2ImgPipeline (Image-to-Image - commented out by default):
The `image_input` WOULD be used as the base image for transformation.
"""
if image_input is None:
gr.Warning("Please upload an image to enable avatar generation. (Even if it's not directly used for content, it acts as a trigger).")
return None
# Base prompt from the selected style
base_prompt = styles[style]
# Construct the subject part of the prompt based on gender selection
gender_subject = ""
if gender == "male":
gender_subject = "a man"
elif gender == "female":
gender_subject = "a woman"
else: # unspecified
gender_subject = "a person" # Model will default based on its biases if no gender specified
# Enhance the prompt for better quality and detail in text-to-image generation
prompt = f"{base_prompt} {gender_subject}, high quality, detailed, professional photography, studio lighting, volumetric lighting, 4k, cinematic, sharp focus"
# Stronger negative prompt to avoid common issues like low quality, distortions, and undesired artifacts
negative_prompt = "low resolution, blurry, distorted, bad quality, ugly, cartoon, sketch, duplicate, out of frame, bad anatomy, deformed, extra limbs, malformed hands, missing fingers, watermark, text, signature, low contrast, oversaturated"
# Inference parameters (tuned for a balance of speed and quality on CPU)
num_inference_steps = 25 # Generally, 20-30 steps is a good range for quality vs speed on CPU
guidance_scale = 7.5 # Higher values make output closer to prompt, but can be less diverse
print(f"[{os.getpid()}] Generating for style: '{style}', gender: '{gender}', with prompt: '{prompt}' (Steps: {num_inference_steps}, Guidance: {guidance_scale})")
try:
# Use torch.no_grad() or torch.inference_mode() to disable gradient calculations
# during inference, which saves memory and speeds up computation.
with torch.no_grad(): # For PyTorch >= 1.9, torch.inference_mode() is also an option
if isinstance(pipe, StableDiffusionPipeline):
# Text-to-Image generation: Image_input is ignored for content
generated_image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
height=512, # Stable Diffusion 1.x models are usually trained at 512x512
width=512
).images[0]
# elif isinstance(pipe, StableDiffusionImg2ImgPipeline):
# # Image-to-Image generation: Uncomment this block if you switch to Img2ImgPipeline
# # The 'strength' parameter controls how much noise is added to the input image.
# # 0.0 means no change, 1.0 means complete re-imagining (like text-to-image).
# # A value around 0.7-0.8 is typical for style transfer.
# strength = 0.75
# generated_image = pipe(
# prompt=prompt,
# image=image_input, # Pass the uploaded image here for img2img
# negative_prompt=negative_prompt,
# num_inference_steps=num_inference_steps,
# guidance_scale=guidance_scale,
# strength=strength
# ).images[0]
else:
raise ValueError("Unsupported pipeline type. Please check model loading.")
print(f"[{os.getpid()}] Image generation complete.")
return generated_image
except Exception as e:
print(f"[{os.getpid()}] Error during image generation: {e}")
# Display an error message to the user in the Gradio interface
gr.Error(f"An error occurred during image generation: {e}")
return None # Return None to clear the output image
# --- Gradio Interface Definition ---
with gr.Blocks() as demo:
gr.Markdown("## 🎨 Stable Diffusion Avatar Generator with Preset Styles (CPU Optimized)")
gr.Markdown(
"This demo uses a smaller, distilled Stable Diffusion model and is optimized for CPU inference. "
"Generation will still take time on CPU compared to GPU (e.g., 20-60 seconds per image depending on CPU and parameters).<br>"
"**Note:** The uploaded image is currently used only to trigger generation and is **not directly influencing the avatar's appearance**. "
"It's here for your reference or potential future Image-to-Image features. You will get a new person in the chosen style."
)
with gr.Row():
with gr.Column():
# Image input component. type="pil" ensures a PIL Image object is passed to the function.
image_input = gr.Image(
label="Upload your photo",
type="pil",
sources=["upload", "webcam"], # Allow file upload or webcam capture
# Optional: Add a placeholder image path if you want a default visual
# value="assets/placeholder.jpg"
)
style_selector = gr.Radio(
choices=list(styles.keys()),
label="Choose a style",
value="Anime", # Default selected style
info="Select the artistic style for your avatar."
)
gender_selector = gr.Radio(
choices=["male", "female", "unspecified"],
label="Choose a Gender",
value="male", # Default to male to address your specific issue
info="Explicitly set the gender of the generated person. 'Unspecified' may lead to biased results from the model."
)
generate_btn = gr.Button("Generate Avatar", variant="primary")
with gr.Column():
output_image = gr.Image(label="Generated Avatar")
# Connect the button click to the generation function, passing all inputs
generate_btn.click(
fn=generate_avatar,
inputs=[image_input, style_selector, gender_selector], # Now includes gender_selector
outputs=output_image
)
# Optional: Add examples for quick testing
gr.Examples(
examples=[
# Example format: [image_path_or_None, style_name, gender]
# Use None for image_input as it's not directly influencing the output in text-to-image mode
[None, "Pixar", "male"],
[None, "Anime", "female"],
[None, "Cyberpunk", "unspecified"], # To show what 'unspecified' might produce
[None, "Disney", "male"],
[None, "Sketch", "female"],
[None, "Astronaut", "male"]
],
inputs=[image_input, style_selector, gender_selector],
# fn=generate_avatar, # Uncomment if you want examples to run the generation live
# outputs=output_image,
cache_examples=False, # Set to True if examples are pre-computed images, False for live generation
label="Quick Examples (Generates new images each time)"
)
# Launch the Gradio application
# share=True will generate a public link (useful for sharing demos temporarily)
# auth=("username", "password") for basic authentication
demo.launch(inbrowser=True, show_error=True)