Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,61 +1,69 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderTiny
|
|
|
|
|
4 |
from PIL import Image
|
5 |
import os # For better logging/debugging
|
|
|
6 |
|
7 |
# --- Configuration ---
|
8 |
-
# 1. Force CPU usage for compatibility on
|
9 |
device = "cpu"
|
10 |
|
11 |
# 2. Choose a smaller/distilled Stable Diffusion model for CPU speed
|
12 |
-
# 'nota-ai/bk-sdm-small'
|
13 |
-
# If quality is
|
14 |
-
#
|
|
|
15 |
model_id = "nota-ai/bk-sdm-small"
|
16 |
|
17 |
-
# 3. Tiny VAE for drastically faster encoding/decoding on CPU
|
18 |
tiny_vae_id = "sayakpaul/taesd-diffusers"
|
19 |
|
20 |
# --- Model Loading ---
|
21 |
-
# Load the pipeline globally to avoid reloading on each request
|
22 |
-
print(f"Loading model: {model_id} on {device}...")
|
23 |
try:
|
24 |
-
# Use StableDiffusionPipeline for Text-to-Image generation
|
25 |
-
# If you want Image-to-Image,
|
26 |
-
|
|
|
|
|
|
|
|
|
27 |
model_id,
|
28 |
-
torch_dtype=torch.float32, # CPU usually
|
29 |
-
low_cpu_mem_usage=True, # Helps
|
30 |
-
safety_checker=None # Disable safety checker to save CPU cycles and memory
|
31 |
)
|
32 |
-
print("Main pipeline loaded.")
|
33 |
|
34 |
-
# Load and assign the Tiny VAE for speed optimization
|
35 |
-
print(f"Loading Tiny VAE from {tiny_vae_id}...")
|
36 |
try:
|
37 |
pipe.vae = AutoencoderTiny.from_pretrained(tiny_vae_id, torch_dtype=torch.float32)
|
38 |
-
print("Tiny VAE loaded successfully.")
|
39 |
except Exception as vae_e:
|
40 |
-
print(f"Warning: Could not load Tiny VAE '{tiny_vae_id}': {vae_e}. Using default VAE (
|
41 |
-
# Ensure default VAE is
|
42 |
pipe.vae.to(device)
|
43 |
|
44 |
-
# Move entire pipeline to CPU explicitly
|
45 |
pipe.to(device)
|
46 |
|
47 |
-
# Set up the scheduler. DDIMScheduler is a good choice.
|
48 |
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
49 |
|
50 |
-
# Optional: Enable CPU offload if you
|
51 |
-
# Be aware
|
52 |
# pipe.enable_sequential_cpu_offload()
|
53 |
|
54 |
-
print("Model loaded and configured
|
55 |
|
56 |
except Exception as e:
|
57 |
-
print(f"FATAL ERROR: Failed to load models: {e}")
|
58 |
-
# Raise an exception to prevent the
|
59 |
raise RuntimeError(f"Failed to load Stable Diffusion model: {e}")
|
60 |
|
61 |
# --- Preset Styles ---
|
@@ -69,100 +77,147 @@ styles = {
|
|
69 |
}
|
70 |
|
71 |
# --- Generation Function ---
|
72 |
-
def generate_avatar(image_input: Image.Image, style: str):
|
73 |
"""
|
74 |
-
Generates an avatar based on a chosen style
|
75 |
-
|
76 |
-
|
|
|
|
|
|
|
|
|
77 |
"""
|
78 |
if image_input is None:
|
79 |
-
gr.Warning("Please upload an image to
|
80 |
return None
|
81 |
|
82 |
-
# Base prompt from selected style
|
83 |
base_prompt = styles[style]
|
84 |
|
85 |
-
#
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
-
# Inference parameters (
|
90 |
-
num_inference_steps = 25 #
|
91 |
-
guidance_scale = 7.5 #
|
92 |
|
93 |
-
print(f"Generating for style: {style} with prompt: '{prompt}' (Steps: {num_inference_steps}, Guidance: {guidance_scale})")
|
94 |
|
95 |
try:
|
96 |
-
# Use torch.no_grad()
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
return generated_image
|
109 |
|
110 |
except Exception as e:
|
111 |
-
print(f"Error during image generation: {e}")
|
112 |
-
|
113 |
-
|
|
|
114 |
|
115 |
-
# --- Gradio Interface ---
|
116 |
with gr.Blocks() as demo:
|
117 |
gr.Markdown("## 🎨 Stable Diffusion Avatar Generator with Preset Styles (CPU Optimized)")
|
118 |
gr.Markdown(
|
119 |
"This demo uses a smaller, distilled Stable Diffusion model and is optimized for CPU inference. "
|
120 |
"Generation will still take time on CPU compared to GPU (e.g., 20-60 seconds per image depending on CPU and parameters).<br>"
|
121 |
-
"**Note:** The uploaded image is currently used only to trigger generation and is not directly influencing the avatar's appearance
|
122 |
-
"It's here for
|
123 |
)
|
124 |
|
125 |
with gr.Row():
|
126 |
with gr.Column():
|
127 |
-
# Image input component
|
128 |
image_input = gr.Image(
|
129 |
label="Upload your photo",
|
130 |
type="pil",
|
131 |
sources=["upload", "webcam"], # Allow file upload or webcam capture
|
132 |
-
#
|
|
|
133 |
)
|
134 |
style_selector = gr.Radio(
|
135 |
choices=list(styles.keys()),
|
136 |
label="Choose a style",
|
137 |
-
value="Anime" # Default selected style
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
)
|
139 |
generate_btn = gr.Button("Generate Avatar", variant="primary")
|
140 |
|
141 |
with gr.Column():
|
142 |
output_image = gr.Image(label="Generated Avatar")
|
143 |
|
144 |
-
# Connect the button click to the generation function
|
145 |
generate_btn.click(
|
146 |
fn=generate_avatar,
|
147 |
-
inputs=[image_input, style_selector],
|
148 |
outputs=output_image
|
149 |
)
|
150 |
|
|
|
151 |
gr.Examples(
|
152 |
examples=[
|
153 |
-
[
|
154 |
-
|
155 |
-
[None, "
|
156 |
-
[None, "
|
157 |
-
[None, "
|
158 |
-
[None, "
|
|
|
|
|
159 |
],
|
160 |
-
inputs=[image_input, style_selector],
|
161 |
-
fn=generate_avatar,
|
162 |
-
outputs=output_image,
|
163 |
-
cache_examples=False, # Set to True if examples are pre-computed, False for live generation
|
164 |
label="Quick Examples (Generates new images each time)"
|
165 |
)
|
166 |
|
167 |
# Launch the Gradio application
|
168 |
-
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderTiny
|
4 |
+
# For Image-to-Image, you would also import:
|
5 |
+
# from diffusers import StableDiffusionImg2ImgPipeline
|
6 |
from PIL import Image
|
7 |
import os # For better logging/debugging
|
8 |
+
from typing import Literal # For type hinting the gender choices
|
9 |
|
10 |
# --- Configuration ---
|
11 |
+
# 1. Force CPU usage for compatibility on machines without a GPU
|
12 |
device = "cpu"
|
13 |
|
14 |
# 2. Choose a smaller/distilled Stable Diffusion model for CPU speed
|
15 |
+
# 'nota-ai/bk-sdm-small' offers a good balance of size, speed, and reasonable quality for CPU.
|
16 |
+
# If higher quality is essential and you can tolerate much longer generation times on CPU,
|
17 |
+
# you might consider 'runwayml/stable-diffusion-v1-5', but be prepared for significant slowdowns
|
18 |
+
# and potentially higher memory consumption that might require `enable_sequential_cpu_offload()`.
|
19 |
model_id = "nota-ai/bk-sdm-small"
|
20 |
|
21 |
+
# 3. Tiny VAE for drastically faster encoding/decoding on CPU. This is a crucial optimization.
|
22 |
tiny_vae_id = "sayakpaul/taesd-diffusers"
|
23 |
|
24 |
# --- Model Loading ---
|
25 |
+
# Load the pipeline globally when the application starts to avoid reloading on each request.
|
26 |
+
print(f"[{os.getpid()}] Loading model: {model_id} on {device}...")
|
27 |
try:
|
28 |
+
# Use StableDiffusionPipeline for Text-to-Image generation (generate a new person in a style)
|
29 |
+
# If you want to transform an uploaded image (Image-to-Image), uncomment the line below
|
30 |
+
# and replace `StableDiffusionPipeline` with `StableDiffusionImg2ImgPipeline`.
|
31 |
+
pipe_class = StableDiffusionPipeline
|
32 |
+
# pipe_class = StableDiffusionImg2ImgPipeline # Uncomment this for Image-to-Image functionality
|
33 |
+
|
34 |
+
pipe = pipe_class.from_pretrained(
|
35 |
model_id,
|
36 |
+
torch_dtype=torch.float32, # CPU usually performs best with float32
|
37 |
+
low_cpu_mem_usage=True, # Helps reduce peak memory usage on CPU
|
38 |
+
safety_checker=None # Disable safety checker to save CPU cycles and memory for faster generation
|
39 |
)
|
40 |
+
print(f"[{os.getpid()}] Main pipeline loaded.")
|
41 |
|
42 |
+
# Load and assign the Tiny VAE for significant speed optimization in the VAE step
|
43 |
+
print(f"[{os.getpid()}] Loading Tiny VAE from {tiny_vae_id}...")
|
44 |
try:
|
45 |
pipe.vae = AutoencoderTiny.from_pretrained(tiny_vae_id, torch_dtype=torch.float32)
|
46 |
+
print(f"[{os.getpid()}] Tiny VAE loaded successfully.")
|
47 |
except Exception as vae_e:
|
48 |
+
print(f"[{os.getpid()}] Warning: Could not load Tiny VAE '{tiny_vae_id}': {vae_e}. Using default VAE (this will be slower).")
|
49 |
+
# Ensure the default VAE is explicitly moved to CPU if Tiny VAE fails to load
|
50 |
pipe.vae.to(device)
|
51 |
|
52 |
+
# Move entire pipeline components to CPU explicitly
|
53 |
pipe.to(device)
|
54 |
|
55 |
+
# Set up the scheduler. DDIMScheduler is a good general-purpose choice.
|
56 |
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
57 |
|
58 |
+
# Optional: Enable CPU offload if you encounter Out-Of-Memory errors on CPU,
|
59 |
+
# especially with larger models. Be aware that this will make generation significantly slower.
|
60 |
# pipe.enable_sequential_cpu_offload()
|
61 |
|
62 |
+
print(f"[{os.getpid()}] Model fully loaded and configured on {device}.")
|
63 |
|
64 |
except Exception as e:
|
65 |
+
print(f"[{os.getpid()}] FATAL ERROR: Failed to load models: {e}")
|
66 |
+
# Raise an exception to prevent the application from starting if model loading fails
|
67 |
raise RuntimeError(f"Failed to load Stable Diffusion model: {e}")
|
68 |
|
69 |
# --- Preset Styles ---
|
|
|
77 |
}
|
78 |
|
79 |
# --- Generation Function ---
|
80 |
+
def generate_avatar(image_input: Image.Image, style: str, gender: Literal["male", "female", "unspecified"]):
|
81 |
"""
|
82 |
+
Generates an avatar based on a chosen style and gender.
|
83 |
+
|
84 |
+
- If using StableDiffusionPipeline (Text-to-Image): The uploaded `image_input`
|
85 |
+
is used only to trigger the generation and is NOT directly used to
|
86 |
+
influence the avatar's appearance. A new person is generated based on the text.
|
87 |
+
- If using StableDiffusionImg2ImgPipeline (Image-to-Image - commented out by default):
|
88 |
+
The `image_input` WOULD be used as the base image for transformation.
|
89 |
"""
|
90 |
if image_input is None:
|
91 |
+
gr.Warning("Please upload an image to enable avatar generation. (Even if it's not directly used for content, it acts as a trigger).")
|
92 |
return None
|
93 |
|
94 |
+
# Base prompt from the selected style
|
95 |
base_prompt = styles[style]
|
96 |
|
97 |
+
# Construct the subject part of the prompt based on gender selection
|
98 |
+
gender_subject = ""
|
99 |
+
if gender == "male":
|
100 |
+
gender_subject = "a man"
|
101 |
+
elif gender == "female":
|
102 |
+
gender_subject = "a woman"
|
103 |
+
else: # unspecified
|
104 |
+
gender_subject = "a person" # Model will default based on its biases if no gender specified
|
105 |
+
|
106 |
+
# Enhance the prompt for better quality and detail in text-to-image generation
|
107 |
+
prompt = f"{base_prompt} {gender_subject}, high quality, detailed, professional photography, studio lighting, volumetric lighting, 4k, cinematic, sharp focus"
|
108 |
+
# Stronger negative prompt to avoid common issues like low quality, distortions, and undesired artifacts
|
109 |
+
negative_prompt = "low resolution, blurry, distorted, bad quality, ugly, cartoon, sketch, duplicate, out of frame, bad anatomy, deformed, extra limbs, malformed hands, missing fingers, watermark, text, signature, low contrast, oversaturated"
|
110 |
|
111 |
+
# Inference parameters (tuned for a balance of speed and quality on CPU)
|
112 |
+
num_inference_steps = 25 # Generally, 20-30 steps is a good range for quality vs speed on CPU
|
113 |
+
guidance_scale = 7.5 # Higher values make output closer to prompt, but can be less diverse
|
114 |
|
115 |
+
print(f"[{os.getpid()}] Generating for style: '{style}', gender: '{gender}', with prompt: '{prompt}' (Steps: {num_inference_steps}, Guidance: {guidance_scale})")
|
116 |
|
117 |
try:
|
118 |
+
# Use torch.no_grad() or torch.inference_mode() to disable gradient calculations
|
119 |
+
# during inference, which saves memory and speeds up computation.
|
120 |
+
with torch.no_grad(): # For PyTorch >= 1.9, torch.inference_mode() is also an option
|
121 |
+
if isinstance(pipe, StableDiffusionPipeline):
|
122 |
+
# Text-to-Image generation: Image_input is ignored for content
|
123 |
+
generated_image = pipe(
|
124 |
+
prompt=prompt,
|
125 |
+
negative_prompt=negative_prompt,
|
126 |
+
num_inference_steps=num_inference_steps,
|
127 |
+
guidance_scale=guidance_scale,
|
128 |
+
height=512, # Stable Diffusion 1.x models are usually trained at 512x512
|
129 |
+
width=512
|
130 |
+
).images[0]
|
131 |
+
# elif isinstance(pipe, StableDiffusionImg2ImgPipeline):
|
132 |
+
# # Image-to-Image generation: Uncomment this block if you switch to Img2ImgPipeline
|
133 |
+
# # The 'strength' parameter controls how much noise is added to the input image.
|
134 |
+
# # 0.0 means no change, 1.0 means complete re-imagining (like text-to-image).
|
135 |
+
# # A value around 0.7-0.8 is typical for style transfer.
|
136 |
+
# strength = 0.75
|
137 |
+
# generated_image = pipe(
|
138 |
+
# prompt=prompt,
|
139 |
+
# image=image_input, # Pass the uploaded image here for img2img
|
140 |
+
# negative_prompt=negative_prompt,
|
141 |
+
# num_inference_steps=num_inference_steps,
|
142 |
+
# guidance_scale=guidance_scale,
|
143 |
+
# strength=strength
|
144 |
+
# ).images[0]
|
145 |
+
else:
|
146 |
+
raise ValueError("Unsupported pipeline type. Please check model loading.")
|
147 |
+
|
148 |
+
print(f"[{os.getpid()}] Image generation complete.")
|
149 |
return generated_image
|
150 |
|
151 |
except Exception as e:
|
152 |
+
print(f"[{os.getpid()}] Error during image generation: {e}")
|
153 |
+
# Display an error message to the user in the Gradio interface
|
154 |
+
gr.Error(f"An error occurred during image generation: {e}")
|
155 |
+
return None # Return None to clear the output image
|
156 |
|
157 |
+
# --- Gradio Interface Definition ---
|
158 |
with gr.Blocks() as demo:
|
159 |
gr.Markdown("## 🎨 Stable Diffusion Avatar Generator with Preset Styles (CPU Optimized)")
|
160 |
gr.Markdown(
|
161 |
"This demo uses a smaller, distilled Stable Diffusion model and is optimized for CPU inference. "
|
162 |
"Generation will still take time on CPU compared to GPU (e.g., 20-60 seconds per image depending on CPU and parameters).<br>"
|
163 |
+
"**Note:** The uploaded image is currently used only to trigger generation and is **not directly influencing the avatar's appearance**. "
|
164 |
+
"It's here for your reference or potential future Image-to-Image features. You will get a new person in the chosen style."
|
165 |
)
|
166 |
|
167 |
with gr.Row():
|
168 |
with gr.Column():
|
169 |
+
# Image input component. type="pil" ensures a PIL Image object is passed to the function.
|
170 |
image_input = gr.Image(
|
171 |
label="Upload your photo",
|
172 |
type="pil",
|
173 |
sources=["upload", "webcam"], # Allow file upload or webcam capture
|
174 |
+
# Optional: Add a placeholder image path if you want a default visual
|
175 |
+
# value="assets/placeholder.jpg"
|
176 |
)
|
177 |
style_selector = gr.Radio(
|
178 |
choices=list(styles.keys()),
|
179 |
label="Choose a style",
|
180 |
+
value="Anime", # Default selected style
|
181 |
+
info="Select the artistic style for your avatar."
|
182 |
+
)
|
183 |
+
gender_selector = gr.Radio(
|
184 |
+
choices=["male", "female", "unspecified"],
|
185 |
+
label="Choose a Gender",
|
186 |
+
value="male", # Default to male to address your specific issue
|
187 |
+
info="Explicitly set the gender of the generated person. 'Unspecified' may lead to biased results from the model."
|
188 |
)
|
189 |
generate_btn = gr.Button("Generate Avatar", variant="primary")
|
190 |
|
191 |
with gr.Column():
|
192 |
output_image = gr.Image(label="Generated Avatar")
|
193 |
|
194 |
+
# Connect the button click to the generation function, passing all inputs
|
195 |
generate_btn.click(
|
196 |
fn=generate_avatar,
|
197 |
+
inputs=[image_input, style_selector, gender_selector], # Now includes gender_selector
|
198 |
outputs=output_image
|
199 |
)
|
200 |
|
201 |
+
# Optional: Add examples for quick testing
|
202 |
gr.Examples(
|
203 |
examples=[
|
204 |
+
# Example format: [image_path_or_None, style_name, gender]
|
205 |
+
# Use None for image_input as it's not directly influencing the output in text-to-image mode
|
206 |
+
[None, "Pixar", "male"],
|
207 |
+
[None, "Anime", "female"],
|
208 |
+
[None, "Cyberpunk", "unspecified"], # To show what 'unspecified' might produce
|
209 |
+
[None, "Disney", "male"],
|
210 |
+
[None, "Sketch", "female"],
|
211 |
+
[None, "Astronaut", "male"]
|
212 |
],
|
213 |
+
inputs=[image_input, style_selector, gender_selector],
|
214 |
+
# fn=generate_avatar, # Uncomment if you want examples to run the generation live
|
215 |
+
# outputs=output_image,
|
216 |
+
cache_examples=False, # Set to True if examples are pre-computed images, False for live generation
|
217 |
label="Quick Examples (Generates new images each time)"
|
218 |
)
|
219 |
|
220 |
# Launch the Gradio application
|
221 |
+
# share=True will generate a public link (useful for sharing demos temporarily)
|
222 |
+
# auth=("username", "password") for basic authentication
|
223 |
+
demo.launch(inbrowser=True, show_error=True)
|