Spaces:
Running
Running
import gradio as gr | |
import os | |
import time | |
import requests | |
from PIL import Image | |
# ux format | |
tryon_css=""" | |
#col-garment { | |
margin: 0 auto; | |
max-width: 420px; | |
} | |
#garm_img { | |
aspect-ratio: 3 / 4; | |
width: 100%; | |
max-height: 560px; | |
object-fit: contain; | |
} | |
#col-person { | |
margin: 0 auto; | |
max-width: 420px; | |
} | |
#person_img { | |
aspect-ratio: 3 / 4; | |
width: 100%; | |
max-height: 560px; | |
object-fit: contain; | |
} | |
#col-result { | |
margin: 0 auto; | |
max-width: 420px; | |
} | |
#result_img { | |
aspect-ratio: 3 / 4; | |
width: 100%; | |
max-height: 560px; | |
object-fit: contain; | |
} | |
#col-examples { | |
margin: 0 auto; | |
max-width: 1000px; | |
} | |
#col-examples img { | |
aspect-ratio: 3 / 4; | |
object-fit: contain; | |
} | |
#button { | |
background-color: #A47764; | |
color: white; | |
} | |
""" | |
# assets loading | |
example_path = os.path.join(os.path.dirname(__file__), 'data') | |
garm_list = os.listdir(os.path.join(example_path,"garment")) | |
garm_list_path = [os.path.join(example_path, "garment", garm) for garm in garm_list] | |
person_list = os.listdir(os.path.join(example_path,"person")) | |
person_list_path = [os.path.join(example_path, "person", person) for person in person_list] | |
garm_img_category_mapping = {os.path.basename(garm_file): os.path.basename(garm_file).split("_")[2].capitalize() for garm_file in garm_list_path} | |
def load_header(header_file): | |
with open(header_file, 'r', encoding='utf-8') as f: | |
content = f.read() | |
return content | |
def preprocess_img(img_path, max_size=1024): | |
if img_path is None: | |
return None | |
img = Image.open(img_path) | |
if max(img.size) > max_size: | |
img.thumbnail((max_size, max_size)) | |
img.save(img_path) | |
return img_path | |
def update_category(selected_garm_file): | |
selected_category = garm_img_category_mapping.get(os.path.basename(selected_garm_file), "Fullbody") | |
return gr.update(value=selected_category) | |
def call_tryon_api(person_file, garm_file, category, model_type='SD_V1'): | |
tryon_url = os.environ['API_ENDPOINT'] + "/tryon/v1" | |
payload = {'garment_type': category, 'model_type': model_type, 'repaint_other_garment': 'false'} | |
files = { | |
'image_garment_file': open(garm_file, 'rb'), | |
'image_model_file': open(person_file, 'rb'), | |
} | |
headers = { | |
'x-api-key': os.environ['API_KEY'] | |
} | |
try: | |
response = requests.post(tryon_url, headers=headers, data=payload, files=files) | |
if response.ok: | |
data = response.json() | |
return data['job_id'], data['status'] | |
else: | |
print(response.content) | |
except Exception as e: | |
print(f"call tryon api error: {e}") | |
# if the API call fails, return pop up error | |
raise gr.Error("Over heated, please try again later") | |
def get_tryon_result(job_id): | |
result_url = os.environ['API_ENDPOINT'] + "/requests/v1" + f"?job_id={job_id}" | |
headers = { | |
'x-api-key': os.environ['API_KEY'] | |
} | |
try: | |
response = requests.get(result_url, headers=headers) | |
if response.ok: | |
data = response.json() | |
if data["status"] == "completed": | |
image_url = data['output'][0]['image_url'] | |
return image_url, data['status'] | |
else: | |
return None, data['status'] | |
except Exception as e: | |
print(f"get tryon result error: {e}") | |
return None, None | |
def run_turbo(person_img, garm_img, category="Top"): | |
if person_img is None or garm_img is None: | |
gr.Warning("input image is missing") | |
return None, "No input image" | |
info = "" # placeholder for now | |
job_id, status = call_tryon_api(person_img, garm_img, category, model_type= os.environ['MODEL_TYPE']) | |
time.sleep(8) # wait before fetching the result | |
# check the status of the job | |
max_retry = 40 # 40x1.5s = 60s timeout for sinlge job run | |
while status not in ["completed", "failed"]: | |
try: | |
result_image_url, status = get_tryon_result(job_id) | |
if result_image_url is not None: | |
return result_image_url, info | |
except: | |
pass | |
time.sleep(1.5) # Wait before retrying | |
gr.Warning("Over heated, please try again later") | |
return None, info | |
with gr.Blocks(css=tryon_css) as Huhu_Turbo: | |
gr.HTML(load_header("data/header.html")) | |
with gr.Row(): | |
with gr.Column(elem_id = "col-garment"): | |
gr.HTML(""" | |
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;"> | |
<div> | |
Upload your garment image 🧥 | |
</div> | |
</div> | |
""") | |
with gr.Column(elem_id = "col-person"): | |
gr.HTML(""" | |
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;"> | |
<div> | |
Select a model image 🧍 | |
</div> | |
</div> | |
""") | |
with gr.Column(elem_id = "col-result"): | |
gr.HTML(""" | |
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;"> | |
<div> | |
“RUN” to get results 🪄 | |
</div> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(elem_id = "col-garment"): | |
garm_img = gr.Image(label="Garment image", sources='upload', type="filepath", elem_id="garm_img") | |
category = gr.Dropdown(label="Garment type", choices=['Top', 'Bottom', 'Fullbody'], value="Top") | |
garm_example = gr.Examples( | |
inputs=garm_img, | |
examples_per_page=10, | |
examples=garm_list_path, | |
cache_examples=False | |
) | |
with gr.Column(elem_id = "col-person"): | |
person_img = gr.Image(label="Person image", sources='upload', type="filepath", elem_id="person_img") | |
person_example = gr.Examples( | |
inputs=person_img, | |
examples_per_page=10, | |
examples=person_list_path | |
) | |
with gr.Column(elem_id = "col-result"): | |
result_img = gr.Image(label="Result", show_share_button=False, elem_id="result_img") | |
with gr.Row(): | |
result_info = gr.Text(label="Tryon inference runtime", visible=False) | |
generate_button = gr.Button(value="RUN", elem_id="button") | |
garm_example.load_input_event.then( | |
fn=update_category, | |
inputs=[garm_img], | |
outputs=[category] | |
) | |
garm_img.change(fn=preprocess_img, inputs=[garm_img], outputs=[garm_img]) | |
person_img.change(fn=preprocess_img, inputs=[person_img], outputs=[person_img]) | |
generate_button.click(fn=run_turbo, inputs=[person_img, garm_img, category], outputs=[result_img, result_info], api_name=False, concurrency_limit=30) | |
with gr.Column(elem_id = "col-examples"): | |
gr.HTML(""" | |
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;"> | |
<div> </div> | |
<br> | |
<div> | |
Huhu Try-on Turbo examples in pairs of garment and model images | |
</div> | |
</div> | |
""") | |
show_case = gr.Examples( | |
examples=[ | |
["data/examples/person_example_1.png", "data/examples/garment_example_1.png", "Top", "data/examples/result_example_1.png"], | |
["data/examples/person_example_2.png", "data/examples/garment_example_2.png", "Top", "data/examples/result_example_2.png"], | |
["data/examples/person_example_3.png", "data/examples/garment_example_3.png", "Top", "data/examples/result_example_3.png"], | |
["data/examples/person_example_4.png", "data/examples/garment_example_4.png", "Fullbody", "data/examples/result_example_4.png"], | |
["data/examples/person_example_5.png", "data/examples/garment_example_5.png", "Top", "data/examples/result_example_5.png"], | |
], | |
inputs=[person_img, garm_img, category, result_img], | |
label=None | |
) | |
Huhu_Turbo.queue(api_open=False).launch(show_api=False) | |