b2bomber commited on
Commit
9390207
·
verified ·
1 Parent(s): d8eef13

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -73
app.py CHANGED
@@ -1,61 +1,69 @@
1
  import gradio as gr
2
  import torch
3
  from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderTiny
 
 
4
  from PIL import Image
5
  import os # For better logging/debugging
 
6
 
7
  # --- Configuration ---
8
- # 1. Force CPU usage for compatibility on Spaces without GPU
9
  device = "cpu"
10
 
11
  # 2. Choose a smaller/distilled Stable Diffusion model for CPU speed
12
- # 'nota-ai/bk-sdm-small' is a good balance of size/speed/quality for CPU.
13
- # If quality is paramount and you can tolerate more time, consider 'runwayml/stable-diffusion-v1-5'
14
- # but expect significantly slower generation times on CPU.
 
15
  model_id = "nota-ai/bk-sdm-small"
16
 
17
- # 3. Tiny VAE for drastically faster encoding/decoding on CPU
18
  tiny_vae_id = "sayakpaul/taesd-diffusers"
19
 
20
  # --- Model Loading ---
21
- # Load the pipeline globally to avoid reloading on each request
22
- print(f"Loading model: {model_id} on {device}...")
23
  try:
24
- # Use StableDiffusionPipeline for Text-to-Image generation
25
- # If you want Image-to-Image, you'd use StableDiffusionImg2ImgPipeline here.
26
- pipe = StableDiffusionPipeline.from_pretrained(
 
 
 
 
27
  model_id,
28
- torch_dtype=torch.float32, # CPU usually prefers float32 for stability/speed
29
- low_cpu_mem_usage=True, # Helps with memory on CPU
30
- safety_checker=None # Disable safety checker to save CPU cycles and memory
31
  )
32
- print("Main pipeline loaded.")
33
 
34
- # Load and assign the Tiny VAE for speed optimization
35
- print(f"Loading Tiny VAE from {tiny_vae_id}...")
36
  try:
37
  pipe.vae = AutoencoderTiny.from_pretrained(tiny_vae_id, torch_dtype=torch.float32)
38
- print("Tiny VAE loaded successfully.")
39
  except Exception as vae_e:
40
- print(f"Warning: Could not load Tiny VAE '{tiny_vae_id}': {vae_e}. Using default VAE (might be slower).")
41
- # Ensure default VAE is on CPU
42
  pipe.vae.to(device)
43
 
44
- # Move entire pipeline to CPU explicitly
45
  pipe.to(device)
46
 
47
- # Set up the scheduler. DDIMScheduler is a good choice.
48
  pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
49
 
50
- # Optional: Enable CPU offload if you run into Out-Of-Memory errors on CPU with larger models.
51
- # Be aware: This will make generation *much* slower.
52
  # pipe.enable_sequential_cpu_offload()
53
 
54
- print("Model loaded and configured successfully.")
55
 
56
  except Exception as e:
57
- print(f"FATAL ERROR: Failed to load models: {e}")
58
- # Raise an exception to prevent the app from starting if model loading fails
59
  raise RuntimeError(f"Failed to load Stable Diffusion model: {e}")
60
 
61
  # --- Preset Styles ---
@@ -69,100 +77,147 @@ styles = {
69
  }
70
 
71
  # --- Generation Function ---
72
- def generate_avatar(image_input: Image.Image, style: str):
73
  """
74
- Generates an avatar based on a chosen style using Stable Diffusion.
75
- Note: In this text-to-image setup, the uploaded `image_input` is used
76
- only to trigger the generation, not to influence the image content directly.
 
 
 
 
77
  """
78
  if image_input is None:
79
- gr.Warning("Please upload an image to generate an avatar.")
80
  return None
81
 
82
- # Base prompt from selected style
83
  base_prompt = styles[style]
84
 
85
- # Enhance prompt for better quality
86
- prompt = f"{base_prompt} a person, highly detailed, professional, studio lighting, volumetric lighting, 4k, cinematic"
87
- negative_prompt = "low resolution, blurry, distorted, bad quality, ugly, cartoon, sketch, duplicate, out of frame, bad anatomy, deformed, extra limbs, watermark, text"
 
 
 
 
 
 
 
 
 
 
88
 
89
- # Inference parameters (adjusted for speed on CPU, can be tweaked for quality)
90
- num_inference_steps = 25 # Increased slightly for better quality, balance with speed
91
- guidance_scale = 7.5 # Slightly increased for stronger adherence to prompt
92
 
93
- print(f"Generating for style: {style} with prompt: '{prompt}' (Steps: {num_inference_steps}, Guidance: {guidance_scale})")
94
 
95
  try:
96
- # Use torch.no_grad() for efficient inference (disables gradient calculations)
97
- with torch.no_grad(): # Or torch.inference_mode() for PyTorch >= 1.9
98
- generated_image = pipe(
99
- prompt=prompt,
100
- negative_prompt=negative_prompt,
101
- num_inference_steps=num_inference_steps,
102
- guidance_scale=guidance_scale,
103
- height=512, # Explicitly set output dimensions, can try 768 for SD 2.1 or larger models
104
- width=512
105
- ).images[0]
106
-
107
- print("Image generation complete.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  return generated_image
109
 
110
  except Exception as e:
111
- print(f"Error during image generation: {e}")
112
- gr.Error(f"An error occurred during generation: {e}")
113
- return None
 
114
 
115
- # --- Gradio Interface ---
116
  with gr.Blocks() as demo:
117
  gr.Markdown("## 🎨 Stable Diffusion Avatar Generator with Preset Styles (CPU Optimized)")
118
  gr.Markdown(
119
  "This demo uses a smaller, distilled Stable Diffusion model and is optimized for CPU inference. "
120
  "Generation will still take time on CPU compared to GPU (e.g., 20-60 seconds per image depending on CPU and parameters).<br>"
121
- "**Note:** The uploaded image is currently used only to trigger generation and is not directly influencing the avatar's appearance. "
122
- "It's here for user reference or potential future Image-to-Image features."
123
  )
124
 
125
  with gr.Row():
126
  with gr.Column():
127
- # Image input component (type="pil" for Pillow Image object)
128
  image_input = gr.Image(
129
  label="Upload your photo",
130
  type="pil",
131
  sources=["upload", "webcam"], # Allow file upload or webcam capture
132
- # You might want to set a default for testing: value="path/to/default_image.jpg"
 
133
  )
134
  style_selector = gr.Radio(
135
  choices=list(styles.keys()),
136
  label="Choose a style",
137
- value="Anime" # Default selected style
 
 
 
 
 
 
 
138
  )
139
  generate_btn = gr.Button("Generate Avatar", variant="primary")
140
 
141
  with gr.Column():
142
  output_image = gr.Image(label="Generated Avatar")
143
 
144
- # Connect the button click to the generation function
145
  generate_btn.click(
146
  fn=generate_avatar,
147
- inputs=[image_input, style_selector],
148
  outputs=output_image
149
  )
150
 
 
151
  gr.Examples(
152
  examples=[
153
- [None, "Pixar"],
154
- [None, "Anime"],
155
- [None, "Cyberpunk"],
156
- [None, "Disney"],
157
- [None, "Sketch"],
158
- [None, "Astronaut"]
 
 
159
  ],
160
- inputs=[image_input, style_selector],
161
- fn=generate_avatar,
162
- outputs=output_image,
163
- cache_examples=False, # Set to True if examples are pre-computed, False for live generation
164
  label="Quick Examples (Generates new images each time)"
165
  )
166
 
167
  # Launch the Gradio application
168
- demo.launch()
 
 
 
1
  import gradio as gr
2
  import torch
3
  from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderTiny
4
+ # For Image-to-Image, you would also import:
5
+ # from diffusers import StableDiffusionImg2ImgPipeline
6
  from PIL import Image
7
  import os # For better logging/debugging
8
+ from typing import Literal # For type hinting the gender choices
9
 
10
  # --- Configuration ---
11
+ # 1. Force CPU usage for compatibility on machines without a GPU
12
  device = "cpu"
13
 
14
  # 2. Choose a smaller/distilled Stable Diffusion model for CPU speed
15
+ # 'nota-ai/bk-sdm-small' offers a good balance of size, speed, and reasonable quality for CPU.
16
+ # If higher quality is essential and you can tolerate much longer generation times on CPU,
17
+ # you might consider 'runwayml/stable-diffusion-v1-5', but be prepared for significant slowdowns
18
+ # and potentially higher memory consumption that might require `enable_sequential_cpu_offload()`.
19
  model_id = "nota-ai/bk-sdm-small"
20
 
21
+ # 3. Tiny VAE for drastically faster encoding/decoding on CPU. This is a crucial optimization.
22
  tiny_vae_id = "sayakpaul/taesd-diffusers"
23
 
24
  # --- Model Loading ---
25
+ # Load the pipeline globally when the application starts to avoid reloading on each request.
26
+ print(f"[{os.getpid()}] Loading model: {model_id} on {device}...")
27
  try:
28
+ # Use StableDiffusionPipeline for Text-to-Image generation (generate a new person in a style)
29
+ # If you want to transform an uploaded image (Image-to-Image), uncomment the line below
30
+ # and replace `StableDiffusionPipeline` with `StableDiffusionImg2ImgPipeline`.
31
+ pipe_class = StableDiffusionPipeline
32
+ # pipe_class = StableDiffusionImg2ImgPipeline # Uncomment this for Image-to-Image functionality
33
+
34
+ pipe = pipe_class.from_pretrained(
35
  model_id,
36
+ torch_dtype=torch.float32, # CPU usually performs best with float32
37
+ low_cpu_mem_usage=True, # Helps reduce peak memory usage on CPU
38
+ safety_checker=None # Disable safety checker to save CPU cycles and memory for faster generation
39
  )
40
+ print(f"[{os.getpid()}] Main pipeline loaded.")
41
 
42
+ # Load and assign the Tiny VAE for significant speed optimization in the VAE step
43
+ print(f"[{os.getpid()}] Loading Tiny VAE from {tiny_vae_id}...")
44
  try:
45
  pipe.vae = AutoencoderTiny.from_pretrained(tiny_vae_id, torch_dtype=torch.float32)
46
+ print(f"[{os.getpid()}] Tiny VAE loaded successfully.")
47
  except Exception as vae_e:
48
+ print(f"[{os.getpid()}] Warning: Could not load Tiny VAE '{tiny_vae_id}': {vae_e}. Using default VAE (this will be slower).")
49
+ # Ensure the default VAE is explicitly moved to CPU if Tiny VAE fails to load
50
  pipe.vae.to(device)
51
 
52
+ # Move entire pipeline components to CPU explicitly
53
  pipe.to(device)
54
 
55
+ # Set up the scheduler. DDIMScheduler is a good general-purpose choice.
56
  pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
57
 
58
+ # Optional: Enable CPU offload if you encounter Out-Of-Memory errors on CPU,
59
+ # especially with larger models. Be aware that this will make generation significantly slower.
60
  # pipe.enable_sequential_cpu_offload()
61
 
62
+ print(f"[{os.getpid()}] Model fully loaded and configured on {device}.")
63
 
64
  except Exception as e:
65
+ print(f"[{os.getpid()}] FATAL ERROR: Failed to load models: {e}")
66
+ # Raise an exception to prevent the application from starting if model loading fails
67
  raise RuntimeError(f"Failed to load Stable Diffusion model: {e}")
68
 
69
  # --- Preset Styles ---
 
77
  }
78
 
79
  # --- Generation Function ---
80
+ def generate_avatar(image_input: Image.Image, style: str, gender: Literal["male", "female", "unspecified"]):
81
  """
82
+ Generates an avatar based on a chosen style and gender.
83
+
84
+ - If using StableDiffusionPipeline (Text-to-Image): The uploaded `image_input`
85
+ is used only to trigger the generation and is NOT directly used to
86
+ influence the avatar's appearance. A new person is generated based on the text.
87
+ - If using StableDiffusionImg2ImgPipeline (Image-to-Image - commented out by default):
88
+ The `image_input` WOULD be used as the base image for transformation.
89
  """
90
  if image_input is None:
91
+ gr.Warning("Please upload an image to enable avatar generation. (Even if it's not directly used for content, it acts as a trigger).")
92
  return None
93
 
94
+ # Base prompt from the selected style
95
  base_prompt = styles[style]
96
 
97
+ # Construct the subject part of the prompt based on gender selection
98
+ gender_subject = ""
99
+ if gender == "male":
100
+ gender_subject = "a man"
101
+ elif gender == "female":
102
+ gender_subject = "a woman"
103
+ else: # unspecified
104
+ gender_subject = "a person" # Model will default based on its biases if no gender specified
105
+
106
+ # Enhance the prompt for better quality and detail in text-to-image generation
107
+ prompt = f"{base_prompt} {gender_subject}, high quality, detailed, professional photography, studio lighting, volumetric lighting, 4k, cinematic, sharp focus"
108
+ # Stronger negative prompt to avoid common issues like low quality, distortions, and undesired artifacts
109
+ negative_prompt = "low resolution, blurry, distorted, bad quality, ugly, cartoon, sketch, duplicate, out of frame, bad anatomy, deformed, extra limbs, malformed hands, missing fingers, watermark, text, signature, low contrast, oversaturated"
110
 
111
+ # Inference parameters (tuned for a balance of speed and quality on CPU)
112
+ num_inference_steps = 25 # Generally, 20-30 steps is a good range for quality vs speed on CPU
113
+ guidance_scale = 7.5 # Higher values make output closer to prompt, but can be less diverse
114
 
115
+ print(f"[{os.getpid()}] Generating for style: '{style}', gender: '{gender}', with prompt: '{prompt}' (Steps: {num_inference_steps}, Guidance: {guidance_scale})")
116
 
117
  try:
118
+ # Use torch.no_grad() or torch.inference_mode() to disable gradient calculations
119
+ # during inference, which saves memory and speeds up computation.
120
+ with torch.no_grad(): # For PyTorch >= 1.9, torch.inference_mode() is also an option
121
+ if isinstance(pipe, StableDiffusionPipeline):
122
+ # Text-to-Image generation: Image_input is ignored for content
123
+ generated_image = pipe(
124
+ prompt=prompt,
125
+ negative_prompt=negative_prompt,
126
+ num_inference_steps=num_inference_steps,
127
+ guidance_scale=guidance_scale,
128
+ height=512, # Stable Diffusion 1.x models are usually trained at 512x512
129
+ width=512
130
+ ).images[0]
131
+ # elif isinstance(pipe, StableDiffusionImg2ImgPipeline):
132
+ # # Image-to-Image generation: Uncomment this block if you switch to Img2ImgPipeline
133
+ # # The 'strength' parameter controls how much noise is added to the input image.
134
+ # # 0.0 means no change, 1.0 means complete re-imagining (like text-to-image).
135
+ # # A value around 0.7-0.8 is typical for style transfer.
136
+ # strength = 0.75
137
+ # generated_image = pipe(
138
+ # prompt=prompt,
139
+ # image=image_input, # Pass the uploaded image here for img2img
140
+ # negative_prompt=negative_prompt,
141
+ # num_inference_steps=num_inference_steps,
142
+ # guidance_scale=guidance_scale,
143
+ # strength=strength
144
+ # ).images[0]
145
+ else:
146
+ raise ValueError("Unsupported pipeline type. Please check model loading.")
147
+
148
+ print(f"[{os.getpid()}] Image generation complete.")
149
  return generated_image
150
 
151
  except Exception as e:
152
+ print(f"[{os.getpid()}] Error during image generation: {e}")
153
+ # Display an error message to the user in the Gradio interface
154
+ gr.Error(f"An error occurred during image generation: {e}")
155
+ return None # Return None to clear the output image
156
 
157
+ # --- Gradio Interface Definition ---
158
  with gr.Blocks() as demo:
159
  gr.Markdown("## 🎨 Stable Diffusion Avatar Generator with Preset Styles (CPU Optimized)")
160
  gr.Markdown(
161
  "This demo uses a smaller, distilled Stable Diffusion model and is optimized for CPU inference. "
162
  "Generation will still take time on CPU compared to GPU (e.g., 20-60 seconds per image depending on CPU and parameters).<br>"
163
+ "**Note:** The uploaded image is currently used only to trigger generation and is **not directly influencing the avatar's appearance**. "
164
+ "It's here for your reference or potential future Image-to-Image features. You will get a new person in the chosen style."
165
  )
166
 
167
  with gr.Row():
168
  with gr.Column():
169
+ # Image input component. type="pil" ensures a PIL Image object is passed to the function.
170
  image_input = gr.Image(
171
  label="Upload your photo",
172
  type="pil",
173
  sources=["upload", "webcam"], # Allow file upload or webcam capture
174
+ # Optional: Add a placeholder image path if you want a default visual
175
+ # value="assets/placeholder.jpg"
176
  )
177
  style_selector = gr.Radio(
178
  choices=list(styles.keys()),
179
  label="Choose a style",
180
+ value="Anime", # Default selected style
181
+ info="Select the artistic style for your avatar."
182
+ )
183
+ gender_selector = gr.Radio(
184
+ choices=["male", "female", "unspecified"],
185
+ label="Choose a Gender",
186
+ value="male", # Default to male to address your specific issue
187
+ info="Explicitly set the gender of the generated person. 'Unspecified' may lead to biased results from the model."
188
  )
189
  generate_btn = gr.Button("Generate Avatar", variant="primary")
190
 
191
  with gr.Column():
192
  output_image = gr.Image(label="Generated Avatar")
193
 
194
+ # Connect the button click to the generation function, passing all inputs
195
  generate_btn.click(
196
  fn=generate_avatar,
197
+ inputs=[image_input, style_selector, gender_selector], # Now includes gender_selector
198
  outputs=output_image
199
  )
200
 
201
+ # Optional: Add examples for quick testing
202
  gr.Examples(
203
  examples=[
204
+ # Example format: [image_path_or_None, style_name, gender]
205
+ # Use None for image_input as it's not directly influencing the output in text-to-image mode
206
+ [None, "Pixar", "male"],
207
+ [None, "Anime", "female"],
208
+ [None, "Cyberpunk", "unspecified"], # To show what 'unspecified' might produce
209
+ [None, "Disney", "male"],
210
+ [None, "Sketch", "female"],
211
+ [None, "Astronaut", "male"]
212
  ],
213
+ inputs=[image_input, style_selector, gender_selector],
214
+ # fn=generate_avatar, # Uncomment if you want examples to run the generation live
215
+ # outputs=output_image,
216
+ cache_examples=False, # Set to True if examples are pre-computed images, False for live generation
217
  label="Quick Examples (Generates new images each time)"
218
  )
219
 
220
  # Launch the Gradio application
221
+ # share=True will generate a public link (useful for sharing demos temporarily)
222
+ # auth=("username", "password") for basic authentication
223
+ demo.launch(inbrowser=True, show_error=True)