Spaces:
Running
Running
File size: 6,221 Bytes
7b580ae d2733eb e99c87a 7b580ae d2733eb 7b580ae d2733eb 7b580ae d2733eb ebe7ed8 262e590 e99c87a d8b253f 12776f2 e99c87a c6041b2 e99c87a d2733eb e99c87a d2733eb 262e590 d2733eb e99c87a 262e590 d2733eb e99c87a 262e590 d2733eb e99c87a d2733eb 85c691b e99c87a d2733eb 262e590 d2733eb 85c691b d2733eb e99c87a d2733eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import gradio as gr
import os
import time
import requests
# assets loading
example_path = os.path.join(os.path.dirname(__file__), 'data')
garm_list = os.listdir(os.path.join(example_path,"garment"))
garm_list_path = [os.path.join(example_path, "garment", garm) for garm in garm_list]
person_list = os.listdir(os.path.join(example_path,"person"))
person_list_path = [os.path.join(example_path, "person", person) for person in person_list]
garm_img_category_mapping = {garm_file: os.path.basename(garm_file).split("_")[1].capitalize() for garm_file in garm_list_path}
def update_category(selected_garm_file):
selected_category = garm_img_category_mapping.get(selected_garm_file, None)
# Return an update dict for the Dropdown
return gr.update(value=selected_category)
def call_tryon_api(person_file, garm_file, category, model_type='SD_V1'):
tryon_url = os.environ['API_ENDPOINT'] + "/tryon/v1"
payload = {'garment_type': category, 'model_type': model_type}
files = {
'image_garment_file': open(garm_file, 'rb'),
'image_model_file': open(person_file, 'rb')
}
headers = {
'x-api-key': os.environ['API_KEY']
}
try:
response = requests.post(tryon_url, headers=headers, data=payload, files=files)
if response.ok:
data = response.json()
return data['job_id'], data['status']
else:
print(response.content)
except Exception as e:
print(f"call tryon api error: {e}")
# if the API call fails, return pop up error
raise gr.Error("Over heated, please try again later")
def get_tryon_result(job_id):
result_url = os.environ['API_ENDPOINT'] + "/requests/v1" + f"?job_id={job_id}"
headers = {
'x-api-key': os.environ['API_KEY']
}
try:
response = requests.get(result_url, headers=headers)
if response.ok:
data = response.json()
if data["status"] == "completed":
image_url = data['output'][0]['image_url']
return image_url, data['status']
else:
return None, data['status']
except Exception as e:
print(f"get tryon result error: {e}")
return None, None
def run_turbo(person_img, garm_img, category="Top"):
if person_img is None or garm_img is None:
gr.Warning("input image is missing")
return None, "No input image"
info = "" # placeholder for now
job_id, status = call_tryon_api(person_img, garm_img, category, model_type= os.environ['MODEL_TYPE'])
time.sleep(3) # wait before fetching the result
# check the status of the job
max_retry = 30 # 30x2s = 60s timeout for sinlge job run
while status not in ["completed", "failed"]:
try:
result_image_url, status = get_tryon_result(job_id)
if result_image_url is not None:
return result_image_url, info
except:
pass
time.sleep(2) # Wait before retrying
gr.Warning("Over heated, please try again later")
return None, info
with gr.Blocks() as Huhu_Turbo:
with gr.Row():
with gr.Column(elem_id = "col-garment"):
gr.HTML("""
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;">
<div>
Upload your garment image 🧥
</div>
</div>
""")
with gr.Column(elem_id = "col-person"):
gr.HTML("""
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;">
<div>
Select a model image 🧍
</div>
</div>
""")
with gr.Column(elem_id = "col-result"):
gr.HTML("""
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;">
<div>
“RUN” to get results 🪄
</div>
</div>
""")
with gr.Row():
with gr.Column(elem_id = "col-garment"):
garm_img = gr.Image(label="Garment image", sources='upload', type="filepath")
category = gr.Dropdown(label="Garment type", choices=['Top', 'Bottom', 'Fullbody'], value="Top")
garm_example = gr.Examples(
inputs=garm_img,
examples_per_page=10,
examples=garm_list_path,
cache_examples=False
)
with gr.Column(elem_id = "col-person"):
person_img = gr.Image(label="Person image", sources='upload', type="filepath")
person_example = gr.Examples(
inputs=person_img,
examples_per_page=10,
examples=person_list_path
)
with gr.Column(elem_id = "col-result"):
result_img = gr.Image(label="Result", show_share_button=False)
with gr.Row():
result_info = gr.Text(label="Tryon inference runtime")
generate_button = gr.Button(value="RUN", elem_id="button")
garm_example.load_input_event.then(
fn=update_category,
inputs=[garm_img],
outputs=[category]
)
generate_button.click(fn=run_turbo, inputs=[person_img, garm_img, category], outputs=[result_img, result_info], api_name=False, concurrency_limit=30)
with gr.Column(elem_id = "col-showcase"):
gr.HTML("""
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;">
<div> </div>
<br>
<div>
Huhu-turbo try-on examples in pairs of garment and person images
</div>
</div>
""")
show_case = gr.Examples(
examples=[
["data/examples/person_example.png", "data/examples/garment_example.png", "Top", "data/examples/result_example.png"],
],
inputs=[person_img, garm_img, category, result_img],
label=None
)
Huhu_Turbo.queue(api_open=False).launch(show_api=False)
|