haizad commited on
Commit
ccd41f5
·
1 Parent(s): c7c9368

add AI generated garment

Browse files
Files changed (1) hide show
  1. app.py +133 -13
app.py CHANGED
@@ -49,8 +49,35 @@ def url_to_base64(url):
49
  print(f"Error converting URL to base64: {str(e)}")
50
  return None
51
 
52
- def run_viton(model_image_path, garment_image_path, model_url, garment_url,
53
- n_steps=20, image_scale=2.0, seed=-1):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  try:
55
  api_url = os.environ.get("SERVER_URL")
56
  print(f"Using API URL: {api_url}") # Add this to debug
@@ -110,6 +137,94 @@ def run_viton(model_image_path, garment_image_path, model_url, garment_url,
110
  img = base64_to_image(img_b64, output_path) # Remove 'self.'
111
  generated_images.append(img)
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  print(f"Successfully generated {len(generated_images)} images")
114
  return generated_images
115
  else:
@@ -124,48 +239,53 @@ block = gr.Blocks().queue()
124
  with block:
125
  with gr.Row():
126
  gr.Markdown("# Virtual Try-On")
127
- with gr.Row():
128
- gr.Markdown("**Instructions:** You can either upload images using the file upload interface or provide direct URLs to images. URL inputs will take priority over uploaded files.")
129
  with gr.Row():
130
  with gr.Column():
 
131
  model_url = gr.Textbox(
132
  label="Enter Model Image URL",
133
  )
134
  vton_img = gr.Image(label="Model", sources=['upload', 'webcam'], type="filepath", height=384)
135
  example = gr.Examples(
136
  inputs=vton_img,
137
- examples_per_page=5,
138
  examples=[
139
- os.path.join(example_path, 'model/model_8.png'),
140
  os.path.join(example_path, 'model/model_2.png'),
141
  os.path.join(example_path, 'model/model_7.png'),
142
  os.path.join(example_path, 'model/model_4.png'),
143
  os.path.join(example_path, 'model/model_5.png'),
144
  ])
145
  with gr.Column():
 
146
  garment_url = gr.Textbox(
147
- label="Enter Garment Image URL",
 
 
 
148
  )
149
  garm_img = gr.Image(label="Garment", sources=['upload', 'webcam'], type="filepath", height=384)
150
  example = gr.Examples(
151
  inputs=garm_img,
152
- examples_per_page=5,
153
  examples=[
154
- os.path.join(example_path, 'garment/00055_00.jpg'),
155
  os.path.join(example_path, 'garment/07764_00.jpg'),
156
  os.path.join(example_path, 'garment/03032_00.jpg'),
157
  os.path.join(example_path, 'garment/048554_1.jpg'),
158
  os.path.join(example_path, 'garment/049805_1.jpg'),
159
  ])
160
  with gr.Column():
161
- result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", preview=True, scale=1)
 
162
  with gr.Column():
163
- run_button = gr.Button(value="Run")
 
164
  n_steps = gr.Slider(label="Steps", minimum=20, maximum=40, value=20, step=1)
165
  image_scale = gr.Slider(label="Guidance scale", minimum=1.0, maximum=5.0, value=2.0, step=0.1)
166
  seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=-1)
167
 
168
- ips = [vton_img, garm_img, model_url, garment_url, n_steps, image_scale, seed]
169
- run_button.click(fn=run_viton, inputs=ips, outputs=result_gallery)
 
 
170
 
171
  block.launch(mcp_server=True)
 
49
  print(f"Error converting URL to base64: {str(e)}")
50
  return None
51
 
52
+ def run_viton(model_image_path: str = None,
53
+ garment_image_path: str = None,
54
+ model_url: str = None,
55
+ garment_url: str = None,
56
+ n_steps=20,
57
+ image_scale=2.0,
58
+ seed=-1
59
+ ):
60
+ """
61
+ Run the Virtual Try-On model with provided images path or URLs.
62
+
63
+ Args:
64
+ model_image_path (str): Path to the model image file.
65
+ garment_image_path (str): Path to the garment image file.
66
+ model_url (str): URL of the model image.
67
+ garment_url (str): URL of the garment image.
68
+ n_steps (int): Number of steps for the model.
69
+ image_scale (float): Scale for the generated images.
70
+ seed (int): Random seed for reproducibility.
71
+
72
+ Returns:
73
+ list: List of generated images in base64 format.
74
+ """
75
+ if not model_image_path and not model_url:
76
+ print("Error: No model image provided")
77
+ return []
78
+ if not garment_image_path and not garment_url:
79
+ print("Error: No garment image provided")
80
+ return []
81
  try:
82
  api_url = os.environ.get("SERVER_URL")
83
  print(f"Using API URL: {api_url}") # Add this to debug
 
137
  img = base64_to_image(img_b64, output_path) # Remove 'self.'
138
  generated_images.append(img)
139
 
140
+ print(f"Successfully generated {len(generated_images)} images")
141
+ return generated_images
142
+ else:
143
+ print(f"Request failed with status code: {response.status_code}")
144
+ return []
145
+
146
+ except Exception as e:
147
+ print(f"Exception occurred: {str(e)}") # Add this
148
+ return []
149
+
150
+ def run_new_garment(model_image_path: str = None,
151
+ garment_prompt: str = None,
152
+ model_url: str = None,
153
+ n_steps=20,
154
+ image_scale=2.0,
155
+ seed=-1
156
+ ):
157
+ """
158
+ Run the Virtual Try-On model with provided model image and garment image generated using FLUX.1-dev
159
+ based on description of the garment obtained from the user
160
+
161
+ Args:
162
+ model_image_path (str): Path to the model image file.
163
+ garment_prompt (str): Description of the garment.
164
+ model_url (str): URL of the model image.
165
+ n_steps (int): Number of steps for the model.
166
+ image_scale (float): Scale for the generated images.
167
+ seed (int): Random seed for reproducibility.
168
+
169
+ Returns:
170
+ list: List of generated images in base64 format.
171
+ """
172
+ if not model_image_path and not model_url:
173
+ print("Error: No model image provided")
174
+ return []
175
+ if not garment_prompt or not garment_prompt.strip():
176
+ print("Error: No garment description provided")
177
+ return []
178
+ try:
179
+ api_url = os.environ.get("SERVER_URL")
180
+ print(f"Using API URL: {api_url}") # Add this to debug
181
+
182
+ # Determine which inputs to use (file upload or URL)
183
+ model_b64 = None
184
+
185
+ # Handle model image
186
+ if model_url and model_url.strip():
187
+ print(f"Using model URL: {model_url}")
188
+ model_b64 = url_to_base64(model_url.strip())
189
+ elif model_image_path:
190
+ print(f"Using model file: {model_image_path}")
191
+ model_b64 = image_to_base64(model_image_path)
192
+
193
+ # Check if we have both images
194
+ if not model_b64 or not garment_prompt:
195
+ print("Error: Missing model or garment description")
196
+ return []
197
+
198
+ # Prepare request
199
+ request_data = {
200
+ "model_image_base64": model_b64,
201
+ "garment_prompt": garment_prompt.strip(),
202
+ "n_samples": 1,
203
+ "n_steps": n_steps,
204
+ "image_scale": image_scale,
205
+ "seed": seed
206
+ }
207
+
208
+ # Send request
209
+ response = requests.post(f"{api_url}/new-garment",
210
+ json=request_data,
211
+ timeout=300)
212
+
213
+ print(f"Request sent to {api_url}/new-garment")
214
+ print(f"Response status code: {response.status_code}")
215
+
216
+ if response.status_code == 200:
217
+ result = response.json()
218
+ if result.get("error"):
219
+ print(f"Error: {result['error']}")
220
+ return []
221
+
222
+ generated_images = []
223
+ for i, img_b64 in enumerate(result.get("images_base64", [])):
224
+ output_path = f"flux_output_{i}.png"
225
+ img = base64_to_image(img_b64, output_path) # Remove 'self.'
226
+ generated_images.append(img)
227
+
228
  print(f"Successfully generated {len(generated_images)} images")
229
  return generated_images
230
  else:
 
239
  with block:
240
  with gr.Row():
241
  gr.Markdown("# Virtual Try-On")
 
 
242
  with gr.Row():
243
  with gr.Column():
244
+ gr.Markdown("### Provide image or URL of upper body photo")
245
  model_url = gr.Textbox(
246
  label="Enter Model Image URL",
247
  )
248
  vton_img = gr.Image(label="Model", sources=['upload', 'webcam'], type="filepath", height=384)
249
  example = gr.Examples(
250
  inputs=vton_img,
251
+ examples_per_page=4,
252
  examples=[
 
253
  os.path.join(example_path, 'model/model_2.png'),
254
  os.path.join(example_path, 'model/model_7.png'),
255
  os.path.join(example_path, 'model/model_4.png'),
256
  os.path.join(example_path, 'model/model_5.png'),
257
  ])
258
  with gr.Column():
259
+ gr.Markdown("### Provide image, URL or description of a garment")
260
  garment_url = gr.Textbox(
261
+ label="Enter Garment Image URL",
262
+ )
263
+ garment_promt = gr.Textbox(
264
+ label="Describe Garment",
265
  )
266
  garm_img = gr.Image(label="Garment", sources=['upload', 'webcam'], type="filepath", height=384)
267
  example = gr.Examples(
268
  inputs=garm_img,
269
+ examples_per_page=4,
270
  examples=[
 
271
  os.path.join(example_path, 'garment/07764_00.jpg'),
272
  os.path.join(example_path, 'garment/03032_00.jpg'),
273
  os.path.join(example_path, 'garment/048554_1.jpg'),
274
  os.path.join(example_path, 'garment/049805_1.jpg'),
275
  ])
276
  with gr.Column():
277
+ gr.Markdown("### 2D Result")
278
+ result_gallery = gr.Gallery(label='Output 2D', show_label=False, elem_id="gallery", preview=True, scale=1)
279
  with gr.Column():
280
+ run_button = gr.Button(value="Try On with your garment")
281
+ run_button2 = gr.Button(value="Try On with AI generated garment")
282
  n_steps = gr.Slider(label="Steps", minimum=20, maximum=40, value=20, step=1)
283
  image_scale = gr.Slider(label="Guidance scale", minimum=1.0, maximum=5.0, value=2.0, step=0.1)
284
  seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=-1)
285
 
286
+ ips1 = [vton_img, garm_img, model_url, garment_url, n_steps, image_scale, seed]
287
+ run_button.click(fn=run_viton, inputs=ips1, outputs=result_gallery)
288
+ ips2 = [vton_img, garment_promt, model_url, n_steps, image_scale, seed]
289
+ run_button2.click(fn=run_new_garment, inputs=ips2, outputs=result_gallery)
290
 
291
  block.launch(mcp_server=True)