Spaces:
Running
Running
File size: 5,697 Bytes
7b580ae d2733eb e99c87a 7b580ae d2733eb 7b580ae d2733eb 7b580ae d2733eb e99c87a 12776f2 e99c87a d8b253f 12776f2 e99c87a d2733eb e99c87a d2733eb e99c87a d2733eb e99c87a d2733eb e99c87a d2733eb e99c87a d2733eb e99c87a d2733eb e99c87a d2733eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import gradio as gr
import os
import time
import requests
# assets loading
example_path = os.path.join(os.path.dirname(__file__), 'data')
garm_list = os.listdir(os.path.join(example_path,"garment"))
garm_list_path = [os.path.join(example_path, "garment", garm) for garm in garm_list]
person_list = os.listdir(os.path.join(example_path,"person"))
person_list_path = [os.path.join(example_path, "person", person) for person in person_list]
def call_tryon_api(person_file, garm_file, category, model_type='SD_V1'):
print(person_file, garm_file, category)
tryon_url = os.environ['API_ENDPOINT'] + "/tryon/v1"
payload = {'garment_type': category, 'model_type': model_type}
files = {
'image_garment_file': open(garm_file, 'rb'),
'image_model_file': open(person_file, 'rb')
}
headers = {
'x-api-key': os.environ['API_KEY']
}
try:
response = requests.post(tryon_url, headers=headers, data=payload, files=files)
if response.ok:
data = response.json()
return data['job_id'], data['status']
else:
print(response.content)
except Exception as e:
print(f"call tryon api error: {e}")
# if the API call fails, return pop up error
raise gr.Error("Over heated, please try again later")
def get_tryon_result(job_id):
result_url = os.environ['API_ENDPOINT'] + "/requests/v1" + f"?job_id={job_id}"
headers = {
'x-api-key': os.environ['API_KEY']
}
try:
response = requests.get(result_url, headers=headers)
if response.ok:
data = response.json()
if data["status"] == "completed":
image_url = data['output'][0]['image_url']
return image_url, data['status']
else:
return None, data['status']
except Exception as e:
print(f"get tryon result error: {e}")
return None, None
def run_turbo(person_img, garm_img, category="Top"):
if person_img is None or garm_img is None:
gr.Warning("input image is missing")
return None, "No input image"
info = "" # placeholder for now
job_id, status = call_tryon_api(person_img, garm_img, category)
time.sleep(3) # wait before fetching the result
# check the status of the job
max_retry = 30 # 30x2s = 60s timeout for sinlge job run
while status not in ["completed", "failed"]:
try:
result_image_url, status = get_tryon_result(job_id)
if result_image_url is not None:
return result_image_url, info
except:
pass
time.sleep(2) # Wait before retrying
gr.Warning("Over heated, please try again later")
return None, info
with gr.Blocks() as Huhu_Turbo:
with gr.Row():
with gr.Column(elem_id = "col-garment"):
gr.HTML("""
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;">
<div>
Upload your garment image 🧥
</div>
</div>
""")
with gr.Column(elem_id = "col-person"):
gr.HTML("""
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;">
<div>
Select a model image 🧍
</div>
</div>
""")
with gr.Column(elem_id = "col-result"):
gr.HTML("""
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;">
<div>
“RUN” to get results 🪄
</div>
</div>
""")
with gr.Row():
with gr.Column(elem_id = "col-garment"):
garm_img = gr.Image(label="Garment image", sources='upload', type="filepath")
category = gr.Dropdown(label="Garment type", choices=['Top', 'Bottom', 'Fullbody'], value="Top")
example = gr.Examples(
inputs=garm_img,
examples_per_page=10,
examples=garm_list_path
)
with gr.Column(elem_id = "col-person"):
person_img = gr.Image(label="Person image", sources='upload', type="filepath")
example = gr.Examples(
inputs=person_img,
examples_per_page=10,
examples=person_list_path
)
with gr.Column(elem_id = "col-result"):
result_img = gr.Image(label="Result", show_share_button=False)
with gr.Row():
result_info = gr.Text(label="Generation time")
generate_button = gr.Button(value="RUN", elem_id="button")
generate_button.click(fn=run_turbo, inputs=[person_img, garm_img, category], outputs=[result_img, result_info], api_name=False, concurrency_limit=30)
with gr.Column(elem_id = "col-showcase"):
gr.HTML("""
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;">
<div> </div>
<br>
<div>
Huhu-turbo try-on examples in pairs of garment and person images
</div>
</div>
""")
show_case = gr.Examples(
examples=[
["data/examples/garment_example.png", "data/examples/person_example.png", "Top", "data/examples/result_example.png"],
],
inputs=[person_img, garm_img, category, result_img],
label=None
)
Huhu_Turbo.queue(api_open=False).launch(show_api=False)
|