Spaces:
Sleeping
Sleeping
Update app.py
Browse filesAdded dit_size specific num_samples limit
app.py
CHANGED
@@ -166,11 +166,27 @@ def generate_random_seed():
|
|
166 |
"""Generate a random seed between 0 and 2^32 - 1"""
|
167 |
return random.randint(0, 2**32 - 1)
|
168 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
@spaces.GPU(duration=120)
|
170 |
def generate_samples(num_samples, dit_size, genre_name, style_name, seed, progress=gr.Progress()):
|
171 |
"""Main function for Gradio interface"""
|
172 |
-
|
173 |
-
|
|
|
174 |
|
175 |
# Get genre and style IDs from mappings
|
176 |
genre_id = reduced_genre_mapping.get(genre_name)
|
@@ -205,8 +221,21 @@ with gr.Blocks(title="DiT Diffusion Model Generator", theme=gr.themes.Soft()) as
|
|
205 |
|
206 |
with gr.Row():
|
207 |
with gr.Column(scale=1):
|
208 |
-
|
209 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
|
211 |
genre_names = list(reduced_genre_mapping.keys())
|
212 |
style_names = list(reduced_style_mapping.keys())
|
@@ -230,8 +259,10 @@ with gr.Blocks(title="DiT Diffusion Model Generator", theme=gr.themes.Soft()) as
|
|
230 |
progress_bar = gr.Progress(track_tqdm=True)
|
231 |
|
232 |
with gr.Column(scale=2):
|
233 |
-
output_gallery = gr.Gallery(label="Generated Images", columns=4, rows=
|
234 |
error_message = gr.Textbox(label="Error", visible=False, max_lines=3, container=True, elem_id="error_box")
|
|
|
|
|
235 |
|
236 |
# Seed reset button functionality
|
237 |
reset_seed_btn.click(generate_random_seed, inputs=[], outputs=[seed])
|
|
|
166 |
"""Generate a random seed between 0 and 2^32 - 1"""
|
167 |
return random.randint(0, 2**32 - 1)
|
168 |
|
169 |
+
MODEL_SAMPLE_LIMITS = {
|
170 |
+
"S": {"min":1, "max": 18, "default": 4},
|
171 |
+
"B": {"min":1, "max": 9, "default": 4},
|
172 |
+
"L": {"min":1, "max": 3, "default": 1}
|
173 |
+
}
|
174 |
+
|
175 |
+
def update_sample_slider(dit_size):
|
176 |
+
limits = MODEL_SAMPLE_LIMITS[dit_size]
|
177 |
+
return gr.update(
|
178 |
+
minimum=limits["min"],
|
179 |
+
maximum=limits["max"],
|
180 |
+
value=limits["default"],
|
181 |
+
info=f"How many images to generate ({limits['min']}-{limits['max']})"
|
182 |
+
)
|
183 |
+
|
184 |
@spaces.GPU(duration=120)
|
185 |
def generate_samples(num_samples, dit_size, genre_name, style_name, seed, progress=gr.Progress()):
|
186 |
"""Main function for Gradio interface"""
|
187 |
+
limits = MODEL_SAMPLE_LIMITS[dit_size]
|
188 |
+
if num_samples < limits["min"] or num_samples > limits["max"]:
|
189 |
+
return None, gr.update(value=f"Number of samples for {dit_size} model must be between {limits["min"]} and {limits["max"]}", visible=True)
|
190 |
|
191 |
# Get genre and style IDs from mappings
|
192 |
genre_id = reduced_genre_mapping.get(genre_name)
|
|
|
221 |
|
222 |
with gr.Row():
|
223 |
with gr.Column(scale=1):
|
224 |
+
dit_size = gr.Radio(
|
225 |
+
choices=["S", "B", "L"],
|
226 |
+
value="B",
|
227 |
+
label="DiT Model Size",
|
228 |
+
info="S: Small (fastest), B: Base (balanced), L: Large (best quality but slowest)"
|
229 |
+
)
|
230 |
+
|
231 |
+
num_samples = gr.Slider(
|
232 |
+
minimum=MODEL_SAMPLE_LIMITS["B"]["min"],
|
233 |
+
maximum=MODEL_SAMPLE_LIMITS["B"]["max"],
|
234 |
+
value=MODEL_SAMPLE_LIMITS["B"]["default"],
|
235 |
+
step=1,
|
236 |
+
label="Number of Samples",
|
237 |
+
info=f"How many images to generate ({MODEL_SAMPLE_LIMITS['B']['min']}-{MODEL_SAMPLE_LIMITS['B']['max']})"
|
238 |
+
)
|
239 |
|
240 |
genre_names = list(reduced_genre_mapping.keys())
|
241 |
style_names = list(reduced_style_mapping.keys())
|
|
|
259 |
progress_bar = gr.Progress(track_tqdm=True)
|
260 |
|
261 |
with gr.Column(scale=2):
|
262 |
+
output_gallery = gr.Gallery(label="Generated Images", columns=4, rows=4, object_fit="contain", height=600)
|
263 |
error_message = gr.Textbox(label="Error", visible=False, max_lines=3, container=True, elem_id="error_box")
|
264 |
+
|
265 |
+
dit_size.change(update_sample_slider, inputs=[dit_size],outputs=[num_samples])
|
266 |
|
267 |
# Seed reset button functionality
|
268 |
reset_seed_btn.click(generate_random_seed, inputs=[], outputs=[seed])
|