File size: 14,989 Bytes
cba7cd1
 
 
 
 
 
 
 
 
 
ec9e2d7
 
 
cba7cd1
ec9e2d7
 
 
cba7cd1
ec9e2d7
cba7cd1
 
 
 
 
ec9e2d7
cba7cd1
 
 
 
ec9e2d7
 
 
 
cba7cd1
 
 
17b1c77
 
 
 
 
ec9e2d7
 
 
 
 
 
cba7cd1
ec9e2d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17b1c77
 
 
 
ec9e2d7
 
 
17b1c77
 
 
 
ec9e2d7
 
 
cba7cd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ec9e2d7
 
 
 
 
 
 
 
cba7cd1
ec9e2d7
 
 
 
 
 
 
 
cba7cd1
ec9e2d7
 
 
 
 
 
 
 
 
cba7cd1
ec9e2d7
 
 
cba7cd1
ec9e2d7
 
 
 
 
 
 
 
cba7cd1
ec9e2d7
 
cba7cd1
ec9e2d7
 
cba7cd1
ec9e2d7
 
 
 
 
cba7cd1
ec9e2d7
 
 
 
 
cba7cd1
ec9e2d7
 
 
0822af2
ec9e2d7
 
 
 
 
 
 
 
 
cba7cd1
ec9e2d7
 
 
 
 
 
cba7cd1
ec9e2d7
 
 
cba7cd1
ec9e2d7
 
 
 
cba7cd1
ec9e2d7
 
cba7cd1
ec9e2d7
 
 
cba7cd1
ec9e2d7
 
 
cba7cd1
ec9e2d7
 
 
 
 
 
cba7cd1
ec9e2d7
 
 
 
 
 
 
 
 
 
 
 
 
 
cba7cd1
ec9e2d7
 
 
cba7cd1
ec9e2d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cba7cd1
 
ec9e2d7
 
 
 
 
 
 
 
 
 
 
 
 
cba7cd1
 
 
 
 
 
 
 
 
 
ec9e2d7
 
 
 
 
 
 
 
cba7cd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ec9e2d7
 
 
 
 
 
 
 
 
 
 
 
cba7cd1
ec9e2d7
 
 
 
 
cba7cd1
 
 
ec9e2d7
 
 
 
 
 
 
 
cba7cd1
ec9e2d7
 
 
 
 
 
cba7cd1
ec9e2d7
cba7cd1
 
 
 
 
 
 
 
 
 
 
 
ec9e2d7
cba7cd1
ec9e2d7
 
 
 
 
 
 
 
 
cba7cd1
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
import os
import torch
import gradio as gr
from PIL import Image
from diffusers import StableDiffusionPipeline, DiffusionPipeline
from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
from tqdm.auto import tqdm
import torchvision.transforms as T
import torch.nn.functional as F
import gc
import signal
import time
import traceback

# Configure constants - optimized for CPU
HEIGHT, WIDTH = 384, 384  # Smaller images use less memory
GUIDANCE_SCALE = 7.5
LOSS_SCALE = 200
NUM_INFERENCE_STEPS = 30  # Reduced from 50
BATCH_SIZE = 1
DEFAULT_PROMPT = "A deadly witcher slinging a sword with a lion medallion in his neck, casting a fire spell from his hand in a snowy forest"

# Define the device
TORCH_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {TORCH_DEVICE}")

# Initialize the elastic transformer
elastic_transformer = T.ElasticTransform(alpha=550.0, sigma=5.0)

# Timeout handler for CPU processing
def timeout_handler(signum, frame):
    raise TimeoutError("Image generation took too long")

# Load the model
def load_model():
    try:
        # Initialize signal handler only on Unix-like systems
        if TORCH_DEVICE == "cpu" and hasattr(signal, 'SIGALRM'):
            signal.signal(signal.SIGALRM, timeout_handler)
            signal.alarm(2100)  # 15 minutes timeout for model loading
        
        pipe = DiffusionPipeline.from_pretrained(
            "CompVis/stable-diffusion-v1-4",
            torch_dtype=torch.float16 if TORCH_DEVICE == "cuda" else torch.float32,
            safety_checker=None,  # Disable safety checker for memory
            low_cpu_mem_usage=True  # Enable memory optimization
        ).to(TORCH_DEVICE)
        
        # Load textual inversion for all devices including CPU
        try:
            # Load one at a time with memory cleanup between each
            concepts = [
                "sd-concepts-library/rimworld-art-style",
                "sd-concepts-library/hk-goldenlantern",
                "sd-concepts-library/phoenix-01",
                "sd-concepts-library/fractal-flame",
                "sd-concepts-library/scarlet-witch"
            ]
            
            for concept in concepts:
                try:
                    print(f"Loading textual inversion concept: {concept}")
                    pipe.load_textual_inversion(concept, mean_resizing=False)
                    # Clear memory after loading each concept
                    if TORCH_DEVICE == "cpu":
                        gc.collect()
                except Exception as e:
                    print(f"Warning: Could not load textual inversion concept {concept}: {e}")
        except Exception as e:
            print(f"Warning: Could not load textual inversion concepts: {e}")
        
        # Clear the alarm if set
        if TORCH_DEVICE == "cpu" and hasattr(signal, 'SIGALRM'):
            signal.alarm(0)
            
        return pipe
    except Exception as e:
        # Clear the alarm if set
        if TORCH_DEVICE == "cpu" and hasattr(signal, 'SIGALRM'):
            signal.alarm(0)
        
        print(f"Error loading model: {e}")
        traceback.print_exc()
        raise

# Helper functions
def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols
    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

def image_loss(images, loss_type):
    if loss_type == 'blue':
        # blue loss
        error = torch.abs(images[:,2] - 0.9).mean()
    elif loss_type == 'elastic':
        # elastic loss
        transformed_imgs = elastic_transformer(images)
        error = torch.abs(transformed_imgs - images).mean()
    elif loss_type == 'symmetry':
        flipped_image = torch.flip(images, [3])
        error = F.mse_loss(images, flipped_image)
    elif loss_type == 'saturation':
        # saturation loss
        transformed_imgs = T.functional.adjust_saturation(images, saturation_factor=10)
        error = torch.abs(transformed_imgs - images).mean()
    else:
        print("Error. Loss not defined")
        error = torch.tensor(0.0)

    return error

def latents_to_pil(latents, pipe):
    # batch of latents -> list of images
    latents = (1 / 0.18215) * latents
    with torch.no_grad():
        image = pipe.vae.decode(latents).sample
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
    images = (image * 255).round().astype("uint8")
    pil_images = [Image.fromarray(image) for image in images]
    return pil_images

def generate_image(pipe, seed_no, prompts, loss_type, loss_apply=False, progress=gr.Progress()):
    try:
        # Set timeout for CPU
        if TORCH_DEVICE == "cpu":
            signal.signal(signal.SIGALRM, timeout_handler)
            signal.alarm(600)  # 10 minute timeout
            
        # Initialization and Setup
        generator = torch.manual_seed(seed_no)

        scheduler = LMSDiscreteScheduler(
            beta_start=0.00085, 
            beta_end=0.012, 
            beta_schedule="scaled_linear", 
            num_train_timesteps=1000
        )
        scheduler.set_timesteps(NUM_INFERENCE_STEPS)
        scheduler.timesteps = scheduler.timesteps.to(torch.float32)

        # Text Processing
        text_input = pipe.tokenizer(
            prompts, 
            padding='max_length', 
            max_length=pipe.tokenizer.model_max_length, 
            truncation=True, 
            return_tensors="pt"
        )
        input_ids = text_input.input_ids.to(TORCH_DEVICE)

        # Convert text inputs to embeddings
        with torch.no_grad():
            text_embeddings = pipe.text_encoder(input_ids)[0]

        # Handle padding and truncation of text inputs
        max_length = text_input.input_ids.shape[-1]
        uncond_input = pipe.tokenizer(
            [""] * BATCH_SIZE, 
            padding="max_length", 
            max_length=max_length, 
            return_tensors="pt"
        )

        with torch.no_grad():
            uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(TORCH_DEVICE))[0]

        # Concatenate unconditioned and text embeddings
        text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

        # Create random initial latents
        latents = torch.randn(
            (BATCH_SIZE, pipe.unet.config.in_channels, HEIGHT // 8, WIDTH // 8),
            generator=generator,
        )

        # Move latents to device and apply noise scaling
        if TORCH_DEVICE == "cuda":
            latents = latents.to(torch.float16)
        latents = latents.to(TORCH_DEVICE)
        latents = latents * scheduler.init_noise_sigma

        # Diffusion Process
        timesteps = scheduler.timesteps
        progress(0, desc="Generating")
        
        # Fixed loop - separate the progress tracking from the enumeration
        for i in range(len(timesteps)):
            progress((i + 1) / len(timesteps), desc=f"Diffusion step {i+1}/{len(timesteps)}")
            t = timesteps[i]
            
            # Process the latent model input
            latent_model_input = torch.cat([latents] * 2)
            sigma = scheduler.sigmas[i]
            latent_model_input = scheduler.scale_model_input(latent_model_input, t)

            with torch.no_grad():
                noise_pred = pipe.unet(
                    latent_model_input, 
                    t, 
                    encoder_hidden_states=text_embeddings
                )["sample"]

            # Apply noise prediction
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + GUIDANCE_SCALE * (noise_pred_text - noise_pred_uncond)

            # Apply loss if requested
            if loss_apply and i % 5 == 0 and loss_type != "N/A":
                latents = latents.detach().requires_grad_()
                latents_x0 = latents - sigma * noise_pred

                # Use VAE to decode the image
                denoised_images = pipe.vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5

                # Apply loss
                loss = image_loss(denoised_images, loss_type) * LOSS_SCALE
                print(f"Step {i}, Loss: {loss.item()}")

                # Compute gradients for optimization
                cond_grad = torch.autograd.grad(loss, latents)[0]
                latents = latents.detach() - cond_grad * sigma**2

            # Update latents using the scheduler
            latents = scheduler.step(noise_pred, t, latents).prev_sample
            
            # Garbage collect every 5 steps if on CPU
            if TORCH_DEVICE == "cpu" and i % 5 == 0:
                gc.collect()

        # Clear the alarm if set
        if TORCH_DEVICE == "cpu":
            signal.alarm(0)
            
        return latents
        
    except Exception as e:
        print(f"Error in generate_image: {e}")
        traceback.print_exc()
        # Return empty latents as fallback
        return torch.zeros(
            (BATCH_SIZE, pipe.unet.config.in_channels, HEIGHT // 8, WIDTH // 8),
            device=TORCH_DEVICE
        )

def generate_images(prompt, loss_type, apply_loss, seeds, pipe, progress=gr.Progress()):
    try:
        images_list = []
        
        # Convert comma-separated string to list and clean
        seeds = [int(seed.strip()) for seed in seeds.split(',') if seed.strip()]
        
        if not seeds:
            seeds = [1000]  # Default seed if none provided
            
        # Process one seed at a time to save memory
        for i, seed_no in enumerate(seeds):
            progress((i / len(seeds)) * 0.1, desc=f"Starting seed {seed_no}")  
            
            # Clear memory
            if TORCH_DEVICE == "cuda":
                torch.cuda.empty_cache()
            gc.collect()
            
            try:
                # Generate image
                prompts = [prompt]
                latents = generate_image(pipe, seed_no, prompts, loss_type, loss_apply=apply_loss, progress=progress)
                pil_images = latents_to_pil(latents, pipe)
                images_list.extend(pil_images)
            except Exception as e:
                print(f"Error generating image with seed {seed_no}: {e}")
                # Create an error image
                error_img = Image.new('RGB', (HEIGHT, WIDTH), color=(255, 0, 0))
                images_list.append(error_img)
                
            # Force garbage collection
            gc.collect()
        
        # Create image grid
        if len(images_list) > 1:
            result = image_grid(images_list, 1, len(images_list))
            return result
        else:
            return images_list[0]
            
    except Exception as e:
        print(f"Error in generate_images: {e}")
        traceback.print_exc()
        # Create an error image
        error_img = Image.new('RGB', (WIDTH, HEIGHT), color=(255, 0, 0))
        return error_img

# Gradio Interface
def create_interface():
    with gr.Blocks(title="Stable Diffusion Text Inversion with Loss Functions") as app:
        gr.Markdown("""
        # Stable Diffusion Text Inversion with Loss Functions
        
        Generate images using Stable Diffusion with various loss functions to guide the diffusion process.
        """)
        
        if TORCH_DEVICE == "cpu":
            gr.Markdown("""
            ⚠️ **Running on CPU**: Generation will be slow and memory-intensive. 
            Each image may take several minutes to generate.
            """)
        
        pipe = None  # Initialize to None to avoid loading during interface creation
        
        with gr.Row():
            with gr.Column():
                prompt = gr.Textbox(
                    label="Prompt", 
                    value=DEFAULT_PROMPT,
                    lines=3
                )
                
                loss_type = gr.Radio(
                    label="Loss Type",
                    choices=["N/A", "blue", "elastic", "symmetry", "saturation"],
                    value="N/A"
                )
                
                apply_loss = gr.Checkbox(
                    label="Apply Loss Function", 
                    value=False
                )
                
                if TORCH_DEVICE == "cpu":
                    seeds = gr.Textbox(
                        label="Seeds (comma-separated) - Use fewer seeds for CPU",
                        value="1000",
                        lines=1
                    )
                else:
                    seeds = gr.Textbox(
                        label="Seeds (comma-separated)",
                        value="3000,2000,1000",
                        lines=1
                    )
                
                # Load model button
                load_model_btn = gr.Button("Load Model")
                model_status = gr.Textbox(label="Model Status", value="Model not loaded", interactive=False)
                
                generate_btn = gr.Button("Generate Images", interactive=False)
                
            with gr.Column():
                output_image = gr.Image(label="Generated Image")
        
        def load_model_fn():
            nonlocal pipe
            try:
                pipe = load_model()
                return "Model loaded successfully", True
            except Exception as e:
                return f"Error loading model: {str(e)}", False
                
        load_model_btn.click(
            fn=load_model_fn,
            inputs=[],
            outputs=[model_status, generate_btn]
        )
        
        generate_btn.click(
            fn=lambda p, lt, al, s, prog: generate_images(p, lt, al, s, pipe, prog),
            inputs=[prompt, loss_type, apply_loss, seeds],
            outputs=output_image
        )
        
        gr.Markdown("""
        ## About the Loss Functions
        
        - **Blue**: Encourages more blue tones in the image
        - **Elastic**: Creates distortion effects by minimizing differences with elastically transformed versions
        - **Symmetry**: Encourages symmetrical images by minimizing differences with horizontally flipped versions
        - **Saturation**: Increases color saturation in the image
        
        Set "N/A" and uncheck "Apply Loss Function" for normal image generation.
        """)
        
        if TORCH_DEVICE == "cpu":
            gr.Markdown("""
            ## CPU Mode Tips
            - Use smaller prompts
            - Process one seed at a time
            - Be patient, generation can take 5-10 minutes per image
            - If you encounter memory errors, try restarting the app and using even smaller dimensions
            """)
    
    return app

if __name__ == "__main__":
    # Create and launch the interface
    app = create_interface()
    app.launch()