anirudh97 commited on
Commit
a21e693
·
verified ·
1 Parent(s): 6891195

Upload 2 files

Browse files
Files changed (2) hide show
  1. README.md +129 -13
  2. gradio_app.py +279 -0
README.md CHANGED
@@ -1,13 +1,129 @@
1
- ---
2
- title: Text Inversion
3
- emoji: 🖼
4
- colorFrom: purple
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 5.0.1
8
- app_file: app.py
9
- pinned: false
10
- short_description: Stable Diffusion - Text Inversion
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Stable Diffusion Text Inversion with Loss Functions
2
+
3
+ This repository contains a Gradio web application that provides an intuitive interface for generating images using Stable Diffusion with textual inversion and guided loss functions.
4
+
5
+ ## Overview
6
+
7
+ The application allows users to explore the capabilities of Stable Diffusion by:
8
+ - Generating images from text prompts
9
+ - Using textual inversion concepts
10
+ - Applying various loss functions to guide the diffusion process
11
+ - Generating multiple images with different seeds
12
+
13
+ !![alt text](image.png)
14
+
15
+ ## Features
16
+
17
+ ### Core Functionality
18
+ - **Text-to-Image Generation**: Create detailed images from descriptive text prompts
19
+ - **Textual Inversion**: Apply learned concepts to your generations
20
+ - **Loss Function Guidance**: Shape image generation with specialized loss functions:
21
+ - **Blue Loss**: Emphasizes blue tones in the generated images
22
+ - **Elastic Loss**: Creates distortion effects by applying elastic transformations
23
+ - **Symmetry Loss**: Encourages symmetrical image generation
24
+ - **Saturation Loss**: Enhances color saturation in the output
25
+ - **Multi-Seed Generation**: Create multiple variations of an image with different seeds
26
+
27
+ ## Installation
28
+
29
+ ### Prerequisites
30
+ - Python 3.8+
31
+ - CUDA-capable GPU (recommended)
32
+
33
+ ### Setup
34
+ 1. Clone this repository:
35
+ ```bash
36
+ git clone https://github.com/yourusername/stable-diffusion-text-inversion.git
37
+ cd stable-diffusion-text-inversion
38
+ ```
39
+
40
+ 2. Install dependencies:
41
+ ```bash
42
+ pip install torch diffusers transformers tqdm torchvision matplotlib gradio
43
+ ```
44
+
45
+ 3. Run the application:
46
+ ```bash
47
+ python gradio_app.py
48
+ ```
49
+
50
+ 4. Open the provided URL (typically http://localhost:7860) in your browser.
51
+
52
+ ## Understanding the Technology
53
+
54
+ ### Stable Diffusion
55
+
56
+ Stable Diffusion is a latent text-to-image diffusion model developed by Stability AI. It works by:
57
+
58
+ 1. **Encoding text**: Converting text prompts into embeddings that the model can understand
59
+ 2. **Starting with noise**: Beginning with random noise in a latent space
60
+ 3. **Iterative denoising**: Gradually removing noise while being guided by the text embeddings
61
+ 4. **Decoding to image**: Converting the final latent representation to a pixel-based image
62
+
63
+ The model operates in a compressed latent space (64x64x4) rather than pixel space (512x512x3), allowing for efficient generation of high-resolution images with limited computational resources.
64
+
65
+ ### Textual Inversion
66
+
67
+ Textual Inversion is a technique that allows Stable Diffusion to learn new concepts from just a few example images. Key aspects include:
68
+
69
+ - **Custom Concepts**: Learn new visual concepts not present in the model's training data
70
+ - **Few-Shot Learning**: Typically requires only 3-5 examples of a concept
71
+ - **Token Optimization**: Creates a new "pseudo-word" embedding that represents the concept
72
+ - **Seamless Integration**: Once learned, concepts can be used in prompts just like regular words
73
+
74
+ In this application, we load several pre-trained textual inversion concepts from the SD concepts library:
75
+ - Rimworld art style
76
+ - HK Golden Lantern
77
+ - Phoenix-01
78
+ - Fractal Flame
79
+ - Scarlet Witch
80
+
81
+ ### Guided Loss Functions
82
+
83
+ This application introduces an innovative approach by applying custom loss functions during the diffusion process:
84
+
85
+ 1. **How it works**: During generation, we periodically decode the current latent representation, apply a loss function to the decoded image, and backpropagate that loss to adjust the latents.
86
+
87
+ 2. **Types of Loss Functions**:
88
+ - **Blue Loss**: Encourages pixels to have higher values in the blue channel
89
+ - **Elastic Loss**: Minimizes difference between the image and an elastically transformed version
90
+ - **Symmetry Loss**: Minimizes difference between the image and its horizontal mirror
91
+ - **Saturation Loss**: Pushes the image toward higher color saturation
92
+
93
+ 3. **Impact**: These loss functions can dramatically alter the aesthetic qualities of the generated images, creating effects that would be difficult to achieve through prompt engineering alone.
94
+
95
+ ## Usage Examples
96
+
97
+ ### Basic Image Generation
98
+ 1. Enter a prompt in the text box (e.g., "A majestic castle on a floating island with waterfalls")
99
+ 2. Set Loss Type to "N/A" and uncheck "Apply Loss Function"
100
+ 3. Enter a seed value (e.g., "42")
101
+ 4. Click "Generate Images"
102
+
103
+ ### Applying Loss Functions
104
+ 1. Enter your prompt
105
+ 2. Select a Loss Type (e.g., "symmetry")
106
+ 3. Check "Apply Loss Function"
107
+ 4. Enter a seed value
108
+ 5. Click "Generate Images"
109
+
110
+ ### Batch Generation
111
+ 1. Enter your prompt
112
+ 2. Select desired loss settings
113
+ 3. Enter multiple comma-separated seeds (e.g., "42, 100, 500")
114
+ 4. Click "Generate Images" to generate a grid of variations
115
+
116
+ ## Contributing
117
+
118
+ Contributions are welcome! Please feel free to submit a Pull Request.
119
+
120
+ ## License
121
+
122
+ This project is licensed under the MIT License - see the LICENSE file for details.
123
+
124
+ ## Acknowledgments
125
+
126
+ - [Stability AI](https://stability.ai/) for developing Stable Diffusion
127
+ - [Hugging Face](https://huggingface.co/) for the Diffusers library
128
+ - [Gradio](https://gradio.app/) for the web interface framework
129
+ - The creators of the textual inversion concepts used in this project
gradio_app.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ from PIL import Image
5
+ from diffusers import StableDiffusionPipeline, DiffusionPipeline
6
+ from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
7
+ from tqdm.auto import tqdm
8
+ import torchvision.transforms as T
9
+ import torch.nn.functional as F
10
+ import gc
11
+
12
+ # Configure constants
13
+ HEIGHT, WIDTH = 512, 512
14
+ GUIDANCE_SCALE = 8
15
+ LOSS_SCALE = 200
16
+ NUM_INFERENCE_STEPS = 50
17
+ BATCH_SIZE = 1
18
+ DEFAULT_PROMPT = "A deadly witcher slinging a sword with a lion medallion in his neck, casting a fire spell from his hand in a snowy forest"
19
+
20
+ # Define the device
21
+ TORCH_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
22
+
23
+ # Initialize the elastic transformer
24
+ elastic_transformer = T.ElasticTransform(alpha=550.0, sigma=5.0)
25
+
26
+ # Load the model
27
+ def load_model():
28
+ pipe = DiffusionPipeline.from_pretrained(
29
+ "CompVis/stable-diffusion-v1-4",
30
+ torch_dtype=torch.float16 if TORCH_DEVICE == "cuda" else torch.float32
31
+ ).to(TORCH_DEVICE)
32
+
33
+ # Load textual inversion concepts
34
+ try:
35
+ pipe.load_textual_inversion("sd-concepts-library/rimworld-art-style", mean_resizing=False)
36
+ pipe.load_textual_inversion("sd-concepts-library/hk-goldenlantern", mean_resizing=False)
37
+ pipe.load_textual_inversion("sd-concepts-library/phoenix-01", mean_resizing=False)
38
+ pipe.load_textual_inversion("sd-concepts-library/fractal-flame", mean_resizing=False)
39
+ pipe.load_textual_inversion("sd-concepts-library/scarlet-witch", mean_resizing=False)
40
+ except Exception as e:
41
+ print(f"Warning: Could not load all textual inversion concepts: {e}")
42
+
43
+ return pipe
44
+
45
+ # Helper functions
46
+ def image_grid(imgs, rows, cols):
47
+ assert len(imgs) == rows*cols
48
+ w, h = imgs[0].size
49
+ grid = Image.new('RGB', size=(cols*w, rows*h))
50
+
51
+ for i, img in enumerate(imgs):
52
+ grid.paste(img, box=(i%cols*w, i//cols*h))
53
+ return grid
54
+
55
+ def image_loss(images, loss_type):
56
+ if loss_type == 'blue':
57
+ # blue loss
58
+ error = torch.abs(images[:,2] - 0.9).mean()
59
+ elif loss_type == 'elastic':
60
+ # elastic loss
61
+ transformed_imgs = elastic_transformer(images)
62
+ error = torch.abs(transformed_imgs - images).mean()
63
+ elif loss_type == 'symmetry':
64
+ flipped_image = torch.flip(images, [3])
65
+ error = F.mse_loss(images, flipped_image)
66
+ elif loss_type == 'saturation':
67
+ # saturation loss
68
+ transformed_imgs = T.functional.adjust_saturation(images, saturation_factor=10)
69
+ error = torch.abs(transformed_imgs - images).mean()
70
+ else:
71
+ print("Error. Loss not defined")
72
+ error = torch.tensor(0.0)
73
+
74
+ return error
75
+
76
+ def latents_to_pil(latents, pipe):
77
+ # batch of latents -> list of images
78
+ latents = (1 / 0.18215) * latents
79
+ with torch.no_grad():
80
+ image = pipe.vae.decode(latents).sample
81
+ image = (image / 2 + 0.5).clamp(0, 1)
82
+ image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
83
+ images = (image * 255).round().astype("uint8")
84
+ pil_images = [Image.fromarray(image) for image in images]
85
+ return pil_images
86
+
87
+ def generate_image(pipe, seed_no, prompts, loss_type, loss_apply=False, progress=gr.Progress()):
88
+ # Initialization and Setup
89
+ generator = torch.manual_seed(seed_no)
90
+
91
+ scheduler = LMSDiscreteScheduler(
92
+ beta_start=0.00085,
93
+ beta_end=0.012,
94
+ beta_schedule="scaled_linear",
95
+ num_train_timesteps=1000
96
+ )
97
+ scheduler.set_timesteps(NUM_INFERENCE_STEPS)
98
+ scheduler.timesteps = scheduler.timesteps.to(torch.float32)
99
+
100
+ # Text Processing
101
+ text_input = pipe.tokenizer(
102
+ prompts,
103
+ padding='max_length',
104
+ max_length=pipe.tokenizer.model_max_length,
105
+ truncation=True,
106
+ return_tensors="pt"
107
+ )
108
+ input_ids = text_input.input_ids.to(TORCH_DEVICE)
109
+
110
+ # Convert text inputs to embeddings
111
+ with torch.no_grad():
112
+ text_embeddings = pipe.text_encoder(input_ids)[0]
113
+
114
+ # Handle padding and truncation of text inputs
115
+ max_length = text_input.input_ids.shape[-1]
116
+ uncond_input = pipe.tokenizer(
117
+ [""] * BATCH_SIZE,
118
+ padding="max_length",
119
+ max_length=max_length,
120
+ return_tensors="pt"
121
+ )
122
+
123
+ with torch.no_grad():
124
+ uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(TORCH_DEVICE))[0]
125
+
126
+ # Concatenate unconditioned and text embeddings
127
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
128
+
129
+ # Create random initial latents
130
+ latents = torch.randn(
131
+ (BATCH_SIZE, pipe.unet.config.in_channels, HEIGHT // 8, WIDTH // 8),
132
+ generator=generator,
133
+ )
134
+
135
+ # Move latents to device and apply noise scaling
136
+ if TORCH_DEVICE == "cuda":
137
+ latents = latents.to(torch.float16)
138
+ latents = latents.to(TORCH_DEVICE)
139
+ latents = latents * scheduler.init_noise_sigma
140
+
141
+ # Diffusion Process
142
+ for i, t in progress.tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
143
+ # Process the latent model input
144
+ latent_model_input = torch.cat([latents] * 2)
145
+ sigma = scheduler.sigmas[i]
146
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
147
+
148
+ with torch.no_grad():
149
+ noise_pred = pipe.unet(
150
+ latent_model_input,
151
+ t,
152
+ encoder_hidden_states=text_embeddings
153
+ )["sample"]
154
+
155
+ # Apply noise prediction
156
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
157
+ noise_pred = noise_pred_uncond + GUIDANCE_SCALE * (noise_pred_text - noise_pred_uncond)
158
+
159
+ # Apply loss if requested
160
+ if loss_apply and i % 5 == 0:
161
+ latents = latents.detach().requires_grad_()
162
+ latents_x0 = latents - sigma * noise_pred
163
+
164
+ # Use VAE to decode the image
165
+ denoised_images = pipe.vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5
166
+
167
+ # Apply loss
168
+ loss = image_loss(denoised_images, loss_type) * LOSS_SCALE
169
+ print(f"Step {i}, Loss: {loss.item()}")
170
+
171
+ # Compute gradients for optimization
172
+ cond_grad = torch.autograd.grad(loss, latents)[0]
173
+ latents = latents.detach() - cond_grad * sigma**2
174
+
175
+ # Update latents using the scheduler
176
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
177
+
178
+ return latents
179
+
180
+ def generate_images(prompt, loss_type, apply_loss, seeds, pipe):
181
+ latents_collect = []
182
+
183
+ # Convert comma-separated string to list and clean
184
+ seeds = [int(seed.strip()) for seed in seeds.split(',') if seed.strip()]
185
+
186
+ if not seeds:
187
+ seeds = [1000] # Default seed if none provided
188
+
189
+ # List of SD concepts (can be empty if not used)
190
+ sdconcepts = [''] * len(seeds)
191
+
192
+ # Generate images for each seed
193
+ for seed_no, sd in zip(seeds, sdconcepts):
194
+ # Clear CUDA cache
195
+ if TORCH_DEVICE == "cuda":
196
+ torch.cuda.empty_cache()
197
+ gc.collect()
198
+ torch.cuda.empty_cache()
199
+
200
+ # Generate image
201
+ prompts = [f'{prompt} {sd}']
202
+ latents = generate_image(pipe, seed_no, prompts, loss_type, loss_apply=apply_loss)
203
+ latents_collect.append(latents)
204
+
205
+ # Stack latents and convert to images
206
+ latents_collect = torch.vstack(latents_collect)
207
+ images = latents_to_pil(latents_collect, pipe)
208
+
209
+ # Create image grid
210
+ if len(images) > 1:
211
+ result = image_grid(images, 1, len(images))
212
+ return result
213
+ else:
214
+ return images[0]
215
+
216
+ # Gradio Interface
217
+ def create_interface():
218
+ pipe = load_model()
219
+
220
+ with gr.Blocks(title="Stable Diffusion Text Inversion with Loss Functions") as app:
221
+ gr.Markdown("""
222
+ # Stable Diffusion Text Inversion with Loss Functions
223
+
224
+ Generate images using Stable Diffusion with various loss functions to guide the diffusion process.
225
+ """)
226
+
227
+ with gr.Row():
228
+ with gr.Column():
229
+ prompt = gr.Textbox(
230
+ label="Prompt",
231
+ value=DEFAULT_PROMPT,
232
+ lines=3
233
+ )
234
+
235
+ loss_type = gr.Radio(
236
+ label="Loss Type",
237
+ choices=["N/A", "blue", "elastic", "symmetry", "saturation"],
238
+ value="N/A"
239
+ )
240
+
241
+ apply_loss = gr.Checkbox(
242
+ label="Apply Loss Function",
243
+ value=False
244
+ )
245
+
246
+ seeds = gr.Textbox(
247
+ label="Seeds (comma-separated)",
248
+ value="3000,2000,1000",
249
+ lines=1
250
+ )
251
+
252
+ generate_btn = gr.Button("Generate Images")
253
+
254
+ with gr.Column():
255
+ output_image = gr.Image(label="Generated Image")
256
+
257
+ generate_btn.click(
258
+ fn=lambda p, lt, al, s: generate_images(p, lt, al, s, pipe),
259
+ inputs=[prompt, loss_type, apply_loss, seeds],
260
+ outputs=output_image
261
+ )
262
+
263
+ gr.Markdown("""
264
+ ## About the Loss Functions
265
+
266
+ - **Blue**: Encourages more blue tones in the image
267
+ - **Elastic**: Creates distortion effects by minimizing differences with elastically transformed versions
268
+ - **Symmetry**: Encourages symmetrical images by minimizing differences with horizontally flipped versions
269
+ - **Saturation**: Increases color saturation in the image
270
+
271
+ Set "N/A" and uncheck "Apply Loss Function" for normal image generation.
272
+ """)
273
+
274
+ return app
275
+
276
+ if __name__ == "__main__":
277
+ # Create and launch the interface
278
+ app = create_interface()
279
+ app.launch()