File size: 8,255 Bytes
7b580ae
d2733eb
e99c87a
 
67557f4
e99c87a
7b580ae
1d18396
 
 
 
 
218fb5f
 
449bee0
 
fceefab
9b03efc
1d18396
218fb5f
1d18396
 
218fb5f
 
449bee0
 
fceefab
9b03efc
1d18396
218fb5f
1d18396
 
218fb5f
 
449bee0
 
fceefab
9b03efc
1d18396
 
 
 
 
94e6be6
 
 
 
1d18396
214cdb2
 
1d18396
 
 
d2733eb
 
7b580ae
d2733eb
 
7b580ae
d2733eb
 
 
4d958b5
262e590
 
9b03efc
 
 
 
 
67557f4
5705ac4
 
67557f4
 
 
 
 
 
262e590
c540905
262e590
e99c87a
 
 
7713351
e99c87a
 
214cdb2
 
e99c87a
 
 
 
 
 
 
 
 
d8b253f
12776f2
e99c87a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6041b2
e99c87a
f4003e7
e99c87a
 
f4003e7
e99c87a
 
 
 
 
 
 
f4003e7
e99c87a
 
 
d2733eb
1d18396
4d958b5
d2733eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218fb5f
d2733eb
262e590
d2733eb
e99c87a
262e590
 
d2733eb
 
218fb5f
262e590
d2733eb
e99c87a
d2733eb
 
 
218fb5f
d2733eb
c88f726
e99c87a
d2733eb
262e590
 
 
 
 
67557f4
 
 
262e590
d2733eb
 
1d18396
d2733eb
 
 
 
 
fe79c88
d2733eb
 
 
 
 
fe79c88
 
 
 
 
d2733eb
e99c87a
d2733eb
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
import gradio as gr
import os
import time
import requests
from PIL import Image


# ux format
tryon_css="""
#col-garment {
    margin: 0 auto;
    max-width: 420px;
}
#garm_img {
    aspect-ratio: 3 / 4;
    width: 100%;
    max-height: 560px;
    object-fit: contain;
}
#col-person {
    margin: 0 auto;
    max-width: 420px;
}
#person_img {
    aspect-ratio: 3 / 4;
    width: 100%;
    max-height: 560px;
    object-fit: contain;
}
#col-result {
    margin: 0 auto;
    max-width: 420px;
}
#result_img {
    aspect-ratio: 3 / 4;
    width: 100%;
    max-height: 560px;
    object-fit: contain;
}
#col-examples {
    margin: 0 auto;
    max-width: 1000px;
}
#col-examples img {
    aspect-ratio: 3 / 4;
    object-fit: contain;
}
#button {
    background-color: #A47764;
    color: white;
}
"""

# assets loading
example_path = os.path.join(os.path.dirname(__file__), 'data')

garm_list = os.listdir(os.path.join(example_path,"garment"))
garm_list_path = [os.path.join(example_path, "garment", garm) for garm in garm_list]

person_list = os.listdir(os.path.join(example_path,"person"))
person_list_path = [os.path.join(example_path, "person", person) for person in person_list]

garm_img_category_mapping = {os.path.basename(garm_file): os.path.basename(garm_file).split("_")[2].capitalize() for garm_file in garm_list_path}


def load_header(header_file):
    with open(header_file, 'r', encoding='utf-8') as f:
        content = f.read()
    return content

def preprocess_img(img_path, max_size=1024):
    if img_path is None:
        return None
    img = Image.open(img_path)
    if max(img.size) > max_size:
        img.thumbnail((max_size, max_size))
        img.save(img_path)
    return img_path

def update_category(selected_garm_file):
    selected_category = garm_img_category_mapping.get(os.path.basename(selected_garm_file), "Fullbody")
    return gr.update(value=selected_category)

def call_tryon_api(person_file, garm_file, category, model_type='SD_V1'):
    tryon_url = os.environ['API_ENDPOINT'] + "/tryon/v1"
    payload = {'garment_type': category, 'model_type': model_type, 'repaint_other_garment': 'false'}
    files = {
        'image_garment_file': open(garm_file, 'rb'),
        'image_model_file': open(person_file, 'rb'),   
        }
    headers = {
        'x-api-key': os.environ['API_KEY']
    }
    
    try:
        response = requests.post(tryon_url, headers=headers, data=payload, files=files)
        if response.ok:
            data = response.json()
            return data['job_id'], data['status']
        else:
            print(response.content)
    except Exception as e:
        print(f"call tryon api error: {e}")
    
    # if the API call fails, return pop up error
    raise gr.Error("Over heated, please try again later")

def get_tryon_result(job_id):
    result_url = os.environ['API_ENDPOINT'] + "/requests/v1" + f"?job_id={job_id}"
    headers = {
        'x-api-key': os.environ['API_KEY']
    }
    
    try:
        response = requests.get(result_url, headers=headers)
    
        if response.ok:
            data = response.json()
            if data["status"] == "completed":
                image_url = data['output'][0]['image_url']
                return image_url, data['status']
            else:
                return None, data['status']
    except Exception as e:
        print(f"get tryon result error: {e}")
        return None, None

def run_turbo(person_img, garm_img, category="Top"):
    if person_img is None or garm_img is None:
        gr.Warning("input image is missing")
        return None, "No input image"
    
    info = "" # placeholder for now

    job_id, status = call_tryon_api(person_img, garm_img, category, model_type= os.environ['MODEL_TYPE'])
    
    time.sleep(8) # wait before fetching the result
    
    # check the status of the job
    max_retry = 40  # 40x1.5s = 60s timeout for sinlge job run
    while status not in ["completed", "failed"]:
        try:
            result_image_url, status = get_tryon_result(job_id)
            if result_image_url is not None:
                return result_image_url, info
        except:
            pass
        time.sleep(1.5)  # Wait before retrying

    gr.Warning("Over heated, please try again later")
    return None, info

with gr.Blocks(css=tryon_css) as Huhu_Turbo:
    gr.HTML(load_header("data/header.html"))
    with gr.Row():
        with gr.Column(elem_id = "col-garment"):
            gr.HTML("""
            <div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;">
                <div>
                Upload your garment image 🧥
                </div>
            </div>
            """)
        with gr.Column(elem_id = "col-person"):
            gr.HTML("""
            <div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;">
                <div>
                Select a model image 🧍
                </div>
            </div>
            """)
        with gr.Column(elem_id = "col-result"):
            gr.HTML("""
            <div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;">
                <div>
                “RUN” to get results 🪄
                </div>
            </div>
            """)
    with gr.Row():
        with gr.Column(elem_id = "col-garment"):
            garm_img = gr.Image(label="Garment image", sources='upload', type="filepath", elem_id="garm_img")
            category = gr.Dropdown(label="Garment type", choices=['Top', 'Bottom', 'Fullbody'],  value="Top")
            garm_example = gr.Examples(
                inputs=garm_img,
                examples_per_page=10,
                examples=garm_list_path,
                cache_examples=False 
            )
        with gr.Column(elem_id = "col-person"):
            person_img = gr.Image(label="Person image", sources='upload', type="filepath", elem_id="person_img")
            person_example = gr.Examples(
                inputs=person_img,
                examples_per_page=10,
                examples=person_list_path
            )
        with gr.Column(elem_id = "col-result"):
            result_img = gr.Image(label="Result", show_share_button=False, elem_id="result_img")
            with gr.Row():
                result_info = gr.Text(label="Tryon inference runtime", visible=False)
            generate_button = gr.Button(value="RUN", elem_id="button")

    garm_example.load_input_event.then(
        fn=update_category,
        inputs=[garm_img],
        outputs=[category]
    )
    
    garm_img.change(fn=preprocess_img, inputs=[garm_img], outputs=[garm_img])
    person_img.change(fn=preprocess_img, inputs=[person_img], outputs=[person_img])

    generate_button.click(fn=run_turbo, inputs=[person_img, garm_img, category], outputs=[result_img, result_info], api_name=False, concurrency_limit=30)

    with gr.Column(elem_id = "col-examples"):
        gr.HTML("""
        <div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;">
            <div> </div>
            <br>
            <div>
            Huhu Try-on Turbo examples in pairs of garment and model images
            </div>
        </div>
        """)
        show_case = gr.Examples(
            examples=[
                ["data/examples/person_example_1.png", "data/examples/garment_example_1.png", "Top", "data/examples/result_example_1.png"],
                ["data/examples/person_example_2.png", "data/examples/garment_example_2.png", "Top", "data/examples/result_example_2.png"],
                ["data/examples/person_example_3.png", "data/examples/garment_example_3.png", "Top", "data/examples/result_example_3.png"],
                ["data/examples/person_example_4.png", "data/examples/garment_example_4.png", "Fullbody", "data/examples/result_example_4.png"],
                ["data/examples/person_example_5.png", "data/examples/garment_example_5.png", "Top", "data/examples/result_example_5.png"],
            ],
            inputs=[person_img, garm_img, category, result_img],
            label=None
        )

Huhu_Turbo.queue(api_open=False).launch(show_api=False)