danhtran2mind commited on
Commit
7ee3710
·
verified ·
1 Parent(s): e81b8e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -20
app.py CHANGED
@@ -23,10 +23,11 @@ text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder="text_encoder
23
  unet = UNet2DConditionModel.from_pretrained(model_name, subfolder="unet", torch_dtype=dtype).to(device)
24
  scheduler = PNDMScheduler.from_pretrained(model_name, subfolder="scheduler")
25
 
26
- def load_examples_from_directory(sample_output_dir="sample_output"):
27
  """
28
- Load example data from the sample_output directory.
29
  Assumes each image has a corresponding .json file with metadata.
 
30
  """
31
  examples = []
32
  json_files = glob.glob(os.path.join(sample_output_dir, "*.json"))
@@ -43,14 +44,15 @@ def load_examples_from_directory(sample_output_dir="sample_output"):
43
  metadata["width"],
44
  metadata["num_inference_steps"],
45
  metadata["guidance_scale"],
46
- metadata["seed"]
 
47
  ])
48
  except Exception as e:
49
  print(f"Error loading {json_file}: {e}")
50
 
51
  if not examples:
52
  examples = [
53
- ["a serene landscape in Ghibli style", 64, 64, 50, 3.5, 42]
54
  ]
55
  return examples
56
 
@@ -122,25 +124,60 @@ def generate_image(prompt, height, width, num_inference_steps, guidance_scale, s
122
 
123
  return pil_image, f"Image generated successfully! Seed used: {seed}"
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  def on_example_select(selected_row):
126
  """
127
  Handle selection of a row in the examples Dataframe.
128
- Returns the parameters to populate the input fields and trigger image generation.
129
  """
 
 
 
 
 
 
 
 
 
 
 
 
130
  return (
131
- selected_row[0], # prompt
132
- selected_row[1], # height
133
- selected_row[2], # width
134
- selected_row[3], # num_inference_steps
135
- selected_row[4], # guidance_scale
136
- selected_row[5], # seed
137
- False # random_seed
138
  )
139
 
140
  # Gradio interface
141
  with gr.Blocks() as demo:
142
  gr.Markdown("# Ghibli-Style Image Generator")
143
- gr.Markdown("Generate images in Ghibli style using a fine-tuned Stable Diffusion model. Enter a prompt or select an example below to create your image.")
144
  gr.Markdown("**Note:** For CPU inference, execution time is long (e.g., for 64x64 resolution with 50 inference steps, time is approximately 1800 seconds).")
145
 
146
  with gr.Row():
@@ -158,10 +195,10 @@ with gr.Blocks() as demo:
158
  output_text = gr.Textbox(label="Status")
159
 
160
  gr.Markdown("### Example Prompts")
161
- examples_data = load_examples_from_directory("sample_output")
162
  examples = gr.Dataframe(
163
  value=examples_data,
164
- headers=["Prompt", "Height", "Width", "Inference Steps", "Guidance Scale", "Seed"],
165
  label="Examples",
166
  interactive=True
167
  )
@@ -170,11 +207,7 @@ with gr.Blocks() as demo:
170
  examples.select(
171
  fn=on_example_select,
172
  inputs=examples,
173
- outputs=[prompt, height, width, num_inference_steps, guidance_scale, seed, random_seed]
174
- ).then(
175
- fn=generate_image,
176
- inputs=[prompt, height, width, num_inference_steps, guidance_scale, seed, random_seed],
177
- outputs=[output_image, output_text]
178
  )
179
 
180
  generate_btn.click(
 
23
  unet = UNet2DConditionModel.from_pretrained(model_name, subfolder="unet", torch_dtype=dtype).to(device)
24
  scheduler = PNDMScheduler.from_pretrained(model_name, subfolder="scheduler")
25
 
26
+ def load_examples_from_directory(sample_output_dir="assets/example"):
27
  """
28
+ Load example data from the assets/example directory.
29
  Assumes each image has a corresponding .json file with metadata.
30
+ Returns a list of [prompt, height, width, num_inference_steps, guidance_scale, seed, json_filename].
31
  """
32
  examples = []
33
  json_files = glob.glob(os.path.join(sample_output_dir, "*.json"))
 
44
  metadata["width"],
45
  metadata["num_inference_steps"],
46
  metadata["guidance_scale"],
47
+ metadata["seed"],
48
+ os.path.basename(json_file) # Store the JSON filename for image lookup
49
  ])
50
  except Exception as e:
51
  print(f"Error loading {json_file}: {e}")
52
 
53
  if not examples:
54
  examples = [
55
+ ["a serene landscape in Ghibli style", 64, 64, 50, 3.5, 42, None]
56
  ]
57
  return examples
58
 
 
124
 
125
  return pil_image, f"Image generated successfully! Seed used: {seed}"
126
 
127
+ def load_existing_image(json_filename, sample_output_dir="assets/example"):
128
+ """
129
+ Load an existing image corresponding to the given JSON filename.
130
+ Assumes the image has the same base name as the JSON file (e.g., example1.json -> example1.png).
131
+ """
132
+ if not json_filename:
133
+ return None, "No associated image found."
134
+
135
+ # Get the base name without extension
136
+ base_name = os.path.splitext(json_filename)[0]
137
+ # Look for common image extensions
138
+ image_extensions = ["*.png", "*.jpg", "*.jpeg"]
139
+ image_path = None
140
+
141
+ for ext in image_extensions:
142
+ candidates = glob.glob(os.path.join(sample_output_dir, f"{base_name}{ext}"))
143
+ if candidates:
144
+ image_path = candidates[0]
145
+ break
146
+
147
+ if not image_path:
148
+ return None, f"No image found for {json_filename} in {sample_output_dir}."
149
+
150
+ try:
151
+ pil_image = Image.open(image_path)
152
+ return pil_image, f"Loaded existing image for {json_filename}."
153
+ except Exception as e:
154
+ return None, f"Error loading image {image_path}: {e}"
155
+
156
  def on_example_select(selected_row):
157
  """
158
  Handle selection of a row in the examples Dataframe.
159
+ Loads the existing image and updates input fields.
160
  """
161
+ prompt = selected_row[0]
162
+ height = selected_row[1]
163
+ width = selected_row[2]
164
+ num_inference_steps = selected_row[3]
165
+ guidance_scale = selected_row[4]
166
+ seed = selected_row[5]
167
+ json_filename = selected_row[6] # JSON filename for image lookup
168
+ random_seed = False
169
+
170
+ # Load the existing image
171
+ image, status = load_existing_image(json_filename)
172
+
173
  return (
174
+ prompt, height, width, num_inference_steps, guidance_scale, seed, random_seed, image, status
 
 
 
 
 
 
175
  )
176
 
177
  # Gradio interface
178
  with gr.Blocks() as demo:
179
  gr.Markdown("# Ghibli-Style Image Generator")
180
+ gr.Markdown("Generate images in Ghibli style using a fine-tuned Stable Diffusion model. Enter a prompt or select an example below to load a pre-generated image.")
181
  gr.Markdown("**Note:** For CPU inference, execution time is long (e.g., for 64x64 resolution with 50 inference steps, time is approximately 1800 seconds).")
182
 
183
  with gr.Row():
 
195
  output_text = gr.Textbox(label="Status")
196
 
197
  gr.Markdown("### Example Prompts")
198
+ examples_data = load_examples_from_directory("assets/example")
199
  examples = gr.Dataframe(
200
  value=examples_data,
201
+ headers=["Prompt", "Height", "Width", "Inference Steps", "Guidance Scale", "Seed", "JSON File"],
202
  label="Examples",
203
  interactive=True
204
  )
 
207
  examples.select(
208
  fn=on_example_select,
209
  inputs=examples,
210
+ outputs=[prompt, height, width, num_inference_steps, guidance_scale, seed, random_seed, output_image, output_text]
 
 
 
 
211
  )
212
 
213
  generate_btn.click(