stillerman commited on
Commit
d4f5d63
Β·
1 Parent(s): b478961

show links instead of images

Browse files
Files changed (1) hide show
  1. app.py +41 -38
app.py CHANGED
@@ -103,7 +103,7 @@ def start_training(
103
  except json.JSONDecodeError:
104
  return "❌ Error: Invalid response from server", ""
105
 
106
- def check_job_status(job_id: str, job_status_url: str) -> tuple[str, Optional[Image.Image]]:
107
  """
108
  Check the current status of a LoRA training job or image generation job.
109
 
@@ -112,14 +112,17 @@ def check_job_status(job_id: str, job_status_url: str) -> tuple[str, Optional[Im
112
  Note that if we are invoking this function with MCP, the user cannot neccecarily see the images
113
  in the tool call, so you will have to render them again in the chat.
114
 
 
 
 
115
  Parameters:
116
  - job_id (str, required): The unique job identifier returned from start_training or generate_images function
117
  - job_status_url (str, required): Modal API endpoint for checking job status, format: "https://modal-app-url-api-job-status.modal.run". If the app is already deployed, this can be found in the Modal [dashboard](https://modal.com/apps/) . Otherwise, the app can get deployed with the deploy_for_user function.
118
 
119
  Returns:
120
- - tuple[str, Optional[Image.Image]]: (status_message, first_image)
121
  - status_message: Detailed status message containing job information
122
- - first_image: First PIL Image object if images are available, None otherwise
123
 
124
  Possible status values:
125
  - "completed": Job finished successfully
@@ -132,7 +135,7 @@ def check_job_status(job_id: str, job_status_url: str) -> tuple[str, Optional[Im
132
  """
133
 
134
  if not job_id or not job_id.strip():
135
- return "❌ Error: Job ID is required", None
136
 
137
  try:
138
  response = requests.get(
@@ -156,37 +159,34 @@ def check_job_status(job_id: str, job_status_url: str) -> tuple[str, Optional[Im
156
  message += f"**Message:** {training_result.get('message', 'Generation finished')}\n"
157
  if training_result.get('lora_repo'):
158
  message += f"**LoRA Used:** {training_result['lora_repo']}\n"
159
-
160
  images_data = training_result.get('images', [])
161
- first_image = None
162
-
163
  if images_data:
164
  message += f"**Images Generated:** {len(images_data)}\n\n"
165
-
166
  # Show all prompts
167
  message += "**Generated Images:**\n"
168
  for i, img_data in enumerate(images_data):
169
  prompt = img_data.get('prompt', f'Image {i+1}')
170
  message += f"**{i+1}.** {prompt}\n"
171
-
172
- # But only decode and return the first image
173
- if len(images_data) > 0:
174
- first_img_data = images_data[0]
175
- base64_data = first_img_data.get('image', '')
176
- first_prompt = first_img_data.get('prompt', 'Image 1')
177
-
178
  if base64_data:
179
  try:
180
  image_bytes = base64.b64decode(base64_data)
181
- first_image = Image.open(BytesIO(image_bytes))
182
- message += f"\n**Displaying first image:** {first_prompt}"
183
- if len(images_data) > 1:
184
- message += f"\n*({len(images_data) - 1} additional images were generated but not displayed)*"
185
  except Exception as e:
186
- print(f"Error decoding first image: {e}")
187
- message += f"\n**Error loading first image:** {e}"
188
-
189
- return message, first_image
 
 
190
  else:
191
  # Training job
192
  message = "πŸŽ‰ **Training Completed!**\n\n"
@@ -198,35 +198,35 @@ def check_job_status(job_id: str, job_status_url: str) -> tuple[str, Optional[Im
198
  message += f"**Training Steps:** {training_result['training_steps']}\n"
199
  if training_result.get('training_prompt'):
200
  message += f"**Training Prompt:** {training_result['training_prompt']}\n"
201
- return message, None
202
  else:
203
  message = "πŸŽ‰ **Job Completed!**\n\n"
204
  message += f"**Result:** {training_result}"
205
- return message, None
206
 
207
  elif status == "running":
208
- return f"πŸ”„ **Job in Progress**\n\nThe job is still running. Check back in a few minutes.", None
209
 
210
  elif status == "failed":
211
  error_msg = result.get("message", "Job failed with unknown error")
212
- return f"❌ **Job Failed**\n\n**Error:** {error_msg}", None
213
 
214
  elif status == "error":
215
  error_msg = result.get("message", "Unknown error occurred")
216
- return f"❌ **Error**\n\n**Message:** {error_msg}", None
217
 
218
  else:
219
- return f"❓ **Unknown Status**\n\n**Status:** {status}\n**Response:** {json.dumps(result, indent=2)}", None
220
 
221
  else:
222
- return f"❌ HTTP Error {response.status_code}: {response.text}", None
223
 
224
  except requests.exceptions.Timeout:
225
- return "❌ Error: Request timed out", None
226
  except requests.exceptions.RequestException as e:
227
- return f"❌ Error: Failed to connect to status service: {str(e)}", None
228
  except json.JSONDecodeError:
229
- return "❌ Error: Invalid response from server", None
230
 
231
  def generate_images(
232
  prompts_json: str,
@@ -614,18 +614,21 @@ with gr.Blocks(title="FluxFoundry LoRA Training", theme=gr.themes.Soft()) as app
614
 
615
  status_output = gr.Markdown(label="Job Status")
616
 
617
- # Add single image component for displaying the first generated image
618
- generated_image = gr.Image(
619
- label="First Generated Image",
620
  show_label=True,
621
  interactive=False,
622
- visible=True
 
 
 
623
  )
624
 
625
  status_btn.click(
626
  fn=check_job_status,
627
  inputs=[job_id_input, job_status_url],
628
- outputs=[status_output, generated_image]
629
  )
630
 
631
  gr.Markdown("---")
 
103
  except json.JSONDecodeError:
104
  return "❌ Error: Invalid response from server", ""
105
 
106
+ def check_job_status(job_id: str, job_status_url: str) -> tuple[str, List[Image.Image]]:
107
  """
108
  Check the current status of a LoRA training job or image generation job.
109
 
 
112
  Note that if we are invoking this function with MCP, the user cannot neccecarily see the images
113
  in the tool call, so you will have to render them again in the chat.
114
 
115
+ **MCP Client Limitation:** Due to MCP client constraints, we cannot render a gallery of images in the chat.
116
+ The MCP client should render these URLs as clickable markdown links when possible.
117
+
118
  Parameters:
119
  - job_id (str, required): The unique job identifier returned from start_training or generate_images function
120
  - job_status_url (str, required): Modal API endpoint for checking job status, format: "https://modal-app-url-api-job-status.modal.run". If the app is already deployed, this can be found in the Modal [dashboard](https://modal.com/apps/) . Otherwise, the app can get deployed with the deploy_for_user function.
121
 
122
  Returns:
123
+ - tuple[str, List[Image.Image]]: (status_message, all_images)
124
  - status_message: Detailed status message containing job information
125
+ - all_images: List of PIL Image objects if images are available, empty list otherwise
126
 
127
  Possible status values:
128
  - "completed": Job finished successfully
 
135
  """
136
 
137
  if not job_id or not job_id.strip():
138
+ return "❌ Error: Job ID is required", []
139
 
140
  try:
141
  response = requests.get(
 
159
  message += f"**Message:** {training_result.get('message', 'Generation finished')}\n"
160
  if training_result.get('lora_repo'):
161
  message += f"**LoRA Used:** {training_result['lora_repo']}\n"
162
+
163
  images_data = training_result.get('images', [])
164
+ all_images = []
165
+
166
  if images_data:
167
  message += f"**Images Generated:** {len(images_data)}\n\n"
168
+
169
  # Show all prompts
170
  message += "**Generated Images:**\n"
171
  for i, img_data in enumerate(images_data):
172
  prompt = img_data.get('prompt', f'Image {i+1}')
173
  message += f"**{i+1}.** {prompt}\n"
174
+
175
+ # Decode and return all images
176
+ for i, img_data in enumerate(images_data):
177
+ base64_data = img_data.get('image', '')
 
 
 
178
  if base64_data:
179
  try:
180
  image_bytes = base64.b64decode(base64_data)
181
+ image = Image.open(BytesIO(image_bytes))
182
+ all_images.append(image)
 
 
183
  except Exception as e:
184
+ print(f"Error decoding image {i+1}: {e}")
185
+ message += f"\n**Error loading image {i+1}:** {e}"
186
+
187
+ message += f"\n**Displaying all {len(all_images)} generated images**"
188
+
189
+ return message, all_images
190
  else:
191
  # Training job
192
  message = "πŸŽ‰ **Training Completed!**\n\n"
 
198
  message += f"**Training Steps:** {training_result['training_steps']}\n"
199
  if training_result.get('training_prompt'):
200
  message += f"**Training Prompt:** {training_result['training_prompt']}\n"
201
+ return message, []
202
  else:
203
  message = "πŸŽ‰ **Job Completed!**\n\n"
204
  message += f"**Result:** {training_result}"
205
+ return message, []
206
 
207
  elif status == "running":
208
+ return f"πŸ”„ **Job in Progress**\n\nThe job is still running. Check back in a few minutes.", []
209
 
210
  elif status == "failed":
211
  error_msg = result.get("message", "Job failed with unknown error")
212
+ return f"❌ **Job Failed**\n\n**Error:** {error_msg}", []
213
 
214
  elif status == "error":
215
  error_msg = result.get("message", "Unknown error occurred")
216
+ return f"❌ **Error**\n\n**Message:** {error_msg}", []
217
 
218
  else:
219
+ return f"❓ **Unknown Status**\n\n**Status:** {status}\n**Response:** {json.dumps(result, indent=2)}", []
220
 
221
  else:
222
+ return f"❌ HTTP Error {response.status_code}: {response.text}", []
223
 
224
  except requests.exceptions.Timeout:
225
+ return "❌ Error: Request timed out", []
226
  except requests.exceptions.RequestException as e:
227
+ return f"❌ Error: Failed to connect to status service: {str(e)}", []
228
  except json.JSONDecodeError:
229
+ return "❌ Error: Invalid response from server", []
230
 
231
  def generate_images(
232
  prompts_json: str,
 
614
 
615
  status_output = gr.Markdown(label="Job Status")
616
 
617
+ # Add gallery component for displaying all generated images
618
+ generated_images = gr.Gallery(
619
+ label="Generated Images",
620
  show_label=True,
621
  interactive=False,
622
+ visible=True,
623
+ columns=2,
624
+ rows=2,
625
+ height="auto"
626
  )
627
 
628
  status_btn.click(
629
  fn=check_job_status,
630
  inputs=[job_id_input, job_status_url],
631
+ outputs=[status_output, generated_images]
632
  )
633
 
634
  gr.Markdown("---")