Update app.py
Browse files
app.py
CHANGED
@@ -5,6 +5,9 @@ import numpy as np
|
|
5 |
from transformers import CLIPTextModel, CLIPTokenizer
|
6 |
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
|
7 |
from tqdm import tqdm
|
|
|
|
|
|
|
8 |
|
9 |
# Set device and dtype
|
10 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
@@ -20,7 +23,41 @@ text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder="text_encoder
|
|
20 |
unet = UNet2DConditionModel.from_pretrained(model_name, subfolder="unet", torch_dtype=dtype).to(device)
|
21 |
scheduler = PNDMScheduler.from_pretrained(model_name, subfolder="scheduler")
|
22 |
|
23 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
# Validate inputs
|
25 |
# if not prompt:
|
26 |
# return None, "Prompt cannot be empty."
|
@@ -38,7 +75,9 @@ def generate_image(prompt, height, width, num_inference_steps, guidance_scale, s
|
|
38 |
# Set batch size
|
39 |
batch_size = 1
|
40 |
|
41 |
-
#
|
|
|
|
|
42 |
generator = torch.Generator(device=device).manual_seed(int(seed))
|
43 |
|
44 |
# Tokenize and encode prompt
|
@@ -96,13 +135,14 @@ def generate_image(prompt, height, width, num_inference_steps, guidance_scale, s
|
|
96 |
image = (image * 255).round().astype("uint8")
|
97 |
pil_image = Image.fromarray(image[0])
|
98 |
|
99 |
-
return pil_image, "Image generated successfully!"
|
100 |
|
101 |
# Gradio interface
|
102 |
with gr.Blocks() as demo:
|
103 |
gr.Markdown("# Ghibli-Style Image Generator")
|
104 |
-
gr.Markdown("Generate images in Ghibli style using a fine-tuned Stable Diffusion model. Enter a prompt and adjust parameters to create your image.")
|
105 |
-
|
|
|
106 |
with gr.Row():
|
107 |
with gr.Column():
|
108 |
prompt = gr.Textbox(label="Prompt", placeholder="e.g., 'a serene landscape in Ghibli style'")
|
@@ -111,15 +151,26 @@ with gr.Blocks() as demo:
|
|
111 |
num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=100, step=1, value=50)
|
112 |
guidance_scale = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=20.0, step=0.5, value=3.5)
|
113 |
seed = gr.Slider(label="Seed", minimum=0, maximum=4294967295, step=1, value=42)
|
|
|
114 |
generate_btn = gr.Button("Generate Image")
|
115 |
with gr.Column():
|
116 |
output_image = gr.Image(label="Generated Image")
|
117 |
output_text = gr.Textbox(label="Status")
|
118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
generate_btn.click(
|
120 |
fn=generate_image,
|
121 |
-
inputs=[prompt, height, width, num_inference_steps, guidance_scale, seed],
|
122 |
outputs=[output_image, output_text]
|
123 |
)
|
124 |
|
125 |
-
|
|
|
|
5 |
from transformers import CLIPTextModel, CLIPTokenizer
|
6 |
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
|
7 |
from tqdm import tqdm
|
8 |
+
import os
|
9 |
+
import json
|
10 |
+
import glob
|
11 |
|
12 |
# Set device and dtype
|
13 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
23 |
unet = UNet2DConditionModel.from_pretrained(model_name, subfolder="unet", torch_dtype=dtype).to(device)
|
24 |
scheduler = PNDMScheduler.from_pretrained(model_name, subfolder="scheduler")
|
25 |
|
26 |
+
def load_examples_from_directory(sample_output_dir="sample_output"):
|
27 |
+
"""
|
28 |
+
Load example data from the sample_output directory.
|
29 |
+
Assumes each image has a corresponding .json file with metadata.
|
30 |
+
"""
|
31 |
+
examples = []
|
32 |
+
# Look for .json files in the directory
|
33 |
+
json_files = glob.glob(os.path.join(sample_output_dir, "*.json"))
|
34 |
+
|
35 |
+
for json_file in json_files:
|
36 |
+
try:
|
37 |
+
with open(json_file, 'r') as f:
|
38 |
+
metadata = json.load(f)
|
39 |
+
# Ensure required fields are present
|
40 |
+
required_keys = ["prompt", "height", "width", "num_inference_steps", "guidance_scale", "seed"]
|
41 |
+
if all(key in metadata for key in required_keys):
|
42 |
+
examples.append([
|
43 |
+
metadata["prompt"],
|
44 |
+
metadata["height"],
|
45 |
+
metadata["width"],
|
46 |
+
metadata["num_inference_steps"],
|
47 |
+
metadata["guidance_scale"],
|
48 |
+
metadata["seed"]
|
49 |
+
])
|
50 |
+
except Exception as e:
|
51 |
+
print(f"Error loading {json_file}: {e}")
|
52 |
+
|
53 |
+
# If no valid examples are found, return a default example
|
54 |
+
if not examples:
|
55 |
+
examples = [
|
56 |
+
["a serene landscape in Ghibli style", 64, 64, 50, 3.5, 42]
|
57 |
+
]
|
58 |
+
return examples
|
59 |
+
|
60 |
+
def generate_image(prompt, height, width, num_inference_steps, guidance_scale, seed, random_seed):
|
61 |
# Validate inputs
|
62 |
# if not prompt:
|
63 |
# return None, "Prompt cannot be empty."
|
|
|
75 |
# Set batch size
|
76 |
batch_size = 1
|
77 |
|
78 |
+
# Handle random seed
|
79 |
+
if random_seed:
|
80 |
+
seed = torch.randint(0, 4294967295, (1,)).item()
|
81 |
generator = torch.Generator(device=device).manual_seed(int(seed))
|
82 |
|
83 |
# Tokenize and encode prompt
|
|
|
135 |
image = (image * 255).round().astype("uint8")
|
136 |
pil_image = Image.fromarray(image[0])
|
137 |
|
138 |
+
return pil_image, f"Image generated successfully! Seed used: {seed}"
|
139 |
|
140 |
# Gradio interface
|
141 |
with gr.Blocks() as demo:
|
142 |
gr.Markdown("# Ghibli-Style Image Generator")
|
143 |
+
gr.Markdown("Generate images in Ghibli style using a fine-tuned Stable Diffusion model. Enter ABOVE a prompt and adjust parameters to create your image.")
|
144 |
+
gr.Markdown("**Note:** For CPU inference, execution time is long (e.g., for 64x64 resolution with 50 inference steps, time is approximately 1800 seconds).")
|
145 |
+
|
146 |
with gr.Row():
|
147 |
with gr.Column():
|
148 |
prompt = gr.Textbox(label="Prompt", placeholder="e.g., 'a serene landscape in Ghibli style'")
|
|
|
151 |
num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=100, step=1, value=50)
|
152 |
guidance_scale = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=20.0, step=0.5, value=3.5)
|
153 |
seed = gr.Slider(label="Seed", minimum=0, maximum=4294967295, step=1, value=42)
|
154 |
+
random_seed = gr.Checkbox(label="Use Random Seed", value=False)
|
155 |
generate_btn = gr.Button("Generate Image")
|
156 |
with gr.Column():
|
157 |
output_image = gr.Image(label="Generated Image")
|
158 |
output_text = gr.Textbox(label="Status")
|
159 |
|
160 |
+
gr.Markdown("### Example Prompts")
|
161 |
+
# Load examples from sample_output directory
|
162 |
+
examples_data = load_examples_from_directory("sample_output")
|
163 |
+
examples = gr.Dataframe(
|
164 |
+
value=examples_data,
|
165 |
+
headers=["Prompt", "Height", "Width", "Inference Steps", "Guidance Scale", "Seed"],
|
166 |
+
label="Examples"
|
167 |
+
)
|
168 |
+
|
169 |
generate_btn.click(
|
170 |
fn=generate_image,
|
171 |
+
inputs=[prompt, height, width, num_inference_steps, guidance_scale, seed, random_seed],
|
172 |
outputs=[output_image, output_text]
|
173 |
)
|
174 |
|
175 |
+
# Launch with limited users
|
176 |
+
demo.launch(max_threads=3)
|