Rogerjs commited on
Commit
9e91681
·
verified ·
1 Parent(s): ac04126

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +419 -0
app.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import zipfile
4
+ import time
5
+ import uuid # For unique filenames
6
+
7
+ # --- LLM/Model Setup ---
8
+ from transformers import pipeline as transformers_pipeline # For local list generation
9
+ from huggingface_hub import InferenceClient # For prompt refinement via API
10
+ from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler # For image generation
11
+ from gradio_client import Client as GradioClient, handle_file # For 3D generation
12
+
13
+ # --- Configuration ---
14
+ # Consider making these configurable in the UI later
15
+ LIST_GENERATION_MODEL = "google/flan-t5-base" # Or another suitable small model
16
+ PROMPT_REFINEMENT_MODEL_API = "mistralai/Mixtral-8x7B-Instruct-v0.1" # Or another instruct model via Inference API
17
+ IMAGE_GENERATION_MODEL = "stabilityai/stable-diffusion-xl-base-1.0" # Or "runwayml/stable-diffusion-v1-5"
18
+ HUNYUAN_SPACE_ID = "tencent/Hunyuan3D-2"
19
+ OUTPUT_DIR = "outputs"
20
+ MODELS_SUBDIR = "3d_models"
21
+ IMAGES_SUBDIR = "image_previews"
22
+ ZIP_FILENAME = "3d_collection.zip"
23
+
24
+ # --- Initialize Clients/Pipelines (can be slow, consider loading on demand if needed) ---
25
+
26
+ # Use HF Token from Space secrets if available/needed for Inference API
27
+ HF_TOKEN = os.environ.get("HUGGINGFACE_TOKEN", None)
28
+
29
+ # Basic List Generator (local)
30
+ try:
31
+ list_generator = transformers_pipeline("text2text-generation", model=LIST_GENERATION_MODEL)
32
+ except Exception as e:
33
+ print(f"Warning: Could not load local list generator {LIST_GENERATION_MODEL}: {e}")
34
+ list_generator = None
35
+
36
+ # Prompt Refiner (API)
37
+ try:
38
+ if not HF_TOKEN:
39
+ print("Warning: HUGGINGFACE_TOKEN not set. Inference API calls might be rate-limited or fail.")
40
+ prompt_refiner = InferenceClient(model=PROMPT_REFINEMENT_MODEL_API, token=HF_TOKEN)
41
+ except Exception as e:
42
+ print(f"Warning: Could not initialize InferenceClient for {PROMPT_REFINEMENT_MODEL_API}: {e}")
43
+ prompt_refiner = None
44
+
45
+ # Image Generator (Local with Diffusers - requires GPU on Space for reasonable speed)
46
+ # Or consider an Image Gen API service if running on CPU hardware
47
+ try:
48
+ # Using XL as an example - adjust based on available hardware
49
+ image_pipeline = StableDiffusionPipeline.from_pretrained(IMAGE_GENERATION_MODEL, torch_dtype=torch.float16, use_safetensors=True)
50
+ # Move to GPU if available (check Space hardware)
51
+ # image_pipeline.to("cuda") # Uncomment if GPU is available
52
+ image_pipeline.scheduler = EulerDiscreteScheduler.from_config(image_pipeline.scheduler.config)
53
+ except Exception as e:
54
+ print(f"Warning: Could not load diffusers pipeline {IMAGE_GENERATION_MODEL}. Image generation might fail: {e}")
55
+ image_pipeline = None
56
+
57
+
58
+ # 3D Generator Client
59
+ try:
60
+ hunyuan_client = GradioClient(HUNYUAN_SPACE_ID)
61
+ except Exception as e:
62
+ print(f"Error initializing GradioClient for {HUNYUAN_SPACE_ID}: {e}")
63
+ hunyuan_client = None
64
+
65
+ # --- Helper Functions ---
66
+
67
+ def generate_list_local(theme, count):
68
+ if not list_generator:
69
+ return ["Error: List generator model not loaded."]
70
+ prompt = f"Generate a comma-separated list of {count} distinct types of {theme}."
71
+ try:
72
+ result = list_generator(prompt, max_length=200)[0]['generated_text']
73
+ items = [item.strip() for item in result.split(',') if item.strip()]
74
+ return items[:count] # Ensure we don't exceed the requested count
75
+ except Exception as e:
76
+ print(f"Error generating list: {e}")
77
+ return [f"Error: {e}"]
78
+
79
+ def refine_prompt_api(item_name):
80
+ if not prompt_refiner:
81
+ return f"A 3D model of a {item_name}" # Fallback basic prompt
82
+ prompt = f"Create a detailed, descriptive prompt for generating a highly realistic image of a single '{item_name}'. Focus on visual details suitable for a text-to-image AI. Only output the prompt itself."
83
+ try:
84
+ refined = prompt_refiner.text_generation(prompt, max_new_tokens=100)
85
+ # Clean up potential API artifacts if necessary
86
+ refined = refined.strip().strip('"')
87
+ return refined
88
+ except Exception as e:
89
+ print(f"Error refining prompt for '{item_name}': {e}")
90
+ # Fallback to a simpler prompt for 3D generation if refinement fails
91
+ return f"A high quality 3D model of a {item_name}"
92
+
93
+ def generate_image_local(refined_prompt, output_path):
94
+ if not image_pipeline:
95
+ print("Image generation pipeline not available.")
96
+ # Create a placeholder image or return None
97
+ # Example: from PIL import Image; img = Image.new('RGB', (60, 30), color = 'red'); img.save(output_path); return output_path
98
+ return None
99
+ try:
100
+ # Adjust inference steps/guidance as needed
101
+ image = image_pipeline(refined_prompt, num_inference_steps=25, guidance_scale=7.5).images[0]
102
+ image.save(output_path)
103
+ return output_path
104
+ except Exception as e:
105
+ print(f"Error generating image for prompt '{refined_prompt}': {e}")
106
+ return None
107
+
108
+ def generate_3d_model_hunyuan(refined_prompt_for_3d, output_dir, item_name_safe):
109
+ if not hunyuan_client:
110
+ print("Hunyuan 3D client not available.")
111
+ return None, "Client not initialized"
112
+
113
+ print(f"Requesting 3D model for: {refined_prompt_for_3d}")
114
+ # Use defaults for most parameters initially
115
+ try:
116
+ result_tuple = hunyuan_client.predict(
117
+ caption=refined_prompt_for_3d,
118
+ # Leave image and mv_image inputs as None for text-to-3D
119
+ image=None,
120
+ mv_image_front=None,
121
+ mv_image_back=None,
122
+ mv_image_left=None,
123
+ mv_image_right=None,
124
+ # Default values from API docs (can be overridden)
125
+ steps=30,
126
+ guidance_scale=5,
127
+ seed=1234, # Or use randomize_seed=True
128
+ octree_resolution=256,
129
+ check_box_rembg=True,
130
+ num_chunks=8000,
131
+ randomize_seed=True,
132
+ api_name="/generation_all" # Crucial!
133
+ )
134
+
135
+ # --- VERIFICATION NEEDED ---
136
+ # Check the actual return tuple structure. Assuming file path is first or second.
137
+ # Let's try the first element (index 0). If it's None or not a path, try index 1.
138
+ raw_filepath = None
139
+ if result_tuple and len(result_tuple) > 0 and isinstance(result_tuple[0], str):
140
+ raw_filepath = result_tuple[0]
141
+ elif result_tuple and len(result_tuple) > 1 and isinstance(result_tuple[1], str):
142
+ print("Using second element from result tuple for filepath.")
143
+ raw_filepath = result_tuple[1]
144
+ # --- END VERIFICATION NEEDED ---
145
+
146
+ if raw_filepath:
147
+ print(f"Job completed. Raw result path: {raw_filepath}")
148
+ os.makedirs(output_dir, exist_ok=True)
149
+
150
+ # Download the file using handle_file which manages temp paths etc.
151
+ # handle_file saves with a potentially random name in download_dir
152
+ downloaded_temp_path = handle_file(raw_filepath, download_dir=output_dir)
153
+
154
+ if downloaded_temp_path and os.path.exists(downloaded_temp_path):
155
+ # Rename it to something meaningful
156
+ file_ext = os.path.splitext(downloaded_temp_path)[1] # Get extension (.glb, .obj?)
157
+ if not file_ext: file_ext = ".glb" # Assume glb if unknown
158
+ final_path = os.path.join(output_dir, f"{item_name_safe}{file_ext}")
159
+ os.rename(downloaded_temp_path, final_path)
160
+ print(f"Model saved to: {final_path}")
161
+ return final_path, "Success"
162
+ else:
163
+ error_msg = f"handle_file failed to download or returned invalid path: {downloaded_temp_path}"
164
+ print(error_msg)
165
+ return None, error_msg
166
+ else:
167
+ error_msg = f"Job for '{refined_prompt_for_3d}' did not return a valid filepath in expected tuple elements."
168
+ print(error_msg)
169
+ # You might want to inspect the full result_tuple here for debugging
170
+ print(f"Full result tuple: {result_tuple}")
171
+ return None, error_msg
172
+
173
+ except Exception as e:
174
+ error_msg = f"Error calling Hunyuan3D API for '{refined_prompt_for_3d}': {e}"
175
+ print(error_msg)
176
+ return None, str(e)
177
+
178
+ def create_zip(files_to_zip, zip_filepath):
179
+ with zipfile.ZipFile(zip_filepath, 'w') as zf:
180
+ for file_path in files_to_zip:
181
+ if file_path and os.path.exists(file_path):
182
+ zf.write(file_path, os.path.basename(file_path))
183
+ return zip_filepath
184
+
185
+
186
+ # --- Gradio Interface & Logic ---
187
+
188
+ with gr.Blocks() as demo:
189
+ gr.Markdown("# 3D Asset Collection Generator")
190
+ gr.Markdown("Generate a list based on a theme, refine prompts, preview images, and generate selected 3D models using Hunyuan3D-2.")
191
+ if not HF_TOKEN:
192
+ gr.Warning("Hugging Face Token not found. Prompt refinement quality/rate limits may be affected. Consider adding HUGGINGFACE_TOKEN to Space secrets.")
193
+ if not image_pipeline:
194
+ gr.Warning("Local Image Generation model failed to load. Image previews will be skipped. Check Space hardware/logs.")
195
+ if not hunyuan_client:
196
+ gr.Error("Failed to connect to the Hunyuan3D-2 Space. 3D generation will not work.")
197
+
198
+
199
+ # State to hold intermediate results
200
+ # Using gr.State is good for simple values, for complex lists/dicts might need alternatives or careful handling
201
+ list_items_state = gr.State([])
202
+ refined_prompts_state = gr.State({}) # Dict: {item_name: refined_prompt}
203
+ image_paths_state = gr.State({}) # Dict: {item_name: image_path}
204
+ selected_items_state = gr.State([]) # List of item_names selected by user
205
+ generated_3d_files_state = gr.State([]) # List of paths to successfully generated models
206
+
207
+
208
+ with gr.Row():
209
+ theme_input = gr.Textbox(label="Theme", placeholder="e.g., reptiles, kitchen appliances, medieval weapons")
210
+ count_input = gr.Number(label="Number of Items", value=5, minimum=1, step=1)
211
+
212
+ generate_list_button = gr.Button("1. Generate List & Refine Prompts")
213
+ list_output_display = gr.Markdown("List will appear here...") # Or use gr.DataFrame
214
+
215
+ generate_images_button = gr.Button("2. Generate Image Previews", visible=False) # Hidden initially
216
+ # Use Gallery for display, Dataset for selection tracking
217
+ image_gallery = gr.Gallery(label="Image Previews", visible=False, elem_id="image_gallery")
218
+ # Dataset to hold data for selection (item_name, image_path, refined_prompt)
219
+ selection_data = gr.Dataset(components=[gr.Textbox(visible=False), gr.Textbox(visible=False), gr.Textbox(visible=False)], # item, img_path, prompt
220
+ headers=["Item Name", "Image", "Prompt"],
221
+ label="Select Items for 3D Generation",
222
+ visible=False)
223
+
224
+
225
+ generate_3d_button = gr.Button("3. Generate 3D Models for Selected Items", visible=False) # Hidden initially
226
+ status_output = gr.Markdown("") # For progress updates
227
+ final_zip_output = gr.File(label="Download 3D Model Collection (ZIP)", visible=False)
228
+
229
+
230
+ # --- Event Logic ---
231
+
232
+ def run_list_and_refine(theme, count):
233
+ if not theme:
234
+ return {list_output_display: "Please enter a theme.", generate_images_button: gr.Button(visible=False)}
235
+
236
+ # Ensure output dirs exist
237
+ os.makedirs(os.path.join(OUTPUT_DIR, IMAGES_SUBDIR), exist_ok=True)
238
+ os.makedirs(os.path.join(OUTPUT_DIR, MODELS_SUBDIR), exist_ok=True)
239
+
240
+
241
+ gr.Info("Generating list...")
242
+ items = generate_list_local(theme, int(count))
243
+ if not items or "Error:" in items[0]:
244
+ return {list_output_display: f"Failed to generate list: {items[0] if items else 'Unknown error'}",
245
+ generate_images_button: gr.Button(visible=False)}
246
+
247
+ list_items_state.value = items # Save items to state
248
+
249
+ gr.Info("Refining prompts via API...")
250
+ refined_prompts = {}
251
+ output_md = "### Generated List & Refined Prompts:\n\n"
252
+ for item in items:
253
+ refined = refine_prompt_api(item)
254
+ refined_prompts[item] = refined
255
+ output_md += f"* **{item}:** {refined}\n"
256
+
257
+ refined_prompts_state.value = refined_prompts # Save refined prompts
258
+
259
+ # Enable next step
260
+ return {
261
+ list_output_display: output_md,
262
+ generate_images_button: gr.Button(visible=True) # Show image gen button
263
+ }
264
+
265
+ generate_list_button.click(
266
+ fn=run_list_and_refine,
267
+ inputs=[theme_input, count_input],
268
+ outputs=[list_output_display, generate_images_button, list_items_state, refined_prompts_state] # Update state too
269
+ )
270
+
271
+
272
+ def run_image_generation(items, refined_prompts_dict):
273
+ if not image_pipeline:
274
+ # Skip image generation if pipeline not loaded
275
+ gr.Warning("Image pipeline not loaded. Skipping image previews.")
276
+ # Prepare data for selection without images
277
+ selection_samples = [[item, "N/A", refined_prompts_dict.get(item, "")] for item in items]
278
+ image_paths_state.value = {} # Clear image paths
279
+ return {
280
+ image_gallery: gr.Gallery(visible=False),
281
+ selection_data: gr.Dataset(samples=selection_samples, visible=True),
282
+ generate_3d_button: gr.Button(visible=True) # Allow proceeding without previews
283
+ }
284
+
285
+ gr.Info("Generating image previews... (this may take a while)")
286
+ image_paths = {}
287
+ gallery_images = []
288
+ selection_samples = [] # For the Dataset component
289
+
290
+ img_dir = os.path.join(OUTPUT_DIR, IMAGES_SUBDIR)
291
+
292
+ for item in items:
293
+ refined_prompt = refined_prompts_dict.get(item, f"Image of {item}") # Get refined prompt
294
+ safe_item_name = "".join(c if c.isalnum() else "_" for c in item)
295
+ img_filename = f"{safe_item_name}_{uuid.uuid4()}.png"
296
+ img_path = os.path.join(img_dir, img_filename)
297
+
298
+ generated_path = generate_image_local(refined_prompt, img_path)
299
+
300
+ if generated_path:
301
+ image_paths[item] = generated_path
302
+ gallery_images.append(generated_path)
303
+ selection_samples.append([item, generated_path, refined_prompt])
304
+ else:
305
+ # Handle image generation failure - maybe add placeholder info
306
+ selection_samples.append([item, "Failed", refined_prompt])
307
+ # Optionally add a placeholder to gallery_images too
308
+
309
+ image_paths_state.value = image_paths # Save image paths
310
+
311
+ # Show gallery and selection dataset
312
+ return {
313
+ image_gallery: gr.Gallery(value=gallery_images, visible=True),
314
+ selection_data: gr.Dataset(samples=selection_samples, visible=True),
315
+ generate_3d_button: gr.Button(visible=True) # Show 3D gen button
316
+ }
317
+
318
+ generate_images_button.click(
319
+ fn=run_image_generation,
320
+ inputs=[list_items_state, refined_prompts_state],
321
+ outputs=[image_gallery, selection_data, generate_3d_button, image_paths_state] # Update state
322
+ )
323
+
324
+ # Handler for when user makes selections in the Dataset
325
+ # Note: Gradio's Dataset selection handling might require specific event listeners
326
+ # or potentially using gr.CheckboxGroup or similar if Dataset selection is tricky.
327
+ # For simplicity here, we assume we can get the selected indices/items.
328
+ # A common pattern is to add a hidden Textbox updated by JS on selection,
329
+ # or use the Dataset's 'select' event if available and robust.
330
+ # Let's simulate getting selected *items* (requires correct component setup).
331
+ # This part might need refinement based on Gradio version/behavior.
332
+
333
+ # We'll trigger 3D generation directly from the button click for now,
334
+ # assuming the selection_data component holds the necessary info and selection state.
335
+
336
+ def run_3d_generation(selection_evt: gr.SelectData, all_items_data):
337
+ if not hunyuan_client:
338
+ return {status_output: "Hunyuan3D client not initialized. Cannot generate.", final_zip_output: gr.File(visible=False)}
339
+
340
+ selected_indices = selection_evt.index if selection_evt else []
341
+ if not selected_indices:
342
+ return {status_output: "Please select items from the table above before generating 3D models.", final_zip_output: gr.File(visible=False)}
343
+
344
+ # Extract selected items based on indices from the *current* data in the dataset
345
+ selected_items_info = [all_items_data[i] for i in selected_indices] # Each item is [name, img_path, prompt]
346
+
347
+ generated_files = []
348
+ status_messages = ["### 3D Generation Status:\n"]
349
+
350
+ model_dir = os.path.join(OUTPUT_DIR, MODELS_SUBDIR)
351
+
352
+ total_selected = len(selected_items_info)
353
+ for i, (item_name, _, refined_prompt) in enumerate(selected_items_info):
354
+ current_status = f"({i+1}/{total_selected}) Generating model for: **{item_name}**..."
355
+ print(current_status)
356
+ status_messages.append(f"* {current_status}")
357
+ # Update UI status progressively
358
+ yield {status_output: "\n".join(status_messages), final_zip_output: gr.File(visible=False)}
359
+
360
+ # Adapt prompt slightly for 3D if desired, or use the image prompt directly
361
+ prompt_for_3d = refined_prompt # Or customize: f"A high quality 3D model of {item_name}, {refined_prompt}"
362
+
363
+ item_name_safe = "".join(c if c.isalnum() else "_" for c in item_name)
364
+
365
+ # --- Retry Logic Placeholder ---
366
+ max_retries = 1 # Example: allow 1 retry
367
+ attempts = 0
368
+ model_path = None
369
+ last_error = "Unknown error"
370
+
371
+ while attempts <= max_retries:
372
+ attempts += 1
373
+ if attempts > 1:
374
+ status_messages.append(f" * Retrying ({attempts-1}/{max_retries})...")
375
+ yield {status_output: "\n".join(status_messages)}
376
+ time.sleep(2) # Brief pause before retry
377
+
378
+ model_path, msg = generate_3d_model_hunyuan(prompt_for_3d, model_dir, item_name_safe)
379
+ last_error = msg
380
+ if model_path:
381
+ generated_files.append(model_path)
382
+ status_messages.append(f" * Success! Model saved.")
383
+ break # Exit retry loop on success
384
+ else:
385
+ status_messages.append(f" * Attempt {attempts} failed: {msg}")
386
+
387
+ if not model_path:
388
+ status_messages.append(f" * **Failed** after {attempts} attempt(s). Last error: {last_error}")
389
+ # --- End Retry Logic ---
390
+
391
+ # Update UI status after each item
392
+ yield {status_output: "\n".join(status_messages)}
393
+
394
+
395
+ if generated_files:
396
+ status_messages.append("\nCreating ZIP archive...")
397
+ yield {status_output: "\n".join(status_messages)}
398
+ zip_path = os.path.join(OUTPUT_DIR, ZIP_FILENAME)
399
+ final_zip = create_zip(generated_files, zip_path)
400
+ status_messages.append(f"\n**Collection ready!** Download '{ZIP_FILENAME}' below.")
401
+ generated_3d_files_state.value = generated_files # Store final paths
402
+ return {status_output: "\n".join(status_messages), final_zip_output: gr.File(value=final_zip, visible=True)}
403
+ else:
404
+ status_messages.append("\nNo 3D models were successfully generated.")
405
+ return {status_output: "\n".join(status_messages), final_zip_output: gr.File(visible=False)}
406
+
407
+
408
+ # Link the button click to the generator function
409
+ # The 'select' event on Dataset provides selection info (gr.SelectData)
410
+ # We pass both the selection event data and the full dataset content
411
+ generate_3d_button.click(
412
+ fn=run_3d_generation,
413
+ inputs=[selection_data, selection_data], # Pass dataset twice: once for select event, once for full data access
414
+ outputs=[status_output, final_zip_output, generated_3d_files_state] # Update state
415
+ )
416
+
417
+
418
+ # Launch the Gradio app
419
+ demo.queue().launch(debug=True) # Enable queue for longer processes, debug for detailed errors