Update app.py
Browse files
app.py
CHANGED
@@ -14,9 +14,7 @@ from torchvision import transforms
|
|
14 |
from torchvision.transforms import functional as TF
|
15 |
from tqdm import trange
|
16 |
from transformers import CLIPProcessor, CLIPModel
|
17 |
-
|
18 |
-
# from diffusion_models import Diffusion # Swapped Diffusion model for DALL·E 2 based model - REMOVED
|
19 |
-
from huggingface_hub import hf_hub_url, cached_download
|
20 |
import gradio as gr # 🎨 The magic canvas for AI-powered image generation!
|
21 |
import math
|
22 |
|
@@ -130,8 +128,9 @@ def ddpm_sample(model, x, steps, **kwargs):
|
|
130 |
# NOTE: The HuggingFace URLs you provided might be placeholders.
|
131 |
# Make sure these point to the correct model files.
|
132 |
try:
|
133 |
-
|
134 |
-
|
|
|
135 |
except Exception as e:
|
136 |
print(f"Could not download models. Please ensure the HuggingFace URLs are correct.")
|
137 |
print("Using placeholder models which will not produce good images.")
|
@@ -213,7 +212,7 @@ def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15, method
|
|
213 |
target_embeds.append(text_embed)
|
214 |
weights.append(1.0)
|
215 |
|
216 |
-
#
|
217 |
# Assign a default weight for image prompts
|
218 |
image_prompt_weight = 1.0
|
219 |
for image_path in images:
|
@@ -250,7 +249,6 @@ def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15, method
|
|
250 |
return v
|
251 |
|
252 |
# 🎞️ Run the sampler to generate images
|
253 |
-
# **FIXED**: Call sampling functions directly without the 'sampling.' prefix
|
254 |
def run(x, steps):
|
255 |
if method == 'ddpm':
|
256 |
return ddpm_sample(cfg_model_fn, x, steps)
|
@@ -310,7 +308,6 @@ iface = gr.Interface(
|
|
310 |
fn=gen_ims,
|
311 |
inputs=[
|
312 |
gr.Textbox(label="Text prompt"),
|
313 |
-
# **FIXED**: Removed deprecated 'optional=True' argument
|
314 |
gr.Image(label="Image prompt", type='filepath')
|
315 |
],
|
316 |
outputs=gr.Image(type="pil", label="Generated Image"),
|
|
|
14 |
from torchvision.transforms import functional as TF
|
15 |
from tqdm import trange
|
16 |
from transformers import CLIPProcessor, CLIPModel
|
17 |
+
from huggingface_hub import hf_hub_download # FIXED: Replaced deprecated function
|
|
|
|
|
18 |
import gradio as gr # 🎨 The magic canvas for AI-powered image generation!
|
19 |
import math
|
20 |
|
|
|
128 |
# NOTE: The HuggingFace URLs you provided might be placeholders.
|
129 |
# Make sure these point to the correct model files.
|
130 |
try:
|
131 |
+
# FIXED: Using the new hf_hub_download function with keyword arguments
|
132 |
+
vqvae_model_path = hf_hub_download(repo_id="dalle-mini/vqgan_imagenet_f16_16384", filename="flax_model.msgpack")
|
133 |
+
diffusion_model_path = hf_hub_download(repo_id="huggingface/dalle-2", filename="diffusion_model.ckpt")
|
134 |
except Exception as e:
|
135 |
print(f"Could not download models. Please ensure the HuggingFace URLs are correct.")
|
136 |
print("Using placeholder models which will not produce good images.")
|
|
|
212 |
target_embeds.append(text_embed)
|
213 |
weights.append(1.0)
|
214 |
|
215 |
+
# Correctly process image prompts from Gradio
|
216 |
# Assign a default weight for image prompts
|
217 |
image_prompt_weight = 1.0
|
218 |
for image_path in images:
|
|
|
249 |
return v
|
250 |
|
251 |
# 🎞️ Run the sampler to generate images
|
|
|
252 |
def run(x, steps):
|
253 |
if method == 'ddpm':
|
254 |
return ddpm_sample(cfg_model_fn, x, steps)
|
|
|
308 |
fn=gen_ims,
|
309 |
inputs=[
|
310 |
gr.Textbox(label="Text prompt"),
|
|
|
311 |
gr.Image(label="Image prompt", type='filepath')
|
312 |
],
|
313 |
outputs=gr.Image(type="pil", label="Generated Image"),
|