import gradio as gr import os import time import requests from PIL import Image # ux format tryon_css=""" #col-garment { margin: 0 auto; max-width: 420px; } #garm_img { aspect-ratio: 3 / 4; width: 100%; max-height: 560px; object-fit: contain; } #col-person { margin: 0 auto; max-width: 420px; } #person_img { aspect-ratio: 3 / 4; width: 100%; max-height: 560px; object-fit: contain; } #col-result { margin: 0 auto; max-width: 420px; } #result_img { aspect-ratio: 3 / 4; width: 100%; max-height: 560px; object-fit: contain; } #col-examples { margin: 0 auto; max-width: 1000px; } #col-examples img { aspect-ratio: 3 / 4; object-fit: contain; } #button { background-color: #A47764; color: white; } """ # assets loading example_path = os.path.join(os.path.dirname(__file__), 'data') garm_list = os.listdir(os.path.join(example_path,"garment")) garm_list_path = [os.path.join(example_path, "garment", garm) for garm in garm_list] person_list = os.listdir(os.path.join(example_path,"person")) person_list_path = [os.path.join(example_path, "person", person) for person in person_list] garm_img_category_mapping = {os.path.basename(garm_file): os.path.basename(garm_file).split("_")[2].capitalize() for garm_file in garm_list_path} def load_header(header_file): with open(header_file, 'r', encoding='utf-8') as f: content = f.read() return content def preprocess_img(img_path, max_size=1024): if img_path is None: return None img = Image.open(img_path) if max(img.size) > max_size: img.thumbnail((max_size, max_size)) img.save(img_path) return img_path def update_category(selected_garm_file): selected_category = garm_img_category_mapping.get(os.path.basename(selected_garm_file), "Fullbody") return gr.update(value=selected_category) def call_tryon_api(person_file, garm_file, category, model_type='SD_V1'): tryon_url = os.environ['API_ENDPOINT'] + "/tryon/v1" payload = {'garment_type': category, 'model_type': model_type, 'repaint_other_garment': 'false'} files = { 'image_garment_file': open(garm_file, 'rb'), 'image_model_file': open(person_file, 'rb'), } headers = { 'x-api-key': os.environ['API_KEY'] } try: response = requests.post(tryon_url, headers=headers, data=payload, files=files) if response.ok: data = response.json() return data['job_id'], data['status'] else: print(response.content) except Exception as e: print(f"call tryon api error: {e}") # if the API call fails, return pop up error raise gr.Error("Over heated, please try again later") def get_tryon_result(job_id): result_url = os.environ['API_ENDPOINT'] + "/requests/v1" + f"?job_id={job_id}" headers = { 'x-api-key': os.environ['API_KEY'] } try: response = requests.get(result_url, headers=headers) if response.ok: data = response.json() if data["status"] == "completed": image_url = data['output'][0]['image_url'] return image_url, data['status'] else: return None, data['status'] except Exception as e: print(f"get tryon result error: {e}") return None, None def run_turbo(person_img, garm_img, category="Top"): if person_img is None or garm_img is None: gr.Warning("input image is missing") return None, "No input image" info = "" # placeholder for now job_id, status = call_tryon_api(person_img, garm_img, category, model_type= os.environ['MODEL_TYPE']) time.sleep(8) # wait before fetching the result # check the status of the job max_retry = 40 # 40x1.5s = 60s timeout for sinlge job run while status not in ["completed", "failed"]: try: result_image_url, status = get_tryon_result(job_id) if result_image_url is not None: return result_image_url, info except: pass time.sleep(1.5) # Wait before retrying gr.Warning("Over heated, please try again later") return None, info with gr.Blocks(css=tryon_css) as Huhu_Turbo: gr.HTML(load_header("data/header.html")) with gr.Row(): with gr.Column(elem_id = "col-garment"): gr.HTML("""