AkashDataScience commited on
Commit
7f3baaa
·
1 Parent(s): 2809d2c

First commit

Browse files
app.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ from tqdm import tqdm
5
+ from PIL import Image
6
+ from torchvision import transforms as tfms
7
+ from transformers import CLIPTextModel, CLIPTokenizer, logging
8
+ from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
9
+
10
+ torch_device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
11
+ if "mps" == torch_device: os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1"
12
+
13
+ # Load the autoencoder model which will be used to decode the latents into image space.
14
+ vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")
15
+
16
+ # Load the tokenizer and text encoder to tokenize and encode the text.
17
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
18
+ text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
19
+
20
+ # The UNet model for generating the latents.
21
+ unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")
22
+
23
+ # The noise scheduler
24
+ scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
25
+
26
+ style_token_dict = {'Concept':'<concept-art>', 'Realistic':'<doose-realistic>', 'Line':'<line-art>',
27
+ 'Ricky':'<RickyArt>', 'Plane Scape':'<tony-diterlizzi-planescape>'}
28
+
29
+ # To the GPU we go!
30
+ vae = vae.to(torch_device)
31
+ text_encoder = text_encoder.to(torch_device)
32
+ unet = unet.to(torch_device)
33
+
34
+ token_emb_layer = text_encoder.text_model.embeddings.token_embedding
35
+ pos_emb_layer = text_encoder.text_model.embeddings.position_embedding
36
+ position_ids = text_encoder.text_model.embeddings.position_ids[:, :77]
37
+ position_embeddings = pos_emb_layer(position_ids)
38
+
39
+ concept_art_embed = torch.load('concept-art.bin')
40
+ doose_s_realistic_art_style_embed = torch.load('doose-s-realistic-art-style.bin')
41
+ line_art_embed = torch.load('line-art.bin')
42
+ rickyart_embed = torch.load('rickyart.bin')
43
+ tony_diterlizzi_s_planescape_art_embed = torch.load('tony-diterlizzi-s-planescape-art.bin')
44
+
45
+ tokenizer.add_tokens(['<concept-art>', '<doose-realistic>', '<line-art>', '<RickyArt>', '<tony-diterlizzi-planescape>'])
46
+
47
+ token_emb_layer_with_art = torch.nn.Embedding(49413, 768)
48
+ token_emb_layer_with_art.load_state_dict({'weight': torch.cat((token_emb_layer.state_dict()['weight'],
49
+ concept_art_embed['<concept-art>'].unsqueeze(0).to(torch_device),
50
+ doose_s_realistic_art_style_embed['<doose-realistic>'].unsqueeze(0).to(torch_device),
51
+ line_art_embed['<line-art>'].unsqueeze(0).to(torch_device),
52
+ rickyart_embed['<RickyArt>'].unsqueeze(0).to(torch_device),
53
+ tony_diterlizzi_s_planescape_art_embed['<tony-diterlizzi-planescape>'].unsqueeze(0).to(torch_device)))})
54
+ token_emb_layer_with_art = token_emb_layer_with_art.to(torch_device)
55
+
56
+ def set_timesteps(scheduler, num_inference_steps):
57
+ scheduler.set_timesteps(num_inference_steps)
58
+ scheduler.timesteps = scheduler.timesteps.to(torch.float32)
59
+
60
+ def pil_to_latent(input_im):
61
+ with torch.no_grad():
62
+ latent = vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(torch_device)*2-1) # Note scaling
63
+ return 0.18215 * latent.latent_dist.sample()
64
+
65
+ def latents_to_pil(latents):
66
+ latents = (1 / 0.18215) * latents
67
+ with torch.no_grad():
68
+ image = vae.decode(latents).sample
69
+ image = (image / 2 + 0.5).clamp(0, 1)
70
+ image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
71
+ images = (image * 255).round().astype("uint8")
72
+ pil_images = [Image.fromarray(image) for image in images]
73
+ return pil_images
74
+
75
+ def build_causal_attention_mask(bsz, seq_len, dtype):
76
+ mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
77
+ mask.fill_(torch.tensor(torch.finfo(dtype).min)) # fill with large negative number (acts like -inf)
78
+ mask = mask.triu_(1) # zero out the lower diagonal to enforce causality
79
+ return mask.unsqueeze(1) # add a batch dimension
80
+
81
+ def get_output_embeds(input_embeddings):
82
+ # CLIP's text model uses causal mask, so we prepare it here:
83
+ bsz, seq_len = input_embeddings.shape[:2]
84
+ causal_attention_mask = build_causal_attention_mask(bsz, seq_len, dtype=input_embeddings.dtype)
85
+
86
+ # Getting the output embeddings involves calling the model with passing output_hidden_states=True
87
+ # so that it doesn't just return the pooled final predictions:
88
+ encoder_outputs = text_encoder.text_model.encoder(
89
+ inputs_embeds=input_embeddings,
90
+ attention_mask=None, # We aren't using an attention mask so that can be None
91
+ causal_attention_mask=causal_attention_mask.to(torch_device),
92
+ output_attentions=None,
93
+ output_hidden_states=True, # We want the output embs not the final output
94
+ return_dict=None,
95
+ )
96
+
97
+ # We're interested in the output hidden state only
98
+ output = encoder_outputs[0]
99
+
100
+ # There is a final layer norm we need to pass these through
101
+ output = text_encoder.text_model.final_layer_norm(output)
102
+
103
+ # And now they're ready!
104
+ return output
105
+
106
+ def generate_with_embs(num_inference_steps, guidance_scale, seed, text_input, text_embeddings):
107
+ height = 512 # default height of Stable Diffusion
108
+ width = 512 # default width of Stable Diffusion
109
+ generator = torch.manual_seed(seed) # Seed generator to create the inital latent noise
110
+ batch_size = 1
111
+
112
+ max_length = text_input.input_ids.shape[-1]
113
+ uncond_input = tokenizer(
114
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
115
+ )
116
+ with torch.no_grad():
117
+ uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
118
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
119
+
120
+ # Prep Scheduler
121
+ set_timesteps(scheduler, num_inference_steps)
122
+
123
+ # Prep latents
124
+ latents = torch.randn(
125
+ (batch_size, unet.in_channels, height // 8, width // 8),
126
+ generator=generator,
127
+ )
128
+ latents = latents.to(torch_device)
129
+ latents = latents * scheduler.init_noise_sigma
130
+
131
+ # Loop
132
+ for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
133
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
134
+ latent_model_input = torch.cat([latents] * 2)
135
+ sigma = scheduler.sigmas[i]
136
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
137
+
138
+ # predict the noise residual
139
+ with torch.no_grad():
140
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
141
+
142
+ # perform guidance
143
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
144
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
145
+
146
+ # compute the previous noisy sample x_t -> x_t-1
147
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
148
+
149
+ return latents_to_pil(latents)[0]
150
+
151
+ def inference(text, style, inference_step, guidance_scale, seed):
152
+ prompt = text + " the style of " + style_token_dict[style]
153
+
154
+ # Tokenize
155
+ text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
156
+ input_ids = text_input.input_ids.to(torch_device)
157
+
158
+ # Get token embeddings
159
+ token_embeddings = token_emb_layer_with_art(input_ids)
160
+
161
+ # Combine with pos embs
162
+ input_embeddings = token_embeddings + position_embeddings
163
+
164
+ # Feed through to get final output embs
165
+ modified_output_embeddings = get_output_embeds(input_embeddings)
166
+
167
+ # And generate an image with this:
168
+ image = generate_with_embs(inference_step, guidance_scale, seed, text_input, modified_output_embeddings)
169
+
170
+ return image
171
+
172
+ title = "Stable Diffusion with Textual Inversion"
173
+ description = "A simple Gradio interface to infer Stable Diffusion and generate images with different art style"
174
+ examples = [["A sweet potato farm", 'Concept', 30, 0.5, 1],
175
+ ["Sky full of cotton candy", 'Realistic', 30, 1.5, 2],
176
+ ["Coffin full of jello", 'Line', 30, 2.5, 3],
177
+ ["Water skiing on a lake", 'Ricky', 30, 3.5, 4],
178
+ ["Super slippery noodles", 'Plane Scape', 30, 4.5, 5],
179
+ ["Beautiful sunset", 'Concept', 30, 5.5, 6],
180
+ ["A glittering gem", 'Realistic', 30, 6.5, 7],
181
+ ["River rafting", 'Line', 30, 7.5, 8],
182
+ ["A green tea", 'Ricky', 30, 8.5, 9],
183
+ ["Three sphered rocks", 'Plane Scape', 30, 9.5, 10]]
184
+
185
+ demo = gr.Interface(inference,
186
+ inputs = [gr.Textbox(label="Prompt", type="text"),
187
+ gr.Dropdown(label="Style", choices=['Concept', 'Realistic', 'Line',
188
+ 'Ricky', 'Plane Scape'], value="Concept"),
189
+ gr.Slider(10, 50, 30, step = 10, label="Inference steps"),
190
+ gr.Slider(1, 10, 7.5, step = 0.1, label="Guidance scale"),
191
+ gr.Slider(0, 10000, 1, step = 1, label="Seed")],
192
+ outputs= [gr.Image(width=320, height=320, label="Output SAM")],
193
+ title=title,
194
+ description=description,
195
+ examples=examples)
196
+
197
+ demo.launch()
concept-art.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cd046d1c90c6e58769033de23adadf936e873597b11fed16a07dde7750bd348c
3
+ size 3819
doose-s-realistic-art-style.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0a2cfe14214eb4055475b3445420796adfe6aa1bfbc9e1fb7dc62dedd5d71808
3
+ size 3819
line-art.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0528436ec2228c659e0cf1316e713345bc97a3d88294f1a2987a3505d220e770
3
+ size 3819
requirements.txt ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ altair==5.3.0
3
+ annotated-types==0.7.0
4
+ anyio==4.4.0
5
+ attrs==23.2.0
6
+ certifi==2024.6.2
7
+ charset-normalizer==3.3.2
8
+ click==8.1.7
9
+ colorama==0.4.6
10
+ contourpy==1.2.1
11
+ cycler==0.12.1
12
+ diffusers==0.29.2
13
+ dnspython==2.6.1
14
+ email_validator==2.1.1
15
+ fastapi==0.111.0
16
+ fastapi-cli==0.0.4
17
+ ffmpy==0.3.2
18
+ filelock==3.13.1
19
+ fonttools==4.53.0
20
+ fsspec==2024.2.0
21
+ ftfy==6.2.0
22
+ gradio==4.36.1
23
+ gradio_client==1.0.1
24
+ h11==0.14.0
25
+ httpcore==1.0.5
26
+ httptools==0.6.1
27
+ httpx==0.27.0
28
+ huggingface-hub==0.23.4
29
+ idna==3.7
30
+ importlib_resources==6.4.0
31
+ intel-openmp==2021.4.0
32
+ Jinja2==3.1.3
33
+ jsonschema==4.22.0
34
+ jsonschema-specifications==2023.12.1
35
+ kiwisolver==1.4.5
36
+ markdown-it-py==3.0.0
37
+ MarkupSafe==2.1.5
38
+ matplotlib==3.9.0
39
+ mdurl==0.1.2
40
+ mkl==2021.4.0
41
+ mpmath==1.3.0
42
+ networkx==3.2.1
43
+ numpy==1.26.3
44
+ opencv-python==4.10.0.84
45
+ orjson==3.10.5
46
+ packaging==24.1
47
+ pandas==2.2.2
48
+ pillow==10.2.0
49
+ pydantic==2.7.4
50
+ pydantic_core==2.18.4
51
+ pydub==0.25.1
52
+ Pygments==2.18.0
53
+ pyparsing==3.1.2
54
+ python-dateutil==2.9.0.post0
55
+ python-dotenv==1.0.1
56
+ python-multipart==0.0.9
57
+ pytz==2024.1
58
+ PyYAML==6.0.1
59
+ referencing==0.35.1
60
+ regex==2024.5.15
61
+ requests==2.32.3
62
+ rich==13.7.1
63
+ rpds-py==0.18.1
64
+ ruff==0.4.9
65
+ scipy==1.14.0
66
+ semantic-version==2.10.0
67
+ shellingham==1.5.4
68
+ six==1.16.0
69
+ sniffio==1.3.1
70
+ starlette==0.37.2
71
+ sympy==1.12
72
+ tbb==2021.11.0
73
+ tiktoken==0.7.0
74
+ tomlkit==0.12.0
75
+ toolz==0.12.1
76
+ torch==2.3.1
77
+ torchaudio==2.3.1
78
+ torchvision==0.18.1
79
+ tqdm==4.66.4
80
+ transformers==4.43.3
81
+ typer==0.12.3
82
+ typing_extensions==4.9.0
83
+ tzdata==2024.1
84
+ ujson==5.10.0
85
+ urllib3==2.2.1
86
+ uvicorn==0.30.1
87
+ watchfiles==0.22.0
88
+ websockets==11.0.3
rickyart.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d0e2263b57fb66b48c9b09bee7b01b5fd8d708a6c52754265fd35052424d82ee
3
+ size 3819
tony-diterlizzi-s-planescape-art.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:718333440c0986d401954d82acd2b2e0e8222f6b4d8587d4332c16bc9191cba4
3
+ size 3819