b2bomber commited on
Commit
14d4d16
·
verified ·
1 Parent(s): abf7663

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -68
app.py CHANGED
@@ -2,58 +2,63 @@ import gradio as gr
2
  import torch
3
  from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderTiny
4
  from PIL import Image
 
5
 
6
- # 1. Force CPU usage
 
7
  device = "cpu"
8
 
9
- # 2. Choose a smaller/distilled Stable Diffusion model
10
- # 'nota-ai/bk-sdm-small' is a good example of a distilled model that's faster.
11
- # Another option is 'segmind/SSD-1B' (though still relatively large, it's optimized).
12
- # For truly tiny models, you might look for "TinySD" variations.
13
- # Let's start with a well-known distilled model for better CPU performance.
14
- model_id = "nota-ai/bk-sdm-small" # Smaller and faster than SD 2.1
15
- # model_id = "segmind/SSD-1B" # Another optimized, but still larger, option.
16
 
17
- # Load the pipeline. For CPU, use torch_dtype=torch.float32.
18
- # Disable safe_serialization if you encounter issues with some older models.
 
 
 
19
  print(f"Loading model: {model_id} on {device}...")
20
  try:
 
 
21
  pipe = StableDiffusionPipeline.from_pretrained(
22
  model_id,
23
- torch_dtype=torch.float32, # CPU usually prefers float32 for stability/speed unless specialized kernels are used
24
- low_cpu_mem_usage=True # Helps with memory on CPU
25
- )
26
- except Exception as e:
27
- print(f"Error loading model {model_id}: {e}. Trying without low_cpu_mem_usage.")
28
- pipe = StableDiffusionPipeline.from_pretrained(
29
- model_id,
30
- torch_dtype=torch.float32,
31
  )
 
32
 
33
- # Optimize VAE (Very Important for Speed and Memory on CPU)
34
- # The VAE (Variational AutoEncoder) is a bottleneck. Using a tiny VAE helps a lot.
35
- # 'sayakpaul/taesd-diffusers' is a known tiny VAE.
36
- print("Loading Tiny VAE...")
37
- try:
38
- pipe.vae = AutoencoderTiny.from_pretrained("sayakpaul/taesd-diffusers", torch_dtype=torch.float32)
39
- except Exception as e:
40
- print(f"Could not load Tiny VAE: {e}. Model might be slower.")
41
- # Fallback: if Tiny VAE fails, ensure the default VAE is on CPU
42
- pipe.vae.to(device)
43
 
 
 
44
 
45
- # Move pipeline components to CPU explicitly
46
- pipe.to(device)
47
 
48
- # Set up the scheduler. DDIMScheduler is fine.
49
- pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
 
50
 
51
- # Enable CPU offload for even lower memory (can make it slower, but might be necessary for very limited RAM)
52
- # pipe.enable_sequential_cpu_offload() # Use if you hit OOM errors, but it will be much slower.
53
 
54
- print("Model loaded and configured.")
 
 
 
55
 
56
- # Preset styles (same as before)
57
  styles = {
58
  "Pixar": "pixar style portrait of",
59
  "Anime": "anime style portrait of",
@@ -63,50 +68,101 @@ styles = {
63
  "Astronaut": "realistic astronaut with helmet, portrait of"
64
  }
65
 
66
- def generate_avatar(image, style):
67
- if image is None:
68
- # You might want to generate a default image or throw an error via Gradio
69
- # For a more robust app, consider a placeholder image or a clear error message in the UI.
 
 
 
 
70
  gr.Warning("Please upload an image to generate an avatar.")
71
  return None
72
 
73
- # Although the original intent was image-to-image, your current logic
74
- # converts the image input into a text-only prompt.
75
- # To truly use the image as input, you would need an img2img pipeline or a specific
76
- # controlnet/adapter for Stable Diffusion.
77
- # For now, let's keep it as a text-to-image generation based on the style and a generic prompt.
78
-
79
  base_prompt = styles[style]
80
- # For CPU, fewer steps and lower guidance scale can yield faster (but potentially lower quality) results.
81
- num_inference_steps = 20 # Reduced for speed
82
- guidance_scale = 7.0 # Slightly reduced guidance
83
 
84
- prompt = f"{base_prompt} a person, high quality, detailed, professional" # Enhance prompt
85
- negative_prompt = "low resolution, blurry, distorted, bad quality, ugly, cartoon, sketch" # Add negative prompt for better results
86
-
87
- # Generate image
88
- print(f"Generating for style: {style} with prompt: {prompt}")
89
- with torch.no_grad(): # Disable gradient calculations for inference
90
- generated_image = pipe(
91
- prompt=prompt,
92
- negative_prompt=negative_prompt,
93
- num_inference_steps=num_inference_steps,
94
- guidance_scale=guidance_scale
95
- ).images[0]
96
-
97
- return generated_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
 
99
  with gr.Blocks() as demo:
100
  gr.Markdown("## 🎨 Stable Diffusion Avatar Generator with Preset Styles (CPU Optimized)")
101
- gr.Markdown("This demo uses a smaller, distilled Stable Diffusion model and is optimized for CPU inference. Generation will still take time on CPU, but should be faster than larger models.")
 
 
 
 
 
 
102
  with gr.Row():
103
  with gr.Column():
104
- image_input = gr.Image(label="Upload your photo (Note: Image currently used only to trigger generation, not as direct input)", type="pil", sources=["upload", "webcam"])
105
- style_selector = gr.Radio(choices=list(styles.keys()), label="Choose a style", value="Anime")
106
- generate_btn = gr.Button("Generate Avatar")
 
 
 
 
 
 
 
 
 
 
 
107
  with gr.Column():
108
  output_image = gr.Image(label="Generated Avatar")
109
 
110
- generate_btn.click(fn=generate_avatar, inputs=[image_input, style_selector], outputs=output_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
 
112
  demo.launch()
 
2
  import torch
3
  from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderTiny
4
  from PIL import Image
5
+ import os # For better logging/debugging
6
 
7
+ # --- Configuration ---
8
+ # 1. Force CPU usage for compatibility on Spaces without GPU
9
  device = "cpu"
10
 
11
+ # 2. Choose a smaller/distilled Stable Diffusion model for CPU speed
12
+ # 'nota-ai/bk-sdm-small' is a good balance of size/speed/quality for CPU.
13
+ # If quality is paramount and you can tolerate more time, consider 'runwayml/stable-diffusion-v1-5'
14
+ # but expect significantly slower generation times on CPU.
15
+ model_id = "nota-ai/bk-sdm-small"
 
 
16
 
17
+ # 3. Tiny VAE for drastically faster encoding/decoding on CPU
18
+ tiny_vae_id = "sayakpaul/taesd-diffusers"
19
+
20
+ # --- Model Loading ---
21
+ # Load the pipeline globally to avoid reloading on each request
22
  print(f"Loading model: {model_id} on {device}...")
23
  try:
24
+ # Use StableDiffusionPipeline for Text-to-Image generation
25
+ # If you want Image-to-Image, you'd use StableDiffusionImg2ImgPipeline here.
26
  pipe = StableDiffusionPipeline.from_pretrained(
27
  model_id,
28
+ torch_dtype=torch.float32, # CPU usually prefers float32 for stability/speed
29
+ low_cpu_mem_usage=True, # Helps with memory on CPU
30
+ safety_checker=None # Disable safety checker to save CPU cycles and memory
 
 
 
 
 
31
  )
32
+ print("Main pipeline loaded.")
33
 
34
+ # Load and assign the Tiny VAE for speed optimization
35
+ print(f"Loading Tiny VAE from {tiny_vae_id}...")
36
+ try:
37
+ pipe.vae = AutoencoderTiny.from_pretrained(tiny_vae_id, torch_dtype=torch.float32)
38
+ print("Tiny VAE loaded successfully.")
39
+ except Exception as vae_e:
40
+ print(f"Warning: Could not load Tiny VAE '{tiny_vae_id}': {vae_e}. Using default VAE (might be slower).")
41
+ # Ensure default VAE is on CPU
42
+ pipe.vae.to(device)
 
43
 
44
+ # Move entire pipeline to CPU explicitly
45
+ pipe.to(device)
46
 
47
+ # Set up the scheduler. DDIMScheduler is a good choice.
48
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
49
 
50
+ # Optional: Enable CPU offload if you run into Out-Of-Memory errors on CPU with larger models.
51
+ # Be aware: This will make generation *much* slower.
52
+ # pipe.enable_sequential_cpu_offload()
53
 
54
+ print("Model loaded and configured successfully.")
 
55
 
56
+ except Exception as e:
57
+ print(f"FATAL ERROR: Failed to load models: {e}")
58
+ # Raise an exception to prevent the app from starting if model loading fails
59
+ raise RuntimeError(f"Failed to load Stable Diffusion model: {e}")
60
 
61
+ # --- Preset Styles ---
62
  styles = {
63
  "Pixar": "pixar style portrait of",
64
  "Anime": "anime style portrait of",
 
68
  "Astronaut": "realistic astronaut with helmet, portrait of"
69
  }
70
 
71
+ # --- Generation Function ---
72
+ def generate_avatar(image_input: Image.Image, style: str):
73
+ """
74
+ Generates an avatar based on a chosen style using Stable Diffusion.
75
+ Note: In this text-to-image setup, the uploaded `image_input` is used
76
+ only to trigger the generation, not to influence the image content directly.
77
+ """
78
+ if image_input is None:
79
  gr.Warning("Please upload an image to generate an avatar.")
80
  return None
81
 
82
+ # Base prompt from selected style
 
 
 
 
 
83
  base_prompt = styles[style]
 
 
 
84
 
85
+ # Enhance prompt for better quality
86
+ prompt = f"{base_prompt} a person, highly detailed, professional, studio lighting, volumetric lighting, 4k, cinematic"
87
+ negative_prompt = "low resolution, blurry, distorted, bad quality, ugly, cartoon, sketch, duplicate, out of frame, bad anatomy, deformed, extra limbs, watermark, text"
88
+
89
+ # Inference parameters (adjusted for speed on CPU, can be tweaked for quality)
90
+ num_inference_steps = 25 # Increased slightly for better quality, balance with speed
91
+ guidance_scale = 7.5 # Slightly increased for stronger adherence to prompt
92
+
93
+ print(f"Generating for style: {style} with prompt: '{prompt}' (Steps: {num_inference_steps}, Guidance: {guidance_scale})")
94
+
95
+ try:
96
+ # Use torch.no_grad() for efficient inference (disables gradient calculations)
97
+ with torch.no_grad(): # Or torch.inference_mode() for PyTorch >= 1.9
98
+ generated_image = pipe(
99
+ prompt=prompt,
100
+ negative_prompt=negative_prompt,
101
+ num_inference_steps=num_inference_steps,
102
+ guidance_scale=guidance_scale,
103
+ height=512, # Explicitly set output dimensions, can try 768 for SD 2.1 or larger models
104
+ width=512
105
+ ).images[0]
106
+
107
+ print("Image generation complete.")
108
+ return generated_image
109
+
110
+ except Exception as e:
111
+ print(f"Error during image generation: {e}")
112
+ gr.Error(f"An error occurred during generation: {e}")
113
+ return None
114
 
115
+ # --- Gradio Interface ---
116
  with gr.Blocks() as demo:
117
  gr.Markdown("## 🎨 Stable Diffusion Avatar Generator with Preset Styles (CPU Optimized)")
118
+ gr.Markdown(
119
+ "This demo uses a smaller, distilled Stable Diffusion model and is optimized for CPU inference. "
120
+ "Generation will still take time on CPU compared to GPU (e.g., 20-60 seconds per image depending on CPU and parameters).<br>"
121
+ "**Note:** The uploaded image is currently used only to trigger generation and is not directly influencing the avatar's appearance. "
122
+ "It's here for user reference or potential future Image-to-Image features."
123
+ )
124
+
125
  with gr.Row():
126
  with gr.Column():
127
+ # Image input component (type="pil" for Pillow Image object)
128
+ image_input = gr.Image(
129
+ label="Upload your photo",
130
+ type="pil",
131
+ sources=["upload", "webcam"], # Allow file upload or webcam capture
132
+ # You might want to set a default for testing: value="path/to/default_image.jpg"
133
+ )
134
+ style_selector = gr.Radio(
135
+ choices=list(styles.keys()),
136
+ label="Choose a style",
137
+ value="Anime" # Default selected style
138
+ )
139
+ generate_btn = gr.Button("Generate Avatar", variant="primary")
140
+
141
  with gr.Column():
142
  output_image = gr.Image(label="Generated Avatar")
143
 
144
+ # Connect the button click to the generation function
145
+ generate_btn.click(
146
+ fn=generate_avatar,
147
+ inputs=[image_input, style_selector],
148
+ outputs=output_image
149
+ )
150
+
151
+ gr.Examples(
152
+ examples=[
153
+ [None, "Pixar"],
154
+ [None, "Anime"],
155
+ [None, "Cyberpunk"],
156
+ [None, "Disney"],
157
+ [None, "Sketch"],
158
+ [None, "Astronaut"]
159
+ ],
160
+ inputs=[image_input, style_selector],
161
+ fn=generate_avatar,
162
+ outputs=output_image,
163
+ cache_examples=False, # Set to True if examples are pre-computed, False for live generation
164
+ label="Quick Examples (Generates new images each time)"
165
+ )
166
 
167
+ # Launch the Gradio application
168
  demo.launch()