Update app.py
Browse files
app.py
CHANGED
@@ -14,8 +14,8 @@ from torchvision import transforms
|
|
14 |
from torchvision.transforms import functional as TF
|
15 |
from tqdm import trange
|
16 |
from transformers import CLIPProcessor, CLIPModel
|
17 |
-
from huggingface_hub import hf_hub_download
|
18 |
-
import gradio as gr
|
19 |
import math
|
20 |
|
21 |
# -----------------------------------------------------------------------------
|
@@ -27,10 +27,6 @@ import math
|
|
27 |
class VQVAE2(nn.Module):
|
28 |
def __init__(self, n_embed=8192, embed_dim=256, ch=128):
|
29 |
super().__init__()
|
30 |
-
# This is a simplified placeholder. The actual architecture would be more complex.
|
31 |
-
# The key is having a 'decode' method that matches the state_dict.
|
32 |
-
# A full implementation would require the original model's architecture file.
|
33 |
-
# For this fix, we assume a basic structure that allows loading the state_dict.
|
34 |
self.decoder = nn.Sequential(
|
35 |
nn.Conv2d(embed_dim, ch * 4, 3, padding=1),
|
36 |
nn.ReLU(),
|
@@ -40,40 +36,32 @@ class VQVAE2(nn.Module):
|
|
40 |
nn.ReLU(),
|
41 |
nn.ConvTranspose2d(ch, 3, 4, stride=2, padding=1),
|
42 |
)
|
43 |
-
|
44 |
def decode(self, latents):
|
45 |
-
# A real VQVAE would involve lookup tables, but for generation we only need the decoder part.
|
46 |
-
# This part is highly dependent on the model checkpoint.
|
47 |
-
# The following is a guess to make it runnable, assuming latents are ready for the decoder.
|
48 |
return self.decoder(latents)
|
49 |
|
50 |
# Diffusion Model Definition
|
51 |
class Diffusion(nn.Module):
|
52 |
-
|
|
|
53 |
super().__init__()
|
54 |
-
# This is also a placeholder for the architecture.
|
55 |
-
# A full UNet-style model is expected here. The key is that it can be called
|
56 |
-
# with x, t, and conditional embeddings, and returns the predicted noise.
|
57 |
self.time_embed = nn.Embedding(1000, n_inputs * 4)
|
58 |
self.cond_embed = nn.Linear(n_embed, n_inputs * 4)
|
59 |
-
|
60 |
self.layers = nn.ModuleList([
|
61 |
-
nn.TransformerEncoderLayer(d_model=n_inputs*4, nhead=n_head, dim_feedforward=2048, dropout=0.1, activation='gelu')
|
62 |
for _ in range(n_layer)
|
63 |
])
|
64 |
self.out = nn.Linear(n_inputs*4, n_inputs)
|
65 |
|
66 |
def forward(self, x, t, c):
|
67 |
-
# A very simplified forward pass
|
68 |
-
# The actual model is likely a UNet with cross-attention.
|
69 |
bs, ch, h, w = x.shape
|
70 |
x = x.permute(0, 2, 3, 1).reshape(bs, h * w, ch)
|
71 |
-
|
72 |
t_emb = self.time_embed(t.long())
|
73 |
c_emb = self.cond_embed(c)
|
74 |
emb = t_emb + c_emb
|
75 |
-
|
76 |
-
# This is a gross simplification; a real model would use cross-attention here.
|
77 |
x_out = self.out(x + emb.unsqueeze(1))
|
78 |
x_out = x_out.reshape(bs, h, w, ch).permute(0, 3, 1, 2)
|
79 |
return x_out
|
@@ -81,42 +69,35 @@ class Diffusion(nn.Module):
|
|
81 |
|
82 |
# Sampling Function Definitions
|
83 |
def get_sigmas(n_steps):
|
84 |
-
"""Returns the sigma schedule."""
|
85 |
t = torch.linspace(1, 0, n_steps + 1)
|
86 |
return ((t[:-1] ** 2) / (t[1:] ** 2) - 1).sqrt()
|
87 |
|
88 |
@torch.no_grad()
|
89 |
def plms_sample(model, x, steps, **kwargs):
|
90 |
-
"""Poor Man's LMS Sampler"""
|
91 |
ts = x.new_ones([x.shape[0]])
|
92 |
sigmas = get_sigmas(steps)
|
93 |
model_fn = lambda x, t: model(x, t * 1000, **kwargs)
|
94 |
-
|
95 |
-
x_outs = []
|
96 |
old_denoised = None
|
97 |
-
|
98 |
for i in trange(len(sigmas) -1, disable=True):
|
99 |
denoised = model_fn(x, ts * sigmas[i])
|
100 |
-
|
101 |
if old_denoised is None:
|
102 |
d = (denoised - x) / sigmas[i]
|
103 |
else:
|
104 |
-
d = (3 * denoised - old_denoised) / 2 - x / sigmas[i]
|
105 |
|
106 |
x = x + d * (sigmas[i+1] - sigmas[i])
|
107 |
old_denoised = denoised
|
108 |
-
x_outs.append(x)
|
109 |
-
return x_outs[-1]
|
110 |
|
111 |
-
|
112 |
-
|
113 |
def ddim_sample(model, x, steps, eta, **kwargs):
|
114 |
-
# This is a placeholder for a full DDIM implementation
|
115 |
print("Warning: DDIM sampler is not fully implemented. Using PLMS instead.")
|
116 |
return plms_sample(model, x, steps, **kwargs)
|
117 |
|
118 |
def ddpm_sample(model, x, steps, **kwargs):
|
119 |
-
# This is a placeholder for a full DDPM implementation
|
120 |
print("Warning: DDPM sampler is not fully implemented. Using PLMS instead.")
|
121 |
return plms_sample(model, x, steps, **kwargs)
|
122 |
|
@@ -125,16 +106,12 @@ def ddpm_sample(model, x, steps, **kwargs):
|
|
125 |
# -----------------------------------------------------------------------------
|
126 |
|
127 |
# 🖼️ Download the necessary model files from HuggingFace
|
128 |
-
# NOTE: The HuggingFace URLs you provided might be placeholders.
|
129 |
-
# Make sure these point to the correct model files.
|
130 |
try:
|
131 |
-
# FIXED: Using the new hf_hub_download function with keyword arguments
|
132 |
vqvae_model_path = hf_hub_download(repo_id="dalle-mini/vqgan_imagenet_f16_16384", filename="flax_model.msgpack")
|
133 |
diffusion_model_path = hf_hub_download(repo_id="huggingface/dalle-2", filename="diffusion_model.ckpt")
|
134 |
except Exception as e:
|
135 |
print(f"Could not download models. Please ensure the HuggingFace URLs are correct.")
|
136 |
print("Using placeholder models which will not produce good images.")
|
137 |
-
# Create dummy files if download fails to allow script to run
|
138 |
Path("vqvae_model.ckpt").touch()
|
139 |
Path("diffusion_model.ckpt").touch()
|
140 |
vqvae_model_path = "vqvae_model.ckpt"
|
@@ -142,26 +119,17 @@ except Exception as e:
|
|
142 |
|
143 |
|
144 |
# 📐 Utility Functions: Math and images, what could go wrong?
|
145 |
-
# These functions help parse prompts and resize/crop images to fit nicely
|
146 |
-
|
147 |
def parse_prompt(prompt, default_weight=3.):
|
148 |
-
"""
|
149 |
-
🎯 Parses a prompt into text and weight.
|
150 |
-
"""
|
151 |
vals = prompt.rsplit(':', 1)
|
152 |
vals = vals + ['', default_weight][len(vals):]
|
153 |
return vals[0], float(vals[1])
|
154 |
|
155 |
def resize_and_center_crop(image, size):
|
156 |
-
"""
|
157 |
-
✂️ Resize and crop image to center it beautifully.
|
158 |
-
"""
|
159 |
fac = max(size[0] / image.size[0], size[1] / image.size[1])
|
160 |
image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS)
|
161 |
return TF.center_crop(image, size[::-1])
|
162 |
|
163 |
# 🧠 Model loading: the brain of our operation! 🔥
|
164 |
-
|
165 |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
166 |
print('Using device:', device)
|
167 |
print('loading models... 🛠️')
|
@@ -171,23 +139,17 @@ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device
|
|
171 |
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
172 |
|
173 |
# Load VQ-VAE-2 Autoencoder
|
174 |
-
# NOTE: The VQVAE2 class is a placeholder. Loading a real checkpoint will likely fail
|
175 |
-
# unless the class definition perfectly matches the architecture of the saved model.
|
176 |
try:
|
177 |
vqvae = VQVAE2()
|
178 |
-
# vqvae.load_state_dict(torch.load(vqvae_model_path, map_location=device))
|
179 |
print("Skipping VQVAE weight loading due to placeholder architecture.")
|
180 |
except Exception as e:
|
181 |
print(f"Could not load VQVAE model: {e}. Using placeholder.")
|
182 |
vqvae = VQVAE2()
|
183 |
vqvae.eval().requires_grad_(False).to(device)
|
184 |
|
185 |
-
|
186 |
# Load Diffusion Model
|
187 |
-
# NOTE: The Diffusion class is a placeholder. This will also likely fail.
|
188 |
try:
|
189 |
diffusion_model = Diffusion()
|
190 |
-
# diffusion_model.load_state_dict(torch.load(diffusion_model_path, map_location=device))
|
191 |
print("Skipping Diffusion Model weight loading due to placeholder architecture.")
|
192 |
except Exception as e:
|
193 |
print(f"Could not load Diffusion model: {e}. Using placeholder.")
|
@@ -196,27 +158,19 @@ diffusion_model = diffusion_model.to(device).eval().requires_grad_(False)
|
|
196 |
|
197 |
|
198 |
# 🎨 The key function: Where the magic happens!
|
199 |
-
# This is where we generate images based on text and image prompts
|
200 |
-
|
201 |
def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15, method='ddim', eta=None):
|
202 |
-
"""
|
203 |
-
🖼️ Generates a list of PIL images based on given text and image prompts.
|
204 |
-
"""
|
205 |
zero_embed = torch.zeros([1, clip_model.config.projection_dim], device=device)
|
206 |
target_embeds, weights = [zero_embed], []
|
207 |
|
208 |
-
# Parse text prompts and encode with CLIP
|
209 |
for prompt in prompts:
|
210 |
inputs = clip_processor(text=prompt, return_tensors="pt").to(device)
|
211 |
text_embed = clip_model.get_text_features(**inputs).float()
|
212 |
target_embeds.append(text_embed)
|
213 |
weights.append(1.0)
|
214 |
|
215 |
-
# Correctly process image prompts from Gradio
|
216 |
-
# Assign a default weight for image prompts
|
217 |
image_prompt_weight = 1.0
|
218 |
for image_path in images:
|
219 |
-
if image_path:
|
220 |
try:
|
221 |
img = Image.open(image_path).convert('RGB')
|
222 |
img = resize_and_center_crop(img, (224, 224))
|
@@ -227,28 +181,23 @@ def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15, method
|
|
227 |
except Exception as e:
|
228 |
print(f"Warning: Could not process image prompt {image_path}. Error: {e}")
|
229 |
|
230 |
-
|
231 |
-
# Adjust weights and set seed
|
232 |
weights = torch.tensor([1 - sum(weights), *weights], device=device)
|
233 |
torch.manual_seed(seed)
|
234 |
|
235 |
-
# 💡 Model function with classifier-free guidance
|
236 |
def cfg_model_fn(x, t):
|
237 |
n = x.shape[0]
|
238 |
n_conds = len(target_embeds)
|
239 |
x_in = x.repeat([n_conds, 1, 1, 1])
|
240 |
t_in = t.repeat([n_conds])
|
241 |
embed_in = torch.cat(target_embeds).repeat_interleave(n, 0)
|
242 |
-
|
243 |
-
# Ensure correct dimensions for the placeholder Diffusion model
|
244 |
if isinstance(diffusion_model, Diffusion):
|
245 |
-
embed_in = embed_in[:, :512]
|
246 |
-
|
247 |
vs = diffusion_model(x_in, t_in, embed_in).view([n_conds, n, *x.shape[1:]])
|
248 |
v = vs.mul(weights[:, None, None, None, None]).sum(0)
|
249 |
return v
|
250 |
|
251 |
-
# 🎞️ Run the sampler to generate images
|
252 |
def run(x, steps):
|
253 |
if method == 'ddpm':
|
254 |
return ddpm_sample(cfg_model_fn, x, steps)
|
@@ -258,52 +207,39 @@ def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15, method
|
|
258 |
return plms_sample(cfg_model_fn, x, steps)
|
259 |
assert False, f"Unknown method: {method}"
|
260 |
|
261 |
-
# 🏃♂️ Generate the output images
|
262 |
batch_size = n
|
263 |
x = torch.randn([n, 3, 64, 64], device=device)
|
264 |
-
|
265 |
pil_ims = []
|
266 |
for i in trange(0, n, batch_size):
|
267 |
cur_batch_size = min(n - i, batch_size)
|
268 |
out_latents = run(x[i:i + cur_batch_size], steps)
|
269 |
-
|
270 |
-
# The VQVAE expects specific dimensions. Adjusting for the placeholder.
|
271 |
-
# This will likely need tuning for the real model.
|
272 |
if isinstance(vqvae, VQVAE2):
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
quant_guess = F.gumbel_softmax(out_latents, hard=True).permute(0, 2, 3, 1) # (B, H, W, C)
|
277 |
-
pil_ims.append(transforms.ToPILImage()(quant_guess[0].permute(2, 0, 1)))
|
278 |
else:
|
279 |
outs = vqvae.decode(out_latents)
|
280 |
for j, out in enumerate(outs):
|
281 |
pil_ims.append(transforms.ToPILImage()(out.clamp(0, 1)))
|
282 |
-
|
283 |
return pil_ims
|
284 |
|
285 |
# 🖌️ Interface: Gradio's brush to paint the UI
|
286 |
def gen_ims(prompt, im_prompt=None, seed=None, n_steps=10, method='plms'):
|
287 |
-
"""
|
288 |
-
💡 Gradio function to wrap image generation.
|
289 |
-
"""
|
290 |
if seed is None:
|
291 |
seed = random.randint(0, 10000)
|
292 |
prompts = [prompt]
|
293 |
im_prompts = []
|
294 |
if im_prompt is not None:
|
295 |
im_prompts = [im_prompt]
|
296 |
-
|
297 |
try:
|
298 |
pil_ims = generate(n=1, prompts=prompts, images=im_prompts, seed=seed, steps=n_steps, method=method)
|
299 |
return pil_ims[0]
|
300 |
except Exception as e:
|
301 |
print(f"ERROR during generation: {e}")
|
302 |
-
# Return a blank image on failure
|
303 |
return Image.new('RGB', (256, 256), color = 'red')
|
304 |
|
305 |
-
|
306 |
-
# 🖼️ Gradio UI: The interface where users can input text or image prompts
|
307 |
iface = gr.Interface(
|
308 |
fn=gen_ims,
|
309 |
inputs=[
|
|
|
14 |
from torchvision.transforms import functional as TF
|
15 |
from tqdm import trange
|
16 |
from transformers import CLIPProcessor, CLIPModel
|
17 |
+
from huggingface_hub import hf_hub_download
|
18 |
+
import gradio as gr
|
19 |
import math
|
20 |
|
21 |
# -----------------------------------------------------------------------------
|
|
|
27 |
class VQVAE2(nn.Module):
|
28 |
def __init__(self, n_embed=8192, embed_dim=256, ch=128):
|
29 |
super().__init__()
|
|
|
|
|
|
|
|
|
30 |
self.decoder = nn.Sequential(
|
31 |
nn.Conv2d(embed_dim, ch * 4, 3, padding=1),
|
32 |
nn.ReLU(),
|
|
|
36 |
nn.ReLU(),
|
37 |
nn.ConvTranspose2d(ch, 3, 4, stride=2, padding=1),
|
38 |
)
|
39 |
+
|
40 |
def decode(self, latents):
|
|
|
|
|
|
|
41 |
return self.decoder(latents)
|
42 |
|
43 |
# Diffusion Model Definition
|
44 |
class Diffusion(nn.Module):
|
45 |
+
# FIXED: Changed n_head default from 8 to 4 to make dimensions compatible
|
46 |
+
def __init__(self, n_inputs=3, n_embed=512, n_head=4, n_layer=12):
|
47 |
super().__init__()
|
|
|
|
|
|
|
48 |
self.time_embed = nn.Embedding(1000, n_inputs * 4)
|
49 |
self.cond_embed = nn.Linear(n_embed, n_inputs * 4)
|
50 |
+
|
51 |
self.layers = nn.ModuleList([
|
52 |
+
nn.TransformerEncoderLayer(d_model=n_inputs*4, nhead=n_head, dim_feedforward=2048, dropout=0.1, activation='gelu', batch_first=True)
|
53 |
for _ in range(n_layer)
|
54 |
])
|
55 |
self.out = nn.Linear(n_inputs*4, n_inputs)
|
56 |
|
57 |
def forward(self, x, t, c):
|
|
|
|
|
58 |
bs, ch, h, w = x.shape
|
59 |
x = x.permute(0, 2, 3, 1).reshape(bs, h * w, ch)
|
60 |
+
|
61 |
t_emb = self.time_embed(t.long())
|
62 |
c_emb = self.cond_embed(c)
|
63 |
emb = t_emb + c_emb
|
64 |
+
|
|
|
65 |
x_out = self.out(x + emb.unsqueeze(1))
|
66 |
x_out = x_out.reshape(bs, h, w, ch).permute(0, 3, 1, 2)
|
67 |
return x_out
|
|
|
69 |
|
70 |
# Sampling Function Definitions
|
71 |
def get_sigmas(n_steps):
|
|
|
72 |
t = torch.linspace(1, 0, n_steps + 1)
|
73 |
return ((t[:-1] ** 2) / (t[1:] ** 2) - 1).sqrt()
|
74 |
|
75 |
@torch.no_grad()
|
76 |
def plms_sample(model, x, steps, **kwargs):
|
|
|
77 |
ts = x.new_ones([x.shape[0]])
|
78 |
sigmas = get_sigmas(steps)
|
79 |
model_fn = lambda x, t: model(x, t * 1000, **kwargs)
|
80 |
+
|
|
|
81 |
old_denoised = None
|
82 |
+
|
83 |
for i in trange(len(sigmas) -1, disable=True):
|
84 |
denoised = model_fn(x, ts * sigmas[i])
|
85 |
+
|
86 |
if old_denoised is None:
|
87 |
d = (denoised - x) / sigmas[i]
|
88 |
else:
|
89 |
+
d = (3 * denoised - old_denoised) / 2 - x / sigmas[i]
|
90 |
|
91 |
x = x + d * (sigmas[i+1] - sigmas[i])
|
92 |
old_denoised = denoised
|
|
|
|
|
93 |
|
94 |
+
return x
|
95 |
+
|
96 |
def ddim_sample(model, x, steps, eta, **kwargs):
|
|
|
97 |
print("Warning: DDIM sampler is not fully implemented. Using PLMS instead.")
|
98 |
return plms_sample(model, x, steps, **kwargs)
|
99 |
|
100 |
def ddpm_sample(model, x, steps, **kwargs):
|
|
|
101 |
print("Warning: DDPM sampler is not fully implemented. Using PLMS instead.")
|
102 |
return plms_sample(model, x, steps, **kwargs)
|
103 |
|
|
|
106 |
# -----------------------------------------------------------------------------
|
107 |
|
108 |
# 🖼️ Download the necessary model files from HuggingFace
|
|
|
|
|
109 |
try:
|
|
|
110 |
vqvae_model_path = hf_hub_download(repo_id="dalle-mini/vqgan_imagenet_f16_16384", filename="flax_model.msgpack")
|
111 |
diffusion_model_path = hf_hub_download(repo_id="huggingface/dalle-2", filename="diffusion_model.ckpt")
|
112 |
except Exception as e:
|
113 |
print(f"Could not download models. Please ensure the HuggingFace URLs are correct.")
|
114 |
print("Using placeholder models which will not produce good images.")
|
|
|
115 |
Path("vqvae_model.ckpt").touch()
|
116 |
Path("diffusion_model.ckpt").touch()
|
117 |
vqvae_model_path = "vqvae_model.ckpt"
|
|
|
119 |
|
120 |
|
121 |
# 📐 Utility Functions: Math and images, what could go wrong?
|
|
|
|
|
122 |
def parse_prompt(prompt, default_weight=3.):
|
|
|
|
|
|
|
123 |
vals = prompt.rsplit(':', 1)
|
124 |
vals = vals + ['', default_weight][len(vals):]
|
125 |
return vals[0], float(vals[1])
|
126 |
|
127 |
def resize_and_center_crop(image, size):
|
|
|
|
|
|
|
128 |
fac = max(size[0] / image.size[0], size[1] / image.size[1])
|
129 |
image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS)
|
130 |
return TF.center_crop(image, size[::-1])
|
131 |
|
132 |
# 🧠 Model loading: the brain of our operation! 🔥
|
|
|
133 |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
134 |
print('Using device:', device)
|
135 |
print('loading models... 🛠️')
|
|
|
139 |
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
140 |
|
141 |
# Load VQ-VAE-2 Autoencoder
|
|
|
|
|
142 |
try:
|
143 |
vqvae = VQVAE2()
|
|
|
144 |
print("Skipping VQVAE weight loading due to placeholder architecture.")
|
145 |
except Exception as e:
|
146 |
print(f"Could not load VQVAE model: {e}. Using placeholder.")
|
147 |
vqvae = VQVAE2()
|
148 |
vqvae.eval().requires_grad_(False).to(device)
|
149 |
|
|
|
150 |
# Load Diffusion Model
|
|
|
151 |
try:
|
152 |
diffusion_model = Diffusion()
|
|
|
153 |
print("Skipping Diffusion Model weight loading due to placeholder architecture.")
|
154 |
except Exception as e:
|
155 |
print(f"Could not load Diffusion model: {e}. Using placeholder.")
|
|
|
158 |
|
159 |
|
160 |
# 🎨 The key function: Where the magic happens!
|
|
|
|
|
161 |
def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15, method='ddim', eta=None):
|
|
|
|
|
|
|
162 |
zero_embed = torch.zeros([1, clip_model.config.projection_dim], device=device)
|
163 |
target_embeds, weights = [zero_embed], []
|
164 |
|
|
|
165 |
for prompt in prompts:
|
166 |
inputs = clip_processor(text=prompt, return_tensors="pt").to(device)
|
167 |
text_embed = clip_model.get_text_features(**inputs).float()
|
168 |
target_embeds.append(text_embed)
|
169 |
weights.append(1.0)
|
170 |
|
|
|
|
|
171 |
image_prompt_weight = 1.0
|
172 |
for image_path in images:
|
173 |
+
if image_path:
|
174 |
try:
|
175 |
img = Image.open(image_path).convert('RGB')
|
176 |
img = resize_and_center_crop(img, (224, 224))
|
|
|
181 |
except Exception as e:
|
182 |
print(f"Warning: Could not process image prompt {image_path}. Error: {e}")
|
183 |
|
|
|
|
|
184 |
weights = torch.tensor([1 - sum(weights), *weights], device=device)
|
185 |
torch.manual_seed(seed)
|
186 |
|
|
|
187 |
def cfg_model_fn(x, t):
|
188 |
n = x.shape[0]
|
189 |
n_conds = len(target_embeds)
|
190 |
x_in = x.repeat([n_conds, 1, 1, 1])
|
191 |
t_in = t.repeat([n_conds])
|
192 |
embed_in = torch.cat(target_embeds).repeat_interleave(n, 0)
|
193 |
+
|
|
|
194 |
if isinstance(diffusion_model, Diffusion):
|
195 |
+
embed_in = embed_in[:, :512]
|
196 |
+
|
197 |
vs = diffusion_model(x_in, t_in, embed_in).view([n_conds, n, *x.shape[1:]])
|
198 |
v = vs.mul(weights[:, None, None, None, None]).sum(0)
|
199 |
return v
|
200 |
|
|
|
201 |
def run(x, steps):
|
202 |
if method == 'ddpm':
|
203 |
return ddpm_sample(cfg_model_fn, x, steps)
|
|
|
207 |
return plms_sample(cfg_model_fn, x, steps)
|
208 |
assert False, f"Unknown method: {method}"
|
209 |
|
|
|
210 |
batch_size = n
|
211 |
x = torch.randn([n, 3, 64, 64], device=device)
|
|
|
212 |
pil_ims = []
|
213 |
for i in trange(0, n, batch_size):
|
214 |
cur_batch_size = min(n - i, batch_size)
|
215 |
out_latents = run(x[i:i + cur_batch_size], steps)
|
216 |
+
|
|
|
|
|
217 |
if isinstance(vqvae, VQVAE2):
|
218 |
+
outs = vqvae.decode(out_latents)
|
219 |
+
for j, out in enumerate(outs):
|
220 |
+
pil_ims.append(transforms.ToPILImage()(out.clamp(0, 1)))
|
|
|
|
|
221 |
else:
|
222 |
outs = vqvae.decode(out_latents)
|
223 |
for j, out in enumerate(outs):
|
224 |
pil_ims.append(transforms.ToPILImage()(out.clamp(0, 1)))
|
|
|
225 |
return pil_ims
|
226 |
|
227 |
# 🖌️ Interface: Gradio's brush to paint the UI
|
228 |
def gen_ims(prompt, im_prompt=None, seed=None, n_steps=10, method='plms'):
|
|
|
|
|
|
|
229 |
if seed is None:
|
230 |
seed = random.randint(0, 10000)
|
231 |
prompts = [prompt]
|
232 |
im_prompts = []
|
233 |
if im_prompt is not None:
|
234 |
im_prompts = [im_prompt]
|
|
|
235 |
try:
|
236 |
pil_ims = generate(n=1, prompts=prompts, images=im_prompts, seed=seed, steps=n_steps, method=method)
|
237 |
return pil_ims[0]
|
238 |
except Exception as e:
|
239 |
print(f"ERROR during generation: {e}")
|
|
|
240 |
return Image.new('RGB', (256, 256), color = 'red')
|
241 |
|
242 |
+
# 🖼️ Gradio UI
|
|
|
243 |
iface = gr.Interface(
|
244 |
fn=gen_ims,
|
245 |
inputs=[
|