kaupane commited on
Commit
bfbe2d4
·
verified ·
1 Parent(s): e92c9a9

Update app.py

Browse files

Added dit_size specific num_samples limit

Files changed (1) hide show
  1. app.py +36 -5
app.py CHANGED
@@ -166,11 +166,27 @@ def generate_random_seed():
166
  """Generate a random seed between 0 and 2^32 - 1"""
167
  return random.randint(0, 2**32 - 1)
168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  @spaces.GPU(duration=120)
170
  def generate_samples(num_samples, dit_size, genre_name, style_name, seed, progress=gr.Progress()):
171
  """Main function for Gradio interface"""
172
- if num_samples < 1 or num_samples > 12:
173
- return None, gr.update(value="Number of samples must be between 1 and 12", visible=True)
 
174
 
175
  # Get genre and style IDs from mappings
176
  genre_id = reduced_genre_mapping.get(genre_name)
@@ -205,8 +221,21 @@ with gr.Blocks(title="DiT Diffusion Model Generator", theme=gr.themes.Soft()) as
205
 
206
  with gr.Row():
207
  with gr.Column(scale=1):
208
- num_samples = gr.Slider(minimum=1, maximum=12, value=2, step=1, label="Number of Samples", info="How many images to generate (1-12)")
209
- dit_size = gr.Radio(choices=["S", "B", "L"], value="B", label="DiT Model Size", info="Larger models produce better quality but take longer. Recommend to choose <4 num_samples for L(Large) model.")
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
  genre_names = list(reduced_genre_mapping.keys())
212
  style_names = list(reduced_style_mapping.keys())
@@ -230,8 +259,10 @@ with gr.Blocks(title="DiT Diffusion Model Generator", theme=gr.themes.Soft()) as
230
  progress_bar = gr.Progress(track_tqdm=True)
231
 
232
  with gr.Column(scale=2):
233
- output_gallery = gr.Gallery(label="Generated Images", columns=4, rows=3, object_fit="contain", height=600)
234
  error_message = gr.Textbox(label="Error", visible=False, max_lines=3, container=True, elem_id="error_box")
 
 
235
 
236
  # Seed reset button functionality
237
  reset_seed_btn.click(generate_random_seed, inputs=[], outputs=[seed])
 
166
  """Generate a random seed between 0 and 2^32 - 1"""
167
  return random.randint(0, 2**32 - 1)
168
 
169
+ MODEL_SAMPLE_LIMITS = {
170
+ "S": {"min":1, "max": 18, "default": 4},
171
+ "B": {"min":1, "max": 9, "default": 4},
172
+ "L": {"min":1, "max": 3, "default": 1}
173
+ }
174
+
175
+ def update_sample_slider(dit_size):
176
+ limits = MODEL_SAMPLE_LIMITS[dit_size]
177
+ return gr.update(
178
+ minimum=limits["min"],
179
+ maximum=limits["max"],
180
+ value=limits["default"],
181
+ info=f"How many images to generate ({limits['min']}-{limits['max']})"
182
+ )
183
+
184
  @spaces.GPU(duration=120)
185
  def generate_samples(num_samples, dit_size, genre_name, style_name, seed, progress=gr.Progress()):
186
  """Main function for Gradio interface"""
187
+ limits = MODEL_SAMPLE_LIMITS[dit_size]
188
+ if num_samples < limits["min"] or num_samples > limits["max"]:
189
+ return None, gr.update(value=f"Number of samples for {dit_size} model must be between {limits["min"]} and {limits["max"]}", visible=True)
190
 
191
  # Get genre and style IDs from mappings
192
  genre_id = reduced_genre_mapping.get(genre_name)
 
221
 
222
  with gr.Row():
223
  with gr.Column(scale=1):
224
+ dit_size = gr.Radio(
225
+ choices=["S", "B", "L"],
226
+ value="B",
227
+ label="DiT Model Size",
228
+ info="S: Small (fastest), B: Base (balanced), L: Large (best quality but slowest)"
229
+ )
230
+
231
+ num_samples = gr.Slider(
232
+ minimum=MODEL_SAMPLE_LIMITS["B"]["min"],
233
+ maximum=MODEL_SAMPLE_LIMITS["B"]["max"],
234
+ value=MODEL_SAMPLE_LIMITS["B"]["default"],
235
+ step=1,
236
+ label="Number of Samples",
237
+ info=f"How many images to generate ({MODEL_SAMPLE_LIMITS['B']['min']}-{MODEL_SAMPLE_LIMITS['B']['max']})"
238
+ )
239
 
240
  genre_names = list(reduced_genre_mapping.keys())
241
  style_names = list(reduced_style_mapping.keys())
 
259
  progress_bar = gr.Progress(track_tqdm=True)
260
 
261
  with gr.Column(scale=2):
262
+ output_gallery = gr.Gallery(label="Generated Images", columns=4, rows=4, object_fit="contain", height=600)
263
  error_message = gr.Textbox(label="Error", visible=False, max_lines=3, container=True, elem_id="error_box")
264
+
265
+ dit_size.change(update_sample_slider, inputs=[dit_size],outputs=[num_samples])
266
 
267
  # Seed reset button functionality
268
  reset_seed_btn.click(generate_random_seed, inputs=[], outputs=[seed])