zc-huhu's picture
debug
12776f2
raw
history blame
5.7 kB
import gradio as gr
import os
import time
import requests
# assets loading
example_path = os.path.join(os.path.dirname(__file__), 'data')
garm_list = os.listdir(os.path.join(example_path,"garment"))
garm_list_path = [os.path.join(example_path, "garment", garm) for garm in garm_list]
person_list = os.listdir(os.path.join(example_path,"person"))
person_list_path = [os.path.join(example_path, "person", person) for person in person_list]
def call_tryon_api(person_file, garm_file, category, model_type='SD_V1'):
print(person_file, garm_file, category)
tryon_url = os.environ['API_ENDPOINT'] + "/tryon/v1"
payload = {'garment_type': category, 'model_type': model_type}
files = {
'image_garment_file': open(garm_file, 'rb'),
'image_model_file': open(person_file, 'rb')
}
headers = {
'x-api-key': os.environ['API_KEY']
}
try:
response = requests.post(tryon_url, headers=headers, data=payload, files=files)
if response.ok:
data = response.json()
return data['job_id'], data['status']
else:
print(response.content)
except Exception as e:
print(f"call tryon api error: {e}")
# if the API call fails, return pop up error
raise gr.Error("Over heated, please try again later")
def get_tryon_result(job_id):
result_url = os.environ['API_ENDPOINT'] + "/requests/v1" + f"?job_id={job_id}"
headers = {
'x-api-key': os.environ['API_KEY']
}
try:
response = requests.get(result_url, headers=headers)
if response.ok:
data = response.json()
if data["status"] == "completed":
image_url = data['output'][0]['image_url']
return image_url, data['status']
else:
return None, data['status']
except Exception as e:
print(f"get tryon result error: {e}")
return None, None
def run_turbo(person_img, garm_img, category="Top"):
if person_img is None or garm_img is None:
gr.Warning("input image is missing")
return None, "No input image"
info = "" # placeholder for now
job_id, status = call_tryon_api(person_img, garm_img, category)
time.sleep(3) # wait before fetching the result
# check the status of the job
max_retry = 30 # 30x2s = 60s timeout for sinlge job run
while status not in ["completed", "failed"]:
try:
result_image_url, status = get_tryon_result(job_id)
if result_image_url is not None:
return result_image_url, info
except:
pass
time.sleep(2) # Wait before retrying
gr.Warning("Over heated, please try again later")
return None, info
with gr.Blocks() as Huhu_Turbo:
with gr.Row():
with gr.Column(elem_id = "col-garment"):
gr.HTML("""
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;">
<div>
Upload your garment image πŸ§₯
</div>
</div>
""")
with gr.Column(elem_id = "col-person"):
gr.HTML("""
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;">
<div>
Select a model image 🧍
</div>
</div>
""")
with gr.Column(elem_id = "col-result"):
gr.HTML("""
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;">
<div>
β€œRUN” to get results πŸͺ„
</div>
</div>
""")
with gr.Row():
with gr.Column(elem_id = "col-garment"):
garm_img = gr.Image(label="Garment image", sources='upload', type="filepath")
category = gr.Dropdown(label="Garment type", choices=['Top', 'Bottom', 'Fullbody'], value="Top")
example = gr.Examples(
inputs=garm_img,
examples_per_page=10,
examples=garm_list_path
)
with gr.Column(elem_id = "col-person"):
person_img = gr.Image(label="Person image", sources='upload', type="filepath")
example = gr.Examples(
inputs=person_img,
examples_per_page=10,
examples=person_list_path
)
with gr.Column(elem_id = "col-result"):
result_img = gr.Image(label="Result", show_share_button=False)
with gr.Row():
result_info = gr.Text(label="Generation time")
generate_button = gr.Button(value="RUN", elem_id="button")
generate_button.click(fn=run_turbo, inputs=[person_img, garm_img, category], outputs=[result_img, result_info], api_name=False, concurrency_limit=30)
with gr.Column(elem_id = "col-showcase"):
gr.HTML("""
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;">
<div> </div>
<br>
<div>
Huhu-turbo try-on examples in pairs of garment and person images
</div>
</div>
""")
show_case = gr.Examples(
examples=[
["data/examples/garment_example.png", "data/examples/person_example.png", "Top", "data/examples/result_example.png"],
],
inputs=[person_img, garm_img, category, result_img],
label=None
)
Huhu_Turbo.queue(api_open=False).launch(show_api=False)