Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderTiny | |
# For Image-to-Image, you would also import: | |
# from diffusers import StableDiffusionImg2ImgPipeline | |
from PIL import Image | |
import os # For better logging/debugging | |
from typing import Literal # For type hinting the gender choices | |
# --- Configuration --- | |
# 1. Force CPU usage for compatibility on machines without a GPU | |
device = "cpu" | |
# 2. Choose a smaller/distilled Stable Diffusion model for CPU speed | |
# 'nota-ai/bk-sdm-small' offers a good balance of size, speed, and reasonable quality for CPU. | |
# If higher quality is essential and you can tolerate much longer generation times on CPU, | |
# you might consider 'runwayml/stable-diffusion-v1-5', but be prepared for significant slowdowns | |
# and potentially higher memory consumption that might require `enable_sequential_cpu_offload()`. | |
model_id = "nota-ai/bk-sdm-small" | |
# 3. Tiny VAE for drastically faster encoding/decoding on CPU. This is a crucial optimization. | |
tiny_vae_id = "sayakpaul/taesd-diffusers" | |
# --- Model Loading --- | |
# Load the pipeline globally when the application starts to avoid reloading on each request. | |
print(f"[{os.getpid()}] Loading model: {model_id} on {device}...") | |
try: | |
# Use StableDiffusionPipeline for Text-to-Image generation (generate a new person in a style) | |
# If you want to transform an uploaded image (Image-to-Image), uncomment the line below | |
# and replace `StableDiffusionPipeline` with `StableDiffusionImg2ImgPipeline`. | |
pipe_class = StableDiffusionPipeline | |
# pipe_class = StableDiffusionImg2ImgPipeline # Uncomment this for Image-to-Image functionality | |
pipe = pipe_class.from_pretrained( | |
model_id, | |
torch_dtype=torch.float32, # CPU usually performs best with float32 | |
low_cpu_mem_usage=True, # Helps reduce peak memory usage on CPU | |
safety_checker=None # Disable safety checker to save CPU cycles and memory for faster generation | |
) | |
print(f"[{os.getpid()}] Main pipeline loaded.") | |
# Load and assign the Tiny VAE for significant speed optimization in the VAE step | |
print(f"[{os.getpid()}] Loading Tiny VAE from {tiny_vae_id}...") | |
try: | |
pipe.vae = AutoencoderTiny.from_pretrained(tiny_vae_id, torch_dtype=torch.float32) | |
print(f"[{os.getpid()}] Tiny VAE loaded successfully.") | |
except Exception as vae_e: | |
print(f"[{os.getpid()}] Warning: Could not load Tiny VAE '{tiny_vae_id}': {vae_e}. Using default VAE (this will be slower).") | |
# Ensure the default VAE is explicitly moved to CPU if Tiny VAE fails to load | |
pipe.vae.to(device) | |
# Move entire pipeline components to CPU explicitly | |
pipe.to(device) | |
# Set up the scheduler. DDIMScheduler is a good general-purpose choice. | |
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) | |
# Optional: Enable CPU offload if you encounter Out-Of-Memory errors on CPU, | |
# especially with larger models. Be aware that this will make generation significantly slower. | |
# pipe.enable_sequential_cpu_offload() | |
print(f"[{os.getpid()}] Model fully loaded and configured on {device}.") | |
except Exception as e: | |
print(f"[{os.getpid()}] FATAL ERROR: Failed to load models: {e}") | |
# Raise an exception to prevent the application from starting if model loading fails | |
raise RuntimeError(f"Failed to load Stable Diffusion model: {e}") | |
# --- Preset Styles --- | |
styles = { | |
"Pixar": "pixar style portrait of", | |
"Anime": "anime style portrait of", | |
"Cyberpunk": "cyberpunk futuristic avatar of", | |
"Disney": "disney movie character of", | |
"Sketch": "pencil sketch portrait of", | |
"Astronaut": "realistic astronaut with helmet, portrait of" | |
} | |
# --- Generation Function --- | |
def generate_avatar(image_input: Image.Image, style: str, gender: Literal["male", "female", "unspecified"]): | |
""" | |
Generates an avatar based on a chosen style and gender. | |
- If using StableDiffusionPipeline (Text-to-Image): The uploaded `image_input` | |
is used only to trigger the generation and is NOT directly used to | |
influence the avatar's appearance. A new person is generated based on the text. | |
- If using StableDiffusionImg2ImgPipeline (Image-to-Image - commented out by default): | |
The `image_input` WOULD be used as the base image for transformation. | |
""" | |
if image_input is None: | |
gr.Warning("Please upload an image to enable avatar generation. (Even if it's not directly used for content, it acts as a trigger).") | |
return None | |
# Base prompt from the selected style | |
base_prompt = styles[style] | |
# Construct the subject part of the prompt based on gender selection | |
gender_subject = "" | |
if gender == "male": | |
gender_subject = "a man" | |
elif gender == "female": | |
gender_subject = "a woman" | |
else: # unspecified | |
gender_subject = "a person" # Model will default based on its biases if no gender specified | |
# Enhance the prompt for better quality and detail in text-to-image generation | |
prompt = f"{base_prompt} {gender_subject}, high quality, detailed, professional photography, studio lighting, volumetric lighting, 4k, cinematic, sharp focus" | |
# Stronger negative prompt to avoid common issues like low quality, distortions, and undesired artifacts | |
negative_prompt = "low resolution, blurry, distorted, bad quality, ugly, cartoon, sketch, duplicate, out of frame, bad anatomy, deformed, extra limbs, malformed hands, missing fingers, watermark, text, signature, low contrast, oversaturated" | |
# Inference parameters (tuned for a balance of speed and quality on CPU) | |
num_inference_steps = 25 # Generally, 20-30 steps is a good range for quality vs speed on CPU | |
guidance_scale = 7.5 # Higher values make output closer to prompt, but can be less diverse | |
print(f"[{os.getpid()}] Generating for style: '{style}', gender: '{gender}', with prompt: '{prompt}' (Steps: {num_inference_steps}, Guidance: {guidance_scale})") | |
try: | |
# Use torch.no_grad() or torch.inference_mode() to disable gradient calculations | |
# during inference, which saves memory and speeds up computation. | |
with torch.no_grad(): # For PyTorch >= 1.9, torch.inference_mode() is also an option | |
if isinstance(pipe, StableDiffusionPipeline): | |
# Text-to-Image generation: Image_input is ignored for content | |
generated_image = pipe( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
height=512, # Stable Diffusion 1.x models are usually trained at 512x512 | |
width=512 | |
).images[0] | |
# elif isinstance(pipe, StableDiffusionImg2ImgPipeline): | |
# # Image-to-Image generation: Uncomment this block if you switch to Img2ImgPipeline | |
# # The 'strength' parameter controls how much noise is added to the input image. | |
# # 0.0 means no change, 1.0 means complete re-imagining (like text-to-image). | |
# # A value around 0.7-0.8 is typical for style transfer. | |
# strength = 0.75 | |
# generated_image = pipe( | |
# prompt=prompt, | |
# image=image_input, # Pass the uploaded image here for img2img | |
# negative_prompt=negative_prompt, | |
# num_inference_steps=num_inference_steps, | |
# guidance_scale=guidance_scale, | |
# strength=strength | |
# ).images[0] | |
else: | |
raise ValueError("Unsupported pipeline type. Please check model loading.") | |
print(f"[{os.getpid()}] Image generation complete.") | |
return generated_image | |
except Exception as e: | |
print(f"[{os.getpid()}] Error during image generation: {e}") | |
# Display an error message to the user in the Gradio interface | |
gr.Error(f"An error occurred during image generation: {e}") | |
return None # Return None to clear the output image | |
# --- Gradio Interface Definition --- | |
with gr.Blocks() as demo: | |
gr.Markdown("## 🎨 Stable Diffusion Avatar Generator with Preset Styles (CPU Optimized)") | |
gr.Markdown( | |
"This demo uses a smaller, distilled Stable Diffusion model and is optimized for CPU inference. " | |
"Generation will still take time on CPU compared to GPU (e.g., 20-60 seconds per image depending on CPU and parameters).<br>" | |
"**Note:** The uploaded image is currently used only to trigger generation and is **not directly influencing the avatar's appearance**. " | |
"It's here for your reference or potential future Image-to-Image features. You will get a new person in the chosen style." | |
) | |
with gr.Row(): | |
with gr.Column(): | |
# Image input component. type="pil" ensures a PIL Image object is passed to the function. | |
image_input = gr.Image( | |
label="Upload your photo", | |
type="pil", | |
sources=["upload", "webcam"], # Allow file upload or webcam capture | |
# Optional: Add a placeholder image path if you want a default visual | |
# value="assets/placeholder.jpg" | |
) | |
style_selector = gr.Radio( | |
choices=list(styles.keys()), | |
label="Choose a style", | |
value="Anime", # Default selected style | |
info="Select the artistic style for your avatar." | |
) | |
gender_selector = gr.Radio( | |
choices=["male", "female", "unspecified"], | |
label="Choose a Gender", | |
value="male", # Default to male to address your specific issue | |
info="Explicitly set the gender of the generated person. 'Unspecified' may lead to biased results from the model." | |
) | |
generate_btn = gr.Button("Generate Avatar", variant="primary") | |
with gr.Column(): | |
output_image = gr.Image(label="Generated Avatar") | |
# Connect the button click to the generation function, passing all inputs | |
generate_btn.click( | |
fn=generate_avatar, | |
inputs=[image_input, style_selector, gender_selector], # Now includes gender_selector | |
outputs=output_image | |
) | |
# Optional: Add examples for quick testing | |
gr.Examples( | |
examples=[ | |
# Example format: [image_path_or_None, style_name, gender] | |
# Use None for image_input as it's not directly influencing the output in text-to-image mode | |
[None, "Pixar", "male"], | |
[None, "Anime", "female"], | |
[None, "Cyberpunk", "unspecified"], # To show what 'unspecified' might produce | |
[None, "Disney", "male"], | |
[None, "Sketch", "female"], | |
[None, "Astronaut", "male"] | |
], | |
inputs=[image_input, style_selector, gender_selector], | |
# fn=generate_avatar, # Uncomment if you want examples to run the generation live | |
# outputs=output_image, | |
cache_examples=False, # Set to True if examples are pre-computed images, False for live generation | |
label="Quick Examples (Generates new images each time)" | |
) | |
# Launch the Gradio application | |
# share=True will generate a public link (useful for sharing demos temporarily) | |
# auth=("username", "password") for basic authentication | |
demo.launch(inbrowser=True, show_error=True) |