zc-huhu commited on
Commit
e99c87a
·
1 Parent(s): 684703f

add api calling

Browse files
Files changed (1) hide show
  1. app.py +77 -9
app.py CHANGED
@@ -1,5 +1,8 @@
1
  import gradio as gr
2
  import os
 
 
 
3
 
4
  # assets loading
5
  example_path = os.path.join(os.path.dirname(__file__), 'data')
@@ -10,8 +13,73 @@ garm_list_path = [os.path.join(example_path, "garment", garm) for garm in garm_l
10
  person_list = os.listdir(os.path.join(example_path,"person"))
11
  person_list_path = [os.path.join(example_path, "person", person) for person in person_list]
12
 
13
- def run_turbo(person_img, garm_img, category):
14
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  with gr.Blocks() as Huhu_Turbo:
17
  with gr.Row():
@@ -41,25 +109,25 @@ with gr.Blocks() as Huhu_Turbo:
41
  """)
42
  with gr.Row():
43
  with gr.Column(elem_id = "col-garment"):
44
- garm_img = gr.Image(label="Garment image", sources='upload', type="numpy")
45
  category = gr.Dropdown(label="Garment type", choices=['Top', 'Bottom', 'Fullbody'], value="Top")
46
  example = gr.Examples(
47
  inputs=garm_img,
48
- examples_per_page=6,
49
  examples=garm_list_path
50
  )
51
  with gr.Column(elem_id = "col-person"):
52
- person_img = gr.Image(label="Person image", sources='upload', type="numpy")
53
  example = gr.Examples(
54
  inputs=person_img,
55
- examples_per_page=6,
56
  examples=person_list_path
57
  )
58
  with gr.Column(elem_id = "col-result"):
59
  result_img = gr.Image(label="Result", show_share_button=False)
60
  with gr.Row():
61
  result_info = gr.Text(label="Generation time")
62
- generate_button = gr.Button(value="RUN", elem_id="button")
63
 
64
  generate_button.click(fn=run_turbo, inputs=[person_img, garm_img, category], outputs=[result_img, result_info], api_name=False, concurrency_limit=30)
65
 
@@ -75,9 +143,9 @@ with gr.Blocks() as Huhu_Turbo:
75
  """)
76
  show_case = gr.Examples(
77
  examples=[
78
- ["data/examples/garment_example.png", "data/examples/person_example.png", "data/examples/result_example.png"],
79
  ],
80
- inputs=[person_img, garm_img, result_img],
81
  label=None
82
  )
83
 
 
1
  import gradio as gr
2
  import os
3
+ import time
4
+ import requests
5
+
6
 
7
  # assets loading
8
  example_path = os.path.join(os.path.dirname(__file__), 'data')
 
13
  person_list = os.listdir(os.path.join(example_path,"person"))
14
  person_list_path = [os.path.join(example_path, "person", person) for person in person_list]
15
 
16
+
17
+ def call_tryon_api(person_file, garm_file, category, model_type='SD_V1'):
18
+ tryon_url = os.environ['API_ENDPOINT'] + "/tryon/v1"
19
+ payload = {'garment_type': category, 'model_type': model_type}
20
+ files = {
21
+ 'image_garment_file': open(garm_file, 'rb'),
22
+ 'image_model_file': open(person_file, 'rb')
23
+ }
24
+ headers = {
25
+ 'x-api-key': os.environ['API_KEY']
26
+ }
27
+
28
+ try:
29
+ response = requests.post(tryon_url, headers=headers, data=payload, files=files)
30
+ if response.ok:
31
+ data = response.json()
32
+ return data['job_id'], data['status']
33
+ except Exception as e:
34
+ print(f"call tryon api error: {e}")
35
+
36
+ # if the API call fails, return pop up error
37
+ raise gr.Error("Over heated, please try again later")
38
+
39
+ def get_tryon_result(job_id):
40
+ result_url = os.environ['API_ENDPOINT'] + "/requests/v1" + f"?job_id={job_id}"
41
+ headers = {
42
+ 'x-api-key': os.environ['API_KEY']
43
+ }
44
+
45
+ try:
46
+ response = requests.get(result_url, headers=headers)
47
+
48
+ if response.ok:
49
+ data = response.json()
50
+ if data["status"] == "completed":
51
+ image_url = data['output'][0]['image_url']
52
+ return image_url, data['status']
53
+ else:
54
+ return None, data['status']
55
+ except Exception as e:
56
+ print(f"get tryon result error: {e}")
57
+ return None, None
58
+
59
+ def run_turbo(person_img, garm_img, category="Top"):
60
+ if person_img is None or garm_img is None:
61
+ gr.Warning("input image is missing")
62
+ return None, "No input image"
63
+
64
+ info = "" # placeholder for now
65
+
66
+ job_id, status = call_tryon_api(person_img, garm_img, category)
67
+
68
+ time.sleep(3) # wait before fetching the result
69
+
70
+ # check the status of the job
71
+ max_retry = 30 # 30x2s = 60s timeout for sinlge job run
72
+ while status not in ["completed", "failed"]:
73
+ try:
74
+ result_image_url, status = get_tryon_result(job_id)
75
+ if result_image_url is not None:
76
+ return result_image_url, info
77
+ except:
78
+ pass
79
+ time.sleep(2) # Wait before retrying
80
+
81
+ gr.Warning("Over heated, please try again later")
82
+ return None, info
83
 
84
  with gr.Blocks() as Huhu_Turbo:
85
  with gr.Row():
 
109
  """)
110
  with gr.Row():
111
  with gr.Column(elem_id = "col-garment"):
112
+ garm_img = gr.Image(label="Garment image", sources='upload', type="filepath")
113
  category = gr.Dropdown(label="Garment type", choices=['Top', 'Bottom', 'Fullbody'], value="Top")
114
  example = gr.Examples(
115
  inputs=garm_img,
116
+ examples_per_page=10,
117
  examples=garm_list_path
118
  )
119
  with gr.Column(elem_id = "col-person"):
120
+ person_img = gr.Image(label="Person image", sources='upload', type="filepath")
121
  example = gr.Examples(
122
  inputs=person_img,
123
+ examples_per_page=10,
124
  examples=person_list_path
125
  )
126
  with gr.Column(elem_id = "col-result"):
127
  result_img = gr.Image(label="Result", show_share_button=False)
128
  with gr.Row():
129
  result_info = gr.Text(label="Generation time")
130
+ generate_button = gr.Button(value="RUN", elem_id="button")
131
 
132
  generate_button.click(fn=run_turbo, inputs=[person_img, garm_img, category], outputs=[result_img, result_info], api_name=False, concurrency_limit=30)
133
 
 
143
  """)
144
  show_case = gr.Examples(
145
  examples=[
146
+ ["data/examples/garment_example.png", "data/examples/person_example.png", "Top", "data/examples/result_example.png"],
147
  ],
148
+ inputs=[person_img, garm_img, category, result_img],
149
  label=None
150
  )
151