haizad commited on
Commit
c7c9368
·
1 Parent(s): a1b5165

enable image url as input

Browse files
Files changed (1) hide show
  1. app.py +63 -9
app.py CHANGED
@@ -22,15 +22,63 @@ def base64_to_image(base64_str, output_path): # Remove 'self'
22
  image.save(output_path)
23
  return image
24
 
25
- def run_viton(model_image_path, garment_image_path,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  n_steps=20, image_scale=2.0, seed=-1):
27
  try:
28
  api_url = os.environ.get("SERVER_URL")
29
  print(f"Using API URL: {api_url}") # Add this to debug
30
 
31
- # Convert images to base64 (remove 'self.')
32
- model_b64 = image_to_base64(model_image_path)
33
- garment_b64 = image_to_base64(garment_image_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  # Prepare request
36
  request_data = {
@@ -73,14 +121,17 @@ def run_viton(model_image_path, garment_image_path,
73
  return [] # Fix: should return list, not dict for gallery
74
 
75
  block = gr.Blocks().queue()
76
- default_model = os.path.join(example_path, 'model/model_8.png')
77
- default_garment = os.path.join(example_path, 'garment/00055_00.jpg')
78
  with block:
79
  with gr.Row():
80
  gr.Markdown("# Virtual Try-On")
 
 
81
  with gr.Row():
82
  with gr.Column():
83
- vton_img = gr.Image(label="Model", sources=['upload', 'webcam'], type="filepath", height=384, value=default_model)
 
 
 
84
  example = gr.Examples(
85
  inputs=vton_img,
86
  examples_per_page=5,
@@ -92,7 +143,10 @@ with block:
92
  os.path.join(example_path, 'model/model_5.png'),
93
  ])
94
  with gr.Column():
95
- garm_img = gr.Image(label="Garment", sources=['upload', 'webcam'], type="filepath", height=384, value=default_garment)
 
 
 
96
  example = gr.Examples(
97
  inputs=garm_img,
98
  examples_per_page=5,
@@ -111,7 +165,7 @@ with block:
111
  image_scale = gr.Slider(label="Guidance scale", minimum=1.0, maximum=5.0, value=2.0, step=0.1)
112
  seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=-1)
113
 
114
- ips = [vton_img, garm_img, n_steps, image_scale, seed]
115
  run_button.click(fn=run_viton, inputs=ips, outputs=result_gallery)
116
 
117
  block.launch(mcp_server=True)
 
22
  image.save(output_path)
23
  return image
24
 
25
+ def download_image_from_url(url, output_path):
26
+ """Download image from URL and save to local path"""
27
+ try:
28
+ response = requests.get(url, timeout=30)
29
+ response.raise_for_status()
30
+
31
+ # Save the image
32
+ with open(output_path, 'wb') as f:
33
+ f.write(response.content)
34
+
35
+ # Verify it's a valid image
36
+ image = Image.open(output_path)
37
+ return output_path
38
+ except Exception as e:
39
+ print(f"Error downloading image from {url}: {str(e)}")
40
+ return None
41
+
42
+ def url_to_base64(url):
43
+ """Convert image URL to base64 string"""
44
+ try:
45
+ response = requests.get(url, timeout=30)
46
+ response.raise_for_status()
47
+ return base64.b64encode(response.content).decode()
48
+ except Exception as e:
49
+ print(f"Error converting URL to base64: {str(e)}")
50
+ return None
51
+
52
+ def run_viton(model_image_path, garment_image_path, model_url, garment_url,
53
  n_steps=20, image_scale=2.0, seed=-1):
54
  try:
55
  api_url = os.environ.get("SERVER_URL")
56
  print(f"Using API URL: {api_url}") # Add this to debug
57
 
58
+ # Determine which inputs to use (file upload or URL)
59
+ model_b64 = None
60
+ garment_b64 = None
61
+
62
+ # Handle model image
63
+ if model_url and model_url.strip():
64
+ print(f"Using model URL: {model_url}")
65
+ model_b64 = url_to_base64(model_url.strip())
66
+ elif model_image_path:
67
+ print(f"Using model file: {model_image_path}")
68
+ model_b64 = image_to_base64(model_image_path)
69
+
70
+ # Handle garment image
71
+ if garment_url and garment_url.strip():
72
+ print(f"Using garment URL: {garment_url}")
73
+ garment_b64 = url_to_base64(garment_url.strip())
74
+ elif garment_image_path:
75
+ print(f"Using garment file: {garment_image_path}")
76
+ garment_b64 = image_to_base64(garment_image_path)
77
+
78
+ # Check if we have both images
79
+ if not model_b64 or not garment_b64:
80
+ print("Error: Missing model or garment image")
81
+ return []
82
 
83
  # Prepare request
84
  request_data = {
 
121
  return [] # Fix: should return list, not dict for gallery
122
 
123
  block = gr.Blocks().queue()
 
 
124
  with block:
125
  with gr.Row():
126
  gr.Markdown("# Virtual Try-On")
127
+ with gr.Row():
128
+ gr.Markdown("**Instructions:** You can either upload images using the file upload interface or provide direct URLs to images. URL inputs will take priority over uploaded files.")
129
  with gr.Row():
130
  with gr.Column():
131
+ model_url = gr.Textbox(
132
+ label="Enter Model Image URL",
133
+ )
134
+ vton_img = gr.Image(label="Model", sources=['upload', 'webcam'], type="filepath", height=384)
135
  example = gr.Examples(
136
  inputs=vton_img,
137
  examples_per_page=5,
 
143
  os.path.join(example_path, 'model/model_5.png'),
144
  ])
145
  with gr.Column():
146
+ garment_url = gr.Textbox(
147
+ label="Enter Garment Image URL",
148
+ )
149
+ garm_img = gr.Image(label="Garment", sources=['upload', 'webcam'], type="filepath", height=384)
150
  example = gr.Examples(
151
  inputs=garm_img,
152
  examples_per_page=5,
 
165
  image_scale = gr.Slider(label="Guidance scale", minimum=1.0, maximum=5.0, value=2.0, step=0.1)
166
  seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=-1)
167
 
168
+ ips = [vton_img, garm_img, model_url, garment_url, n_steps, image_scale, seed]
169
  run_button.click(fn=run_viton, inputs=ips, outputs=result_gallery)
170
 
171
  block.launch(mcp_server=True)