danhtran2mind commited on
Commit
d193794
·
verified ·
1 Parent(s): 19bcfaf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -7
app.py CHANGED
@@ -5,6 +5,9 @@ import numpy as np
5
  from transformers import CLIPTextModel, CLIPTokenizer
6
  from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
7
  from tqdm import tqdm
 
 
 
8
 
9
  # Set device and dtype
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -20,7 +23,41 @@ text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder="text_encoder
20
  unet = UNet2DConditionModel.from_pretrained(model_name, subfolder="unet", torch_dtype=dtype).to(device)
21
  scheduler = PNDMScheduler.from_pretrained(model_name, subfolder="scheduler")
22
 
23
- def generate_image(prompt, height, width, num_inference_steps, guidance_scale, seed):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  # Validate inputs
25
  # if not prompt:
26
  # return None, "Prompt cannot be empty."
@@ -38,7 +75,9 @@ def generate_image(prompt, height, width, num_inference_steps, guidance_scale, s
38
  # Set batch size
39
  batch_size = 1
40
 
41
- # Create generator
 
 
42
  generator = torch.Generator(device=device).manual_seed(int(seed))
43
 
44
  # Tokenize and encode prompt
@@ -96,13 +135,14 @@ def generate_image(prompt, height, width, num_inference_steps, guidance_scale, s
96
  image = (image * 255).round().astype("uint8")
97
  pil_image = Image.fromarray(image[0])
98
 
99
- return pil_image, "Image generated successfully!"
100
 
101
  # Gradio interface
102
  with gr.Blocks() as demo:
103
  gr.Markdown("# Ghibli-Style Image Generator")
104
- gr.Markdown("Generate images in Ghibli style using a fine-tuned Stable Diffusion model. Enter a prompt and adjust parameters to create your image.")
105
-
 
106
  with gr.Row():
107
  with gr.Column():
108
  prompt = gr.Textbox(label="Prompt", placeholder="e.g., 'a serene landscape in Ghibli style'")
@@ -111,15 +151,26 @@ with gr.Blocks() as demo:
111
  num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=100, step=1, value=50)
112
  guidance_scale = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=20.0, step=0.5, value=3.5)
113
  seed = gr.Slider(label="Seed", minimum=0, maximum=4294967295, step=1, value=42)
 
114
  generate_btn = gr.Button("Generate Image")
115
  with gr.Column():
116
  output_image = gr.Image(label="Generated Image")
117
  output_text = gr.Textbox(label="Status")
118
 
 
 
 
 
 
 
 
 
 
119
  generate_btn.click(
120
  fn=generate_image,
121
- inputs=[prompt, height, width, num_inference_steps, guidance_scale, seed],
122
  outputs=[output_image, output_text]
123
  )
124
 
125
- demo.launch()
 
 
5
  from transformers import CLIPTextModel, CLIPTokenizer
6
  from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
7
  from tqdm import tqdm
8
+ import os
9
+ import json
10
+ import glob
11
 
12
  # Set device and dtype
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
23
  unet = UNet2DConditionModel.from_pretrained(model_name, subfolder="unet", torch_dtype=dtype).to(device)
24
  scheduler = PNDMScheduler.from_pretrained(model_name, subfolder="scheduler")
25
 
26
+ def load_examples_from_directory(sample_output_dir="sample_output"):
27
+ """
28
+ Load example data from the sample_output directory.
29
+ Assumes each image has a corresponding .json file with metadata.
30
+ """
31
+ examples = []
32
+ # Look for .json files in the directory
33
+ json_files = glob.glob(os.path.join(sample_output_dir, "*.json"))
34
+
35
+ for json_file in json_files:
36
+ try:
37
+ with open(json_file, 'r') as f:
38
+ metadata = json.load(f)
39
+ # Ensure required fields are present
40
+ required_keys = ["prompt", "height", "width", "num_inference_steps", "guidance_scale", "seed"]
41
+ if all(key in metadata for key in required_keys):
42
+ examples.append([
43
+ metadata["prompt"],
44
+ metadata["height"],
45
+ metadata["width"],
46
+ metadata["num_inference_steps"],
47
+ metadata["guidance_scale"],
48
+ metadata["seed"]
49
+ ])
50
+ except Exception as e:
51
+ print(f"Error loading {json_file}: {e}")
52
+
53
+ # If no valid examples are found, return a default example
54
+ if not examples:
55
+ examples = [
56
+ ["a serene landscape in Ghibli style", 64, 64, 50, 3.5, 42]
57
+ ]
58
+ return examples
59
+
60
+ def generate_image(prompt, height, width, num_inference_steps, guidance_scale, seed, random_seed):
61
  # Validate inputs
62
  # if not prompt:
63
  # return None, "Prompt cannot be empty."
 
75
  # Set batch size
76
  batch_size = 1
77
 
78
+ # Handle random seed
79
+ if random_seed:
80
+ seed = torch.randint(0, 4294967295, (1,)).item()
81
  generator = torch.Generator(device=device).manual_seed(int(seed))
82
 
83
  # Tokenize and encode prompt
 
135
  image = (image * 255).round().astype("uint8")
136
  pil_image = Image.fromarray(image[0])
137
 
138
+ return pil_image, f"Image generated successfully! Seed used: {seed}"
139
 
140
  # Gradio interface
141
  with gr.Blocks() as demo:
142
  gr.Markdown("# Ghibli-Style Image Generator")
143
+ gr.Markdown("Generate images in Ghibli style using a fine-tuned Stable Diffusion model. Enter ABOVE a prompt and adjust parameters to create your image.")
144
+ gr.Markdown("**Note:** For CPU inference, execution time is long (e.g., for 64x64 resolution with 50 inference steps, time is approximately 1800 seconds).")
145
+
146
  with gr.Row():
147
  with gr.Column():
148
  prompt = gr.Textbox(label="Prompt", placeholder="e.g., 'a serene landscape in Ghibli style'")
 
151
  num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=100, step=1, value=50)
152
  guidance_scale = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=20.0, step=0.5, value=3.5)
153
  seed = gr.Slider(label="Seed", minimum=0, maximum=4294967295, step=1, value=42)
154
+ random_seed = gr.Checkbox(label="Use Random Seed", value=False)
155
  generate_btn = gr.Button("Generate Image")
156
  with gr.Column():
157
  output_image = gr.Image(label="Generated Image")
158
  output_text = gr.Textbox(label="Status")
159
 
160
+ gr.Markdown("### Example Prompts")
161
+ # Load examples from sample_output directory
162
+ examples_data = load_examples_from_directory("sample_output")
163
+ examples = gr.Dataframe(
164
+ value=examples_data,
165
+ headers=["Prompt", "Height", "Width", "Inference Steps", "Guidance Scale", "Seed"],
166
+ label="Examples"
167
+ )
168
+
169
  generate_btn.click(
170
  fn=generate_image,
171
+ inputs=[prompt, height, width, num_inference_steps, guidance_scale, seed, random_seed],
172
  outputs=[output_image, output_text]
173
  )
174
 
175
+ # Launch with limited users
176
+ demo.launch(max_threads=3)