File size: 5,697 Bytes
7b580ae
d2733eb
e99c87a
 
 
7b580ae
d2733eb
 
7b580ae
d2733eb
 
7b580ae
d2733eb
 
 
e99c87a
 
12776f2
e99c87a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d8b253f
12776f2
e99c87a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2733eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e99c87a
d2733eb
 
 
e99c87a
d2733eb
 
 
e99c87a
d2733eb
 
e99c87a
d2733eb
 
 
 
 
 
e99c87a
d2733eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e99c87a
d2733eb
e99c87a
d2733eb
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import gradio as gr
import os
import time
import requests


# assets loading
example_path = os.path.join(os.path.dirname(__file__), 'data')

garm_list = os.listdir(os.path.join(example_path,"garment"))
garm_list_path = [os.path.join(example_path, "garment", garm) for garm in garm_list]

person_list = os.listdir(os.path.join(example_path,"person"))
person_list_path = [os.path.join(example_path, "person", person) for person in person_list]


def call_tryon_api(person_file, garm_file, category, model_type='SD_V1'):
    print(person_file, garm_file, category)
    tryon_url = os.environ['API_ENDPOINT'] + "/tryon/v1"
    payload = {'garment_type': category, 'model_type': model_type}
    files = {
        'image_garment_file': open(garm_file, 'rb'),
        'image_model_file': open(person_file, 'rb')
    }
    headers = {
        'x-api-key': os.environ['API_KEY']
    }
    
    try:
        response = requests.post(tryon_url, headers=headers, data=payload, files=files)
        if response.ok:
            data = response.json()
            return data['job_id'], data['status']
        else:
            print(response.content)
    except Exception as e:
        print(f"call tryon api error: {e}")
    
    # if the API call fails, return pop up error
    raise gr.Error("Over heated, please try again later")

def get_tryon_result(job_id):
    result_url = os.environ['API_ENDPOINT'] + "/requests/v1" + f"?job_id={job_id}"
    headers = {
        'x-api-key': os.environ['API_KEY']
    }
    
    try:
        response = requests.get(result_url, headers=headers)
    
        if response.ok:
            data = response.json()
            if data["status"] == "completed":
                image_url = data['output'][0]['image_url']
                return image_url, data['status']
            else:
                return None, data['status']
    except Exception as e:
        print(f"get tryon result error: {e}")
        return None, None

def run_turbo(person_img, garm_img, category="Top"):
    if person_img is None or garm_img is None:
        gr.Warning("input image is missing")
        return None, "No input image"
    
    info = "" # placeholder for now

    job_id, status = call_tryon_api(person_img, garm_img, category)
    
    time.sleep(3) # wait before fetching the result
    
    # check the status of the job
    max_retry = 30  # 30x2s = 60s timeout for sinlge job run
    while status not in ["completed", "failed"]:
        try:
            result_image_url, status = get_tryon_result(job_id)
            if result_image_url is not None:
                return result_image_url, info
        except:
            pass
        time.sleep(2)  # Wait before retrying

    gr.Warning("Over heated, please try again later")
    return None, info

with gr.Blocks() as Huhu_Turbo:
    with gr.Row():
        with gr.Column(elem_id = "col-garment"):
            gr.HTML("""
            <div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;">
                <div>
                Upload your garment image 🧥
                </div>
            </div>
            """)
        with gr.Column(elem_id = "col-person"):
            gr.HTML("""
            <div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;">
                <div>
                Select a model image 🧍
                </div>
            </div>
            """)
        with gr.Column(elem_id = "col-result"):
            gr.HTML("""
            <div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;">
                <div>
                “RUN” to get results 🪄
                </div>
            </div>
            """)
    with gr.Row():
        with gr.Column(elem_id = "col-garment"):
            garm_img = gr.Image(label="Garment image", sources='upload', type="filepath")
            category = gr.Dropdown(label="Garment type", choices=['Top', 'Bottom', 'Fullbody'],  value="Top")
            example = gr.Examples(
                inputs=garm_img,
                examples_per_page=10,
                examples=garm_list_path
            )
        with gr.Column(elem_id = "col-person"):
            person_img = gr.Image(label="Person image", sources='upload', type="filepath")
            example = gr.Examples(
                inputs=person_img,
                examples_per_page=10,
                examples=person_list_path
            )
        with gr.Column(elem_id = "col-result"):
            result_img = gr.Image(label="Result", show_share_button=False)
            with gr.Row():
                result_info = gr.Text(label="Generation time")
            generate_button = gr.Button(value="RUN", elem_id="button")

    generate_button.click(fn=run_turbo, inputs=[person_img, garm_img, category], outputs=[result_img, result_info], api_name=False, concurrency_limit=30)

    with gr.Column(elem_id = "col-showcase"):
        gr.HTML("""
        <div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;">
            <div> </div>
            <br>
            <div>
            Huhu-turbo try-on examples in pairs of garment and person images
            </div>
        </div>
        """)
        show_case = gr.Examples(
            examples=[
                ["data/examples/garment_example.png", "data/examples/person_example.png", "Top", "data/examples/result_example.png"],
            ],
            inputs=[person_img, garm_img, category, result_img],
            label=None
        )

Huhu_Turbo.queue(api_open=False).launch(show_api=False)