zc-huhu's picture
minor
5705ac4
import gradio as gr
import os
import time
import requests
from PIL import Image
# ux format
tryon_css="""
#col-garment {
margin: 0 auto;
max-width: 420px;
}
#garm_img {
aspect-ratio: 3 / 4;
width: 100%;
max-height: 560px;
object-fit: contain;
}
#col-person {
margin: 0 auto;
max-width: 420px;
}
#person_img {
aspect-ratio: 3 / 4;
width: 100%;
max-height: 560px;
object-fit: contain;
}
#col-result {
margin: 0 auto;
max-width: 420px;
}
#result_img {
aspect-ratio: 3 / 4;
width: 100%;
max-height: 560px;
object-fit: contain;
}
#col-examples {
margin: 0 auto;
max-width: 1000px;
}
#col-examples img {
aspect-ratio: 3 / 4;
object-fit: contain;
}
#button {
background-color: #A47764;
color: white;
}
"""
# assets loading
example_path = os.path.join(os.path.dirname(__file__), 'data')
garm_list = os.listdir(os.path.join(example_path,"garment"))
garm_list_path = [os.path.join(example_path, "garment", garm) for garm in garm_list]
person_list = os.listdir(os.path.join(example_path,"person"))
person_list_path = [os.path.join(example_path, "person", person) for person in person_list]
garm_img_category_mapping = {os.path.basename(garm_file): os.path.basename(garm_file).split("_")[2].capitalize() for garm_file in garm_list_path}
def load_header(header_file):
with open(header_file, 'r', encoding='utf-8') as f:
content = f.read()
return content
def preprocess_img(img_path, max_size=1024):
if img_path is None:
return None
img = Image.open(img_path)
if max(img.size) > max_size:
img.thumbnail((max_size, max_size))
img.save(img_path)
return img_path
def update_category(selected_garm_file):
selected_category = garm_img_category_mapping.get(os.path.basename(selected_garm_file), "Fullbody")
return gr.update(value=selected_category)
def call_tryon_api(person_file, garm_file, category, model_type='SD_V1'):
tryon_url = os.environ['API_ENDPOINT'] + "/tryon/v1"
payload = {'garment_type': category, 'model_type': model_type, 'repaint_other_garment': 'false'}
files = {
'image_garment_file': open(garm_file, 'rb'),
'image_model_file': open(person_file, 'rb'),
}
headers = {
'x-api-key': os.environ['API_KEY']
}
try:
response = requests.post(tryon_url, headers=headers, data=payload, files=files)
if response.ok:
data = response.json()
return data['job_id'], data['status']
else:
print(response.content)
except Exception as e:
print(f"call tryon api error: {e}")
# if the API call fails, return pop up error
raise gr.Error("Over heated, please try again later")
def get_tryon_result(job_id):
result_url = os.environ['API_ENDPOINT'] + "/requests/v1" + f"?job_id={job_id}"
headers = {
'x-api-key': os.environ['API_KEY']
}
try:
response = requests.get(result_url, headers=headers)
if response.ok:
data = response.json()
if data["status"] == "completed":
image_url = data['output'][0]['image_url']
return image_url, data['status']
else:
return None, data['status']
except Exception as e:
print(f"get tryon result error: {e}")
return None, None
def run_turbo(person_img, garm_img, category="Top"):
if person_img is None or garm_img is None:
gr.Warning("input image is missing")
return None, "No input image"
info = "" # placeholder for now
job_id, status = call_tryon_api(person_img, garm_img, category, model_type= os.environ['MODEL_TYPE'])
time.sleep(8) # wait before fetching the result
# check the status of the job
max_retry = 40 # 40x1.5s = 60s timeout for sinlge job run
while status not in ["completed", "failed"]:
try:
result_image_url, status = get_tryon_result(job_id)
if result_image_url is not None:
return result_image_url, info
except:
pass
time.sleep(1.5) # Wait before retrying
gr.Warning("Over heated, please try again later")
return None, info
with gr.Blocks(css=tryon_css) as Huhu_Turbo:
gr.HTML(load_header("data/header.html"))
with gr.Row():
with gr.Column(elem_id = "col-garment"):
gr.HTML("""
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;">
<div>
Upload your garment image 🧥
</div>
</div>
""")
with gr.Column(elem_id = "col-person"):
gr.HTML("""
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;">
<div>
Select a model image 🧍
</div>
</div>
""")
with gr.Column(elem_id = "col-result"):
gr.HTML("""
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;">
<div>
“RUN” to get results 🪄
</div>
</div>
""")
with gr.Row():
with gr.Column(elem_id = "col-garment"):
garm_img = gr.Image(label="Garment image", sources='upload', type="filepath", elem_id="garm_img")
category = gr.Dropdown(label="Garment type", choices=['Top', 'Bottom', 'Fullbody'], value="Top")
garm_example = gr.Examples(
inputs=garm_img,
examples_per_page=10,
examples=garm_list_path,
cache_examples=False
)
with gr.Column(elem_id = "col-person"):
person_img = gr.Image(label="Person image", sources='upload', type="filepath", elem_id="person_img")
person_example = gr.Examples(
inputs=person_img,
examples_per_page=10,
examples=person_list_path
)
with gr.Column(elem_id = "col-result"):
result_img = gr.Image(label="Result", show_share_button=False, elem_id="result_img")
with gr.Row():
result_info = gr.Text(label="Tryon inference runtime", visible=False)
generate_button = gr.Button(value="RUN", elem_id="button")
garm_example.load_input_event.then(
fn=update_category,
inputs=[garm_img],
outputs=[category]
)
garm_img.change(fn=preprocess_img, inputs=[garm_img], outputs=[garm_img])
person_img.change(fn=preprocess_img, inputs=[person_img], outputs=[person_img])
generate_button.click(fn=run_turbo, inputs=[person_img, garm_img, category], outputs=[result_img, result_info], api_name=False, concurrency_limit=30)
with gr.Column(elem_id = "col-examples"):
gr.HTML("""
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;">
<div> </div>
<br>
<div>
Huhu Try-on Turbo examples in pairs of garment and model images
</div>
</div>
""")
show_case = gr.Examples(
examples=[
["data/examples/person_example_1.png", "data/examples/garment_example_1.png", "Top", "data/examples/result_example_1.png"],
["data/examples/person_example_2.png", "data/examples/garment_example_2.png", "Top", "data/examples/result_example_2.png"],
["data/examples/person_example_3.png", "data/examples/garment_example_3.png", "Top", "data/examples/result_example_3.png"],
["data/examples/person_example_4.png", "data/examples/garment_example_4.png", "Fullbody", "data/examples/result_example_4.png"],
["data/examples/person_example_5.png", "data/examples/garment_example_5.png", "Top", "data/examples/result_example_5.png"],
],
inputs=[person_img, garm_img, category, result_img],
label=None
)
Huhu_Turbo.queue(api_open=False).launch(show_api=False)