File size: 10,907 Bytes
d2bc90c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7c9368
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
588bc14
ccd41f5
 
 
 
 
 
 
 
 
 
 
 
588bc14
ccd41f5
588bc14
 
d2bc90c
 
588bc14
 
d2bc90c
588bc14
c7c9368
 
588bc14
c7c9368
 
 
588bc14
 
c7c9368
 
 
 
 
588bc14
c7c9368
 
 
588bc14
 
c7c9368
 
 
 
 
588bc14
d2bc90c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
588bc14
d2bc90c
 
 
 
588bc14
d2bc90c
 
588bc14
 
 
ccd41f5
 
 
588bc14
ccd41f5
588bc14
 
ccd41f5
588bc14
 
 
ccd41f5
 
 
 
 
 
 
 
588bc14
ccd41f5
 
588bc14
ccd41f5
588bc14
 
ccd41f5
 
588bc14
 
ccd41f5
588bc14
ccd41f5
 
588bc14
ccd41f5
 
 
588bc14
 
ccd41f5
 
 
 
588bc14
 
ccd41f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
588bc14
ccd41f5
 
 
 
588bc14
ccd41f5
 
588bc14
 
 
d2bc90c
 
 
588bc14
d2bc90c
588bc14
 
d2bc90c
588bc14
 
d2bc90c
 
 
 
 
 
 
ccd41f5
c7c9368
 
 
 
d2bc90c
 
ccd41f5
d2bc90c
 
 
 
 
 
 
ccd41f5
c7c9368
ccd41f5
 
 
 
c7c9368
 
d2bc90c
 
ccd41f5
d2bc90c
 
 
 
 
 
 
ccd41f5
 
d2bc90c
ccd41f5
 
d2bc90c
 
 
 
ccd41f5
 
 
 
d2bc90c
a1b5165
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
import gradio as gr
import os
from PIL import Image
import requests
import base64
import io
from dotenv import load_dotenv

load_dotenv()

example_path = os.path.join(os.path.dirname(__file__), 'examples')

def image_to_base64(image_path):  # Remove 'self'
    """Convert image file to base64 string"""
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode()

def base64_to_image(base64_str, output_path):  # Remove 'self'
    """Convert base64 string to image file"""
    image_data = base64.b64decode(base64_str)
    image = Image.open(io.BytesIO(image_data))
    image.save(output_path)
    return image

def download_image_from_url(url, output_path):
    """Download image from URL and save to local path"""
    try:
        response = requests.get(url, timeout=30)
        response.raise_for_status()
        
        # Save the image
        with open(output_path, 'wb') as f:
            f.write(response.content)
        
        # Verify it's a valid image
        image = Image.open(output_path)
        return output_path
    except Exception as e:
        print(f"Error downloading image from {url}: {str(e)}")
        return None

def url_to_base64(url):
    """Convert image URL to base64 string"""
    try:
        response = requests.get(url, timeout=30)
        response.raise_for_status()
        return base64.b64encode(response.content).decode()
    except Exception as e:
        print(f"Error converting URL to base64: {str(e)}")
        return None


def run_viton(model_image_path: str = None, 
              garment_image_path: str = None, 
              model_url: str = None, 
              garment_url: str = None,
              n_steps=20, 
              image_scale=2.0, 
              seed=-1
            ):
    """
    Run the Virtual Try-On model with provided images path or URLs.
    """
    if not model_image_path and not model_url:
        raise gr.Error("❌ Please provide either a model image file or URL")
    if not garment_image_path and not garment_url:
        raise gr.Error("❌ Please provide either a garment image file or URL")
    
    try:
        api_url = os.environ.get("SERVER_URL")
        if not api_url:
            raise gr.Error("❌ SERVER_URL not configured in environment variables")
        
        print(f"Using API URL: {api_url}")
        
        # Handle model image
        model_b64 = None
        if model_url and model_url.strip():
            print(f"Using model URL: {model_url}")
            model_b64 = url_to_base64(model_url.strip())
            if not model_b64:
                raise gr.Error("❌ Failed to load model image from URL. Please check the URL is valid.")
        elif model_image_path:
            print(f"Using model file: {model_image_path}")
            model_b64 = image_to_base64(model_image_path)
        
        # Handle garment image
        garment_b64 = None
        if garment_url and garment_url.strip():
            print(f"Using garment URL: {garment_url}")
            garment_b64 = url_to_base64(garment_url.strip())
            if not garment_b64:
                raise gr.Error("❌ Failed to load garment image from URL. Please check the URL is valid.")
        elif garment_image_path:
            print(f"Using garment file: {garment_image_path}")
            garment_b64 = image_to_base64(garment_image_path)
        
        if not model_b64 or not garment_b64:
            raise gr.Error("❌ Failed to process images. Please try again.")

        # Prepare request
        request_data = {
            "model_image_base64": model_b64,
            "garment_image_base64": garment_b64,
            "n_samples": 1,
            "n_steps": n_steps,
            "image_scale": image_scale,
            "seed": seed
        }
                
        # Send request
        response = requests.post(f"{api_url}/viton", 
                                json=request_data, 
                                timeout=300)
        
        print(f"Request sent to {api_url}/viton")
        print(f"Response status code: {response.status_code}")
        
        if response.status_code == 200:
            result = response.json()
            if result.get("error"):
                raise gr.Error(f"❌ Server error: {result['error']}")
            
            generated_images = []
            for i, img_b64 in enumerate(result.get("images_base64", [])):
                output_path = f"ootd_output_{i}.png"
                img = base64_to_image(img_b64, output_path)
                generated_images.append(img)
            
            if not generated_images:
                raise gr.Error("❌ No images were generated. Please try again.")
            
            print(f"Successfully generated {len(generated_images)} images")
            return generated_images
        else:
            raise gr.Error(f"❌ Request failed with status code: {response.status_code}")
            
    except gr.Error:
        raise  # Re-raise Gradio errors
    except Exception as e:
        print(f"Exception occurred: {str(e)}")
        raise gr.Error(f"❌ An unexpected error occurred: {str(e)}")

def run_new_garment(model_image_path: str = None, 
              garment_prompt: str = None, 
              model_url: str = None, 
              n_steps=20, 
              image_scale=2.0, 
              seed=-1
            ):
    """
    Run the Virtual Try-On model with provided model image and garment prompt.
    """
    if not model_image_path and not model_url:
        raise gr.Error("❌ Please provide either a model image file or URL")
    if not garment_prompt or not garment_prompt.strip():
        raise gr.Error("❌ Please provide a garment description")
    
    try:
        api_url = os.environ.get("SERVER_URL")
        if not api_url:
            raise gr.Error("❌ SERVER_URL not configured in environment variables")
        
        print(f"Using API URL: {api_url}")
        
        # Handle model image
        model_b64 = None
        if model_url and model_url.strip():
            print(f"Using model URL: {model_url}")
            model_b64 = url_to_base64(model_url.strip())
            if not model_b64:
                raise gr.Error("❌ Failed to load model image from URL. Please check the URL is valid.")
        elif model_image_path:
            print(f"Using model file: {model_image_path}")
            model_b64 = image_to_base64(model_image_path)
        
        if not model_b64:
            raise gr.Error("❌ Failed to process model image. Please try again.")

        # Prepare request
        request_data = {
            "model_image_base64": model_b64,
            "garment_prompt": garment_prompt.strip(),
            "n_samples": 1,
            "n_steps": n_steps,
            "image_scale": image_scale,
            "seed": seed
        }
                
        # Send request
        response = requests.post(f"{api_url}/new-garment", 
                                json=request_data, 
                                timeout=300)
        
        print(f"Request sent to {api_url}/new-garment")
        print(f"Response status code: {response.status_code}")
        
        if response.status_code == 200:
            result = response.json()
            if result.get("error"):
                raise gr.Error(f"❌ Server error: {result['error']}")
            
            generated_images = []
            for i, img_b64 in enumerate(result.get("images_base64", [])):
                output_path = f"flux_output_{i}.png"
                img = base64_to_image(img_b64, output_path)
                generated_images.append(img)
            
            if not generated_images:
                raise gr.Error("❌ No images were generated. Please try again.")
            
            print(f"Successfully generated {len(generated_images)} images")
            return generated_images
        else:
            raise gr.Error(f"❌ Request failed with status code: {response.status_code}")
            
    except gr.Error:
        raise  # Re-raise Gradio errors
    except Exception as e:
        print(f"Exception occurred: {str(e)}")
        raise gr.Error(f"❌ An unexpected error occurred: {str(e)}")

block = gr.Blocks().queue()
with block:
    with gr.Row():
        gr.Markdown("# Virtual Try-On")
    with gr.Row():
        with gr.Column():
            gr.Markdown("### Provide image or URL of upper body photo")
            model_url = gr.Textbox(
                label="Enter Model Image URL", 
            )
            vton_img = gr.Image(label="Model", sources=['upload', 'webcam'], type="filepath", height=384)
            example = gr.Examples(
                inputs=vton_img,
                examples_per_page=4,
                examples=[
                    os.path.join(example_path, 'model/model_2.png'),
                    os.path.join(example_path, 'model/model_7.png'),
                    os.path.join(example_path, 'model/model_4.png'),
                    os.path.join(example_path, 'model/model_5.png'),
                ])
        with gr.Column():
            gr.Markdown("### Provide image, URL or description of a garment")
            garment_url = gr.Textbox(
                label="Enter Garment Image URL",
            )
            garment_promt = gr.Textbox(
                label="Describe Garment",
            )
            garm_img = gr.Image(label="Garment", sources=['upload', 'webcam'], type="filepath", height=384)
            example = gr.Examples(
                inputs=garm_img,
                examples_per_page=4,
                examples=[
                    os.path.join(example_path, 'garment/07764_00.jpg'),
                    os.path.join(example_path, 'garment/03032_00.jpg'),
                    os.path.join(example_path, 'garment/048554_1.jpg'),
                    os.path.join(example_path, 'garment/049805_1.jpg'),
                ])
        with gr.Column():
            gr.Markdown("### 2D Result")
            result_gallery = gr.Gallery(label='Output 2D', show_label=False, elem_id="gallery", preview=True, scale=1)   
    with gr.Column():
        run_button = gr.Button(value="Try On with your garment")
        run_button2 = gr.Button(value="Try On with AI generated garment")
        n_steps = gr.Slider(label="Steps", minimum=20, maximum=40, value=20, step=1)
        image_scale = gr.Slider(label="Guidance scale", minimum=1.0, maximum=5.0, value=2.0, step=0.1)
        seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=-1)
        
    ips1 = [vton_img, garm_img, model_url, garment_url, n_steps, image_scale, seed]
    run_button.click(fn=run_viton, inputs=ips1, outputs=result_gallery)
    ips2 = [vton_img, garment_promt, model_url, n_steps, image_scale, seed]
    run_button2.click(fn=run_new_garment, inputs=ips2, outputs=result_gallery)

block.launch(mcp_server=True)