import gradio as gr import torch from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderTiny # For Image-to-Image, you would also import: # from diffusers import StableDiffusionImg2ImgPipeline from PIL import Image import os # For better logging/debugging from typing import Literal # For type hinting the gender choices # --- Configuration --- # 1. Force CPU usage for compatibility on machines without a GPU device = "cpu" # 2. Choose a smaller/distilled Stable Diffusion model for CPU speed # 'nota-ai/bk-sdm-small' offers a good balance of size, speed, and reasonable quality for CPU. # If higher quality is essential and you can tolerate much longer generation times on CPU, # you might consider 'runwayml/stable-diffusion-v1-5', but be prepared for significant slowdowns # and potentially higher memory consumption that might require `enable_sequential_cpu_offload()`. model_id = "nota-ai/bk-sdm-small" # 3. Tiny VAE for drastically faster encoding/decoding on CPU. This is a crucial optimization. tiny_vae_id = "sayakpaul/taesd-diffusers" # --- Model Loading --- # Load the pipeline globally when the application starts to avoid reloading on each request. print(f"[{os.getpid()}] Loading model: {model_id} on {device}...") try: # Use StableDiffusionPipeline for Text-to-Image generation (generate a new person in a style) # If you want to transform an uploaded image (Image-to-Image), uncomment the line below # and replace `StableDiffusionPipeline` with `StableDiffusionImg2ImgPipeline`. pipe_class = StableDiffusionPipeline # pipe_class = StableDiffusionImg2ImgPipeline # Uncomment this for Image-to-Image functionality pipe = pipe_class.from_pretrained( model_id, torch_dtype=torch.float32, # CPU usually performs best with float32 low_cpu_mem_usage=True, # Helps reduce peak memory usage on CPU safety_checker=None # Disable safety checker to save CPU cycles and memory for faster generation ) print(f"[{os.getpid()}] Main pipeline loaded.") # Load and assign the Tiny VAE for significant speed optimization in the VAE step print(f"[{os.getpid()}] Loading Tiny VAE from {tiny_vae_id}...") try: pipe.vae = AutoencoderTiny.from_pretrained(tiny_vae_id, torch_dtype=torch.float32) print(f"[{os.getpid()}] Tiny VAE loaded successfully.") except Exception as vae_e: print(f"[{os.getpid()}] Warning: Could not load Tiny VAE '{tiny_vae_id}': {vae_e}. Using default VAE (this will be slower).") # Ensure the default VAE is explicitly moved to CPU if Tiny VAE fails to load pipe.vae.to(device) # Move entire pipeline components to CPU explicitly pipe.to(device) # Set up the scheduler. DDIMScheduler is a good general-purpose choice. pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) # Optional: Enable CPU offload if you encounter Out-Of-Memory errors on CPU, # especially with larger models. Be aware that this will make generation significantly slower. # pipe.enable_sequential_cpu_offload() print(f"[{os.getpid()}] Model fully loaded and configured on {device}.") except Exception as e: print(f"[{os.getpid()}] FATAL ERROR: Failed to load models: {e}") # Raise an exception to prevent the application from starting if model loading fails raise RuntimeError(f"Failed to load Stable Diffusion model: {e}") # --- Preset Styles --- styles = { "Pixar": "pixar style portrait of", "Anime": "anime style portrait of", "Cyberpunk": "cyberpunk futuristic avatar of", "Disney": "disney movie character of", "Sketch": "pencil sketch portrait of", "Astronaut": "realistic astronaut with helmet, portrait of" } # --- Generation Function --- def generate_avatar(image_input: Image.Image, style: str, gender: Literal["male", "female", "unspecified"]): """ Generates an avatar based on a chosen style and gender. - If using StableDiffusionPipeline (Text-to-Image): The uploaded `image_input` is used only to trigger the generation and is NOT directly used to influence the avatar's appearance. A new person is generated based on the text. - If using StableDiffusionImg2ImgPipeline (Image-to-Image - commented out by default): The `image_input` WOULD be used as the base image for transformation. """ if image_input is None: gr.Warning("Please upload an image to enable avatar generation. (Even if it's not directly used for content, it acts as a trigger).") return None # Base prompt from the selected style base_prompt = styles[style] # Construct the subject part of the prompt based on gender selection gender_subject = "" if gender == "male": gender_subject = "a man" elif gender == "female": gender_subject = "a woman" else: # unspecified gender_subject = "a person" # Model will default based on its biases if no gender specified # Enhance the prompt for better quality and detail in text-to-image generation prompt = f"{base_prompt} {gender_subject}, high quality, detailed, professional photography, studio lighting, volumetric lighting, 4k, cinematic, sharp focus" # Stronger negative prompt to avoid common issues like low quality, distortions, and undesired artifacts negative_prompt = "low resolution, blurry, distorted, bad quality, ugly, cartoon, sketch, duplicate, out of frame, bad anatomy, deformed, extra limbs, malformed hands, missing fingers, watermark, text, signature, low contrast, oversaturated" # Inference parameters (tuned for a balance of speed and quality on CPU) num_inference_steps = 25 # Generally, 20-30 steps is a good range for quality vs speed on CPU guidance_scale = 7.5 # Higher values make output closer to prompt, but can be less diverse print(f"[{os.getpid()}] Generating for style: '{style}', gender: '{gender}', with prompt: '{prompt}' (Steps: {num_inference_steps}, Guidance: {guidance_scale})") try: # Use torch.no_grad() or torch.inference_mode() to disable gradient calculations # during inference, which saves memory and speeds up computation. with torch.no_grad(): # For PyTorch >= 1.9, torch.inference_mode() is also an option if isinstance(pipe, StableDiffusionPipeline): # Text-to-Image generation: Image_input is ignored for content generated_image = pipe( prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, height=512, # Stable Diffusion 1.x models are usually trained at 512x512 width=512 ).images[0] # elif isinstance(pipe, StableDiffusionImg2ImgPipeline): # # Image-to-Image generation: Uncomment this block if you switch to Img2ImgPipeline # # The 'strength' parameter controls how much noise is added to the input image. # # 0.0 means no change, 1.0 means complete re-imagining (like text-to-image). # # A value around 0.7-0.8 is typical for style transfer. # strength = 0.75 # generated_image = pipe( # prompt=prompt, # image=image_input, # Pass the uploaded image here for img2img # negative_prompt=negative_prompt, # num_inference_steps=num_inference_steps, # guidance_scale=guidance_scale, # strength=strength # ).images[0] else: raise ValueError("Unsupported pipeline type. Please check model loading.") print(f"[{os.getpid()}] Image generation complete.") return generated_image except Exception as e: print(f"[{os.getpid()}] Error during image generation: {e}") # Display an error message to the user in the Gradio interface gr.Error(f"An error occurred during image generation: {e}") return None # Return None to clear the output image # --- Gradio Interface Definition --- with gr.Blocks() as demo: gr.Markdown("## 🎨 Stable Diffusion Avatar Generator with Preset Styles (CPU Optimized)") gr.Markdown( "This demo uses a smaller, distilled Stable Diffusion model and is optimized for CPU inference. " "Generation will still take time on CPU compared to GPU (e.g., 20-60 seconds per image depending on CPU and parameters).
" "**Note:** The uploaded image is currently used only to trigger generation and is **not directly influencing the avatar's appearance**. " "It's here for your reference or potential future Image-to-Image features. You will get a new person in the chosen style." ) with gr.Row(): with gr.Column(): # Image input component. type="pil" ensures a PIL Image object is passed to the function. image_input = gr.Image( label="Upload your photo", type="pil", sources=["upload", "webcam"], # Allow file upload or webcam capture # Optional: Add a placeholder image path if you want a default visual # value="assets/placeholder.jpg" ) style_selector = gr.Radio( choices=list(styles.keys()), label="Choose a style", value="Anime", # Default selected style info="Select the artistic style for your avatar." ) gender_selector = gr.Radio( choices=["male", "female", "unspecified"], label="Choose a Gender", value="male", # Default to male to address your specific issue info="Explicitly set the gender of the generated person. 'Unspecified' may lead to biased results from the model." ) generate_btn = gr.Button("Generate Avatar", variant="primary") with gr.Column(): output_image = gr.Image(label="Generated Avatar") # Connect the button click to the generation function, passing all inputs generate_btn.click( fn=generate_avatar, inputs=[image_input, style_selector, gender_selector], # Now includes gender_selector outputs=output_image ) # Optional: Add examples for quick testing gr.Examples( examples=[ # Example format: [image_path_or_None, style_name, gender] # Use None for image_input as it's not directly influencing the output in text-to-image mode [None, "Pixar", "male"], [None, "Anime", "female"], [None, "Cyberpunk", "unspecified"], # To show what 'unspecified' might produce [None, "Disney", "male"], [None, "Sketch", "female"], [None, "Astronaut", "male"] ], inputs=[image_input, style_selector, gender_selector], # fn=generate_avatar, # Uncomment if you want examples to run the generation live # outputs=output_image, cache_examples=False, # Set to True if examples are pre-computed images, False for live generation label="Quick Examples (Generates new images each time)" ) # Launch the Gradio application # share=True will generate a public link (useful for sharing demos temporarily) # auth=("username", "password") for basic authentication demo.launch(inbrowser=True, show_error=True)