anirudh97 commited on
Commit
ec9e2d7
·
1 Parent(s): b085f1f

cpu_optimize

Browse files
Files changed (1) hide show
  1. app.py +243 -140
app.py CHANGED
@@ -8,39 +8,67 @@ from tqdm.auto import tqdm
8
  import torchvision.transforms as T
9
  import torch.nn.functional as F
10
  import gc
 
 
 
11
 
12
- # Configure constants
13
- HEIGHT, WIDTH = 512, 512
14
- GUIDANCE_SCALE = 8
15
  LOSS_SCALE = 200
16
- NUM_INFERENCE_STEPS = 50
17
  BATCH_SIZE = 1
18
  DEFAULT_PROMPT = "A deadly witcher slinging a sword with a lion medallion in his neck, casting a fire spell from his hand in a snowy forest"
19
 
20
  # Define the device
21
  TORCH_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
 
22
 
23
  # Initialize the elastic transformer
24
  elastic_transformer = T.ElasticTransform(alpha=550.0, sigma=5.0)
25
 
 
 
 
 
26
  # Load the model
27
  def load_model():
28
- pipe = DiffusionPipeline.from_pretrained(
29
- "CompVis/stable-diffusion-v1-4",
30
- torch_dtype=torch.float16 if TORCH_DEVICE == "cuda" else torch.float32
31
- ).to(TORCH_DEVICE)
32
-
33
- # Load textual inversion concepts
34
  try:
35
- pipe.load_textual_inversion("sd-concepts-library/rimworld-art-style", mean_resizing=False)
36
- pipe.load_textual_inversion("sd-concepts-library/hk-goldenlantern", mean_resizing=False)
37
- pipe.load_textual_inversion("sd-concepts-library/phoenix-01", mean_resizing=False)
38
- pipe.load_textual_inversion("sd-concepts-library/fractal-flame", mean_resizing=False)
39
- pipe.load_textual_inversion("sd-concepts-library/scarlet-witch", mean_resizing=False)
40
- except Exception as e:
41
- print(f"Warning: Could not load all textual inversion concepts: {e}")
42
 
43
- return pipe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  # Helper functions
46
  def image_grid(imgs, rows, cols):
@@ -85,145 +113,178 @@ def latents_to_pil(latents, pipe):
85
  return pil_images
86
 
87
  def generate_image(pipe, seed_no, prompts, loss_type, loss_apply=False, progress=gr.Progress()):
88
- # Initialization and Setup
89
- generator = torch.manual_seed(seed_no)
 
 
 
 
 
 
90
 
91
- scheduler = LMSDiscreteScheduler(
92
- beta_start=0.00085,
93
- beta_end=0.012,
94
- beta_schedule="scaled_linear",
95
- num_train_timesteps=1000
96
- )
97
- scheduler.set_timesteps(NUM_INFERENCE_STEPS)
98
- scheduler.timesteps = scheduler.timesteps.to(torch.float32)
99
 
100
- # Text Processing
101
- text_input = pipe.tokenizer(
102
- prompts,
103
- padding='max_length',
104
- max_length=pipe.tokenizer.model_max_length,
105
- truncation=True,
106
- return_tensors="pt"
107
- )
108
- input_ids = text_input.input_ids.to(TORCH_DEVICE)
109
 
110
- # Convert text inputs to embeddings
111
- with torch.no_grad():
112
- text_embeddings = pipe.text_encoder(input_ids)[0]
113
 
114
- # Handle padding and truncation of text inputs
115
- max_length = text_input.input_ids.shape[-1]
116
- uncond_input = pipe.tokenizer(
117
- [""] * BATCH_SIZE,
118
- padding="max_length",
119
- max_length=max_length,
120
- return_tensors="pt"
121
- )
122
 
123
- with torch.no_grad():
124
- uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(TORCH_DEVICE))[0]
125
 
126
- # Concatenate unconditioned and text embeddings
127
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
128
 
129
- # Create random initial latents
130
- latents = torch.randn(
131
- (BATCH_SIZE, pipe.unet.config.in_channels, HEIGHT // 8, WIDTH // 8),
132
- generator=generator,
133
- )
134
 
135
- # Move latents to device and apply noise scaling
136
- if TORCH_DEVICE == "cuda":
137
- latents = latents.to(torch.float16)
138
- latents = latents.to(TORCH_DEVICE)
139
- latents = latents * scheduler.init_noise_sigma
140
 
141
- # Diffusion Process
142
- timesteps = scheduler.timesteps
143
- progress(0, desc="Generating")
144
-
145
- # Fixed loop - separate the progress tracking from the enumeration
146
- for i in range(len(timesteps)):
147
- progress((i + 1) / len(timesteps), desc=f"Diffusion step {i+1}/{len(timesteps)}")
148
- t = timesteps[i]
149
 
150
- # Process the latent model input
151
- latent_model_input = torch.cat([latents] * 2)
152
- sigma = scheduler.sigmas[i]
153
- latent_model_input = scheduler.scale_model_input(latent_model_input, t)
 
 
 
 
 
154
 
155
- with torch.no_grad():
156
- noise_pred = pipe.unet(
157
- latent_model_input,
158
- t,
159
- encoder_hidden_states=text_embeddings
160
- )["sample"]
161
 
162
- # Apply noise prediction
163
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
164
- noise_pred = noise_pred_uncond + GUIDANCE_SCALE * (noise_pred_text - noise_pred_uncond)
165
 
166
- # Apply loss if requested
167
- if loss_apply and i % 5 == 0:
168
- latents = latents.detach().requires_grad_()
169
- latents_x0 = latents - sigma * noise_pred
170
 
171
- # Use VAE to decode the image
172
- denoised_images = pipe.vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5
173
 
174
- # Apply loss
175
- loss = image_loss(denoised_images, loss_type) * LOSS_SCALE
176
- print(f"Step {i}, Loss: {loss.item()}")
177
 
178
- # Compute gradients for optimization
179
- cond_grad = torch.autograd.grad(loss, latents)[0]
180
- latents = latents.detach() - cond_grad * sigma**2
181
 
182
- # Update latents using the scheduler
183
- latents = scheduler.step(noise_pred, t, latents).prev_sample
 
 
 
 
184
 
185
- return latents
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
- def generate_images(prompt, loss_type, apply_loss, seeds, pipe):
188
- latents_collect = []
189
-
190
- # Convert comma-separated string to list and clean
191
- seeds = [int(seed.strip()) for seed in seeds.split(',') if seed.strip()]
192
-
193
- if not seeds:
194
- seeds = [1000] # Default seed if none provided
195
 
196
- # List of SD concepts (can be empty if not used)
197
- sdconcepts = [''] * len(seeds)
198
-
199
- # Generate images for each seed
200
- for seed_no, sd in zip(seeds, sdconcepts):
201
- # Clear CUDA cache
202
- if TORCH_DEVICE == "cuda":
203
- torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  gc.collect()
205
- torch.cuda.empty_cache()
206
 
207
- # Generate image
208
- prompts = [f'{prompt} {sd}']
209
- latents = generate_image(pipe, seed_no, prompts, loss_type, loss_apply=apply_loss)
210
- latents_collect.append(latents)
211
-
212
- # Stack latents and convert to images
213
- latents_collect = torch.vstack(latents_collect)
214
- images = latents_to_pil(latents_collect, pipe)
215
-
216
- # Create image grid
217
- if len(images) > 1:
218
- result = image_grid(images, 1, len(images))
219
- return result
220
- else:
221
- return images[0]
222
 
223
  # Gradio Interface
224
  def create_interface():
225
- pipe = load_model()
226
-
227
  with gr.Blocks(title="Stable Diffusion Text Inversion with Loss Functions") as app:
228
  gr.Markdown("""
229
  # Stable Diffusion Text Inversion with Loss Functions
@@ -231,6 +292,14 @@ def create_interface():
231
  Generate images using Stable Diffusion with various loss functions to guide the diffusion process.
232
  """)
233
 
 
 
 
 
 
 
 
 
234
  with gr.Row():
235
  with gr.Column():
236
  prompt = gr.Textbox(
@@ -250,19 +319,44 @@ def create_interface():
250
  value=False
251
  )
252
 
253
- seeds = gr.Textbox(
254
- label="Seeds (comma-separated)",
255
- value="3000,2000,1000",
256
- lines=1
257
- )
 
 
 
 
 
 
 
258
 
259
- generate_btn = gr.Button("Generate Images")
 
 
 
 
260
 
261
  with gr.Column():
262
  output_image = gr.Image(label="Generated Image")
 
 
 
 
 
 
 
 
263
 
 
 
 
 
 
 
264
  generate_btn.click(
265
- fn=lambda p, lt, al, s: generate_images(p, lt, al, s, pipe),
266
  inputs=[prompt, loss_type, apply_loss, seeds],
267
  outputs=output_image
268
  )
@@ -275,8 +369,17 @@ def create_interface():
275
  - **Symmetry**: Encourages symmetrical images by minimizing differences with horizontally flipped versions
276
  - **Saturation**: Increases color saturation in the image
277
 
278
- Set "N/A" and unchecks "Apply Loss Function" for normal image generation.
279
  """)
 
 
 
 
 
 
 
 
 
280
 
281
  return app
282
 
 
8
  import torchvision.transforms as T
9
  import torch.nn.functional as F
10
  import gc
11
+ import signal
12
+ import time
13
+ import traceback
14
 
15
+ # Configure constants - optimized for CPU
16
+ HEIGHT, WIDTH = 384, 384 # Smaller images use less memory
17
+ GUIDANCE_SCALE = 7.5
18
  LOSS_SCALE = 200
19
+ NUM_INFERENCE_STEPS = 30 # Reduced from 50
20
  BATCH_SIZE = 1
21
  DEFAULT_PROMPT = "A deadly witcher slinging a sword with a lion medallion in his neck, casting a fire spell from his hand in a snowy forest"
22
 
23
  # Define the device
24
  TORCH_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
25
+ print(f"Using device: {TORCH_DEVICE}")
26
 
27
  # Initialize the elastic transformer
28
  elastic_transformer = T.ElasticTransform(alpha=550.0, sigma=5.0)
29
 
30
+ # Timeout handler for CPU processing
31
+ def timeout_handler(signum, frame):
32
+ raise TimeoutError("Image generation took too long")
33
+
34
  # Load the model
35
  def load_model():
 
 
 
 
 
 
36
  try:
37
+ pipe = DiffusionPipeline.from_pretrained(
38
+ "CompVis/stable-diffusion-v1-4",
39
+ torch_dtype=torch.float16 if TORCH_DEVICE == "cuda" else torch.float32,
40
+ safety_checker=None, # Disable safety checker for memory
41
+ low_cpu_mem_usage=True # Enable memory optimization
42
+ ).to(TORCH_DEVICE)
 
43
 
44
+ # Load textual inversion for all devices including CPU
45
+ try:
46
+ # Load one at a time with memory cleanup between each
47
+ concepts = [
48
+ "sd-concepts-library/rimworld-art-style",
49
+ "sd-concepts-library/hk-goldenlantern",
50
+ "sd-concepts-library/phoenix-01",
51
+ "sd-concepts-library/fractal-flame",
52
+ "sd-concepts-library/scarlet-witch"
53
+ ]
54
+
55
+ for concept in concepts:
56
+ try:
57
+ print(f"Loading textual inversion concept: {concept}")
58
+ pipe.load_textual_inversion(concept, mean_resizing=False)
59
+ # Clear memory after loading each concept
60
+ if TORCH_DEVICE == "cpu":
61
+ gc.collect()
62
+ except Exception as e:
63
+ print(f"Warning: Could not load textual inversion concept {concept}: {e}")
64
+ except Exception as e:
65
+ print(f"Warning: Could not load textual inversion concepts: {e}")
66
+
67
+ return pipe
68
+ except Exception as e:
69
+ print(f"Error loading model: {e}")
70
+ traceback.print_exc()
71
+ raise
72
 
73
  # Helper functions
74
  def image_grid(imgs, rows, cols):
 
113
  return pil_images
114
 
115
  def generate_image(pipe, seed_no, prompts, loss_type, loss_apply=False, progress=gr.Progress()):
116
+ try:
117
+ # Set timeout for CPU
118
+ if TORCH_DEVICE == "cpu":
119
+ signal.signal(signal.SIGALRM, timeout_handler)
120
+ signal.alarm(600) # 10 minute timeout
121
+
122
+ # Initialization and Setup
123
+ generator = torch.manual_seed(seed_no)
124
 
125
+ scheduler = LMSDiscreteScheduler(
126
+ beta_start=0.00085,
127
+ beta_end=0.012,
128
+ beta_schedule="scaled_linear",
129
+ num_train_timesteps=1000
130
+ )
131
+ scheduler.set_timesteps(NUM_INFERENCE_STEPS)
132
+ scheduler.timesteps = scheduler.timesteps.to(torch.float32)
133
 
134
+ # Text Processing
135
+ text_input = pipe.tokenizer(
136
+ prompts,
137
+ padding='max_length',
138
+ max_length=pipe.tokenizer.model_max_length,
139
+ truncation=True,
140
+ return_tensors="pt"
141
+ )
142
+ input_ids = text_input.input_ids.to(TORCH_DEVICE)
143
 
144
+ # Convert text inputs to embeddings
145
+ with torch.no_grad():
146
+ text_embeddings = pipe.text_encoder(input_ids)[0]
147
 
148
+ # Handle padding and truncation of text inputs
149
+ max_length = text_input.input_ids.shape[-1]
150
+ uncond_input = pipe.tokenizer(
151
+ [""] * BATCH_SIZE,
152
+ padding="max_length",
153
+ max_length=max_length,
154
+ return_tensors="pt"
155
+ )
156
 
157
+ with torch.no_grad():
158
+ uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(TORCH_DEVICE))[0]
159
 
160
+ # Concatenate unconditioned and text embeddings
161
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
162
 
163
+ # Create random initial latents
164
+ latents = torch.randn(
165
+ (BATCH_SIZE, pipe.unet.config.in_channels, HEIGHT // 8, WIDTH // 8),
166
+ generator=generator,
167
+ )
168
 
169
+ # Move latents to device and apply noise scaling
170
+ if TORCH_DEVICE == "cuda":
171
+ latents = latents.to(torch.float16)
172
+ latents = latents.to(TORCH_DEVICE)
173
+ latents = latents * scheduler.init_noise_sigma
174
 
175
+ # Diffusion Process
176
+ timesteps = scheduler.timesteps
177
+ progress(0, desc="Generating")
 
 
 
 
 
178
 
179
+ # Fixed loop - separate the progress tracking from the enumeration
180
+ for i in range(len(timesteps)):
181
+ progress((i + 1) / len(timesteps), desc=f"Diffusion step {i+1}/{len(timesteps)}")
182
+ t = timesteps[i]
183
+
184
+ # Process the latent model input
185
+ latent_model_input = torch.cat([latents] * 2)
186
+ sigma = scheduler.sigmas[i]
187
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
188
 
189
+ with torch.no_grad():
190
+ noise_pred = pipe.unet(
191
+ latent_model_input,
192
+ t,
193
+ encoder_hidden_states=text_embeddings
194
+ )["sample"]
195
 
196
+ # Apply noise prediction
197
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
198
+ noise_pred = noise_pred_uncond + GUIDANCE_SCALE * (noise_pred_text - noise_pred_uncond)
199
 
200
+ # Apply loss if requested
201
+ if loss_apply and i % 5 == 0 and loss_type != "N/A":
202
+ latents = latents.detach().requires_grad_()
203
+ latents_x0 = latents - sigma * noise_pred
204
 
205
+ # Use VAE to decode the image
206
+ denoised_images = pipe.vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5
207
 
208
+ # Apply loss
209
+ loss = image_loss(denoised_images, loss_type) * LOSS_SCALE
210
+ print(f"Step {i}, Loss: {loss.item()}")
211
 
212
+ # Compute gradients for optimization
213
+ cond_grad = torch.autograd.grad(loss, latents)[0]
214
+ latents = latents.detach() - cond_grad * sigma**2
215
 
216
+ # Update latents using the scheduler
217
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
218
+
219
+ # Garbage collect every 5 steps if on CPU
220
+ if TORCH_DEVICE == "cpu" and i % 5 == 0:
221
+ gc.collect()
222
 
223
+ # Clear the alarm if set
224
+ if TORCH_DEVICE == "cpu":
225
+ signal.alarm(0)
226
+
227
+ return latents
228
+
229
+ except Exception as e:
230
+ print(f"Error in generate_image: {e}")
231
+ traceback.print_exc()
232
+ # Return empty latents as fallback
233
+ return torch.zeros(
234
+ (BATCH_SIZE, pipe.unet.config.in_channels, HEIGHT // 8, WIDTH // 8),
235
+ device=TORCH_DEVICE
236
+ )
237
 
238
+ def generate_images(prompt, loss_type, apply_loss, seeds, pipe, progress=gr.Progress()):
239
+ try:
240
+ images_list = []
 
 
 
 
 
241
 
242
+ # Convert comma-separated string to list and clean
243
+ seeds = [int(seed.strip()) for seed in seeds.split(',') if seed.strip()]
244
+
245
+ if not seeds:
246
+ seeds = [1000] # Default seed if none provided
247
+
248
+ # Process one seed at a time to save memory
249
+ for i, seed_no in enumerate(seeds):
250
+ progress((i / len(seeds)) * 0.1, desc=f"Starting seed {seed_no}")
251
+
252
+ # Clear memory
253
+ if TORCH_DEVICE == "cuda":
254
+ torch.cuda.empty_cache()
255
+ gc.collect()
256
+
257
+ try:
258
+ # Generate image
259
+ prompts = [prompt]
260
+ latents = generate_image(pipe, seed_no, prompts, loss_type, loss_apply=apply_loss, progress=progress)
261
+ pil_images = latents_to_pil(latents, pipe)
262
+ images_list.extend(pil_images)
263
+ except Exception as e:
264
+ print(f"Error generating image with seed {seed_no}: {e}")
265
+ # Create an error image
266
+ error_img = Image.new('RGB', (HEIGHT, WIDTH), color=(255, 0, 0))
267
+ images_list.append(error_img)
268
+
269
+ # Force garbage collection
270
  gc.collect()
 
271
 
272
+ # Create image grid
273
+ if len(images_list) > 1:
274
+ result = image_grid(images_list, 1, len(images_list))
275
+ return result
276
+ else:
277
+ return images_list[0]
278
+
279
+ except Exception as e:
280
+ print(f"Error in generate_images: {e}")
281
+ traceback.print_exc()
282
+ # Create an error image
283
+ error_img = Image.new('RGB', (WIDTH, HEIGHT), color=(255, 0, 0))
284
+ return error_img
 
 
285
 
286
  # Gradio Interface
287
  def create_interface():
 
 
288
  with gr.Blocks(title="Stable Diffusion Text Inversion with Loss Functions") as app:
289
  gr.Markdown("""
290
  # Stable Diffusion Text Inversion with Loss Functions
 
292
  Generate images using Stable Diffusion with various loss functions to guide the diffusion process.
293
  """)
294
 
295
+ if TORCH_DEVICE == "cpu":
296
+ gr.Markdown("""
297
+ ⚠️ **Running on CPU**: Generation will be slow and memory-intensive.
298
+ Each image may take several minutes to generate.
299
+ """)
300
+
301
+ pipe = None # Initialize to None to avoid loading during interface creation
302
+
303
  with gr.Row():
304
  with gr.Column():
305
  prompt = gr.Textbox(
 
319
  value=False
320
  )
321
 
322
+ if TORCH_DEVICE == "cpu":
323
+ seeds = gr.Textbox(
324
+ label="Seeds (comma-separated) - Use fewer seeds for CPU",
325
+ value="1000",
326
+ lines=1
327
+ )
328
+ else:
329
+ seeds = gr.Textbox(
330
+ label="Seeds (comma-separated)",
331
+ value="3000,2000,1000",
332
+ lines=1
333
+ )
334
 
335
+ # Load model button
336
+ load_model_btn = gr.Button("Load Model")
337
+ model_status = gr.Textbox(label="Model Status", value="Model not loaded", interactive=False)
338
+
339
+ generate_btn = gr.Button("Generate Images", interactive=False)
340
 
341
  with gr.Column():
342
  output_image = gr.Image(label="Generated Image")
343
+
344
+ def load_model_fn():
345
+ nonlocal pipe
346
+ try:
347
+ pipe = load_model()
348
+ return "Model loaded successfully", True
349
+ except Exception as e:
350
+ return f"Error loading model: {str(e)}", False
351
 
352
+ load_model_btn.click(
353
+ fn=load_model_fn,
354
+ inputs=[],
355
+ outputs=[model_status, generate_btn]
356
+ )
357
+
358
  generate_btn.click(
359
+ fn=lambda p, lt, al, s, prog: generate_images(p, lt, al, s, pipe, prog),
360
  inputs=[prompt, loss_type, apply_loss, seeds],
361
  outputs=output_image
362
  )
 
369
  - **Symmetry**: Encourages symmetrical images by minimizing differences with horizontally flipped versions
370
  - **Saturation**: Increases color saturation in the image
371
 
372
+ Set "N/A" and uncheck "Apply Loss Function" for normal image generation.
373
  """)
374
+
375
+ if TORCH_DEVICE == "cpu":
376
+ gr.Markdown("""
377
+ ## CPU Mode Tips
378
+ - Use smaller prompts
379
+ - Process one seed at a time
380
+ - Be patient, generation can take 5-10 minutes per image
381
+ - If you encounter memory errors, try restarting the app and using even smaller dimensions
382
+ """)
383
 
384
  return app
385